# @package      hubzero-submit-common
# @file         MessageCore.py
# @author       Steven Clark <clarks@purdue.edu>
# @copyright    Copyright (c) 2012 HUBzero Foundation, LLC.
# @license      http://www.gnu.org/licenses/lgpl-3.0.html LGPLv3
#
# Copyright (c) 2012 HUBzero Foundation, LLC.
#
# This file is part of: The HUBzero(R) Platform for Scientific Collaboration
#
# The HUBzero(R) Platform for Scientific Collaboration (HUBzero) is free
# software: you can redistribute it and/or modify it under the terms of
# the GNU Lesser General Public License as published by the Free Software
# Foundation, either version 3 of the License, or (at your option) any
# later version.
#
# HUBzero is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
# HUBzero is a registered trademark of HUBzero Foundation, LLC.
#

from OpenSSL import SSL
import sys
import os
import socket
import time
import traceback

from hubzero.submit.LogMessage import logID as log

class MessageCore:
   def __init__(self,
                bindHost="",
                bindPort=0,
                bindLabel="",
                sslKeyPath="",
                sslCertPath="",
                sslCACertPath="",
                reuseAddress=False,
                blocking=True,
                listenerHost="",
                listenerPort=0,
                repeatDelay=5):
      self.bindHost     = bindHost
      self.bindPort     = bindPort
      self.bindLabel    = bindLabel
      self.bindSocket   = None
      self.listenerHost = listenerHost
      self.listenerPort = listenerPort
      self.repeatDelay  = repeatDelay
      self.sslContext   = None

      if bindPort > 0:
         sock = socket.socket(socket.AF_INET,socket.SOCK_STREAM)
         if reuseAddress:
            sock.setsockopt(socket.SOL_SOCKET,socket.SO_REUSEADDR,1)

         if os.access(sslKeyPath,os.R_OK) and os.access(sslCertPath,os.R_OK) and os.access(sslCACertPath,os.R_OK):
            self.sslContext = SSL.Context(SSL.TLSv1_METHOD)
            self.sslContext.use_privatekey_file(sslKeyPath)
            self.sslContext.use_certificate_file(sslCertPath)
            self.sslContext.load_verify_locations(sslCACertPath)
            self.bindSocket = SSL.Connection(self.sslContext,sock)
         else:
            self.bindSocket = sock

         bound = False
         nTry = 0
         while not bound and nTry < 10:
            try:
               nTry += 1
               self.bindSocket.bind((bindHost,bindPort))
               self.bindSocket.listen(512)
               if not blocking:
                  self.bindSocket.setblocking(0)
               bound = True
            except:
               time.sleep(repeatDelay)

         if not bound:
            self.bindSocket = None
            log("Can't bind to port %d: %s %s" % (bindPort,sys.exc_info()[0],sys.exc_info()[1]))
      else:
         if os.access(sslCACertPath,os.R_OK):
            self.sslContext = SSL.Context(SSL.TLSv1_METHOD)
            self.sslContext.set_verify(SSL.VERIFY_PEER,self.__verifyCert) # Demand a certificate
            self.sslContext.load_verify_locations(sslCACertPath)


   def __verifyCert(self,
                    conn,
                    cert,
                    errnum,
                    depth,
                    ok):
#     log("Got certificate: %s" % (cert.get_subject()))
#     log("Issuer: %s" % (cert.get_issuer()))
#     log("Depth = %s" % (str(depth)))

      return(ok)


   def getMessageSocket(self):
      if self.sslContext:
         messageSocket = SSL.Connection(self.sslContext,socket.socket(socket.AF_INET,socket.SOCK_STREAM))
      else:
         messageSocket = socket.socket(socket.AF_INET,socket.SOCK_STREAM)

      return(messageSocket)


   def isBound(self):
      return(self.bindSocket != None)


   def close(self):
      if self.bindSocket:
         self.bindSocket.close()
         self.bindSocket = None


   def boundSocket(self):
      return(self.bindSocket)


   def boundFileDescriptor(self):
      return(self.bindSocket.fileno())


   def acceptConnection(self,
                        logConnection=False):
      channel,details = self.bindSocket.accept()
      if logConnection:
         log("====================================================")
         log("Connection to %s from %s" % (self.bindLabel,details))

      return(channel)


   def acceptConnectionDetailed(self,
                                logConnection=False):
      channel,details = self.bindSocket.accept()
      if logConnection:
         log("====================================================")
         log("Connection to %s from %s" % (self.bindLabel,details))
      remoteIP,remotePort = details

      return(channel,remoteIP,remotePort)


   def openListenerChannel(self,
                           recordTraceback=False):
      try:
         listenerChannel = self.getMessageSocket()
         listenerChannel.connect((self.listenerHost,self.listenerPort))
      except:
         if recordTraceback:
            log(traceback.format_exc())
         if listenerChannel:
            listenerChannel.close()
         listenerChannel = None

      return(listenerChannel)


   def closeListenerChannel(self,
                            listenerChannel):
      if listenerChannel:
         listenerChannel.close()


   def __receiveNonBlockingMessage(self,
                                   channel,
                                   bufferSize):
      message = ""

      while True:
         try:
            messageChunk = channel.recv(bufferSize)
         except socket.error:
            # Happens on non-blocking TCP socket when there's nothing to read
            break
         except (SSL.WantReadError, SSL.WantWriteError, SSL.WantX509LookupError):
            break
         except SSL.SysCallError, (errnum,errstr):
            if errnum == -1:
               break
         except:
            log("Unexpected error in receiveNonBlockingMessage() " + message)
            message = ""
            log(traceback.format_exc())
         else:
            if messageChunk:
               message += messageChunk
            else:
               break

      return(message)


   def __receiveBlockingMessage(self,
                                channel,
                                messageLength,
                                bufferSize):
      bytesRemaining = messageLength
      message = ""

      try:
         while bytesRemaining:
            messageChunk = channel.recv(bufferSize)
            message += messageChunk
            bytesRemaining -= len(messageChunk)
            if messageChunk == "":
               if message != "":
                  log("socket connection broken in receiveBlockingMessage()")
               message = ""
               break
      except:
         log("Unexpected error in receiveBlockingMessage() " + message)
         message = ""
         log(traceback.format_exc())

      return(message)


   def receiveMessage(self,
                      channel,
                      messageLength,
                      bufferSize=128):
      timeout = channel.gettimeout()
      if timeout == None:
         message = self.__receiveBlockingMessage(channel,messageLength,bufferSize)
      else:
         message = self.__receiveNonBlockingMessage(channel,bufferSize)

      return(message)


   def __sendNonBlockingMessage(self,
                                channel,
                                message):
      try:
         transmittedLength = channel.send(message)
         if transmittedLength == 0:
            log("socket connection broken in sendNonBlockingMessage()")
            transmittedLength = -1
      except:
         log("Unexpected error in sendNonBlockingMessage(%s)" % (message))
         log(traceback.format_exc())
         transmittedLength = -1

      return(transmittedLength)


   def __sendBlockingMessage(self,
                             channel,
                             message,
                             fixedBufferSize):
      try:
         if fixedBufferSize > 0:
            fixedBufferMessage = "%-*s" % (fixedBufferSize,message)
            bytesRemaining = fixedBufferSize
            while bytesRemaining:
               transmittedLength = channel.send(fixedBufferMessage[fixedBufferSize-bytesRemaining:])
               bytesRemaining -= transmittedLength
               if transmittedLength == 0:
                  log("socket connection broken in sendBlockingMessage()")
                  transmittedLength = -1
                  break
         else:
            bytesRemaining = len(message)
            while bytesRemaining:
               transmittedLength = channel.send(message[len(message)-bytesRemaining:])
               bytesRemaining -= transmittedLength
               if transmittedLength == 0:
                  log("socket connection broken in sendBlockingMessage()")
                  transmittedLength = -1
                  break
      except:
         log("Unexpected error in sendBlockingMessage(%s)" % (message))
         log(traceback.format_exc())
         transmittedLength = -1

      return(transmittedLength)


   def sendMessage(self,
                   channel,
                   message,
                   fixedBufferSize=0):
      timeout = channel.gettimeout()
      if timeout == None:
         transmittedLength = self.__sendBlockingMessage(channel,message,fixedBufferSize)
      else:
         transmittedLength = self.__sendNonBlockingMessage(channel,message)

      return(transmittedLength)


   def requestMessageResponse(self,
                              message,
                              messageBufferSize,
                              responseBufferSize,
                              recordTraceback=False):
      nTry = 0
      delay = 0
      posted = False
      messageResponseSocket = None
      while not posted:
         time.sleep(delay)
         try:
            nTry += 1
            messageResponseSocket = self.getMessageSocket()
            messageResponseSocket.connect((self.listenerHost,self.listenerPort))
            if self.__sendBlockingMessage(messageResponseSocket,message,messageBufferSize) > 0:
               response = self.__receiveBlockingMessage(messageResponseSocket,responseBufferSize,128)
               if response != "":
                  posted = True
         except:
            if recordTraceback:
               log(traceback.format_exc())

         if messageResponseSocket:
            messageResponseSocket.close()
         delay = self.repeatDelay

      return(nTry,response)


   def requestMessageVariableResponse(self,
                                      message,
                                      messageBufferSize,
                                      responseBufferSize,
                                      recordTraceback=False):
      response = ""
      nTry = 0
      delay = 0
      posted = False
      messageResponseSocket = None
      while not posted:
         time.sleep(delay)
         try:
            nTry += 1
            messageResponseSocket = self.getMessageSocket()
            messageResponseSocket.connect((self.listenerHost,self.listenerPort))
            if self.__sendBlockingMessage(messageResponseSocket,message,messageBufferSize) > 0:
               responseHeader = self.__receiveBlockingMessage(messageResponseSocket,responseBufferSize,responseBufferSize)
               if responseHeader != "":
                  responseLength = responseHeader.strip()
                  if int(responseLength) > 0:
                     response = self.__receiveBlockingMessage(messageResponseSocket,int(responseLength),int(responseLength))
                     if response != "":
                        posted = True
                  else:
                     posted = True
         except:
            if recordTraceback:
               log(traceback.format_exc())

         if messageResponseSocket:
            messageResponseSocket.close()
         delay = self.repeatDelay

      return(nTry,response)


   def requestMessageTimestampResponse(self,
                                       message,
                                       messageBufferSize,
                                       responseBufferSize,
                                       recordTraceback=False):
      response = ""
      nTry = 0
      delay = 0
      posted = False
      messageResponseSocket = None
      while not posted:
         time.sleep(delay)
         try:
            nTry += 1
            messageResponseSocket = self.getMessageSocket()
            messageResponseSocket.connect((self.listenerHost,self.listenerPort))
            if self.__sendBlockingMessage(messageResponseSocket,message,messageBufferSize) > 0:
               responseHeader = self.__receiveBlockingMessage(messageResponseSocket,responseBufferSize,responseBufferSize)
               if responseHeader != "":
                  responseLength,responseTimestamp = responseHeader.strip().split()
                  if int(responseLength) > 0:
                     response = self.__receiveBlockingMessage(messageResponseSocket,int(responseLength),int(responseLength))
                     if response != "":
                        posted = True
                  else:
                     posted = True
         except:
            if recordTraceback:
               log(traceback.format_exc())

         if messageResponseSocket:
            messageResponseSocket.close()
         delay = self.repeatDelay

      return(nTry,response,responseTimestamp)


