# @package      hubzero-submit-common
# @file         ServerConnection.py
# @author       Steven Clark <clarks@purdue.edu>
# @copyright    Copyright (c) 2012 HUBzero Foundation, LLC.
# @license      http://www.gnu.org/licenses/lgpl-3.0.html LGPLv3
#
# Copyright (c) 2012 HUBzero Foundation, LLC.
#
# This file is part of: The HUBzero(R) Platform for Scientific Collaboration
#
# The HUBzero(R) Platform for Scientific Collaboration (HUBzero) is free
# software: you can redistribute it and/or modify it under the terms of
# the GNU Lesser General Public License as published by the Free Software
# Foundation, either version 3 of the License, or (at your option) any
# later version.
#
# HUBzero is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
# HUBzero is a registered trademark of HUBzero Foundation, LLC.
#

import os.path
import time

from hubzero.submit.LogMessage  import logID as log
from hubzero.submit.MessageCore import MessageCore

CERTIFICATEDIRECTORY = os.path.join(os.sep,'etc','submit')
SSLCACERTPATH        = os.path.join(CERTIFICATEDIRECTORY,'submit_server_ca.crt')

class ServerConnection:
   def __init__(self,
                listenURIs):
      self.serverMessageCore = None
      self.serverChannel     = None
      self.fromServerBuffer  = ""
      self.toServerBuffer    = ""
      self.bufferSize        = 1024
      delay                  = 0.

      if len(listenURIs) > 0:
         while not self.serverChannel:
            time.sleep(delay)
            for listenURI in listenURIs:
               protocol,serverHost,serverPort = self.__parseURL(listenURI)
               if serverPort > 0:
                  log("Connecting: protocol='%s', host='%s', port=%d" % (protocol,serverHost,serverPort))
                  if   protocol == 'tls':
                     self.serverMessageCore = MessageCore(bindLabel=listenURI,
                                                          sslCACertPath=SSLCACERTPATH,
                                                          listenerHost=serverHost,listenerPort=serverPort)
                  elif protocol == 'tcp':
                     self.serverMessageCore = MessageCore(bindLabel=listenURI,
                                                          listenerHost=serverHost,listenerPort=serverPort)
                  else:
                     log("Unknown protocol: %s" % (protocol))

                  if self.serverMessageCore:
                     self.serverChannel = self.serverMessageCore.openListenerChannel(True)
                     if self.serverChannel:
                        if self.handshake():
                           break
                        else:
                           self.serverMessageCore.closeListenerChannel(self.serverChannel)
                           self.serverMessageCore = None
                           self.serverChannel     = None
                     else:
                        self.serverMessageCore = None

            delay = 10.
      else:
         log("No servers to be configured")


   def isConnected(self):
      return(self.serverChannel != None)


   def __parseURL(self,
                  url):
      protocol = ""
      host     = ""
      port     = 0
      try:
         protocol,host,port = url.split(':')
         protocol = protocol.lower()
         host     = host.lstrip('/')
         port     = int(port)
      except:
         log("Improper network specification: %s" % (url))

      return(protocol,host,port)


   def __handshake(self):
      valid = False
      message = "Hello.\n"
      reply  = ""

      try:
         # Write the message.
         nSent = self.serverMessageCore.sendMessage(self.serverChannel,message)

         # Expect the same message back.
         reply = self.serverMessageCore.receiveMessage(self.serverChannel,nSent,nSent)
         if reply == message:
            valid = True
      except Exception, err:
         log("ERROR: Connection handshake failed.  Protocol mismatch?")
         log("handshake(%s): %s" % (message.strip(),reply.strip()))
         log("err = %s" % (str(err)))

      return(valid)


   def handshake(self):
      valid = self.__handshake()
      if valid:
         self.serverChannel.setblocking(0)

      return(valid)


   def getServerInputObject(self):
      serverReader = []
      if self.serverChannel:
         serverReader = [self.serverChannel]

      return(serverReader)


   def getServerOutputObject(self):
      serverWriter = []
      if self.serverChannel and self.toServerBuffer != "":
         serverWriter = [self.serverChannel]

      return(serverWriter)


   def receiveServerMessage(self):
      serverMessage = self.serverMessageCore.receiveMessage(self.serverChannel,0,self.bufferSize)
      if   serverMessage == None:
         self.closeServerConnection()
      elif serverMessage == "":
         self.closeServerConnection()
      else:
         self.fromServerBuffer += serverMessage


   def pullServerMessage(self,
                         messageLength):
      if messageLength == 0:
         try:
            nl = self.fromServerBuffer.index('\n')
            message = self.fromServerBuffer[0:nl]
            self.fromServerBuffer = self.fromServerBuffer[nl+1:]
         except:
            message = ""
      else:
         if len(self.fromServerBuffer) >= messageLength:
            message = self.fromServerBuffer[0:messageLength]
            self.fromServerBuffer = self.fromServerBuffer[messageLength:]
         else:
            message = ""

      return(message)


   def pushServerMessage(self,
                         message):
      self.fromServerBuffer = message + self.fromServerBuffer


   def postServerMessage(self,
                         message):
      self.toServerBuffer += message


   def postServerMessageBySize(self,
                               command,
                               message):
      if len(message) > 0:
         self.toServerBuffer += command + " %d\n" % (len(message))
         self.toServerBuffer += message


   def postServerMessagesBySize(self,
                                command,
                                messages):
      text = ""
      for message in messages:
         text += message
      if len(text) > 0:
         self.toServerBuffer += command
         for message in messages:
            self.toServerBuffer += " %d" % (len(message))
         self.toServerBuffer += "\n" + text


   def sendServerMessage(self):
      if self.isConnected():
         transmittedLength = self.serverMessageCore.sendMessage(self.serverChannel,self.toServerBuffer)
         self.toServerBuffer = self.toServerBuffer[transmittedLength:]


   def closeServerConnection(self):
      self.serverChannel.close()
      self.serverChannel = None


