#!/usr/bin/python
# @package      hubzero-mw2-front-proxy
# @file         front-proxy.py
# @author       Pascal Meunier <pmeunier@purdue.edu>
# @copyright    Copyright (c) 2016-2017 HUBzero Foundation, LLC.
# @license      http://opensource.org/licenses/MIT MIT
#
# Packaging of original work by Richard L. Kennell
#
# Copyright (c) 2016-2017 HUBzero Foundation, LLC.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#
# HUBzero is a registered trademark of HUBzero Foundation, LLC.
#

#==============================================================================
# Front-end proxy
# This proxy runs on the web front-end and handles AJAX and WebSocket
# connections.  Apache 2.4 or better can be used to redirect AJAX and WSS
# into a locally-running copy of this proxy.  Earlier versions of Apache
# cannot act as a reverse proxy for WS/WSS, so this proxy must run on a
# separate IP address.
# This proxy forks a child process for each connection, interprets the
# HTTP header, looks up the information in a database, and chooses a
# host:port to forward the connection to.
#==============================================================================

import socket
import time
import sys
import os
import errno
import traceback
import fcntl
from OpenSSL import SSL
import select
import signal
import MySQLdb
from hubzero.mw.mw_db import MW_DB
from hubzero.mw.log import setup_log, log

#
# This config file can override any of the following global constants.
#
FRONT_PROXY_CONFIG_FILE = '/etc/mw-proxy/front-proxy.conf'
FRONT_CONFIGDIR = '/etc/mw-proxy'

FRONT_LISTEN_HOST = None  # IP address to bind to
FRONT_LISTEN_PORT = None  # port number to bind to
FRONT_LISTEN_SSL = None   # Listen as SSL socket (True/False)
FRONT_LISTEN_TLSv1_2 = True # Newer TLS version requiring Python 2.7.9 at least
FRONT_KEY = None          # SSL private key
FRONT_CERT = None         # SSL certificate
FRONT_CACERT = None       # SSL CA certificate
FRONT_DHPARAM = None      # Diffie-Hellman param
FRONT_CIPHERS = None      # List of ciphers to use/exclude
FORWARD_PORT = 443        # Port to connect to on execution hosts
FORWARD_SSL = True        # Use SSL to connect to execution hosts (True/False)
FORWARD_TLSv1_2 = True	  # Use TLS 1.2 to connect to execution hosts (True/False)

#FORWARD_PORT = 8001
#FORWARD_SSL = False
PIDFILE = '/var/run/front-proxy.pid'
PROXY_LOG = '/var/log/front-proxy/front-proxy.log'
HEADER_TIMEOUT = 20     # How long should we wait for an HTTP header
CACHE_FLUSH_PERCENT = 1 # How many times per hundred accesses is cache flushed
CACHE_DURATION = 86400  # Max age of something to keep in cache
RUN_UID = 'hz-front-proxy'  # UID to run the proxy
RUN_GID = 'hz-front-proxy'	# GID to run the proxy
VERBOSITY = 0

from hubzero.mw.constants import MYSQL_CONN_ATTEMPTS, CONFIG_FILE
execfile(CONFIG_FILE)
if os.path.isfile(FRONT_PROXY_CONFIG_FILE):
  execfile(FRONT_PROXY_CONFIG_FILE)

if os.path.isfile(FRONT_KEY):
  ABSOLUTE_FRONT_KEY = FRONT_KEY
  ABSOLUTE_FRONT_CERT = FRONT_CERT
  ABSOLUTE_FRONT_CACERT = FRONT_CACERT
else:
  ABSOLUTE_FRONT_KEY = os.path.join(FRONT_CONFIGDIR, FRONT_KEY)
  ABSOLUTE_FRONT_CERT = os.path.join(FRONT_CONFIGDIR, FRONT_CERT)
  ABSOLUTE_FRONT_CACERT = os.path.join(FRONT_CONFIGDIR, FRONT_CACERT)

verbosity = int(VERBOSITY)
 

#==============================================================================
# The default heartbeat interval is 1 hour.
#==============================================================================
heartbeat_interval = 3600

#==============================================================================
# Various printing functions.
#==============================================================================
def verbose(str):
  #log(str)
  pass

def fatal(str):
  log('FATAL: ' + str)
  # Note that we use sys.exit here to throw a SystemExit up the call stack.
  sys.exit(1)

def log_exception(title):
  cla, exc, trbk = sys.exc_info()
  excName = cla.__name__
  log(title + ' ' + excName + ':')
  log(str(exc))
  log('-'*50)
  excTb = traceback.format_tb(trbk, 10)
  for entry in excTb:
    for line in entry.strip().split('\n'):
      log(line.strip())
  log('-'*50)

def print_exception(title):
  print(title)
  print('-'*50)
  cla, exc, trbk = sys.exc_info()
  excName = cla.__name__
  excTb = traceback.format_tb(trbk, 10)
  for entry in excTb:
    for line in entry.strip().split('\n'):
      print(line.strip())
  print('-'*50)

#==============================================================================
# Enable/disable blocking for a SSL, socket, or file descriptor.
#==============================================================================
def setblocking(x, value):
  if 'SSL' in str(type(x)):
    x.setblocking(value)
    return
  if 'socket' in str(type(x)):
    # Replace x with the file descriptor and fall through to cases below.
    x = x.fileno()
  if type(x) != type(1):
    fatal("Don't know how to set blocking of type %s" % str(type(x)))
  if value:
    # Do block on I/O.
    fcntl.fcntl(x,fcntl.F_SETFL, fcntl.fcntl(x,fcntl.F_GETFL) &~os.O_NONBLOCK)
  else:
    # Do not block on I/O.
    fcntl.fcntl(x,fcntl.F_SETFL, fcntl.fcntl(x,fcntl.F_GETFL) | os.O_NONBLOCK)

#==============================================================================
# Occasionally flush old entries from the cache.
#==============================================================================
def flush_cache(cache, t):
  if t-int(t) < CACHE_FLUSH_PERCENT / 100.0:
    #verbose('Checking session cache at %s' % str(t))
    sessions = cache.keys()
    for sessnum in sessions:
      #verbose('Checking %d: %s' % (sessnum,session_cache[sessnum][4]))
      if cache[sessnum][4] + CACHE_DURATION < t:
        #verbose('Deleting %d' % sessnum)
        del cache[sessnum]

#==============================================================================
# Try to read all pipes open to all child processes.  Get the session number,
# username, filexfer cookie, execution host, and container number.  Incorporate
# that information into the cache---as well as the time at which it occurred.
# Close the pipe reader when it has been read.
#==============================================================================
def read_cache_pipes(rlist, cache, t):
  for r in rlist:
    try:
      stuff = os.read(r,1000)
      #verbose('Read from %d: "%s"' % (r,stuff))
      os.close(r)
      rlist.remove(r)
      if stuff == '':
        continue
      try:
        sessnum,username,cookie,exechost,ctnum = stuff.split(',')
      except:
        log("Unable to split cache info '%s'" % stuff)
        continue
      try:
        sessnum = int(sessnum)
        ctnum = int(ctnum)
      except:
        error("Bad format for cache values")
      cache[sessnum] = (username,cookie,exechost,ctnum,t)
    except OSError, error:
      if error.errno == errno.EAGAIN:
        pass
      else:
        fatal("Error %d in read pipe" % error.errno)

#==============================================================================
# Write the sessnum, username, filexfer cookie, execution host and container
# number to the pipe open with the parent.
#==============================================================================
def write_cache_pipe(w, sessnum, username, cookie, host, ct):
  # A write to the pipe might fail if the reader already closed it.
  # This could happen if the parent process was killed.
  try:
    os.write(w, "%d,%s,%s,%s,%d" % (sessnum,username,cookie,host,ct))
  except:
    pass
  os.close(w)

#==============================================================================
# Create a socket to listen on.  Optionally wrap it with SSL.
#==============================================================================
def create_listener(host,port):
  sock = socket.socket()
  sock.setsockopt(socket.SOL_SOCKET,socket.SO_REUSEADDR,1)
  if FRONT_LISTEN_SSL:
    if FRONT_LISTEN_TLSv1_2:
      ctx = SSL.Context(SSL.TLSv1_2_METHOD)
      ctx.set_options(SSL.OP_NO_TLSv1)
      ctx.set_options(SSL.OP_NO_TLSv1_1)
    else:
      ctx = SSL.Context(SSL.TLSv1_METHOD)
    ctx.set_options(SSL.OP_NO_SSLv2)
    ctx.set_options(SSL.OP_NO_SSLv3)
    ctx.use_privatekey_file (ABSOLUTE_FRONT_KEY)
    ctx.use_certificate_file(ABSOLUTE_FRONT_CERT)
    if ABSOLUTE_FRONT_CACERT != None:
      if ABSOLUTE_FRONT_CACERT != '':
        ctx.load_verify_locations(ABSOLUTE_FRONT_CACERT)
    if FRONT_DHPARAM != None:
      ctx.load_tmp_dh(os.path.join(FRONT_CONFIGDIR,FRONT_DHPARAM))
    if FRONT_CIPHERS != None:
      ctx.set_cipher_list(FRONT_CIPHERS)
    ctx.set_options(SSL.OP_NO_SSLv2)
    ctx.set_options(SSL.OP_NO_SSLv3)
    ctx.set_options(SSL.OP_SINGLE_DH_USE)
    sock = SSL.Connection(ctx, sock)
  sock.bind((host, port))
  sock.listen(5)
  return sock

#==============================================================================
# Exceptions for reading/writing.
#==============================================================================
class ReaderClose(StandardError):
  """Socket or SSL is closed for reading."""
  pass

class WriterClose(StandardError):
  """Socket or SSL is closed for writing."""
  pass

#==============================================================================
# Read a string from a socket.
#==============================================================================
def read_chunk(s,maxlen=65536):
  try:
    msg = s.recv(maxlen)
  except SSL.ZeroReturnError:
    raise ReaderClose()
  except socket.error, error:
    if error.errno == errno.EAGAIN:
      log('Caught EAGAIN for %s' % str(type(s)))
      return ''
    else:
      log("read_chunk socket.error" + str(error))
      raise ReaderClose()
  except (SSL.WantReadError, SSL.WantWriteError, SSL.WantX509LookupError), errors:
    #log("read_chunk SSL.WantReadError" + str(errors))
    return ''
  except SSL.Error, errors:
    log("read_chunk SSL.Error " + str(errors))
    raise ReaderClose()
  except OSError, error:
    if error.errno == errno.EAGAIN:
      log('Caught EAGAIN for %s' % str(type(s)))
      return ''
    else:
      log("Error %d in read_chunk" % error.errno)
      raise ReaderClose()

  if msg == '':
    raise ReaderClose()
  return msg

#==============================================================================
# Write a string to a socket.  Return the unwritten part of the string.
#==============================================================================
def write_chunk(s,msg):
  # Upon upgrade from Python 2.7.3 to Python 2.7.9 installation, new errors required adding the "while True" loop
  # Without while loop, this program would log "SSL.Error [('SSL routines', 'ssl3_write_pending', 'bad write retry')]"
  # because it retried a write using a different buffer (even if same contents).  It would then abort and transferred files were truncated
  # Without while loop and without SSL, it would log "FATAL: write_chunk: socket.error [Errno 11] Resource temporarily unavailable"
  # because EAGAIN errors were retried using a different buffer
  # the while loop prevents the buffer from being changed
  # "When you retry a write, you must retry with the exact same buffer"
  # "the same contents are not sufficient and, of course, different contents is absolutely prohibited"
  while True:
    try:
      n = s.send(msg)
      return msg[n:]
    except OSError, error:
      if error.errno == errno.EAGAIN:
        # loop and try again
        continue
      else:
        fatal("Error %d in write_chunk" % error.errno)
    except SSL.WantWriteError:
      # Call select to wait efficiently (instead of busy waiting) until ready for writing
      # select.select(rlist, wlist, xlist[, timeout])
      _, wlist, _ = select.select([], [s], [], 60)
      # check for timeout or other error
      if not wlist:
        fatal("write_chunk: socket not ready to write after SSL.WantWriteError")
      # loop and try writing again
      continue
    except (SSL.WantReadError, SSL.WantX509LookupError):
      return msg
    except SSL.ZeroReturnError:
      fatal("write_chunk: Zero Return")
    except SSL.Error, errors:
      fatal("write_chunk): SSL.Error " + str(errors))
    except socket.error, errors:
      fatal("write_chunk: socket.error " + str(errors))

#==============================================================================
# Send a shutdown only for an SSL socket.  Then close it.
#==============================================================================
def shutdown(s):
  if 'SSL' in str(type(s)):
    setblocking(s,True)
    try:
      s.shutdown()
    except: # Could be that the connection is already shutdown
      pass
  else:
    setblocking(s.fileno(), True)
  s.close()

#==============================================================================
# Read an HTTP header, return it as an array.
# Any additional data read after the double CRLF is returned as body.
#==============================================================================
debug_header=''
def read_header(ns):
  global debug_header
  chunk=''
  while chunk.rfind('\r\n\r\n') == -1:
    first = (len(chunk) == 0)
    try:
      chunk += read_chunk(ns)
    except ReaderClose:
      break

    if first:
      debug_header = chunk.split('\r')[0]
    if len(chunk) > 100000:
      fatal('Header is too long')
  arr = chunk.split('\r\n\r\n', 1)
  hdr = arr[0].split('\n')
  for n in range(0,len(hdr)):
    hdr[n] = hdr[n].strip()
  if len(hdr) < 1:
    fatal('Malformed header1: ' + str(hdr))
  if len(arr) == 2:
    body = arr[1]
  else:
    body = ''
  return hdr,body

#==============================================================================
# Print a header.
#==============================================================================
def header_print(hdr):
  print('\n'.join(hdr))

#==============================================================================
# Send the header (and any body data) through a socket/SSL.
#==============================================================================
def send_header(ns,hdr,body):
  #header_print(hdr)
  chunk = '\r\n'.join(hdr) + '\r\n\r\n' + body
  while len(chunk) > 0:
    chunk = write_chunk(ns,chunk)

  if verbosity > 1:
    # Check things...
    i = header_find(hdr, 'Connection:')
    if i < 0:
      fatal("No Connection: statement")
    if hdr[i] != 'Connection: close':
      log('>>>>> ' + hdr[i])

#==============================================================================
# Log the full HTTP header.
#==============================================================================
def log_header(hdr,body):
  log('================================')
  for h in hdr:
    log(h)
  if len(body) > 0:
    log(body)

#==============================================================================
# Find an entry in the header.
#==============================================================================
def header_find(hdr, s):
  for n in range(0,len(hdr)):
    if s in hdr[n]:
      return n
  return -1

#==============================================================================
# Find a key in the header and replace its value, or create it if nonexistent.
#==============================================================================
def header_set(hdr, key, value):
  start = key + ':'
  found = False
  for n in range(0,len(hdr)):
    if hdr[n].startswith(start):
      hdr[n] = start + ' ' + value
      found = True
  if not found:
    hdr += [ start + ' ' + value ]

#==============================================================================
# Return the content-length field or -1 if not found.
#==============================================================================
def header_content_length(hdr):
  idx = header_find(hdr,'Content-Length:')
  if idx < 0:
    return -1
  try:
    return int(hdr[idx][len('Content-Length:'):])
  except ValueError:
    log('Bad Content-Length: ' + hdr[idx])
    return -1

#==============================================================================
# This function is used if we want to verify an SSL server we connect to.
#==============================================================================
def verify_cb(conn, crt, errnum, depth, ok):
  verbose('Got cert: %s' % cert.get_subject())
  verbose('Issuer: %s' % cert.get_issuer())
  verbose('Depth: %s' % str(depth))
  return ok

#==============================================================================
# Look up the username, filexfer cookie, execution host, and container number.
#==============================================================================
def find_session_params(db, sessnum, remoteip):
  rows = db.getall("""
    SELECT fileperm.fileuser, fileperm.cookie, display.hostname, display.dispnum
    FROM fileperm, display
    WHERE fileperm.sessnum = display.sessnum
    AND display.sessnum = %d""" % sessnum, ())
  if len(rows) == 1:
    user,cookie,host,ct = rows[0]
  else:
    user,cookie,host,ct = '','','',-1
  return user,cookie,host,ct

#==============================================================================
# Interpret the header and open a socket/SSL to the execution host's proxy.
#==============================================================================
def forward(hdr, body, cache_writer, cache, remoteip):
  global verbosity
  host = 'fail'
  port = FORWARD_PORT
  arr = hdr[0].split(' ')
  if len(arr) < 2:
    if hdr[0] != '':
      fatal('Malformed header2: ' + hdr[0])
    else:
      log('Malformed header2: ' + hdr[0])
      log_header(hdr,body)
      fatal("Can't go on")
  action = arr[0]
  url = arr[1]
  params = []
  if url.find('?') != -1:
    arr = url.split('?')
    #if len(arr) != 2:
    #  fatal("Did not expect two '?' in URL")
    url = arr[0]
    params = arr[1].split('&')
  comp=url.split('/')[1:]
  if verbosity > 2:
    log_header(hdr,body)
    if action != 'POST':
      verbosity -= 1
  elif verbosity > 0:
    log(url)
  if comp[0] not in [ 'weber', 'notebook' ]:
    fatal('Malformed URL: ' + url)
  try:
    sessnum = int(comp[1])
  except:
    fatal('Malformed URL: ' + url)

  cookie = comp[2]

  try:
    ct=int(comp[3])
  except:
    fatal('Malformed URL: ' + url)

  idx = header_find(hdr,'Upgrade: websocket')
  if idx == -1:
    header_set(hdr, 'Connection', 'close')

  if sessnum in cache:
    #verbose('Cache hit')
    values = cache[sessnum]
  else:
    #verbose('Cache miss')
    db = MW_DB(mysql_host, mysql_user, mysql_password, mysql_db, MYSQL_CONN_ATTEMPTS)
    values = find_session_params(db,sessnum,remoteip)
    db.close()

  username = values[0]
  host = values[2]
  if host == '':
    # Cache the invalid entry so we don't continually
    # bother the database for invalid sessions.
    write_cache_pipe(cache_writer, sessnum, '', '', '', -1)
    fatal('Session %d not found. rhost=%s' % (sessnum,remoteip))
  if values[1] != cookie:
    fatal('Session cookie does not match. rhost=%s' % remoteip)
  if values[3] != ct:
    fatal('Session container does not match. rhost=%s' % remoteip)

  write_cache_pipe(cache_writer, sessnum, username, cookie, host, ct)

  newurl = url
  sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
  if FORWARD_SSL:
    if FORWARD_TLSv1_2:
      ctx = SSL.Context(SSL.TLSv1_2_METHOD)
    else:
      ctx = SSL.Context(SSL.TLSv1_METHOD)
      
    #ctx.set_verify(SSL.VERIFY_PEER, verify_cb) # Demand a certificate
    #ctx.load_verify_locations(os.path.join(FRONT_CONFIGDIR,'server.crt'))
    sock = SSL.Connection(ctx, sock)
  try:
    sock.connect((host, port))
    send_header(sock,hdr,body)
    return sock,sessnum,username
  except:
    fatal("Unable to forward to %s:%d" % (host,port))

#==============================================================================
# Copy a plaintext socket or SSL socket bidirectionally.
#==============================================================================
def bidirectional_copy(a,b):
  setblocking(a, False)
  setblocking(b, False)
  af = a.fileno()
  bf = b.fileno()
  amsg = ''
  bmsg = ''
  rfds = [af,bf]
  wfds = []
  while len(rfds) > 0 or len(wfds) > 0:
    try:
      rd,wr,_ = select.select(rfds,wfds,[])
    except select.error as e:
      if e.errno == errno.EAGAIN:
        # loop and try again
        continue
      elif e.errno == errno.EBADF:
        fatal("bidirectional_copy failed with EBADF (Bad file descriptor)")
      else:
        fatal("bidirectional_copy failed with error %d" + str(e.errno))
    if af in rd:
      #verbose('af is readable')
      #os.write(1,'>')
      try:
        amsg += read_chunk(a)
        if bf not in wfds:
          wfds += [bf]
          wr += [bf]
      except ReaderClose:
        rfds.remove(af)
        if bf not in wfds:
          wfds += [bf]
        continue
    if bf in rd:
      #verbose('bf is readable')
      #os.write(1,'<')
      try:
        chunk = read_chunk(b)
        if verbosity > 2:
          for ln in chunk.split('\r\n'):
            log('REPLY: ' + ln)
        bmsg += chunk
        if af not in wfds:
          wfds += [af]
          wr += [af]
      except ReaderClose:
        rfds.remove(bf)
        if af not in wfds:
          wfds += [af]
        continue
    if bf in wr:
      if len(amsg) > 0:
        amsg = write_chunk(b,amsg)
      if len(amsg) == 0:
        wfds.remove(bf)
    if af in wr:
      if len(bmsg) > 0:
        bmsg = write_chunk(a,bmsg)
      if len(bmsg) == 0:
        wfds.remove(af)
    if af not in rfds and len(amsg) == 0:
      shutdown(b)
      if bf in rfds:
        rfds.remove(bf)
    if bf not in rfds and len(bmsg) == 0:
      shutdown(a)
      if af in rfds:
        rfds.remove(af)
  #verbose('bidirectional_copy finished')

#==============================================================================
# Forward content of a particular length from a to b.
#==============================================================================
def forward_body(a, b, blen):
  orig_blen = blen
  setblocking(a, True)
  setblocking(b, True)

  # Read the stream for the precise number of bytes, then check for close.
  while blen > 0:
    try:
      chunk = read_chunk(a,min(blen,4096))
    except ReaderClose:
      return
    blen -= len(chunk)
    while len(chunk) > 0:
      chunk = write_chunk(b, chunk)

  # Check if the incoming socket is now closed.
  # If it is closed, it will throw a ReaderClosed exception.
  setblocking(a, False)
  chunk = read_chunk(a)
  if chunk != '':
    fatal('Read %d excess bytes beyond content length of %d' % (len(chunk),orig_blen))
  setblocking(a, True)

#==============================================================================
# Add an entry into the view table for an active WebSocket connection.
#==============================================================================
def add_viewer(db, sessnum, username, host):

  timeout = db.getsingle("""SELECT timeout FROM session WHERE sessnum=%s""",
                         sessnum)
  if timeout is None:
    fatal('Unable to get timeout after insert')
  db.c.execute(
    """INSERT INTO view(sessnum,username,remoteip,start,heartbeat)
       VALUES(%s, %s, %s, NOW(), NOW())""", (sessnum, username, host))
  view_str = db.getsingle("""SELECT last_insert_id()""", ())
  if view_str is None:
    fatal("Unable to get view ID after insert")
  return int(view_str),int(timeout)/4

#==============================================================================
# Remove the view entry when the WebSocket is closed.
#==============================================================================
def delete_viewer(viewid):
  if viewid == 0:
    return
  db = MW_DB(mysql_host,mysql_user,mysql_password,mysql_db, MYSQL_CONN_ATTEMPTS)
  db.c.execute(
    """INSERT INTO viewlog (sessnum,username,remoteip,time,duration)
       SELECT view.sessnum, view.username, view.remoteip, view.start,
              TIMESTAMPDIFF(SECOND, start, NOW())
       FROM view WHERE view.viewid = %d""" % viewid, ())
  db.c.execute("""DELETE FROM view WHERE viewid=%d""" % viewid, ())
  db.close()

#==============================================================================
# This function is invoked by a signal and updates the view table entry to
# indicate that a connection is still in progress.
#==============================================================================
heartbeat_viewid = 0
def heartbeat_handler(signo,extra):
  db = MW_DB(mysql_host,mysql_user,mysql_password,mysql_db, MYSQL_CONN_ATTEMPTS)
  db.c.execute("""UPDATE view SET heartbeat=NOW() WHERE viewid=%d""" % heartbeat_viewid, ())
  signal.alarm(heartbeat_interval)
  db.close()

#==============================================================================
# Don't let the connection read too long without a full header.
#==============================================================================
debug_header=''
def timeout_handler(signo,extra):
  fatal('Timeout waiting for header "' + str(debug_header) + '"')

#==============================================================================
# Handle a new connection.
#==============================================================================
def handle_connection(ns, cache_writer, cache, remoteip, remoteport):
  global heartbeat_viewid
  global heartbeat_interval
  signal.signal(signal.SIGALRM, timeout_handler)
  signal.alarm(HEADER_TIMEOUT)
  hdr,body = read_header(ns)
  signal.alarm(0)

  if len(hdr) < 1 or hdr[0] == '':
    fatal('Empty HTTP header')

  if FRONT_LISTEN_SSL:
    # The proxy is the primary front-end web interface.  Create the
    # X-Real-IP and X-Forwarded-For header entries.  Replace others.
    header_set(hdr, 'X-Forwarded-For', remoteip)
    header_set(hdr, 'X-Forwarded-Port', str(FRONT_LISTEN_PORT))
    header_set(hdr, 'X-Forwarded-Proto', "https")
    header_set(hdr, 'X-Real-IP', remoteip)
    header_set(hdr, 'X-Scheme', "https")

  else:
    # If not using SSL, assume that the connection is forwarded from
    # Apache or some other kind of front-facing web server that adds an
    # X-Real-IP or X-Forwarded-For entry to the HTTP header to indicate
    # the true origin of the connection.
    idx = header_find(hdr, 'X-Real-IP')
    if idx == -1:
      idx = header_find(hdr, 'X-Forwarded-For')
    if idx != -1:
      arr = hdr[idx].split(':')
      if len(arr) != 2:
        log('Malformed X-Real-IP line: ' + hdr[idx])
      else:
        remoteip = arr[1].strip()

  # Set up the forwarded connection to the execution host proxy.
  newconn,sessnum,username = forward(hdr, body, cache_writer, cache, remoteip)
  if not newconn:
    # TODO: issue a 404 response or something.
    shutdown(ns)
    return

  if header_find(hdr, 'Upgrade: websocket') != -1:
    # The connection is upgraded to a websocket.  Record the duration of
    # the connection in the 'view' table of the database.
    db = MW_DB(mysql_host, mysql_user, mysql_password, mysql_db, MYSQL_CONN_ATTEMPTS)
    heartbeat_viewid,heartbeat_interval = add_viewer(db,sessnum,username,remoteip)
    db.close()
    signal.signal(signal.SIGALRM, heartbeat_handler)
    signal.alarm(heartbeat_interval)
    try:
      bidirectional_copy(ns, newconn)
    except SystemExit:
      pass
    signal.alarm(0)
    delete_viewer(heartbeat_viewid)
    heartbeat_viewid = 0
  else:
    # The connection is a standard transient HTTP request.  Forward the body
    # of the request, then get the header for the reply.  Ensure that the
    # 'Connection: close' entry is set in the header so that the browser
    # knows it cannot reuse the connection for request pipelining.
    bodylen = header_content_length(hdr)
    if bodylen < 0:
      bodylen = 0
    if bodylen > 0:
      forward_body(ns, newconn, bodylen - len(body))
    hdr,body = read_header(newconn)
    header_set(hdr, 'Connection', 'close')
    send_header(ns, hdr, body)
    bidirectional_copy(newconn, ns)


#==============================================================================
# Accept a new connection, handle it as a child process, incorporate the
# child's database use into a cache.
#==============================================================================
def main_loop(ls):
  global PIDFILE
  session_cache = {}
  rlist = []
  while True:
    try:
      ns,addr = ls.accept()
    except OSError:
      # accept() could be interrupted by a signal.
      continue
    t = time.time()
    read_cache_pipes(rlist, session_cache, t)
    r,w = os.pipe()
    if os.fork() == 0:
      PIDFILE = None # Don't let a child delete the pidfile.
      ls.close() # Close the listener in the child
      os.close(r)
      if os.fork() == 0:
        os.setsid()
        if verbosity > 1:
          log('Connect from ' + addr[0] + ':' + str(addr[1]))
        try:
          handle_connection(ns, w, session_cache, addr[0], addr[1])
        except SystemExit:
          pass
        except:
          log_exception("Exception in child:")
        # If handle_connection() failed because of an exception, try to do
        # a delete_viewer() here.  This one will not be subject to exception
        # notification, so we won't catch any bugs with this call.  Instead,
        # we rely on the call inside handle_connection() to happen in most
        # circumstances and notify us if it fails.
        delete_viewer(heartbeat_viewid)
        if verbosity > 1:
          log('Closed ' + addr[0] + ':' + str(addr[1]))
      os._exit(0)
    else:
      # Add the pipe reader to a list of other pipe readers.
      # Make it non-blocking so that a read on a pipe that is not ready
      # will result in EAGAIN.  This allows us to try reading all of the
      # pipes instead of having to select() on them.
      rlist += [r]
      os.close(w)
      setblocking(r, False)
      try:
        # Always do two waits.  The second one should normally issue an
        # ECHILD because there will always be exactly one child to wait for
        # at this point.  However, if the first wait() is interrupted by a 
        # signal, we'll be left with an extra child waiting to be reaped.
        # With each signal, we'll have another unreaped child.  So we wait
        # twice to ensure we have no backlog.
        os.wait()
        os.wait()
      except:
        pass

    flush_cache(session_cache, t)

    ns.close()

#==============================================================================
# Run the process in the background.
#==============================================================================
def daemonize():
  os.chdir('/')
  if os.fork() == 0:
    os.setsid()
  else:
    os._exit(0)

#==============================================================================
# Create the PID file.
#==============================================================================
def write_pidfile(name):
  f = open(name, 'w')
  f.write(str(os.getpid()) + '\n')
  f.close()

#==============================================================================
# Remove the PID file.
#==============================================================================
def remove_pidfile(name):
  if name == None:
    return
  try:
    os.seteuid(0)
    os.unlink(name)
  except:
    log('Unable to remove %s' % name)

#==============================================================================
# Clean up on termination.
#==============================================================================
def termination_handler(signo,extra):
  if signo != 0:
    log('Terminating on signal %d' % signo)
  delete_viewer(heartbeat_viewid)
  remove_pidfile(PIDFILE)
  # There is no good reason for termination.  Consider it a failure.
  os._exit(1)

#==============================================================================
# The main function creates the main listening socket, sets up logging and
# daemonization, creates the pid file, invokes main_loop(), and handles
# top-level exceptions.
#==============================================================================
def main():
  global verbosity
  import syslog
  # we're root so write to syslog until we change uid instead of writing as root in a directory owned by a different user
  syslog.syslog("front-proxy starting up")

  # Create the listening socket before backgrounding or writing the pidfile.
  try:
    ls = create_listener(FRONT_LISTEN_HOST, FRONT_LISTEN_PORT)
  except Exception, e:
    syslog.syslog("FATAL: Unable to listen to %s:%d" % (FRONT_LISTEN_HOST, FRONT_LISTEN_PORT))
    syslog.syslog("FATAL: Exception is " + str(e))
    print_exception("FATAL: Unable to listen to %s:%d" % (FRONT_LISTEN_HOST, FRONT_LISTEN_PORT))
    print("FATAL: Exception is " + str(e))
    os._exit(1)
  signal.signal(signal.SIGINT, termination_handler)
  signal.signal(signal.SIGTERM, termination_handler)
  foreground = False
  for arg in sys.argv:
    if arg == '-f':
      foreground = True
    if arg == '-v':
      verbosity += 1
  if not foreground:
    try:
      fd = os.open("/dev/null", os.O_RDWR)
      os.dup2(fd, 0)
      os.dup2(fd, 1)
      os.dup2(fd, 2)
      os.close(fd)
    except Exception, e:
      syslog.syslog("FATAL: Unable to use /dev/null for input/output")
      syslog.syslog("FATAL: Exception is " + str(e))
      print("FATAL: Unable to use /dev/null for input/output")
      os._exit(1)

  if not foreground:
    daemonize()
  write_pidfile(PIDFILE)

  try:
    import grp
    os.setegid(grp.getgrnam(RUN_GID).gr_gid)
  except:
    fatal('Unable to set gid to %s' % RUN_GID)

  try:
    import pwd
    os.seteuid(pwd.getpwnam(RUN_UID).pw_uid)
  except:
    fatal('Unable to set uid to %s' % RUN_UID)

  setup_log(PROXY_LOG, None)
  log('Starting up')
  if verbosity > 0:
    log('Verbosity level %d' % verbosity)

  # Invoke the handler loop.  Catch and print any unusual exceptions.
  try:
    main_loop(ls)
  except SystemExit: # Something called fatal()
    pass
  except:
    log_exception("Exception in server:")

  # Once it's done, clean up.  A zero indicates no signal.
  termination_handler(0,0)

#==============================================================================
# Invoke the main function.
#==============================================================================
main()

