/*
 * WSSSH Tunnel (wsssht) - Thread Functions Implementation
 * Thread-related functions for wsssht
 *
 * Copyright (C) 2024 Stefy Lanza <stefy@nexlab.net> and SexHack.me
 *
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see <https://www.gnu.org/licenses/>.
 */

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <netdb.h>
#include <getopt.h>
#include <sys/wait.h>
#include <fcntl.h>
#include <pthread.h>
#include <sys/select.h>
#include <time.h>
#include <errno.h>

#include "wssshlib.h"
#include "websocket.h"
#include "wssh_ssl.h"
#include "tunnel.h"
#include "wsssht.h"
#include "threads.h"
#include "control_messages.h"

void *run_tunnel_thread(void *arg) {
    tunnel_thread_args_t *args = (tunnel_thread_args_t *)arg;

    // Establish the tunnel for this connection
    tunnel_setup_result_t setup_result = setup_tunnel(args->wssshd_host, args->wssshd_port, args->client_id, 0, args->config->debug, 0, args->tunnel_host, args->config->encoding, 1);
    int tunnel_sock = setup_result.listen_sock;
    if (tunnel_sock < 0) {
        fprintf(stderr, "Failed to establish tunnel for connection\n");
        close(args->accepted_sock);
        free(args);
        return NULL;
    }
    // Close the dummy listening socket created by setup_tunnel
    close(tunnel_sock);

    // Print tunnel information (unless silent mode)
    if (args->config->mode != MODE_SILENT) {
        printf("\n");
        printf("========================================\n");
        printf("        WEBSSH TUNNEL READY\n");
        printf("========================================\n");
        printf("Tunnel established successfully!\n");
        printf("Local port: %d\n", atoi(args->config->local_port));
        printf("Target: %s@%s\n", args->client_id, args->wssshd_host);
        printf("\n");
        printf("Connect manually using one of these commands:\n");
        printf("\n");
        printf("  Telnet:\n");
        printf("    telnet localhost %d\n", atoi(args->config->local_port));
        printf("\n");
        printf("  Netcat:\n");
        printf("    nc localhost %d\n", atoi(args->config->local_port));
        printf("\n");
        printf("  SSH (if connecting to SSH server):\n");
        printf("    ssh -p %d user@localhost\n", atoi(args->config->local_port));
        printf("\n");
        printf("  SCP (if connecting to SSH server):\n");
        printf("    scp -P %d user@localhost:/remote/path ./local/path\n", atoi(args->config->local_port));
        printf("\n");
        printf("  Any TCP client:\n");
        printf("    Connect to localhost:%d\n", atoi(args->config->local_port));
        printf("\n");
        printf("Press Ctrl+C to close the tunnel and exit.\n");
        printf("========================================\n");
        printf("\n");
    }

    // Set the accepted socket with mutex protection
    pthread_mutex_lock(&tunnel_mutex);
    active_tunnel->local_sock = args->accepted_sock;

    // Send any buffered data to the client immediately
    if (active_tunnel->incoming_buffer && active_tunnel->incoming_buffer->used > 0) {
        if (args->config->debug) {
            printf("[DEBUG - Tunnel] Sending %zu bytes of buffered server response to client\n", active_tunnel->incoming_buffer->used);
            fflush(stdout);
        }
        ssize_t sent = send(args->accepted_sock, active_tunnel->incoming_buffer->buffer, active_tunnel->incoming_buffer->used, 0);
        if (sent > 0) {
            frame_buffer_consume(active_tunnel->incoming_buffer, sent);
            if (args->config->debug) {
                printf("[DEBUG] Sent %zd bytes of buffered server response to client\n", sent);
                fflush(stdout);
            }
        }
    }

    pthread_mutex_unlock(&tunnel_mutex);

    if (args->config->debug) {
        printf("[DEBUG - Tunnel] Local connection accepted! Starting data forwarding...\n");
        fflush(stdout);
    }

    // Get initial SSL connection for thread
    pthread_mutex_lock(&tunnel_mutex);
    SSL *current_ssl = active_tunnel ? active_tunnel->ssl : NULL;
    pthread_mutex_unlock(&tunnel_mutex);

    // Start forwarding thread
    thread_args_t *thread_args = malloc(sizeof(thread_args_t));
    if (!thread_args) {
        perror("Memory allocation failed for thread args");
        close(active_tunnel->local_sock);
        free(active_tunnel);
        active_tunnel = NULL;
        free(args);
        return NULL;
    }
    thread_args->ssl = current_ssl;
    thread_args->tunnel = active_tunnel;
    thread_args->debug = args->config->debug;

    pthread_t forwarding_thread;
    pthread_create(&forwarding_thread, NULL, forward_tcp_to_ws, thread_args);
    pthread_detach(forwarding_thread);

    // Main tunnel loop - handle WebSocket messages
    char buffer[BUFFER_SIZE];
    int bytes_read;
    fd_set tunnel_readfds;
    struct timeval tunnel_tv;

    // Frame accumulation buffer for handling partial WebSocket frames
    char frame_buffer[BUFFER_SIZE * 4];
    int frame_buffer_used = 0;

    while (1) {
        // Get SSL fd with mutex protection
        pthread_mutex_lock(&tunnel_mutex);
        if (!active_tunnel || !active_tunnel->active) {
            if (active_tunnel && active_tunnel->broken) {
                pthread_mutex_unlock(&tunnel_mutex);
                goto thread_cleanup;
            } else {
                // normal closure
                pthread_mutex_unlock(&tunnel_mutex);
                if (args->config->debug) {
                    printf("[DEBUG - Tunnel] Tunnel is no longer active, exiting main loop\n");
                    fflush(stdout);
                }
                break;
            }
        }
        int ssl_fd = SSL_get_fd(active_tunnel->ssl);
        current_ssl = active_tunnel->ssl;

        // Check if local socket is still valid
        if (active_tunnel->local_sock < 0) {
            if (args->config->debug) {
                printf("[DEBUG - Tunnel] Local socket is invalid, tunnel broken\n");
                fflush(stdout);
            }
            active_tunnel->broken = 1;
            // Send tunnel_close notification
            if (args->config->debug) {
                printf("[DEBUG - Tunnel] Sending tunnel_close notification due to invalid local socket...\n");
                fflush(stdout);
            }
            send_tunnel_close(current_ssl, active_tunnel->request_id, args->config->debug);
            pthread_mutex_unlock(&tunnel_mutex);
            goto thread_cleanup;
        }

        // Check if the local socket connection is broken
        char test_buf[1];
        int result = recv(active_tunnel->local_sock, test_buf, 1, MSG_PEEK | MSG_DONTWAIT);
        if (result == 0 || (result < 0 && (errno == ECONNRESET || errno == EPIPE || errno == EBADF))) {
            if (args->config->debug) {
                printf("[DEBUG - Tunnel] Local socket connection is broken (errno=%d), sending tunnel_close\n", errno);
                fflush(stdout);
            }
            active_tunnel->broken = 1;
            // Send tunnel_close notification
            send_tunnel_close(current_ssl, active_tunnel->request_id, args->config->debug);
            pthread_mutex_unlock(&tunnel_mutex);
            goto thread_cleanup;
        }

        pthread_mutex_unlock(&tunnel_mutex);

        // Use select to wait for data on SSL socket with timeout
        FD_ZERO(&tunnel_readfds);
        FD_SET(ssl_fd, &tunnel_readfds);
        tunnel_tv.tv_sec = 0;
        tunnel_tv.tv_usec = 50000;  // 50ms timeout

        int retval = select(ssl_fd + 1, &tunnel_readfds, NULL, NULL, &tunnel_tv);
        if (retval == -1) {
            if (args->config->debug) {
                perror("[DEBUG - WebSockets] select on SSL fd failed");
                fflush(stdout);
            }
            // Send tunnel_close notification
            if (args->config->debug) {
                printf("[DEBUG - Tunnel] Sending tunnel_close notification due to select failure...\n");
                fflush(stdout);
            }
            send_tunnel_close(current_ssl, active_tunnel->request_id, args->config->debug);
            goto thread_cleanup;
        } else if (retval == 0) {
            // Timeout, check if tunnel became inactive
            pthread_mutex_lock(&tunnel_mutex);
            if (!active_tunnel || !active_tunnel->active) {
                pthread_mutex_unlock(&tunnel_mutex);
                if (args->config->debug) {
                    printf("[DEBUG - Tunnel] Tunnel became inactive during timeout, exiting\n");
                    fflush(stdout);
                }
                goto thread_cleanup;
            }
            pthread_mutex_unlock(&tunnel_mutex);
            continue;
        }

        // Read more data if we don't have a complete frame
        if (FD_ISSET(ssl_fd, &tunnel_readfds)) {
            if ((size_t)frame_buffer_used < sizeof(frame_buffer)) {
                // Validate SSL connection state
                if (SSL_get_shutdown(current_ssl) & SSL_RECEIVED_SHUTDOWN) {
                    if (args->config->debug) {
                        printf("[DEBUG - WebSockets] SSL connection has received shutdown\n");
                        fflush(stdout);
                    }
                    cleanup_tunnel(args->config->debug);
                    break;
                }

                // Set up timeout for SSL read
                fd_set readfds_timeout;
                struct timeval tv_timeout;
                int sock_fd = SSL_get_fd(current_ssl);

                FD_ZERO(&readfds_timeout);
                FD_SET(sock_fd, &readfds_timeout);
                tv_timeout.tv_sec = 5;
                tv_timeout.tv_usec = 0;

                int select_result = select(sock_fd + 1, &readfds_timeout, NULL, NULL, &tv_timeout);
                if (select_result == -1) {
                    if (args->config->debug) {
                        perror("[DEBUG - WebSockets] select failed");
                        fflush(stdout);
                    }
                    cleanup_tunnel(args->config->debug);
                    break;
                } else if (select_result == 0) {
                    if (args->config->debug) {
                        printf("[DEBUG - WebSockets] SSL read timeout\n");
                        fflush(stdout);
                    }
                    continue;
                }

                bytes_read = SSL_read(current_ssl, frame_buffer + frame_buffer_used, sizeof(frame_buffer) - frame_buffer_used);
                if (bytes_read <= 0) {
                    if (bytes_read < 0) {
                        int ssl_error = SSL_get_error(current_ssl, bytes_read);
                        if (args->config->debug) {
                            printf("[DEBUG - WebSockets] SSL read error: %d\n", ssl_error);
                            fflush(stdout);
                        }

                        // Handle transient SSL errors
                        if (ssl_error == SSL_ERROR_WANT_READ || ssl_error == SSL_ERROR_WANT_WRITE) {
                            if (args->config->debug) {
                                printf("[DEBUG - WebSockets] Transient SSL error, retrying...\n");
                                fflush(stdout);
                            }
                            usleep(10000);
                            continue;
                        }

                        // Print SSL error details
                        char error_buf[256];
                        ERR_error_string_n(ssl_error, error_buf, sizeof(error_buf));
                        if (args->config->debug) {
                            printf("[DEBUG - WebSockets] SSL error details: %s\n", error_buf);
                            fflush(stdout);
                        }
                        fprintf(stderr, "SSL read error (%d): %s\n", ssl_error, error_buf);
                    } else {
                        if (args->config->debug) {
                            printf("[DEBUG - WebSockets] Connection closed by server (EOF)\n");
                            fflush(stdout);
                        }
                    }

                    if (args->config->debug) {
                        printf("[DEBUG - WebSockets] WebSocket connection lost, attempting reconnection...\n");
                        fflush(stdout);
                    }

                    // Attempt reconnection
                    int reconnect_attempts = 0;
                    int max_reconnect_attempts = 3;
                    int reconnected = 0;

                    while (reconnect_attempts < max_reconnect_attempts && !reconnected) {
                        if (args->config->debug) {
                            printf("[DEBUG - WebSockets] WebSocket reconnection attempt %d/%d\n", reconnect_attempts + 1, max_reconnect_attempts);
                            fflush(stdout);
                        }

                        pthread_mutex_lock(&tunnel_mutex);
                        if (!active_tunnel) {
                            pthread_mutex_unlock(&tunnel_mutex);
                            break;
                        }
                        if (reconnect_websocket(active_tunnel, args->wssshd_host, args->wssshd_port, args->client_id, active_tunnel->request_id, args->config->debug) == 0) {
                            reconnected = 1;
                            if (args->config->debug) {
                                printf("[DEBUG - WebSockets] WebSocket reconnection successful, continuing tunnel\n");
                                fflush(stdout);
                            }
                            // Update ssl_fd for select
                            ssl_fd = SSL_get_fd(active_tunnel->ssl);
                            current_ssl = active_tunnel->ssl;
                        }
                        pthread_mutex_unlock(&tunnel_mutex);

                        if (!reconnected) {
                            reconnect_attempts++;
                            if (reconnect_attempts < max_reconnect_attempts) {
                                if (args->config->debug) {
                                    printf("[DEBUG - WebSockets] WebSocket reconnection failed, waiting 1 second...\n");
                                    fflush(stdout);
                                }
                                sleep(1);
                            }
                        }
                    }

                    if (!reconnected) {
                        if (args->config->debug) {
                            printf("[DEBUG - WebSockets] All reconnection attempts failed, exiting\n");
                            fflush(stdout);
                        }

                        // Send tunnel_close notification
                        if (args->config->debug) {
                            printf("[DEBUG - Tunnel] Sending tunnel_close notification due to connection failure...\n");
                            fflush(stdout);
                        }
                        send_tunnel_close(current_ssl, active_tunnel->request_id, args->config->debug);
                        goto thread_cleanup;
                    }

                    // Skip processing this iteration since we just reconnected
                    continue;
                }

                frame_buffer_used += bytes_read;

                if (args->config->debug) {
                    printf("[DEBUG - WebSockets] Accumulated %d bytes, frame: 0x%02x 0x%02x 0x%02x 0x%02x\n", frame_buffer_used, frame_buffer[0], frame_buffer[1], frame_buffer[2], frame_buffer[3]);
                    fflush(stdout);
                }
            }

            // Try to parse WebSocket frame
            char *payload;
            int payload_len;
            if (parse_websocket_frame(frame_buffer, frame_buffer_used, &payload, &payload_len)) {
                // Frame is complete, determine frame type
                unsigned char frame_type = frame_buffer[0] & 0x8F;

                if (frame_type == 0x88) { // Close frame
                    if (args->config->debug) {
                        printf("[DEBUG - WebSockets] Received close frame from server\n");
                        fflush(stdout);
                    }
                    // Send tunnel_close notification
                    if (args->config->debug) {
                        printf("[DEBUG - Tunnel] Sending tunnel_close notification due to server close frame...\n");
                        fflush(stdout);
                    }
                    send_tunnel_close(current_ssl, active_tunnel->request_id, args->config->debug);
                    goto thread_cleanup;
                } else if (frame_type == 0x89) { // Ping frame
                    if (args->config->debug) {
                        printf("[DEBUG - WebSockets] Received ping frame, sending pong\n");
                        fflush(stdout);
                    }
                    // Send pong
                    if (!send_pong_frame(current_ssl, payload, payload_len)) {
                        if (args->config->debug) {
                            printf("[DEBUG - WebSockets] Failed to send pong frame\n");
                            fflush(stdout);
                        }
                    }
                } else if (frame_type == 0x8A) { // Pong frame
                    if (args->config->debug) {
                        printf("[DEBUG - WebSockets] Received pong frame\n");
                        fflush(stdout);
                    }
                } else if (frame_type == 0x81 || frame_type == 0x82) { // Text or binary frame
                    // Copy payload to buffer
                    if ((size_t)payload_len < sizeof(buffer)) {
                        memcpy(buffer, payload, payload_len);
                        buffer[payload_len] = '\0';
                    } else {
                        fprintf(stderr, "Payload too large for processing buffer\n");
                        frame_buffer_used = 0;
                        continue;
                    }

                    // Check if this is a data message to suppress verbose logging
                    int is_data_message = (strstr(buffer, "\"type\":\"tunnel_data\"") != NULL ||
                                          strstr(buffer, "\"type\":\"tunnel_response\"") != NULL);

                    if (args->config->debug && !is_data_message) {
                        printf("[DEBUG - WebSockets] Received message: %.*s\n", payload_len, payload);
                        fflush(stdout);
                    }

                    // Handle message
                    if (args->config->debug && !is_data_message) {
                        printf("[DEBUG - WebSockets] Processing message: %s\n", buffer);
                        fflush(stdout);
                    }

                    // Handle tunnel messages
                    if (strstr(buffer, "tunnel_data") || strstr(buffer, "tunnel_response")) {
                        if (args->config->debug) {
                            // Suppress tunnel_data debug messages in debug mode
                            if (!strstr(buffer, "tunnel_data")) {
                                printf("[DEBUG - Tunnel] Received tunnel_response message\n");
                                fflush(stdout);
                            }
                        }
                        // Extract request_id and data
                        char *id_start = strstr(buffer, "\"request_id\"");
                        char *data_start = strstr(buffer, "\"data\"");
                        if (id_start && data_start) {
                            char *colon = strchr(id_start, ':');
                            if (colon) {
                                char *open_quote = strchr(colon, '"');
                                if (open_quote) {
                                    id_start = open_quote + 1;
                                    char *close_quote = strchr(id_start, '"');
                                    if (close_quote) {
                                        *close_quote = '\0';
                                        char *data_colon = strchr(data_start, ':');
                                        if (data_colon) {
                                            char *data_quote = strchr(data_colon, '"');
                                            if (data_quote) {
                                                data_start = data_quote + 1;
                                                char *data_end = strchr(data_start, '"');
                                                if (data_end) {
                                                    *data_end = '\0';
                                                    handle_tunnel_data(current_ssl, id_start, data_start, args->config->debug);
                                                }
                                            }
                                        }
                                    }
                                }
                            }
                        }
                    } else if (strstr(buffer, "tunnel_close")) {
                        if (args->config->debug) {
                            printf("[DEBUG - Tunnel] Received tunnel_close message\n");
                            fflush(stdout);
                        }
                        char *id_start = strstr(buffer, "\"request_id\"");
                        if (id_start) {
                            char *colon = strchr(id_start, ':');
                            if (colon) {
                                char *open_quote = strchr(colon, '"');
                                if (open_quote) {
                                    id_start = open_quote + 1;
                                    char *close_quote = strchr(id_start, '"');
                                    if (close_quote) {
                                        *close_quote = '\0';
                                        handle_tunnel_close(current_ssl, id_start, args->config->debug);
                                    }
                                }
                            }
                        }
                    } else {
                        if (args->config->debug) {
                            printf("[DEBUG - WebSockets] Received unknown message type: %s\n", buffer);
                            fflush(stdout);
                        }
                    }
                }

                // Remove processed frame from buffer
                // Calculate the actual frame size consumed from the buffer
                int frame_size = (payload - frame_buffer) + payload_len;
                if (frame_size <= frame_buffer_used) {
                    if (frame_size < frame_buffer_used) {
                        memmove(frame_buffer, frame_buffer + frame_size, frame_buffer_used - frame_size);
                        frame_buffer_used -= frame_size;
                    } else {
                        // Frame consumed entire buffer
                        frame_buffer_used = 0;
                    }
                } else {
                    // Safety check: if calculated frame_size is larger than buffer_used,
                    // something went wrong in parsing, reset buffer to be safe
                    if (args->config->debug) {
                        printf("[DEBUG] Frame size calculation error: frame_size=%d, buffer_used=%d\n", frame_size, frame_buffer_used);
                        fflush(stdout);
                    }
                    frame_buffer_used = 0;
                }
            } else {
                // Frame not complete yet, continue reading
                continue;
            }
        }
    }

thread_cleanup:
    // Cleanup section
    if (args->config->debug) {
        printf("[DEBUG - Tunnel] Performing cleanup and exiting\n");
        fflush(stdout);
    }

    // Cleanup
    if (active_tunnel) {
        if (active_tunnel->local_sock >= 0) {
            close(active_tunnel->local_sock);
        }
        if (active_tunnel->ssl) {
            SSL_free(active_tunnel->ssl);
        }
        free(active_tunnel);
        active_tunnel = NULL;
    }

    free(args);
    return NULL;
}