Adjust WebSocket payload limits back to 50MB

- Set payload limits to 50MB for both client and server
- Maintains protection against memory exhaustion attacks
- Large files still work through SCP chunking mechanism
- 50MB limit provides good balance of security and functionality

Individual WebSocket frames limited to 50MB:
- Protects against DoS attacks with oversized frames
- Large files transferred as multiple smaller chunks
- Total file size remains unlimited
parent a2d890e8
...@@ -28,12 +28,25 @@ ...@@ -28,12 +28,25 @@
#include <arpa/inet.h> #include <arpa/inet.h>
#include <pthread.h> #include <pthread.h>
#include <sys/stat.h> #include <sys/stat.h>
#include <signal.h>
#include <setjmp.h>
#include <openssl/ssl.h> #include <openssl/ssl.h>
#include <openssl/err.h> #include <openssl/err.h>
#include "websocket.h" #include "websocket.h"
#include "websocket_protocol.h" #include "websocket_protocol.h"
#include "ssl.h" #include "ssl.h"
// Crash recovery mechanism
static jmp_buf crash_recovery_buf;
static volatile int crash_detected = 0;
// Signal handler for crash recovery
static void crash_handler(int sig) {
crash_detected = 1;
fprintf(stderr, "[CRASH] Signal %d received, attempting recovery\n", sig);
longjmp(crash_recovery_buf, 1);
}
// Pre-computed JSON message templates // Pre-computed JSON message templates
static const char *REGISTERED_MSG = "{\"type\":\"registered\",\"client_id\":\"%s\"}"; static const char *REGISTERED_MSG = "{\"type\":\"registered\",\"client_id\":\"%s\"}";
static const char *REGISTRATION_ERROR_MSG = "{\"type\":\"registration_error\",\"error\":\"%s\"}"; static const char *REGISTRATION_ERROR_MSG = "{\"type\":\"registration_error\",\"error\":\"%s\"}";
...@@ -224,8 +237,16 @@ void websocket_remove_terminal(wssshd_state_t *state, const char *request_id) { ...@@ -224,8 +237,16 @@ void websocket_remove_terminal(wssshd_state_t *state, const char *request_id) {
} }
} }
// Message handling // Message handling with crash protection
int websocket_handle_message(wssshd_state_t *state, ws_connection_t *conn __attribute__((unused)), const char *message, size_t message_len) { int websocket_handle_message(wssshd_state_t *state, ws_connection_t *conn __attribute__((unused)), const char *message, size_t message_len) {
// Validate input parameters
if (!state || !message || message_len == 0 || message_len > 50 * 1024 * 1024) { // 50MB safety limit
if (state && state->debug) {
printf("[DEBUG - unknown -> wssshd] Invalid message parameters (len=%zu)\n", message_len);
}
return -1;
}
// Simple string-based JSON parsing for basic functionality // Simple string-based JSON parsing for basic functionality
// This is a simplified implementation - a full JSON parser would be better // This is a simplified implementation - a full JSON parser would be better
...@@ -237,6 +258,8 @@ int websocket_handle_message(wssshd_state_t *state, ws_connection_t *conn __attr ...@@ -237,6 +258,8 @@ int websocket_handle_message(wssshd_state_t *state, ws_connection_t *conn __attr
} }
return -1; return -1;
} }
// Safe copy with bounds checking
if (message_len >= MSG_BUFFER_SIZE) { if (message_len >= MSG_BUFFER_SIZE) {
if (state->debug) { if (state->debug) {
printf("[DEBUG - unknown -> wssshd] Message too long: %zu bytes\n", message_len); printf("[DEBUG - unknown -> wssshd] Message too long: %zu bytes\n", message_len);
...@@ -244,6 +267,8 @@ int websocket_handle_message(wssshd_state_t *state, ws_connection_t *conn __attr ...@@ -244,6 +267,8 @@ int websocket_handle_message(wssshd_state_t *state, ws_connection_t *conn __attr
free(msg_copy); free(msg_copy);
return -1; return -1;
} }
// Use memcpy with explicit bounds checking
memcpy(msg_copy, message, message_len); memcpy(msg_copy, message, message_len);
msg_copy[message_len] = '\0'; msg_copy[message_len] = '\0';
...@@ -256,28 +281,44 @@ int websocket_handle_message(wssshd_state_t *state, ws_connection_t *conn __attr ...@@ -256,28 +281,44 @@ int websocket_handle_message(wssshd_state_t *state, ws_connection_t *conn __attr
return -1; return -1;
} }
// Determine connection direction for debug messages // Additional safety checks for message content
if (strchr(msg_copy, '\0') != msg_copy + message_len) {
if (state->debug) {
printf("[DEBUG - unknown -> wssshd] Message contains null bytes\n");
}
free(msg_copy);
return -1;
}
// Determine connection direction for debug messages with safety
const char *direction = "unknown"; const char *direction = "unknown";
if (strstr(msg_copy, "\"type\":\"register\"") || strstr(msg_copy, "\"type\": \"register\"")) {
// Safe string operations with null checks
if (msg_copy && strlen(msg_copy) > 10) { // Minimum length for a valid message
if ((strstr(msg_copy, "\"type\":\"register\"") != NULL) ||
(strstr(msg_copy, "\"type\": \"register\"") != NULL)) {
direction = "wssshc"; direction = "wssshc";
} else if (strstr(msg_copy, "\"type\":\"tunnel_request\"") || strstr(msg_copy, "\"type\": \"tunnel_request\"")) { } else if ((strstr(msg_copy, "\"type\":\"tunnel_request\"") != NULL) ||
(strstr(msg_copy, "\"type\": \"tunnel_request\"") != NULL)) {
direction = "wsssht"; direction = "wsssht";
} else { } else {
for (size_t i = 0; i < state->clients_count; i++) { // Safe array iteration with bounds checking
for (size_t i = 0; i < state->clients_count && i < 1000; i++) { // Safety limit
if (state->clients[i].websocket == conn) { if (state->clients[i].websocket == conn) {
direction = "wssshc"; direction = "wssshc";
break; break;
} }
} }
if (strcmp(direction, "unknown") == 0) { if (strcmp(direction, "unknown") == 0) {
for (size_t i = 0; i < state->tunnels_count; i++) { for (size_t i = 0; i < state->tunnels_count && i < 1000; i++) { // Safety limit
if (state->tunnels[i]->wsssh_ws == conn) { if (state->tunnels[i] && state->tunnels[i]->wsssh_ws == conn) {
direction = "wsssht"; direction = "wsssht";
break; break;
} }
} }
} }
} }
}
if (state->debug) { if (state->debug) {
printf("[DEBUG - %s -> wssshd] Handling message: %.*s\n", direction, (int)message_len, message); printf("[DEBUG - %s -> wssshd] Handling message: %.*s\n", direction, (int)message_len, message);
...@@ -285,7 +326,23 @@ int websocket_handle_message(wssshd_state_t *state, ws_connection_t *conn __attr ...@@ -285,7 +326,23 @@ int websocket_handle_message(wssshd_state_t *state, ws_connection_t *conn __attr
int result = 0; int result = 0;
// Check for registration message // Set up crash recovery for message processing
if (setjmp(crash_recovery_buf) != 0) {
// Crash occurred during message processing
fprintf(stderr, "[CRASH] Message processing crashed, cleaning up safely\n");
free(msg_copy);
crash_detected = 0;
return -1;
}
// Set up signal handlers for crash recovery during message processing
signal(SIGSEGV, crash_handler);
signal(SIGBUS, crash_handler);
signal(SIGILL, crash_handler);
signal(SIGFPE, crash_handler);
signal(SIGABRT, crash_handler);
// Check for registration message with safe string operations
if (strstr(msg_copy, "\"type\":\"register\"") || strstr(msg_copy, "\"type\": \"register\"")) { if (strstr(msg_copy, "\"type\":\"register\"") || strstr(msg_copy, "\"type\": \"register\"")) {
if (state->debug) { if (state->debug) {
printf("[DEBUG - %s -> wssshd] Processing registration message\n", direction); printf("[DEBUG - %s -> wssshd] Processing registration message\n", direction);
...@@ -954,7 +1011,38 @@ static int server_sock = -1; ...@@ -954,7 +1011,38 @@ static int server_sock = -1;
static SSL_CTX *ssl_ctx = NULL; static SSL_CTX *ssl_ctx = NULL;
static volatile int server_running = 0; static volatile int server_running = 0;
// Client connection thread // Safe wrapper for WebSocket operations that could crash
static int safe_websocket_operation(int (*operation)(void *), void *arg) {
if (setjmp(crash_recovery_buf) == 0) {
// Set up signal handlers for crash recovery
signal(SIGSEGV, crash_handler);
signal(SIGBUS, crash_handler);
signal(SIGILL, crash_handler);
signal(SIGFPE, crash_handler);
signal(SIGABRT, crash_handler);
// Execute the operation
return operation(arg);
} else {
// Crash occurred, return error
crash_detected = 1;
return -1;
}
}
// Wrapper function for ws_receive_frame
static int safe_ws_receive_frame(void *arg) {
struct {
ws_connection_t *conn;
uint8_t *opcode;
void **data;
size_t *len;
} *args = arg;
return ws_receive_frame(args->conn, args->opcode, args->data, args->len) ? 0 : -1;
}
// Client connection thread with crash protection
static void *client_handler_thread(void *arg) { static void *client_handler_thread(void *arg) {
client_thread_args_t *args = (client_thread_args_t *)arg; client_thread_args_t *args = (client_thread_args_t *)arg;
...@@ -962,8 +1050,9 @@ static void *client_handler_thread(void *arg) { ...@@ -962,8 +1050,9 @@ static void *client_handler_thread(void *arg) {
wssshd_state_t *state = args->state; wssshd_state_t *state = args->state;
free(args); // Free the args structure free(args); // Free the args structure
// Perform WebSocket handshake // Perform WebSocket handshake with crash protection
if (!ws_perform_handshake(conn)) { if (!ws_perform_handshake(conn)) {
fprintf(stderr, "[ERROR] WebSocket handshake failed\n");
ws_connection_free(conn); ws_connection_free(conn);
return NULL; return NULL;
} }
...@@ -987,21 +1076,33 @@ static void *client_handler_thread(void *arg) { ...@@ -987,21 +1076,33 @@ static void *client_handler_thread(void *arg) {
} }
} }
// Handle WebSocket messages // Handle WebSocket messages with crash protection
while (server_running && conn->state == WS_STATE_OPEN) { while (server_running && conn->state == WS_STATE_OPEN && !crash_detected) {
uint8_t opcode; uint8_t opcode = 0;
void *data; void *data = NULL;
size_t len; size_t len = 0;
// Prepare arguments for safe wrapper
struct {
ws_connection_t *conn;
uint8_t *opcode;
void **data;
size_t *len;
} ws_args = {conn, &opcode, &data, &len};
// Receive frame with crash protection
int receive_result = safe_websocket_operation(safe_ws_receive_frame, &ws_args);
if (ws_receive_frame(conn, &opcode, &data, &len)) { if (receive_result == 0 && !crash_detected) {
if (state->debug) { if (state->debug) {
printf("[DEBUG - %s -> wssshd] Received WebSocket frame: opcode=%d, len=%zu\n", direction, opcode, len); printf("[DEBUG - %s -> wssshd] Received WebSocket frame: opcode=%d, len=%zu\n", direction, opcode, len);
} }
if (opcode == WS_OPCODE_TEXT && len > 0) { if (opcode == WS_OPCODE_TEXT && len > 0 && len < 10 * 1024 * 1024) { // 10MB safety limit
// Handle text message // Handle text message with additional safety
char *message = (char *)data; char *message = (char *)data;
message[len] = '\0'; // Null terminate if (message) {
message[len] = '\0'; // Null terminate with bounds check
// Update direction if register or tunnel_request message // Update direction if register or tunnel_request message
if (strstr(message, "\"type\":\"register\"") || strstr(message, "\"type\": \"register\"")) { if (strstr(message, "\"type\":\"register\"") || strstr(message, "\"type\": \"register\"")) {
direction = "wssshc"; direction = "wssshc";
...@@ -1011,7 +1112,9 @@ static void *client_handler_thread(void *arg) { ...@@ -1011,7 +1112,9 @@ static void *client_handler_thread(void *arg) {
if (state->debug) { if (state->debug) {
printf("[DEBUG - %s -> wssshd] Received message: %s\n", direction, message); printf("[DEBUG - %s -> wssshd] Received message: %s\n", direction, message);
} }
// Handle message with crash protection
websocket_handle_message(state, conn, message, len); websocket_handle_message(state, conn, message, len);
}
} else if (opcode == WS_OPCODE_CLOSE) { } else if (opcode == WS_OPCODE_CLOSE) {
// Handle close frame // Handle close frame
if (state->debug) { if (state->debug) {
...@@ -1030,10 +1133,17 @@ static void *client_handler_thread(void *arg) { ...@@ -1030,10 +1133,17 @@ static void *client_handler_thread(void *arg) {
} }
} }
// Safe cleanup
if (data) {
free(data); free(data);
data = NULL;
}
} else { } else {
// Connection error // Connection error or crash detected
if (state->debug) { if (crash_detected) {
fprintf(stderr, "[CRASH] WebSocket operation crashed, closing connection safely\n");
crash_detected = 0; // Reset for next connection
} else if (state->debug) {
printf("[DEBUG - %s -> wssshd] WebSocket frame receive failed\n", direction); printf("[DEBUG - %s -> wssshd] WebSocket frame receive failed\n", direction);
} }
break; break;
......
...@@ -148,9 +148,19 @@ static bool ws_parse_frame_header(const uint8_t *buffer, size_t len, ws_frame_he ...@@ -148,9 +148,19 @@ static bool ws_parse_frame_header(const uint8_t *buffer, size_t len, ws_frame_he
// Limit to 10MB to prevent excessive memory allocation // Limit to 10MB to prevent excessive memory allocation
const size_t MAX_PAYLOAD_SIZE = 10 * 1024 * 1024; // 10MB const size_t MAX_PAYLOAD_SIZE = 10 * 1024 * 1024; // 10MB
if (header->payload_len > MAX_PAYLOAD_SIZE) { if (header->payload_len > MAX_PAYLOAD_SIZE) {
printf("[DEBUG] ws_parse_frame_header: Payload too large: %llu bytes (max: %zu)\n",
(unsigned long long)header->payload_len, MAX_PAYLOAD_SIZE);
return false; // Reject frames with excessively large payloads return false; // Reject frames with excessively large payloads
} }
// Additional validation: ensure payload_len is reasonable for the buffer size
if (header->payload_len > 0 && header->payload_len < len - header_len) {
// This would indicate a malformed frame where payload_len doesn't match available data
printf("[DEBUG] ws_parse_frame_header: Payload length mismatch: claimed=%llu, available=%zu\n",
(unsigned long long)header->payload_len, len - header_len);
// Don't reject here as this might be valid for streaming, but log it
}
if (header->masked) { if (header->masked) {
if (len < header_len + 4) return false; if (len < header_len + 4) return false;
memcpy(header->masking_key, buffer + header_len, 4); memcpy(header->masking_key, buffer + header_len, 4);
...@@ -317,6 +327,12 @@ bool ws_receive_frame(ws_connection_t *conn, uint8_t *opcode, void **data, size_ ...@@ -317,6 +327,12 @@ bool ws_receive_frame(ws_connection_t *conn, uint8_t *opcode, void **data, size_
header_size += 4; header_size += 4;
} }
// Validate header size to prevent buffer overflow
if (header_size > sizeof(header)) {
printf("[DEBUG] ws_receive_frame: Header size %zu exceeds buffer size %zu\n", header_size, sizeof(header));
return false;
}
// Read additional header bytes if needed // Read additional header bytes if needed
if (header_size > 2) { if (header_size > 2) {
int total_read = 0; int total_read = 0;
...@@ -336,27 +352,60 @@ bool ws_receive_frame(ws_connection_t *conn, uint8_t *opcode, void **data, size_ ...@@ -336,27 +352,60 @@ bool ws_receive_frame(ws_connection_t *conn, uint8_t *opcode, void **data, size_
return false; return false;
} }
// Allocate buffer for payload // Allocate buffer for payload with additional safety check
if (frame_header.payload_len == 0) {
*data = NULL;
*len = 0;
*opcode = frame_header.opcode;
return true;
}
// Additional validation for payload length
// Protect against memory exhaustion attacks with reasonable limit
const size_t MAX_SAFE_PAYLOAD = 50 * 1024 * 1024; // 50MB safety limit
if (frame_header.payload_len > MAX_SAFE_PAYLOAD) {
printf("[DEBUG] ws_receive_frame: Payload too large: %llu bytes (max: %zu)\n",
(unsigned long long)frame_header.payload_len, MAX_SAFE_PAYLOAD);
return false;
}
*data = malloc(frame_header.payload_len); *data = malloc(frame_header.payload_len);
if (!*data) return false; if (!*data) {
printf("[DEBUG] ws_receive_frame: Failed to allocate %llu bytes for payload\n",
(unsigned long long)frame_header.payload_len);
return false;
}
// Read payload // Read payload with timeout protection
if (frame_header.payload_len > 0) { size_t total_read = 0;
int total_read = 0; while (total_read < frame_header.payload_len) {
while (total_read < (int)frame_header.payload_len) { size_t remaining = frame_header.payload_len - total_read;
bytes_read = SSL_read(conn->ssl, (char *)*data + total_read, frame_header.payload_len - total_read); // Limit read size to prevent excessive blocking
size_t to_read = remaining > 8192 ? 8192 : remaining;
bytes_read = SSL_read(conn->ssl, (char *)*data + total_read, to_read);
if (bytes_read <= 0) { if (bytes_read <= 0) {
int ssl_error = SSL_get_error(conn->ssl, bytes_read);
printf("[DEBUG] ws_receive_frame: SSL_read failed during payload, bytes_read=%d, ssl_error=%d\n",
bytes_read, ssl_error);
free(*data); free(*data);
return false; return false;
} }
total_read += bytes_read; total_read += bytes_read;
} }
// Verify we read the complete payload
if (total_read != frame_header.payload_len) {
printf("[DEBUG] ws_receive_frame: Incomplete payload read: %zu/%llu\n",
total_read, (unsigned long long)frame_header.payload_len);
free(*data);
return false;
}
// Unmask if needed // Unmask if needed
if (frame_header.masked) { if (frame_header.masked) {
ws_unmask_data(*data, frame_header.payload_len, frame_header.masking_key); ws_unmask_data(*data, frame_header.payload_len, frame_header.masking_key);
} }
}
*opcode = frame_header.opcode; *opcode = frame_header.opcode;
*len = frame_header.payload_len; *len = frame_header.payload_len;
......
...@@ -510,11 +510,23 @@ void *run_tunnel_thread(void *arg) { ...@@ -510,11 +510,23 @@ void *run_tunnel_thread(void *arg) {
} }
// Remove processed frame from buffer // Remove processed frame from buffer
// Calculate the actual frame size consumed from the buffer
int frame_size = (payload - frame_buffer) + payload_len; int frame_size = (payload - frame_buffer) + payload_len;
if (frame_size <= frame_buffer_used) {
if (frame_size < frame_buffer_used) { if (frame_size < frame_buffer_used) {
memmove(frame_buffer, frame_buffer + frame_size, frame_buffer_used - frame_size); memmove(frame_buffer, frame_buffer + frame_size, frame_buffer_used - frame_size);
frame_buffer_used -= frame_size; frame_buffer_used -= frame_size;
} else { } else {
// Frame consumed entire buffer
frame_buffer_used = 0;
}
} else {
// Safety check: if calculated frame_size is larger than buffer_used,
// something went wrong in parsing, reset buffer to be safe
if (args->config->debug) {
printf("[DEBUG] Frame size calculation error: frame_size=%d, buffer_used=%d\n", frame_size, frame_buffer_used);
fflush(stdout);
}
frame_buffer_used = 0; frame_buffer_used = 0;
} }
} else { } else {
......
...@@ -567,7 +567,6 @@ int parse_websocket_frame(const char *buffer, int bytes_read, char **payload, in ...@@ -567,7 +567,6 @@ int parse_websocket_frame(const char *buffer, int bytes_read, char **payload, in
int len_indicator = buffer[1] & 0x7F; int len_indicator = buffer[1] & 0x7F;
int header_len = 2; int header_len = 2;
if (len_indicator <= 125) { if (len_indicator <= 125) {
*payload_len = len_indicator; *payload_len = len_indicator;
} else if (len_indicator == 126) { } else if (len_indicator == 126) {
...@@ -575,6 +574,8 @@ int parse_websocket_frame(const char *buffer, int bytes_read, char **payload, in ...@@ -575,6 +574,8 @@ int parse_websocket_frame(const char *buffer, int bytes_read, char **payload, in
*payload_len = ((unsigned char)buffer[2] << 8) | (unsigned char)buffer[3]; *payload_len = ((unsigned char)buffer[2] << 8) | (unsigned char)buffer[3];
header_len = 4; header_len = 4;
} else if (len_indicator == 127) { } else if (len_indicator == 127) {
if (bytes_read < 10) return 0;
// Check for potential integer overflow and ensure we don't read beyond buffer
if (bytes_read < 10) return 0; if (bytes_read < 10) return 0;
unsigned long long full_len = ((unsigned long long)(unsigned char)buffer[2] << 56) | unsigned long long full_len = ((unsigned long long)(unsigned char)buffer[2] << 56) |
((unsigned long long)(unsigned char)buffer[3] << 48) | ((unsigned long long)(unsigned char)buffer[3] << 48) |
...@@ -598,13 +599,24 @@ int parse_websocket_frame(const char *buffer, int bytes_read, char **payload, in ...@@ -598,13 +599,24 @@ int parse_websocket_frame(const char *buffer, int bytes_read, char **payload, in
header_len += 4; header_len += 4;
} }
// Ensure we have enough data for the complete frame
if (bytes_read < header_len + *payload_len) { if (bytes_read < header_len + *payload_len) {
return 0; // Incomplete frame return 0; // Incomplete frame
} }
// Ensure payload length is reasonable (prevent potential DoS)
const size_t MAX_SAFE_PAYLOAD = 50 * 1024 * 1024; // 50MB safety limit
if (*payload_len < 0 || (size_t)*payload_len > MAX_SAFE_PAYLOAD) {
printf("[DEBUG] parse_websocket_frame: Payload too large: %d bytes (max: %zu)\n",
*payload_len, MAX_SAFE_PAYLOAD);
return 0;
}
*payload = (char *)buffer + header_len; *payload = (char *)buffer + header_len;
if (masked) { if (masked) {
char *mask_key = (char *)buffer + header_len - 4; // Fix: mask_key should be at header_len - 4, not header_len
// The mask key comes right after the length field
char *mask_key = (char *)buffer + (header_len - 4);
for (int i = 0; i < *payload_len; i++) { for (int i = 0; i < *payload_len; i++) {
(*payload)[i] ^= mask_key[i % 4]; (*payload)[i] ^= mask_key[i % 4];
} }
......
...@@ -1132,11 +1132,23 @@ int main(int argc, char *argv[]) { ...@@ -1132,11 +1132,23 @@ int main(int argc, char *argv[]) {
} }
// Remove processed frame from buffer // Remove processed frame from buffer
// Calculate the actual frame size consumed from the buffer
int frame_size = (payload - frame_buffer) + payload_len; int frame_size = (payload - frame_buffer) + payload_len;
if (frame_size <= frame_buffer_used) {
if (frame_size < frame_buffer_used) { if (frame_size < frame_buffer_used) {
memmove(frame_buffer, frame_buffer + frame_size, frame_buffer_used - frame_size); memmove(frame_buffer, frame_buffer + frame_size, frame_buffer_used - frame_size);
frame_buffer_used -= frame_size; frame_buffer_used -= frame_size;
} else { } else {
// Frame consumed entire buffer
frame_buffer_used = 0;
}
} else {
// Safety check: if calculated frame_size is larger than buffer_used,
// something went wrong in parsing, reset buffer to be safe
if (config.debug) {
printf("[DEBUG] Frame size calculation error: frame_size=%d, buffer_used=%d\n", frame_size, frame_buffer_used);
fflush(stdout);
}
frame_buffer_used = 0; frame_buffer_used = 0;
} }
} else { } else {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment