/*
 * WSSSH (wsssh) - SSH Wrapper with WebSocket ProxyCommand
 *
 * Copyright (C) 2024 Stefy Lanza <stefy@nexlab.net> and SexHack.me
 *
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see <https://www.gnu.org/licenses/>.
 */

#include "wsssh.h"
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>
#include <sys/wait.h>
#include <errno.h>
#include <limits.h>
#include <libgen.h>

void print_wsssh_usage(const char *program_name) {
    fprintf(stderr, "Usage: %s [options] <user>[@clientid[.wssshd-host[:sshstring]]] [ssh_options...]\n", program_name);
    fprintf(stderr, "WSSSH Wrapper - SSH through WebSocket tunnels\n\n");
    fprintf(stderr, "Protect the dolls!\n\n");
    fprintf(stderr, "Options:\n");
    fprintf(stderr, "  --help                    Show this help\n");
    fprintf(stderr, "  --clientid ID             Client ID of the registered wssshc endpoint\n");
    fprintf(stderr, "  --wssshd-host HOST        wssshd relay host\n");
    fprintf(stderr, "  -p, --wssshd-port PORT    wssshd relay websocket port (default: 9898)\n");
    fprintf(stderr, "  --debug                   Enable debug output\n");
    fprintf(stderr, "  --tunnel TRANSPORT        Select data channel transport (comma-separated or 'any')\n");
    fprintf(stderr, "  --tunnel-control TRANSPORT Select control channel transport (comma-separated or 'any')\n");
    fprintf(stderr, "  --enc ENCODING            Data encoding: hex, base64, or bin\n");
    fprintf(stderr, "\nTarget format:\n");
    fprintf(stderr, "  user[@clientid[.wssshd-host[:sshstring]]]\n");
    fprintf(stderr, "\nExamples:\n");
    fprintf(stderr, "  %s user@myclient\n", program_name);
    fprintf(stderr, "  %s user@myclient.server.com\n", program_name);
    fprintf(stderr, "  %s --debug user@myclient.server.com:22\n", program_name);
    fprintf(stderr, "  %s --tunnel websocket user@myclient.server.com -p 2222\n", program_name);
    fprintf(stderr, "\nDonations:\n");
    fprintf(stderr, "  BTC: bc1q3zlkpu95amtcltsk85y0eacyzzk29v68tgc5hx\n");
    fprintf(stderr, "  ETH: 0xdA6dAb526515b5cb556d20269207D43fcc760E51\n");
}

int parse_wsssh_args(int argc, char *argv[], wsssh_wrapper_config_t *config) {
    // Initialize config with defaults
    memset(config, 0, sizeof(wsssh_wrapper_config_t));
    config->wssshd_port = 9898;

    // Parse options
    int remaining_start = 1;

    for (int i = 1; i < argc; i++) {
        if (strcmp(argv[i], "--help") == 0 || strcmp(argv[i], "-h") == 0) {
            print_wsssh_usage(argv[0]);
            return 0;
        } else if (strcmp(argv[i], "--clientid") == 0 && i + 1 < argc) {
            config->client_id = strdup(argv[i + 1]);
            i++;
        } else if (strcmp(argv[i], "--wssshd-host") == 0 && i + 1 < argc) {
            config->wssshd_host = strdup(argv[i + 1]);
            i++;
        } else if ((strcmp(argv[i], "--wssshd-port") == 0 || strcmp(argv[i], "-p") == 0) && i + 1 < argc) {
            config->wssshd_port = atoi(argv[i + 1]);
            config->wssshd_port_explicit = 1;
            i++;
        } else if (strcmp(argv[i], "--debug") == 0) {
            config->debug = 1;
        } else if (strcmp(argv[i], "--tunnel") == 0 && i + 1 < argc) {
            config->tunnel = strdup(argv[i + 1]);
            i++;
        } else if (strcmp(argv[i], "--tunnel-control") == 0 && i + 1 < argc) {
            config->tunnel_control = strdup(argv[i + 1]);
            i++;
        } else if (strcmp(argv[i], "--enc") == 0 && i + 1 < argc) {
            config->enc = strdup(argv[i + 1]);
            i++;
        } else if (argv[i][0] == '-') {
            // Unknown option, treat as SSH option
            remaining_start = i;
            break;
        } else {
            // First non-option argument is the target
            remaining_start = i;
            break;
        }
    }

    // Store remaining arguments
    config->remaining_argc = argc - remaining_start;
    if (config->remaining_argc > 0) {
        config->remaining_argv = &argv[remaining_start];
    }

    return 1;
}

int parse_target_string(const char *target, wsssh_wrapper_config_t *config) {
    if (!target) return 0;

    char *target_copy = strdup(target);
    if (!target_copy) return 0;

    // Parse user@host format
    char *at_pos = strchr(target_copy, '@');
    if (at_pos) {
        *at_pos = '\0';
        config->user = strdup(target_copy);

        // Parse the part after @
        char *host_part = at_pos + 1;

        // Check for :port
        char *colon_pos = strchr(host_part, ':');
        if (colon_pos) {
            *colon_pos = '\0';
            config->ssh_string = strdup(colon_pos + 1);
        }

        // Check for .domain
        char *dot_pos = strchr(host_part, '.');
        if (dot_pos) {
            *dot_pos = '\0';
            config->client_id = strdup(host_part);
            config->wssshd_host = strdup(dot_pos + 1);
        } else {
            config->client_id = strdup(host_part);
        }
    } else {
        // No @, just user/host
        config->user = strdup(target_copy);
    }

    free(target_copy);
    return 1;
}

int parse_ssh_port_from_args(wsssh_wrapper_config_t *config) {
    if (!config->remaining_argv || config->remaining_argc < 2) {
        return 0;
    }

    // Look for -p option in remaining arguments
    for (int i = 0; i < config->remaining_argc - 1; i++) {
        if (strcmp(config->remaining_argv[i], "-p") == 0) {
            // Found -p, next argument should be the port
            char *endptr;
            int port = strtol(config->remaining_argv[i + 1], &endptr, 10);
            if (*endptr == '\0' && port > 0 && port <= 65535) {
                return port;
            }
        }
    }

    return 0; // No valid -p option found
}

char *find_wsssht_path() {
    // First check if wsssht is in PATH
    if (system("which wsssht > /dev/null 2>&1") == 0) {
        return strdup("wsssht");
    }

    // If not in PATH, check in the same directory as wsssh
    char wsssh_path[PATH_MAX];
    ssize_t len = readlink("/proc/self/exe", wsssh_path, sizeof(wsssh_path) - 1);
    if (len != -1) {
        wsssh_path[len] = '\0';
        char *dir_end = strrchr(wsssh_path, '/');
        if (dir_end) {
            *dir_end = '\0';
            // Ensure we have enough space for the path + "/wsssht" + null terminator
            size_t wsssh_path_len = strlen(wsssh_path);
            if (wsssh_path_len + 8 < PATH_MAX) {  // 8 = "/wsssht" + null
                char wsssht_full_path[PATH_MAX];
                strcpy(wsssht_full_path, wsssh_path);
                strcat(wsssht_full_path, "/wsssht");
                if (access(wsssht_full_path, X_OK) == 0) {
                    return strdup(wsssht_full_path);
                }
            }
        }
    }

    return NULL; // wsssht not found
}

char *build_proxy_command(wsssh_wrapper_config_t *config) {
    char *wsssht_path = find_wsssht_path();
    if (!wsssht_path) {
        fprintf(stderr, "Error: wsssht not found in PATH or in the same directory as wsssh\n");
        fprintf(stderr, "Please install wsssht to use wsssh\n");
        return NULL;
    }

    char *cmd = malloc(2048);
    if (!cmd) {
        free(wsssht_path);
        return NULL;
    }

    // Start with wsssht path --pipe
    sprintf(cmd, "%s --pipe", wsssht_path);
    free(wsssht_path);

    // If --wssshd-port was not explicitly set, check for -p in SSH arguments
    if (!config->wssshd_port_explicit) {
        int ssh_port = parse_ssh_port_from_args(config);
        if (ssh_port > 0) {
            char port_str[32];
            sprintf(port_str, " --wssshd-port %d", ssh_port);
            strcat(cmd, port_str);
        }
    }

    // Add debug if specified
    if (config->debug) {
        strcat(cmd, " --debug");
    }

    // Add tunnel if specified
    if (config->tunnel) {
        strcat(cmd, " --tunnel ");
        strcat(cmd, config->tunnel);
    }

    // Add tunnel-control if specified
    if (config->tunnel_control) {
        strcat(cmd, " --tunnel-control ");
        strcat(cmd, config->tunnel_control);
    }

    // Add enc if specified
    if (config->enc) {
        strcat(cmd, " --enc ");
        strcat(cmd, config->enc);
    }

    // Add wssshd-port if not default
    if (config->wssshd_port != 9898) {
        char port_str[32];
        sprintf(port_str, " --wssshd-port %d", config->wssshd_port);
        strcat(cmd, port_str);
    }

    // Add client_id
    if (config->client_id) {
        strcat(cmd, " ssh://");
        strcat(cmd, config->client_id);

        // Add wssshd-host if specified
        if (config->wssshd_host) {
            strcat(cmd, "@");
            strcat(cmd, config->wssshd_host);
        }
    }

    return cmd;
}

char *build_ssh_command(wsssh_wrapper_config_t *config, const char *proxy_command) {
    if (!config->user || !config->client_id) {
        return NULL;
    }

    char *cmd = malloc(4096);
    if (!cmd) return NULL;

    // Start with ssh -o ProxyCommand="..."
    sprintf(cmd, "ssh -o ProxyCommand=\"%s\"", proxy_command);

    // Add user@client_id
    strcat(cmd, " ");
    strcat(cmd, config->user);
    strcat(cmd, "@");
    strcat(cmd, config->client_id);

    // Add ssh_string if specified
    if (config->ssh_string) {
        strcat(cmd, ":");
        strcat(cmd, config->ssh_string);
    }

    // Add remaining arguments
    for (int i = 1; i < config->remaining_argc; i++) {
        strcat(cmd, " ");
        // Escape quotes in arguments
        if (strchr(config->remaining_argv[i], '"')) {
            strcat(cmd, "'");
            strcat(cmd, config->remaining_argv[i]);
            strcat(cmd, "'");
        } else {
            strcat(cmd, config->remaining_argv[i]);
        }
    }

    return cmd;
}

int execute_ssh_command(const char *ssh_command, int debug) {
    if (!ssh_command) return 1;

    if (debug) {
        printf("[DEBUG] SSH command: %s\n", ssh_command);
    } else {
        printf("Executing: %s\n", ssh_command);
    }

    // Execute the command
    int result = system(ssh_command);

    if (result == -1) {
        perror("system");
        return 1;
    }

    return WEXITSTATUS(result);
}

int main(int argc, char *argv[]) {
    wsssh_wrapper_config_t config;

    // Parse arguments
    if (!parse_wsssh_args(argc, argv, &config)) {
        return 0; // Help was printed
    }

    // Need at least one argument (the target)
    if (config.remaining_argc == 0) {
        fprintf(stderr, "Error: No target specified\n");
        print_wsssh_usage(argv[0]);
        return 1;
    }

    // Parse the target string
    if (!parse_target_string(config.remaining_argv[0], &config)) {
        fprintf(stderr, "Error: Invalid target format\n");
        return 1;
    }

    // Build proxy command
    char *proxy_command = build_proxy_command(&config);
    if (!proxy_command) {
        fprintf(stderr, "Error: Failed to build proxy command\n");
        return 1;
    }

    // Build SSH command
    char *ssh_command = build_ssh_command(&config, proxy_command);
    if (!ssh_command) {
        fprintf(stderr, "Error: Failed to build SSH command\n");
        free(proxy_command);
        return 1;
    }

    // Execute SSH command
    int result;
    if (config.debug) {
        printf("[DEBUG] Executing: %s\n", ssh_command);
        result = execute_ssh_command(ssh_command, config.debug);
    } else {
        result = execute_ssh_command(ssh_command, config.debug);
    }

    // Cleanup
    free(proxy_command);
    free(ssh_command);
    free(config.client_id);
    free(config.wssshd_host);
    free(config.tunnel);
    free(config.tunnel_control);
    free(config.enc);
    free(config.user);
    free(config.ssh_string);

    return result;
}