/**
 * WebSocket protocol implementation from scratch
 *
 * Copyright (C) 2024 Stefy Lanza <stefy@nexlab.net> and SexHack.me
 *
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see <https://www.gnu.org/licenses/>.
 */

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>
#include <errno.h>
#include <string.h>  // for strerror
#include <fcntl.h>
#include <openssl/sha.h>
#include <openssl/ssl.h>
#include "websocket_protocol.h"

// Base64 encoding table
static const char base64_table[] = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";

ws_connection_t *ws_connection_create(SSL *ssl, int sock_fd) {
    ws_connection_t *conn = calloc(1, sizeof(ws_connection_t));
    if (!conn) return NULL;

    conn->ssl = ssl;
    conn->sock_fd = sock_fd;
    conn->state = WS_STATE_CONNECTING;
    conn->recv_buffer_size = 4096;
    conn->recv_buffer = malloc(conn->recv_buffer_size);
    if (!conn->recv_buffer) {
        free(conn);
        return NULL;
    }

    return conn;
}

ws_connection_t *ws_connection_create_plain(int sock_fd) {
    return ws_connection_create(NULL, sock_fd);
}

void ws_connection_free(ws_connection_t *conn) {
    if (!conn) return;

    // Free the receive buffer
    free(conn->recv_buffer);

    // Clean up SSL connection
    if (conn->ssl) {
        SSL_shutdown(conn->ssl);
        SSL_free(conn->ssl);
    }

    // Close socket
    if (conn->sock_fd >= 0) {
        close(conn->sock_fd);
    }

    free(conn);
}

// SHA-1 implementation (using OpenSSL)
void sha1(const unsigned char *data, size_t len, unsigned char *hash) {
    SHA1(data, len, hash);
}

// Base64 encoding
char *base64_encode(const unsigned char *data, size_t len) {
    size_t out_len = ((len + 2) / 3) * 4 + 1;
    char *out = malloc(out_len);
    if (!out) return NULL;

    size_t i, j;
    for (i = 0, j = 0; i < len; i += 3, j += 4) {
        uint32_t triple = (i < len ? data[i] : 0) << 16 |
                         (i + 1 < len ? data[i + 1] : 0) << 8 |
                         (i + 2 < len ? data[i + 2] : 0);

        out[j] = base64_table[(triple >> 18) & 0x3F];
        out[j + 1] = base64_table[(triple >> 12) & 0x3F];
        out[j + 2] = base64_table[(triple >> 6) & 0x3F];
        out[j + 3] = base64_table[triple & 0x3F];
    }

    // Add padding
    if (len % 3 == 1) {
        out[j - 2] = '=';
        out[j - 1] = '=';
    } else if (len % 3 == 2) {
        out[j - 1] = '=';
    }

    out[j] = '\0';
    return out;
}

// Compute WebSocket accept key
char *ws_compute_accept_key(const char *key) {
    if (!key) return NULL;

    // Concatenate key with magic string
    char combined[256];
    snprintf(combined, sizeof(combined), "%s%s", key, "258EAFA5-E914-47DA-95CA-C5AB0DC85B11");

    // SHA-1 hash
    unsigned char hash[SHA_DIGEST_LENGTH];
    sha1((unsigned char *)combined, strlen(combined), hash);

    // Base64 encode
    return base64_encode(hash, SHA_DIGEST_LENGTH);
}

// Mask/unmask data
void ws_mask_data(uint8_t *data, size_t len, const uint8_t *mask) {
    for (size_t i = 0; i < len; i++) {
        data[i] ^= mask[i % 4];
    }
}

void ws_unmask_data(uint8_t *data, size_t len, const uint8_t *mask) {
    // Same as masking
    ws_mask_data(data, len, mask);
}

// Parse WebSocket frame header
static bool ws_parse_frame_header(const uint8_t *buffer, size_t len, ws_frame_header_t *header) {
    if (len < 2) return false;

    header->fin = (buffer[0] & 0x80) != 0;
    header->rsv1 = (buffer[0] & 0x40) != 0;
    header->rsv2 = (buffer[0] & 0x20) != 0;
    header->rsv3 = (buffer[0] & 0x10) != 0;
    header->opcode = buffer[0] & 0x0F;

    header->masked = (buffer[1] & 0x80) != 0;
    uint8_t payload_len = buffer[1] & 0x7F;

    size_t header_len = 2;

    if (payload_len == 126) {
        if (len < 4) return false;
        header->payload_len = (buffer[2] << 8) | buffer[3];
        header_len = 4;
    } else if (payload_len == 127) {
        if (len < 10) return false;
        header->payload_len = 0;
        for (int i = 0; i < 8; i++) {
            header->payload_len = (header->payload_len << 8) | buffer[2 + i];
        }
        header_len = 10;
    } else {
        header->payload_len = payload_len;
    }

    // Validate payload length to prevent memory exhaustion attacks
    // Limit to 10MB to prevent excessive memory allocation
    const size_t MAX_PAYLOAD_SIZE = 10 * 1024 * 1024; // 10MB
    if (header->payload_len > MAX_PAYLOAD_SIZE) {
        return false; // Reject frames with excessively large payloads
    }

    // Payload length validation is done later when we actually read the payload
    // At header parsing time, we only validate the header structure itself

    if (header->masked) {
        if (len < header_len + 4) return false;
        memcpy(header->masking_key, buffer + header_len, 4);
        header_len += 4;
    }

    return true;
}

// Perform WebSocket handshake
bool ws_perform_handshake(ws_connection_t *conn) {
    // Read HTTP request
    char buffer[4096];
    int bytes_read;
    if (conn->ssl) {
        bytes_read = SSL_read(conn->ssl, buffer, sizeof(buffer) - 1);
    } else {
        bytes_read = read(conn->sock_fd, buffer, sizeof(buffer) - 1);
    }
    if (bytes_read <= 0) return false;

    buffer[bytes_read] = '\0';

    // Parse HTTP headers (avoid strtok which modifies the buffer)
    char *sec_websocket_key = NULL;
    bool is_websocket_upgrade = false;

    char *buffer_end = buffer + bytes_read;
    char *line_start = buffer;
    char *line_end;

    while (line_start < buffer_end) {
        // Find end of line
        line_end = line_start;
        while (line_end < buffer_end && *line_end != '\r' && *line_end != '\n') {
            line_end++;
        }

        if (line_end > line_start) {
            // Null-terminate the line temporarily for string operations
            char saved_char = *line_end;
            *line_end = '\0';

            if (strncasecmp(line_start, "GET ", 4) == 0) {
                // Check for WebSocket upgrade
                if (strstr(line_start, "HTTP/1.1") && strstr(line_start, "/")) {
                    is_websocket_upgrade = true;
                }
            } else if (strncasecmp(line_start, "Sec-WebSocket-Key: ", 19) == 0) {
                sec_websocket_key = line_start + 19;
                // Trim whitespace
                while (*sec_websocket_key == ' ' && sec_websocket_key < line_end) {
                    sec_websocket_key++;
                }
            } else if (strncasecmp(line_start, "Upgrade: websocket", 18) == 0) {
                is_websocket_upgrade = true;
            }

            // Restore the character
            *line_end = saved_char;
        }

        // Move to next line
        line_start = line_end;
        if (line_start < buffer_end && *line_start == '\r') line_start++;
        if (line_start < buffer_end && *line_start == '\n') line_start++;
    }

    if (!is_websocket_upgrade || !sec_websocket_key) {
        return false;
    }

    // Compute accept key
    char *accept_key = ws_compute_accept_key(sec_websocket_key);
    if (!accept_key) return false;

    // Send handshake response
    char response[512];
    snprintf(response, sizeof(response),
             "HTTP/1.1 101 Switching Protocols\r\n"
             "Upgrade: websocket\r\n"
             "Connection: Upgrade\r\n"
             "Sec-WebSocket-Accept: %s\r\n"
             "\r\n", accept_key);

    free(accept_key);

    int bytes_written;
    if (conn->ssl) {
        bytes_written = SSL_write(conn->ssl, response, strlen(response));
    } else {
        bytes_written = write(conn->sock_fd, response, strlen(response));
    }
    if (bytes_written <= 0) return false;

    conn->state = WS_STATE_OPEN;
    return true;
}

// Send WebSocket frame
bool ws_send_frame(ws_connection_t *conn, uint8_t opcode, const void *data, size_t len) {
    if (!conn) {
        return false;
    }

    if (conn->state != WS_STATE_OPEN) {
        return false;
    }

    // Allow non-SSL connections for web interface

    size_t header_len = 2;
    if (len >= 126) {
        if (len < 65536) {
            header_len = 4;
        } else {
            header_len = 10;
        }
    }

    size_t frame_len = header_len + len;
    uint8_t *frame = malloc(frame_len);
    if (!frame) {
        return false;
    }

    // Frame header
    frame[0] = 0x80 | opcode; // FIN bit set

    if (len < 126) {
        frame[1] = len;
    } else if (len < 65536) {
        frame[1] = 126;
        frame[2] = (len >> 8) & 0xFF;
        frame[3] = len & 0xFF;
    } else {
        frame[1] = 127;
        // Only support 32-bit lengths for simplicity
        frame[2] = frame[3] = frame[4] = frame[5] = 0;
        frame[6] = (len >> 24) & 0xFF;
        frame[7] = (len >> 16) & 0xFF;
        frame[8] = (len >> 8) & 0xFF;
        frame[9] = len & 0xFF;
    }

    // Copy data
    if (len > 0) {
        memcpy(frame + header_len, data, len);
    }

    // Send frame with partial write handling and retry logic
    int total_written = 0;
    int retry_count = 0;
    const int max_retries = 5; // Increased retries for better stability

    while (total_written < (int)frame_len && retry_count < max_retries) {
        int to_write = frame_len - total_written;
        int written;
        if (conn->ssl) {
            written = SSL_write(conn->ssl, frame + total_written, to_write);
            if (written <= 0) {
                int ssl_error = SSL_get_error(conn->ssl, written);

                // Check for recoverable SSL errors
                if ((ssl_error == SSL_ERROR_WANT_READ || ssl_error == SSL_ERROR_WANT_WRITE ||
                     ssl_error == SSL_ERROR_SSL || ssl_error == SSL_ERROR_SYSCALL) && retry_count < max_retries - 1) {
                    retry_count++;
                    // Exponential backoff: wait longer between retries
                    usleep(10000 * (1 << retry_count)); // 10ms, 20ms, 40ms, 80ms
                    continue; // Retry the write operation
                } else {
                    conn->state = WS_STATE_CLOSED;
                    free(frame);
                    return false;
                }
            }
        } else {
            written = write(conn->sock_fd, frame + total_written, to_write);
            if (written <= 0) {
                conn->state = WS_STATE_CLOSED;
                free(frame);
                return false;
            }
        }
        total_written += written;
        retry_count = 0; // Reset retry count on successful write
    }

    if (total_written < (int)frame_len) {
        free(frame);
        return false;
    }
    free(frame);
    return total_written == (int)frame_len;
}

// Receive WebSocket frame
bool ws_receive_frame(ws_connection_t *conn, uint8_t *opcode, void **data, size_t *len) {
    if (conn->state != WS_STATE_OPEN) return false;

    // Read minimum frame header (2 bytes) to determine full header size
    uint8_t header[14];
    int bytes_read;
    if (conn->ssl) {
        bytes_read = SSL_read(conn->ssl, header, 2);
        if (bytes_read <= 0) {
            return false;
        }
    } else {
        bytes_read = read(conn->sock_fd, header, 2);
        if (bytes_read < 0) {
            if (errno == EAGAIN || errno == EWOULDBLOCK) {
                return false; // No data available, not an error
            }
            return false; // Real error
        }
    }
    if (bytes_read != 2) {
        return false;
    }

    // Determine minimum header size needed for parsing
    uint8_t payload_len_indicator = header[1] & 0x7F;
    size_t min_header_size = 2;

    if (payload_len_indicator == 126) {
        min_header_size = 4;
    } else if (payload_len_indicator == 127) {
        min_header_size = 10;
    }

    // Read additional header bytes if needed
    if (min_header_size > 2) {
        int total_read = 0;
        while (total_read < (int)(min_header_size - 2)) {
            if (conn->ssl) {
                bytes_read = SSL_read(conn->ssl, header + 2 + total_read, min_header_size - 2 - total_read);
                if (bytes_read <= 0) {
                    return false;
                }
            } else {
                bytes_read = read(conn->sock_fd, header + 2 + total_read, min_header_size - 2 - total_read);
                if (bytes_read < 0) {
                    if (errno == EAGAIN || errno == EWOULDBLOCK) {
                        return false; // No data available
                    }
                    return false; // Real error
                }
            }
            total_read += bytes_read;
        }
    }

    // Now read the masking key if present
    bool masked = (header[1] & 0x80) != 0;
    size_t total_header_size = min_header_size;
    if (masked) {
        if (conn->ssl) {
            bytes_read = SSL_read(conn->ssl, header + min_header_size, 4);
            if (bytes_read != 4) {
                return false;
            }
        } else {
            bytes_read = read(conn->sock_fd, header + min_header_size, 4);
            if (bytes_read < 0) {
                if (errno == EAGAIN || errno == EWOULDBLOCK) {
                    return false; // No data available
                }
                return false; // Real error
            }
            if (bytes_read != 4) {
                return false;
            }
        }
        total_header_size += 4;
    }

    // Validate header size
    if (total_header_size > sizeof(header)) {
        return false;
    }

    ws_frame_header_t frame_header;
    if (!ws_parse_frame_header(header, total_header_size, &frame_header)) {
        return false;
    }

    // Allocate buffer for payload with additional safety check
    if (frame_header.payload_len == 0) {
        *data = NULL;
        *len = 0;
        *opcode = frame_header.opcode;
        return true;
    }

    // Additional validation for payload length
    // Protect against memory exhaustion attacks with reasonable limit
    const size_t MAX_SAFE_PAYLOAD = 50 * 1024 * 1024; // 50MB safety limit
    if (frame_header.payload_len > MAX_SAFE_PAYLOAD) {
        return false;
    }

    *data = malloc(frame_header.payload_len + 1); // +1 for null termination
    if (!*data) {
        return false;
    }

    // Read payload with timeout protection
    size_t total_read = 0;
    while (total_read < frame_header.payload_len) {
        size_t remaining = frame_header.payload_len - total_read;
        // Limit read size to prevent excessive blocking
        size_t to_read = remaining > 8192 ? 8192 : remaining;

        if (conn->ssl) {
            bytes_read = SSL_read(conn->ssl, (char *)*data + total_read, to_read);
            if (bytes_read <= 0) {
                free(*data);
                return false;
            }
        } else {
            bytes_read = read(conn->sock_fd, (char *)*data + total_read, to_read);
            if (bytes_read < 0) {
                if (errno == EAGAIN || errno == EWOULDBLOCK) {
                    // For non-blocking sockets, if we can't read the complete payload at once,
                    // this is an error since WebSocket frames should be complete
                    free(*data);
                    return false;
                }
                free(*data);
                return false;
            }
            if (bytes_read == 0) {
                // Connection closed
                free(*data);
                return false;
            }
        }
        total_read += bytes_read;
    }

    // Verify we read the complete payload
    if (total_read != frame_header.payload_len) {
        free(*data);
        return false;
    }

    // Null terminate the payload
    ((char *)*data)[frame_header.payload_len] = '\0';

    // Unmask if needed
    if (frame_header.masked) {
        ws_unmask_data(*data, frame_header.payload_len, frame_header.masking_key);
    }

    *opcode = frame_header.opcode;
    *len = frame_header.payload_len;
    return true;
}

// Send raw WebSocket binary frame (for VNC connections)
bool ws_send_binary_frame(ws_connection_t *conn, const void *data, size_t len, bool debug) {
    if (!conn || conn->state != WS_STATE_OPEN) {
        return false;
    }

    if (debug) {
        printf("[WS-BINARY] Sending binary frame, len=%zu\n", len);
    }

    // WebSocket binary frame construction for VNC
    // No masking (server-to-client), with retry logic for partial writes
    size_t header_len = 2;
    if (len >= 126) {
        if (len < 65536) {
            header_len = 4;
        } else {
            header_len = 10;
        }
    }

    size_t frame_len = header_len + len;
    uint8_t *frame = malloc(frame_len);
    if (!frame) {
        if (debug) printf("[WS-BINARY] Failed to allocate frame buffer\n");
        return false;
    }

    // Frame header - binary frame, no masking
    frame[0] = 0x82; // FIN + binary opcode

    if (len < 126) {
        frame[1] = len & 0x7F; // No mask bit
        if (debug) printf("[WS-BINARY] Header: FIN=1, opcode=2, len=%zu (<126)\n", len);
    } else if (len < 65536) {
        frame[1] = 126 & 0x7F; // No mask bit
        frame[2] = (len >> 8) & 0xFF;
        frame[3] = len & 0xFF;
        if (debug) printf("[WS-BINARY] Header: FIN=1, opcode=2, len=%zu (16-bit)\n", len);
    } else {
        frame[1] = 127 & 0x7F; // No mask bit
        frame[2] = frame[3] = frame[4] = frame[5] = 0;
        frame[6] = (len >> 24) & 0xFF;
        frame[7] = (len >> 16) & 0xFF;
        frame[8] = (len >> 8) & 0xFF;
        frame[9] = len & 0xFF;
        if (debug) printf("[WS-BINARY] Header: FIN=1, opcode=2, len=%zu (64-bit)\n", len);
    }

    // Copy data
    if (len > 0) {
        memcpy(frame + header_len, data, len);
    }

    // Send frame - for WebSocket binary frames, we need to send atomically
    // Temporarily set socket to blocking mode to ensure complete frame transmission
    int original_flags = -1;
    if (!conn->ssl) {
        original_flags = fcntl(conn->sock_fd, F_GETFL, 0);
        if (original_flags != -1) {
            fcntl(conn->sock_fd, F_SETFL, original_flags & ~O_NONBLOCK);
        }
    }

    int total_written = 0;
    int retry_count = 0;
    const int max_retries = 10;

    if (debug) printf("[WS-BINARY] Sending frame of total length %zu\n", frame_len);

    while (total_written < (int)frame_len && retry_count < max_retries) {
        int to_write = frame_len - total_written;
        int written;
        if (conn->ssl) {
            written = SSL_write(conn->ssl, frame + total_written, to_write);
            if (written <= 0) {
                int ssl_error = SSL_get_error(conn->ssl, written);
                if (debug) printf("[WS-BINARY] SSL write error: %d, written=%d\n", ssl_error, written);
                // Check for recoverable SSL errors
                if ((ssl_error == SSL_ERROR_WANT_READ || ssl_error == SSL_ERROR_WANT_WRITE ||
                     ssl_error == SSL_ERROR_SSL || ssl_error == SSL_ERROR_SYSCALL) && retry_count < max_retries - 1) {
                    retry_count++;
                    usleep(10000 * (1 << retry_count)); // Exponential backoff
                    continue;
                } else {
                    conn->state = WS_STATE_CLOSED;
                    free(frame);
                    return false;
                }
            }
        } else {
            written = write(conn->sock_fd, frame + total_written, to_write);
            if (written < 0) {
                if (errno == EAGAIN || errno == EWOULDBLOCK) {
                    // Should not happen in blocking mode, but handle anyway
                    if (debug) printf("[WS-BINARY] Socket would block in blocking mode, retrying\n");
                    retry_count++;
                    usleep(1000); // Short delay before retry
                    continue;
                } else {
                    if (debug) printf("[WS-BINARY] Socket write error: errno=%d (%s), written=%d\n", errno, strerror(errno), written);
                    conn->state = WS_STATE_CLOSED;
                    free(frame);
                    return false;
                }
            } else if (written == 0) {
                // Connection closed
                if (debug) printf("[WS-BINARY] Connection closed during write\n");
                conn->state = WS_STATE_CLOSED;
                free(frame);
                return false;
            }
        }
        total_written += written;
        retry_count = 0; // Reset retry count on successful write
        if (debug) printf("[WS-BINARY] Wrote %d bytes, total written: %d/%zu\n", written, total_written, frame_len);
    }

    // Restore original socket flags
    if (!conn->ssl && original_flags != -1) {
        fcntl(conn->sock_fd, F_SETFL, original_flags);
    }

    if (total_written != (int)frame_len) {
        if (debug) printf("[WS-BINARY] Failed to write complete frame: %d/%zu\n", total_written, frame_len);
        free(frame);
        return false;
    }

    if (debug) printf("[WS-BINARY] Successfully sent binary frame\n");
    free(frame);
    return true;
}

// Check if WebSocket connection is healthy
bool ws_connection_is_healthy(ws_connection_t *conn) {
    // Simple health check - just verify connection is open and has SSL context or socket
    // Actual connection health is determined by send/receive operations
    return conn && conn->state == WS_STATE_OPEN && (conn->ssl || conn->sock_fd >= 0);
}