/*
 * WSSSH Library - WebSocket functions implementation
 *
 * Copyright (C) 2024 Stefy Lanza <stefy@nexlab.net> and SexHack.me
 *
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see <https://www.gnu.org/licenses/>.
 */

#include "websocket.h"
#include "wssshlib.h"
#include "tunnel.h"
#include "control_messages.h"
#include <openssl/err.h>
#include <openssl/ssl.h>
#include <string.h>
#include <stdlib.h>
#include <stdio.h>
#include <unistd.h>
#include <netdb.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <time.h>

int websocket_handshake(SSL *ssl, const char *host, int port, const char *path, int debug) {
    char request[1024];
    char response[BUFFER_SIZE];
    int bytes_read;

    if (debug) {
        fprintf(stderr, "[DEBUG] Starting WebSocket handshake to %s:%d\n", host, port);
    }

    // Send WebSocket handshake
    snprintf(request, sizeof(request),
             "GET %s HTTP/1.1\r\n"
             "Host: %s:%d\r\n"
             "Upgrade: websocket\r\n"
             "Connection: upgrade\r\n"
             "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n"
             "Sec-WebSocket-Version: 13\r\n"
             "\r\n",
             path, host, port);

    if (debug) {
        fprintf(stderr, "[DEBUG] Sending WebSocket handshake request...\n");
    }

    // Lock SSL mutex for write operation
    pthread_mutex_lock(&ssl_mutex);
    if (SSL_write(ssl, request, strlen(request)) <= 0) {
        ERR_print_errors_fp(stderr);
        fprintf(stderr, "WebSocket handshake send failed\n");
        pthread_mutex_unlock(&ssl_mutex);
        return 0;
    }

    if (debug) {
        fprintf(stderr, "[DEBUG] WebSocket handshake request sent, waiting for response...\n");
    }

    // Read response
    bytes_read = SSL_read(ssl, response, sizeof(response) - 1);
    pthread_mutex_unlock(&ssl_mutex);
    if (bytes_read <= 0) {
        ERR_print_errors_fp(stderr);
        fprintf(stderr, "WebSocket handshake recv failed (bytes_read=%d)\n", bytes_read);
        return 0;
    }

    response[bytes_read] = '\0';
    if (debug) {
        fprintf(stderr, "[DEBUG] Received WebSocket handshake response (%d bytes)\n", bytes_read);
    }

    // Check for successful handshake
    if (strstr(response, "101 Switching Protocols") == NULL) {
        fprintf(stderr, "WebSocket handshake failed - no 101 response\n");
        fprintf(stderr, "[DEBUG] Response: %.200s\n", response);
        return 0;
    }

    if (debug) {
        fprintf(stderr, "[DEBUG] WebSocket handshake successful\n");
    }
    return 1;
}


int send_websocket_frame(SSL *ssl, const char *data) {
    // Lock SSL mutex to prevent concurrent SSL operations
    pthread_mutex_lock(&ssl_mutex);

    int msg_len = strlen(data);
    int header_len = 2;

    if (msg_len <= 125) {
        header_len = 6; // 2 + 4 for mask
    } else if (msg_len <= 65535) {
        header_len = 8; // 4 + 4 for mask
    } else {
        header_len = 14; // 10 + 4 for mask
    }

    int frame_len = header_len + msg_len;
    char *frame = malloc(frame_len);
    if (!frame) {
        pthread_mutex_unlock(&ssl_mutex);
        return 0;
    }

    frame[0] = 0x81; // FIN + text opcode

    if (msg_len <= 125) {
        frame[1] = 0x80 | msg_len; // MASK + length
    } else if (msg_len <= 65535) {
        frame[1] = 0x80 | 126; // MASK + extended length
        frame[2] = (msg_len >> 8) & 0xFF;
        frame[3] = msg_len & 0xFF;
    } else {
        frame[1] = 0x80 | 127; // MASK + extended length
        frame[2] = 0;
        frame[3] = 0;
        frame[4] = 0;
        frame[5] = 0;
        frame[6] = (msg_len >> 24) & 0xFF;
        frame[7] = (msg_len >> 16) & 0xFF;
        frame[8] = (msg_len >> 8) & 0xFF;
        frame[9] = msg_len & 0xFF;
    }

    // Add mask key
    char mask_key[4];
    for (int i = 0; i < 4; i++) {
        mask_key[i] = rand() % 256;
        frame[header_len - 4 + i] = mask_key[i];
    }

    // Mask payload
    for (int i = 0; i < msg_len; i++) {
        frame[header_len + i] = data[i] ^ mask_key[i % 4];
    }

    // Handle partial writes for large frames with SIGINT checking
    int total_written = 0;
    int retry_count = 0;
    const int max_retries = 3;

    while (total_written < frame_len && retry_count < max_retries) {
        // Check for SIGINT to allow interruption
        if (sigint_received) {
            fprintf(stderr, "[DEBUG] SIGINT received during WebSocket send, aborting\n");
            fflush(stderr);
            free(frame);
            pthread_mutex_unlock(&ssl_mutex);
            return 0;
        }

        int to_write = frame_len - total_written;
        // Limit to BUFFER_SIZE to avoid issues with very large frames
        if (to_write > BUFFER_SIZE) {
            to_write = BUFFER_SIZE;
        }
        int written = SSL_write(ssl, frame + total_written, to_write);
        if (written <= 0) {
            int ssl_error = SSL_get_error(ssl, written);

            // Handle transient SSL errors with retry
            if ((ssl_error == SSL_ERROR_WANT_READ || ssl_error == SSL_ERROR_WANT_WRITE) && retry_count < max_retries - 1) {
                retry_count++;
                usleep(10000); // Wait 10ms before retry
                continue; // Retry the write operation
            }

            fprintf(stderr, "WebSocket frame SSL_write failed: %d (after %d retries)\n", ssl_error, retry_count);
            char error_buf[256];
            ERR_error_string_n(ssl_error, error_buf, sizeof(error_buf));
            fprintf(stderr, "SSL write error details: %s\n", error_buf);
            free(frame);
            pthread_mutex_unlock(&ssl_mutex);
            return 0; // Write failed
        }
        total_written += written;
        retry_count = 0; // Reset retry count on successful write
    }

    if (total_written < frame_len) {
        fprintf(stderr, "WebSocket frame write incomplete: %d/%d bytes written\n", total_written, frame_len);
        free(frame);
        pthread_mutex_unlock(&ssl_mutex);
        return 0;
    }

    free(frame);
    pthread_mutex_unlock(&ssl_mutex);
    return 1;
}

int send_websocket_binary_frame(SSL *ssl, const unsigned char *data, size_t data_len) {
    // Lock SSL mutex to prevent concurrent SSL operations
    pthread_mutex_lock(&ssl_mutex);

    int header_len = 2;

    if (data_len <= 125) {
        header_len = 6; // 2 + 4 for mask
    } else if (data_len <= 65535) {
        header_len = 8; // 4 + 4 for mask
    } else {
        header_len = 14; // 10 + 4 for mask
    }

    int frame_len = header_len + data_len;
    char *frame = malloc(frame_len);
    if (!frame) {
        pthread_mutex_unlock(&ssl_mutex);
        return 0;
    }

    frame[0] = 0x82; // FIN + binary opcode

    if (data_len <= 125) {
        frame[1] = 0x80 | data_len; // MASK + length
    } else if (data_len <= 65535) {
        frame[1] = 0x80 | 126; // MASK + extended length
        frame[2] = (data_len >> 8) & 0xFF;
        frame[3] = data_len & 0xFF;
    } else {
        frame[1] = 0x80 | 127; // MASK + extended length
        frame[2] = 0;
        frame[3] = 0;
        frame[4] = 0;
        frame[5] = 0;
        frame[6] = (data_len >> 24) & 0xFF;
        frame[7] = (data_len >> 16) & 0xFF;
        frame[8] = (data_len >> 8) & 0xFF;
        frame[9] = data_len & 0xFF;
    }

    // Add mask key
    char mask_key[4];
    for (int i = 0; i < 4; i++) {
        mask_key[i] = rand() % 256;
        frame[header_len - 4 + i] = mask_key[i];
    }

    // Mask payload
    for (size_t i = 0; i < data_len; i++) {
        frame[header_len + i] = data[i] ^ mask_key[i % 4];
    }

    // Handle partial writes for large frames with SIGINT checking
    int total_written = 0;
    int retry_count = 0;
    const int max_retries = 3;

    while (total_written < frame_len && retry_count < max_retries) {
        // Check for SIGINT to allow interruption
        if (sigint_received) {
            fprintf(stderr, "[DEBUG] SIGINT received during WebSocket send, aborting\n");
            fflush(stderr);
            free(frame);
            pthread_mutex_unlock(&ssl_mutex);
            return 0;
        }

        int to_write = frame_len - total_written;
        // Limit to BUFFER_SIZE to avoid issues with very large frames
        if (to_write > BUFFER_SIZE) {
            to_write = BUFFER_SIZE;
        }
        int written = SSL_write(ssl, frame + total_written, to_write);
        if (written <= 0) {
            int ssl_error = SSL_get_error(ssl, written);

            // Handle transient SSL errors with retry
            if ((ssl_error == SSL_ERROR_WANT_READ || ssl_error == SSL_ERROR_WANT_WRITE) && retry_count < max_retries - 1) {
                retry_count++;
                usleep(10000); // Wait 10ms before retry
                continue; // Retry the write operation
            }

            fprintf(stderr, "WebSocket binary frame SSL_write failed: %d (after %d retries)\n", ssl_error, retry_count);
            char error_buf[256];
            ERR_error_string_n(ssl_error, error_buf, sizeof(error_buf));
            fprintf(stderr, "SSL write error details: %s\n", error_buf);
            free(frame);
            pthread_mutex_unlock(&ssl_mutex);
            return 0; // Write failed
        }
        total_written += written;
        retry_count = 0; // Reset retry count on successful write
    }

    if (total_written < frame_len) {
        fprintf(stderr, "WebSocket binary frame write incomplete: %d/%d bytes written\n", total_written, frame_len);
        free(frame);
        pthread_mutex_unlock(&ssl_mutex);
        return 0;
    }

    free(frame);
    pthread_mutex_unlock(&ssl_mutex);
    return 1;
}

// Bridge mode transport layer functions - Pure WebSocket connection without tunnel setup
int setup_websocket_connection(const char *host, int port, const char *client_id, int debug, SSL_CTX **ctx_out) {
    int sock;
    struct sockaddr_in server_addr;
    SSL_CTX *ctx;
    SSL *ssl;

    if (debug) {
        fprintf(stderr, "[DEBUG] Setting up pure WebSocket connection to %s:%d for client %s\n", host, port, client_id);
    }

    // Create socket
    sock = socket(AF_INET, SOCK_STREAM, 0);
    if (sock < 0) {
        perror("Socket creation failed");
        return -1;
    }

    // Set up server address
    memset(&server_addr, 0, sizeof(server_addr));
    server_addr.sin_family = AF_INET;
    server_addr.sin_port = htons(port);

    struct hostent *he;
    if ((he = gethostbyname(host)) == NULL) {
        fprintf(stderr, "Failed to resolve hostname: %s\n", host);
        close(sock);
        return -1;
    }
    server_addr.sin_addr = *((struct in_addr *)he->h_addr);

    // Connect to server
    if (connect(sock, (struct sockaddr *)&server_addr, sizeof(server_addr)) < 0) {
        perror("Connection failed");
        close(sock);
        return -1;
    }

    // Set up SSL context
    SSL_library_init();
    OpenSSL_add_all_algorithms();
    SSL_load_error_strings();
    ctx = SSL_CTX_new(TLS_client_method());
    if (!ctx) {
        fprintf(stderr, "SSL context creation failed\n");
        close(sock);
        return -1;
    }

    // Create SSL connection
    ssl = SSL_new(ctx);
    if (!ssl) {
        fprintf(stderr, "SSL creation failed\n");
        SSL_CTX_free(ctx);
        close(sock);
        return -1;
    }

    SSL_set_fd(ssl, sock);

    // Perform SSL handshake
    if (SSL_connect(ssl) != 1) {
        ERR_print_errors_fp(stderr);
        fprintf(stderr, "SSL handshake failed\n");
        SSL_free(ssl);
        SSL_CTX_free(ctx);
        close(sock);
        return -1;
    }

    // Perform WebSocket handshake
    char path[256];
    snprintf(path, sizeof(path), "/ws/%s", client_id);

    if (!websocket_handshake(ssl, host, port, path, debug)) {
        fprintf(stderr, "WebSocket handshake failed\n");
        SSL_free(ssl);
        SSL_CTX_free(ctx);
        close(sock);
        return -1;
    }

    // Store SSL in global tunnel for cleanup (bridge mode doesn't use full tunnel struct)
    if (!active_tunnel) {
        active_tunnel = malloc(sizeof(tunnel_t));
        if (!active_tunnel) {
            fprintf(stderr, "Memory allocation failed for tunnel\n");
            SSL_free(ssl);
            SSL_CTX_free(ctx);
            close(sock);
            return -1;
        }
        memset(active_tunnel, 0, sizeof(tunnel_t));
    }
    active_tunnel->ssl = ssl;
    // Note: ctx is not stored in tunnel_t, we'll manage it separately for bridge mode
    // For bridge mode, we need to store ctx somewhere accessible for cleanup
    // For now, we'll handle cleanup in the calling function

    if (debug) {
        fprintf(stderr, "[DEBUG] Pure WebSocket connection established successfully\n");
    }

    // Return the SSL context for proper cleanup
    if (ctx_out) {
        *ctx_out = ctx;
    }

    return sock;
}

int send_websocket_message(int sock, const char *message, int len, const char *channel, int debug) {
    // For now, send all messages as text frames
    // In a full implementation, you might want to route based on channel
    char frame[BUFFER_SIZE];
    int header_len = 2;

    if (len <= 125) {
        header_len = 6; // 2 + 4 for mask
    } else if (len <= 65535) {
        header_len = 8; // 4 + 4 for mask
    } else {
        header_len = 14; // 10 + 4 for mask
    }

    int frame_len = header_len + len;
    if (frame_len > BUFFER_SIZE) {
        fprintf(stderr, "Message too large for buffer\n");
        return 0;
    }

    frame[0] = 0x81; // FIN + text opcode

    if (len <= 125) {
        frame[1] = 0x80 | len; // MASK + length
    } else if (len <= 65535) {
        frame[1] = 0x80 | 126; // MASK + extended length
        frame[2] = (len >> 8) & 0xFF;
        frame[3] = len & 0xFF;
    } else {
        frame[1] = 0x80 | 127; // MASK + extended length
        frame[2] = 0;
        frame[3] = 0;
        frame[4] = 0;
        frame[5] = 0;
        frame[6] = (len >> 24) & 0xFF;
        frame[7] = (len >> 16) & 0xFF;
        frame[8] = (len >> 8) & 0xFF;
        frame[9] = len & 0xFF;
    }

    // Add mask key
    char mask_key[4];
    for (int i = 0; i < 4; i++) {
        mask_key[i] = rand() % 256;
        frame[header_len - 4 + i] = mask_key[i];
    }

    // Mask payload
    for (int i = 0; i < len; i++) {
        frame[header_len + i] = message[i] ^ mask_key[i % 4];
    }

    // Send frame
    int total_sent = 0;
    while (total_sent < frame_len) {
        int sent = send(sock, frame + total_sent, frame_len - total_sent, 0);
        if (sent <= 0) {
            perror("Send failed");
            return 0;
        }
        total_sent += sent;
    }

    if (debug) {
        printf("[DEBUG] Sent %d bytes to %s channel\n", len, channel);
        fflush(stdout);
    }

    return 1;
}

int send_pong_frame_ws(int sock, const char *ping_payload, int payload_len) {
    char frame[BUFFER_SIZE];
    int header_len = 2;

    frame[0] = 0x8A; // FIN + pong opcode

    if (payload_len <= 125) {
        frame[1] = 0x80 | payload_len; // MASK + length
    } else if (payload_len <= 65535) {
        frame[1] = 0x80 | 126; // MASK + extended length
        frame[2] = (payload_len >> 8) & 0xFF;
        frame[3] = payload_len & 0xFF;
        header_len = 4;
    } else {
        frame[1] = 0x80 | 127; // MASK + extended length
        frame[2] = 0;
        frame[3] = 0;
        frame[4] = 0;
        frame[5] = 0;
        frame[6] = (payload_len >> 24) & 0xFF;
        frame[7] = (payload_len >> 16) & 0xFF;
        frame[8] = (payload_len >> 8) & 0xFF;
        frame[9] = payload_len & 0xFF;
        header_len = 10;
    }

    // Add mask key
    char mask_key[4];
    for (int i = 0; i < 4; i++) {
        mask_key[i] = rand() % 256;
        frame[header_len + i] = mask_key[i];
    }
    header_len += 4;

    // Mask payload
    for (int i = 0; i < payload_len; i++) {
        frame[header_len + i] = ping_payload[i] ^ mask_key[i % 4];
    }

    int frame_len = header_len + payload_len;

    // Send frame
    int total_sent = 0;
    while (total_sent < frame_len) {
        int sent = send(sock, frame + total_sent, frame_len - total_sent, 0);
        if (sent <= 0) {
            perror("Pong send failed");
            return 0;
        }
        total_sent += sent;
    }

    return 1;
}


int parse_websocket_frame(const char *buffer, int bytes_read, char **payload, int *payload_len) {
    if (bytes_read < 2) {
        return 0; // Not enough data for a frame
    }

    unsigned char frame_type = buffer[0] & 0x8F;
    if (frame_type != 0x81 && frame_type != 0x82 && frame_type != 0x88 && frame_type != 0x89 && frame_type != 0x8A) {
        return 0; // Not a supported frame type (text, binary, close, ping, pong)
    }

    int masked = buffer[1] & 0x80;
    int len_indicator = buffer[1] & 0x7F;
    int header_len = 2;

    if (len_indicator <= 125) {
        *payload_len = len_indicator;
    } else if (len_indicator == 126) {
        if (bytes_read < 4) return 0;
        *payload_len = ((unsigned char)buffer[2] << 8) | (unsigned char)buffer[3];
        header_len = 4;
    } else if (len_indicator == 127) {
        if (bytes_read < 10) return 0;
        // Check for potential integer overflow and ensure we don't read beyond buffer
        if (bytes_read < 10) return 0;
        unsigned long long full_len = ((unsigned long long)(unsigned char)buffer[2] << 56) |
                                      ((unsigned long long)(unsigned char)buffer[3] << 48) |
                                      ((unsigned long long)(unsigned char)buffer[4] << 40) |
                                      ((unsigned long long)(unsigned char)buffer[5] << 32) |
                                      ((unsigned long long)(unsigned char)buffer[6] << 24) |
                                      ((unsigned long long)(unsigned char)buffer[7] << 16) |
                                      ((unsigned long long)(unsigned char)buffer[8] << 8) |
                                      (unsigned char)buffer[9];
        if (full_len > INT_MAX) {
            return 0; // Payload too large
        }
        *payload_len = (int)full_len;
        header_len = 10;
    } else {
        // Invalid length indicator
        return 0;
    }

    if (masked) {
        header_len += 4;
    }

    // Ensure we have enough data for the complete frame
    if (bytes_read < header_len + *payload_len) {
        return 0; // Incomplete frame
    }

    // Ensure payload length is reasonable (prevent potential DoS)
    const size_t MAX_SAFE_PAYLOAD = 50 * 1024 * 1024; // 50MB safety limit
    if (*payload_len < 0 || (size_t)*payload_len > MAX_SAFE_PAYLOAD) {
        printf("[DEBUG] parse_websocket_frame: Payload too large: %d bytes (max: %zu)\n",
               *payload_len, MAX_SAFE_PAYLOAD);
        return 0;
    }

    *payload = (char *)buffer + header_len;
    if (masked) {
        // Fix: mask_key should be at header_len - 4, not header_len
        // The mask key comes right after the length field
        char *mask_key = (char *)buffer + (header_len - 4);
        for (int i = 0; i < *payload_len; i++) {
            (*payload)[i] ^= mask_key[i % 4];
        }
    }

    return 1;
}