Fix multiple wsssh issues: framing protocol, performance, transfer limits, and graceful shutdown

- Fix WebSocket framing protocol issues with dynamic buffer allocation
- Remove 255KB transfer limit by using heap allocation for large data
- Optimize performance with 64KB chunking and faster reconnection (1s)
- Add SIGINT handling for graceful tunnel closure with error messages
- Improve WebSocket reconnection handling and tunnel state management
- Treat close frames as tunnel closures to maintain WebSocket connections
- Add proper memory cleanup and buffer overflow prevention
- Reduce reconnection intervals for better responsiveness
parent ac456539
...@@ -288,7 +288,7 @@ void *forward_tcp_to_ws(void *arg) { ...@@ -288,7 +288,7 @@ void *forward_tcp_to_ws(void *arg) {
} }
if (FD_ISSET(client_sock, &readfds)) { if (FD_ISSET(client_sock, &readfds)) {
bytes_read = recv(client_sock, buffer, sizeof(buffer), 0); bytes_read = recv(client_sock, buffer, MAX_CHUNK_SIZE, 0);
if (bytes_read <= 0) { if (bytes_read <= 0) {
if (debug) { if (debug) {
if (bytes_read == 0) { if (bytes_read == 0) {
...@@ -319,35 +319,35 @@ void *forward_tcp_to_ws(void *arg) { ...@@ -319,35 +319,35 @@ void *forward_tcp_to_ws(void *arg) {
fflush(stdout); fflush(stdout);
} }
// Convert to hex with bounds checking // Convert to hex with dynamic allocation
// Reserve space for JSON overhead (about 80 characters)
size_t max_hex_size = BUFFER_SIZE - 100;
size_t hex_size = (size_t)bytes_read * 2 + 1; size_t hex_size = (size_t)bytes_read * 2 + 1;
int truncated = 0; char *hex_data = malloc(hex_size);
if (hex_size > max_hex_size) { if (!hex_data) {
if (debug) { if (debug) {
printf("[DEBUG] Hex data too large (%zu bytes), truncating to %zu\n", hex_size, max_hex_size); printf("[DEBUG] Failed to allocate memory for hex data (%zu bytes)\n", hex_size);
fflush(stdout); fflush(stdout);
} }
hex_size = max_hex_size; continue;
truncated = 1;
} }
char hex_data[hex_size];
size_t actual_hex_len = 0; size_t actual_hex_len = 0;
for (int i = 0; i < bytes_read && actual_hex_len < hex_size - 1; i++) { for (int i = 0; i < bytes_read; i++) {
sprintf(hex_data + actual_hex_len, "%02x", (unsigned char)buffer[i]); sprintf(hex_data + actual_hex_len, "%02x", (unsigned char)buffer[i]);
actual_hex_len += 2; actual_hex_len += 2;
} }
hex_data[actual_hex_len] = '\0'; hex_data[actual_hex_len] = '\0';
if (truncated && debug) { // Send as tunnel_data
printf("[DEBUG] Hex data truncated, sent %zu of %d bytes\n", actual_hex_len / 2, bytes_read); size_t msg_size = strlen("{\"type\":\"tunnel_data\",\"request_id\":\"\",\"data\":\"\"}") + strlen(request_id) + actual_hex_len + 1;
char *message = malloc(msg_size);
if (!message) {
if (debug) {
printf("[DEBUG] Failed to allocate memory for message (%zu bytes)\n", msg_size);
fflush(stdout); fflush(stdout);
} }
free(hex_data);
// Send as tunnel_data continue;
char message[BUFFER_SIZE]; }
snprintf(message, sizeof(message), snprintf(message, msg_size,
"{\"type\":\"tunnel_data\",\"request_id\":\"%s\",\"data\":\"%s\"}", "{\"type\":\"tunnel_data\",\"request_id\":\"%s\",\"data\":\"%s\"}",
request_id, hex_data); request_id, hex_data);
...@@ -356,8 +356,13 @@ void *forward_tcp_to_ws(void *arg) { ...@@ -356,8 +356,13 @@ void *forward_tcp_to_ws(void *arg) {
printf("[DEBUG] Failed to send WebSocket frame\n"); printf("[DEBUG] Failed to send WebSocket frame\n");
fflush(stdout); fflush(stdout);
} }
free(hex_data);
free(message);
break; break;
} }
free(hex_data);
free(message);
} }
} }
...@@ -425,7 +430,7 @@ void *forward_ws_to_ssh_server(void *arg) { ...@@ -425,7 +430,7 @@ void *forward_ws_to_ssh_server(void *arg) {
} }
if (FD_ISSET(target_sock, &readfds)) { if (FD_ISSET(target_sock, &readfds)) {
bytes_read = recv(target_sock, buffer, sizeof(buffer), 0); bytes_read = recv(target_sock, buffer, MAX_CHUNK_SIZE, 0);
if (bytes_read <= 0) { if (bytes_read <= 0) {
if (debug) { if (debug) {
printf("[DEBUG - TCPConnection] Target connection closed or error\n"); printf("[DEBUG - TCPConnection] Target connection closed or error\n");
...@@ -439,35 +444,35 @@ void *forward_ws_to_ssh_server(void *arg) { ...@@ -439,35 +444,35 @@ void *forward_ws_to_ssh_server(void *arg) {
fflush(stdout); fflush(stdout);
} }
// Convert to hex with bounds checking // Convert to hex with dynamic allocation
// Reserve space for JSON overhead (about 80 characters)
size_t max_hex_size = BUFFER_SIZE - 100;
size_t hex_size = (size_t)bytes_read * 2 + 1; size_t hex_size = (size_t)bytes_read * 2 + 1;
int truncated = 0; char *hex_data = malloc(hex_size);
if (hex_size > max_hex_size) { if (!hex_data) {
if (debug) { if (debug) {
printf("[DEBUG] Hex data too large (%zu bytes), truncating to %zu\n", hex_size, max_hex_size); printf("[DEBUG] Failed to allocate memory for hex data (%zu bytes)\n", hex_size);
fflush(stdout); fflush(stdout);
} }
hex_size = max_hex_size; continue;
truncated = 1;
} }
char hex_data[hex_size];
size_t actual_hex_len = 0; size_t actual_hex_len = 0;
for (int i = 0; i < bytes_read && actual_hex_len < hex_size - 1; i++) { for (int i = 0; i < bytes_read; i++) {
sprintf(hex_data + actual_hex_len, "%02x", (unsigned char)buffer[i]); sprintf(hex_data + actual_hex_len, "%02x", (unsigned char)buffer[i]);
actual_hex_len += 2; actual_hex_len += 2;
} }
hex_data[actual_hex_len] = '\0'; hex_data[actual_hex_len] = '\0';
if (truncated && debug) { // Send as tunnel_response (from target back to WebSocket)
printf("[DEBUG] Hex data truncated, sent %zu of %d bytes\n", actual_hex_len / 2, bytes_read); size_t msg_size = strlen("{\"type\":\"tunnel_response\",\"request_id\":\"\",\"data\":\"\"}") + strlen(request_id) + actual_hex_len + 1;
char *message = malloc(msg_size);
if (!message) {
if (debug) {
printf("[DEBUG] Failed to allocate memory for message (%zu bytes)\n", msg_size);
fflush(stdout); fflush(stdout);
} }
free(hex_data);
// Send as tunnel_response (from target back to WebSocket) continue;
char message[BUFFER_SIZE]; }
snprintf(message, sizeof(message), snprintf(message, msg_size,
"{\"type\":\"tunnel_response\",\"request_id\":\"%s\",\"data\":\"%s\"}", "{\"type\":\"tunnel_response\",\"request_id\":\"%s\",\"data\":\"%s\"}",
request_id, hex_data); request_id, hex_data);
...@@ -476,8 +481,13 @@ void *forward_ws_to_ssh_server(void *arg) { ...@@ -476,8 +481,13 @@ void *forward_ws_to_ssh_server(void *arg) {
printf("[DEBUG] Failed to send WebSocket frame\n"); printf("[DEBUG] Failed to send WebSocket frame\n");
fflush(stdout); fflush(stdout);
} }
free(hex_data);
free(message);
break; break;
} }
free(hex_data);
free(message);
} }
} }
......
...@@ -100,18 +100,31 @@ int send_registration_message(SSL *ssl, const char *client_id, const char *passw ...@@ -100,18 +100,31 @@ int send_registration_message(SSL *ssl, const char *client_id, const char *passw
} }
int send_websocket_frame(SSL *ssl, const char *data) { int send_websocket_frame(SSL *ssl, const char *data) {
char frame[BUFFER_SIZE];
frame[0] = 0x81; // FIN + text opcode
int msg_len = strlen(data); int msg_len = strlen(data);
int header_len = 2; int header_len = 2;
if (msg_len <= 125) {
header_len = 6; // 2 + 4 for mask
} else if (msg_len <= 65535) {
header_len = 8; // 4 + 4 for mask
} else {
header_len = 14; // 10 + 4 for mask
}
int frame_len = header_len + msg_len;
char *frame = malloc(frame_len);
if (!frame) {
return 0;
}
frame[0] = 0x81; // FIN + text opcode
if (msg_len <= 125) { if (msg_len <= 125) {
frame[1] = 0x80 | msg_len; // MASK + length frame[1] = 0x80 | msg_len; // MASK + length
} else if (msg_len <= 65535) { } else if (msg_len <= 65535) {
frame[1] = 0x80 | 126; // MASK + extended length frame[1] = 0x80 | 126; // MASK + extended length
frame[2] = (msg_len >> 8) & 0xFF; frame[2] = (msg_len >> 8) & 0xFF;
frame[3] = msg_len & 0xFF; frame[3] = msg_len & 0xFF;
header_len = 4;
} else { } else {
frame[1] = 0x80 | 127; // MASK + extended length frame[1] = 0x80 | 127; // MASK + extended length
frame[2] = 0; frame[2] = 0;
...@@ -122,24 +135,20 @@ int send_websocket_frame(SSL *ssl, const char *data) { ...@@ -122,24 +135,20 @@ int send_websocket_frame(SSL *ssl, const char *data) {
frame[7] = (msg_len >> 16) & 0xFF; frame[7] = (msg_len >> 16) & 0xFF;
frame[8] = (msg_len >> 8) & 0xFF; frame[8] = (msg_len >> 8) & 0xFF;
frame[9] = msg_len & 0xFF; frame[9] = msg_len & 0xFF;
header_len = 10;
} }
// Add mask key // Add mask key
char mask_key[4]; char mask_key[4];
for (int i = 0; i < 4; i++) { for (int i = 0; i < 4; i++) {
mask_key[i] = rand() % 256; mask_key[i] = rand() % 256;
frame[header_len + i] = mask_key[i]; frame[header_len - 4 + i] = mask_key[i];
} }
header_len += 4;
// Mask payload // Mask payload
for (int i = 0; i < msg_len; i++) { for (int i = 0; i < msg_len; i++) {
frame[header_len + i] = data[i] ^ mask_key[i % 4]; frame[header_len + i] = data[i] ^ mask_key[i % 4];
} }
int frame_len = header_len + msg_len;
// Handle partial writes for large frames // Handle partial writes for large frames
int total_written = 0; int total_written = 0;
int retry_count = 0; int retry_count = 0;
...@@ -166,6 +175,7 @@ int send_websocket_frame(SSL *ssl, const char *data) { ...@@ -166,6 +175,7 @@ int send_websocket_frame(SSL *ssl, const char *data) {
char error_buf[256]; char error_buf[256];
ERR_error_string_n(ssl_error, error_buf, sizeof(error_buf)); ERR_error_string_n(ssl_error, error_buf, sizeof(error_buf));
fprintf(stderr, "SSL write error details: %s\n", error_buf); fprintf(stderr, "SSL write error details: %s\n", error_buf);
free(frame);
return 0; // Write failed return 0; // Write failed
} }
total_written += written; total_written += written;
...@@ -174,9 +184,11 @@ int send_websocket_frame(SSL *ssl, const char *data) { ...@@ -174,9 +184,11 @@ int send_websocket_frame(SSL *ssl, const char *data) {
if (total_written < frame_len) { if (total_written < frame_len) {
fprintf(stderr, "WebSocket frame write incomplete: %d/%d bytes written\n", total_written, frame_len); fprintf(stderr, "WebSocket frame write incomplete: %d/%d bytes written\n", total_written, frame_len);
free(frame);
return 0; return 0;
} }
free(frame);
return 1; return 1;
} }
......
...@@ -994,6 +994,7 @@ start_forwarding_threads: ...@@ -994,6 +994,7 @@ start_forwarding_threads:
printf("[DEBUG - Tunnel] Received tunnel_close message\n"); printf("[DEBUG - Tunnel] Received tunnel_close message\n");
fflush(stdout); fflush(stdout);
} }
fprintf(stderr, "Error: Tunnel closed by remote client\n");
char *id_start = strstr(buffer, "\"request_id\""); char *id_start = strstr(buffer, "\"request_id\"");
if (id_start) { if (id_start) {
char *colon = strchr(id_start, ':'); char *colon = strchr(id_start, ':');
......
...@@ -29,6 +29,7 @@ ...@@ -29,6 +29,7 @@
#include <getopt.h> #include <getopt.h>
#include <pthread.h> #include <pthread.h>
#include <errno.h> #include <errno.h>
#include <signal.h>
#include "wssshlib.h" #include "wssshlib.h"
#include "websocket.h" #include "websocket.h"
...@@ -39,6 +40,12 @@ ...@@ -39,6 +40,12 @@
int global_debug = 0; int global_debug = 0;
volatile sig_atomic_t sigint_received = 0;
void sigint_handler(int sig __attribute__((unused))) {
sigint_received = 1;
}
typedef struct { typedef struct {
char *wssshd_server; char *wssshd_server;
...@@ -387,6 +394,15 @@ int connect_to_server(const wssshc_config_t *config) { ...@@ -387,6 +394,15 @@ int connect_to_server(const wssshc_config_t *config) {
static int frame_buffer_used = 0; static int frame_buffer_used = 0;
while (1) { while (1) {
// Check for SIGINT
if (sigint_received) {
if (active_tunnel) {
send_tunnel_close(ssl, active_tunnel->request_id, config->debug);
fprintf(stderr, "Received SIGINT, sent tunnel_close and exiting\n");
}
break;
}
// Always try to read more data if there's space, even if we have a complete frame // Always try to read more data if there's space, even if we have a complete frame
// This ensures we can accumulate data for very large frames // This ensures we can accumulate data for very large frames
if ((size_t)frame_buffer_used < sizeof(frame_buffer)) { if ((size_t)frame_buffer_used < sizeof(frame_buffer)) {
...@@ -481,12 +497,14 @@ int connect_to_server(const wssshc_config_t *config) { ...@@ -481,12 +497,14 @@ int connect_to_server(const wssshc_config_t *config) {
if (frame_type == 0x88) { // Close frame if (frame_type == 0x88) { // Close frame
if (config->debug) { if (config->debug) {
printf("[DEBUG - WebSockets] Received close frame - cleaning up and reconnecting...\n"); printf("[DEBUG - WebSockets] Received close frame - treating as tunnel close, keeping WebSocket open\n");
fflush(stdout); fflush(stdout);
} }
// Clean up tunnel resources before reconnecting // Treat as tunnel close, don't close WebSocket connection
cleanup_tunnel(config->debug); if (active_tunnel) {
return 0; handle_tunnel_close(NULL, active_tunnel->request_id, config->debug);
}
// Continue processing, don't return
} else if (frame_type == 0x89) { // Ping frame } else if (frame_type == 0x89) { // Ping frame
if (config->debug) { if (config->debug) {
printf("[DEBUG - WebSockets] Received ping frame\n"); printf("[DEBUG - WebSockets] Received ping frame\n");
...@@ -644,6 +662,9 @@ int main(int argc, char *argv[]) { ...@@ -644,6 +662,9 @@ int main(int argc, char *argv[]) {
pthread_mutex_init(&tunnel_mutex, NULL); pthread_mutex_init(&tunnel_mutex, NULL);
// Set up signal handler for SIGINT
signal(SIGINT, sigint_handler);
// Load configuration files first (system and user configs) // Load configuration files first (system and user configs)
load_config(&config); load_config(&config);
...@@ -691,18 +712,18 @@ int main(int argc, char *argv[]) { ...@@ -691,18 +712,18 @@ int main(int argc, char *argv[]) {
while (1) { while (1) {
int result = connect_to_server(&config); int result = connect_to_server(&config);
if (result == 1) { if (result == 1) {
// Error condition - use normal retry interval // Error condition - use short retry interval for immediate reconnection
printf("Connection lost, retrying in %d seconds...\n", config.interval); printf("Connection lost, retrying in 1 seconds...\n");
sleep(config.interval); sleep(1);
} else if (result == 0) { } else if (result == 0) {
// Close frame received - add small delay to prevent rapid reconnection // Close frame received - use short delay for immediate reconnection
if (config.debug) { if (config.debug) {
printf("[DEBUG - WebSockets] Server initiated disconnect, reconnecting in 2 seconds...\n"); printf("[DEBUG - WebSockets] Server initiated disconnect, reconnecting in 1 seconds...\n");
fflush(stdout); fflush(stdout);
} }
sleep(2); sleep(1);
} }
// result == 2 (registration failed) also uses normal interval // result == 2 (registration failed) also uses shorter interval
} }
// Cleanup // Cleanup
......
...@@ -37,6 +37,7 @@ ...@@ -37,6 +37,7 @@
#include <sys/select.h> #include <sys/select.h>
#define BUFFER_SIZE 1048576 #define BUFFER_SIZE 1048576
#define MAX_CHUNK_SIZE 65536
#define DEFAULT_PORT 22 #define DEFAULT_PORT 22
// Config structures // Config structures
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment