#!/usr/bin/python
# @package      hubzero-mw2-front-virtualssh
# @file         ssh_session.py
# @author       Pascal Meunier <pmeunier@purdue.edu>
# @copyright    Copyright (c) 2016-2017 HUBzero Foundation, LLC.
# @license      http://opensource.org/licenses/MIT MIT
#
# Based on prior work by Richard L. Kennell
#
# Copyright (c) 2016-2017 HUBzero Foundation, LLC.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#
# HUBzero is a registered trademark of HUBzero Foundation, LLC.
#

"""
Designed by Richard L. Kennell
rewritten and extended by Pascal Meunier

This runs on the webserver.
1. sshd configuration is crucial for this to work properly (see /etc/ssh/sshd_config).
-Execution hosts SSH back to run a forced command as www-data to notify of session events
-Regular users are forced into chrooted SFTP sessions
-Members of group mw-login are forced to run the command "/usr/bin/session" 
-The forced command /usr/bin/session then calls ssh_session.py "$@"'

2. /etc/sudoers allows everyone to invoke this (ssh_session.py) with sudo.
3. Without virtual SSH, /etc/ldap-auth.conf blocks logins with pam_filter.  With virtual SSH that filter needs to be removed so everyone can login, but only if SSH is configured properly as above.

In effect, authorization checks are done here.
4. We determine if we're doing scp, sftp or accessing a workspace session,
depending on the arguments.
5. If it's sftp or scp we put the user in a chrooted environment.
6. If we're accessing a workspace, from the session we figure out the execution host, or exit.
7. We copy the precious private key for maxwell to a temporary directory.  We
   make it owned and readable only by the user.

8. We assume the identity of the user.
9. We forward the communication to root@host by using the maxwell key and starting the special
 client that accesses a container on the host through ssh.

  Security implications:
  i. Users logged directly on the web server send any command to root@execution host (must be in
  	hub-login or www-data group).  Workspace users don't have access.  In the future, consider
        as an improvement, setting up a different SSH key pair for virtual SSH use only 
        with a forced command on the execution host.
  ii. Normal users can't access the copy of the key with scp/sftp because it's chrooted into a different
  	directory.  They can ssh only to execution host workstations.  So, they can't access the key
  	directly to send arbitrary commands.  That's OK.
"""
import os
import sys
import signal
import time
import tempfile
import shutil
import re
import subprocess

# setup error.py as a symlink pointing to /usr/lib/mw/bin/error.py
from hubzero.mw.constants import HOST_K, CONFIG_FILE, LIB_PATH, \
     MW_USER, MYSQL_CONN_ATTEMPTS, VERBOSE, NAME_REGEXP, APP_K
from hubzero.mw.heartbeat import do_heartbeat
from hubzero.mw.mw_db import MW_DB
from hubzero.mw.user_account import User_account
from hubzero.mw.log import log, open_log, log_setid

DEBUG = False
LOG_ID = 'vssh'
INSTALL_PATH = "/usr/bin/"
ENSURE_PATH = "/usr/bin/ensure-known-host"

# Configuration file to talk to server inside containers
SSH_CONFIG_PATH = '/etc/mw-virtualssh/front_ssh_config'
# Executable ssh client script
CLIENT_PATH = '/usr/bin/vssh_exec_proxy'
DEFAULT_APP = 'workspace'
os.chdir(INSTALL_PATH)
VSSH_LOG_FILENAME = '/var/log/mw-virtualssh/front-vssh.log'

orig_stdout = None
orig_stderr = None

REFRESH_INTERVAL = 60*60 # Refresh once per hour
#REFRESH_INTERVAL = 5 # DEBUG: Refresh every 5 seconds
view_is_done = False
need_to_refresh = False
mysql_prefix = None

#=============================================================================
# Load default parameters...
#=============================================================================
default_domains = [""]
dns_retries = 10
mysql_host = ""
mysql_user = ""
mysql_password = ""
mysql_db = ""
mysql_connect_attempts = 120
session_suffix = ""

#=============================================================================
# Load the configuration and override the variables above.
#=============================================================================
try:
  execfile(CONFIG_FILE)
except IOError:
  pass
DB_PARAMS = (mysql_host, mysql_user, mysql_password, mysql_db, MYSQL_CONN_ATTEMPTS)


#=============================================================================
# Communicate, e.g., to launch a new workspace session.
#=============================================================================
def tell_maxwell(cmd, message):
  if os.environ.has_key('SSH_TTY'):
    os.write(orig_stdout, message)
  pid = os.fork()
  if pid == 0:
    mwu = User_account(MW_USER)
    mwu.relinquish()
    os.chdir(LIB_PATH)
    fd = os.open("/dev/null", os.O_RDWR)
    os.dup2(fd, 0)
    os.dup2(fd, 1)
    os.dup2(fd, 2)
    # Clean up the environment so the middleware believes this is www-data
    for e in ['SUDO_USER', 'SUDO_UID', 'SUDO_COMMAND']:
      try:
        del os.environ[e]
      except OSError:
        pass
    os.environ['USER'] = MW_USER
    os.environ['MAIL'] = '/var/mail/%s' % MW_USER
    os.environ['USERNAME'] = MW_USER
    os.environ['LOGNAME'] = MW_USER
    os.environ['HOME'] = mwu.homedir #user_dir
    os.execve(HOST_K["MAXWELL_PATH"],
      [HOST_K["MAXWELL_PATH"]] + cmd,
      os.environ)
  else:
    os.wait()
  # It's possible to SSH into a session before it's ready
  # It results in errors like "Unknown id: pmeunier" because the password file
  # has not been created yet inside the container.
  # This is more likely when the container is new and caching hasn't completed yet.
  # vsshd_start and vssh_exec_proxy will wait 1 second and retry 100 times.
  # wait 1 second to avoid the first retry and error message?
  # time.sleep(1)

#=============================================================================
# Create a view entry for the SSH session.
#=============================================================================
def start_view(db, sessnum):

  db.c.execute("""UPDATE session SET accesstime=now()
                   WHERE sessnum = %s""", str(sessnum))
  db.c.execute("""
    INSERT INTO view(sessnum,username,remoteip,start,heartbeat)
    VALUES(%s, %s, %s,now(),now())""", (str(sessnum), u.user, remoteip))

  viewid = db.getsingle("""SELECT LAST_INSERT_ID()""", ())
  if VERBOSE:
    log("started view %d for session %d" % (viewid, int(sessnum)))
  db.close()
  return viewid

#=============================================================================
# Move the view table entry to viewlog.
#=============================================================================
def finish_view(viewid, starttime, sessnum):
  db = MW_DB(*DB_PARAMS)

  endtime = time.time()
  viewtime = endtime - starttime
  db.c.execute("""
    INSERT INTO viewlog(sessnum,username,remoteip,time,duration)
    SELECT sessnum, username, remoteip, start, %s
    FROM view WHERE viewid=%s
  """, ("%f" % viewtime, str(viewid)))

  sessnum = db.getsingle("""SELECT sessnum from view where viewid=%s""", (str(viewid)))
  db.c.execute("""UPDATE session SET session.accesstime=NOW()
    WHERE sessnum=%s""", (str(sessnum)))

  db.c.execute("""DELETE FROM view WHERE viewid=%s""", str(viewid))
  db.close()
  log("SSH into %d end %s@%s (%0.1f sec)" % (sessnum, u.user, remoteip, viewtime))

def wait_for_signals(viewid, pid, sessnum):
  starttime = time.time()
  sql = """UPDATE view SET heartbeat=now() WHERE viewid=%s"""
  do_heartbeat(sql, str(viewid), DB_PARAMS, REFRESH_INTERVAL)

  finish_view(viewid, starttime, sessnum)
  try:
    os.kill(pid, signal.SIGKILL)
  except OSError:
    pass

def print_help():
  os.write(orig_stdout,"Virtual SSH help:\n\n")
  os.write(orig_stdout,"ssh    user@<hub>: create a session if none exists, or enter the first\n")
  os.write(orig_stdout,"                   session found (interactive shell)\n\n")
  os.write(orig_stdout,"ssh    user@<hub> help: this message.\n\n")
  os.write(orig_stdout,"ssh    user@<hub> session: create a session if none exists, or enter the\n")
  os.write(orig_stdout,"                           first session found (non-interactive shell)\n\n")
  os.write(orig_stdout,"ssh -t user@<hub> session: same as above but get a command prompt\n")
  os.write(orig_stdout,"                           (interactive shell).\n\n")
  os.write(orig_stdout,"ssh -t user@<hub> session <command>: execute the command, if necessary\n")
  os.write(orig_stdout,"                           creating a workspace.\n\n")
  os.write(orig_stdout,"ssh    user@<hub> session list: provide a listing of your existing\n")
  os.write(orig_stdout,"                                sessions (workspaces)\n\n")
  os.write(orig_stdout,"ssh -t user@<hub> session <session #>: Access session # (interactive).\n\n")
  os.write(orig_stdout,"ssh    user@<hub> session <session #> <command>: Access session # and\n")
  os.write(orig_stdout,"                                                 execute command.\n\n")
  os.write(orig_stdout,"ssh -t user@<hub> session create <session title>: create a new session\n")
  os.write(orig_stdout,"                                              with the specified name.\n\n")
  os.write(orig_stdout,"ssh -t user@<hub> session start: start a new session.\n\n")
  os.write(orig_stdout,"ssh -t user@<hub> session stop <session #>: stop that session.\n\n")
  sys.exit(0)

def ssh_flags():
  """ Construct the SSH flags based on environment variables."""
  flags = []
  if os.environ.has_key('SSH_TTY'):
    flags += [ '-t' ]
  else:
    flags += [ '-T' ]
  if os.environ.has_key('SSH_AUTH_SOCK'):
    flags += [ '-A' ]
  if os.environ.has_key('DISPLAY'):
    flags += [ '-X', '-Y' ]
  if DEBUG:
    log("flags are %s" % flags)
  return flags

def do_sftp(executable):
  """ Setup the chrooted environment.
  Add the user's entry to /sftp/etc/passwd
  Try to avoid writing to it as much as possible, because there's a race condition between users.
  So, read it, and if the user's entry already exists, don't touch it."""
  import stat
  if DEBUG:
    log("Forcing %s for %s@%s" % (str.join(' ', args), u.user, remoteip))
  try:
    os.mkdir("/sftp/etc")
  except EnvironmentError:
    pass
  user_entry_found = False
  try:
    pwdfile = open("/sftp/etc/passwd", 'r')
  except IOError:
    # we'll create the file when appending
    pass
  else:
    # verify that it's readable by others.
    fd = os.open('/sftp/etc/passwd', os.O_NOFOLLOW | os.O_NOATIME)
    pw_stat = os.fstat(fd)
    pw_mode = pw_stat[stat.ST_MODE]
    if pw_mode & stat.S_IROTH == 0:
      os.fchmod(fd, pw_mode | stat.S_IROTH)
      log('fixed /sftp/etc/passwd not readable by others')
    os.close(fd)
    for line in pwdfile:
      if line.find(u.user) != -1:
        if line.find(u.passwd_entry()) == -1:
          # something changed, delete password file and start over
          # there is a race here with other scp/sftp attempts
          # that may be trying to append a password entry
          os.unlink("/sftp/etc/passwd")
          break
        else:
          user_entry_found = True
        break
  if not user_entry_found:
    log("Creating user entry in /sftp/etc/passwd")
    os.umask(0022)
    # Must have a return at the end of the line
    open("/sftp/etc/passwd","a").write(u.passwd_entry() + '\n')
    os.umask(0027)
  pid = os.fork()
  if pid == 0:
    # Must invoke initgroups before chroot.  Otherwise access to NSS goes away.
    # User_account(username).initgroups(gid)
    u.initgroups()
    try:
      os.chroot("/sftp")
    except OSError:
      log("Unable to chroot")
      sys.exit(1)

    u.relinquish()
    if DEBUG:
      log("chdir into %s" %u.homedir)
    try:
      os.chdir(u.homedir)
    except OSError:
      log("Unable to chdir into %s" %u.homedir)
      sys.exit(1)


    # Use this for debugging. (ssh username@host internal-sftp)
    #os.execve("/usr/bin/strace", ['strace','/usr/lib/sftp-server'], os.environ)
    #os.execve("/usr/lib/sftp-server", ['sftp-server'], os.environ)
    log_stdout = os.dup(1)
    os.dup2(orig_stdout, 1)
    os.dup2(orig_stderr, 2)
    os.umask(0027)
    try:
      os.execve(executable, args, os.environ)
    except OSError:
      os.dup2(log_stdout, 1)
      os.dup2(log_stdout, 2)
      log("Unable to exec %s" % executable)

  else:
    time.sleep(1)
    os.wait()
  sys.exit(0)

def get_sessnum(arg):
  s = 0
  try:
    s = int(arg)
  except ValueError:
    if args[0][-1] == session_suffix:
      try:
        s = int(args[0][0:-1])
      except ValueError:
        s = 0
  return s

def get_appname(db): 
  appname = db.getsingle("""
    SELECT instance FROM """ + mysql_prefix + APP_K["TOOL_TABLE"] + """
    WHERE instance LIKE '""" + DEFAULT_APP + """%%'
    AND state = 1 ORDER BY instance DESC LIMIT 1""", ())
  if appname is None:
    log("ERROR: Unable to find the name of the application to start for %s@%s;  default is %s" %
    (u.user, remoteip, DEFAULT_APP))
    sys.exit(1)
  return appname

def get_revision(db, app):
  appname = db.getsingle("""
    SELECT instance FROM """ + mysql_prefix + APP_K["TOOL_TABLE"] + """
    WHERE instance LIKE '""" + app + """%%'
    AND state = 1 ORDER BY instance DESC LIMIT 1""", ())
  if appname is None:
    # try again without the published restriction but require the exact revision number
    appname = db.getsingle("""
      SELECT instance FROM """ + mysql_prefix + APP_K["TOOL_TABLE"] + """
      WHERE instance =%s
      ORDER BY instance DESC LIMIT 1""", (app))
  if appname is None:
    log("ERROR: Unable to find the name of the application to start for %s@%s;  default is %s" %
    (u.user, remoteip, DEFAULT_APP))
    sys.exit(1)
  return appname

def manage_session(sess_args):
  sessnum = 0
  row = None

  if DEBUG and sess_args != None:
    if len(sess_args) >0:
      log("argument 0: %s" % sess_args[0])
    log("SELECT username FROM " + mysql_prefix + """_users WHERE username=%s""" % str(u.user))
  db = MW_DB(*DB_PARAMS)
  user_row = db.getrow(
    "SELECT username FROM " + mysql_prefix + """_users
    WHERE username=%s""", str(u.user))
  if user_row is None:
    log("ERROR: No such HUB user (%s).\n" % (u.user))
    # OK to give out that much information because user has already authenticated.
    # Therefore, we're not revealing user names to third parties.
    os.write(orig_stdout, "No such HUB user.\n")
    sys.exit(1)

  if len(sess_args) > 0:
    if VERBOSE:
      log("Processing session command '%s' for user '%s'" % (" ".join(sess_args), str(u.user)))
    if sess_args[0] in ("help", '--help', '-h'):
      print_help() # exits
    #
    # If this is a --list command, show the running sessions and exit.
    #
    if sess_args[0] in ('list', '--list', '-l'):
      # username is trusted
      arr = db.getall("""
        SELECT sessnum,appname,sessname FROM session
        WHERE username=%s ORDER BY sessnum""", u.user)
      db.close()
      if len(arr) == 0:
        os.write(orig_stdout, "No sessions are active.\n")
      else:
        selected = False
        os.write(orig_stdout, "%8s %8s %-20s %s\n" % ("Number", "Default", "Name", "Title"))
        for elt in arr:
          (sessnum, appname, sessname) = elt
          if selected == False and appname.startswith(DEFAULT_APP):
            selected = True
            isdefault = '*   '
          else:
            isdefault = ' '
          os.write(orig_stdout, "%8s %8s %-20s %s\n" % (str(sessnum), isdefault, appname, sessname))
      sys.exit(0)

    if sess_args[0] in ('create', '-c'):
      if len(sess_args) >= 2:
        # rename the session
        m = re.match(NAME_REGEXP, sess_args[1])
        if m is None:
          log("Input validation failed for create session title '%s'" % sess_args[1])
          sys.exit(1)
        title = sess_args[1]
        if len(sess_args) >= 3 and 'apps' in u.groups():
          # restrict to members of the apps group
          m = re.match(NAME_REGEXP, sess_args[2])
          if m is None:
            log("Input validation failed for application name '%s'" % sess_args[1])
            sys.exit(1)
          appname = get_revision(db, sess_args[2])
        else:
          if DEBUG:
            print u.groups()
          appname = get_appname(db)
      else:
        appname = get_appname(db)
      sesscount = db.getsingle("""SELECT count(*) FROM session WHERE username=%s""", (u.user))
      if sesscount > 3:
        log("user '%s' has too many sessions to start a new one" % u.user)
        os.write(orig_stdout, "Error: You already have %s sessions\n" % sesscount)
        sys.exit(0)
      # scripted session start, i.e., batch mode
      tell_maxwell(['start', 'user=' + u.user, 'ip=' + remoteip, 'app=' + appname], "")
      sessnum = db.getsingle("""
        SELECT sessnum FROM session
        WHERE username=%s AND appname=%s
        ORDER BY sessnum DESC LIMIT 1""", (u.user, appname))
      if sessnum is None:
        log("ERROR: Unable to create a workspace session for %s@%s." % (u.user, remoteip))
        log("Unable to create a workspace session.")
        sys.exit(1)
      os.write(orig_stdout, "%s\n" % sessnum)
      if len(sess_args) >= 2:
        # rename the session
        db.c.execute("""
          UPDATE session set sessname = %s
          WHERE sessnum = %s""", (sess_args[1], str(sessnum))
        )
      sys.exit(0)

    if sess_args[0] in ("start", '--start', '-s'):
      appname = get_appname(db)
      # interactive session start
      tell_maxwell(['start', 'user='+u.user, 'ip='+remoteip, 'app='+appname], "Starting session\n")
      row = db.getrow("""
        SELECT sessnum,username,exechost,dispnum FROM session
        WHERE username=%s AND appname=%s
        ORDER BY sessnum DESC LIMIT 1""", (u.user, appname))
      if row is None:
        log("ERROR: Unable to create a workspace session for %s@%s." % (u.user, remoteip))
        log("Unable to create a workspace session.")
        sys.exit(1)
      sess_args = sess_args[1:]

    elif sess_args[0] in ("stop", '--stop'):
      if DEBUG:
        log("stop command")
      sessnum = get_sessnum(sess_args[1])
      if sessnum != 0:
        row = db.getrow("""
          SELECT sessnum FROM session
          WHERE username=%s AND sessnum=%s""", (u.user, str(sessnum)))
        if row is None:
          log("ERROR: Unauthorized attempt to stop session %d by %s@%s" %
            (sessnum, u.user, remoteip))
          os.write(orig_stderr, "No such session %d." % sessnum)
          sys.exit(1)
        else:
          tell_maxwell(['stop', '%d' % sessnum], "stopping session %d\n" % sessnum)
          sys.exit(0)
      else:
        if DEBUG:
          log("check if it's a title instead of a session number: %s" % sess_args[1])
        m = re.match(NAME_REGEXP, sess_args[1])
        if m is None:
          log("Input validation failed for session title '%s'" % sess_args[1])
          sys.exit(1)
        sessnum = db.getsingle("""
          SELECT sessnum
          FROM session WHERE username=%s AND sessname=%s""", (u.user, sess_args[1]))
        if sessnum is None:
          log("ERROR: No such session '%s' for %s@%s" % (sess_args[1], u.user, remoteip))
          os.write(orig_stdout, "No such session.\n")
          sys.exit(1)
        else:
          if DEBUG:
            log("stopping session %s which has number %s" % (sess_args[1], sessnum))
          tell_maxwell(['stop', '%d' % sessnum], "stopping session %s\n" % sessnum)
          sys.exit(0)

    else:
      # If the first argument is the session number, note it and shift it off.
      sessnum = get_sessnum(sess_args[0])
      if sessnum != 0:
        row = db.getrow("""
          SELECT sessnum,username,exechost,dispnum
          FROM session WHERE username=%s AND sessnum=%s""", (u.user, str(sessnum)))
        if row is None:
          log("ERROR: No such session '%s' for %s@%s" % (sess_args[0], u.user, remoteip))
          os.write(orig_stdout, "No such session.\n")
          sys.exit(1)
        sess_args = sess_args[1:]
      else:
        if DEBUG:
          log("last chance: check if it's a title instead of a session number")
        m = re.match(NAME_REGEXP, sess_args[0])
        if m is not None:
          row = db.getrow("""
            SELECT sessnum,username,exechost,dispnum
            FROM session WHERE username=%s AND sessname=%s""", (u.user, sess_args[0]))
          if row is not None:
            sess_args = sess_args[1:]


  if row is None:
    # Attempt to find a workspace session belonging to that user
    # not just any session otherwise people can unexpectedly end up in a different container version
    # and be surprised by different version numbers of software.  i.e., ticket 264210 on nanoHUB
    appname = get_appname(db)
    row = db.getrow("""
      SELECT sessnum,username,exechost,dispnum FROM session
      WHERE username=%s AND appname=%s""", (u.user, appname))
    if row is None:
      appname = get_appname(db)
      # None found, attempt to start a new one
      tell_maxwell(['start', 'user='+u.user, 'ip='+remoteip, 'app='+appname], "Starting session\n")
      row = db.getrow("""
        SELECT sessnum,username,exechost,dispnum FROM session
        WHERE username=%s AND appname=%s""", (u.user, appname))
      if row is None:
        log("ERROR: Unable to create a workspace session for %s@%s." % (u.user, remoteip))
        log("Unable to create a workspace session.")
        sys.exit(1)

  (sessnum, session_username, exechost, dispnum) = row
  try:
    sessnum = int(sessnum)
    dispnum = int(dispnum)
  except ValueError:
    sessnum = 0
    dispnum = 0
    exechost = ''

  if session_username != u.user or exechost == '' or dispnum == 0:
    log("ERROR: Unauthorized attempt to ssh to session %d for %s@%s" % (sessnum, u.user, remoteip))
    os.write(orig_stderr, "No such session %d." % sessnum)
    sys.exit(1)

  command = "%s %d %s" % (CLIENT_PATH, dispnum, u.user)
  #print "Sessnum: %d" % sessnum
  #print "Command: %s" % command
  #print "Real uid: %d" % uid
  #sys.exit(0)
  rundir = tempfile.mkdtemp("", "ssh-passthru-")

  if len(sess_args) > 0:
    command += ' ' + str.join(' ', sess_args)
  # do not print a message when passing a command because this is used for
  # scripted commands (e.g. with git) that don't handle them
  # Only print the message when "session" or "session 12345" is specified
  # by the user.  If a command is specified (or if no arguments at all are given)
  # then don't print the message.
  else:
    if got_session:
      os.write(orig_stdout,
           "Accessing session %d.  Use ssh -t to get a command prompt.\n" % sessnum)

  # Make sure the host is known.
  pid = os.fork()
  if pid == 0:
    # Get host key of execution host if we don't have it already, and make it available to all users in one place
    os.execv(ENSURE_PATH, [ "ensure-known-host", exechost])
  else:
    try:
      os.wait()
    except OSError:
      pass

  #============================================================================
  # Switch to the runtime directory and set up its key.
  # everyone gets a copy of the maxwell private key...  
  # Anyone who can login directly on web server can get the middleware's private key
  # by using Virtual SSH...
  # It would be better to use a different key pair generated on the fly for each
  # user and use static NAT on the execution host to SSH directly into the tool container
  #============================================================================
  os.chdir(rundir)
  (_, idfile) = tempfile.mkstemp("", "id-", rundir)
  shutil.copyfile(HOST_K["KEY_PATH"], idfile)
  os.chmod(idfile, 0600)
  os.chown(idfile, u.uid, u.gid)
  os.chmod(rundir, 0700)
  # chown dir as the last operation, no more operations as root after that
  # this assumes we're not running a /tmp cleaner, and rundir has not been deleted by it
  # and replaced by a symlink, etc...
  os.chown(rundir, u.uid, u.gid)

  #============================================================================
  # Change user.  Everything below happens as the requesting user.
  #============================================================================
  u.initgroups()
  u.relinquish()

  log("SSH into %s start %s@%s %s" % (sessnum, u.user, remoteip, exechost))
  log("SSH into %s %s %s" % (sessnum, idfile, command.replace(INSTALL_PATH+'guest/', '')))

  view_id = start_view(db, sessnum)  # start_view() will close the database.
  # log("/usr/bin/ssh -F %s -i %s %s root@%s %s" % (SSH_CONFIG_PATH, idfile, ssh_flags(), exechost, command]))
  pid = os.fork()
  if pid == 0:
    try:
      os.environ['SESSION'] = str(sessnum) + session_suffix
      os.environ['SESSIONDIR'] = u.homedir + '/data/sessions/' + str(sessnum) + session_suffix
      os.environ['RESULTSDIR'] = u.homedir + '/data/results/' + str(sessnum) + session_suffix
      os.dup2(orig_stdout, 1)
      os.dup2(orig_stderr, 2)
      os.execve("/usr/bin/ssh",
        ['ssh', '-F', SSH_CONFIG_PATH, '-i', idfile ] + ssh_flags() + ['root@'+exechost, command],
        os.environ)
    except StandardError, e:
      log("Exception: %s" % e)
  else:
    wait_for_signals(view_id, pid, sessnum)
    shutil.rmtree(rundir)
    if os.environ.has_key('DISPLAY'):
      disp = os.environ['DISPLAY']
      xauthdir = '/tmp/' + u.user + '-' + disp.replace('$','_')
      try:
        shutil.rmtree(xauthdir)
      except (OSError, shutil.Error):
        if VERBOSE:
          print("Unable to unlink " + xauthdir)


  sys.exit(0)

#=============================================================================
# Main program...
#
# We recognize three types of commands:
#   session <sessnum> <command> ...
#   session <command> ...
#   session
#
#=============================================================================

if 0 != os.geteuid():
  # verify assumption
  print "Must be run as root"
  sys.exit(1)

orig_stdout = os.dup(1)
orig_stderr = os.dup(2)
# use open_log instead of setup_log because setup_log messes with stdout and stderr
# The purpose of this program is to redirect stdin and stdout to tool sessions, so it's incompatible with setup_log
open_log(VSSH_LOG_FILENAME)

try:
  u = User_account(os.environ["SUDO_USER"])
except (KeyError):
  log("Cannot find the SUDO_UID variable.  Bailing out.")
  os.write(orig_stdout, "No SUDO_USER variable.\n")
  sys.exit(1)

log_setid(u.user)
if DEBUG:
  log("Starting processing")

try:
  uid = int(os.environ["SUDO_UID"])
except (KeyError):
  log("Cannot find the SUDO_UID variable.  Bailing out.")
  sys.exit(1)
except ValueError:
  log("SUDO_UID variable is not an integer.  Bailing out.")
  sys.exit(1)

if DEBUG:
  log("uid %d, user name %s and userid %d" % (uid, os.environ["SUDO_USER"], u.uid))

if u.uid != uid:
  log("Mismatched SUDO_USER (%s) and SUDO_UID (%s) variables (expected uid %d)." %
      (os.environ["SUDO_USER"], os.environ["SUDO_UID"], u.uid))
  sys.exit(1)

#
# Find the IP address of the client.
#
try:
  remoteip = os.environ['SSH_CLIENT'].split()[0]
except (KeyError, TypeError):
  try:
    remoteip = os.environ['SSH_CONNECTION'].split()[0]
  except (KeyError, TypeError):
    log("Unable to determine client IP address.")
    remoteip = '0.0.0.0'

#=============================================================================
# Assemble the user command.
#=============================================================================
#log("args: " + str.join(' ', sys.argv))
got_session = False
args = sys.argv[1:]
if DEBUG:
  log("args: " + str.join(' ', args))

# If we don't have any arguments, try to pull them from SSH_ORIGINAL_COMMAND
if len(args) == 0:
  if os.environ.has_key('SSH_ORIGINAL_COMMAND'):
    args = os.environ['SSH_ORIGINAL_COMMAND'].split()
    if DEBUG:
      log("args from environment: " + str.join(' ', args))

if len(args) > 0:
  if args[0] == "help":
    print_help() # exits

  if args[0] == 'session':
    # shift it off
    got_session = True
    manage_session(args[1:])

  if args[0] == "internal-sftp":
    args[0] = "sftp-server"
    do_sftp("/usr/lib/sftp-server")

  if args[0] == "scp":
    do_sftp("/usr/bin/scp")

manage_session(args)

