/*
 * WSSSH Library - Data Channel Messages Implementation
 *
 * Copyright (C) 2024 Stefy Lanza <stefy@nexlab.net> and SexHack.me
 *
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see <https://www.gnu.org/licenses/>.
 */

#include "data_messages.h"
#include "wssshlib.h"
#include "websocket.h"
#include "tunnel.h"
#include <openssl/ssl.h>
#include <string.h>
#include <stdio.h>
#include <stdlib.h>
#include <zlib.h>
#include <zlib.h>  // For CRC32
#include <time.h>
#include <pthread.h>

int send_tunnel_data_message(SSL *ssl, const char *request_id, const char *data_hex, int debug) {
    // Send as tunnel_data with size information
    size_t hex_len = strlen(data_hex);
    size_t binary_size = hex_len / 2; // Size of actual binary data
    size_t request_id_len = strlen(request_id);
    size_t json_overhead = strlen("{\"type\":\"tunnel_data\",\"request_id\":\"\",\"size\":,\"data\":\"\"}");
    size_t msg_size = json_overhead + request_id_len + 32 + hex_len + 1; // Extra 32 for safety
    char *message = malloc(msg_size);
    if (!message) {
        if (debug) {
            printf("[DEBUG] Failed to allocate memory for tunnel_data message (%zu bytes)\n", msg_size);
            fflush(stdout);
        }
        return 0;
    }
    int msg_len = snprintf(message, msg_size,
             "{\"type\":\"tunnel_data\",\"request_id\":\"%s\",\"size\":%zu,\"data\":\"%s\"}",
             request_id, binary_size, data_hex);
    if (msg_len < 0 || (size_t)msg_len >= msg_size) {
        if (debug) {
            printf("[DEBUG] Failed to format tunnel_data message (msg_len=%d, msg_size=%zu)\n", msg_len, msg_size);
            fflush(stdout);
        }
        free(message);
        return 0;
    }

    if (!send_websocket_frame(ssl, message)) {
        if (debug) {
            printf("[DEBUG] Failed to send tunnel_data WebSocket frame\n");
            fflush(stdout);
        }
        free(message);
        return 0;
    }

    free(message);
    return 1;
}

int send_tunnel_data_binary_message(SSL *ssl, const char *request_id, const unsigned char *data, size_t data_len, int debug) {
    // Send binary data directly in tunnel_data message
    // First, base64 encode the binary data
    size_t b64_len = ((data_len + 2) / 3) * 4 + 1; // Base64 encoded length + null terminator
    char *b64_data = malloc(b64_len);
    if (!b64_data) {
        if (debug) {
            printf("[DEBUG] Failed to allocate memory for base64 encoding (%zu bytes)\n", b64_len);
            fflush(stdout);
        }
        return 0;
    }

    // Simple base64 encoding
    static const char base64_chars[] = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
    size_t i = 0, j = 0;
    while (i < data_len) {
        uint32_t octet_a = i < data_len ? data[i++] : 0;
        uint32_t octet_b = i < data_len ? data[i++] : 0;
        uint32_t octet_c = i < data_len ? data[i++] : 0;

        uint32_t triple = (octet_a << 16) | (octet_b << 8) | octet_c;

        b64_data[j++] = base64_chars[(triple >> 18) & 0x3F];
        b64_data[j++] = base64_chars[(triple >> 12) & 0x3F];
        b64_data[j++] = base64_chars[(triple >> 6) & 0x3F];
        b64_data[j++] = base64_chars[triple & 0x3F];
    }

    // Add padding
    size_t padding = (3 - (data_len % 3)) % 3;
    for (size_t p = 0; p < padding; p++) {
        b64_data[j - 1 - p] = '=';
    }
    b64_data[j] = '\0';

    // Create JSON message
    size_t msg_size = strlen("{\"type\":\"tunnel_data\",\"request_id\":\"\",\"data\":\"\"}") + strlen(request_id) + b64_len + 1;
    char *message = malloc(msg_size);
    if (!message) {
        if (debug) {
            printf("[DEBUG] Failed to allocate memory for binary tunnel_data message (%zu bytes)\n", msg_size);
            fflush(stdout);
        }
        free(b64_data);
        return 0;
    }

    snprintf(message, msg_size, "{\"type\":\"tunnel_data\",\"request_id\":\"%s\",\"data\":\"%s\"}", request_id, b64_data);

    if (!send_websocket_frame(ssl, message)) {
        if (debug) {
            printf("[DEBUG] Failed to send binary tunnel_data WebSocket frame\n");
            fflush(stdout);
        }
        free(message);
        free(b64_data);
        return 0;
    }

    free(message);
    free(b64_data);
    return 1;
}

// CRC32 checksum calculation
uint32_t crc32_checksum(const unsigned char *data, size_t len) {
    return crc32(0L, data, len);
}

// Retransmission buffer functions
retransmission_buffer_t *retransmission_buffer_init(void) {
    retransmission_buffer_t *buffer = malloc(sizeof(retransmission_buffer_t));
    if (!buffer) return NULL;

    buffer->count = 0;
    buffer->next_frame_id = 1;  // Start from 1
    pthread_mutex_init(&buffer->mutex, NULL);

    // Initialize all entries as empty
    for (int i = 0; i < RETRANSMISSION_BUFFER_SIZE; i++) {
        buffer->entries[i].frame_id = 0;
        buffer->entries[i].message = NULL;
        buffer->entries[i].message_len = 0;
        buffer->entries[i].timestamp = 0;
        buffer->entries[i].retries = 0;
    }

    return buffer;
}

void retransmission_buffer_free(retransmission_buffer_t *buffer) {
    if (!buffer) return;

    pthread_mutex_lock(&buffer->mutex);
    for (int i = 0; i < RETRANSMISSION_BUFFER_SIZE; i++) {
        if (buffer->entries[i].message) {
            free(buffer->entries[i].message);
        }
    }
    pthread_mutex_unlock(&buffer->mutex);
    pthread_mutex_destroy(&buffer->mutex);
    free(buffer);
}

int retransmission_buffer_add(retransmission_buffer_t *buffer, uint32_t frame_id, const char *message, size_t message_len) {
    if (!buffer || !message) return 0;

    pthread_mutex_lock(&buffer->mutex);

    // Find empty slot or evict oldest
    int slot = -1;
    time_t oldest_time = time(NULL) + 1;  // Future time

    for (int i = 0; i < RETRANSMISSION_BUFFER_SIZE; i++) {
        if (buffer->entries[i].frame_id == 0) {
            // Empty slot
            slot = i;
            break;
        } else if (buffer->entries[i].timestamp < oldest_time) {
            // Track oldest for eviction
            oldest_time = buffer->entries[i].timestamp;
            slot = i;
        }
    }

    if (slot == -1) {
        // Should not happen, but just in case
        pthread_mutex_unlock(&buffer->mutex);
        return 0;
    }

    // Free old message if any
    if (buffer->entries[slot].message) {
        free(buffer->entries[slot].message);
    }

    // Add new entry
    buffer->entries[slot].frame_id = frame_id;
    buffer->entries[slot].message = malloc(message_len + 1);
    if (!buffer->entries[slot].message) {
        pthread_mutex_unlock(&buffer->mutex);
        return 0;
    }
    memcpy(buffer->entries[slot].message, message, message_len);
    buffer->entries[slot].message[message_len] = '\0';
    buffer->entries[slot].message_len = message_len;
    buffer->entries[slot].timestamp = time(NULL);
    buffer->entries[slot].retries = 0;

    if (buffer->count < RETRANSMISSION_BUFFER_SIZE) {
        buffer->count++;
    }

    pthread_mutex_unlock(&buffer->mutex);
    return 1;
}

void retransmission_buffer_ack(retransmission_buffer_t *buffer, uint32_t frame_id) {
    if (!buffer) return;

    pthread_mutex_lock(&buffer->mutex);
    for (int i = 0; i < RETRANSMISSION_BUFFER_SIZE; i++) {
        if (buffer->entries[i].frame_id == frame_id) {
            // Remove the entry
            if (buffer->entries[i].message) {
                free(buffer->entries[i].message);
            }
            buffer->entries[i].frame_id = 0;
            buffer->entries[i].message = NULL;
            buffer->entries[i].message_len = 0;
            buffer->entries[i].timestamp = 0;
            buffer->entries[i].retries = 0;
            buffer->count--;
            break;
        }
    }
    pthread_mutex_unlock(&buffer->mutex);
}

void retransmission_buffer_ko(retransmission_buffer_t *buffer, uint32_t frame_id) {
    if (!buffer) return;

    pthread_mutex_lock(&buffer->mutex);
    for (int i = 0; i < RETRANSMISSION_BUFFER_SIZE; i++) {
        if (buffer->entries[i].frame_id == frame_id) {
            // Increment retry count
            buffer->entries[i].retries++;
            if (buffer->entries[i].retries >= MAX_RETRIES) {
                // Max retries reached, remove entry
                if (buffer->entries[i].message) {
                    free(buffer->entries[i].message);
                }
                buffer->entries[i].frame_id = 0;
                buffer->entries[i].message = NULL;
                buffer->entries[i].message_len = 0;
                buffer->entries[i].timestamp = 0;
                buffer->entries[i].retries = 0;
                buffer->count--;
            } else {
                // Update timestamp for retransmission
                buffer->entries[i].timestamp = time(NULL);
            }
            break;
        }
    }
    pthread_mutex_unlock(&buffer->mutex);
}

void retransmission_buffer_gc(retransmission_buffer_t *buffer) {
    if (!buffer) return;

    time_t current_time = time(NULL);
    pthread_mutex_lock(&buffer->mutex);

    for (int i = 0; i < RETRANSMISSION_BUFFER_SIZE; i++) {
        if (buffer->entries[i].frame_id != 0 &&
            current_time - buffer->entries[i].timestamp > RETRANSMISSION_TIMEOUT) {
            // Timeout, remove entry
            if (buffer->entries[i].message) {
                free(buffer->entries[i].message);
            }
            buffer->entries[i].frame_id = 0;
            buffer->entries[i].message = NULL;
            buffer->entries[i].message_len = 0;
            buffer->entries[i].timestamp = 0;
            buffer->entries[i].retries = 0;
            buffer->count--;
        }
    }

    pthread_mutex_unlock(&buffer->mutex);
}

uint32_t retransmission_buffer_get_next_frame_id(retransmission_buffer_t *buffer) {
    if (!buffer) return 0;

    pthread_mutex_lock(&buffer->mutex);
    uint32_t frame_id = buffer->next_frame_id++;
    if (buffer->next_frame_id == 0) {  // Wrap around
        buffer->next_frame_id = 1;
    }
    pthread_mutex_unlock(&buffer->mutex);
    return frame_id;
}

// Reliable tunnel data message with frame_id and checksum
int send_tunnel_data_reliable_message(SSL *ssl, const char *request_id, const unsigned char *data, size_t data_len, retransmission_buffer_t *buffer, wsssh_encoding_t encoding, int debug) {
    if (!buffer) return 0;

    // Calculate checksum
    uint32_t checksum = crc32_checksum(data, data_len);

    // Get next frame ID
    uint32_t frame_id = retransmission_buffer_get_next_frame_id(buffer);

    char *encoded_data = NULL;
    size_t encoded_len = 0;

    if (encoding == ENCODING_BINARY) {
        // Send as binary WebSocket frame
        // For now, fall back to base64 for reliable transmission
        // TODO: Implement binary WebSocket frames for reliable transmission
        encoding = ENCODING_BASE64;
    }

    if (encoding == ENCODING_BASE64) {
        // Base64 encode the binary data
        size_t b64_len = ((data_len + 2) / 3) * 4 + 1;
        char *b64_data = malloc(b64_len);
        if (!b64_data) {
            if (debug) {
                printf("[DEBUG] Failed to allocate memory for base64 encoding (%zu bytes)\n", b64_len);
                fflush(stdout);
            }
            return 0;
        }

        // Simple base64 encoding
        static const char base64_chars[] = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
        size_t i = 0, j = 0;
        while (i < data_len) {
            uint32_t octet_a = i < data_len ? data[i++] : 0;
            uint32_t octet_b = i < data_len ? data[i++] : 0;
            uint32_t octet_c = i < data_len ? data[i++] : 0;

            uint32_t triple = (octet_a << 16) | (octet_b << 8) | octet_c;

            b64_data[j++] = base64_chars[(triple >> 18) & 0x3F];
            b64_data[j++] = base64_chars[(triple >> 12) & 0x3F];
            b64_data[j++] = base64_chars[(triple >> 6) & 0x3F];
            b64_data[j++] = base64_chars[triple & 0x3F];
        }

        // Add padding
        size_t padding = (3 - (data_len % 3)) % 3;
        for (size_t p = 0; p < padding; p++) {
            b64_data[j - 1 - p] = '=';
        }
        b64_data[j] = '\0';

        encoded_data = b64_data;
        encoded_len = b64_len;
    } else if (encoding == ENCODING_HEX) {
        // Hex encode the binary data
        size_t hex_len = data_len * 2 + 1;
        char *hex_data = malloc(hex_len);
        if (!hex_data) {
            if (debug) {
                printf("[DEBUG] Failed to allocate memory for hex encoding (%zu bytes)\n", hex_len);
                fflush(stdout);
            }
            return 0;
        }

        for (size_t i = 0; i < data_len; i++) {
            sprintf(hex_data + i * 2, "%02x", data[i]);
        }
        hex_data[data_len * 2] = '\0';

        encoded_data = hex_data;
        encoded_len = hex_len;
    }

    // Create JSON message with frame_id and checksum
    size_t msg_size = strlen("{\"type\":\"tunnel_data\",\"request_id\":\"\",\"frame_id\":,\"checksum\":,\"data\":\"\"}") +
                      strlen(request_id) + 20 + 10 + encoded_len + 1; // frame_id (10) + checksum (10)
    char *message = malloc(msg_size);
    if (!message) {
        if (debug) {
            printf("[DEBUG] Failed to allocate memory for reliable tunnel_data message (%zu bytes)\n", msg_size);
            fflush(stdout);
        }
        free(encoded_data);
        return 0;
    }

    snprintf(message, msg_size, "{\"type\":\"tunnel_data\",\"request_id\":\"%s\",\"frame_id\":%u,\"checksum\":%u,\"data\":\"%s\"}",
             request_id, frame_id, checksum, encoded_data);

    // Add to retransmission buffer
    if (!retransmission_buffer_add(buffer, frame_id, message, strlen(message))) {
        if (debug) {
            printf("[DEBUG] Failed to add message to retransmission buffer\n");
            fflush(stdout);
        }
        free(message);
        free(encoded_data);
        return 0;
    }

    if (!send_websocket_frame(ssl, message)) {
        if (debug) {
            printf("[DEBUG] Failed to send reliable tunnel_data WebSocket frame\n");
            fflush(stdout);
        }
        free(message);
        free(encoded_data);
        return 0;
    }

    free(message);
    free(encoded_data);
    return 1;
}

// ACK message
int send_tunnel_ack_message(SSL *ssl, const char *request_id, uint32_t frame_id, int debug) {
    char message[256];
    snprintf(message, sizeof(message), "{\"type\":\"tunnel_ack\",\"request_id\":\"%s\",\"frame_id\":%u}", request_id, frame_id);

    if (debug) {
        printf("[DEBUG] Sending tunnel_ack: %s\n", message);
        fflush(stdout);
    }

    return send_websocket_frame(ssl, message);
}

// KO (error) message
int send_tunnel_ko_message(SSL *ssl, const char *request_id, uint32_t frame_id, int debug) {
    char message[256];
    snprintf(message, sizeof(message), "{\"type\":\"tunnel_ko\",\"request_id\":\"%s\",\"frame_id\":%u}", request_id, frame_id);

    if (debug) {
        printf("[DEBUG] Sending tunnel_ko: %s\n", message);
        fflush(stdout);
    }

    return send_websocket_frame(ssl, message);
}



int send_tunnel_response_message(SSL *ssl, const char *request_id, const char *data_hex, int debug) {
    // Send as tunnel_response (from target back to WebSocket) with size information
    size_t hex_len = strlen(data_hex);
    size_t binary_size = hex_len / 2; // Size of actual binary data
    size_t request_id_len = strlen(request_id);
    size_t json_overhead = strlen("{\"type\":\"tunnel_response\",\"request_id\":\"\",\"size\":,\"data\":\"\"}");
    size_t msg_size = json_overhead + request_id_len + 32 + hex_len + 1; // Extra 32 for safety
    char *message = malloc(msg_size);
    if (!message) {
        if (debug) {
            printf("[DEBUG] Failed to allocate memory for tunnel_response message (%zu bytes)\n", msg_size);
            fflush(stdout);
        }
        return 0;
    }
    int msg_len = snprintf(message, msg_size,
             "{\"type\":\"tunnel_response\",\"request_id\":\"%s\",\"size\":%zu,\"data\":\"%s\"}",
             request_id, binary_size, data_hex);
    if (msg_len < 0 || (size_t)msg_len >= msg_size) {
        if (debug) {
            printf("[DEBUG] Failed to format tunnel_response message (msg_len=%d, msg_size=%zu)\n", msg_len, msg_size);
            fflush(stdout);
        }
        free(message);
        return 0;
    }

    if (!send_websocket_frame(ssl, message)) {
        if (debug) {
            printf("[DEBUG] Failed to send tunnel_response WebSocket frame\n");
            fflush(stdout);
        }
        free(message);
        return 0;
    }

    free(message);
    return 1;
}
