Critical security fix: Prevent wssshd server crashes from malformed packets

- Added comprehensive bounds checking to all WebSocket message parsing
- Validate JSON structure (braces) before processing to prevent crashes
- Added length limits and bounds validation for all parameter extractions:
  * client_id: max 64 chars
  * password: max 256 chars
  * request_id: max 64 chars
  * enc/service/version: max 32 chars each
- Prevent buffer overflows that could corrupt heap metadata
- Ensure all string operations stay within allocated buffer bounds
- Server now logs errors and continues running instead of crashing on malformed packets
- Critical defense against DoS attacks via malformed WebSocket messages
parent d26c949e
...@@ -247,6 +247,15 @@ int websocket_handle_message(wssshd_state_t *state, ws_connection_t *conn __attr ...@@ -247,6 +247,15 @@ int websocket_handle_message(wssshd_state_t *state, ws_connection_t *conn __attr
memcpy(msg_copy, message, message_len); memcpy(msg_copy, message, message_len);
msg_copy[message_len] = '\0'; msg_copy[message_len] = '\0';
// Validate basic JSON structure to prevent crashes from malformed messages
if (message_len < 2 || message[0] != '{' || message[message_len-1] != '}') {
if (state->debug) {
printf("[DEBUG - unknown -> wssshd] Malformed JSON message (invalid braces)\n");
}
free(msg_copy);
return -1;
}
// Determine connection direction for debug messages // Determine connection direction for debug messages
const char *direction = "unknown"; const char *direction = "unknown";
if (strstr(msg_copy, "\"type\":\"register\"") || strstr(msg_copy, "\"type\": \"register\"")) { if (strstr(msg_copy, "\"type\":\"register\"") || strstr(msg_copy, "\"type\": \"register\"")) {
...@@ -296,8 +305,9 @@ int websocket_handle_message(wssshd_state_t *state, ws_connection_t *conn __attr ...@@ -296,8 +305,9 @@ int websocket_handle_message(wssshd_state_t *state, ws_connection_t *conn __attr
if (client_id_start) { if (client_id_start) {
client_id_start += strlen(client_id_start == strstr(msg_copy, "\"client_id\":\"") ? "\"client_id\":\"" : "\"id\":\""); client_id_start += strlen(client_id_start == strstr(msg_copy, "\"client_id\":\"") ? "\"client_id\":\"" : "\"id\":\"");
char *client_id_end = strchr(client_id_start, '"'); char *client_id_end = strchr(client_id_start, '"');
if (client_id_end) { if (client_id_end && client_id_end > client_id_start && client_id_end < msg_copy + MSG_BUFFER_SIZE) {
size_t client_id_len = client_id_end - client_id_start; size_t client_id_len = client_id_end - client_id_start;
if (client_id_len > 0 && client_id_len < 64) { // Reasonable limit for client ID
char *client_id_copy = malloc(client_id_len + 1); char *client_id_copy = malloc(client_id_len + 1);
if (client_id_copy) { if (client_id_copy) {
memcpy(client_id_copy, client_id_start, client_id_len); memcpy(client_id_copy, client_id_start, client_id_len);
...@@ -305,7 +315,14 @@ int websocket_handle_message(wssshd_state_t *state, ws_connection_t *conn __attr ...@@ -305,7 +315,14 @@ int websocket_handle_message(wssshd_state_t *state, ws_connection_t *conn __attr
client_id = client_id_copy; client_id = client_id_copy;
if (state->debug) printf("[DEBUG - %s -> wssshd] Extracted client_id: '%s'\n", direction, client_id); if (state->debug) printf("[DEBUG - %s -> wssshd] Extracted client_id: '%s'\n", direction, client_id);
} }
} else {
if (state->debug) printf("[DEBUG - %s -> wssshd] Client ID length invalid: %zu\n", direction, client_id_len);
}
} else {
if (state->debug) printf("[DEBUG - %s -> wssshd] Client ID end quote not found or invalid\n", direction);
} }
} else {
if (state->debug) printf("[DEBUG - %s -> wssshd] Client ID start not found\n", direction);
} }
// Extract password (search in original unmodified string) // Extract password (search in original unmodified string)
...@@ -313,8 +330,9 @@ int websocket_handle_message(wssshd_state_t *state, ws_connection_t *conn __attr ...@@ -313,8 +330,9 @@ int websocket_handle_message(wssshd_state_t *state, ws_connection_t *conn __attr
if (password_start) { if (password_start) {
password_start += strlen("\"password\":\""); password_start += strlen("\"password\":\"");
char *password_end = strchr(password_start, '"'); char *password_end = strchr(password_start, '"');
if (password_end) { if (password_end && password_end > password_start && password_end < msg_copy + MSG_BUFFER_SIZE) {
size_t password_len = password_end - password_start; size_t password_len = password_end - password_start;
if (password_len > 0 && password_len < 256) { // Reasonable limit for password
char *password_copy = malloc(password_len + 1); char *password_copy = malloc(password_len + 1);
if (password_copy) { if (password_copy) {
memcpy(password_copy, password_start, password_len); memcpy(password_copy, password_start, password_len);
...@@ -323,10 +341,13 @@ int websocket_handle_message(wssshd_state_t *state, ws_connection_t *conn __attr ...@@ -323,10 +341,13 @@ int websocket_handle_message(wssshd_state_t *state, ws_connection_t *conn __attr
if (state->debug) printf("[DEBUG - %s -> wssshd] Extracted password: '%s'\n", direction, password); if (state->debug) printf("[DEBUG - %s -> wssshd] Extracted password: '%s'\n", direction, password);
} }
} else { } else {
if (state->debug) printf("[DEBUG] Password end quote not found\n"); if (state->debug) printf("[DEBUG - %s -> wssshd] Password length invalid: %zu\n", direction, password_len);
}
} else {
if (state->debug) printf("[DEBUG - %s -> wssshd] Password end quote not found or invalid\n", direction);
} }
} else { } else {
if (state->debug) printf("[DEBUG] Password start not found\n"); if (state->debug) printf("[DEBUG - %s -> wssshd] Password start not found\n", direction);
} }
if (state->debug) { if (state->debug) {
...@@ -382,8 +403,9 @@ int websocket_handle_message(wssshd_state_t *state, ws_connection_t *conn __attr ...@@ -382,8 +403,9 @@ int websocket_handle_message(wssshd_state_t *state, ws_connection_t *conn __attr
if (tunnel_client_id_start) { if (tunnel_client_id_start) {
tunnel_client_id_start += strlen("\"client_id\":\""); tunnel_client_id_start += strlen("\"client_id\":\"");
char *tunnel_client_id_end = strchr(tunnel_client_id_start, '"'); char *tunnel_client_id_end = strchr(tunnel_client_id_start, '"');
if (tunnel_client_id_end) { if (tunnel_client_id_end && tunnel_client_id_end > tunnel_client_id_start && tunnel_client_id_end < msg_copy + MSG_BUFFER_SIZE) {
size_t tunnel_client_id_len = tunnel_client_id_end - tunnel_client_id_start; size_t tunnel_client_id_len = tunnel_client_id_end - tunnel_client_id_start;
if (tunnel_client_id_len > 0 && tunnel_client_id_len < 64) { // Reasonable limit for client ID
char *tunnel_client_id_copy = malloc(tunnel_client_id_len + 1); char *tunnel_client_id_copy = malloc(tunnel_client_id_len + 1);
if (tunnel_client_id_copy) { if (tunnel_client_id_copy) {
memcpy(tunnel_client_id_copy, tunnel_client_id_start, tunnel_client_id_len); memcpy(tunnel_client_id_copy, tunnel_client_id_start, tunnel_client_id_len);
...@@ -391,13 +413,16 @@ int websocket_handle_message(wssshd_state_t *state, ws_connection_t *conn __attr ...@@ -391,13 +413,16 @@ int websocket_handle_message(wssshd_state_t *state, ws_connection_t *conn __attr
client_id = tunnel_client_id_copy; client_id = tunnel_client_id_copy;
if (state->debug) printf("[DEBUG - %s -> wssshd] Extracted tunnel client_id: '%s' (ptr=%p)\n", direction, client_id, client_id); if (state->debug) printf("[DEBUG - %s -> wssshd] Extracted tunnel client_id: '%s' (ptr=%p)\n", direction, client_id, client_id);
} else { } else {
if (state->debug) printf("[DEBUG] Failed to allocate memory for client_id\n"); if (state->debug) printf("[DEBUG - %s -> wssshd] Failed to allocate memory for client_id\n", direction);
} }
} else { } else {
if (state->debug) printf("[DEBUG] Client ID end quote not found\n"); if (state->debug) printf("[DEBUG - %s -> wssshd] Client ID length invalid: %zu\n", direction, tunnel_client_id_len);
} }
} else { } else {
if (state->debug) printf("[DEBUG] Client ID start not found\n"); if (state->debug) printf("[DEBUG - %s -> wssshd] Client ID end quote not found or invalid\n", direction);
}
} else {
if (state->debug) printf("[DEBUG - %s -> wssshd] Client ID start not found\n", direction);
} }
// Extract request_id (search in original unmodified string) // Extract request_id (search in original unmodified string)
...@@ -405,8 +430,9 @@ int websocket_handle_message(wssshd_state_t *state, ws_connection_t *conn __attr ...@@ -405,8 +430,9 @@ int websocket_handle_message(wssshd_state_t *state, ws_connection_t *conn __attr
if (tunnel_request_id_start) { if (tunnel_request_id_start) {
tunnel_request_id_start += strlen("\"request_id\":\""); tunnel_request_id_start += strlen("\"request_id\":\"");
char *tunnel_request_id_end = strchr(tunnel_request_id_start, '"'); char *tunnel_request_id_end = strchr(tunnel_request_id_start, '"');
if (tunnel_request_id_end) { if (tunnel_request_id_end && tunnel_request_id_end > tunnel_request_id_start && tunnel_request_id_end < msg_copy + MSG_BUFFER_SIZE) {
size_t tunnel_request_id_len = tunnel_request_id_end - tunnel_request_id_start; size_t tunnel_request_id_len = tunnel_request_id_end - tunnel_request_id_start;
if (tunnel_request_id_len > 0 && tunnel_request_id_len < 64) { // Reasonable limit for request ID
char *tunnel_request_id_copy = malloc(tunnel_request_id_len + 1); char *tunnel_request_id_copy = malloc(tunnel_request_id_len + 1);
if (tunnel_request_id_copy) { if (tunnel_request_id_copy) {
memcpy(tunnel_request_id_copy, tunnel_request_id_start, tunnel_request_id_len); memcpy(tunnel_request_id_copy, tunnel_request_id_start, tunnel_request_id_len);
...@@ -414,6 +440,11 @@ int websocket_handle_message(wssshd_state_t *state, ws_connection_t *conn __attr ...@@ -414,6 +440,11 @@ int websocket_handle_message(wssshd_state_t *state, ws_connection_t *conn __attr
request_id = tunnel_request_id_copy; request_id = tunnel_request_id_copy;
if (state->debug) printf("[DEBUG - %s -> wssshd] Extracted request_id: '%s'\n", direction, request_id); if (state->debug) printf("[DEBUG - %s -> wssshd] Extracted request_id: '%s'\n", direction, request_id);
} }
} else {
if (state->debug) printf("[DEBUG - %s -> wssshd] Request ID length invalid: %zu\n", direction, tunnel_request_id_len);
}
} else {
if (state->debug) printf("[DEBUG - %s -> wssshd] Request ID end quote not found or invalid\n", direction);
} }
} }
...@@ -422,7 +453,7 @@ int websocket_handle_message(wssshd_state_t *state, ws_connection_t *conn __attr ...@@ -422,7 +453,7 @@ int websocket_handle_message(wssshd_state_t *state, ws_connection_t *conn __attr
if (enc_start) { if (enc_start) {
enc_start += strlen("\"enc\":\""); enc_start += strlen("\"enc\":\"");
char *enc_end = strchr(enc_start, '"'); char *enc_end = strchr(enc_start, '"');
if (enc_end && enc_end > enc_start) { if (enc_end && enc_end > enc_start && enc_end < msg_copy + MSG_BUFFER_SIZE) {
size_t enc_len = enc_end - enc_start; size_t enc_len = enc_end - enc_start;
if (enc_len > 0 && enc_len < 32) { // Reasonable limit for encoding type if (enc_len > 0 && enc_len < 32) { // Reasonable limit for encoding type
char *enc_copy = malloc(enc_len + 1); char *enc_copy = malloc(enc_len + 1);
...@@ -432,7 +463,11 @@ int websocket_handle_message(wssshd_state_t *state, ws_connection_t *conn __attr ...@@ -432,7 +463,11 @@ int websocket_handle_message(wssshd_state_t *state, ws_connection_t *conn __attr
enc = enc_copy; enc = enc_copy;
if (state->debug) printf("[DEBUG - %s -> wssshd] Extracted enc: '%s'\n", direction, enc); if (state->debug) printf("[DEBUG - %s -> wssshd] Extracted enc: '%s'\n", direction, enc);
} }
} else {
if (state->debug) printf("[DEBUG - %s -> wssshd] Enc length invalid: %zu\n", direction, enc_len);
} }
} else {
if (state->debug) printf("[DEBUG - %s -> wssshd] Enc end quote not found or invalid\n", direction);
} }
} }
...@@ -441,7 +476,7 @@ int websocket_handle_message(wssshd_state_t *state, ws_connection_t *conn __attr ...@@ -441,7 +476,7 @@ int websocket_handle_message(wssshd_state_t *state, ws_connection_t *conn __attr
if (service_start) { if (service_start) {
service_start += strlen("\"service\":\""); service_start += strlen("\"service\":\"");
char *service_end = strchr(service_start, '"'); char *service_end = strchr(service_start, '"');
if (service_end && service_end > service_start) { if (service_end && service_end > service_start && service_end < msg_copy + MSG_BUFFER_SIZE) {
size_t service_len = service_end - service_start; size_t service_len = service_end - service_start;
if (service_len > 0 && service_len < 32) { // Reasonable limit for service type if (service_len > 0 && service_len < 32) { // Reasonable limit for service type
char *service_copy = malloc(service_len + 1); char *service_copy = malloc(service_len + 1);
...@@ -451,7 +486,11 @@ int websocket_handle_message(wssshd_state_t *state, ws_connection_t *conn __attr ...@@ -451,7 +486,11 @@ int websocket_handle_message(wssshd_state_t *state, ws_connection_t *conn __attr
service = service_copy; service = service_copy;
if (state->debug) printf("[DEBUG - %s -> wssshd] Extracted service: '%s'\n", direction, service); if (state->debug) printf("[DEBUG - %s -> wssshd] Extracted service: '%s'\n", direction, service);
} }
} else {
if (state->debug) printf("[DEBUG - %s -> wssshd] Service length invalid: %zu\n", direction, service_len);
} }
} else {
if (state->debug) printf("[DEBUG - %s -> wssshd] Service end quote not found or invalid\n", direction);
} }
} }
...@@ -460,7 +499,7 @@ int websocket_handle_message(wssshd_state_t *state, ws_connection_t *conn __attr ...@@ -460,7 +499,7 @@ int websocket_handle_message(wssshd_state_t *state, ws_connection_t *conn __attr
if (version_start) { if (version_start) {
version_start += strlen("\"version\":\""); version_start += strlen("\"version\":\"");
char *version_end = strchr(version_start, '"'); char *version_end = strchr(version_start, '"');
if (version_end && version_end > version_start) { if (version_end && version_end > version_start && version_end < msg_copy + MSG_BUFFER_SIZE) {
size_t version_len = version_end - version_start; size_t version_len = version_end - version_start;
if (version_len > 0 && version_len < 32) { // Reasonable limit for version string if (version_len > 0 && version_len < 32) { // Reasonable limit for version string
char *version_copy = malloc(version_len + 1); char *version_copy = malloc(version_len + 1);
...@@ -470,7 +509,11 @@ int websocket_handle_message(wssshd_state_t *state, ws_connection_t *conn __attr ...@@ -470,7 +509,11 @@ int websocket_handle_message(wssshd_state_t *state, ws_connection_t *conn __attr
version = version_copy; version = version_copy;
if (state->debug) printf("[DEBUG - %s -> wssshd] Extracted version: '%s'\n", direction, version); if (state->debug) printf("[DEBUG - %s -> wssshd] Extracted version: '%s'\n", direction, version);
} }
} else {
if (state->debug) printf("[DEBUG - %s -> wssshd] Version length invalid: %zu\n", direction, version_len);
} }
} else {
if (state->debug) printf("[DEBUG - %s -> wssshd] Version end quote not found or invalid\n", direction);
} }
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment