# @package      hubzero-submit-server
# @file         BoundConnection.py
# @author       Steven Clark <clarks@purdue.edu>
# @copyright    Copyright (c) 2012-2014 HUBzero Foundation, LLC.
# @license      http://www.gnu.org/licenses/lgpl-3.0.html LGPLv3
#
# Copyright (c) 2012-2014 HUBzero Foundation, LLC.
#
# This file is part of: The HUBzero(R) Platform for Scientific Collaboration
#
# The HUBzero(R) Platform for Scientific Collaboration (HUBzero) is free
# software: you can redistribute it and/or modify it under the terms of
# the GNU Lesser General Public License as published by the Free Software
# Foundation, either version 3 of the License, or (at your option) any
# later version.
#
# HUBzero is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
# HUBzero is a registered trademark of HUBzero Foundation, LLC.
#

import os.path
import copy
import time
import logging

from hubzero.submit.LogMessage  import getLogIDMessage as getLogMessage
from hubzero.submit.MessageCore import MessageCore

class BoundConnection:
   def __init__(self,
                listenURIs,
                submitSSLcert="",
                submitSSLkey="",
                submitSSLCA=""):
      self.logger                = logging.getLogger(__name__)
      self.listeners             = {}
      self.boundSockets          = []
      self.remoteIP              = None
      self.remoteHost            = None
      self.activeChannel         = None
      self.activeSocket          = None
      self.connectionCheckedTime = 0.
      self.bufferSize            = 1024
      self.fromBuffer            = ""
      self.toBuffer              = ""

      for listenURI in listenURIs:
         protocol,bindHost,bindPort,bindFile = self.__parseURI(listenURI)
         if   protocol == 'tls':
            if bindPort > 0:
               self.logger.log(logging.INFO,getLogMessage("Listening: protocol='%s', host='%s', port=%d" % \
                                                                        (protocol,bindHost,bindPort)))
               listener = MessageCore(protocol='tls',
                                      bindHost=bindHost,bindPort=bindPort,bindLabel=listenURI,
                                      sslKeyPath=submitSSLkey,sslCertPath=submitSSLcert,
                                      sslCACertPath=submitSSLCA,
                                      reuseAddress=True,blocking=False)
         elif protocol == 'tcp':
            if bindPort > 0:
               self.logger.log(logging.INFO,getLogMessage("Listening: protocol='%s', host='%s', port=%d" % \
                                                                        (protocol,bindHost,bindPort)))
               listener = MessageCore(protocol='tcp',
                                      bindHost=bindHost,bindPort=bindPort,bindLabel=listenURI,
                                      reuseAddress=True,blocking=False)
         elif protocol == 'file':
            if bindFile:
               self.logger.log(logging.INFO,getLogMessage("Listening: '%s'" % (bindFile)))
               listener = MessageCore(protocol='file',
                                      bindFile=bindFile,
                                      bindLabel="UD:%s" % (bindFile))
         else:
            listener = None
            self.logger.log(logging.ERROR,getLogMessage("Unknown protocol: %s" % (protocol)))

         if listener:
            if listener.isBound():
               boundSocket = listener.boundSocket()
               self.listeners[boundSocket] = listener
               self.boundSockets.append(boundSocket)
         else:
            self.logger.log(logging.ERROR,getLogMessage("Could not bind connection to: %s" % (listenURI)))

      if len(self.listeners) == 0:
         self.logger.log(logging.ERROR,getLogMessage("No listening devices configured"))


   def __parseURI(self,
                  uri):
      protocol = ""
      host     = ""
      port     = 0
      filePath = ""
      try:
         parts = uri.split(':')
         if   len(parts) == 3:
            protocol,host,port = parts
            protocol = protocol.lower()
            host     = host.lstrip('/')
            port     = int(port)
         elif len(parts) == 2:
            protocol,filePath = parts
            protocol = protocol.lower()
            filePath = filePath.replace('/','',2)
      except:
         protocol = ""
         self.logger.log(logging.ERROR,getLogMessage("Improper network specification: %s" % (uri)))

      return(protocol,host,port,filePath)


   def __handshake(self,
                   listeningSocket):
      valid = False
      message = "Hello.\n"
      reply = ""

      try:
         # Write the message.
         nSent = self.listeners[listeningSocket].sendMessage(self.activeChannel,message)

         if nSent > 0:
            # Expect the same message back.
            reply = self.listeners[listeningSocket].receiveMessage(self.activeChannel,nSent,nSent)
            if reply == message:
               valid = True
      except Exception, err:
         self.logger.log(logging.ERROR,getLogMessage("ERROR: Connection handshake failed.  Protocol mismatch?"))
         self.logger.log(logging.ERROR,getLogMessage("handshake(%s): %s" % (message.strip(),reply.strip())))
         self.logger.log(logging.ERROR,getLogMessage("err = %s" % (str(err))))

      return(valid)


   def handshake(self,
                 listeningSocket):
      valid = self.__handshake(listeningSocket)
      if valid:
         self.activeChannel.setblocking(0)
         self.connectionCheckedTime = time.time()

      return(valid)


   def isListening(self):
      return(len(self.listeners) > 0)


   def isConnected(self):
      if self.activeChannel:
         self.connectionCheckedTime = time.time()

      return(self.activeChannel != None)


   def closeConnection(self):
      if self.activeChannel:
         self.connectionCheckedTime = time.time()
         self.activeChannel.close()
         self.activeChannel = None


   def setChannelAndBuffers(self,
                            newChannel,
                            newFromBuffer,
                            newToBuffer):
      activeChannel = self.activeChannel
      fromBuffer    = self.fromBuffer
      toBuffer      = self.toBuffer

      self.activeChannel = newChannel
      self.fromBuffer    = newFromBuffer
      self.toBuffer      = newToBuffer

      return(activeChannel,fromBuffer,toBuffer)


   def closeListeningConnections(self):
      for listener in self.listeners.values():
         boundSocket = listener.boundSocket()
         self.boundSockets.remove(boundSocket)
         listener.close()


   def acceptConnection(self,
                        listeningSocket):
      connectionAccepted = False
      if self.listeners[listeningSocket].getProtocol() == 'file':
         activeChannel = self.listeners[listeningSocket].acceptConnection(logConnection=True)
         remoteIP   = None
         remotePort = None
         remoteHost = None
      else:
         activeChannel,remoteIP,remotePort,remoteHost = self.listeners[listeningSocket].acceptConnection(logConnection=True,
                                                                                                         determineDetails=True)
      if activeChannel:
         connectionAccepted         = True
         self.activeSocket          = listeningSocket
         self.activeChannel         = activeChannel
         self.connectionCheckedTime = time.time()
         self.remoteIP              = remoteIP
         self.remoteHost            = remoteHost

      return(connectionAccepted)


   def getConnectionCheckedTime(self):
      return(self.connectionCheckedTime)


   def getRemoteIP(self):
      return(self.remoteIP)


   def getRemoteHost(self):
      return(self.remoteHost)


   def getInputObjects(self):
      listeningSockets = copy.copy(self.boundSockets)
      if self.activeChannel:
         activeReader = [self.activeChannel]
         self.connectionCheckedTime = time.time()
      else:
         activeReader = []

      return(listeningSockets,activeReader)


   def getOutputObjects(self):
      activeWriter = []
      if self.activeChannel:
         self.connectionCheckedTime = time.time()
         if self.toBuffer != "":
            activeWriter = [self.activeChannel]

      return(activeWriter)


   def sendMessage(self):
      transmittedLength = self.listeners[self.activeSocket].sendMessage(self.activeChannel,self.toBuffer)
      if   transmittedLength > 0:
         self.toBuffer = self.toBuffer[transmittedLength:]
      elif transmittedLength < 0:
         self.closeConnection()


   def receiveMessage(self):
      message = self.listeners[self.activeSocket].receiveMessage(self.activeChannel,0,self.bufferSize)
      if message == None:
         self.closeConnection()
      else:
         self.fromBuffer += message


   def pullMessage(self,
                   messageLength):
      if   messageLength == 0:
         try:
            nl = self.fromBuffer.index('\n')
            message = self.fromBuffer[0:nl]
            self.fromBuffer = self.fromBuffer[nl+1:]
         except:
            message = ""
      elif messageLength < 0:
         ml = min(len(self.fromBuffer),-messageLength)
         message = self.fromBuffer[0:ml]
         self.fromBuffer = self.fromBuffer[ml:]
      else:
         if len(self.fromBuffer) >= messageLength:
            message = self.fromBuffer[0:messageLength]
            self.fromBuffer = self.fromBuffer[messageLength:]
         else:
            message = ""

      return(message)


   def pushMessage(self,
                   message):
      self.fromBuffer = message + self.fromBuffer


   def postMessage(self,
                   message):
      self.toBuffer += message


   def postMessageBySize(self,
                         command,
                         message):
      if len(message) > 0:
         self.toBuffer += command + " %d\n" % (len(message))
         self.toBuffer += message


   def postMessagesBySize(self,
                          command,
                          messages):
      text = ""
      for message in messages:
         text += message
      if len(text) > 0:
         self.toBuffer += command
         for message in messages:
            self.toBuffer += " %d" % (len(message))
         self.toBuffer += "\n" + text


   def isMessagePending(self):
      return(len(self.toBuffer) > 0)


   def logPendingMessage(self,
                         logMessage):
      if self.isMessagePending():
         self.logger.log(logging.INFO,getLogMessage("logPendingMessage(%s): %s" % (logMessage,self.toBuffer)))


