/*
 * WSSSH Library - Shared utilities implementation
 *
 * Copyright (C) 2024 Stefy Lanza <stefy@nexlab.net> and SexHack.me
 *
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see <https://www.gnu.org/licenses/>.
 */

#include "wssshlib.h"
#include <stdlib.h>
#include <time.h>
#include <unistd.h>
#include <ifaddrs.h>
#include <net/if.h>
#include <linux/rtnetlink.h>
#include <linux/if.h>
#include <linux/route.h>
#include <sys/ioctl.h>
#include <sys/time.h>
#include <string.h>
#include <errno.h>
#include <limits.h>

// Global signal flag
volatile sig_atomic_t sigint_received = 0;

// SSL mutex for thread-safe SSL operations
pthread_mutex_t ssl_mutex;

// Initialize SSL mutex
__attribute__((constructor)) void init_ssl_mutex(void) {
    pthread_mutex_init(&ssl_mutex, NULL);
}

// Cleanup SSL mutex
__attribute__((destructor)) void destroy_ssl_mutex(void) {
    pthread_mutex_destroy(&ssl_mutex);
}

char *read_config_value(const char *key) {
    char *home = getenv("HOME");
    if (!home) return NULL;
    char path[PATH_MAX];
    snprintf(path, sizeof(path), "%s/.config/wsssh/wsssh.conf", home);
    FILE *f = fopen(path, "r");
    if (!f) return NULL;
    char line[256];
    char section[64] = "";
    while (fgets(line, sizeof(line), f)) {
        // Check for section headers
        if (line[0] == '[') {
            sscanf(line, "[%63[^]]", section);
            continue;
        }
        // Skip comments and empty lines
        if (line[0] == '#' || line[0] == ';' || line[0] == '\n') continue;

        char *equals = strchr(line, '=');
        if (equals) {
            *equals = '\0';
            char *config_key = line;
            char *value = equals + 1;

            // Trim whitespace
            while (*config_key == ' ' || *config_key == '\t') config_key++;
            char *end = config_key + strlen(config_key) - 1;
            while (end > config_key && (*end == ' ' || *end == '\t')) *end-- = '\0';

            while (*value == ' ' || *value == '\t') value++;
            end = value + strlen(value) - 1;
            while (end > value && (*end == ' ' || *end == '\t' || *end == '\n')) *end-- = '\0';

            // Check if key matches (case-insensitive for section names)
            if (strcmp(config_key, key) == 0) {
                fclose(f);
                return strdup(value);
            }
        }
    }
    fclose(f);
    return NULL;
}

char *read_config_value_from_file(const char *key, const char *config_file) {
    char *result = NULL;

    // First try user config: ~/.config/wsssh/{config_file}.conf
    char *home = getenv("HOME");
    if (home) {
        char user_path[PATH_MAX];
        snprintf(user_path, sizeof(user_path), "%s/.config/wsssh/%s.conf", home, config_file);
        FILE *f = fopen(user_path, "r");
        if (f) {
            char line[256];
            char section[64] = "";
            while (fgets(line, sizeof(line), f)) {
                // Check for section headers
                if (line[0] == '[') {
                    sscanf(line, "[%63[^]]", section);
                    continue;
                }
                // Skip comments and empty lines
                if (line[0] == '#' || line[0] == ';' || line[0] == '\n') continue;

                char *equals = strchr(line, '=');
                if (equals) {
                    *equals = '\0';
                    char *config_key = line;
                    char *value = equals + 1;

                    // Trim whitespace
                    while (*config_key == ' ' || *config_key == '\t') config_key++;
                    char *end = config_key + strlen(config_key) - 1;
                    while (end > config_key && (*end == ' ' || *end == '\t')) *end-- = '\0';

                    while (*value == ' ' || *value == '\t') value++;
                    end = value + strlen(value) - 1;
                    while (end > value && (*end == ' ' || *end == '\t' || *end == '\n')) *end-- = '\0';

                    // Check if key matches
                    if (strcmp(config_key, key) == 0) {
                        result = strdup(value);
                        break;
                    }
                }
            }
            fclose(f);
        }
    }

    // If not found in user config, try system config: /etc/{config_file}.conf
    if (!result) {
        char system_path[PATH_MAX];
        snprintf(system_path, sizeof(system_path), "/etc/%s.conf", config_file);
        FILE *f = fopen(system_path, "r");
        if (f) {
            char line[256];
            char section[64] = "";
            while (fgets(line, sizeof(line), f)) {
                // Check for section headers
                if (line[0] == '[') {
                    sscanf(line, "[%63[^]]", section);
                    continue;
                }
                // Skip comments and empty lines
                if (line[0] == '#' || line[0] == ';' || line[0] == '\n') continue;

                char *equals = strchr(line, '=');
                if (equals) {
                    *equals = '\0';
                    char *config_key = line;
                    char *value = equals + 1;

                    // Trim whitespace
                    while (*config_key == ' ' || *config_key == '\t') config_key++;
                    char *end = config_key + strlen(config_key) - 1;
                    while (end > config_key && (*end == ' ' || *end == '\t')) *end-- = '\0';

                    while (*value == ' ' || *value == '\t') value++;
                    end = value + strlen(value) - 1;
                    while (end > value && (*end == ' ' || *end == '\t' || *end == '\n')) *end-- = '\0';

                    // Check if key matches
                    if (strcmp(config_key, key) == 0) {
                        result = strdup(value);
                        break;
                    }
                }
            }
            fclose(f);
        }
    }

    return result;
}

void print_trans_flag(void) {
    // Transgender pride flag colors using ANSI escape codes
    const char *colors[] = {
        "\033[48;5;117m",  // Light blue background
        "\033[48;5;218m",  // Pink background
        "\033[48;5;231m",  // White background
        "\033[48;5;218m",  // Pink background
        "\033[48;5;117m"   // Light blue background
    };
    const char *reset = "\033[0m";

    // Print 10 rows of colored blocks (double height)
    for (int i = 0; i < 10; i++) {
        int color_index = i / 2;  // Each color appears twice for double height
        if (color_index >= 5) color_index = 4;  // Ensure we don't go out of bounds
        printf("%s", colors[color_index]);
        for (int j = 0; j < 40; j++) {
            printf(" ");
        }
        printf("%s\n", reset);
    }
    printf("\n");
}

void print_palestinian_flag(void) {
    const char *reset = "\033[0m";
    const char *red = "\033[41m";
    const char *black = "\033[40m";
    const char *white = "\033[47m";
    const char *green = "\033[42m";

    int height = 12;
    int width = 40;

    for (int i = 0; i < height; i++) {
        int triangle_width;
        if (i <= 5) {
            triangle_width = (i + 1) * 2;
        } else {
            triangle_width = (height - i) * 2;
        }

        const char *stripe_color;
        if (i < 4) {
            stripe_color = black;
        } else if (i < 8) {
            stripe_color = white;
        } else {
            stripe_color = green;
        }

        printf("%s", red);
        for (int j = 0; j < triangle_width; j++) {
            printf(" ");
        }
        printf("%s", stripe_color);
        for (int j = triangle_width; j < width; j++) {
            printf(" ");
        }
        printf("%s\n", reset);
    }
    printf("\n");
}

int find_available_port() {
    struct sockaddr_in addr;
    int sock;
    int port = 0;

    if ((sock = socket(AF_INET, SOCK_STREAM, 0)) < 0) {
        perror("Socket creation failed");
        return 0;
    }

    memset(&addr, 0, sizeof(addr));
    addr.sin_family = AF_INET;
    addr.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
    addr.sin_port = 0;

    if (bind(sock, (struct sockaddr *)&addr, sizeof(addr)) < 0) {
        perror("Bind failed");
        close(sock);
        return 0;
    }

    socklen_t len = sizeof(addr);
    if (getsockname(sock, (struct sockaddr *)&addr, &len) < 0) {
        perror("Getsockname failed");
        close(sock);
        return 0;
    }

    port = ntohs(addr.sin_port);
    close(sock);

    return port;
}

void generate_request_id(char *request_id, size_t size) {
    // Seed the random number generator with high-resolution time and process ID
    // Include microseconds for better uniqueness
    struct timeval tv;
    gettimeofday(&tv, NULL);
    unsigned int seed = (unsigned int)(tv.tv_sec ^ tv.tv_usec ^ getpid());
    srand(seed);

    const char charset[] = "0123456789abcdef";
    for (size_t i = 0; i < size - 1; i++) {
        request_id[i] = charset[rand() % (sizeof(charset) - 1)];
    }
    request_id[size - 1] = '\0';
}

// Helper function to get default gateway interface
static char *get_default_gateway_interface() {
    FILE *fp = fopen("/proc/net/route", "r");
    if (!fp) {
        return NULL;
    }

    char line[256];
    char *interface = NULL;

    // Skip header line
    if (!fgets(line, sizeof(line), fp)) {
        fclose(fp);
        return NULL;
    }

    while (fgets(line, sizeof(line), fp)) {
        char iface[16];
        unsigned long dest, gateway, flags;

        if (sscanf(line, "%15s %lx %lx %lx", iface, &dest, &gateway, &flags) == 4) {
            // Check if destination is 0.0.0.0 (default route) and gateway is not 0.0.0.0
            if (dest == 0 && gateway != 0 && (flags & RTF_UP) && (flags & RTF_GATEWAY)) {
                interface = strdup(iface);
                break;
            }
        }
    }

    fclose(fp);
    return interface;
}

char *expand_transport_list(const char *transport_spec, int for_relay) {
    if (!transport_spec) {
        return NULL;
    }

    // If it's already "any", expand to the appropriate list
    if (strcmp(transport_spec, "any") == 0) {
        if (for_relay) {
            return strdup(RELAY_TRANSPORTS);
        } else {
            return strdup(SUPPORTED_TRANSPORTS);
        }
    }

    // If it's not "any", return as-is (could be a comma-separated list)
    return strdup(transport_spec);
}

// Get the weight of a transport (lower number = higher priority)
int get_transport_weight(const char *transport) {
    if (strcmp(transport, "websocket") == 0) {
        return WEBSOCKET_WEIGHT;
    }
    // Default weight for unknown transports
    return 999;
}

// Select the transport with the lowest weight (highest priority) from a comma-separated list
char *select_best_transport(const char *transport_list) {
    if (!transport_list) {
        return NULL;
    }

    char *list_copy = strdup(transport_list);
    if (!list_copy) {
        return NULL;
    }

    char *best_transport = NULL;
    int best_weight = INT_MAX;

    char *token = strtok(list_copy, ",");
    while (token) {
        // Trim whitespace
        while (*token == ' ' || *token == '\t') token++;
        char *end = token + strlen(token) - 1;
        while (end > token && (*end == ' ' || *end == '\t')) *end-- = '\0';

        int weight = get_transport_weight(token);
        if (weight < best_weight) {
            best_weight = weight;
            best_transport = token;
        }

        token = strtok(NULL, ",");
    }

    char *result = NULL;
    if (best_transport) {
        result = strdup(best_transport);
    }

    free(list_copy);
    return result;
}

char *autodetect_local_ip() {
    struct ifaddrs *ifaddr, *ifa;
    char *selected_ip = NULL;
    char *gateway_iface = get_default_gateway_interface();

    if (getifaddrs(&ifaddr) == -1) {
        perror("getifaddrs");
        free(gateway_iface);
        return strdup("127.0.0.1"); // Fallback to localhost
    }

    // First priority: interface with default gateway
    if (gateway_iface) {
        for (ifa = ifaddr; ifa != NULL; ifa = ifa->ifa_next) {
            if (ifa->ifa_addr == NULL) continue;
            if (ifa->ifa_addr->sa_family != AF_INET) continue;
            if (strcmp(ifa->ifa_name, "lo") == 0) continue; // Skip loopback
            if (strcmp(ifa->ifa_name, gateway_iface) == 0) {
                struct sockaddr_in *addr = (struct sockaddr_in *)ifa->ifa_addr;
                selected_ip = strdup(inet_ntoa(addr->sin_addr));
                break;
            }
        }
    }

    // Second priority: first non-loopback interface
    if (!selected_ip) {
        for (ifa = ifaddr; ifa != NULL; ifa = ifa->ifa_next) {
            if (ifa->ifa_addr == NULL) continue;
            if (ifa->ifa_addr->sa_family != AF_INET) continue;
            if (strcmp(ifa->ifa_name, "lo") == 0) continue; // Skip loopback

            struct sockaddr_in *addr = (struct sockaddr_in *)ifa->ifa_addr;
            selected_ip = strdup(inet_ntoa(addr->sin_addr));
            break;
        }
    }

    freeifaddrs(ifaddr);
    free(gateway_iface);

    // Final fallback
    if (!selected_ip) {
        selected_ip = strdup("127.0.0.1");
    }

    return selected_ip;
}