#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <signal.h>

volatile int running = 1;

void sigint_handler(int sig) {
    running = 0;
}

// Simple base64 encoding function
char *base64_encode(const unsigned char *data, size_t input_length) {
    static const char base64_chars[] = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";

    size_t output_length = 4 * ((input_length + 2) / 3);
    char *encoded_data = malloc(output_length + 1);
    if (encoded_data == NULL) return NULL;

    for (size_t i = 0, j = 0; i < input_length;) {
        uint32_t octet_a = i < input_length ? data[i++] : 0;
        uint32_t octet_b = i < input_length ? data[i++] : 0;
        uint32_t octet_c = i < input_length ? data[i++] : 0;

        uint32_t triple = (octet_a << 0x10) + (octet_b << 0x08) + octet_c;

        encoded_data[j++] = base64_chars[(triple >> 3 * 6) & 0x3F];
        encoded_data[j++] = base64_chars[(triple >> 2 * 6) & 0x3F];
        encoded_data[j++] = base64_chars[(triple >> 1 * 6) & 0x3F];
        encoded_data[j++] = base64_chars[(triple >> 0 * 6) & 0x3F];
    }

    // Add padding
    size_t padding = (3 - (input_length % 3)) % 3;
    for (size_t i = 0; i < padding; i++) {
        encoded_data[output_length - 1 - i] = '=';
    }

    encoded_data[output_length] = '\0';
    return encoded_data;
}

int main(int argc, char *argv[]) {
    if (argc != 2) {
        fprintf(stderr, "Usage: %s <port>\n", argv[0]);
        return 1;
    }

    int port = atoi(argv[1]);
    if (port <= 0 || port > 65535) {
        fprintf(stderr, "Invalid port number\n");
        return 1;
    }

    signal(SIGINT, sigint_handler);

    int server_fd = socket(AF_INET, SOCK_STREAM, 0);
    if (server_fd < 0) {
        perror("socket");
        return 1;
    }

    // Set SO_REUSEADDR to allow immediate reuse of the port
    int opt = 1;
    if (setsockopt(server_fd, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)) < 0) {
        perror("setsockopt");
        close(server_fd);
        return 1;
    }

    struct sockaddr_in addr;
    memset(&addr, 0, sizeof(addr));
    addr.sin_family = AF_INET;
    addr.sin_addr.s_addr = INADDR_ANY;
    addr.sin_port = htons(port);

    if (bind(server_fd, (struct sockaddr *)&addr, sizeof(addr)) < 0) {
        perror("bind");
        close(server_fd);
        return 1;
    }

    if (listen(server_fd, 10) < 0) {
        perror("listen");
        close(server_fd);
        return 1;
    }

    printf("Listening on 0.0.0.0:%d\n", port);

    fd_set master_set, read_set;
    FD_ZERO(&master_set);
    FD_SET(server_fd, &master_set);
    int max_fd = server_fd;

    while (running) {
        read_set = master_set;
        struct timeval timeout;
        timeout.tv_sec = 1;  // 1 second timeout to check running flag
        timeout.tv_usec = 0;

        int ret = select(max_fd + 1, &read_set, NULL, NULL, &timeout);
        if (ret < 0) {
            if (running) perror("select");
            break;
        } else if (ret == 0) {
            // Timeout - check if we should still be running
            continue;
        }

        for (int i = 0; i <= max_fd; i++) {
            if (FD_ISSET(i, &read_set)) {
                if (i == server_fd) {
                    // New connection
                    struct sockaddr_in client_addr;
                    socklen_t client_len = sizeof(client_addr);
                    int client_fd = accept(server_fd, (struct sockaddr *)&client_addr, &client_len);
                    if (client_fd < 0) {
                        perror("accept");
                        continue;
                    }
                    printf("CONNECTION FROM %s:%d\n", inet_ntoa(client_addr.sin_addr), ntohs(client_addr.sin_port));
                    FD_SET(client_fd, &master_set);
                    if (client_fd > max_fd) max_fd = client_fd;
                } else {
                    // Data from client
                    char buffer[1024];
                    int n = read(i, buffer, sizeof(buffer) - 1);
                    if (n <= 0) {
                        // Connection closed or error
                        close(i);
                        FD_CLR(i, &master_set);
                        printf("CONNECTION CLOSED\n");
                    } else {
                        buffer[n] = '\0';
                        // Remove trailing newline if present
                        if (n > 0 && buffer[n-1] == '\n') {
                            buffer[n-1] = '\0';
                            n--;
                        }

                        char *b64 = base64_encode((unsigned char *)buffer, n);
                        printf("(RECEIVED) %s <-> %s\n", buffer, b64 ? b64 : "ERROR");

                        char response[1024 + 100];
                        char *response_b64 = base64_encode((unsigned char *)buffer, n);
                        snprintf(response, sizeof(response), "RECEIVED: %s <-> %s", buffer, response_b64 ? response_b64 : "ERROR");
                        write(i, response, strlen(response));

                        if (b64) free(b64);
                        if (response_b64) free(response_b64);
                    }
                }
            }
        }
    }

    // Close all sockets
    for (int i = 0; i <= max_fd; i++) {
        if (FD_ISSET(i, &master_set)) {
            close(i);
        }
    }

    printf("Exiting...\n");
    return 0;
}