Fix multiple wsssh issues: framing protocol, performance, transfer limits, and graceful shutdown

- Fix WebSocket framing protocol issues with dynamic buffer allocation
- Remove 255KB transfer limit by using heap allocation for large data
- Optimize performance with 64KB chunking and faster reconnection (1s)
- Add SIGINT handling for graceful tunnel closure with error messages
- Improve WebSocket reconnection handling and tunnel state management
- Treat close frames as tunnel closures to maintain WebSocket connections
- Add proper memory cleanup and buffer overflow prevention
- Reduce reconnection intervals for better responsiveness
parent ac456539
......@@ -288,7 +288,7 @@ void *forward_tcp_to_ws(void *arg) {
}
if (FD_ISSET(client_sock, &readfds)) {
bytes_read = recv(client_sock, buffer, sizeof(buffer), 0);
bytes_read = recv(client_sock, buffer, MAX_CHUNK_SIZE, 0);
if (bytes_read <= 0) {
if (debug) {
if (bytes_read == 0) {
......@@ -319,45 +319,50 @@ void *forward_tcp_to_ws(void *arg) {
fflush(stdout);
}
// Convert to hex with bounds checking
// Reserve space for JSON overhead (about 80 characters)
size_t max_hex_size = BUFFER_SIZE - 100;
// Convert to hex with dynamic allocation
size_t hex_size = (size_t)bytes_read * 2 + 1;
int truncated = 0;
if (hex_size > max_hex_size) {
char *hex_data = malloc(hex_size);
if (!hex_data) {
if (debug) {
printf("[DEBUG] Hex data too large (%zu bytes), truncating to %zu\n", hex_size, max_hex_size);
printf("[DEBUG] Failed to allocate memory for hex data (%zu bytes)\n", hex_size);
fflush(stdout);
}
hex_size = max_hex_size;
truncated = 1;
continue;
}
char hex_data[hex_size];
size_t actual_hex_len = 0;
for (int i = 0; i < bytes_read && actual_hex_len < hex_size - 1; i++) {
for (int i = 0; i < bytes_read; i++) {
sprintf(hex_data + actual_hex_len, "%02x", (unsigned char)buffer[i]);
actual_hex_len += 2;
}
hex_data[actual_hex_len] = '\0';
if (truncated && debug) {
printf("[DEBUG] Hex data truncated, sent %zu of %d bytes\n", actual_hex_len / 2, bytes_read);
fflush(stdout);
}
// Send as tunnel_data
char message[BUFFER_SIZE];
snprintf(message, sizeof(message),
"{\"type\":\"tunnel_data\",\"request_id\":\"%s\",\"data\":\"%s\"}",
request_id, hex_data);
size_t msg_size = strlen("{\"type\":\"tunnel_data\",\"request_id\":\"\",\"data\":\"\"}") + strlen(request_id) + actual_hex_len + 1;
char *message = malloc(msg_size);
if (!message) {
if (debug) {
printf("[DEBUG] Failed to allocate memory for message (%zu bytes)\n", msg_size);
fflush(stdout);
}
free(hex_data);
continue;
}
snprintf(message, msg_size,
"{\"type\":\"tunnel_data\",\"request_id\":\"%s\",\"data\":\"%s\"}",
request_id, hex_data);
if (!send_websocket_frame(ssl, message)) {
if (debug) {
printf("[DEBUG] Failed to send WebSocket frame\n");
fflush(stdout);
}
free(hex_data);
free(message);
break;
}
free(hex_data);
free(message);
}
}
......@@ -425,7 +430,7 @@ void *forward_ws_to_ssh_server(void *arg) {
}
if (FD_ISSET(target_sock, &readfds)) {
bytes_read = recv(target_sock, buffer, sizeof(buffer), 0);
bytes_read = recv(target_sock, buffer, MAX_CHUNK_SIZE, 0);
if (bytes_read <= 0) {
if (debug) {
printf("[DEBUG - TCPConnection] Target connection closed or error\n");
......@@ -439,45 +444,50 @@ void *forward_ws_to_ssh_server(void *arg) {
fflush(stdout);
}
// Convert to hex with bounds checking
// Reserve space for JSON overhead (about 80 characters)
size_t max_hex_size = BUFFER_SIZE - 100;
// Convert to hex with dynamic allocation
size_t hex_size = (size_t)bytes_read * 2 + 1;
int truncated = 0;
if (hex_size > max_hex_size) {
char *hex_data = malloc(hex_size);
if (!hex_data) {
if (debug) {
printf("[DEBUG] Hex data too large (%zu bytes), truncating to %zu\n", hex_size, max_hex_size);
printf("[DEBUG] Failed to allocate memory for hex data (%zu bytes)\n", hex_size);
fflush(stdout);
}
hex_size = max_hex_size;
truncated = 1;
continue;
}
char hex_data[hex_size];
size_t actual_hex_len = 0;
for (int i = 0; i < bytes_read && actual_hex_len < hex_size - 1; i++) {
for (int i = 0; i < bytes_read; i++) {
sprintf(hex_data + actual_hex_len, "%02x", (unsigned char)buffer[i]);
actual_hex_len += 2;
}
hex_data[actual_hex_len] = '\0';
if (truncated && debug) {
printf("[DEBUG] Hex data truncated, sent %zu of %d bytes\n", actual_hex_len / 2, bytes_read);
fflush(stdout);
}
// Send as tunnel_response (from target back to WebSocket)
char message[BUFFER_SIZE];
snprintf(message, sizeof(message),
"{\"type\":\"tunnel_response\",\"request_id\":\"%s\",\"data\":\"%s\"}",
request_id, hex_data);
size_t msg_size = strlen("{\"type\":\"tunnel_response\",\"request_id\":\"\",\"data\":\"\"}") + strlen(request_id) + actual_hex_len + 1;
char *message = malloc(msg_size);
if (!message) {
if (debug) {
printf("[DEBUG] Failed to allocate memory for message (%zu bytes)\n", msg_size);
fflush(stdout);
}
free(hex_data);
continue;
}
snprintf(message, msg_size,
"{\"type\":\"tunnel_response\",\"request_id\":\"%s\",\"data\":\"%s\"}",
request_id, hex_data);
if (!send_websocket_frame(ssl, message)) {
if (debug) {
printf("[DEBUG] Failed to send WebSocket frame\n");
fflush(stdout);
}
free(hex_data);
free(message);
break;
}
free(hex_data);
free(message);
}
}
......
......@@ -100,18 +100,31 @@ int send_registration_message(SSL *ssl, const char *client_id, const char *passw
}
int send_websocket_frame(SSL *ssl, const char *data) {
char frame[BUFFER_SIZE];
frame[0] = 0x81; // FIN + text opcode
int msg_len = strlen(data);
int header_len = 2;
if (msg_len <= 125) {
header_len = 6; // 2 + 4 for mask
} else if (msg_len <= 65535) {
header_len = 8; // 4 + 4 for mask
} else {
header_len = 14; // 10 + 4 for mask
}
int frame_len = header_len + msg_len;
char *frame = malloc(frame_len);
if (!frame) {
return 0;
}
frame[0] = 0x81; // FIN + text opcode
if (msg_len <= 125) {
frame[1] = 0x80 | msg_len; // MASK + length
} else if (msg_len <= 65535) {
frame[1] = 0x80 | 126; // MASK + extended length
frame[2] = (msg_len >> 8) & 0xFF;
frame[3] = msg_len & 0xFF;
header_len = 4;
} else {
frame[1] = 0x80 | 127; // MASK + extended length
frame[2] = 0;
......@@ -122,24 +135,20 @@ int send_websocket_frame(SSL *ssl, const char *data) {
frame[7] = (msg_len >> 16) & 0xFF;
frame[8] = (msg_len >> 8) & 0xFF;
frame[9] = msg_len & 0xFF;
header_len = 10;
}
// Add mask key
char mask_key[4];
for (int i = 0; i < 4; i++) {
mask_key[i] = rand() % 256;
frame[header_len + i] = mask_key[i];
frame[header_len - 4 + i] = mask_key[i];
}
header_len += 4;
// Mask payload
for (int i = 0; i < msg_len; i++) {
frame[header_len + i] = data[i] ^ mask_key[i % 4];
}
int frame_len = header_len + msg_len;
// Handle partial writes for large frames
int total_written = 0;
int retry_count = 0;
......@@ -166,6 +175,7 @@ int send_websocket_frame(SSL *ssl, const char *data) {
char error_buf[256];
ERR_error_string_n(ssl_error, error_buf, sizeof(error_buf));
fprintf(stderr, "SSL write error details: %s\n", error_buf);
free(frame);
return 0; // Write failed
}
total_written += written;
......@@ -174,9 +184,11 @@ int send_websocket_frame(SSL *ssl, const char *data) {
if (total_written < frame_len) {
fprintf(stderr, "WebSocket frame write incomplete: %d/%d bytes written\n", total_written, frame_len);
free(frame);
return 0;
}
free(frame);
return 1;
}
......
......@@ -994,6 +994,7 @@ start_forwarding_threads:
printf("[DEBUG - Tunnel] Received tunnel_close message\n");
fflush(stdout);
}
fprintf(stderr, "Error: Tunnel closed by remote client\n");
char *id_start = strstr(buffer, "\"request_id\"");
if (id_start) {
char *colon = strchr(id_start, ':');
......
......@@ -29,6 +29,7 @@
#include <getopt.h>
#include <pthread.h>
#include <errno.h>
#include <signal.h>
#include "wssshlib.h"
#include "websocket.h"
......@@ -39,6 +40,12 @@
int global_debug = 0;
volatile sig_atomic_t sigint_received = 0;
void sigint_handler(int sig __attribute__((unused))) {
sigint_received = 1;
}
typedef struct {
char *wssshd_server;
......@@ -387,6 +394,15 @@ int connect_to_server(const wssshc_config_t *config) {
static int frame_buffer_used = 0;
while (1) {
// Check for SIGINT
if (sigint_received) {
if (active_tunnel) {
send_tunnel_close(ssl, active_tunnel->request_id, config->debug);
fprintf(stderr, "Received SIGINT, sent tunnel_close and exiting\n");
}
break;
}
// Always try to read more data if there's space, even if we have a complete frame
// This ensures we can accumulate data for very large frames
if ((size_t)frame_buffer_used < sizeof(frame_buffer)) {
......@@ -481,12 +497,14 @@ int connect_to_server(const wssshc_config_t *config) {
if (frame_type == 0x88) { // Close frame
if (config->debug) {
printf("[DEBUG - WebSockets] Received close frame - cleaning up and reconnecting...\n");
printf("[DEBUG - WebSockets] Received close frame - treating as tunnel close, keeping WebSocket open\n");
fflush(stdout);
}
// Clean up tunnel resources before reconnecting
cleanup_tunnel(config->debug);
return 0;
// Treat as tunnel close, don't close WebSocket connection
if (active_tunnel) {
handle_tunnel_close(NULL, active_tunnel->request_id, config->debug);
}
// Continue processing, don't return
} else if (frame_type == 0x89) { // Ping frame
if (config->debug) {
printf("[DEBUG - WebSockets] Received ping frame\n");
......@@ -644,6 +662,9 @@ int main(int argc, char *argv[]) {
pthread_mutex_init(&tunnel_mutex, NULL);
// Set up signal handler for SIGINT
signal(SIGINT, sigint_handler);
// Load configuration files first (system and user configs)
load_config(&config);
......@@ -691,18 +712,18 @@ int main(int argc, char *argv[]) {
while (1) {
int result = connect_to_server(&config);
if (result == 1) {
// Error condition - use normal retry interval
printf("Connection lost, retrying in %d seconds...\n", config.interval);
sleep(config.interval);
// Error condition - use short retry interval for immediate reconnection
printf("Connection lost, retrying in 1 seconds...\n");
sleep(1);
} else if (result == 0) {
// Close frame received - add small delay to prevent rapid reconnection
// Close frame received - use short delay for immediate reconnection
if (config.debug) {
printf("[DEBUG - WebSockets] Server initiated disconnect, reconnecting in 2 seconds...\n");
printf("[DEBUG - WebSockets] Server initiated disconnect, reconnecting in 1 seconds...\n");
fflush(stdout);
}
sleep(2);
sleep(1);
}
// result == 2 (registration failed) also uses normal interval
// result == 2 (registration failed) also uses shorter interval
}
// Cleanup
......
......@@ -37,6 +37,7 @@
#include <sys/select.h>
#define BUFFER_SIZE 1048576
#define MAX_CHUNK_SIZE 65536
#define DEFAULT_PORT 22
// Config structures
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment