/**
 * Plugin system implementation for wssshd2
 *
 * Copyright (C) 2024 Stefy Lanza <stefy@nexlab.net> and SexHack.me
 */

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <dlfcn.h>
#include <dirent.h>
#include <sys/stat.h>
#include <unistd.h>
#include "plugin.h"

// Plugin entry point symbol name
#define PLUGIN_INTERFACE_SYMBOL "plugin_interface"

// Global plugin list
#define MAX_PLUGINS 10
static void *loaded_plugins[MAX_PLUGINS];
static plugin_interface_t *plugin_interfaces[MAX_PLUGINS];
static plugin_context_t plugin_contexts[MAX_PLUGINS];
static int plugin_count = 0;

// Global state and config for plugins
static wssshd_state_t *global_state = NULL;
static const wssshd_config_t *global_config = NULL;

// Plugin initialization
int plugin_init(wssshd_state_t *state) {
    global_state = state;
    return plugin_system_init(state, NULL) ? 0 : -1;
}

// Plugin cleanup
void plugin_cleanup(void) {
    plugin_system_cleanup();
}

// Plugin system initialization
bool plugin_system_init(wssshd_state_t *state, const wssshd_config_t *config) {
    global_state = state;
    global_config = config;
    plugin_count = 0;

    // Try to load plugins from default directories
    const char *plugin_dirs[] = {"./plugins", "/usr/local/lib/wssshd/plugins", NULL};

    for (int i = 0; plugin_dirs[i]; i++) {
        if (plugin_load_from_directory(plugin_dirs[i])) {
            printf("[PLUGIN] Loaded plugins from %s\n", plugin_dirs[i]);
            break;
        }
    }

    return true;
}

// Plugin system cleanup
void plugin_system_cleanup(void) {
    for (int i = 0; i < plugin_count; i++) {
        if (plugin_interfaces[i] && plugin_interfaces[i]->cleanup) {
            plugin_interfaces[i]->cleanup(&plugin_contexts[i]);
        }
        if (loaded_plugins[i]) {
            dlclose(loaded_plugins[i]);
        }
    }
    plugin_count = 0;
}

// Load plugins from directory (recursive)
bool plugin_load_from_directory(const char *plugin_dir) {
    if (!plugin_dir) return false;

    DIR *dir = opendir(plugin_dir);
    if (!dir) return false;

    struct dirent *entry;
    bool found_plugins = false;
    while ((entry = readdir(dir)) != NULL) {
        if (strcmp(entry->d_name, ".") == 0 || strcmp(entry->d_name, "..") == 0) continue;

        char full_path[1024];
        snprintf(full_path, sizeof(full_path), "%s/%s", plugin_dir, entry->d_name);

        struct stat st;
        if (stat(full_path, &st) == 0) {
            if (S_ISDIR(st.st_mode)) {
                // Recurse into subdirectory
                if (plugin_load_from_directory(full_path)) {
                    found_plugins = true;
                }
            } else if (S_ISREG(st.st_mode) && strstr(entry->d_name, ".so")) {
                // Load .so file
                plugin_context_t ctx = {global_state, global_config, NULL, plugin_get_data, plugin_set_data};
                if (plugin_load(full_path, &ctx)) {
                    printf("[PLUGIN] Loaded plugin: %s\n", entry->d_name);
                    found_plugins = true;
                }
            }
        }
    }

    closedir(dir);
    return found_plugins;
}

// Load a single plugin
bool plugin_load(const char *plugin_path, plugin_context_t *ctx) {
    if (plugin_count >= MAX_PLUGINS) {
        fprintf(stderr, "[PLUGIN] Maximum number of plugins reached\n");
        return false;
    }

    // Load the shared library
    void *handle = dlopen(plugin_path, RTLD_LAZY);
    if (!handle) {
        fprintf(stderr, "[PLUGIN] Failed to load plugin %s: %s\n", plugin_path, dlerror());
        return false;
    }

    // Get the plugin interface
    plugin_interface_t *interface = (plugin_interface_t *)dlsym(handle, PLUGIN_INTERFACE_SYMBOL);
    if (!interface) {
        fprintf(stderr, "[PLUGIN] Plugin %s does not export %s\n", plugin_path, PLUGIN_INTERFACE_SYMBOL);
        dlclose(handle);
        return false;
    }

    // Check API version
    if (interface->info.api_version != PLUGIN_API_VERSION) {
        fprintf(stderr, "[PLUGIN] Plugin %s has incompatible API version %d (expected %d)\n",
                plugin_path, interface->info.api_version, PLUGIN_API_VERSION);
        dlclose(handle);
        return false;
    }

    // Initialize plugin context
    plugin_contexts[plugin_count] = *ctx;

    // Initialize the plugin
    if (interface->init && !interface->init(&plugin_contexts[plugin_count])) {
        fprintf(stderr, "[PLUGIN] Plugin %s initialization failed\n", plugin_path);
        dlclose(handle);
        return false;
    }

    // Store plugin information
    loaded_plugins[plugin_count] = handle;
    plugin_interfaces[plugin_count] = interface;
    plugin_count++;

    return true;
}

// Unload a plugin
void plugin_unload(plugin_context_t *ctx) {
    // Find and remove the plugin
    for (int i = 0; i < plugin_count; i++) {
        if (&plugin_contexts[i] == ctx) {
            if (plugin_interfaces[i]->cleanup) {
                plugin_interfaces[i]->cleanup(ctx);
            }
            dlclose(loaded_plugins[i]);

            // Shift remaining plugins
            for (int j = i; j < plugin_count - 1; j++) {
                loaded_plugins[j] = loaded_plugins[j + 1];
                plugin_interfaces[j] = plugin_interfaces[j + 1];
                plugin_contexts[j] = plugin_contexts[j + 1];
            }
            plugin_count--;
            break;
        }
    }
}

// Get plugin data
void *plugin_get_data(plugin_context_t *ctx) {
    return ctx->plugin_data;
}

// Set plugin data
void plugin_set_data(plugin_context_t *ctx, void *data) {
    ctx->plugin_data = data;
}

// Handle web requests
int plugin_handle_web_request(int client_fd, const http_request_t *req) {
    for (int i = 0; i < plugin_count; i++) {
        if (plugin_interfaces[i]->capabilities & PLUGIN_CAP_WEB) {
            // Check web routes
            for (size_t j = 0; j < plugin_interfaces[i]->web.routes_count; j++) {
                const plugin_web_route_t *route = &plugin_interfaces[i]->web.routes[j];
                if (strcmp(req->method, route->method) == 0 &&
                    strcmp(req->path, route->path) == 0) {
                    return route->handler(client_fd, req, &plugin_contexts[i]);
                }
            }
        }
    }
    return 0; // Not handled
}

// Handle messages
int plugin_handle_message(wssshd_state_t *state, ws_connection_t *conn, const char *message, size_t message_len) {
    (void)state; // Parameter not used in current implementation but kept for API consistency

    for (int i = 0; i < plugin_count; i++) {
        if (plugin_interfaces[i]->capabilities & PLUGIN_CAP_MESSAGE) {
            // Check message handlers
            for (size_t j = 0; j < plugin_interfaces[i]->message.handlers_count; j++) {
                const plugin_message_handler_t *handler = &plugin_interfaces[i]->message.handlers[j];
                // Simple check for message type in JSON
                char type_check[256];
                snprintf(type_check, sizeof(type_check), "\"type\":\"%s\"", handler->message_type);
                if (strstr(message, type_check)) {
                    return handler->handler(message, message_len, conn, &plugin_contexts[i]);
                }
            }
        }
    }
    return 0; // Not handled
}

// Create transport session
tunnel_session_t *plugin_create_transport_session(const char *protocol, plugin_context_t *ctx, const char *client_id, const char *username, bool debug) {
    for (int i = 0; i < plugin_count; i++) {
        if (plugin_interfaces[i]->capabilities & PLUGIN_CAP_TRANSPORT) {
            // Check transport protocols
            for (size_t j = 0; j < plugin_interfaces[i]->transport.protocols_count; j++) {
                const plugin_transport_protocol_t *transport = &plugin_interfaces[i]->transport.protocols[j];
                if (strcmp(protocol, transport->protocol_name) == 0) {
                    return transport->create_session(ctx, client_id, username, debug);
                }
            }
        }
    }
    return NULL; // Not found
}