#!/usr/bin/python
# @package      hubzero-mw2-exec-service
# @file         maxwell_service
# @author       Pascal Meunier <pmeunier@purdue.edu>
# @copyright    Copyright (c) 2016-2017 HUBzero Foundation, LLC.
# @license      http://opensource.org/licenses/MIT MIT
#
# Based on previous work by Richard L. Kennell and Nicholas Kisseberth
#
# Copyright (c) 2016-2017 HUBzero Foundation, LLC.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#
# HUBzero is a registered trademark of HUBzero Foundation, LLC.
#

"""
Maxwell service script, to be run on hosts.  The exact services offered depend on the database
settings for that host, which consist of flags that can be ORed together.  This helps match
applications to hosts.  However, a host has no knowledge of which services it is supposed to offer;
attempting to invoke a service for which a host wasn't configured may produce unpredictable results.

This script does not access the SQL database, which simplifies the handling of forks and child
 exits.

Logging requirements: This script doesn't work if file handles are closed, like it's done for the maxwell
script.

Files used by this script:
EXEC_CONFIG_FILE : the path to the configuration information.
              Note that fixed paths are set in error.py.


log files for an application:
SERVICE_LOG_PATH/<session number>.err
SERVICE_LOG_PATH/<session number>.out

Configuration information for a particular session and application are stored
inside a user's directory:
resource file = homedir + "/data/sessions/<session number>/resources"

vncdir + "xstartup": "Editing the file $HOME/.vnc/xstartup allows you to change the applications
                    run at startup".  We keep it as a NOOP shell invocation.
vncdir + "pass.<display number>" : password information for VNC (8 bytes).  This file is fed to
             startxvnc via a pipe.
"""

import os
import sys
import time
import re
import stat
import subprocess
from hubzero.mw.log import dissociate, log, setup_log, log_exc, ttyprint, save_out
from hubzero.mw.support import check_rundir
from hubzero.mw.host import BasicHost
from hubzero.mw.constants import MW_USER, HOST_K, EXEC_CONFIG_FILE, SERVICE_LOG, \
    SESSION_K, VERBOSE, SHELLARG_REGEXP, GEOM_REGEXP, USER_REGEXP, SESSNAME_REGEXP, ALPHANUM_REGEXP
from hubzero.mw.errors import  MaxwellError, InputError, PrivateError

try:
  import pwd
  from hubzero.mw.support import get_dirlock, release_dirlock
except ImportError:
  IS_WINDOWS = True
  EXEC_HOST_DRIVE = "c:"
  from hubzero.mw.win_user_account import User_account
  from hubzero.mw.win_container import Container
else:  # do UNIX imports by default
  IS_WINDOWS = False
  EXEC_HOST_DRIVE = ""
  from hubzero.mw.user_account import User_account
  from hubzero.mw.container import Container

#=============================================================================
# Set up default parameter values.
#=============================================================================
notify_retries = 10
notify_timeout = 60
notify_hosts = 'localhost'

# Every OpenVZ host should have a unique number.  if not set, it is calculated in Containers.
machine_number = 0

CONTAINER_CONF = {}
SESSION_CONF = {}
HOST_CONF = {}
VIEW_PATH = "/var/run/vncproxyd/"
NOTIFY_SOURCE_PATH = "/var/run/mw-service/"
DEBUG = True

#=============================================================================
# Tell the caller that a session is finished.
#=============================================================================
def notify_command_finished(sessnum, hostname):
  """ return true if successful"""
  sessnum = int(sessnum)
  h = BasicHost(hostname, MW_USER, HOST_MERGED["NOTIFY_KEY"])
  status = h.ssh(["notify", "session",  "%d" % sessnum])
  if status != 0:
    log("Unable to notify %s@%s about %s;  status is %s" % (MW_USER, hostname, sessnum, str(status)))
    return False
  return True


#=============================================================================
# Move the session directory to a new name to indicate it's expired.
#=============================================================================
def expire_session_dir(user, session_id):

  if IS_WINDOWS: # use alternate path for windows execution hosts
    olddir = ("%s/%s" % (HOST_K["WIN_HOME_DRIVE"], user))
    log("olddir = '%s'" % olddir)
  else: # use for linux execution hosts
    account = User_account(user)
    olddir = account.session_dir(session_id)

  newdir = olddir + "-expired"

  # If using local storage for the session directory, move it instead of just renaming it
  try:
    local_session_dir = CONTAINER_CONF["LOCAL_SESSIONDIR"]
  except KeyError:
    local_session_dir = False
  if local_session_dir:
    local_dir = "/home/sessions/%s/%s" % (user, session_id)
    
    # remove symlink
    args = ['/bin/su', user, '-c', '/bin/rm -f %s' % olddir]
    subprocess.call(args)
    
    # move directory back to user home, with "-expired" suffix
    args = ['/bin/su', user, '-c', 'mv %s %s' % (local_dir, newdir)]
    if VERBOSE:
      log("Executing %s" % " ".join(args))
    p = subprocess.Popen(args)
    p.communicate()
    if p.returncode != 0:
      log("Could not move '%s' to '%s'" % (local_dir, account.homedir + '/data/sessions'))

  else:
    # delcmd = """su %s -c 'mv %s %s'""" % (user,olddir,newdir)
    if IS_WINDOWS:
      args = ['move /y %s %s' % (olddir, newdir)]
      move_path = ""
    else:
      args = ['su', user, '-c', 'mv %s %s' % (olddir, newdir)]
      move_path = "/bin:/usr/bin:/usr/bin/X11:/sbin:/usr/sbin"

    if VERBOSE:
      log("Executing %s" % " ".join(args))
    cmd_env = {
      "HOME":account.homedir,
      "LOGNAME":user,
      "PATH": move_path,
      "USER":user
    }
    # note it's possible the directories don't exist if there was an error starting the session
    retcode = subprocess.call(args, env = cmd_env)
    if retcode != 0:
      log("Unable to move the session directory")

#=============================================================================
# Get the list of hosts to try to notify when the command exits.
#=============================================================================
def get_notify_hosts():
  if notify_hosts == "" or notify_hosts == []:
    try:
      conn = os.environ["SSH_CONNECTION"]
      hostname = conn.split()[0]
      return [ hostname ]
    except KeyError:
      raise MaxwellError("get_notify_hosts: can't find SSH_CONNECTION")
  return notify_hosts.split(',')


#=============================================================================
# Start an X application
#=============================================================================
def startxapp(user, session_id, timeout, disp, command):
  """
    Command-level function.
  """
  vps = Container(disp, machine_number, CONTAINER_CONF)
  status = vps.get_status()
  if status.find("running") == -1:
    raise MaxwellError("Container is not running")
    
  # Get the session number by removing the optional suffix
  if session_id[-1] >= '0' and session_id[-1] <= '9':
    sessnum = int(session_id)
  else:
    sessnum = int(session_id[0:-1])
  if DEBUG:
    log("session is %d" % sessnum)

  # There used to be a 'notify' protocol and a 'stream' protocol.
  if command[0:7] == "notify ":
    command = command[7:]
    if DEBUG:
      log("notify found")
   
  if command[0:7] == "params ":
    command = command[7:]
    params = True
    if DEBUG:
      log("params found")
  else:
    params = False
    if DEBUG:
      log("params is False")

  # get user information in cache before dissociating, to mitigate race condition between
  # populating /etc/group in container, and starting SSH server with Virtual SSH.
  # otherwise, the groups command will return only "public".  See Hubzero ticket #1752
  if user != "anonymous":
    User_account(user)

  if DEBUG:
    log("going to dissociate")
  dissociate()

  for i in range(3, 1024):
    try:
      os.close(i)
    except OSError:
      pass

  # We used to dup stdin, stdout and stderr here
  # instead now we use subprocess.Popen facilities to redirect input/output/err to a file

  # At this point, we're the child of a child and dissociated from
  # the original session.  Invoke the command, wait for it to exit,
  # and then notify the remote caller.
  vps.invoke_unix_command(user, session_id, timeout, command, params, """%s/%d""" % (SESSION_MERGED["SERVICE_LOG_PATH"], sessnum))

  setup_log(SERVICE_LOG, 'cleanup')
  if VERBOSE:
    log ("invoke_unix_command %s, %s, %d, %s, %d is done" %(user, session_id, timeout, command, disp))

  # When we're done with the command, we're done with this VE.  Stop it.
  stopvnc(disp)
  if VERBOSE:
    log("Stopped display")

  if user != "anonymous":
    expire_session_dir(user, session_id)
  if saved_clientIP is not None and saved_clientIP != "":
    notify_hosts = [ saved_clientIP ]
  else:
    notify_hosts = get_notify_hosts()
  for i in range(0, notify_retries):
    notify_success = False
    for hostname in notify_hosts:
      log("notifying %s about session %d" % (hostname, sessnum))
      if notify_command_finished(sessnum, hostname):
        notify_success = True
    if notify_success:
      os._exit(0) # success!
      time.sleep(1)
  log("Unable to notify controller about session %d" % sessnum)
  os._exit(1)

#=============================================================================
# Kill a tree of processes in a depth-first manner.
#=============================================================================
def killtree(pid):
  """
    Command-level function.  Look at the status of processes, using the "proc" file system.
    Extract the parent pid and reconstruct process groups (build a tree).  Don't include
    root-related processes in the tree.  Then kill all processes in that tree, using progressively
    harsher signals.
  """
  proclist = {}
  dir_ls = os.listdir("/proc")
  for afile in dir_ls:
    try:
      ipid = int(afile)
    except ValueError:
      continue

    try:
      fh = open("/proc/" + afile + "/status")
    except IOError:
      continue

    ppid = 0
    while 1:
      try:
        line = fh.readline()
      except IOError:
        break
      #print "line is ", line
      arr = line.split()
      if arr[0] == "PPid:":
        try:
          ppid = int(arr[1])
        except (ValueError, TypeError):
          ppid = 0
          break
        if ppid == 0:
          break

      if arr[0] == "Uid:":
        for x in range(1, len(arr)):
          # if it's a root-related process, leave it alone
          if int(arr[x]) == 0:
            ppid = 0
        break

    if ppid != 0:
      #print "%d has parent %d" % (ipid,ppid)
      if proclist.has_key(ppid):
        proclist[ppid].append(ipid)
      else:
        proclist[ppid] = [ ipid ]

    fh.close()

  # Inner function to do some printing.
  def doprint(indent, proclist, pid):
    log("  "*indent + str(pid))
    if proclist.has_key(pid):
      for child in proclist[pid]:
        doprint(indent+1, proclist, child)

  if VERBOSE:
    doprint(0, proclist, pid)

  # Inner function to do the killing.
  def dokill(proclist, pid):
    # kill in depth-first search manner, killing the youngest first
    if proclist.has_key(pid):
      for child in proclist[pid]:
        dokill(proclist, child)
    sig = 0
    nextsig = {-1:-1, 0:1, 1:15, 15:9, 9:-1}
    while 1:
      try:
        fd = os.open("/proc/%d/status" % pid, os.O_RDONLY)
      except OSError:
        break
      os.close(fd)
      sig = nextsig[sig]
      if sig == -1:
        return
      log("Killing %d with %d" % (pid, sig))
      os.kill(pid, sig)
      time.sleep(0.5)

  dokill(proclist, pid)

#=============================================================================
# Kill all processes belonging to a particular user.
#=============================================================================
def killall(user):
  """print "Terminating all processes belonging to user '%s'" % user

    Command-level function.
  """
  uid = User_account(user).uid

  dir_ls = os.listdir("/proc")
  for afile in dir_ls:
    try:
      ipid = int(afile)
    except ValueError:
      continue

    try:
      fh = open("/proc/" + afile + "/status")
    except IOError:
      continue

    while 1:
      try:
        line = fh.readline()
      except IOError:
        break
      #print "line is ", line
      arr = line.split()

      if arr[0] == "Uid:":
        ok = 1
        for x in range(1, len(arr)):
          if int(arr[x]) != uid:
            ok = 0
        if ok == 1:
          os.kill(ipid, 9)
        break

    fh.close()

  sys.exit(0)


def screenshot(user, session_id, display):
  vps = Container(display, machine_number, CONTAINER_CONF)
  vps.screenshot(user, session_id)

def resize(disp, geom):
  vps = Container(disp, machine_number, CONTAINER_CONF)
  vps.resize(geom)
  if VERBOSE:
    log("vnc %s geometry changed to %s" % (disp, geom))

def mount_paths(display, paths):
  if VERBOSE:
    log("trying to mount %s on display %s" % (paths, display))
  vps = Container(display, machine_number, CONTAINER_CONF)
  vps.mount_paths(paths)

#=============================================================================
# Start a VNC server...
#=============================================================================

def startvnc(disp, geom, depth):
  """
    Command-level function.
     depth unused for now
  """
  if IS_WINDOWS:
    log("startvnc skipped for Windows execution host")
    return

  start_time = time.time()
  stage = 0 # for unwinding the start

  #Virtual Private Server "vps" a.k.a. container, a.k.a. display
  if saved_clientIP is not None and saved_clientIP != "":
    if CONTAINER_CONF["OVZ_SESSION_MOUNT"].find('%') != -1:
      try:
        CONTAINER_CONF["OVZ_SESSION_MOUNT"] = CONTAINER_CONF["OVZ_SESSION_MOUNT"] % saved_clientIP
      except TypeError:
        # ignore error if OVZ_SESSION_MOUNT doesn't have a format specifier for the IP address
        pass
      log("container mount path is %s" % CONTAINER_CONF["OVZ_SESSION_MOUNT"])
    if CONTAINER_CONF["OVZ_SESSION_CONF"].find('%') != -1:
      try:
        CONTAINER_CONF["OVZ_SESSION_CONF"] = CONTAINER_CONF["OVZ_SESSION_CONF"] % saved_clientIP
      except TypeError:
        # ignore error if OVZ_SESSION_CONF doesn't have a format specifier for the IP address
        pass
      log("container config path is %s" % CONTAINER_CONF["OVZ_SESSION_CONF"])
    if CONTAINER_CONF["OVZ_SESSION_UMOUNT"].find('%') != -1:
      try:
        CONTAINER_CONF["OVZ_SESSION_UMOUNT"] = CONTAINER_CONF["OVZ_SESSION_UMOUNT"] % saved_clientIP
      except TypeError:
        # ignore error if OVZ_SESSION_UMOUNT doesn't have a format specifier for the IP address
        pass
  vps = Container(disp, machine_number, CONTAINER_CONF)

  # get lock to arbitrate race conditions.  What if container is still being shut down?
  lock_path = HOST_MERGED["LOCK_PATH"] + "/lock_%d" % vps.veid
  get_dirlock(lock_path, expiration=300, retries=310, sleeptime=1)
  status = vps.get_status()
  if status.find("down") == -1:
    log("Container %d is not down" % disp)
    try:
      vps.stop()
    except StandardError:
      release_dirlock(lock_path)
      raise MaxwellError("Unable to stop container %d" %disp)
    log("Container %d stopped.  Hopefully there was no notify event that put it in the absent state, at which point there was another start command..." % disp)
  try:
    vps.create_xstartup()
    vps.read_passwd()

    # Sanitize the environment just in case.
    vps.umount()
    vps.stunnel()
    stage = 1

    vps.setup_template()
    if VERBOSE:
      end_time1 = time.time()
      log("stage 1 finished in %f seconds" % (end_time1 - start_time))
    stage = 2

    vps.delete_confs() # Delete in case dirty links were left over
    vps.create_confs() # Recreate the links.
    if VERBOSE:
      end_time2 = time.time()
      log("stage 2 finished in %f seconds" % (end_time2 - end_time1))
    stage = 3

    vps.start(geom) # both container and xvnc
    if VERBOSE:
      end_time3 = time.time()
      log("vps started; stage 3 finished in %f seconds" % (end_time3 - end_time2))
    stage = 4

    vps.set_ipaddress()
    if VERBOSE:
      log("ip address set")
    vps.start_filexfer()
    if VERBOSE:
      end_time4 = time.time()
      log("filexfer started; stage 4 finished in %f seconds" % (end_time4 - end_time3))
  except (MaxwellError, subprocess.CalledProcessError), e:
    release_dirlock(lock_path)
    try:
      # unwind depending on stage
      if stage > 2:
        stopvnc(disp)
      if stage > 1:
        vps.delete_confs()
      if stage > 0:
        vps.umount()
    except MaxwellError, e2:
      log("Second error while aborting container start: %s" % e2)
      raise MaxwellError("Original error on container start %d was %s" % (disp, e))
    else:
      raise MaxwellError("Recovered from aborted container %d start; %s" % (disp, e))
  release_dirlock(lock_path)
  end_time = time.time()
  if VERBOSE:
    log("entire startvnc call took %f seconds" % (end_time - start_time))

#=============================================================================
# Return the status of a display.
#=============================================================================
def status(disp):
  vps = Container(disp, machine_number, CONTAINER_CONF) #Virtual Private Server
  ttyprint(vps.get_status())
  return

#=============================================================================
# Stop a VNC server and kill the processes for that display.
#=============================================================================
def stopvnc(disp):
  """To shut down a container, we kill all the processes within using increasing signal
  severity, until only init remains:  1 HUP , 2 INT, 15 TERM, 9 KILL
  Then do a hard halt, then a stop.  Cleanup leftover symlinks we created during startvnc.
  Command-level function.

  Issue: this function may be called simultaneously by the web server,
          and by the maxwell_service that started an application.
  """

  vps = Container(disp, machine_number, CONTAINER_CONF) #Virtual Private Server
  if IS_WINDOWS:
    try:
      vps.stop()
    except MaxwellError, e:
      log("%s" %e)
    return

  # get lock to arbitrate race conditions between two processes trying to shut down the same container
  lock_path = HOST_MERGED["LOCK_PATH"] + "/lock_%d" % vps.veid
  try:
    get_dirlock(lock_path, expiration=300, sleeptime=0)
  except PrivateError:
    if VERBOSE:
      log("Display %d busy (starting or stopping), waiting for completion" % vps.veid)
    # wait to prevent maxwell from thinking this display is now available to start
    get_dirlock(lock_path, expiration=300, retries=310, sleeptime=1)

  # delete any viewtoken mappings for novnc
  try:
    for filename in os.listdir(VIEW_PATH):
      # print filename
      target_disp = open(VIEW_PATH + filename).readline()
      if int(target_disp) == disp:
        os.remove(VIEW_PATH + filename)
        if VERBOSE:
          log("deleted viewtoken %s for container %d" % (filename, disp))
      else:
        if VERBOSE:
          log("%d is not equal to %d"  % (target_disp, disp))
  except OSError:
    pass

  vps.stop_submit_local()
  # stop all processes except init
  vps.killall()
  time.sleep(1) # give time for processes to finish going away and OpenVZ to catch up

  # Stop the container
  attempt = 0
  while attempt < 5:
    vps.wait_unlock() # OpenVZ has its own lock for VEs;  wait until it's done
    status = vps.get_status()
    if status.find("running") == -1:
      break
    else:
      try:
        vps.stop()
        break
      except MaxwellError, e:
        log("%s" %e)
        time.sleep(1) # give time for the container to finish shutting down, if error is caused by simultaneous shut down
        attempt += 1

  status = vps.get_status()
  if status.find("down") == -1:
    release_dirlock(lock_path)
    raise MaxwellError("Unable to stop container %d" % vps.veid)
  else:
    vps.delete_confs()
    status = vps.get_status()
    if status.find("deleted") == -1:
      release_dirlock(lock_path)
      raise MaxwellError("Unable to delete configuration files for %d" % vps.veid)
  release_dirlock(lock_path)

#=============================================================================
# Clean up sessions that are believed to be running after the system reboots.
#=============================================================================
def notify_restarted():
  """Execution hosts store SERVICE_LOG_PATH/<pid>.err and SERVICE_LOG_PATH/<pid>.out files that are scped and analyzed at the end of
    a session.  This is about managing those <pid>.err files.  Session exits are usually logged and the
    master is notified, but if the host was restarted unexpectedly or crashed, then this didn't
    happen.  Fix it.  Also delete all X11 session support files.
    Command-level function.
  """
  os.chdir(SESSION_MERGED["SERVICE_LOG_PATH"])
  dir_ls = os.listdir(".")
  try:
    lock_stat = os.lstat("/tmp/.X11-unix")
    usr_id = lock_stat[stat.ST_UID]
    if usr_id != os.geteuid():
      raise MaxwellError("Not owner of /tmp/.X11-unix/ , can't cleanup securely.")
  except OSError:
    # does not exist
    pass

  try:
    os.system("rm -f /tmp/.X11-unix/*")
    os.system("rm -f /tmp/.X*-lock")
  except EnvironmentError:
    pass

  # look for crashed sessions, append to the log file, and tell master the bad news
  for afile in dir_ls:
    if afile.endswith(".err"): # look for session log files
      arr = afile.split(".")  # expected format is <number>.err
      try:
        sessnum = int(arr[0])
      except ValueError:
        continue # OK, not a number so look at next file

      log("Cleaning up session %d" % sessnum)
      # Trying to see if the session ended before the host was restarted, or if it crashed.
      fp = open(afile, 'r')
      has_exit_status = 0
      line = fp.readline()
      while line != '':
        if line.startswith("Exit_Status:"):
          has_exit_status = 1
          break
        line = fp.readline()
      fp.close()
      if has_exit_status == 0:
        log("Adding exit status to '%s'" % afile)
        fp = open(afile, 'a')
        fp.write("\nSystem crashed\nExit_Status: 65534")
        fp.close()
        # inform a master that the session crashed and is done
        notify_hosts = get_notify_hosts()
        for hostname in notify_hosts:
          if notify_command_finished(sessnum, hostname):
            break
          time.sleep(1)

def anonymous(session, disp, params):
  """For anonymous users, setup a volatile home directory
  """
  vps = Container(disp, 0, CONTAINER_CONF)
  vps.create_anonymous(session, params)

def update_resources(user, session, disp):
  """For anonymous users, add information to the resources file
  """
  if user != "anonymous":
    raise InputError("setup_dir on maxwell_service is only for anonymous users")

  vps = Container(disp, 0, CONTAINER_CONF)
  vps.update_resources(session)

def set_viewtoken(disp, viewtoken):
  if not os.path.isdir(VIEW_PATH):
    os.makedirs(VIEW_PATH)
  vfile = open(VIEW_PATH + viewtoken, "w")
  vfile.write("%d\n" % disp)
  vfile.close()
  if VERBOSE:
    log("wrote viewtoken %s for container %d" % (viewtoken, disp))

#=============================================================================
# Make sure that passed userid is valid, to prevent OS injection attacks
#=============================================================================
def validate_user(i):
  m = re.match(USER_REGEXP, sys.argv[i])
  if m is None:
    raise InputError("Bad user ID '%s'" % sys.argv[i])
  return m.group()

def validate_viewtoken(i):
  m = re.match(ALPHANUM_REGEXP, sys.argv[i])
  if m is None:
    raise InputError("Bad viewtoken '%s'" % sys.argv[i])
  return m.group()

#=============================================================================
# Check the number of arguments
#=============================================================================
def check_nargs(minarg, maxarg=0):
  if len(sys.argv) < minarg:
    raise InputError("Incomplete command: %s" % " ".join(sys.argv))
  if len(sys.argv) > maxarg and len(sys.argv) > minarg:
    raise InputError("Too many arguments: %s" % " ".join(sys.argv))

#=============================================================================
# Validate a session identifier, number + suffix
#=============================================================================
def validate_session(i):
  # suffix: d for development, p for production
  m = re.match(SESSNAME_REGEXP, sys.argv[i])
  if m is None:
    raise InputError("Bad session ID '%s'" % sys.argv[i])
  return m.group()

#=============================================================================
# Windows Functions
#=============================================================================
#=============================================================================
# Start an X application
#=============================================================================
def win_startapp(user, session_id, timeout, command, disp):
  """
  Windows version of startxapp
  """
  #log("Inside win_startapp")
  # There used to be a 'notify' protocol and a 'stream' protocol.
  if command[0:7] == "notify ":
    command = command[7:]
  else:
    raise MaxwellError("unknown win_startapp command '%s'" % command[0:7])

  # Get the session number by removing the optional suffix
  if session_id[-1] >= '0' and session_id[-1] <= '9':
    sessnum = int(session_id)
  else:
    sessnum = int(session_id[0:-1])

  #Set up pid.err and pid.out files
  in_fd = os.open("/dev/null", os.O_RDONLY)
  # file should not already exist, so use os.O_EXCL
  # we don't use the plain open command because it doesn't allow specifying O_EXCL
  out_fd = os.open("""%s/%d.out""" % (SESSION_MERGED["WIN_S_EXEC_HOST_LOG_PATH"], sessnum), os.O_WRONLY|os.O_CREAT|os.O_EXCL)
  error_fd = os.open("""%s/%d.err""" % (SESSION_MERGED["WIN_S_EXEC_HOST_LOG_PATH"], sessnum), os.O_WRONLY|os.O_CREAT|os.O_EXCL)

  # overwrite stdin, stdout and stderr
  # due to this trick, the log() function still works, as it writes to stderr.
  os.dup2(in_fd, 0)
  os.dup2(out_fd, 1)
  os.dup2(error_fd, 2)

  # Invoke the command, wait for it to exit,
  # and then notify the remote caller.
  vps = Container(disp, machine_number, CONTAINER_CONF)
  vps.invoke_unix_command(user, session_id, timeout, command)

  setup_log(SERVICE_LOG, 'cleanup')
  if VERBOSE:
    log ("invoke_unix_command %s, %s, %d, %s, %d is done" %(user, session_id, timeout, command, disp))

    # Close .out and .err files so they can be xfered to web server
    os.close(out_fd)
    os.close(error_fd)

  # When we're done with the command, we're done with this VE.  Stop it.
  stopvnc(disp)
  if VERBOSE:
    log("Stopped display")

  # FIXIT
  #expire_session_dir(user,session_id)
  notify_hosts = get_notify_hosts()
  for i in range(0, notify_retries):
    for hostname in notify_hosts:
      if notify_command_finished(sessnum, hostname):
        os._exit(0) # success!
      time.sleep(1)
  log("Unable to notify controller about session %d" % sessnum)
  os._exit(1)

#=============================================================================
# Kill a tree of processes in a depth-first manner.
#=============================================================================
def win_killtree(pid):
  """
  Windows version of killtree
  """
  pass

#=============================================================================
# Kill all processes belonging to a particular user.
#=============================================================================
def win_killall(user):
  """Windows version of killall
  """
  pass

#=============================================================================
#=============================================================================
# Main program...
# by session name, we mean an integer followed by a letter
# by sessnum, we mean just the integer.
# Whatever is printed gets sent over ssh to the command host as an answer (e.g., "OK")
#
# We recognize nine distinct commands:
#  startvnc <dispnum> <geometry> <depth>
#  stopvnc <dispnum>
#  startxapp <user> <session name> <timeout> <dispnum> notify {params} <command>...
#  setup_dir <user> <session name>
#  killtree <pid>
#  killall <user>
#  check
#  purgeoutputs <sessnum>
#  notifyrestarted <masterhostname>
#=============================================================================
#=============================================================================


#=============================================================================
# Configuration and Safety
# This program receives commands from maxwell
# it runs on a different host and so has different rules.
#
# note that hosts have varying capabilities (roles) assigned to them
# if it's a fileserver, it will respond to setup_dir, etc...
#=============================================================================

if IS_WINDOWS:
  try:
    execfile(EXEC_CONFIG_FILE)
  except EnvironmentError:
    print "Unable to read configuration file, exiting."
    print "The configuration file '%s' needs to exist" % EXEC_CONFIG_FILE
    sys.exit(1)
else:
  check_rundir()
  try:
    mode = os.lstat(EXEC_CONFIG_FILE)[stat.ST_MODE]
    if mode & stat.S_IWOTH:
      print "configuration file is writable by others; exiting.\n"
      sys.exit(1)
    try:
      execfile(EXEC_CONFIG_FILE)
    except EnvironmentError:
      print "Unable to read configuration file, exiting."
      print "The configuration file '%s' needs to exist" % EXEC_CONFIG_FILE
      sys.exit(1)

  except OSError:
    # no configuration file is present, use defaults
    pass

SESSION_MERGED = SESSION_K
SESSION_MERGED.update(SESSION_CONF)
HOST_MERGED = HOST_K
HOST_MERGED.update(HOST_CONF)

if not IS_WINDOWS:
  # check that user is correct
  login =  pwd.getpwuid(os.geteuid())[0]
  if login != HOST_MERGED["SVC_HOST_USER"]:
    print "maxwell: access denied to %s. Must be run as %s (see %s)" % \
      (login, HOST_MERGED["SVC_HOST_USER"], EXEC_CONFIG_FILE)
    sys.exit(1)

    
#=============================================================================
# Input parsing
#=============================================================================

inputs = {}

if sys.argv[1] == "check":
  # the only place in this script where we need to print to standard out.
  # However we want to capture stdout and stderr.  If the logging redirection is setup
  # before this point, we can't print unless we call save_out().  Then later startvnc hangs
  # because save_out() prevents the dissociation of processes.
  # One solution is to setup logging after this.  Alternatively, call save_out() and then discard_out().
  print("OK")
  sys.exit(0)

if sys.argv[1] == "status":
  save_out() # this script needs ttyprint functionality
  setup_log(SERVICE_LOG, 'mw-service')
  check_nargs(3)
  inputs["disp"] = int(sys.argv[2])
  status(**inputs)
  sys.exit(0)

# record IP address of server that contacted us:
# 1. For notification purposes
# 2. So we have the option of using a different template based on the IP address of the webserver (client)
# 3. For logging purposes
try:
  conn = os.environ["SSH_CONNECTION"]
  saved_clientIP = conn.split()[0]
except KeyError:
  saved_clientIP = None
  
setup_log(SERVICE_LOG, 'mw-service')
if VERBOSE:
  if saved_clientIP is not None and saved_clientIP != "":
    log("received command %s from %s" % (" ".join(sys.argv), saved_clientIP))
  else:
    log("received command %s" % " ".join(sys.argv))

try:
  if sys.argv[1] == "startvnc":
    check_nargs(5)
    inputs["disp"] = int(sys.argv[2])
    inputs["geom"] = sys.argv[3]
    # validate geometry, expecting number x number
    if re.match(GEOM_REGEXP, inputs["geom"]) is None:
      raise InputError("Invalid geometry: '%s'" % inputs["geom"])

    inputs["depth"] = int(sys.argv[4])
    startvnc(**inputs)

  elif sys.argv[1] == "resize":
    check_nargs(4)
    inputs["disp"] = int(sys.argv[2])
    inputs["geom"] = sys.argv[3]
    # validate geometry, expecting number x number
    if re.match(GEOM_REGEXP, inputs["geom"]) is None:
      raise InputError("Invalid geometry: '%s'" % inputs["geom"])
    resize(**inputs)

  elif sys.argv[1] == "stopvnc":
    if not IS_WINDOWS:
      os.nice(20)
    check_nargs(3)
    inputs["disp"] = int(sys.argv[2])
    stopvnc(**inputs)

  elif sys.argv[1] == "anonymous":
    check_nargs(4, 5)
    inputs["session"] = validate_session(2)
    inputs["disp"] = int(sys.argv[3])
    if len(sys.argv) > 4:
      inputs["params"] = sys.argv[4]
    else:
      inputs["params"] = None
    anonymous(**inputs)

  elif sys.argv[1] == "update_resources":
    check_nargs(5)
    inputs["user"] = validate_user(2)
    inputs["session"] = validate_session(3)
    inputs["disp"] = int(sys.argv[4])
    update_resources(**inputs)

  elif sys.argv[1] == "screenshot":
    check_nargs(5)
    inputs["user"] = validate_user(2)
    inputs["session_id"] = validate_session(3)
    inputs["display"] = int(sys.argv[4])
    screenshot(**inputs)

  elif sys.argv[1] == "mount_paths":
    # in a given display, mount the following paths
    check_nargs(3, 99)
    # To Do: input validation with regexp and map (see startxapp)
    inputs["display"] = int(sys.argv[2])
    inputs["paths"] = sys.argv[3: len(sys.argv)]
    mount_paths(**inputs)
    
  elif sys.argv[1] == "startxapp":
    check_nargs(7, 99)
    inputs["user"] = validate_user(2)
    inputs["session_id"] = validate_session(3)
    inputs["timeout"] = int(sys.argv[4])
    inputs["disp"] = int(sys.argv[5])
    # command input validation
    # each arg should be a single word containing only
    # alphanumerics, dashes, underscores, dots and slashes, with redirections > <
    # see maxwell/client in function "invoke_command"
    prog = re.compile(SHELLARG_REGEXP)
    try:
      cmds = map(lambda x:prog.match(x).group(0), sys.argv[6: len(sys.argv)])
    except AttributeError:
      raise InputError("Invalid command passed to startxapp: '%s'" % "' '".join(sys.argv[6: len(sys.argv)]))

    inputs["command"] = " ".join(cmds)
    if IS_WINDOWS:
      win_startapp(**inputs)
    else:
      startxapp(**inputs)

  elif sys.argv[1] == "killtree":
    check_nargs(3)
    inputs["pid"] = int(sys.argv[2])
    if IS_WINDOWS:
      win_killtree(**inputs)
    else:
      killtree(**inputs)

  elif sys.argv[1] == "killall":
    check_nargs(3)
    inputs["user"] = validate_user(2)
    if IS_WINDOWS:
      win_killall(**inputs)
    else:
      killall(**inputs)

  elif sys.argv[1] == "purgeoutputs":
    if IS_WINDOWS == False:
      os.nice(20)
    # Idea: add checks here to make sure that the session is finished?
    check_nargs(3)
    # expecting sessnum, an int, not a session identifier with a letter at end
    sess_param = int(sys.argv[2])
    if IS_WINDOWS: # use alternate path for windows execution hosts
      log("os.unlink = %s/%d" % (SESSION_MERGED["WIN_S_EXEC_HOST_LOG_PATH"], sess_param))
    else:
      log("os.unlink = %s/%d" % (SESSION_MERGED["SERVICE_LOG_PATH"], sess_param))
    try:
      if IS_WINDOWS: # use alternate path for windows execution hosts
        os.unlink("%s/%d.out" % (SESSION_MERGED["WIN_S_EXEC_HOST_LOG_PATH"], sess_param))
      else:
        os.unlink("%s/%d.out" % (SESSION_MERGED["SERVICE_LOG_PATH"], sess_param))
    except EnvironmentError:
      pass
    else:
      if VERBOSE:
        if IS_WINDOWS: # use alternate path for windows execution hosts
          log("deleted %s/%d.out" % (SESSION_MERGED["WIN_S_EXEC_HOST_LOG_PATH"], sess_param))
        else:
          log("deleted %s/%d.out" % (SESSION_MERGED["SERVICE_LOG_PATH"], sess_param))
    try:
      if IS_WINDOWS: # use alternate path for windows execution hosts
        os.unlink("%s/%d.err" % (SESSION_MERGED["WIN_S_EXEC_HOST_LOG_PATH"], sess_param))
      else:
        os.unlink("%s/%d.err" % (SESSION_MERGED["SERVICE_LOG_PATH"], sess_param))
    except EnvironmentError:
      pass
    else:
      if VERBOSE:
        if IS_WINDOWS: # use alternate path for windows execution hosts
          log("deleted %s/%d.err" % (SESSION_MERGED["WIN_S_EXEC_HOST_LOG_PATH"], sess_param))
        else:
          log("deleted %s/%d.err" % (SESSION_MERGED["SERVICE_LOG_PATH"], sess_param))
    sys.exit(0)

  elif sys.argv[1] == "notifyrestarted":
    if IS_WINDOWS:
      pass
    else:
      notify_restarted()
      sys.exit(0)

  elif sys.argv[1] == "erase_sessdir":
    pass # for anonymous sessions, no action needed

  elif sys.argv[1] == "set_viewtoken":
    # dispnum, viewtoken
    check_nargs(4)
    inputs["disp"] = int(sys.argv[2])
    inputs["viewtoken"] = validate_viewtoken(3)
    set_viewtoken(**inputs)

  else:
    raise InputError("Unknown command: %s" % " ".join(sys.argv))

# attempted conversion of alpha chars to int results in ValueError
except ValueError, e:
  log("Integer input expected for this command: '%s'" % " ".join(sys.argv))
  sys.exit(1)

except InputError, e:
  log("%s" % e)
  if VERBOSE:
    log_exc(e)
  sys.exit(1)

except MaxwellError, e:
  log("%s" % e)
  if VERBOSE:
    log_exc(e)
  sys.exit(2)

except Exception, e:
  log("%s" % e)
  if VERBOSE:
    log_exc(e)
  sys.exit(5)

if VERBOSE:
  log("done")
sys.exit(0)
