/*
 * WSSSH SCP (wsscp) - C Implementation
 * SCP Wrapper with wsssht ProxyCommand support.
 *
 * Copyright (C) 2024 Stefy Lanza <stefy@nexlab.net> and SexHack.me
 *
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see <https://www.gnu.org/licenses/>.
 */

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <netdb.h>
#include <getopt.h>
#include <sys/wait.h>
#include <fcntl.h>
#include <pthread.h>
#include <sys/select.h>
#include <errno.h>
#include <limits.h>
#include <signal.h>

#include "wsscp.h"

// Global variables for signal handling
volatile sig_atomic_t sigint_received = 0;
pid_t scp_pid = -1;

void sigint_handler(int sig __attribute__((unused))) {
    sigint_received = 1;

    // If we have an SCP process running, terminate it gracefully
    if (scp_pid > 0) {
        kill(scp_pid, SIGTERM);
    }
}

void print_usage(const char *program_name) {
    fprintf(stderr, "Usage: %s [options] source destination\n", program_name);
    fprintf(stderr, "WebSocket SCP Wrapper - SCP through WebSocket tunnels\n\n");
    fprintf(stderr, "Protect the dolls!\n\n");
    fprintf(stderr, "Options:\n");
    fprintf(stderr, "  --clientid CLIENT_ID    Client ID of the registered wssshc endpoint\n");
    fprintf(stderr, "  --wssshd-host HOST      wssshd relay host\n");
    fprintf(stderr, "  --wssshd-port PORT      wssshd relay websocket port (default: 9898)\n");
    fprintf(stderr, "  --debug                 Enable debug output\n");
    fprintf(stderr, "  --tunnel TYPES          Transport types for data channel (comma-separated or 'any', default: any)\n");
    fprintf(stderr, "  --tunnel-control TYPES  Transport types for control channel (comma-separated or 'any', default: any)\n");
    fprintf(stderr, "  --enc ENCODING          Data encoding: hex, base64, or bin\n");
    fprintf(stderr, "  --help                  Show this help\n");
    fprintf(stderr, "\nDestination format:\n");
    fprintf(stderr, "  user@client_id[.wssshd_host]:/remote/path\n");
    fprintf(stderr, "\nExamples:\n");
    fprintf(stderr, "  %s localfile user@myclient:/remote/path\n", program_name);
    fprintf(stderr, "  %s --wssshd-port 9898 localfile user@myclient.server.com:/remote/path\n", program_name);
    fprintf(stderr, "\nDonations:\n");
    fprintf(stderr, "  BTC: bc1q3zlkpu95amtcltsk85y0eacyzzk29v68tgc5hx\n");
    fprintf(stderr, "  ETH: 0xdA6dAb526515b5cb556d20269207D43fcc760E51\n");
}

int parse_wsscp_args(int argc, char *argv[], wsscp_wrapper_config_t *config) {
    static struct option long_options[] = {
        {"clientid", required_argument, 0, 'c'},
        {"wssshd-host", required_argument, 0, 'H'},
        {"wssshd-port", required_argument, 0, 'P'},
        {"debug", no_argument, 0, 'd'},
        {"tunnel", required_argument, 0, 't'},
        {"tunnel-control", required_argument, 0, 'T'},
        {"enc", required_argument, 0, 'e'},
        {"help", no_argument, 0, 'h'},
        {0, 0, 0, 0}
    };

    int opt;
    int option_index = 0;

    while ((opt = getopt_long(argc, argv, "c:H:P:dt:T:e:h", long_options, &option_index)) != -1) {
        switch (opt) {
            case 'c':
                config->client_id = strdup(optarg);
                break;
            case 'H':
                config->wssshd_host = strdup(optarg);
                break;
            case 'P':
                config->wssshd_port = atoi(optarg);
                config->wssshd_port_explicit = 1;
                break;
            case 'd':
                config->debug = 1;
                break;
            case 't':
                config->tunnel = strdup(optarg);
                break;
            case 'T':
                config->tunnel_control = strdup(optarg);
                break;
            case 'e':
                config->enc = strdup(optarg);
                break;
            case 'h':
                print_usage(argv[0]);
                return 0;
            default:
                print_usage(argv[0]);
                return 0;
        }
    }

    // Store remaining arguments (source and destination)
    config->remaining_argc = argc - optind;
    config->remaining_argv = &argv[optind];

    return 1;
}

int parse_target_string(const char *target, wsscp_wrapper_config_t *config) {
    if (!target) return 0;

    char *target_copy = strdup(target);
    if (!target_copy) return 0;

    // Parse user@host:path format
    char *at_pos = strchr(target_copy, '@');
    if (at_pos) {
        *at_pos = '\0';
        config->user = strdup(target_copy);

        // Parse the part after @
        char *host_part = at_pos + 1;

        // Check for :path
        char *colon_pos = strchr(host_part, ':');
        if (colon_pos) {
            *colon_pos = '\0';
            config->destination = strdup(colon_pos + 1);
        }

        // Check for .domain
        char *dot_pos = strchr(host_part, '.');
        if (dot_pos) {
            *dot_pos = '\0';
            config->client_id = strdup(host_part);
            config->wssshd_host = strdup(dot_pos + 1);
        } else {
            config->client_id = strdup(host_part);
        }
    } else {
        // No @, just host:path
        char *colon_pos = strchr(target_copy, ':');
        if (colon_pos) {
            *colon_pos = '\0';
            config->destination = strdup(colon_pos + 1);
            config->client_id = strdup(target_copy);
        } else {
            config->client_id = strdup(target_copy);
        }
    }

    free(target_copy);
    return 1;
}

int parse_scp_port_from_args(wsscp_wrapper_config_t *config) {
    if (!config->remaining_argv || config->remaining_argc < 2) {
        return 0;
    }

    // Look for -P option in remaining arguments (SCP uses -P, not -p)
    for (int i = 0; i < config->remaining_argc - 1; i++) {
        if (strcmp(config->remaining_argv[i], "-P") == 0) {
            // Found -P, next argument should be the port
            char *endptr;
            int port = strtol(config->remaining_argv[i + 1], &endptr, 10);
            if (*endptr == '\0' && port > 0 && port <= 65535) {
                return port;
            }
        }
    }

    return 0; // No valid -P option found
}

char *find_wsssht_path() {
    // Check if wsssht is in PATH
    if (system("which wsssht > /dev/null 2>&1") == 0) {
        return strdup("wsssht");
    }

    // If not in PATH, check in the same directory as wsscp
    char wsssh_path[PATH_MAX];
    ssize_t len = readlink("/proc/self/exe", wsssh_path, sizeof(wsssh_path) - 1);
    if (len != -1) {
        wsssh_path[len] = '\0';
        char *dir_end = strrchr(wsssh_path, '/');
        if (dir_end) {
            *dir_end = '\0';
            // Ensure we have enough space for the path + "/wsssht" + null terminator
            size_t wsssh_path_len = strlen(wsssh_path);
            if (wsssh_path_len + 8 < PATH_MAX) {  // 8 = "/wsssht" + null
                char wsssht_full_path[PATH_MAX];
                strcpy(wsssht_full_path, wsssh_path);
                strcat(wsssht_full_path, "/wsssht");
                if (access(wsssht_full_path, X_OK) == 0) {
                    return strdup(wsssht_full_path);
                }
            }
        }
    }

    return NULL;
}

char *build_proxy_command(wsscp_wrapper_config_t *config) {
    char *wsssht_path = find_wsssht_path();
    if (!wsssht_path) {
        fprintf(stderr, "Error: wsssht not found in PATH or in the same directory as wsscp\n");
        fprintf(stderr, "Please install wsssht to use wsscp\n");
        return NULL;
    }

    char *cmd = malloc(2048);
    if (!cmd) {
        free(wsssht_path);
        return NULL;
    }

    // Start with wsssht path --pipe
    sprintf(cmd, "%s --pipe", wsssht_path);
    free(wsssht_path);

    // Add debug flag if enabled
    if (config->debug) {
        strcat(cmd, " --debug");
    }

    // Add tunnel options if specified
    if (config->tunnel) {
        char tunnel_option[256];
        sprintf(tunnel_option, " --tunnel %s", config->tunnel);
        strcat(cmd, tunnel_option);
    }

    if (config->tunnel_control) {
        char tunnel_control_option[256];
        sprintf(tunnel_control_option, " --tunnel-control %s", config->tunnel_control);
        strcat(cmd, tunnel_control_option);
    }

    // Add enc if specified
    if (config->enc) {
        char enc_option[32];
        sprintf(enc_option, " --enc %s", config->enc);
        strcat(cmd, enc_option);
    }

    // If --wssshd-port was not explicitly set, check for -P in SCP arguments
    if (!config->wssshd_port_explicit) {
        int scp_port = parse_scp_port_from_args(config);
        if (scp_port > 0) {
            char port_str[32];
            sprintf(port_str, " --wssshd-port %d", scp_port);
            strcat(cmd, port_str);
        }
    }

    // Add the SSH URL part
    char ssh_url[512];
    if (config->wssshd_host) {
        sprintf(ssh_url, " ssh://%s@%s", config->client_id, config->wssshd_host);
    } else {
        sprintf(ssh_url, " ssh://%s", config->client_id);
    }
    strcat(cmd, ssh_url);

    return cmd;
}

int execute_scp_command(char *command, int debug) {
    if (debug) {
        printf("[DEBUG] Executing: %s\n", command);
    }

    // Fork and execute the command so we can track the process
    pid_t pid = fork();
    if (pid == -1) {
        perror("fork");
        return 1;
    }

    if (pid == 0) {
        // Child process: execute the command
        execl("/bin/sh", "sh", "-c", command, NULL);
        // If we get here, exec failed
        perror("execl");
        exit(1);
    } else {
        // Parent process: store the PID and wait for completion
        scp_pid = pid;

        int status;
        if (waitpid(pid, &status, 0) == -1) {
            perror("waitpid");
            scp_pid = -1;
            return 1;
        }

        scp_pid = -1; // Reset PID

        if (WIFEXITED(status)) {
            return WEXITSTATUS(status);
        } else if (WIFSIGNALED(status)) {
            if (debug) {
                printf("[DEBUG] SCP process terminated by signal %d\n", WTERMSIG(status));
            }
            return 1;
        }

        return 1;
    }
}

int main(int argc, char *argv[]) {
    // Initialize configuration
    wsscp_wrapper_config_t config = {
        .client_id = NULL,
        .wssshd_host = NULL,
        .wssshd_port = 9898,
        .wssshd_port_explicit = 0,
        .debug = 0,
        .tunnel = NULL,
        .tunnel_control = NULL,
        .enc = NULL,
        .user = NULL,
        .target_host = NULL,
        .ssh_string = NULL,
        .source_file = NULL,
        .destination = NULL,
        .remaining_argc = 0,
        .remaining_argv = NULL
    };

    // Easter egg: --support option (only when it's the only argument)
    if (argc == 2 && strcmp(argv[1], "--support") == 0) {
        printf("Support the dolls!\n");
        printf("BTC: bc1q3zlkpu95amtcltsk85y0eacyzzk29v68tgc5hx\n");
        printf("ETH: 0xdA6dAb526515b5cb556d20269207D43fcc760E51\n");
        return 0;
    }

    // Handle --help before parsing
    for (int i = 1; i < argc; i++) {
        if (strcmp(argv[i], "--help") == 0 || strcmp(argv[i], "-h") == 0) {
            print_usage(argv[0]);
            return 0;
        }
    }

    // Set up signal handler for SIGINT
    signal(SIGINT, sigint_handler);

    // Parse wsscp arguments
    if (!parse_wsscp_args(argc, argv, &config)) {
        return 1;
    }

    // Need at least source and destination
    if (config.remaining_argc < 2) {
        fprintf(stderr, "Error: Source and destination required\n");
        print_usage(argv[0]);
        return 1;
    }

    // Parse source file
    config.source_file = strdup(config.remaining_argv[0]);

    // Parse destination
    if (!parse_target_string(config.remaining_argv[1], &config)) {
        fprintf(stderr, "Error: Invalid destination format\n");
        print_usage(argv[0]);
        free(config.source_file);
        return 1;
    }

    // Check if client_id and destination are provided
    if (!config.client_id || !config.destination) {
        fprintf(stderr, "Error: Client ID and destination path required\n");
        print_usage(argv[0]);
        free(config.source_file);
        return 1;
    }

    if (config.debug) {
        printf("[DEBUG] Source file: %s\n", config.source_file);
        printf("[DEBUG] Client ID: %s\n", config.client_id);
        printf("[DEBUG] Destination: %s\n", config.destination);
        if (config.wssshd_host) {
            printf("[DEBUG] WSSSHD Host: %s\n", config.wssshd_host);
        }
        printf("[DEBUG] WSSSHD Port: %d\n", config.wssshd_port);
    }

    // Build ProxyCommand
    char *proxy_command = build_proxy_command(&config);
    if (!proxy_command) {
        free(config.source_file);
        return 1;
    }

    // Build SCP command
    char scp_command[4096];
    sprintf(scp_command, "scp");

    // Add debug flag to SCP if enabled
    if (config.debug) {
        strcat(scp_command, " -v");
    }

    // Add ProxyCommand
    char proxy_option[3072];
    sprintf(proxy_option, " -o ProxyCommand=\"%s\"", proxy_command);
    strcat(scp_command, proxy_option);

    // Add source file
    char source_part[1024];
    sprintf(source_part, " %s", config.source_file);
    strcat(scp_command, source_part);

    // Add destination
    char dest_part[1024];
    if (config.user) {
        sprintf(dest_part, " %s@%s:%s", config.user, config.client_id, config.destination);
    } else {
        sprintf(dest_part, " %s:%s", config.client_id, config.destination);
    }
    strcat(scp_command, dest_part);

    // Add any additional SCP arguments
    for (int i = 2; i < config.remaining_argc; i++) {
        char additional_arg[256];
        sprintf(additional_arg, " %s", config.remaining_argv[i]);
        strcat(scp_command, additional_arg);
    }

    // Execute SCP command
    int result;
    if (config.debug) {
        printf("[DEBUG] Executing: %s\n", scp_command);
        result = execute_scp_command(scp_command, config.debug);
    } else {
        result = execute_scp_command(scp_command, config.debug);
    }

    // Check if we were interrupted
    if (sigint_received) {
        if (config.debug) {
            printf("[DEBUG] SCP command was interrupted by SIGINT\n");
        }
        result = 130; // Standard exit code for SIGINT (128 + 2)
    }

    // Cleanup
    free(proxy_command);
    free(config.source_file);
    free(config.client_id);
    free(config.wssshd_host);
    free(config.tunnel);
    free(config.tunnel_control);
    free(config.user);
    free(config.destination);

    return result;
}