/**
 * @file       path.c
 * @copyright  Copyright (c) 2016-2020 The Regents of the University of California.
 * @license    http://www.gnu.org/licenses/lgpl-3.0.html LGPLv3
 *
 * Copyright (c) 2016-2020 The Regents of the University of California.
 *
 * This file is part of: The HUBzero(R) Platform for Scientific Collaboration
 *
 * The HUBzero(R) Platform for Scientific Collaboration (HUBzero) is free
 * software: you can redistribute it and/or modify it under the terms of
 * the GNU Lesser General Public License as published by the Free Software
 * Foundation, either version 3 of the License, or (at your option) any
 * later version.
 *
 * HUBzero is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Lesser General Public License for more details.
 *
 * You should have received a copy of the GNU Lesser General Public License
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
 *
 * HUBzero is a registered trademark of The Regents of the University of California.
 */

#include <stddef.h>
#include <malloc.h>
#include <string.h>
#include <limits.h> // for PATH_MAX
#include <stdlib.h>
#include <errno.h>
#include <fuse.h>
#include <pwd.h> // struct passwd
#include <grp.h> // struct group
#include <unistd.h>

#include "globals.h"
#include "path.h"



/* Some basic helper functions
 */
static void dump_path(path_t *path) {
  if (path == NULL) return;
  if (path->full_path) debug("full_path: %s", path->full_path);
  if (path->file_name) debug("file_name: %s", path->file_name);
  if (path->repo_name) debug("repo_name: %s", path->repo_name);
  if (path->file_path) debug("file_path: %s", path->file_path);
  return;
}

void free_path(path_t *path) {
  if (path == NULL) return;
  if (path->full_path) free(path->full_path);
  if (path->file_name) free(path->file_name);
  if (path->repo_name) free(path->repo_name);
  if (path->file_path) free(path->file_path);
  free(path);
  return;
}

void free_path_list(path_list_t *p) {
  if (p == NULL) return;
  if (p->path) free_path(p->path);
#ifdef SHADOW_WRITES
  if (p->tmp_path)  free(p->tmp_path);
#endif
  free(p);
  return;
}

path_t *dup_path(path_t *path) {
  path_t *p = malloc(sizeof(path_t));
  p->full_path = strdup(path->full_path);
  p->file_name = strdup(path->file_name);
  p->repo_name = strdup(path->repo_name);
  p->file_path = strdup(path->file_path);
  return p;
}
 
void dump_open_files() {
  path_list_t *of = open_files;
  debug("--- open_files ---");
  while (of) {
    dump_path(of->path);
#ifdef SHADOW_WRITES
    if (of->tmp_path)  debug("tmp_path: %s", of->tmp_path);
#endif
    debug("uid: %d", of->uid);
    debug("fd: %d", of->fd);
    of = of->next;
  }
  debug("------");
}

/* This function parses a path string into the path_t structure while also
 * verifying that the path is valid and within scope
 */
path_t *normalize_path(const char *path) {
  char start_path[PATH_MAX] = {0};
  char resolved_path[PATH_MAX] = {0};
  char final_path[PATH_MAX] = {0};
  path_t *rpath = NULL;
  int copying;
  int i;

  debug("normalize_path(%s)", path);

  /* Figure out the repository name, and the path remainder. Is there a better
   * way to do this? I wanted to use strtok(), but that doesn't work for the
   * remainder, which may contain additional /'s (think subdirectories)
   */
  char repo_name[PATH_MAX];
  char file_path[PATH_MAX];
  int off = 0;
  int repo = 1;

  i = 0;
  if (path[i] == '/') i++;

  for (; i < strlen(path); i++) {
    if (repo == 1) {
      if (path[i] == '/') {
        repo = 0;
        repo_name[off++] = 0;
        off = 0;
      }
      else {
        repo_name[off++] = path[i];
      }
    }
    else {
      file_path[off++] = path[i]; 
    }
  }

  if (repo == 1) {
    // No file
    repo_name[off] = 0;
    file_path[0] = 0;
  }
  else {
    file_path[off++] = 0;
  }

  // All paths should be relative to the repository root
#ifdef USE_GIT
  if (repo_name[0] != 0) {
    snprintf(start_path, sizeof(start_path), "%s/%s/files/%s", source_dir,
             repo_name, file_path);
  }
  else {
    strcat(start_path, source_dir);
  }
#else
  if (repo_name[0] != 0) {
    snprintf(start_path, sizeof(start_path), "%s/%s/%s", source_dir,
             repo_name, file_path);
  }
  else {
    strcat(start_path, source_dir);
  }
#endif

  debug("normalize_path(): start_path = %s", start_path);

  if (realpath(start_path, resolved_path) == NULL) {
    /* realpath failed, but it's possible that it was because this is a new
     * file or directory. This removes the trailing filename and validates the
     * remainder of the path.
     */
    debug("normalize_path(): file doesn't exist. Removing filename");
    copying = 0;
    for (i = strlen(start_path); i >= 0; i--) {
      if (start_path[i] == '/' && copying == 0) {
        resolved_path[i] = '\0';
        copying = i;
        continue;
      }
      if (copying) resolved_path[i] = start_path[i];
    }
  
    debug("normalize_path(): new resolved_path: %s", resolved_path);
    if (realpath(resolved_path, final_path) == NULL) {
      // The path leading to the hypothetical file is also invalid, fail.
      return NULL;
    }
    else {
      // This is a valid path, begin building the path_t structure
      // Tack the filename onto the end of the fully qualified path
      strcat(&(final_path[strlen(final_path)]), &(start_path[copying]));
      rpath = calloc(1, sizeof(path_t));
      if (rpath == NULL)
        return NULL;
      rpath->full_path = strdup(final_path);
    }
  }
  else {
    // This is a valid path, begin building the path_t structure
    rpath = calloc(1, sizeof(path_t));
    if (rpath == NULL)
      return NULL;
    rpath->full_path = strdup(resolved_path);
  } 

  // Overly cautious check to make sure our path is still within the repository
  // scope
  if (strncmp(rpath->full_path, source_dir, strlen(source_dir)) != 0) {
    error("normalize_path(): '%s' outside of repository scope!",
          rpath->full_path);
    free_path(rpath);
    return NULL;
  }

  rpath->repo_name = strdup(repo_name);
  if (file_path[0] != '\0') {
    rpath->file_path = strdup(file_path);

    for (i = strlen(rpath->file_path) - 1; i >= 0; i--) {
      if (rpath->file_path[i] == '/') break;
    }

    rpath->file_name = strdup(&(rpath->file_path[i+1]));
    rpath->file_path[i+1] = 0;
  }
 
  dump_path(rpath);
  return rpath;
}

/* Determine whether the current context's uid has access to the specified path
 * based on group membership data
 */
int can_access(const char *path, path_t **normal_path, int free_on_blacklist) {
  path_t *resolved_path = NULL;
  char repo_name[PATH_MAX+3] = {0};
  struct fuse_context *fc = fuse_get_context();
  int i;

  debug("can_access(%s)", path);

  if (fc == NULL) return ACC_NONE;

  if ((resolved_path = normalize_path(path)) == NULL) {
    return ACC_NONE;
  }

  path_list_t *p = open_files;
  while (p) {
    if (strcmp(p->path->full_path, resolved_path->full_path) == 0 &&
        p->uid == fc->uid) {
      debug("can_access(): found file in open_files, short circuiting");
      if (normal_path != NULL) *normal_path = resolved_path;
      return ACC_OK;
    }
    else {
      p = p->next;
    }
  }

  if (strcmp(source_dir, resolved_path->full_path) == 0) {
    // Allow everyone to list the root
    if (normal_path != NULL) *normal_path = resolved_path;
    else free_path(resolved_path);
    debug("can_access(): root allowed");
    return ACC_OK;
  }

  for (i = 0; i < black_list_size; i++) {
    if (strstr(resolved_path->full_path, black_list[i]) != NULL) {
      /* Sometimes the blacklisted file needs to be exposed to the next
       * layer up (eg, xmp_getattr()). free_on_blacklist tells us when,
       * so (most of) the upper layer doesn't have to worry about freeing
       * things.
       */
      if (free_on_blacklist == BL_NOFREE && normal_path != NULL) {
        *normal_path = resolved_path;
      }
      else {
        free_path(resolved_path);
      }
      debug("can_access(): blacklisted!");
      return ACC_BLACKLISTED;
    }
  }

  // Otherwise need to check if they're a member of the right group
  debug("can_access(): get_groups(%d)", fc->uid);

  struct passwd *passwd = getpwuid(fc->uid);
  if (passwd == NULL) {
    error("can_access(): user '%d' is not known", fc->uid);
    free_path(resolved_path);
    return ACC_NONE;
  }

  if (resolved_path->repo_name) {
    strcpy(repo_name, REPO_PREFIX);
    strcat(repo_name, resolved_path->repo_name); 

    struct group *grdata = getgrnam(repo_name);
    if (!grdata) {
      // This used to be error(), but it just turned out to be annoying
      debug("can_access(): cannot find group id for '%s'", repo_name);
      return ACC_NONE;
    }

    gid_t grouplist[NGROUPS_MAX];
    int ngroups = NGROUPS_MAX;

    if (getgrouplist(passwd->pw_name, passwd->pw_gid, grouplist, &ngroups) != -1)
    {
      for (i = 0; i < ngroups; i++) {
        if (grouplist[i] == grdata->gr_gid) {
          if (normal_path != NULL) *normal_path = resolved_path;
          else free_path(resolved_path);
          debug("can_access(): allowed");
          return ACC_OK;
        }
      }
    }
  }

  free_path(resolved_path);
  return ACC_NONE;
}

/* This copies a file from location src to destination dst
 * It assumes that permission has already been determined.
 */
int copy_file(const char *src, const char *dst)
{
  off_t cur;
  char buf[4096];
  int sfd;
  int dfd;
  int res;

  debug("copy_file(%s, %s)", src, dst);

  res = access(src, R_OK);
  if (res != 0)
    return -errno;

  sfd = open(src, O_RDONLY);
  if (sfd < 0) {
    error("copy_file(): Unable to open src '%s'", src);
    return -errno;
  }

  if (access(dst, W_OK) == 0) {
    res = truncate(dst, 0);
    if (res != 0) {
      error("copy_file(): truncate(%s) failure", dst);
      close(sfd);
      return -errno;
    }
    dfd = open(dst, O_WRONLY);
    if (dfd < 0) {
      error("copy_file(): open(%s, O_WRONLY) failure", dst);
      close(sfd);
      return -errno;
    }
  }
  else {
//    dfd = open(dst, O_WRONLY|O_CREAT|O_SYNC, 0640);
    dfd = open(dst, O_WRONLY|O_CREAT, 0640);
    if (dfd < 0) {
      error("copy_file(): open(%s, O_WRONLY|O_CREAT) failure", dst);
      close(sfd);
      return -errno;
    }
  }

  cur = 0; 
 
  int wres;
  while ((res = pread(sfd, buf, sizeof(buf), cur)) > 0) {
    wres = pwrite(dfd, buf, res, cur);
    if (wres != res) {
      error("copy_file(): read write mismatch");
      close(sfd);
      close(dfd);
      // Whelp
      res = unlink(dst);
      if (res < 0) {
        error("copy_file(): unable to unlink '%s'", dst);
        return -errno;
      }
      return -1;
    }
    cur += res;
  }

  debug("copy_file(): wrote %d bytes", cur);

  close(sfd);
  close(dfd);

  return 0;
}
