#
# Copyright (c) 2004-2011 Purdue University All rights reserved.
#
# Developed by: HUBzero Technology Group, Purdue University
#               http://hubzero.org
#
# HUBzero is free software: you can redistribute it and/or modify it under the terms of the
# GNU Lesser General Public License as published by the Free Software Foundation, either
# version 3 of the License, or (at your option) any later version.
#
# HUBzero is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
# without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU Lesser General Public License for more details.  You should have received a
# copy of the GNU Lesser General Public License along with HUBzero.
# If not, see <http://www.gnu.org/licenses/>.
#
# GNU LESSER GENERAL PUBLIC LICENSE
# Version 3, 29 June 2007
# Copyright (C) 2007 Free Software Foundation, Inc. <http://fsf.org/>
#
import sys
import os
import select
import subprocess
import time
import datetime
import re
import shutil
import copy

from hubzero.submit.LogMessage   import logID as log
from hubzero.submit.JobStatistic import JobStatistic

class VenueMechanismCore:
   def __init__(self,
                timeHistoryLogs,
                siteInfo,
                managerInfo,
                remoteMonitors,
                isMultiCoreRequest,
                nCpus,
                nNodes,
                ppn):
      self.timeHistoryLogs     = copy.copy(timeHistoryLogs)
      self.siteInfo            = copy.copy(siteInfo)
      self.managerInfo         = copy.copy(managerInfo)
      self.remoteJobMonitor    = remoteMonitors['job']
      self.remoteTunnelMonitor = remoteMonitors['tunnel']
      self.remoteCloudMonitor  = remoteMonitors['cloud']

      self.childPid   = 0
      self.bufferSize = 4096

      self.venueMechanism    = ''
      self.remoteBatchSystem = ''
      self.enteredCommand    = ""

      self.isMultiCoreRequest = isMultiCoreRequest
      self.nCpus              = nCpus
      self.nNodes             = nNodes
      self.ppn                = ppn

      self.jobSubmitted  = False
      self.jobStatistics = {}
      self.jobIndex      = 0
      if self.isMultiCoreRequest:
         self.jobStatistics[self.jobIndex] = JobStatistic(self.nCpus)
      else:
         self.jobStatistics[self.jobIndex] = JobStatistic(1)

      self.scriptsCreated    = False
      self.filesSent         = False
      self.filesRetrieved    = False
      self.filesCleanedup    = False
      self.scriptsKilled     = False
      self.localJobId        = ""
      self.instanceId        = "-1"
      self.instanceDirectory = ""
      self.stageInTarFile    = ""


   def recordGridResourceUse(self):
      reFiles = re.compile(self.timeHistoryLogs['jobGridResource'] + "_[0-9]+$")
      dirFiles = os.listdir(self.instanceDirectory)
      matchingFiles = filter(reFiles.search,dirFiles)
      if len(matchingFiles) > 0:
         matchingFiles.sort()
         for matchingFile in matchingFiles:
            if len(matchingFile) > len(self.timeHistoryLogs['jobGridResource']):
               jobIndex = int(matchingFile[len(self.timeHistoryLogs['jobGridResource'])+1:])
            else:
               jobIndex = 1

            if not jobIndex in self.jobStatistics:
               self.jobStatistics[jobIndex] = JobStatistic(self.nCpus)

            fpGridResource = open(os.path.join(self.instanceDirectory,matchingFile),'r')
            if fpGridResource:
               gridResource = fpGridResource.readline().strip()
               fpGridResource.close()
               gridType = gridResource.split()[0]
               if   gridType == 'gt2':
                  gateKeeper = gridResource.split()[1].split('/')[0]
                  jobManager = gridResource.split()[1].split('/')[1].split('-')[1].upper()
               elif gridType == 'gt4':
                  gateKeeper = gridResource.split()[1]
                  jobManager = gridResource.split()[2]
               elif gridType == 'gt5':
                  gateKeeper = gridResource.split()[1].split('/')[0]
                  jobManager = gridResource.split()[1].split('/')[1].split('-')[1].upper()

               self.jobStatistics[jobIndex]['venue']                  = gateKeeper
               self.jobStatistics[jobIndex]['jobSubmissionMechanism'] = self.venueMechanism + jobManager

            os.remove(os.path.join(self.instanceDirectory,matchingFile))
      else:
         log(self.timeHistoryLogs['jobGridResource'] + " not present in " + self.instanceDirectory)


   def logGridHistories(self):

      fpGridHistoryLog = open(self.timeHistoryLogs['hubGridJobHistoryLogPath'],'a')

      reFiles = re.compile(self.timeHistoryLogs['jobGridHistory'] + "_[0-9]+$")
      dirFiles = os.listdir(self.instanceDirectory)
      matchingFiles = filter(reFiles.search,dirFiles)
      if len(matchingFiles) > 0:
         matchingFiles.sort()
         for matchingFile in matchingFiles:
            if len(matchingFile) > len(self.timeHistoryLogs['jobGridHistory']):
               jobIndex = int(matchingFile[len(self.timeHistoryLogs['jobGridHistory'])+1:])
            else:
               jobIndex = 1

            if not jobIndex in self.jobStatistics:
               self.jobStatistics[jobIndex] = JobStatistic(self.nCpus)

            if fpGridHistoryLog:
               fpGridHistory = open(os.path.join(self.instanceDirectory,matchingFile),'r')
               if fpGridHistory:
                  gridHistory = fpGridHistory.readline()
                  fpGridHistoryLog.write("%s_%s\t%d\t%s\t%s\n" % (self.localJobId,self.instanceId,jobIndex, \
                                                                  gridHistory.strip(),self.enteredCommand))
                  fpGridHistory.close()
            os.remove(os.path.join(self.instanceDirectory,matchingFile))

      reFiles = re.compile(self.timeHistoryLogs['jobGridHistory'] + "_[0-9]+_$")
      dirFiles = os.listdir(self.instanceDirectory)
      matchingFiles = filter(reFiles.search,dirFiles)
      if len(matchingFiles) > 0:
         matchingFiles.sort()
         for matchingFile in matchingFiles:
            if len(matchingFile) > len(self.timeHistoryLogs['jobGridHistory']):
               jobIndex = int(matchingFile[len(self.timeHistoryLogs['jobGridHistory'])+1:-1])
            else:
               jobIndex = 1

            if not jobIndex in self.jobStatistics:
               self.jobStatistics[jobIndex] = JobStatistic(self.nCpus)

            if fpGridHistoryLog:
               fpGridHistory = open(os.path.join(self.instanceDirectory,matchingFile),'r')
               if fpGridHistory:
                  jobInProgressSite,jobInProgressSetupCompleted = fpGridHistory.readline().rstrip().split()
                  jobInProgressCompleted = int(time.mktime(datetime.datetime.utcnow().timetuple()))
                  jobInProgressExitStatus = -3
                  fpGridHistoryLog.write("%s_%s\t%d\t%s\t%s\t%d\t%d\t%s\n" % (self.localJobId,self.instanceId,jobIndex, \
                                                                              jobInProgressSite,jobInProgressSetupCompleted, \
                                                                              jobInProgressCompleted,jobInProgressExitStatus, \
                                                                              self.enteredCommand))
                  fpGridHistory.close()
            os.remove(os.path.join(self.instanceDirectory,matchingFile))

      if fpGridHistoryLog:
         fpGridHistoryLog.close()


   def logGridJobId(self):
      reFiles = re.compile(self.timeHistoryLogs['jobGridJobId'] + "_[0-9]+$")
      dirFiles = os.listdir(self.instanceDirectory)
      matchingFiles = filter(reFiles.search,dirFiles)
      if len(matchingFiles) > 0:
         fpGridJobIdLog = open(self.timeHistoryLogs['hubGridJobIdLogPath'],'a')
         matchingFiles.sort()
         for matchingFile in matchingFiles:
            if len(matchingFile) > len(self.timeHistoryLogs['jobGridJobId']):
               jobIndex = int(matchingFile[len(self.timeHistoryLogs['jobGridJobId'])+1:])
            else:
               jobIndex = 1

            if not jobIndex in self.jobStatistics:
               self.jobStatistics[jobIndex] = JobStatistic(self.nCpus)

            if fpGridJobIdLog:
               fpGridJobId = open(os.path.join(self.instanceDirectory,matchingFile),'r')
               if fpGridJobId:
                  gridJobId = fpGridJobId.readline()
                  fpGridJobIdLog.write("%s_%s %s\n" % (self.localJobId.lstrip('0'),self.instanceId,gridJobId.strip()))
                  fpGridJobId.close()
            os.remove(os.path.join(self.instanceDirectory,matchingFile))
         if fpGridJobIdLog:
            fpGridJobIdLog.close()
      else:
         log(self.timeHistoryLogs['jobGridJobId'] + " not present in " + self.instanceDirectory)


   def executeCommand(self,
                      command,
                      streamOutput=False):
      child = subprocess.Popen(command,shell=True,bufsize=self.bufferSize,
                               stdout=subprocess.PIPE,
                               stderr=subprocess.PIPE,
                               close_fds=True)
      self.childPid = child.pid
      childout      = child.stdout
      childoutFd    = childout.fileno()
      childerr      = child.stderr
      childerrFd    = childerr.fileno()

      outEOF = False
      errEOF = False

      outData = []
      errData = []

      while True:
         toCheck = []
         if not outEOF:
            toCheck.append(childoutFd)
         if not errEOF:
            toCheck.append(childerrFd)
         try:
            ready = select.select(toCheck,[],[]) # wait for input
         except select.error,err:
            ready = {}
            ready[0] = []
         if childoutFd in ready[0]:
            outChunk = os.read(childoutFd,self.bufferSize)
            if outChunk == '':
               outEOF = True
            outData.append(outChunk)
            if streamOutput:
               sys.stdout.write(outChunk)
               sys.stdout.flush()

         if childerrFd in ready[0]:
            errChunk = os.read(childerrFd,self.bufferSize)
            if errChunk == '':
               errEOF = True
            errData.append(errChunk)
            if streamOutput:
               sys.stderr.write(errChunk)
               sys.stderr.flush()

         if outEOF and errEOF:
            break

      pid,err = os.waitpid(self.childPid,0)
      self.childPid = 0
      if err != 0:
         if   os.WIFSIGNALED(err):
            log("%s failed w/ signal %d" % (command,os.WTERMSIG(err)))
         else:
            if os.WIFEXITED(err):
               err = os.WEXITSTATUS(err)
            log("%s failed w/ exit code %d" % (command,err))
         if not streamOutput:
            log("%s" % ("".join(errData)))

      return(err,"".join(outData),"".join(errData))


   def executeSSHCommand(self,
                         sshCommand,
                         tunnelDesignator,
                         streamOutput=False):

      minimumDelay = 1       #  1 2 4 8 16 32 64 128 256
      maximumDelay = 256
      updateFrequency = 1
      maximumDelayTime = 900

      delayTime = 0
      sleepTime = minimumDelay
      nDelays = 0

      if tunnelDesignator != "":
         self.remoteTunnelMonitor.incrementTunnelUse(tunnelDesignator)
      exitStatus,stdOutput,stdError = self.executeCommand(sshCommand,streamOutput)

      while exitStatus and (stdError.count("You don't exist, go away!") > 0):
         nDelays += 1
         time.sleep(sleepTime)
         delayTime += sleepTime
         if nDelays == updateFrequency:
            nDelays = 0
            sleepTime *= 2
            if sleepTime > maximumDelay:
               sleepTime = maximumDelay

         exitStatus,stdOutput,stdError = self.executeCommand(sshCommand,streamOutput)

         if delayTime >= maximumDelayTime:
            break

      if tunnelDesignator != "":
         self.remoteTunnelMonitor.decrementTunnelUse(tunnelDesignator)

      return(exitStatus,stdOutput,stdError)


   def executeLaunchCommand(self,
                            launchCommand,
                            streamOutput=False):

      minimumDelay = 1       #  1 2 4 8 16 32 64 128 256
      maximumDelay = 256
      updateFrequency = 1
      maximumDelayTime = 900

      delayTime = 0
      sleepTime = minimumDelay
      nDelays = 0
      exitStatus,stdOutput,stdError = self.executeCommand(launchCommand,streamOutput)

      while exitStatus and ((stdError.count("qsub: cannot connect to server") > 0) or (stdError.count("ldap-nss.c") > 0)):
         nDelays += 1
         time.sleep(sleepTime)
         delayTime += sleepTime
         if nDelays == updateFrequency:
            nDelays = 0
            sleepTime *= 2
            if sleepTime > maximumDelay:
               sleepTime = maximumDelay

         exitStatus,stdOutput,stdError = self.executeCommand(launchCommand,streamOutput)

         if delayTime >= maximumDelayTime:
            break

      return(exitStatus,stdOutput,stdError)


   def updateVenue(self,
                   executionVenue):
      if self.remoteBatchSystem == 'FACTORY':
         self.jobStatistics[self.jobIndex]['venue'] = executionVenue
      if self.remoteBatchSystem == 'PEGASUS':
         self.jobStatistics[self.jobIndex]['venue'] = executionVenue


   def getVenue(self):
      executionVenue = self.jobStatistics[self.jobIndex]['venue']

      return(executionVenue)


   def recordJobStatisticTime(self,
                              statistic,
                              timeFileBasename):
      reFiles = re.compile(timeFileBasename + "(_[0-9]+)?$")
      dirFiles = os.listdir(self.instanceDirectory)
      matchingFiles = filter(reFiles.search,dirFiles)
      if len(matchingFiles) > 0:
         foundIndexes = []
         matchingFiles.sort()
         for matchingFile in matchingFiles:
            if len(matchingFile) > len(timeFileBasename):
               jobIndex = int(matchingFile[len(timeFileBasename)+1:])
            else:
               jobIndex = 1
            foundIndexes.append(jobIndex)

            if not jobIndex in self.jobStatistics:
               self.jobStatistics[jobIndex] = JobStatistic(self.nCpus)

            if(os.path.getsize(os.path.join(self.instanceDirectory,matchingFile))):
               fpTime = open(os.path.join(self.instanceDirectory,matchingFile),'r')
               if fpTime:
                  self.jobStatistics[jobIndex][statistic] = float(fpTime.readline())
                  fpTime.close()
            os.remove(os.path.join(self.instanceDirectory,matchingFile))
         if self.remoteBatchSystem != 'PEGASUS':
            maximumJobIndex = max(self.jobStatistics.keys())
            if len(foundIndexes) != maximumJobIndex:
               for missingIndex in set(range(1,maximumJobIndex+1)).difference(set(foundIndexes)):
                  log(timeFileBasename + '_' + str(missingIndex) + " is missing")
         del foundIndexes
      else:
         log(timeFileBasename + " not present in " + self.instanceDirectory)


   def recordJobStatisticTimer(self,
                               timeFileBasename):
      reFiles = re.compile(timeFileBasename + "(_[0-9]+)?$")
      dirFiles = os.listdir(self.instanceDirectory)
      matchingFiles = filter(reFiles.search,dirFiles)
      if len(matchingFiles) > 0:
         foundIndexes = []
         matchingFiles.sort()
         for matchingFile in matchingFiles:
            if len(matchingFile) > len(timeFileBasename):
               jobIndex = int(matchingFile[len(timeFileBasename)+1:])
            else:
               jobIndex = 1
            foundIndexes.append(jobIndex)

            if not jobIndex in self.jobStatistics:
               self.jobStatistics[jobIndex] = JobStatistic(self.nCpus)

            venue       = None
            jobId       = None
            event       = None
            realTime    = 0.
            userTime    = 0.
            sysTime     = 0.
            elapsedTime = 0.
            waitTime    = 0.
            exitStatus  = 0
            exitSignal  = 0

            fpTimer = open(os.path.join(self.instanceDirectory,matchingFile),'r')
            if fpTimer:
               while 1:
                  line = fpTimer.readline()
                  if not line:
                     break
                  line = line.strip()

                  if line != "":
                     parts = line.split()
                     if len(parts) == 2:
                        timeType = parts[0]
                        try:
                           timeUsed = float(parts[1])

                           if   timeType == 'real':
                              realTime    = max(realTime,timeUsed)
                           elif timeType == 'user':
                              userTime    = userTime + timeUsed
                           elif timeType == 'sys':
                              sysTime     = sysTime + timeUsed
                           elif timeType == 'elapsed':
                              elapsedTime = elapsedTime + timeUsed
                           elif timeType == 'wait':
                              waitTime    = waitTime + timeUsed
                           elif timeType == 'jobId':
                              jobId       = parts[1]
                        except:
                           if   timeType == 'site':
                              venue = parts[1]
                           elif timeType == 'jobId':
                              jobId = parts[1]
                           elif timeType == 'event':
                              event = parts[1]
                     else:
                        if len(parts) > 2:
                           if parts[-2] == "status":
                              exitStatus = int(float(parts[-1]))
                           if parts[-2] == "signal":
# Killed by signal 2.
                              exitStatus = 1
                              exitSignal = int(float(parts[-1]))
                        if re.search("Killed",line):
                           exitStatus = 1
                           exitSignal = 15

               if self.jobStatistics[jobIndex]['exitCode'] == 2:
                  if exitSignal > 0:
                     self.jobStatistics[jobIndex]['exitCode'] = exitStatus << 8 | exitSignal
                  else:
                     self.jobStatistics[jobIndex]['exitCode'] = exitStatus
               self.jobStatistics[jobIndex]['realTime']       = realTime
               self.jobStatistics[jobIndex]['userTime']       = userTime
               self.jobStatistics[jobIndex]['sysTime']        = sysTime
               self.jobStatistics[jobIndex]['elapsedRunTime'] = elapsedTime
               self.jobStatistics[jobIndex]['waitingTime']    = waitTime
               if venue and not self.jobStatistics[jobIndex]['venue']:
                  self.jobStatistics[jobIndex]['venue']             = venue
               if jobId:
                  self.jobStatistics[jobIndex]['remoteJobIdNumber'] = jobId
               if event:
                  self.jobStatistics[jobIndex]['event']             = event

               fpTimer.close()
            os.remove(os.path.join(self.instanceDirectory,matchingFile))
         if self.remoteBatchSystem != 'PEGASUS':
            maximumJobIndex = max(self.jobStatistics.keys())
            if len(foundIndexes) != maximumJobIndex:
               for missingIndex in set(range(1,maximumJobIndex+1)).difference(set(foundIndexes)):
                  log(timeFileBasename + '_' + str(missingIndex) + " is missing")
         del foundIndexes
      else:
         if self.jobSubmitted:
            log(timeFileBasename + " not present in " + self.instanceDirectory)


   def recordJobStatistics(self):
      self.recordJobStatisticTimer(self.timeHistoryLogs['timeResults'])

      if self.jobSubmitted:
         self.recordJobStatisticTime('transferCompleteTime',self.timeHistoryLogs['timestampTransferred'])
         self.recordJobStatisticTime('jobStartedTime',self.timeHistoryLogs['timestampStart'])
         self.recordJobStatisticTime('jobFinshedTime',self.timeHistoryLogs['timestampFinish'])

      for jobIndex in self.jobStatistics:
         self.jobStatistics[jobIndex].setWaitingTime()
         self.jobStatistics[jobIndex].setElapsedRunTime()


   def wasJobSuccessful(self):
      success = False
      for jobIndex in self.jobStatistics:
         if self.jobStatistics[jobIndex]['exitCode'] == 0:
            success = True
            break

      return(success)


   def removeJobStatisticsFiles(self):
      for timeFileBasename in self.timeHistoryLogs['timeResults'], \
                              self.timeHistoryLogs['timestampTransferred'], \
                              self.timeHistoryLogs['timestampStart'], \
                              self.timeHistoryLogs['timestampFinish']:
         reFiles = re.compile(timeFileBasename + "(_[0-9]+)?$")
         dirFiles = os.listdir(self.instanceDirectory)
         matchingFiles = filter(reFiles.search,dirFiles)
         if len(matchingFiles) > 0:
            for matchingFile in matchingFiles:
               os.remove(os.path.join(self.instanceDirectory,matchingFile))


   def stageFilesToInstanceDirectory(self):
      if self.instanceDirectory != "":
         if self.instanceDirectory != os.getcwd():
            if os.path.isdir(self.instanceDirectory):
               if self.stageInTarFile != "":
                  srcPath = os.path.join(os.getcwd(),self.stageInTarFile)
                  if(os.path.exists(srcPath)):
                     dstPath = os.path.join(self.instanceDirectory,self.stageInTarFile)
                     os.rename(srcPath,dstPath)


   def removeInstanceDirectory(self):
      if self.instanceDirectory != os.getcwd():
         shutil.rmtree(self.instanceDirectory,True)
         jobDirectory = os.path.dirname(self.instanceDirectory)
         if jobDirectory != os.getcwd():
            try:
               os.rmdir(jobDirectory)
            except:
               pass


   def moveTree(self,
                src,
                dst,
                symlinks=False):
      if os.path.isdir(src):
         if(os.path.exists(dst)):
            if not os.path.isdir(dst):
               log("moveTree: %s must be a directory" % (dst))
               return
         else:
            os.mkdir(dst)
         names = os.listdir(src)
         for name in names:
            srcPath = os.path.join(src,name)
            dstPath = os.path.join(dst,name)
            try:
               if symlinks and os.path.islink(srcPath):
                  linkto = os.readlink(srcPath)
                  os.symlink(linkto,dstPath)
               elif os.path.isdir(srcPath):
                  self.moveTree(srcPath,dstPath,symlinks)
               else:
                  os.rename(srcPath,dstPath)
            except (IOError,os.error),why:
               log("moveTree: Can't move %s to %s: %s" % (srcPath,dstPath,str(why)))
      else:
         log("moveTree: %s must be a directory" % (src))


   def recoverFiles(self,
                    ingoreInstanceFiles):
      if self.instanceDirectory != os.getcwd():
         if(os.path.exists(self.instanceDirectory)):
            jobDirectory = os.getcwd()
            dirFiles = os.listdir(self.instanceDirectory)
            if ingoreInstanceFiles:
               reIgnoreFiles = re.compile(".*%s_%s.*" % (self.localJobId,self.instanceId))
               ignoreFiles = filter(reIgnoreFiles.search,dirFiles)
               for ignoreFile in ignoreFiles:
                  dirFiles.remove(ignoreFile)

            for dirFile in dirFiles:
               srcFile = os.path.join(self.instanceDirectory,dirFile)
               dstFile = os.path.join(jobDirectory,dirFile)
               try:
                  if os.path.isdir(srcFile):
                     self.moveTree(srcFile,dstFile)
                  else:
                     os.rename(srcFile,dstFile)
               except:
                  log("%s recovery failed" % (srcFile))


   def recoverStdFiles(self):
      if self.instanceDirectory != os.getcwd():
         for fileType in 'stderr','stdout','FAILURE':
            fileToRecover = os.path.join(self.instanceDirectory,"%s_%s.%s" % (self.localJobId,self.instanceId,fileType))
            if(os.path.exists(fileToRecover)):
               fileToAppend = "%s.%s" % (self.localJobId,fileType)
               fpAppend = open(fileToAppend,'a')
               if fpAppend:
                  fpRecover = open(fileToRecover,'r')
                  if fpRecover:
                     recoverText = fpRecover.readlines()
                     fpRecover.close()
                     fpAppend.writelines(recoverText)
                  fpAppend.close()


