/*
 * Network Monitoring Suite - Main Monitor
 * Copyright (C) 2024 Stefy Lanza <stefy@sexhack.me>
 *
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see <https://www.gnu.org/licenses/>.
 */

#include <winsock2.h>
#include <windows.h>
#include <iphlpapi.h>
#include <commdlg.h>
#include <shlobj.h>
#include <stdio.h>
#include <stdlib.h>
#include <tlhelp32.h>
#include <psapi.h>

// Structure to track monitored processes
typedef struct {
    DWORD pid;
    HANDLE hProcess;
    char processName[MAX_PATH];
} MonitoredProcess;

#define MAX_PROCESSES 256
MonitoredProcess monitoredProcesses[MAX_PROCESSES];
int processCount = 0;

// #pragma comment(lib, "ws2_32.lib")
// #pragma comment(lib, "iphlpapi.lib")

int is_internal_ip(DWORD ip) {
    BYTE b1 = (ip >> 24) & 0xFF;
    BYTE b2 = (ip >> 16) & 0xFF;
    // BYTE b3 = (ip >> 8) & 0xFF;
    // BYTE b4 = ip & 0xFF;
    if (b1 == 10) return 1;
    if (b1 == 172 && b2 >= 16 && b2 <= 31) return 1;
    if (b1 == 192 && b2 == 168) return 1;
    if (b1 == 127) return 1; // loopback
    return 0;
}

// Function to add a process to the monitoring list
void AddMonitoredProcess(DWORD pid, HANDLE hProcess, const char* processName) {
    if (processCount >= MAX_PROCESSES) return;

    monitoredProcesses[processCount].pid = pid;
    monitoredProcesses[processCount].hProcess = hProcess;
    strcpy(monitoredProcesses[processCount].processName, processName);
    processCount++;
}

// Function to remove a terminated process from monitoring
void RemoveMonitoredProcess(DWORD pid) {
    for (int i = 0; i < processCount; i++) {
        if (monitoredProcesses[i].pid == pid) {
            CloseHandle(monitoredProcesses[i].hProcess);
            // Shift remaining processes
            for (int j = i; j < processCount - 1; j++) {
                monitoredProcesses[j] = monitoredProcesses[j + 1];
            }
            processCount--;
            break;
        }
    }
}

// Function to check if a process is already being monitored
int IsProcessMonitored(DWORD pid) {
    for (int i = 0; i < processCount; i++) {
        if (monitoredProcesses[i].pid == pid) return 1;
    }
    return 0;
}

// Function to inject DLL into a process
BOOL InjectDLLIntoProcess(DWORD pid) {
    HANDLE hProcess = OpenProcess(PROCESS_ALL_ACCESS, FALSE, pid);
    if (!hProcess) return FALSE;

    if (!InjectDLL(hProcess, "ssl_hook.dll")) {
        CloseHandle(hProcess);
        return FALSE;
    }

    CloseHandle(hProcess);
    return TRUE;
}

// Function to find and monitor child processes
void MonitorChildProcesses(DWORD parentPid) {
    HANDLE hSnapshot = CreateToolhelp32Snapshot(TH32CS_SNAPPROCESS, 0);
    if (hSnapshot == INVALID_HANDLE_VALUE) return;

    PROCESSENTRY32 pe32;
    pe32.dwSize = sizeof(PROCESSENTRY32);

    if (Process32First(hSnapshot, &pe32)) {
        do {
            if (pe32.th32ParentProcessID == parentPid && !IsProcessMonitored(pe32.th32ProcessID)) {
                // Found a child process, inject DLL and add to monitoring
                if (InjectDLLIntoProcess(pe32.th32ProcessID)) {
                    HANDLE hProcess = OpenProcess(PROCESS_QUERY_INFORMATION | PROCESS_VM_READ, FALSE, pe32.th32ProcessID);
                    if (hProcess) {
                        AddMonitoredProcess(pe32.th32ProcessID, hProcess, pe32.szExeFile);
                        printf("Now monitoring child process: %s (PID: %lu)\n", pe32.szExeFile, pe32.th32ProcessID);
                    }
                }
            }
        } while (Process32Next(hSnapshot, &pe32));
    }

    CloseHandle(hSnapshot);
}

BOOL InjectDLL(HANDLE hProcess, const char* dllPath) {
    LPVOID pRemoteBuf = VirtualAllocEx(hProcess, NULL, strlen(dllPath) + 1, MEM_COMMIT, PAGE_READWRITE);
    if (!pRemoteBuf) return FALSE;

    if (!WriteProcessMemory(hProcess, pRemoteBuf, dllPath, strlen(dllPath) + 1, NULL)) {
        VirtualFreeEx(hProcess, pRemoteBuf, 0, MEM_RELEASE);
        return FALSE;
    }

    HMODULE hKernel32 = GetModuleHandle("kernel32.dll");
    LPVOID pLoadLibrary = GetProcAddress(hKernel32, "LoadLibraryA");

    HANDLE hThread = CreateRemoteThread(hProcess, NULL, 0, (LPTHREAD_START_ROUTINE)pLoadLibrary, pRemoteBuf, 0, NULL);
    if (!hThread) {
        VirtualFreeEx(hProcess, pRemoteBuf, 0, MEM_RELEASE);
        return FALSE;
    }

    WaitForSingleObject(hThread, INFINITE);
    CloseHandle(hThread);
    VirtualFreeEx(hProcess, pRemoteBuf, 0, MEM_RELEASE);
    return TRUE;
}

int main() {
    char program[256] = {0};
    char outputDir[256] = {0};

    // Select program to monitor
    OPENFILENAME ofn;
    ZeroMemory(&ofn, sizeof(ofn));
    ofn.lStructSize = sizeof(ofn);
    ofn.hwndOwner = NULL;
    ofn.lpstrFile = program;
    ofn.nMaxFile = sizeof(program);
    ofn.lpstrFilter = "Executable Files\0*.exe\0All Files\0*.*\0";
    ofn.nFilterIndex = 1;
    ofn.lpstrFileTitle = NULL;
    ofn.nMaxFileTitle = 0;
    ofn.lpstrInitialDir = NULL;
    ofn.Flags = OFN_PATHMUSTEXIST | OFN_FILEMUSTEXIST;

    if (!GetOpenFileName(&ofn)) {
        printf("No program selected\n");
        return 1;
    }

    // Select output directory
    BROWSEINFO bi;
    ZeroMemory(&bi, sizeof(bi));
    bi.hwndOwner = NULL;
    bi.pidlRoot = NULL;
    bi.pszDisplayName = outputDir;
    bi.lpszTitle = "Select output directory for logs";
    bi.ulFlags = BIF_RETURNONLYFSDIRS | BIF_NEWDIALOGSTYLE;
    bi.lpfn = NULL;
    bi.lParam = 0;

    LPITEMIDLIST pidl = SHBrowseForFolder(&bi);
    if (!pidl) {
        printf("No output directory selected\n");
        return 1;
    }

    SHGetPathFromIDList(pidl, outputDir);
    CoTaskMemFree(pidl);

    STARTUPINFO si = {0};
    si.cb = sizeof(si);
    PROCESS_INFORMATION pi;
    if (!CreateProcess(NULL, program, NULL, NULL, FALSE, CREATE_SUSPENDED, NULL, NULL, &si, &pi)) {
        printf("Failed to start process\n");
        return 1;
    }

    // Inject DLL before resuming
    if (!InjectDLL(pi.hProcess, "ssl_hook.dll")) {
        printf("Failed to inject DLL\n");
        TerminateProcess(pi.hProcess, 1);
        CloseHandle(pi.hProcess);
        CloseHandle(pi.hThread);
        return 1;
    }

    ResumeThread(pi.hThread);

    DWORD rootPid = pi.dwProcessId;

    // Add the root process to monitoring
    AddMonitoredProcess(rootPid, pi.hProcess, program);
    printf("Now monitoring root process: %s (PID: %lu)\n", program, rootPid);

    Sleep(2000); // Wait for process to potentially establish connections

    char internal_log_path[512];
    char external_log_path[512];
    sprintf(internal_log_path, "%s\\internal_traffic.log", outputDir);
    sprintf(external_log_path, "%s\\external_traffic.log", outputDir);

    FILE* internal_log = fopen(internal_log_path, "a");
    FILE* external_log = fopen(external_log_path, "a");

    // Initial snapshot of connections for all monitored processes
    PMIB_TCPTABLE_OWNER_PID tcpTable;
    DWORD size = 0;
    GetExtendedTcpTable(NULL, &size, FALSE, AF_INET, TCP_TABLE_OWNER_PID_ALL, 0);
    tcpTable = (PMIB_TCPTABLE_OWNER_PID)malloc(size);
    if (GetExtendedTcpTable(tcpTable, &size, FALSE, AF_INET, TCP_TABLE_OWNER_PID_ALL, 0) == NO_ERROR) {
        for (DWORD i = 0; i < tcpTable->dwNumEntries; i++) {
            MIB_TCPROW_OWNER_PID row = tcpTable->table[i];
            // Check if this connection belongs to any monitored process
            for (int p = 0; p < processCount; p++) {
                if (row.dwOwningPid == monitoredProcesses[p].pid) {
                    DWORD localIP = row.dwLocalAddr;
                    DWORD remoteIP = row.dwRemoteAddr;
                    char log_entry[256];
                    sprintf(log_entry, "[%s] Connection: Local %lu.%lu.%lu.%lu:%u -> Remote %lu.%lu.%lu.%lu:%u State:%lu\n",
                        monitoredProcesses[p].processName,
                        (localIP >> 24) & 0xFF, (localIP >> 16) & 0xFF, (localIP >> 8) & 0xFF, localIP & 0xFF, ntohs(row.dwLocalPort),
                        (remoteIP >> 24) & 0xFF, (remoteIP >> 16) & 0xFF, (remoteIP >> 8) & 0xFF, remoteIP & 0xFF, ntohs(row.dwRemotePort),
                        row.dwState);

                    if (is_internal_ip(remoteIP)) {
                        fprintf(internal_log, "%s", log_entry);
                    } else {
                        fprintf(external_log, "%s", log_entry);
                    }
                    printf("%s", log_entry);
                    break;
                }
            }
        }
    }
    fclose(internal_log);
    fclose(external_log);
    free(tcpTable);

    // Continuous monitoring loop
    DWORD lastStatusTime = GetTickCount();
    const DWORD STATUS_INTERVAL = 5000; // 5 seconds
    DWORD lastChildCheckTime = GetTickCount();
    const DWORD CHILD_CHECK_INTERVAL = 2000; // 2 seconds

    while (1) {
        DWORD currentTime = GetTickCount();

        // Check for new child processes periodically
        if (currentTime - lastChildCheckTime >= CHILD_CHECK_INTERVAL) {
            // Check all monitored processes for new children
            for (int p = 0; p < processCount; p++) {
                MonitorChildProcesses(monitoredProcesses[p].pid);
            }
            lastChildCheckTime = currentTime;
        }

        // Periodic status logging
        if (currentTime - lastStatusTime >= STATUS_INTERVAL) {
            // Aggregate data from all monitored processes
            DWORD totalSocketCount = 0;
            DWORD totalThreadCount = 0;
            DWORD totalMemoryUsage = 0;

            // Get current TCP connections for all monitored processes
            size = 0;
            GetExtendedTcpTable(NULL, &size, FALSE, AF_INET, TCP_TABLE_OWNER_PID_ALL, 0);
            tcpTable = (PMIB_TCPTABLE_OWNER_PID)malloc(size);
            if (GetExtendedTcpTable(tcpTable, &size, FALSE, AF_INET, TCP_TABLE_OWNER_PID_ALL, 0) == NO_ERROR) {
                for (DWORD i = 0; i < tcpTable->dwNumEntries; i++) {
                    for (int p = 0; p < processCount; p++) {
                        if (tcpTable->table[i].dwOwningPid == monitoredProcesses[p].pid) {
                            totalSocketCount++;
                            break;
                        }
                    }
                }
            }

            // Get process information for all monitored processes
            for (int p = 0; p < processCount; p++) {
                DWORD threadCount = 0;
                DWORD memoryUsage = 0;

                PROCESS_MEMORY_COUNTERS pmc;
                if (GetProcessMemoryInfo(monitoredProcesses[p].hProcess, &pmc, sizeof(pmc))) {
                    memoryUsage = pmc.WorkingSetSize / 1024; // KB
                    totalMemoryUsage += memoryUsage;
                }

                // Get thread count
                HANDLE hSnapshot = CreateToolhelp32Snapshot(TH32CS_SNAPTHREAD, 0);
                if (hSnapshot != INVALID_HANDLE_VALUE) {
                    THREADENTRY32 te32;
                    te32.dwSize = sizeof(THREADENTRY32);
                    if (Thread32First(hSnapshot, &te32)) {
                        do {
                            if (te32.th32OwnerProcessID == monitoredProcesses[p].pid) {
                                threadCount++;
                            }
                        } while (Thread32Next(hSnapshot, &te32));
                    }
                    CloseHandle(hSnapshot);
                }
                totalThreadCount += threadCount;
            }

            // Get system time for logging
            SYSTEMTIME st;
            GetSystemTime(&st);

            // Print status to console
            printf("[%04d-%02d-%02d %02d:%02d:%02d] STATUS - Processes: %d, Sockets: %lu, Threads: %lu, Memory: %lu KB\n",
                   st.wYear, st.wMonth, st.wDay, st.wHour, st.wMinute, st.wSecond,
                   processCount, totalSocketCount, totalThreadCount, totalMemoryUsage);

            free(tcpTable);
            lastStatusTime = currentTime;
        }

        // Check for terminated processes and remove them
        for (int p = processCount - 1; p >= 0; p--) {
            DWORD exitCode;
            if (GetExitCodeProcess(monitoredProcesses[p].hProcess, &exitCode) && exitCode != STILL_ACTIVE) {
                printf("Process %s (PID: %lu) has terminated.\n", monitoredProcesses[p].processName, monitoredProcesses[p].pid);
                RemoveMonitoredProcess(monitoredProcesses[p].pid);
            }
        }

        // Exit if no processes are left to monitor
        if (processCount == 0) {
            printf("All monitored processes have terminated.\n");
            break;
        }

        Sleep(1000); // Check every second
    }

    // Clean up any remaining process handles
    for (int p = 0; p < processCount; p++) {
        CloseHandle(monitoredProcesses[p].hProcess);
    }
    return 0;
}