/*
 * Network Monitoring Suite - SSL Hook Implementation
 * Copyright (C) 2024 Stefy Lanza <stefy@sexhack.me>
 *
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see <https://www.gnu.org/licenses/>.
 */

#include "ssl_hook.h"
#ifndef NO_DETOURS
#include <detours.h>
#endif
#define SECURITY_WIN32
#include <sspi.h>

// Original function pointers
SSL_write_t original_SSL_write = NULL;
SSL_read_t original_SSL_read = NULL;
EncryptMessage_t original_EncryptMessage = NULL;
DecryptMessage_t original_DecryptMessage = NULL;
LoadLibraryA_t original_LoadLibraryA = NULL;
LoadLibraryW_t original_LoadLibraryW = NULL;
GetProcAddress_t original_GetProcAddress = NULL;

// Hooked functions
int hooked_SSL_write(void* ssl, const void* buf, int num) {
    log_data("SSL_WRITE", buf, num);
#ifndef NO_DETOURS
    // For PCAP logging, we'd need to track connection details
    // This is simplified - in practice you'd need to maintain connection state
    log_unencrypted_traffic(0, 0, 0, 443, (const BYTE*)buf, num, 0); // Assume external for now
#endif
    return original_SSL_write ? original_SSL_write(ssl, buf, num) : -1;
}

int hooked_SSL_read(void* ssl, void* buf, int num) {
    int result = original_SSL_read ? original_SSL_read(ssl, buf, num) : -1;
    if (result > 0) {
        log_data("SSL_READ", buf, result);
#ifndef NO_DETOURS
        log_unencrypted_traffic(0, 0, 443, 0, (const BYTE*)buf, result, 0); // Assume external
#endif
    }
    return result;
}

SECURITY_STATUS hooked_EncryptMessage(PCtxtHandle phContext, ULONG fQOP, PSecBufferDesc pMessage, ULONG MessageSeqNo) {
    // Log plaintext data before encryption
    for (ULONG i = 0; i < pMessage->cBuffers; i++) {
        if (pMessage->pBuffers[i].BufferType == SECBUFFER_DATA) {
            log_data("SCHANNEL_ENCRYPT", pMessage->pBuffers[i].pvBuffer, pMessage->pBuffers[i].cbBuffer);
#ifndef NO_DETOURS
            log_unencrypted_traffic(0, 0, 0, 443, (const BYTE*)pMessage->pBuffers[i].pvBuffer, pMessage->pBuffers[i].cbBuffer, 0);
#endif
            break;
        }
    }
    return original_EncryptMessage ? original_EncryptMessage(phContext, fQOP, pMessage, MessageSeqNo) : SEC_E_UNSUPPORTED_FUNCTION;
}

SECURITY_STATUS hooked_DecryptMessage(PCtxtHandle phContext, PSecBufferDesc pMessage, ULONG MessageSeqNo, PULONG pfQOP) {
    SECURITY_STATUS status = original_DecryptMessage ? original_DecryptMessage(phContext, pMessage, MessageSeqNo, pfQOP) : SEC_E_UNSUPPORTED_FUNCTION;
    if (status == SEC_E_OK) {
        // Log decrypted data after decryption
        for (ULONG i = 0; i < pMessage->cBuffers; i++) {
            if (pMessage->pBuffers[i].BufferType == SECBUFFER_DATA) {
                log_data("SCHANNEL_DECRYPT", pMessage->pBuffers[i].pvBuffer, pMessage->pBuffers[i].cbBuffer);
#ifndef NO_DETOURS
                log_unencrypted_traffic(0, 0, 443, 0, (const BYTE*)pMessage->pBuffers[i].pvBuffer, pMessage->pBuffers[i].cbBuffer, 0);
#endif
                break;
            }
        }
    }
    return status;
}

HMODULE hooked_LoadLibraryA(LPCSTR lpLibFileName) {
    log_syscall("LoadLibraryA", lpLibFileName);
    return original_LoadLibraryA(lpLibFileName);
}

HMODULE hooked_LoadLibraryW(LPCWSTR lpLibFileName) {
    char libName[MAX_PATH];
    WideCharToMultiByte(CP_ACP, 0, lpLibFileName, -1, libName, MAX_PATH, NULL, NULL);
    log_syscall("LoadLibraryW", libName);
    return original_LoadLibraryW(lpLibFileName);
}

FARPROC hooked_GetProcAddress(HMODULE hModule, LPCSTR lpProcName) {
    char details[256];
    sprintf(details, "Module: 0x%p, Function: %s", hModule, lpProcName);
    log_syscall("GetProcAddress", details);
    return original_GetProcAddress(hModule, lpProcName);
}

// Helper functions
void log_data(const char* direction, const void* buf, int num) {
    FILE* logFile = fopen("ssl_log.txt", "a");
    if (logFile) {
        fprintf(logFile, "[%s] %d bytes: ", direction, num);
        for (int i = 0; i < num && i < 100; i++) { // Log first 100 bytes
            fprintf(logFile, "%02x ", ((unsigned char*)buf)[i]);
        }
        fprintf(logFile, "\n");
        fclose(logFile);
    }
}

void log_syscall(const char* syscall_name, const char* details) {
    FILE* logFile = fopen("syscall_log.txt", "a");
    if (logFile) {
        SYSTEMTIME st;
        GetSystemTime(&st);
        fprintf(logFile, "[%04d-%02d-%02d %02d:%02d:%02d] %s: %s\n",
                st.wYear, st.wMonth, st.wDay, st.wHour, st.wMinute, st.wSecond,
                syscall_name, details);
        fclose(logFile);
    }
}

// DLL entry point
BOOL APIENTRY DllMain(HMODULE hModule, DWORD ul_reason_for_call, LPVOID lpReserved) {
    if (ul_reason_for_call == DLL_PROCESS_ATTACH) {
        // Get handles to SSL libraries (this is simplified - in reality you'd need to handle multiple SSL implementations)
        HMODULE hLibSSL = GetModuleHandle("libssl-1_1.dll"); // OpenSSL
        if (!hLibSSL) hLibSSL = GetModuleHandle("ssleay32.dll"); // Older OpenSSL
        if (!hLibSSL) hLibSSL = GetModuleHandle("libssl.dll");

        if (hLibSSL) {
            original_SSL_write = (SSL_write_t)GetProcAddress(hLibSSL, "SSL_write");
            original_SSL_read = (SSL_read_t)GetProcAddress(hLibSSL, "SSL_read");

            if (original_SSL_write && original_SSL_read) {
#ifndef NO_DETOURS
                DetourTransactionBegin();
                DetourUpdateThread(GetCurrentThread());
                DetourAttach(&(PVOID&)original_SSL_write, hooked_SSL_write);
                DetourAttach(&(PVOID&)original_SSL_read, hooked_SSL_read);
                DetourTransactionCommit();
#endif
            }
        }

        // Also try to hook Windows Schannel if available
        HMODULE hSecur32 = GetModuleHandle("secur32.dll");
        if (hSecur32) {
            original_EncryptMessage = (EncryptMessage_t)GetProcAddress(hSecur32, "EncryptMessage");
            original_DecryptMessage = (DecryptMessage_t)GetProcAddress(hSecur32, "DecryptMessage");

            if (original_EncryptMessage && original_DecryptMessage) {
#ifndef NO_DETOURS
                DetourTransactionBegin();
                DetourUpdateThread(GetCurrentThread());
                DetourAttach(&(PVOID&)original_EncryptMessage, hooked_EncryptMessage);
                DetourAttach(&(PVOID&)original_DecryptMessage, hooked_DecryptMessage);
                DetourTransactionCommit();
#endif
            }
        }

        // Hook system DLL loading functions
        HMODULE hKernel32 = GetModuleHandle("kernel32.dll");
        if (hKernel32) {
            original_LoadLibraryA = (LoadLibraryA_t)GetProcAddress(hKernel32, "LoadLibraryA");
            original_LoadLibraryW = (LoadLibraryW_t)GetProcAddress(hKernel32, "LoadLibraryW");
            original_GetProcAddress = (GetProcAddress_t)GetProcAddress(hKernel32, "GetProcAddress");

            if (original_LoadLibraryA && original_LoadLibraryW && original_GetProcAddress) {
#ifndef NO_DETOURS
                DetourTransactionBegin();
                DetourUpdateThread(GetCurrentThread());
                DetourAttach(&(PVOID&)original_LoadLibraryA, hooked_LoadLibraryA);
                DetourAttach(&(PVOID&)original_LoadLibraryW, hooked_LoadLibraryW);
                DetourAttach(&(PVOID&)original_GetProcAddress, hooked_GetProcAddress);
                DetourTransactionCommit();
#endif
            }
        }
    } else if (ul_reason_for_call == DLL_PROCESS_DETACH) {
        if (original_SSL_write && original_SSL_read) {
#ifndef NO_DETOURS
            DetourTransactionBegin();
            DetourUpdateThread(GetCurrentThread());
            DetourDetach(&(PVOID&)original_SSL_write, hooked_SSL_write);
            DetourDetach(&(PVOID&)original_SSL_read, hooked_SSL_read);
            DetourTransactionCommit();
#endif
        }
        if (original_EncryptMessage && original_DecryptMessage) {
#ifndef NO_DETOURS
            DetourTransactionBegin();
            DetourUpdateThread(GetCurrentThread());
            DetourDetach(&(PVOID&)original_EncryptMessage, hooked_EncryptMessage);
            DetourDetach(&(PVOID&)original_DecryptMessage, hooked_DecryptMessage);
            DetourTransactionCommit();
#endif
        }
        if (original_LoadLibraryA && original_LoadLibraryW && original_GetProcAddress) {
#ifndef NO_DETOURS
            DetourTransactionBegin();
            DetourUpdateThread(GetCurrentThread());
            DetourDetach(&(PVOID&)original_LoadLibraryA, hooked_LoadLibraryA);
            DetourDetach(&(PVOID&)original_LoadLibraryW, hooked_LoadLibraryW);
            DetourDetach(&(PVOID&)original_GetProcAddress, hooked_GetProcAddress);
            DetourTransactionCommit();
#endif
        }
    }
    return TRUE;
}