/*
 * WSSSH Library - Control Channel Messages Implementation
 *
 * Copyright (C) 2024 Stefy Lanza <stefy@nexlab.net> and SexHack.me
 *
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see <https://www.gnu.org/licenses/>.
 */

#include "control_messages.h"
#include "websocket.h"
#include "tunnel.h"
#include <openssl/ssl.h>
#include <openssl/err.h>
#include <string.h>
#include <stdio.h>
#include <stdlib.h>
#include <time.h>
#include <unistd.h>

#define BUFFER_SIZE 1048576

int send_json_message(SSL *ssl, const char *type, const char *client_id, const char *request_id) {
    char message[1024];

    if (request_id) {
        snprintf(message, sizeof(message),
                 "{\"type\":\"%s\",\"client_id\":\"%s\",\"request_id\":\"%s\"}",
                 type, client_id, request_id);
    } else {
        snprintf(message, sizeof(message),
                 "{\"type\":\"%s\",\"client_id\":\"%s\"}",
                 type, client_id);
    }

    // Send as WebSocket frame
    return send_websocket_frame(ssl, message);
}

int send_registration_message(SSL *ssl, const char *client_id, const char *password, const char *tunnel, const char *tunnel_control, const char *wssshd_private_ip) {
    return send_registration_message_with_services(ssl, client_id, password, tunnel, tunnel_control, wssshd_private_ip, NULL, 0);
}

int send_registration_message_with_services(SSL *ssl, const char *client_id, const char *password, const char *tunnel, const char *tunnel_control, const char *wssshd_private_ip, service_config_t **services, int num_services) {
    char message[4096];
    char services_str[2048] = "";

    // Build services string if services are provided
    if (services && num_services > 0) {
        strcat(services_str, "\"services\":[");
        for (int i = 0; i < num_services; i++) {
            if (i > 0) strcat(services_str, ",");
            char service_json[256];
            snprintf(service_json, sizeof(service_json),
                     "{\"name\":\"%s\",\"host\":\"%s\",\"port\":%d,\"proto\":\"%s\"}",
                     services[i]->name, services[i]->tunnel_host, services[i]->tunnel_port,
                     services[i]->proto ? services[i]->proto : "tcp");
            strncat(services_str, service_json, sizeof(services_str) - strlen(services_str) - 1);
        }
        strcat(services_str, "]");
    }

    if (password && strlen(password) > 0) {
        if (wssshd_private_ip && strlen(wssshd_private_ip) > 0) {
            if (services && num_services > 0) {
                snprintf(message, sizeof(message),
                         "{\"type\":\"register\",\"client_id\":\"%s\",\"password\":\"%s\",\"tunnel\":\"%s\",\"tunnel_control\":\"%s\",\"wssshd_private_ip\":\"%s\",%s,\"version\":\"%s\"}",
                         client_id, password, tunnel, tunnel_control, wssshd_private_ip, services_str, WSSSH_VERSION);
            } else {
                snprintf(message, sizeof(message),
                         "{\"type\":\"register\",\"client_id\":\"%s\",\"password\":\"%s\",\"tunnel\":\"%s\",\"tunnel_control\":\"%s\",\"wssshd_private_ip\":\"%s\",\"version\":\"%s\"}",
                         client_id, password, tunnel, tunnel_control, wssshd_private_ip, WSSSH_VERSION);
            }
        } else {
            if (services && num_services > 0) {
                snprintf(message, sizeof(message),
                         "{\"type\":\"register\",\"client_id\":\"%s\",\"password\":\"%s\",\"tunnel\":\"%s\",\"tunnel_control\":\"%s\",%s,\"version\":\"%s\"}",
                         client_id, password, tunnel, tunnel_control, services_str, WSSSH_VERSION);
            } else {
                snprintf(message, sizeof(message),
                         "{\"type\":\"register\",\"client_id\":\"%s\",\"password\":\"%s\",\"tunnel\":\"%s\",\"tunnel_control\":\"%s\",\"version\":\"%s\"}",
                         client_id, password, tunnel, tunnel_control, WSSSH_VERSION);
            }
        }
    } else {
        if (wssshd_private_ip && strlen(wssshd_private_ip) > 0) {
            if (services && num_services > 0) {
                snprintf(message, sizeof(message),
                         "{\"type\":\"register\",\"client_id\":\"%s\",\"tunnel\":\"%s\",\"tunnel_control\":\"%s\",\"wssshd_private_ip\":\"%s\",%s,\"version\":\"%s\"}",
                         client_id, tunnel, tunnel_control, wssshd_private_ip, services_str, WSSSH_VERSION);
            } else {
                snprintf(message, sizeof(message),
                         "{\"type\":\"register\",\"client_id\":\"%s\",\"tunnel\":\"%s\",\"tunnel_control\":\"%s\",\"wssshd_private_ip\":\"%s\",\"version\":\"%s\"}",
                         client_id, tunnel, tunnel_control, wssshd_private_ip, WSSSH_VERSION);
            }
        } else {
            if (services && num_services > 0) {
                snprintf(message, sizeof(message),
                         "{\"type\":\"register\",\"client_id\":\"%s\",\"tunnel\":\"%s\",\"tunnel_control\":\"%s\",%s,\"version\":\"%s\"}",
                         client_id, tunnel, tunnel_control, services_str, WSSSH_VERSION);
            } else {
                snprintf(message, sizeof(message),
                         "{\"type\":\"register\",\"client_id\":\"%s\",\"tunnel\":\"%s\",\"tunnel_control\":\"%s\",\"version\":\"%s\"}",
                         client_id, tunnel, tunnel_control, WSSSH_VERSION);
            }
        }
    }

    printf("[DEBUG] Sending registration message: %s\n", message);
    fflush(stdout);

    // Send as WebSocket frame
    int result = send_websocket_frame(ssl, message);
    if (result) {
        printf("[DEBUG] Registration message sent successfully\n");
        fflush(stdout);
    } else {
        printf("[DEBUG] Failed to send registration message\n");
        fflush(stdout);
    }
    return result;
}

int send_tunnel_request_message(SSL *ssl, const char *client_id, const char *request_id, const char *tunnel, const char *tunnel_control, const char *service) {
    return send_tunnel_request_message_with_enc(ssl, client_id, request_id, tunnel, tunnel_control, service, ENCODING_HEX);
}

int send_tunnel_request_message_with_enc(SSL *ssl, const char *client_id, const char *request_id, const char *tunnel, const char *tunnel_control, const char *service, wsssh_encoding_t encoding) {
    const char *enc_str;
    switch (encoding) {
        case ENCODING_HEX:
            enc_str = "hex";
            break;
        case ENCODING_BASE64:
            enc_str = "base64";
            break;
        case ENCODING_BINARY:
            enc_str = "bin";
            break;
        default:
            enc_str = "hex";
            break;
    }

    char message[1024];

    if (service) {
        snprintf(message, sizeof(message),
                  "{\"type\":\"tunnel_request\",\"client_id\":\"%s\",\"request_id\":\"%s\",\"tunnel\":\"%s\",\"tunnel_control\":\"%s\",\"service\":\"%s\",\"enc\":\"%s\",\"version\":\"%s\"}",
                  client_id, request_id, tunnel, tunnel_control, service, enc_str, WSSSH_VERSION);
    } else {
        snprintf(message, sizeof(message),
                  "{\"type\":\"tunnel_request\",\"client_id\":\"%s\",\"request_id\":\"%s\",\"tunnel\":\"%s\",\"tunnel_control\":\"%s\",\"enc\":\"%s\",\"version\":\"%s\"}",
                  client_id, request_id, tunnel, tunnel_control, enc_str, WSSSH_VERSION);
    }

    // Send as WebSocket frame
    return send_websocket_frame(ssl, message);
}


int send_tunnel_close_message(SSL *ssl, const char *request_id, int debug) {
    char close_msg[256];
    snprintf(close_msg, sizeof(close_msg), "{\"type\":\"tunnel_close\",\"request_id\":\"%s\"}", request_id);

    if (debug) {
        printf("[DEBUG - Tunnel] Sending tunnel_close: %s\n", close_msg);
        fflush(stdout);
    }

    if (!send_websocket_frame(ssl, close_msg)) {
        if (debug) {
            printf("[DEBUG - Tunnel] Failed to send tunnel_close message\n");
            fflush(stdout);
        }
        return 0;
    }
    return 1;
}

int send_tunnel_keepalive_message(SSL *ssl, tunnel_t *tunnel, int debug) {
    if (!tunnel || !tunnel->active) return 0;

    time_t current_time = time(NULL);

    // Reset stats every 30 seconds
    if (current_time - tunnel->last_stats_reset >= 30) {
        tunnel->bytes_last_period = 0;
        tunnel->last_stats_reset = current_time;
    }

    // Calculate rate (bytes per second over last 30 seconds)
    double rate_bps = 0.0;
    if (current_time > tunnel->last_stats_reset) {
        rate_bps = (double)tunnel->bytes_last_period / (current_time - tunnel->last_stats_reset);
    }

    // Send keep-alive message
    char keepalive_msg[512];
    unsigned long long total_bytes = tunnel->total_bytes_sent + tunnel->total_bytes_received;
    snprintf(keepalive_msg, sizeof(keepalive_msg),
             "{\"type\":\"tunnel_keepalive\",\"request_id\":\"%s\",\"total_bytes\":%llu,\"rate_bps\":%.2f}",
             tunnel->request_id, total_bytes, rate_bps);

    if (debug) {
        printf("[DEBUG - Tunnel] Sending keep-alive for tunnel %s: %s\n", tunnel->request_id, keepalive_msg);
        fflush(stdout);
    }

    if (send_websocket_frame(ssl, keepalive_msg)) {
        tunnel->last_keepalive_sent = current_time;
        return 1;
    } else {
        if (debug) {
            printf("[DEBUG - Tunnel] Failed to send keep-alive for tunnel %s\n", tunnel->request_id);
            fflush(stdout);
        }
        return 0;
    }
}

int send_tunnel_keepalive_ack_message(SSL *ssl, const char *request_id, int debug) {
    char ack_msg[256];
    snprintf(ack_msg, sizeof(ack_msg), "{\"type\":\"tunnel_keepalive_ack\",\"request_id\":\"%s\"}", request_id);

    if (debug) {
        printf("[DEBUG - WebSockets] Sending tunnel_keepalive_ack: %s\n", ack_msg);
        fflush(stdout);
    }

    if (!send_websocket_frame(ssl, ack_msg)) {
        if (debug) {
            printf("[DEBUG - Tunnel] Failed to send keep-alive ACK for tunnel %s\n", request_id);
            fflush(stdout);
        }
        return 0;
    }
    return 1;
}

int send_ping_frame(SSL *ssl, const char *ping_payload, int payload_len) {
    // Lock SSL mutex to prevent concurrent SSL operations
    pthread_mutex_lock(&ssl_mutex);

    char frame[BUFFER_SIZE];
    frame[0] = 0x89; // FIN + ping opcode
    int header_len = 2;

    if (payload_len <= 125) {
        frame[1] = 0x80 | payload_len; // MASK + length
    } else if (payload_len <= 65535) {
        frame[1] = 0x80 | 126; // MASK + extended length
        frame[2] = (payload_len >> 8) & 0xFF;
        frame[3] = payload_len & 0xFF;
        header_len = 4;
    } else {
        frame[1] = 0x80 | 127; // MASK + extended length
        frame[2] = 0;
        frame[3] = 0;
        frame[4] = 0;
        frame[5] = 0;
        frame[6] = (payload_len >> 24) & 0xFF;
        frame[7] = (payload_len >> 16) & 0xFF;
        frame[8] = (payload_len >> 8) & 0xFF;
        frame[9] = payload_len & 0xFF;
        header_len = 10;
    }

    // Add mask key
    char mask_key[4];
    for (int i = 0; i < 4; i++) {
        mask_key[i] = rand() % 256;
        frame[header_len + i] = mask_key[i];
    }
    header_len += 4;

    // Mask payload
    for (int i = 0; i < payload_len; i++) {
        frame[header_len + i] = ping_payload[i] ^ mask_key[i % 4];
    }

    int frame_len = header_len + payload_len;

    // Handle partial writes for large frames
    int total_written = 0;
    int retry_count = 0;
    const int max_retries = 3;

    while (total_written < frame_len && retry_count < max_retries) {
        int to_write = frame_len - total_written;
        // Limit to BUFFER_SIZE to avoid issues with very large frames
        if (to_write > BUFFER_SIZE) {
            to_write = BUFFER_SIZE;
        }
        int written = SSL_write(ssl, frame + total_written, to_write);
        if (written <= 0) {
            int ssl_error = SSL_get_error(ssl, written);

            // Handle transient SSL errors with retry
            if ((ssl_error == SSL_ERROR_WANT_READ || ssl_error == SSL_ERROR_WANT_WRITE) && retry_count < max_retries - 1) {
                retry_count++;
                usleep(10000); // Wait 10ms before retry
                continue; // Retry the write operation
            }

            fprintf(stderr, "Ping frame SSL_write failed: %d (after %d retries)\n", ssl_error, retry_count);
            char error_buf[256];
            ERR_error_string_n(ssl_error, error_buf, sizeof(error_buf));
            fprintf(stderr, "SSL write error details: %s\n", error_buf);
            pthread_mutex_unlock(&ssl_mutex);
            return 0; // Write failed
        }
        total_written += written;
        retry_count = 0; // Reset retry count on successful write
    }

    if (total_written < frame_len) {
        fprintf(stderr, "Ping frame write incomplete: %d/%d bytes written\n", total_written, frame_len);
        pthread_mutex_unlock(&ssl_mutex);
        return 0;
    }

    pthread_mutex_unlock(&ssl_mutex);
    return 1;
}

int send_pong_frame(SSL *ssl, const char *ping_payload, int payload_len) {
    // Lock SSL mutex to prevent concurrent SSL operations
    pthread_mutex_lock(&ssl_mutex);

    char frame[BUFFER_SIZE];
    frame[0] = 0x8A; // FIN + pong opcode
    int header_len = 2;

    if (payload_len <= 125) {
        frame[1] = 0x80 | payload_len; // MASK + length
    } else if (payload_len <= 65535) {
        frame[1] = 0x80 | 126; // MASK + extended length
        frame[2] = (payload_len >> 8) & 0xFF;
        frame[3] = payload_len & 0xFF;
        header_len = 4;
    } else {
        frame[1] = 0x80 | 127; // MASK + extended length
        frame[2] = 0;
        frame[3] = 0;
        frame[4] = 0;
        frame[5] = 0;
        frame[6] = (payload_len >> 24) & 0xFF;
        frame[7] = (payload_len >> 16) & 0xFF;
        frame[8] = (payload_len >> 8) & 0xFF;
        frame[9] = payload_len & 0xFF;
        header_len = 10;
    }

    // Add mask key
    char mask_key[4];
    for (int i = 0; i < 4; i++) {
        mask_key[i] = rand() % 256;
        frame[header_len + i] = mask_key[i];
    }
    header_len += 4;

    // Mask payload
    for (int i = 0; i < payload_len; i++) {
        frame[header_len + i] = ping_payload[i] ^ mask_key[i % 4];
    }

    int frame_len = header_len + payload_len;

    // Handle partial writes for large frames
    int total_written = 0;
    int retry_count = 0;
    const int max_retries = 3;

    while (total_written < frame_len && retry_count < max_retries) {
        int to_write = frame_len - total_written;
        // Limit to BUFFER_SIZE to avoid issues with very large frames
        if (to_write > BUFFER_SIZE) {
            to_write = BUFFER_SIZE;
        }
        int written = SSL_write(ssl, frame + total_written, to_write);
        if (written <= 0) {
            int ssl_error = SSL_get_error(ssl, written);

            // Handle transient SSL errors with retry
            if ((ssl_error == SSL_ERROR_WANT_READ || ssl_error == SSL_ERROR_WANT_WRITE) && retry_count < max_retries - 1) {
                retry_count++;
                usleep(10000); // Wait 10ms before retry
                continue; // Retry the write operation
            }

            fprintf(stderr, "Pong frame SSL_write failed: %d (after %d retries)\n", ssl_error, retry_count);
            char error_buf[256];
            ERR_error_string_n(ssl_error, error_buf, sizeof(error_buf));
            fprintf(stderr, "SSL write error details: %s\n", error_buf);
            pthread_mutex_unlock(&ssl_mutex);
            return 0; // Write failed
        }
        total_written += written;
        retry_count = 0; // Reset retry count on successful write
    }

    if (total_written < frame_len) {
        fprintf(stderr, "Pong frame write incomplete: %d/%d bytes written\n", total_written, frame_len);
        pthread_mutex_unlock(&ssl_mutex);
        return 0;
    }

    pthread_mutex_unlock(&ssl_mutex);
    return 1;
}