Critical security fix: Prevent wssshd server crashes from malformed packets

- Added comprehensive bounds checking to all WebSocket message parsing
- Validate JSON structure (braces) before processing to prevent crashes
- Added length limits and bounds validation for all parameter extractions:
  * client_id: max 64 chars
  * password: max 256 chars
  * request_id: max 64 chars
  * enc/service/version: max 32 chars each
- Prevent buffer overflows that could corrupt heap metadata
- Ensure all string operations stay within allocated buffer bounds
- Server now logs errors and continues running instead of crashing on malformed packets
- Critical defense against DoS attacks via malformed WebSocket messages
parent d26c949e
......@@ -247,6 +247,15 @@ int websocket_handle_message(wssshd_state_t *state, ws_connection_t *conn __attr
memcpy(msg_copy, message, message_len);
msg_copy[message_len] = '\0';
// Validate basic JSON structure to prevent crashes from malformed messages
if (message_len < 2 || message[0] != '{' || message[message_len-1] != '}') {
if (state->debug) {
printf("[DEBUG - unknown -> wssshd] Malformed JSON message (invalid braces)\n");
}
free(msg_copy);
return -1;
}
// Determine connection direction for debug messages
const char *direction = "unknown";
if (strstr(msg_copy, "\"type\":\"register\"") || strstr(msg_copy, "\"type\": \"register\"")) {
......@@ -296,16 +305,24 @@ int websocket_handle_message(wssshd_state_t *state, ws_connection_t *conn __attr
if (client_id_start) {
client_id_start += strlen(client_id_start == strstr(msg_copy, "\"client_id\":\"") ? "\"client_id\":\"" : "\"id\":\"");
char *client_id_end = strchr(client_id_start, '"');
if (client_id_end) {
if (client_id_end && client_id_end > client_id_start && client_id_end < msg_copy + MSG_BUFFER_SIZE) {
size_t client_id_len = client_id_end - client_id_start;
char *client_id_copy = malloc(client_id_len + 1);
if (client_id_copy) {
memcpy(client_id_copy, client_id_start, client_id_len);
client_id_copy[client_id_len] = '\0';
client_id = client_id_copy;
if (state->debug) printf("[DEBUG - %s -> wssshd] Extracted client_id: '%s'\n", direction, client_id);
if (client_id_len > 0 && client_id_len < 64) { // Reasonable limit for client ID
char *client_id_copy = malloc(client_id_len + 1);
if (client_id_copy) {
memcpy(client_id_copy, client_id_start, client_id_len);
client_id_copy[client_id_len] = '\0';
client_id = client_id_copy;
if (state->debug) printf("[DEBUG - %s -> wssshd] Extracted client_id: '%s'\n", direction, client_id);
}
} else {
if (state->debug) printf("[DEBUG - %s -> wssshd] Client ID length invalid: %zu\n", direction, client_id_len);
}
} else {
if (state->debug) printf("[DEBUG - %s -> wssshd] Client ID end quote not found or invalid\n", direction);
}
} else {
if (state->debug) printf("[DEBUG - %s -> wssshd] Client ID start not found\n", direction);
}
// Extract password (search in original unmodified string)
......@@ -313,20 +330,24 @@ int websocket_handle_message(wssshd_state_t *state, ws_connection_t *conn __attr
if (password_start) {
password_start += strlen("\"password\":\"");
char *password_end = strchr(password_start, '"');
if (password_end) {
if (password_end && password_end > password_start && password_end < msg_copy + MSG_BUFFER_SIZE) {
size_t password_len = password_end - password_start;
char *password_copy = malloc(password_len + 1);
if (password_copy) {
memcpy(password_copy, password_start, password_len);
password_copy[password_len] = '\0';
password = password_copy;
if (state->debug) printf("[DEBUG - %s -> wssshd] Extracted password: '%s'\n", direction, password);
if (password_len > 0 && password_len < 256) { // Reasonable limit for password
char *password_copy = malloc(password_len + 1);
if (password_copy) {
memcpy(password_copy, password_start, password_len);
password_copy[password_len] = '\0';
password = password_copy;
if (state->debug) printf("[DEBUG - %s -> wssshd] Extracted password: '%s'\n", direction, password);
}
} else {
if (state->debug) printf("[DEBUG - %s -> wssshd] Password length invalid: %zu\n", direction, password_len);
}
} else {
if (state->debug) printf("[DEBUG] Password end quote not found\n");
if (state->debug) printf("[DEBUG - %s -> wssshd] Password end quote not found or invalid\n", direction);
}
} else {
if (state->debug) printf("[DEBUG] Password start not found\n");
if (state->debug) printf("[DEBUG - %s -> wssshd] Password start not found\n", direction);
}
if (state->debug) {
......@@ -382,22 +403,26 @@ int websocket_handle_message(wssshd_state_t *state, ws_connection_t *conn __attr
if (tunnel_client_id_start) {
tunnel_client_id_start += strlen("\"client_id\":\"");
char *tunnel_client_id_end = strchr(tunnel_client_id_start, '"');
if (tunnel_client_id_end) {
if (tunnel_client_id_end && tunnel_client_id_end > tunnel_client_id_start && tunnel_client_id_end < msg_copy + MSG_BUFFER_SIZE) {
size_t tunnel_client_id_len = tunnel_client_id_end - tunnel_client_id_start;
char *tunnel_client_id_copy = malloc(tunnel_client_id_len + 1);
if (tunnel_client_id_copy) {
memcpy(tunnel_client_id_copy, tunnel_client_id_start, tunnel_client_id_len);
tunnel_client_id_copy[tunnel_client_id_len] = '\0';
client_id = tunnel_client_id_copy;
if (state->debug) printf("[DEBUG - %s -> wssshd] Extracted tunnel client_id: '%s' (ptr=%p)\n", direction, client_id, client_id);
if (tunnel_client_id_len > 0 && tunnel_client_id_len < 64) { // Reasonable limit for client ID
char *tunnel_client_id_copy = malloc(tunnel_client_id_len + 1);
if (tunnel_client_id_copy) {
memcpy(tunnel_client_id_copy, tunnel_client_id_start, tunnel_client_id_len);
tunnel_client_id_copy[tunnel_client_id_len] = '\0';
client_id = tunnel_client_id_copy;
if (state->debug) printf("[DEBUG - %s -> wssshd] Extracted tunnel client_id: '%s' (ptr=%p)\n", direction, client_id, client_id);
} else {
if (state->debug) printf("[DEBUG - %s -> wssshd] Failed to allocate memory for client_id\n", direction);
}
} else {
if (state->debug) printf("[DEBUG] Failed to allocate memory for client_id\n");
if (state->debug) printf("[DEBUG - %s -> wssshd] Client ID length invalid: %zu\n", direction, tunnel_client_id_len);
}
} else {
if (state->debug) printf("[DEBUG] Client ID end quote not found\n");
if (state->debug) printf("[DEBUG - %s -> wssshd] Client ID end quote not found or invalid\n", direction);
}
} else {
if (state->debug) printf("[DEBUG] Client ID start not found\n");
if (state->debug) printf("[DEBUG - %s -> wssshd] Client ID start not found\n", direction);
}
// Extract request_id (search in original unmodified string)
......@@ -405,15 +430,21 @@ int websocket_handle_message(wssshd_state_t *state, ws_connection_t *conn __attr
if (tunnel_request_id_start) {
tunnel_request_id_start += strlen("\"request_id\":\"");
char *tunnel_request_id_end = strchr(tunnel_request_id_start, '"');
if (tunnel_request_id_end) {
if (tunnel_request_id_end && tunnel_request_id_end > tunnel_request_id_start && tunnel_request_id_end < msg_copy + MSG_BUFFER_SIZE) {
size_t tunnel_request_id_len = tunnel_request_id_end - tunnel_request_id_start;
char *tunnel_request_id_copy = malloc(tunnel_request_id_len + 1);
if (tunnel_request_id_copy) {
memcpy(tunnel_request_id_copy, tunnel_request_id_start, tunnel_request_id_len);
tunnel_request_id_copy[tunnel_request_id_len] = '\0';
request_id = tunnel_request_id_copy;
if (state->debug) printf("[DEBUG - %s -> wssshd] Extracted request_id: '%s'\n", direction, request_id);
if (tunnel_request_id_len > 0 && tunnel_request_id_len < 64) { // Reasonable limit for request ID
char *tunnel_request_id_copy = malloc(tunnel_request_id_len + 1);
if (tunnel_request_id_copy) {
memcpy(tunnel_request_id_copy, tunnel_request_id_start, tunnel_request_id_len);
tunnel_request_id_copy[tunnel_request_id_len] = '\0';
request_id = tunnel_request_id_copy;
if (state->debug) printf("[DEBUG - %s -> wssshd] Extracted request_id: '%s'\n", direction, request_id);
}
} else {
if (state->debug) printf("[DEBUG - %s -> wssshd] Request ID length invalid: %zu\n", direction, tunnel_request_id_len);
}
} else {
if (state->debug) printf("[DEBUG - %s -> wssshd] Request ID end quote not found or invalid\n", direction);
}
}
......@@ -422,7 +453,7 @@ int websocket_handle_message(wssshd_state_t *state, ws_connection_t *conn __attr
if (enc_start) {
enc_start += strlen("\"enc\":\"");
char *enc_end = strchr(enc_start, '"');
if (enc_end && enc_end > enc_start) {
if (enc_end && enc_end > enc_start && enc_end < msg_copy + MSG_BUFFER_SIZE) {
size_t enc_len = enc_end - enc_start;
if (enc_len > 0 && enc_len < 32) { // Reasonable limit for encoding type
char *enc_copy = malloc(enc_len + 1);
......@@ -432,7 +463,11 @@ int websocket_handle_message(wssshd_state_t *state, ws_connection_t *conn __attr
enc = enc_copy;
if (state->debug) printf("[DEBUG - %s -> wssshd] Extracted enc: '%s'\n", direction, enc);
}
} else {
if (state->debug) printf("[DEBUG - %s -> wssshd] Enc length invalid: %zu\n", direction, enc_len);
}
} else {
if (state->debug) printf("[DEBUG - %s -> wssshd] Enc end quote not found or invalid\n", direction);
}
}
......@@ -441,7 +476,7 @@ int websocket_handle_message(wssshd_state_t *state, ws_connection_t *conn __attr
if (service_start) {
service_start += strlen("\"service\":\"");
char *service_end = strchr(service_start, '"');
if (service_end && service_end > service_start) {
if (service_end && service_end > service_start && service_end < msg_copy + MSG_BUFFER_SIZE) {
size_t service_len = service_end - service_start;
if (service_len > 0 && service_len < 32) { // Reasonable limit for service type
char *service_copy = malloc(service_len + 1);
......@@ -451,7 +486,11 @@ int websocket_handle_message(wssshd_state_t *state, ws_connection_t *conn __attr
service = service_copy;
if (state->debug) printf("[DEBUG - %s -> wssshd] Extracted service: '%s'\n", direction, service);
}
} else {
if (state->debug) printf("[DEBUG - %s -> wssshd] Service length invalid: %zu\n", direction, service_len);
}
} else {
if (state->debug) printf("[DEBUG - %s -> wssshd] Service end quote not found or invalid\n", direction);
}
}
......@@ -460,7 +499,7 @@ int websocket_handle_message(wssshd_state_t *state, ws_connection_t *conn __attr
if (version_start) {
version_start += strlen("\"version\":\"");
char *version_end = strchr(version_start, '"');
if (version_end && version_end > version_start) {
if (version_end && version_end > version_start && version_end < msg_copy + MSG_BUFFER_SIZE) {
size_t version_len = version_end - version_start;
if (version_len > 0 && version_len < 32) { // Reasonable limit for version string
char *version_copy = malloc(version_len + 1);
......@@ -470,7 +509,11 @@ int websocket_handle_message(wssshd_state_t *state, ws_connection_t *conn __attr
version = version_copy;
if (state->debug) printf("[DEBUG - %s -> wssshd] Extracted version: '%s'\n", direction, version);
}
} else {
if (state->debug) printf("[DEBUG - %s -> wssshd] Version length invalid: %zu\n", direction, version_len);
}
} else {
if (state->debug) printf("[DEBUG - %s -> wssshd] Version end quote not found or invalid\n", direction);
}
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment