/*
 * Network Monitoring Suite - SSL Hook Header
 * Copyright (C) 2024 Stefy Lanza <stefy@sexhack.me>
 *
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see <https://www.gnu.org/licenses/>.
 */

#ifndef SSL_HOOK_H
#define SSL_HOOK_H

#include <windows.h>
#define SECURITY_WIN32
#include <sspi.h>
#include <stdio.h>

// Function pointer types for SSL functions
typedef int (*SSL_write_t)(void* ssl, const void* buf, int num);
typedef int (*SSL_read_t)(void* ssl, void* buf, int num);

// Function pointer types for Schannel functions
typedef SECURITY_STATUS (*EncryptMessage_t)(PCtxtHandle phContext, ULONG fQOP, PSecBufferDesc pMessage, ULONG MessageSeqNo);
typedef SECURITY_STATUS (*DecryptMessage_t)(PCtxtHandle phContext, PSecBufferDesc pMessage, ULONG MessageSeqNo, PULONG pfQOP);

// Original function pointers
extern SSL_write_t original_SSL_write;
extern SSL_read_t original_SSL_read;
extern EncryptMessage_t original_EncryptMessage;
extern DecryptMessage_t original_DecryptMessage;

// Hooked functions
int hooked_SSL_write(void* ssl, const void* buf, int num);
int hooked_SSL_read(void* ssl, void* buf, int num);
SECURITY_STATUS hooked_EncryptMessage(PCtxtHandle phContext, ULONG fQOP, PSecBufferDesc pMessage, ULONG MessageSeqNo);
SECURITY_STATUS hooked_DecryptMessage(PCtxtHandle phContext, PSecBufferDesc pMessage, ULONG MessageSeqNo, PULONG pfQOP);

// Function pointer types for additional SSL libraries

// LibreSSL (same as OpenSSL)
typedef int (*LibreSSL_write_t)(void* ssl, const void* buf, int num);
typedef int (*LibreSSL_read_t)(void* ssl, void* buf, int num);

// NSS (Mozilla)
typedef int (*NSS_SSL_Write_t)(void* ssl, const void* buf, int num);
typedef int (*NSS_SSL_Read_t)(void* ssl, void* buf, int num);
typedef int (*NSS_SSL_ForceHandshake_t)(void* ssl);

// GnuTLS
typedef int (*GnuTLS_record_send_t)(void* session, const void* data, size_t data_size);
typedef int (*GnuTLS_record_recv_t)(void* session, void* data, size_t data_size);

// mbed TLS
typedef int (*mbedTLS_ssl_write_t)(void* ssl, const unsigned char* buf, size_t len);
typedef int (*mbedTLS_ssl_read_t)(void* ssl, unsigned char* buf, size_t len);

// wolfSSL
typedef int (*wolfSSL_write_t)(void* ssl, const void* data, int sz);
typedef int (*wolfSSL_read_t)(void* ssl, void* data, int sz);

// Botan (C++ - simplified)
typedef int (*Botan_TLS_write_t)(void* channel, const void* buf, size_t length);
typedef int (*Botan_TLS_read_t)(void* channel, void* buf, size_t length);

// Function pointer types for syscall monitoring
typedef HMODULE (*LoadLibraryA_t)(LPCSTR lpLibFileName);
typedef HMODULE (*LoadLibraryW_t)(LPCWSTR lpLibFileName);
typedef FARPROC (*GetProcAddress_t)(HMODULE hModule, LPCSTR lpProcName);

// Original function pointers for additional SSL libraries
extern LibreSSL_write_t original_LibreSSL_write;
extern LibreSSL_read_t original_LibreSSL_read;
extern NSS_SSL_Write_t original_NSS_SSL_Write;
extern NSS_SSL_Read_t original_NSS_SSL_Read;
extern NSS_SSL_ForceHandshake_t original_NSS_SSL_ForceHandshake;
extern GnuTLS_record_send_t original_GnuTLS_record_send;
extern GnuTLS_record_recv_t original_GnuTLS_record_recv;
extern mbedTLS_ssl_write_t original_mbedTLS_ssl_write;
extern mbedTLS_ssl_read_t original_mbedTLS_ssl_read;
extern wolfSSL_write_t original_wolfSSL_write;
extern wolfSSL_read_t original_wolfSSL_read;
extern Botan_TLS_write_t original_Botan_TLS_write;
extern Botan_TLS_read_t original_Botan_TLS_read;

// Original function pointers for syscall monitoring
extern LoadLibraryA_t original_LoadLibraryA;
extern LoadLibraryW_t original_LoadLibraryW;
extern GetProcAddress_t original_GetProcAddress;

// Hooked functions for additional SSL libraries
int hooked_LibreSSL_write(void* ssl, const void* buf, int num);
int hooked_LibreSSL_read(void* ssl, void* buf, int num);
int hooked_NSS_SSL_Write(void* ssl, const void* buf, int num);
int hooked_NSS_SSL_Read(void* ssl, void* buf, int num);
int hooked_NSS_SSL_ForceHandshake(void* ssl);
int hooked_GnuTLS_record_send(void* session, const void* data, size_t data_size);
int hooked_GnuTLS_record_recv(void* session, void* data, size_t data_size);
int hooked_mbedTLS_ssl_write(void* ssl, const unsigned char* buf, size_t len);
int hooked_mbedTLS_ssl_read(void* ssl, unsigned char* buf, size_t len);
int hooked_wolfSSL_write(void* ssl, const void* data, int sz);
int hooked_wolfSSL_read(void* ssl, void* data, int sz);
int hooked_Botan_TLS_write(void* channel, const void* buf, size_t length);
int hooked_Botan_TLS_read(void* channel, void* buf, size_t length);

// Hooked functions for syscall monitoring
HMODULE hooked_LoadLibraryA(LPCSTR lpLibFileName);
HMODULE hooked_LoadLibraryW(LPCWSTR lpLibFileName);
FARPROC hooked_GetProcAddress(HMODULE hModule, LPCSTR lpProcName);

// Helper functions
void log_data(const char* direction, const void* buf, int num);
void log_syscall(const char* syscall_name, const char* details);

// Packet capture function
__declspec(dllimport) void log_unencrypted_traffic(DWORD src_ip, DWORD dest_ip, WORD src_port, WORD dest_port, const BYTE* data, int len, int is_internal);

#endif