#
# @package      hubzero-submit-distributor
# @file         VenueMechanismCore.py
# @author       Steve Clark <clarks@purdue.edu>
# @copyright    Copyright 2004-2011 Purdue University. All rights reserved.
# @license      http://www.gnu.org/licenses/lgpl-3.0.html LGPLv3
#
# Copyright (c) 2004-2011 Purdue University
# All rights reserved.
#
# This file is part of: The HUBzero(R) Platform for Scientific Collaboration
#
# The HUBzero(R) Platform for Scientific Collaboration (HUBzero) is free
# software: you can redistribute it and/or modify it under the terms of
# the GNU Lesser General Public License as published by the Free Software
# Foundation, either version 3 of the License, or (at your option) any
# later version.
#
# HUBzero is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
# HUBzero is a registered trademark of Purdue University.
#
import sys
import os
import select
import popen2
import time
import re
import sets

from LogMessage   import logID as log
from JobStatistic import JobStatistic

class VenueMechanismCore:
   def __init__(self):
      self.childPid   = 0
      self.bufferSize = 4096

      self.nCpus                = 0
      self.jobStatistics        = {}
      self.scriptsCreated       = False
      self.filesSent            = False
      self.jobSubmitted         = False
      self.filesRetrieved       = False
      self.filesCleanedup       = False
      self.scriptsKilled        = False
      self.localJobId           = ""
      self.trial                = -1
      self.trialDirectory       = ""
      self.timestampTransferred = ""
      self.timestampStart       = ""
      self.timestampFinish      = ""
      self.timeResults          = ""


   def executeCommand(self,
                      command,
                      streamOutput=False):
      child = popen2.Popen3(command,1)
      self.childPid = child.pid
      child.tochild.close() # don't need to talk to child
      childout      = child.fromchild
      childoutFd    = childout.fileno()
      childerr      = child.childerr
      childerrFd    = childerr.fileno()

      outEOF = errEOF = 0

      outData = []
      errData = []

      while 1:
         toCheck = []
         if not outEOF:
            toCheck.append(childoutFd)
         if not errEOF:
            toCheck.append(childerrFd)
         ready = select.select(toCheck,[],[]) # wait for input
         if childoutFd in ready[0]:
            outChunk = os.read(childoutFd,self.bufferSize)
            if outChunk == '':
               outEOF = 1
            outData.append(outChunk)
            if streamOutput:
               sys.stdout.write(outChunk)
               sys.stdout.flush()

         if childerrFd in ready[0]:
            errChunk = os.read(childerrFd,self.bufferSize)
            if errChunk == '':
               errEOF = 1
            errData.append(errChunk)
            if streamOutput:
               sys.stderr.write(errChunk)
               sys.stderr.flush()

         if outEOF and errEOF:
            break

      err = child.wait()
      self.childPid = 0
      if err != 0:
         if os.WIFSIGNALED(err):
            log("%s failed w/ exit code %d signal %d" % (command,os.WEXITSTATUS(err),os.WTERMSIG(err)))
         else:
            err = os.WEXITSTATUS(err)
            log("%s failed w/ exit code %d" % (command,err))
         if not streamOutput:
            log("%s" % ("".join(errData)))

      return(err,"".join(outData),"".join(errData))


   def executeSSHCommand(self,
                         sshCommand,
                         remoteTunnelMonitor,
                         tunnelDesignator,
                         streamOutput=False):

      minimumDelay = 1       #  1 2 4 8 16 32 64 128 256
      maximumDelay = 256
      updateFrequency = 1
      maximumDelayTime = 900

      delayTime = 0
      sleepTime = minimumDelay
      nDelays = 0

      if tunnelDesignator != "":
         remoteTunnelMonitor.incrementTunnelUse(tunnelDesignator)
      exitStatus,stdOutput,stdError = self.executeCommand(sshCommand,streamOutput)

      while exitStatus and (stdError.count("You don't exist, go away!") > 0):
         nDelays += 1
         time.sleep(sleepTime)
         delayTime += sleepTime
         if nDelays == updateFrequency:
            nDelays = 0
            sleepTime *= 2
            if sleepTime > maximumDelay:
               sleepTime = maximumDelay

         exitStatus,stdOutput,stdError = self.executeCommand(sshCommand,streamOutput)

         if delayTime >= maximumDelayTime:
            break

      if tunnelDesignator != "":
         remoteTunnelMonitor.decrementTunnelUse(tunnelDesignator)

      return(exitStatus,stdOutput,stdError)


   def executeLaunchCommand(self,
                            launchCommand,
                            streamOutput=False):

      minimumDelay = 1       #  1 2 4 8 16 32 64 128 256
      maximumDelay = 256
      updateFrequency = 1
      maximumDelayTime = 900

      delayTime = 0
      sleepTime = minimumDelay
      nDelays = 0
      exitStatus,stdOutput,stdError = self.executeCommand(launchCommand,streamOutput)

      while exitStatus and ((stdError.count("qsub: cannot connect to server") > 0) or (stdError.count("ldap-nss.c") > 0)):
         nDelays += 1
         time.sleep(sleepTime)
         delayTime += sleepTime
         if nDelays == updateFrequency:
            nDelays = 0
            sleepTime *= 2
            if sleepTime > maximumDelay:
               sleepTime = maximumDelay

         exitStatus,stdOutput,stdError = self.executeCommand(launchCommand,streamOutput)

         if delayTime >= maximumDelayTime:
            break

      return(exitStatus,stdOutput,stdError)


   def recordJobStatisticTime(self,
                              statistic,
                              timeFileBasename):
      reFiles = re.compile(timeFileBasename + "(_[0-9]+)?$")
      dirFiles = os.listdir(self.trialDirectory)
      matchingFiles = filter(reFiles.search,dirFiles)
      if len(matchingFiles) > 0:
         foundIndexes = []
         matchingFiles.sort()
         for matchingFile in matchingFiles:
            if len(matchingFile) > len(timeFileBasename):
               jobIndex = int(matchingFile[len(timeFileBasename)+1:])
            else:
               jobIndex = 1
            foundIndexes.append(jobIndex)

            if not jobIndex in self.jobStatistics:
               self.jobStatistics[jobIndex] = JobStatistic(self.nCpus)

            if(os.path.getsize(os.path.join(self.trialDirectory,matchingFile))):
               fpTime = open(os.path.join(self.trialDirectory,matchingFile),'r')
               if fpTime:
                  self.jobStatistics[jobIndex][statistic] = float(fpTime.readline())
                  fpTime.close()
            os.remove(os.path.join(self.trialDirectory,matchingFile))
         maximumJobIndex = max(self.jobStatistics.keys())
         if len(foundIndexes) != maximumJobIndex:
            for missingIndex in sets.Set(range(1,maximumJobIndex+1)).difference(sets.Set(foundIndexes)):
               log(timeFileBasename + '_' + str(missingIndex) + " is missing")
         del foundIndexes
      else:
         log(timeFileBasename + " not present")


   def recordJobStatisticTimer(self,
                               timeFileBasename):
      reFiles = re.compile(timeFileBasename + "(_[0-9]+)?$")
      dirFiles = os.listdir(self.trialDirectory)
      matchingFiles = filter(reFiles.search,dirFiles)
      if len(matchingFiles) > 0:
         foundIndexes = []
         matchingFiles.sort()
         for matchingFile in matchingFiles:
            if len(matchingFile) > len(timeFileBasename):
               jobIndex = int(matchingFile[len(timeFileBasename)+1:])
            else:
               jobIndex = 1
            foundIndexes.append(jobIndex)

            if not jobIndex in self.jobStatistics:
               self.jobStatistics[jobIndex] = JobStatistic(self.nCpus)

            realTime   = 0.
            userTime   = 0.
            sysTime    = 0.
            exitStatus = 0
            exitSignal = 0

            fpTimer = open(os.path.join(self.trialDirectory,matchingFile),'r')
            if fpTimer:
               while 1:
                  line = fpTimer.readline()
                  if not line:
                     break
                  line = line.strip()

                  if line != "":
                     parts = line.split()
                     if len(parts) == 2:
                        timeType = parts[0]
                        timeUsed = float(parts[1])

                        if   timeType == 'real':
                           realTime = max(realTime,timeUsed)
                        elif timeType == 'user':
                           userTime = userTime + timeUsed
                        elif timeType == 'sys':
                           sysTime  = sysTime + timeUsed
                     else:
                        if len(parts) > 2:
                           if parts[-2] == "status":
                              exitStatus = int(float(parts[-1]))
                           if parts[-2] == "signal":
# Killed by signal 2.
                              exitStatus = 1
                              exitSignal = int(float(parts[-1]))
                        if re.search("Killed",line):
                           exitStatus = 1
                           exitSignal = 15

               if exitSignal > 0:
                  self.jobStatistics[jobIndex]['exitCode'] = exitStatus << 8 | exitSignal
               else:
                  self.jobStatistics[jobIndex]['exitCode'] = exitStatus
               self.jobStatistics[jobIndex]['realTime'] = realTime
               self.jobStatistics[jobIndex]['userTime'] = userTime
               self.jobStatistics[jobIndex]['sysTime']  = sysTime

               fpTimer.close()
            os.remove(os.path.join(self.trialDirectory,matchingFile))
         maximumJobIndex = max(self.jobStatistics.keys())
         if len(foundIndexes) != maximumJobIndex:
            for missingIndex in sets.Set(range(1,maximumJobIndex+1)).difference(sets.Set(foundIndexes)):
               log(timeFileBasename + '_' + str(missingIndex) + " is missing")
         del foundIndexes
      else:
         if self.jobSubmitted:
            log(timeFileBasename + " not present")


   def recordJobStatistics(self):
      self.recordJobStatisticTimer(self.timeResults)

      if self.jobSubmitted:
         self.recordJobStatisticTime('transferCompleteTime',self.timestampTransferred)
         self.recordJobStatisticTime('jobStartedTime',self.timestampStart)
         self.recordJobStatisticTime('jobFinshedTime',self.timestampFinish)

      for jobIndex in self.jobStatistics:
         self.jobStatistics[jobIndex].setWaitingTime()
         self.jobStatistics[jobIndex].setElapsedRunTime()


   def wasJobSuccessful(self):
      success = False
      for jobIndex in self.jobStatistics:
         if self.jobStatistics[jobIndex]['exitCode'] == 0:
            success = True
            break

      return(success)


   def removeJobStatisticsFiles(self):
      for timeFileBasename in self.timeResults, self.timestampTransferred, self.timestampStart, self.timestampFinish:
         reFiles = re.compile(timeFileBasename + "(_[0-9]+)?$")
         dirFiles = os.listdir(self.trialDirectory)
         matchingFiles = filter(reFiles.search,dirFiles)
         if len(matchingFiles) > 0:
            for matchingFile in matchingFiles:
               os.remove(os.path.join(self.trialDirectory,matchingFile))


   def moveTree(self,
                src,
                dst,
                symlinks=False):
      if os.path.isdir(src):
         if(os.path.exists(dst)):
            if not os.path.isdir(dst):
               log("moveTree: %s must be a directory" % (dst))
               return
         else:
            os.mkdir(dst)
         names = os.listdir(src)
         for name in names:
            srcPath = os.path.join(src,name)
            dstPath = os.path.join(dst,name)
            try:
               if symlinks and os.path.islink(srcPath):
                  linkto = os.readlink(srcPath)
                  os.symlink(linkto,dstPath)
               elif os.path.isdir(srcPath):
                  self.moveTree(srcPath,dstPath,symlinks)
               else:
                  os.rename(srcPath,dstPath)
            except (IOError,os.error),why:
               log("moveTree: Can't move %s to %s: %s" % (srcPath,dstPath,str(why)))
      else:
         log("moveTree: %s must be a directory" % (src))


   def recoverFiles(self,
                    ingoreTrialFiles):
      if self.trialDirectory != os.getcwd():
         jobDirectory = os.getcwd()
         dirFiles = os.listdir(self.trialDirectory)
         if ingoreTrialFiles:
            reIgnoreFiles = re.compile(".*%s_%02d.*" % (self.localJobId,self.trial))
            ignoreFiles = filter(reIgnoreFiles.search,dirFiles)
            for ignoreFile in ignoreFiles:
               dirFiles.remove(ignoreFile)

         for dirFile in dirFiles:
            srcFile = os.path.join(self.trialDirectory,dirFile)
            dstFile = os.path.join(jobDirectory,dirFile)
            try:
               if os.path.isdir(srcFile):
                  self.moveTree(srcFile,dstFile)
               else:
                  os.rename(srcFile,dstFile)
            except:
               log("%s recovery failed" % (srcFile))


   def recoverStdFiles(self):
      if self.trialDirectory != os.getcwd():
         for fileType in 'stderr','stdout','FAILURE':
            fileToRecover = os.path.join(self.trialDirectory,"%s_%02d.%s" % (self.localJobId,self.trial,fileType))
            if(os.path.exists(fileToRecover)):
               fileToAppend = "%s.%s" % (self.localJobId,fileType)
               fpAppend = open(fileToAppend,'a')
               if fpAppend:
                  fpRecover = open(fileToRecover,'r')
                  if fpRecover:
                     recoverText = fpRecover.readlines()
                     fpRecover.close()
                     fpAppend.writelines(recoverText)
                  fpAppend.close()

