#!/usr/bin/env python
#
# Copyright 2006-2009 by Purdue University.
# All rights reserved.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.

from OpenSSL import SSL
import sys
import os
import socket
import MySQLdb
import time
import signal

vnc_listen_port = 8080

foreground = 0
logfile = None
pidfile = "/var/run/vncproxy.pid"

mysql_host = None
mysql_user = None
mysql_password = None
mysql_db = None
db = None
ssl_vnc_portbase = 4000

#=============================================================================
# Load the configuration and override the variables above.
#=============================================================================
try:
  execfile('/etc/hubzero/maxwell.conf')
except IOError:
  pass

#=============================================================================
# Set up errors to go to the log file.
#=============================================================================
def openlog(logfile_name):
  global logfile
  global foreground

  logfile = sys.stderr

  if foreground:
    return

  #if os.isatty(2):
  #  return

  try:
    fd = os.open("/dev/null", os.O_RDONLY)
    os.dup2(fd,0)
  except:
    pass

  for fd in range(3,1024):
    try:
      os.close(fd)
    except:
      pass

  try:
    fd = os.open(logfile_name, os.O_CREAT|os.O_APPEND|os.O_WRONLY, 0644)
    os.dup2(fd,1)
    if fd != 2:
      os.dup2(fd,2)
      os.close(fd)
    return
  except:
    pass

  print "Logfile open failed.  Remaining in foreground."
  logfile = sys.stderr
  foreground = 1

#=============================================================================
# Log a message.
#=============================================================================
def log(msg):
  if not foreground:
    timestamp = "[" + time.asctime() + "] "
  else:
    timestamp = ""
  logfile.write(timestamp + msg + "\n")
  logfile.flush()

#=============================================================================
# Create database connection.
#=============================================================================
def db_connect():
  delay=1
  maxdelay=3600
  while 1:
    try:
      db = MySQLdb.connect(host=mysql_host, user=mysql_user, passwd=mysql_password, db=mysql_db)
      return db
    except:
      log("Exception in db_connect")
    time.sleep(delay)
    delay = delay * 2
    if delay > maxdelay:
      delay = maxdelay

#=============================================================================
# MySQL helpers
#=============================================================================
def mysql(c,cmd):
  try:
    count = c.execute(cmd)
    return c.fetchall()
  except MySQLdb.MySQLError, (num, expl):
    log("%s.  SQL was: %s" % (expl,cmd))
    return ()
  except:
    log("Some other MySQL exception.")
    return ()

def mysql_act(c,cmd):
  try:
    count = c.execute(cmd)
    return ""
  except MySQLdb.MySQLError, (num, expl):
    return expl


#=========================================================================
# Update the heartbeat indicator of the view.
#=========================================================================
def refresh_heartbeat(viewid,interval):
  log("Refreshing heartbeat for view %d" % viewid)
  db = db_connect()
  c = db.cursor()
  arr = mysql(c,"""UPDATE view SET heartbeat=now()
                   WHERE viewid='%d'
                """ % viewid)
  db.close()
  signal.alarm(interval)

#=========================================================================
# The child has exited.  End the view.
#=========================================================================
def end_view(viewid, starttime, sessnum, username, remoteip):
  try:
    [pid,status] = os.wait()
    #log("Child %d exited with status %d" % (pid,status))
  except OSError:
    pass
  #log("Child %d exited." % pid)

  endtime = time.time()
  db = db_connect()
  c = db.cursor()
  err = mysql_act(c,"""
    INSERT INTO viewlog(sessnum,username,remoteip,time,duration)
    SELECT sessnum, username, remoteip, start, %f
    FROM view WHERE viewid='%d'
  """ % ((endtime-starttime), viewid))

  if err != "":
    log("ERROR: %s" % err)

  err = mysql_act(c,"""UPDATE session JOIN view ON session.sessnum=view.sessnum
                       SET session.accesstime=now()
                       WHERE viewid='%d'""" % viewid)
  if err != "":
    log("ERROR: %s" % err)

  err = mysql_act(c,"""DELETE FROM view WHERE viewid='%d'""" % viewid)
  if err != "":
    log("ERROR: %s" % err)

  db.close()
  log("View %d (%s@%s) ended after %f seconds." % (viewid,username,remoteip,endtime-starttime))
  sys.exit(0)

#=========================================================================
# forward a VNC connection
#=========================================================================
def do_proxy_vnc(connect,sessnum,username,remoteip,timeout):
  #log("do_proxy_vnc")

  global db

  #
  # Create an entry in the database to mark the time this view began.
  # Then close the database.
  #
  starttime = time.time()
  c = db.cursor()
  err = mysql_act(c,"""UPDATE session SET accesstime=now()
                   WHERE sessnum = %d""" % (sessnum))
  if err != "":
    log("ERROR: %s" % err)

  err = mysql_act(c,"""
    INSERT INTO view(sessnum,username,remoteip,start,heartbeat)
    VALUES(%d,'%s','%s',now(),now())""" % (sessnum,username,remoteip))
  if err != "":
    log("ERROR: %s" % err)
    sys.exit(1)
  arr = mysql(c,"""SELECT LAST_INSERT_ID()""")
  row = arr[0]
  viewid = row[0]
  db.close()

  #
  # This is a wrapper fork to separate the workers from the main program.
  #
  if os.fork() != 0:
    sys.exit(0)

  #
  # Make this child a process group leader.
  #
  os.setsid();

  #
  # fork off the socat command and wait for it.
  #
  pid = os.fork()
  if pid == 0:
    #
    # Close all descriptors except stdin/stdout/stderr.
    #
    for i in range(3,1024):
      try:
        os.close(i)
      except OSError:
        pass
    os.execlp("socat", "socat", "-", "tcp4:%s" % connect)
  else:
    sys.stdout.write("HTTP/1.0 200 Connection Established\n")
    sys.stdout.write("Proxy-agent: HUBzero connection redirector\n")
    sys.stdout.write("\n")
    sys.stdin.close()
    sys.stdout.close()
    #log("Waiting for pid %d" % pid)

    interval = int(timeout * 0.40)
    #log("timeout  = %d" % timeout)
    #log("interval = %d" % interval)

    def sighandler(sig,stuff):
      if sig == signal.SIGALRM:
        refresh_heartbeat(viewid,interval)
      else:
        end_view(viewid,starttime,sessnum,username,remoteip)

    signal.signal(signal.SIGALRM, sighandler)
    signal.signal(signal.SIGCHLD, sighandler)
    signal.signal(signal.SIGHUP, sighandler)
    signal.signal(signal.SIGTERM, sighandler)
    #signal.signal(signal.SIGKILL, sighandler)
    signal.alarm(interval)
    while 1:
      signal.pause()

#=============================================================================
# In case of errors.
#=============================================================================
def do_proxy_null(connect,sessnum,username,remoteip,timeout):
  log("do_proxy_null")
  sys.exit(0)
  return

#=============================================================================
# Translate the VNC ID to host:port
#=============================================================================
def translate(ip,host):
  global db

  # The idea is that we could have multiple proxy functions.
  # For now, only do_proxy_vnc is the only meaningful one.
  # We'll set that after we've parsed the CONNECT request.
  func=do_proxy_null
  failure = func, "", -1, "", -1

  port=0

  if host == "":
    log("No CONNECT host specified.")
    return failure

  arr = host.split(":")
  if len(arr) != 2:
    log("Improper CONNECT host specified.")
    return failure

  type=arr[0]
  id=arr[1]

  if type != "vncsession":
    log("Improper CONNECT host specified.")
    return failure

  db = db_connect()
  c = db.cursor()
  arr = mysql(c,"""
     SELECT viewperm.sessnum,viewuser,display.hostname,session.dispnum,
            session.timeout,portbase
     FROM viewperm
     JOIN display ON viewperm.sessnum = display.sessnum
     JOIN session ON session.sessnum = display.sessnum
     JOIN host ON display.hostname = host.hostname
     WHERE viewperm.viewtoken='%s'""" % id)

  if len(arr) != 1:
    log("No database entry found for ID %s." % id)
    return failure

  row = arr[0]
  sessnum = row[0]
  username = row[1]
  host = row[2]
  dispnum = int(row[3])
  timeout = int(row[4])
  portbase = int(row[5])

  if type == "vncsession":
    portbase = ssl_vnc_portbase
    func = do_proxy_vnc
    port = dispnum + portbase

  log("Map %s@%s for %s to %s:%d" % (username,ip,id,host,port))
  return func, host + ":" + str(port), sessnum, username, timeout

#=========================================================================
# Indicate to the client that the request failed.
#=========================================================================
def bad_request():
  sys.stdout.write("HTTP/1.1 400 Bad Request\n")
  sys.stdin.close()
  sys.stdout.close()
  sys.exit(1)

#=========================================================================
# handle one proxy connection
#=========================================================================
def do_proxy(remoteip):
  global db
  HEADER_LIMIT=10

  def timeout(sig,frame):
    log("Connection from %s timed out." % remoteip)
    sys.exit(1)
  def ignore(sig,frame):
    pass
  signal.signal(signal.SIGALRM, timeout)
  signal.alarm(10)

  connect=""
  orig_connect=""

  lineno=0
  while lineno < HEADER_LIMIT:
    line=sys.stdin.readline()
    lineno=lineno+1
    array=line.split()
    if array == []:
      log("Done parsing HTTP header.")
      break
    else:
      log(line.strip())
      if array[0].upper() == "CONNECT":
        connect = array[1]
        orig_connect=connect
        continue
      elif array[0].upper() == "GET":
        log("GET received.")
        #sys.stdout.write("This proxy does not support the GET method.\n\n")
        sys.stdout.write("""<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Strict//EN" "http://www.w3.org/TR/xhtml1/DTD/xhtml1-strict.dtd">
<html xmlns="http://www.w3.org/1999/xhtml" lang="en" xml:lang="en">
<body>
Connection to proxy succeeded.
</body>
</html>""")
        sys.stdout.write("\n\n")
	sys.exit(0)

  signal.signal(signal.SIGALRM, ignore)

  if connect == "":
    log("No CONNECT specified in request from %s" % remoteip)
    sys.exit(1)

  if lineno > HEADER_LIMIT:
    log("ERROR: Host %s exceeded header limit." % remoteip)

  if lineno < HEADER_LIMIT:
    func, connect, sessnum, username, timeout = translate(remoteip,orig_connect)

  if connect == "":
    log("ERROR: No translation for %s@%s." % (orig_connect,remoteip))

  if connect == "" or lineno >= HEADER_LIMIT:
    bad_request()

  #log("Host to connect to is %s" % connect)

  func(connect, sessnum, username, remoteip, timeout)

def openSocket(host, port):
  l = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
  l.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
  l.bind( (host, port) )
  l.listen(10)
  return l

def openSSLSocket(host, port):
  configdir = "."
  ctx = SSL.Context(SSL.SSLv3_METHOD)
  ctx.use_privatekey_file (os.path.join(configdir,"server.key"))
  ctx.use_certificate_file(os.path.join(configdir,"server.crt"))
  ctx.load_verify_locations(os.path.join(configdir,"ca.crt"))
  s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
  s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
  server = SSL.Connection(ctx, s)
  server.bind( (host, port) )
  server.listen(10)
  return server

def daemonize():
  log("Backgrounding process.")
  if os.fork():
    os._exit(0)
  os.setsid()
  os.chdir("/")
  if os.fork():
    os._exit(0)
  try:
    f = open(pidfile,"w")
    f.write("%d\n" % os.getpid())
    f.close()
  except:
    log("Unable to write pid (%d) to %s" % (os.getpid(),pidfile))

def sighandler(sig,frame):
  log("Caught signal %d.  Exiting." % sig)
  sys.exit(1)

#=========================================================================
# main listener program
#=========================================================================

openlog("/var/log/hubzero/vncproxy.log")

signal.signal(signal.SIGHUP, sighandler)
signal.signal(signal.SIGINT, sighandler)
signal.signal(signal.SIGQUIT, sighandler)
signal.signal(signal.SIGTERM, sighandler)
#signal.signal(signal.SIGKILL, sighandler)

l = openSocket(socket.gethostname(), vnc_listen_port)

if not foreground:
  daemonize()

log("Server is ready.")

while 1:
  s, [host,port] = l.accept()

  if os.fork() == 0:
    os.setsid();
    l.close();
    log("Accepted connection from " + host)
    os.dup2(s.fileno(), 0)
    os.dup2(s.fileno(), 1)
    if os.fork() == 0:
      do_proxy(host)
    else:
      os._exit(0)
  else:
    s.close()
    try:
      [pid,status]=os.waitpid(-1,os.WNOHANG)
    except:
      pass

