# @package      hubzero-submit-server
# @file         BoundConnection.py
# @author       Steven Clark <clarks@purdue.edu>
# @copyright    Copyright (c) 2012-2015 HUBzero Foundation, LLC.
# @license      http://opensource.org/licenses/MIT MIT
#
# Copyright (c) 2012-2015 HUBzero Foundation, LLC.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#
# HUBzero is a registered trademark of HUBzero Foundation, LLC.
#
import os.path
import copy
import time
import json
import logging

from hubzero.submit.LogMessage  import getLogIDMessage as getLogMessage
from hubzero.submit.MessageCore import MessageCore

class BoundConnection:
   def __init__(self,
                listenURIs,
                submitSSLcert="",
                submitSSLkey="",
                submitSSLCA="",
                bufferSize=1024):
      self.logger              = logging.getLogger(__name__)
      self.listeners           = {}
      self.boundSockets        = []
      self.remoteIP            = None
      self.remoteHost          = None
      self.activeChannel       = None
      self.activeSocket        = None
      self.connectionReadTime  = 0.
      self.connectionWriteTime = 0.
      self.bufferSize          = bufferSize
      self.fromBuffer          = ""
      self.toBuffer            = ""

      for listenURI in listenURIs:
         protocol,bindHost,bindPort,bindFile = self.__parseURI(listenURI)
         if   protocol == 'tls':
            if bindPort > 0:
               self.logger.log(logging.INFO,getLogMessage("Listening: protocol='%s', host='%s', port=%d" % \
                                                                        (protocol,bindHost,bindPort)))
               listener = MessageCore(protocol='tls',
                                      bindHost=bindHost,bindPort=bindPort,bindLabel=listenURI,
                                      sslKeyPath=submitSSLkey,sslCertPath=submitSSLcert,
                                      sslCACertPath=submitSSLCA,
                                      reuseAddress=True,blocking=False,
                                      defaultBufferSize=self.bufferSize)
         elif protocol == 'tcp':
            if bindPort > 0:
               self.logger.log(logging.INFO,getLogMessage("Listening: protocol='%s', host='%s', port=%d" % \
                                                                        (protocol,bindHost,bindPort)))
               listener = MessageCore(protocol='tcp',
                                      bindHost=bindHost,bindPort=bindPort,bindLabel=listenURI,
                                      reuseAddress=True,blocking=False,
                                      defaultBufferSize=self.bufferSize)
         elif protocol == 'file':
            if bindFile:
               self.logger.log(logging.INFO,getLogMessage("Listening: '%s'" % (bindFile)))
               listener = MessageCore(protocol='file',
                                      bindFile=bindFile,
                                      bindLabel="UD:%s" % (bindFile),
                                      defaultBufferSize=self.bufferSize)
         else:
            listener = None
            self.logger.log(logging.ERROR,getLogMessage("Unknown protocol: %s" % (protocol)))

         if listener:
            if listener.isBound():
               boundSocket = listener.boundSocket()
               self.listeners[boundSocket] = listener
               self.boundSockets.append(boundSocket)
         else:
            self.logger.log(logging.ERROR,getLogMessage("Could not bind connection to: %s" % (listenURI)))

      if len(self.listeners) == 0:
         self.logger.log(logging.ERROR,getLogMessage("No listening devices configured"))


   def __parseURI(self,
                  uri):
      protocol = ""
      host     = ""
      port     = 0
      filePath = ""
      try:
         parts = uri.split(':')
         if   len(parts) == 3:
            protocol,host,port = parts
            protocol = protocol.lower()
            host     = host.lstrip('/')
            port     = int(port)
         elif len(parts) == 2:
            protocol,filePath = parts
            protocol = protocol.lower()
            filePath = filePath.replace('/','',2)
      except:
         protocol = ""
         self.logger.log(logging.ERROR,getLogMessage("Improper network specification: %s" % (uri)))

      return(protocol,host,port,filePath)


   def acceptHandshake(self,
                       listeningSocket):
      valid = self.listeners[listeningSocket].acceptHandshake(self.activeChannel)
      if valid:
         self.connectionReadTime  = time.time()
         self.connectionWriteTime = self.connectionReadTime

      return(valid)


   def isListening(self):
      return(len(self.listeners) > 0)


   def isConnected(self):
#     if self.activeChannel:
#        self.connectionReadTime = time.time()

      return(self.activeChannel != None)


   def closeConnection(self):
      if self.activeChannel:
#        self.connectionReadTime = time.time()
         self.activeChannel.close()
         self.activeChannel = None


   def setChannelAndBuffers(self,
                            newChannel,
                            newFromBuffer,
                            newToBuffer):
      activeChannel = self.activeChannel
      fromBuffer    = self.fromBuffer
      toBuffer      = self.toBuffer

      self.activeChannel = newChannel
      self.fromBuffer    = newFromBuffer
      self.toBuffer      = newToBuffer

      self.connectionReadTime  = time.time()
      self.connectionWriteTime = self.connectionReadTime

      return(activeChannel,fromBuffer,toBuffer)


   def closeListeningConnections(self):
      for listener in self.listeners.values():
         boundSocket = listener.boundSocket()
         self.boundSockets.remove(boundSocket)
         listener.close()


   def acceptConnection(self,
                        listeningSocket):
      connectionAccepted = False
      if self.listeners[listeningSocket].getProtocol() == 'file':
         activeChannel = self.listeners[listeningSocket].acceptConnection(logConnection=True)
         remoteIP   = None
         remotePort = None
         remoteHost = None
      else:
         activeChannel,remoteIP,remotePort,remoteHost = self.listeners[listeningSocket].acceptConnection(logConnection=True,
                                                                                                         determineDetails=True)
      if activeChannel:
         connectionAccepted       = True
         self.activeSocket        = listeningSocket
         self.activeChannel       = activeChannel
         self.connectionReadTime  = time.time()
         self.connectionWriteTime = time.time()
         self.remoteIP            = remoteIP
         self.remoteHost          = remoteHost

      return(connectionAccepted)


   def getConnectionReadTime(self):
      return(self.connectionReadTime)


   def getConnectionWriteTime(self):
      return(self.connectionWriteTime)


   def getRemoteIP(self):
      return(self.remoteIP)


   def getRemoteHost(self):
      return(self.remoteHost)


   def getInputObjects(self):
      listeningSockets = copy.copy(self.boundSockets)
      if self.activeChannel:
         activeReader = [self.activeChannel]
#        self.connectionReadTime = time.time()
      else:
         activeReader = []

      return(listeningSockets,activeReader)


   def getOutputObjects(self):
      activeWriter = []
      if self.activeChannel:
#        self.connectionReadTime = time.time()
         if self.toBuffer != "":
            activeWriter = [self.activeChannel]

      return(activeWriter)


   def sendMessage(self):
      transmittedLength = self.listeners[self.activeSocket].sendMessage(self.activeChannel,self.toBuffer)
      if   transmittedLength > 0:
         self.toBuffer = self.toBuffer[transmittedLength:]
         self.connectionWriteTime = time.time()
      elif transmittedLength < 0:
         self.closeConnection()


   def receiveMessage(self):
      message = self.listeners[self.activeSocket].receiveMessage(self.activeChannel,0,self.bufferSize)
      if message == None:
         self.closeConnection()
      else:
         self.fromBuffer += message
         self.connectionReadTime = time.time()


   def pullMessage(self,
                   messageLength):
      if   messageLength == 0:
         try:
            nl = self.fromBuffer.index('\n')
            message = self.fromBuffer[0:nl]
            self.fromBuffer = self.fromBuffer[nl+1:]
         except:
            message = ""
      elif messageLength < 0:
         ml = min(len(self.fromBuffer),-messageLength)
         message = self.fromBuffer[0:ml]
         self.fromBuffer = self.fromBuffer[ml:]
      else:
         if len(self.fromBuffer) >= messageLength:
            message = self.fromBuffer[0:messageLength]
            self.fromBuffer = self.fromBuffer[messageLength:]
         else:
            message = ""

      return(message)


   def pushMessage(self,
                   message):
      self.fromBuffer = message + self.fromBuffer


   def postMessage(self,
                   message):
      self.toBuffer += message


   def postMessageBySize(self,
                         command,
                         message):
      if len(message) > 0:
         self.toBuffer += command + " %d\n" % (len(message))
         self.toBuffer += message


   def postMessagesBySize(self,
                          command,
                          messages):
      text = ""
      for message in messages:
         text += message
      if len(text) > 0:
         self.toBuffer += command
         for message in messages:
            self.toBuffer += " %d" % (len(message))
         self.toBuffer += "\n" + text


   def postJsonMessage(self,
                       jsonObject):
      if jsonObject:
         try:
            message = json.dumps(jsonObject)
         except TypeError:
            self.logger.log(logging.ERROR,getLogMessage("JSON object %s could not be encoded" % (jsonObject)))
         else:
            if len(message) > 0:
               self.postMessageBySize('json',message)


   def isMessagePending(self):
      return(len(self.toBuffer) > 0)


   def logPendingMessage(self,
                         logMessage):
      if self.isMessagePending():
         self.logger.log(logging.INFO,getLogMessage("logPendingMessage(%s): %s" % (logMessage,self.toBuffer)))


