#!/usr/bin/env python
#
# @package      hubzero-submit-monitors
# @file         BatchMonitors/monitorSLURM.py
# @copyright    Copyright (c) 2004-2020 The Regents of the University of California.
# @license      http://opensource.org/licenses/MIT MIT
#
# Copyright (c) 2004-2020 The Regents of the University of California.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#
# HUBzero is a registered trademark of The Regents of the University of California.
#
# ----------------------------------------------------------------------
#  monitorSLURM.py
#
#  script which monitors the SLURM queue and reports changes in job status
#
import sys
import os
import select
import subprocess
import re
import signal
import socket
import json
import traceback
import time

from LogMessage import openLog, log

SITEDESIGNATOR     = ""
USERNAME           = os.getenv("USER")
MONITORROOT        = os.path.dirname(os.path.abspath(__file__))
QSTATCOMMAND       = ['squeue','--noheader','-u',USERNAME,'-o','%.i %t %P']
MONITORLOGLOCATION = os.path.join(MONITORROOT,'log')
MONITORLOGFILENAME = "monitorSLURM.log"
LOGPATH            = os.path.join(MONITORLOGLOCATION,MONITORLOGFILENAME)
HISTORYFILENAME    = "monitorSLURM.history"
HISTORYFILEPATH    = os.path.join(MONITORROOT,HISTORYFILENAME)

SLEEPTIME       = 60
PAUSETIME       = 15.
MAXIMUMIDLETIME = 30*60


class QueueMonitor:
   def __init__(self,
                siteDesignator,
                userName,
                qstatCommand,
                historyFilePath,
                sleepTime,
                pauseTime,
                maximumIdleTime):
      self.siteDesignator                = siteDesignator
      self.userName                      = userName
      self.qstatCommand                  = qstatCommand
      self.historyFilePath               = historyFilePath
      self.sleepTime                     = sleepTime
      self.pauseTime                     = pauseTime
      self.maximumConsecutiveEmptyQueues = maximumIdleTime/sleepTime

      self.historyFile = None
      self.activeJobs  = {}
      self.bufferSize  = 4096

      self.updates                    = {}
      self.updateBucket               = 0
      self.updates[self.updateBucket] = []
      self.updatesPerBucket           = 50

      signal.signal(signal.SIGINT,self.sigINT_handler)
      signal.signal(signal.SIGHUP,self.sigHUP_handler)
      signal.signal(signal.SIGQUIT,self.sigQUIT_handler)
      signal.signal(signal.SIGABRT,self.sigABRT_handler)
      signal.signal(signal.SIGTERM,self.sigTERM_handler)


   def cleanup(self):
      if self.historyFile:
         self.historyFile.close()
         self.historyFile = None


   def sigGEN_handler(self,
                      signalNumber,
                      frame):
      self.cleanup()
      log("%s signal monitor stopped" % (self.siteDesignator))
      sys.exit(1)


   def sigINT_handler(self,
                      signalNumber,
                      frame):
      log("Received SIGINT!")
      self.sigGEN_handler(signalNumber,frame)


   def sigHUP_handler(self,
                      signalNumber,
                      frame):
      log("Received SIGHUP!")
      self.sigGEN_handler(signalNumber,frame)


   def sigQUIT_handler(self,
                       signalNumber,
                       frame):
      log("Received SIGQUIT!")
      self.sigGEN_handler(signalNumber,frame)


   def sigABRT_handler(self,
                       signalNumber,
                       frame):
      log("Received SIGABRT!")
      self.sigGEN_handler(signalNumber,frame)


   def sigTERM_handler(self,
                       signalNumber,
                       frame):
      log("Received SIGTERM!")
      self.sigGEN_handler(signalNumber,frame)


   def openHistory(self,
                   accessMode):
      if accessMode == 'r':
         if os.path.isfile(self.historyFilePath):
            self.historyFile = open(self.historyFilePath,accessMode)
         else:
            self.historyFile = None
      else:
         self.historyFile = open(self.historyFilePath,accessMode)


   def recordHistory(self,
                     jobId):
      self.historyFile.write("%s:%s %s %s %s\n" % (self.siteDesignator,str(jobId),self.activeJobs[jobId]['status'], \
                                                                                  self.activeJobs[jobId]['stage'], \
                                                                                  self.activeJobs[jobId]['queue']))
      self.historyFile.flush()
      jobState = {'jobId':jobId,
                  'status':self.activeJobs[jobId]['status'],
                  'stage':self.activeJobs[jobId]['stage'],
                  'queue':self.activeJobs[jobId]['queue'],
                  'tailFiles':{}}
      if 'tailFiles' in self.activeJobs[jobId]:
         for tailFile in self.activeJobs[jobId]['tailFiles']:
            if self.activeJobs[jobId]['status'] == 'D':
               text = ""
            else:
               tailPath = self.activeJobs[jobId]['tailFiles'][tailFile]['path']
               if tailPath:
                  nLines          = self.activeJobs[jobId]['tailFiles'][tailFile]['nLines']
                  lastEndPosition = self.activeJobs[jobId]['tailFiles'][tailFile]['endPosition']
                  text,endPosition = self.__tailFile(tailPath,nLines,lastEndPosition)
                  self.activeJobs[jobId]['tailFiles'][tailFile]['endPosition'] = endPosition
               else:
                  text = ""
            jobState['tailFiles'][tailFile] = text

#     log(str(jobState))
      if len(self.updates[self.updateBucket]) == self.updatesPerBucket:
         self.updateBucket += 1
         self.updates[self.updateBucket] = []

      self.updates[self.updateBucket].append(jobState)


   def loadHistory(self):
      self.openHistory('r')
      if self.historyFile:
         records = self.historyFile.readlines()
         for record in records:
            colon = record.find(':')
            if colon > 0:
               jobId,status,stage,queue = record[colon+1:].split()
               if status == 'D':
                  if jobId in self.activeJobs:
                     del self.activeJobs[jobId]
               else:
                  self.activeJobs[jobId] = {'jobId':jobId,
                                            'status':status,
                                            'stage':stage,
                                            'queue':queue}
         self.historyFile.close()
         self.historyFile = None


   def saveHistory(self):
      self.openHistory('w')
      if self.historyFile:
         for activeJob in self.activeJobs:
            self.historyFile.write("%s:%s %s %s %s\n" % (self.siteDesignator,str(activeJob),
                                                                             self.activeJobs[activeJob]['status'],
                                                                             self.activeJobs[activeJob]['stage'],
                                                                             self.activeJobs[activeJob]['queue']))
         self.historyFile.close()
         self.historyFile = None


   def executeQstatCommand(self,
                           command):
      child = subprocess.Popen(command,bufsize=self.bufferSize,
                               stdin=None,
                               stdout=subprocess.PIPE,
                               stderr=subprocess.PIPE,
                               close_fds=True)
      childPid   = child.pid
      childout   = child.stdout
      childoutFd = childout.fileno()
      childerr   = child.stderr
      childerrFd = childerr.fileno()

      outEOF = False
      errEOF = False

      outData = []
      errData = []

      while True:
         toCheck = []
         if not outEOF:
            toCheck.append(childoutFd)
         if not errEOF:
            toCheck.append(childerrFd)
         ready = select.select(toCheck,[],[],self.sleepTime) # wait for input
         if childoutFd in ready[0]:
            outChunk = os.read(childoutFd,self.bufferSize)
            if outChunk == '':
               outEOF = True
            outData.append(outChunk)

         if childerrFd in ready[0]:
            errChunk = os.read(childerrFd,self.bufferSize)
            if errChunk == '':
               errEOF = True
            errData.append(errChunk)

         if len(ready[0]) == 0:
            os.kill(childPid,signal.SIGTERM)

         if outEOF and errEOF:
            break

      pid,err = os.waitpid(childPid,0)
      childPid = 0
      if err != 0:
         if os.WIFSIGNALED(err):
            log("%s failed w/ signal %d" % (command,os.WTERMSIG(err)))
         else:
            if os.WIFEXITED(err):
               err = os.WEXITSTATUS(err)
            log("%s failed w/ exit code %d" % (command,err))
         log("%s" % ("".join(errData)))

      return(err,"".join(outData),"".join(errData))


   def filterQstat(self,
                   qstatRecords):
      filteredRecords = []
      for qstatRecord in qstatRecords:
         qstatRecord = qstatRecord.strip()

         jobId,status,queue = qstatRecord.split()
         filteredRecord = {'jobId':str(jobId),
                           'queue':queue,
                           'status':status,
                           'stage':'Simulation'}
         filteredRecords.append(filteredRecord)

      return(filteredRecords)


   @staticmethod
   def __tailFile(tailPath,
                  nLines,
                  lastEndPosition=0):
      text        = ""
      nTextLines  = 0
      endPosition = lastEndPosition
      bufsize = 2048
      if os.path.exists(tailPath):
         fileSize = os.stat(tailPath).st_size
         if fileSize > lastEndPosition:
            try:
               fpTail = open(tailPath,'r')
               try:
                  iter = 0
                  while True:
                     iter += 1
                     location = max(fileSize-bufsize*iter,lastEndPosition)
                     fpTail.seek(location)
                     data = fpTail.readlines()
                     endPosition = fpTail.tell()
                     if len(data) > nLines or location == lastEndPosition:
                        text = ''.join(data[-nLines:])
                        nTextLines += min(len(data),nLines)
                        if lastEndPosition > 0 and location > lastEndPosition+1:
                           text = "...\n" + text
                        del data
                        break
               except (IOError,OSError):
                  log("%s could not be read" % (tailPath))
               finally:
                  fpTail.close()
            except (IOError,OSError):
               log("%s could not be opened" % (tailPath))

      if nTextLines > 0:
         log("tailed %d lines of %s" % (nTextLines,tailPath))

      return(text,endPosition)


   def monitorQ(self):
      self.openHistory('a')
      consecutiveEmptyQueues = 0
      lastReportedActiveJobCount = -1

      readers = []
      readers.append(sys.stdin.fileno())
      writers = []
      writers.append(sys.stdout.fileno())

      while True:
         activeJobCount = len(self.activeJobs)
         if activeJobCount != lastReportedActiveJobCount:
            log("%d monitored jobs" % (activeJobCount))
         lastReportedActiveJobCount = activeJobCount

         self.updates                    = {}
         self.updateBucket               = 0
         self.updates[self.updateBucket] = []

         currentJobs   = {}
         completedJobs = []

         startReadTime = time.time()
         while time.time()-startReadTime <= self.sleepTime:
            if os.getppid() == 1:
               os.kill(os.getpid(),signal.SIGTERM)

            readyReaders,readyWriters,readyExceptions = select.select(readers,[],[],self.pauseTime) # wait for input
            if sys.stdin.fileno() in readyReaders:
               message = sys.stdin.readline()
               if message != "":
                  nNewJobs    = 0
                  nOrphanJobs = 0
                  while message != "":
                     try:
                        centralMessage = json.loads(message)
                     except:
                        log(traceback.format_exc())
                     else:
#                       log("READ: " + str(centralMessage))
                        if   centralMessage['messageType'] == 'newJobId':
                           newJob = centralMessage['remoteJobId']
#                          log("newJobId %s" % (newJob))
                           nNewJobs += 1
                           if not newJob in self.activeJobs:
                              self.activeJobs[newJob] = {'jobId':newJob,
                                                         'status':'N',
                                                         'stage':'Job',
                                                         'queue':'?',
                                                         'jobWorkDirectory':centralMessage.get('jobWorkDirectory','?'),
                                                         'localJobId':centralMessage.get('localJobId','?'),
                                                         'instanceId':centralMessage.get('instanceId','?'),
                                                         'runName':centralMessage.get('runName','?'),
                                                         'tailFiles':centralMessage.get('tailFiles',{})}
                              jobWorkDirectory = os.path.expandvars(os.path.expanduser(self.activeJobs[newJob]['jobWorkDirectory']))
                              if 'tailFiles' in self.activeJobs[newJob]:
                                 if 'runName' in self.activeJobs[newJob]:
                                    runName = self.activeJobs[newJob]['runName']
                                    if 'instanceId' in self.activeJobs[newJob]:
                                       instanceId = self.activeJobs[newJob]['instanceId']
                                       if "#STDOUT#" in self.activeJobs[newJob]['tailFiles']:
                                          stdFile = "%s_%s.stdout" % (runName,instanceId)
                                          self.activeJobs[newJob]['tailFiles']["#STDOUT#"]['path'] = os.path.join(jobWorkDirectory,
                                                                                                                  stdFile)
                                          self.activeJobs[newJob]['tailFiles']["#STDOUT#"]['endPosition'] = 0
                                       if "#STDERR#" in self.activeJobs[newJob]['tailFiles']:
                                          stdFile = "%s_%s.stderr" % (runName,instanceId)
                                          self.activeJobs[newJob]['tailFiles']["#STDERR#"]['path'] = os.path.join(jobWorkDirectory,
                                                                                                                  stdFile)
                                          self.activeJobs[newJob]['tailFiles']["#STDERR#"]['endPosition'] = 0
                                 for tailFile in self.activeJobs[newJob]['tailFiles']:
                                    if tailFile != "#STDOUT#" and tailFile != "#STDERR#":
                                       self.activeJobs[newJob]['tailFiles'][tailFile]['path'] = os.path.join(jobWorkDirectory,
                                                                                                             tailFile)
                                       self.activeJobs[newJob]['tailFiles'][tailFile]['endPosition'] = 0
                              self.recordHistory(newJob)
                              self.activeJobs[newJob]['status'] = 'n'
                           consecutiveEmptyQueues = 0
                        elif centralMessage['messageType'] == 'orphanJobId':
                           orphanJob = centralMessage['remoteJobId']
#                          log("orphanJobId %s" % (orphanJob))
                           nOrphanJobs += 1
                           if not orphanJob in self.activeJobs:
                              self.activeJobs[orphanJob] = {'jobId':orphanJob,
                                                            'status':'N',
                                                            'stage':'Job',
                                                            'queue':'?',
                                                            'jobWorkDirectory':'?',
                                                            'localJobId':centralMessage.get('localJobId','?'),
                                                            'instanceId':centralMessage.get('instanceId','?'),
                                                            'runName':centralMessage.get('runName','?'),
                                                            'tailFiles':{}}
                           self.activeJobs[orphanJob]['status'] = 'o'
                        elif centralMessage['messageType'] == 'pipeFlusher':
                           break
                        else:
                           log(centralMessage)

                     message = sys.stdin.readline()

                  if nNewJobs > 0:
                     log("%d newJobs" % (nNewJobs))
                  if nOrphanJobs > 0:
                     log("%d orphanJobs" % (nOrphanJobs))
               else:
                  self.cleanup()
                  log("%s empty job monitor stopped" % (self.siteDesignator))
                  sys.exit(0)

   #      CA  CANCELLED       Job was explicitly cancelled by the user or system administrator.  The job may or may  not  have  been
   #                          initiated.
   #      CD  COMPLETED       Job has terminated all processes on all nodes.
   #      CG  COMPLETING      Job is in the process of completing. Some processes on some nodes may still be active.
   #      F   FAILED          Job terminated with non-zero exit code or other failure condition.
   #      NF  NODE_FAIL       Job terminated due to failure of one or more allocated nodes.
   #      PD  PENDING         Job is awaiting resource allocation.
   #      R   RUNNING         Job currently has an allocation.
   #      S   SUSPENDED       Job has an allocation, but execution has been suspended.
   #      TO  TIMEOUT         Job terminated upon reaching its time limit.

         try:
            readyReaders,readyWriters,readyExceptions = select.select([],writers,[],self.pauseTime) # wait for input
         except:
            readyWriters = []
         if sys.stdout.fileno() in readyWriters:
            errStatus,qstatOutput,qstatError = self.executeQstatCommand(self.qstatCommand)
            if errStatus == 0:
               jobs = self.filterQstat(qstatOutput.splitlines())
               for job in jobs:
                  jobId = job['jobId']
                  currentJobs[jobId] = job

               if len(currentJobs) == 0:
                  consecutiveEmptyQueues += 1
               else:
                  consecutiveEmptyQueues = 0

               for activeJob in self.activeJobs:
                  if self.activeJobs[activeJob]['status'] == 'n':
                     self.activeJobs[activeJob]['status'] = 'N'
                     self.activeJobs[activeJob]['stage']  = 'Job'
                     self.activeJobs[activeJob]['queue']  = '?'
                  else:
                     if not activeJob in currentJobs:
                        self.activeJobs[activeJob]['status'] = 'D'
                        self.recordHistory(activeJob)
                        completedJobs.append(activeJob)

               for currentJob in currentJobs:
                  if not currentJob in self.activeJobs:
#                    log("recordHistory: currentJob[%s] not in activeJobs" % (currentJob))
                     self.activeJobs[currentJob] = currentJobs[currentJob]
                     self.recordHistory(currentJob)
                  else:
                     somethingChanged = False
                     for key in currentJobs[currentJob]:
                        if currentJobs[currentJob][key] != self.activeJobs[currentJob][key]:
                           somethingChanged = True
                     if somethingChanged:
#                       log("recordHistory: currentJob[%s] different than activeJob" % (currentJob))
                        for key in currentJobs[currentJob]:
                           self.activeJobs[currentJob][key] = currentJobs[currentJob][key]
                        self.recordHistory(currentJob)
                     elif 'tailFiles' in self.activeJobs[currentJob]:
                        for tailFile in self.activeJobs[currentJob]['tailFiles']:
                           tailPath = self.activeJobs[currentJob]['tailFiles'][tailFile]['path']
                           if os.path.exists(tailPath):
                              fileSize = os.stat(tailPath).st_size
                              if fileSize > self.activeJobs[currentJob]['tailFiles'][tailFile]['endPosition']:
#                                log("recordHistory: change in tailFiles")
                                 self.recordHistory(currentJob)
                                 break

                  if self.activeJobs[currentJob]['status'] == 'D':
                     completedJobs.append(currentJob)

               for completedJob in completedJobs:
                  del self.activeJobs[completedJob]

               del currentJobs
               del completedJobs

               if len(self.updates) > 0:
                  for bucket in self.updates:
                     nUpdates = len(self.updates[bucket])
                     if nUpdates > 0:
                        siteMessage = {'messageType':'siteUpdate','siteDesignator':self.siteDesignator,
                                       'nJobStates':nUpdates,'jobStates':self.updates[bucket]}
                        try:
#                          log("WRITE: " + str(siteMessage))
                           sys.stdout.write(json.dumps(siteMessage) + '\n')
                        except:
                           log("Site update of %d jobs failed" % (nUpdates))
                        else:
                           log("Site updated %d jobs" % (nUpdates))
                           sys.stdout.flush()

                        siteMessage = {'messageType':'pipeFlusher'}
                        try:
                           sys.stdout.write(json.dumps(siteMessage) + '\n')
                        except:
                           log("Site pipe flush failed")
                        else:
#                          log("Site pipe flushed")
                           sys.stdout.flush()
               else:
                  siteMessage = {'messageType':'pipeFlusher'}
                  try:
#                    log("WRITE: " + str(siteMessage))
                     sys.stdout.write(json.dumps(siteMessage) + '\n')
                  except:
                     log("Site pipe flush failed")
                  else:
#                    log("Site pipe flushed")
                     sys.stdout.flush()

               del self.updates
            else:
               log("Error %d in %s command:\n%s" % (errStatus,self.qstatCommand,qstatError))

            if self.historyFile:
               self.historyFile.close()
               self.historyFile = None
               self.saveHistory()
               self.openHistory('a')

            if consecutiveEmptyQueues == self.maximumConsecutiveEmptyQueues:
               self.cleanup()
               log("%s idle monitor stopped" % (self.siteDesignator))
               sys.exit(0)


if __name__ == '__main__':

   openLog(LOGPATH)

   if len(sys.argv) == 2:
      if SITEDESIGNATOR == "":
         siteDesignator = sys.argv[1]
      else:
         if sys.argv[1] != SITEDESIGNATOR:
            log("Site designators do not match.\n   internal name = %s\n   external name = %s" % (SITEDESIGNATOR,sys.argv[1]))
            siteDesignator = ""
         else:
            siteDesignator = sys.argv[1]
   else:
      siteDesignator = SITEDESIGNATOR

   if not siteDesignator:
      sys.exit(2)

   log("%s monitor started on %s" % (siteDesignator,socket.gethostname()))

   __queueMonitor__ = QueueMonitor(siteDesignator,USERNAME,QSTATCOMMAND,
                                   HISTORYFILEPATH,SLEEPTIME,PAUSETIME,MAXIMUMIDLETIME)

   __queueMonitor__.loadHistory()
   __queueMonitor__.saveHistory()
   __queueMonitor__.monitorQ()


