# @package      hubzero-mw2-common
# @file         zone.py
# @author       Pascal Meunier <pmeunier@purdue.edu>
# @copyright    Copyright (c) 2016-2017 HUBzero Foundation, LLC.
# @license      http://opensource.org/licenses/MIT MIT
#
#
# Copyright (c) 2016-2017 HUBzero Foundation, LLC.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#
# HUBzero is a registered trademark of HUBzero Foundation, LLC.
#

"""
Zones are an experimental feature to run tool containers at
various locations around the world with lower latencies.  It requires
the zones MySQL table.
#
mysql> describe zones;
+--------------+--------------+------+-----+---------+----------------+
| Field        | Type         | Null | Key | Default | Extra          |
+--------------+--------------+------+-----+---------+----------------+
| id           | int(11)      | NO   | PRI | NULL    | auto_increment |
| zone         | varchar(40)  | YES  |     | NULL    |                |
| title        | varchar(255) | YES  |     | NULL    |                |
| state        | varchar(15)  | YES  |     | NULL    |                |
| type         | varchar(10)  | YES  |     | NULL    |                |
| master       | varchar(255) | YES  |     | NULL    |                |
| mw_version   | varchar(3)   | YES  |     | NULL    |                |
| ssh_key_path | varchar(200) | YES  |     | NULL    |                |
| picture      | varchar(250) | YES  |     | NULL    |                |
+--------------+--------------+------+-----+---------+----------------+

"""

from constants import VERBOSE
from errors import MaxwellError
from log import log
from host import BasicHost
DEBUG = True

class Zone:
  """Zones are remote sites where tool sessions can be run."""
  def __init__(self, name):
    """overrides is a dictionary that allows overwriting the constants specified
    in DISPLAY_K.  Very useful for testing."""
    self.zone = name
    self.zone_id = None
    self.type = None
    self.host = None
    self.ssh_key_path = None
    self.state = None

  def get_type(self, db):
    """type is local or remote"""
    if self.type is not None:
      return self.type
    self.type = db.getsingle("""
      SELECT type FROM zones
      WHERE zone=%s""", self.zone
      )
    if self.type is None:
      raise MaxwellError("No type for zone %s" % self.zone)

  def is_local(self, db):
    """type is local?"""
    if self.get_type(db) == 'local':
      return True
    else:
      return False

  def is_remote(self, db):
    """type is not local?"""
    return not self.is_local(db)

  def get_ssh_key_path(self, db):
    """SSH key to communicate with remote master"""
    if self.ssh_key_path is not None:
      return self.ssh_key_path
    self.ssh_key_path = db.getsingle("""
      SELECT ssh_key_path FROM zones
      WHERE zone=%s""", self.zone
      )
    if self.ssh_key_path is None:
      raise MaxwellError("No ssh_key_path for zone %s" % self.zone)
    return self.ssh_key_path

  def get_host(self, db):
    """find the FQDN of the remote master"""
    if self.host is not None:
      return self.host
    mr_name = db.getsingle("SELECT master FROM zones WHERE id = %s", str(self.zone_id))
    if mr_name is None:
      if self.zone_id is None:
        raise MaxwellError("zone_id is None")
      raise MaxwellError("No command host for zone %s" \
          % (self.zone))
    self.host = BasicHost(mr_name, 'www-data', self.get_ssh_key_path(db))
    if self.host is None:
      raise MaxwellError("Can't allocate host object for command host for zone %s (zone_id %d)" \
          % (self.zone, self.zone_id))
    if VERBOSE:
      log("sending tool session to %s" % mr_name)
    return self.host

  def tell(self, db, cmd):
    """Give command to remote master"""
    if VERBOSE:
      log("sending command '%s'" % (["/usr/lib/mw/bin/maxwell_remote "] + cmd))
    self.get_host(db)
    code = self.host.ssh(["/usr/lib/mw/bin/maxwell_remote "] + cmd)
    if code != 0:
      raise MaxwellError("Error in tell zone")

  def ask(self, db, cmd):
    """Give command to remote master and relay the answer"""
    if VERBOSE:
      log("asking command '%s'" % (["/usr/lib/mw/bin/maxwell_remote "] + cmd))
    self.get_host(db)
    return self.host.ask_ssh(["/usr/lib/mw/bin/maxwell_remote "] + cmd)

  def get_state(self, db):
    """Retrieve state"""
    if self.state is not None:
      return self.state
    self.state = db.getsingle("""
      SELECT state FROM zones
      WHERE zone=%s""", self.zone
      )
    if self.state is None:
      raise MaxwellError("No state for zone %s" % self.zone)

  def is_up(self, db):
    """Check state"""
    if DEBUG:
      log("checking state of %s returned %s" % (self.zone, self.get_state(db)))
      if self.get_state(db) == 'up':
        log("returning True")
        return True
      else:
        log("returning False")
        return False
    return self.get_state(db) == 'up'

  def supports(self, db, app):
    """Check hosttype requirements against main host of venue"""
    # note that library can't be used to escape app.hostreq because it will add quotes!
    row = db.getrow(
      """SELECT count(*) FROM host
         WHERE zone_id = %d
         AND host.provisions & %d = %d
         AND host.status = 'up'""" % (self.zone_id, app.hostreq, app.hostreq), ())
    if row is None:
      return False
    return True

  @staticmethod
  def get_default(db):
    """Try to find a local zone"""
    zone_row = db.getrow(
      "SELECT zone, id FROM zones WHERE type = 'local' AND state = 'up'", ())
    if zone_row is None:
      raise MaxwellError("Error: No local zone available.")
    z = Zone(zone_row[0])
    z.type = 'local'
    z.zone_id = int(zone_row[1])
    return z

  @staticmethod
  def get_zone_by_name(db, name):
    """Retrieve the zone_id and other info, given the zone name"""
    row = db.getrow(
      "SELECT id, type, ssh_key_path FROM zones WHERE zone = %s AND state = 'up'", (name))
    if row is None:
      raise MaxwellError("Error: No zone with name '%s', or zone is down." % name)
    z = Zone(name)
    z.zone_id = int(row[0])
    z.type = row[1]
    z.state = 'up'
    z.ssh_key_path = row[2]
    return z


  @staticmethod
  def get_zone_by_id(db, zone_id):
    """Retrieve the zone_id and other info, given the zone name"""
    zone_id = int(zone_id)
    row = db.getrow(
      "SELECT zone, type, state, ssh_key_path FROM zones WHERE id = %s", (zone_id))
    if row is None:
      raise MaxwellError("Error: No zone with id '%s'" % zone_id)
    z = Zone(row[0])
    z.zone_id = zone_id
    z.type = row[1]
    z.state = row[2]
    z.ssh_key_path = row[3]
    return z

  @staticmethod
  def get_zone_by_master(db, master):
    """Retrieve the zone_id and other info, given the zone master host"""
    row = db.getrow(
      "SELECT zone, id, type, state, ssh_key_path FROM zones WHERE master = %s", (master))
    if row is None:
      raise MaxwellError("Error: No zone with master '%s'" % master)
    z = Zone(row[0])
    z.zone_id = int(row[1])
    z.type = row[2]
    z.state = row[3]
    z.ssh_key_path = row[4]
    return z

  @staticmethod
  def find(db, app):
    """Find a zone that will support the requested application.  Give priority to local zones, then less loaded zones"""
    # note that library can't be used to escape app.hostreq because it will add quotes!
    row = db.getrow(
      """SELECT zone, zones.id, zones.type, zones.ssh_key_path
         FROM zones JOIN host ON zones.id = host.zone_id
         WHERE host.provisions & %d = %d
         AND host.status = 'up' AND zones.state = 'up'
         AND (zones.master = host.hostname OR zones.type = 'local')
         ORDER BY zones.type, uses ASC LIMIT 1""" % (app.hostreq, app.hostreq), ())
    if row is None:
      return None
    z = Zone(row[0])
    z.zone_id = int(row[1])
    z.type = row[2]
    z.state = 'up'
    z.ssh_key_path = row[3]
    return z

