Add comprehensive bounds checking to WebSocket message parsing in wsssht.c

- Prevent heap corruption from malformed JSON messages
- Add bounds validation for all string operations in message parsing
- Ensure all pointers stay within payload buffer limits
- Validate data field lengths to prevent excessive memory allocation
- Protect against buffer overflows in tunnel_data, tunnel_close, tunnel_keepalive, tunnel_ack, and tunnel_ko message parsing
- Add debug logging for malformed messages to aid troubleshooting
parent 60731aed
......@@ -959,33 +959,46 @@ int main(int argc, char *argv[]) {
fflush(stdout);
}
}
// Extract request_id and data
// Extract request_id and data with bounds checking
char *id_start = strstr(payload, "\"request_id\"");
char *data_start = strstr(payload, "\"data\"");
if (id_start && data_start) {
if (id_start && data_start && id_start < data_start) { // Ensure proper order
char *colon = strchr(id_start, ':');
if (colon) {
if (colon && colon < data_start) {
char *open_quote = strchr(colon, '"');
if (open_quote) {
if (open_quote && open_quote < data_start) {
id_start = open_quote + 1;
char *close_quote = strchr(id_start, '"');
if (close_quote) {
if (close_quote && close_quote < data_start) {
*close_quote = '\0';
char *data_colon = strchr(data_start, ':');
if (data_colon) {
if (data_colon && data_colon < payload + payload_len) {
char *data_quote = strchr(data_colon, '"');
if (data_quote) {
if (data_quote && data_quote < payload + payload_len) {
data_start = data_quote + 1;
char *data_end = strchr(data_start, '"');
if (data_end) {
if (data_end && data_end < payload + payload_len && data_end > data_start) {
*data_end = '\0';
handle_tunnel_data(current_ssl, id_start, data_start, config.debug);
size_t hex_len = data_end - data_start;
// Additional validation: reasonable hex data size
if (hex_len > 0 && hex_len < 1024 * 1024) { // Max 1MB hex data
handle_tunnel_data(current_ssl, id_start, data_start, config.debug);
} else if (config.debug) {
printf("[DEBUG] Invalid hex data length: %zu\n", hex_len);
fflush(stdout);
}
} else if (config.debug) {
printf("[DEBUG] Malformed data field in JSON\n");
fflush(stdout);
}
}
}
}
}
}
} else if (config.debug) {
printf("[DEBUG] Malformed tunnel message JSON\n");
fflush(stdout);
}
} else if (strstr(payload, "tunnel_close")) {
if (config.debug) {
......@@ -993,14 +1006,14 @@ int main(int argc, char *argv[]) {
fflush(stdout);
}
char *id_start = strstr(payload, "\"request_id\"");
if (id_start) {
if (id_start && id_start < payload + payload_len) {
char *colon = strchr(id_start, ':');
if (colon) {
if (colon && colon < payload + payload_len) {
char *open_quote = strchr(colon, '"');
if (open_quote) {
if (open_quote && open_quote < payload + payload_len) {
id_start = open_quote + 1;
char *close_quote = strchr(id_start, '"');
if (close_quote) {
if (close_quote && close_quote < payload + payload_len && close_quote > id_start) {
*close_quote = '\0';
handle_tunnel_close(current_ssl, id_start, config.debug);
}
......@@ -1008,36 +1021,36 @@ int main(int argc, char *argv[]) {
}
}
} else if (strstr(payload, "tunnel_keepalive")) {
// Extract request_id, total_bytes, and rate_bps
// Extract request_id, total_bytes, and rate_bps with bounds checking
char *id_start = strstr(payload, "\"request_id\"");
char *total_start = strstr(payload, "\"total_bytes\"");
char *rate_start = strstr(payload, "\"rate_bps\"");
if (id_start) {
if (id_start && id_start < payload + payload_len) {
char *colon = strchr(id_start, ':');
if (colon) {
if (colon && colon < payload + payload_len) {
char *open_quote = strchr(colon, '"');
if (open_quote) {
if (open_quote && open_quote < payload + payload_len) {
id_start = open_quote + 1;
char *close_quote = strchr(id_start, '"');
if (close_quote) {
if (close_quote && close_quote < payload + payload_len && close_quote > id_start) {
*close_quote = '\0';
unsigned long long total_bytes = 0;
double rate_bps = 0.0;
// Parse total_bytes
if (total_start) {
if (total_start && total_start < payload + payload_len) {
char *total_colon = strchr(total_start, ':');
if (total_colon) {
if (total_colon && total_colon < payload + payload_len) {
total_bytes = strtoull(total_colon + 1, NULL, 10);
}
}
// Parse rate_bps
if (rate_start) {
if (rate_start && rate_start < payload + payload_len) {
char *rate_colon = strchr(rate_start, ':');
if (rate_colon) {
if (rate_colon && rate_colon < payload + payload_len) {
rate_bps = strtod(rate_colon + 1, NULL);
}
}
......@@ -1048,16 +1061,16 @@ int main(int argc, char *argv[]) {
}
}
} else if (strstr(payload, "tunnel_keepalive_ack")) {
// Extract request_id
// Extract request_id with bounds checking
char *id_start = strstr(payload, "\"request_id\"");
if (id_start) {
if (id_start && id_start < payload + payload_len) {
char *colon = strchr(id_start, ':');
if (colon) {
if (colon && colon < payload + payload_len) {
char *open_quote = strchr(colon, '"');
if (open_quote) {
if (open_quote && open_quote < payload + payload_len) {
id_start = open_quote + 1;
char *close_quote = strchr(id_start, '"');
if (close_quote) {
if (close_quote && close_quote < payload + payload_len && close_quote > id_start) {
*close_quote = '\0';
handle_tunnel_keepalive_ack(current_ssl, id_start, config.debug);
}
......@@ -1065,21 +1078,21 @@ int main(int argc, char *argv[]) {
}
}
} else if (strstr(payload, "tunnel_ack")) {
// Extract request_id and frame_id
// Extract request_id and frame_id with bounds checking
char *id_start = strstr(payload, "\"request_id\"");
char *frame_start = strstr(payload, "\"frame_id\"");
if (id_start && frame_start) {
if (id_start && frame_start && id_start < frame_start && frame_start < payload + payload_len) {
char *colon = strchr(id_start, ':');
if (colon) {
if (colon && colon < frame_start) {
char *open_quote = strchr(colon, '"');
if (open_quote) {
if (open_quote && open_quote < frame_start) {
id_start = open_quote + 1;
char *close_quote = strchr(id_start, '"');
if (close_quote) {
if (close_quote && close_quote < frame_start && close_quote > id_start) {
*close_quote = '\0';
// Extract frame_id
char *frame_colon = strchr(frame_start, ':');
if (frame_colon) {
if (frame_colon && frame_colon < payload + payload_len) {
uint32_t frame_id = (uint32_t)atoi(frame_colon + 1);
handle_tunnel_ack(current_ssl, id_start, frame_id, config.debug);
}
......@@ -1088,21 +1101,21 @@ int main(int argc, char *argv[]) {
}
}
} else if (strstr(payload, "tunnel_ko")) {
// Extract request_id and frame_id
// Extract request_id and frame_id with bounds checking
char *id_start = strstr(payload, "\"request_id\"");
char *frame_start = strstr(payload, "\"frame_id\"");
if (id_start && frame_start) {
if (id_start && frame_start && id_start < frame_start && frame_start < payload + payload_len) {
char *colon = strchr(id_start, ':');
if (colon) {
if (colon && colon < frame_start) {
char *open_quote = strchr(colon, '"');
if (open_quote) {
if (open_quote && open_quote < frame_start) {
id_start = open_quote + 1;
char *close_quote = strchr(id_start, '"');
if (close_quote) {
if (close_quote && close_quote < frame_start && close_quote > id_start) {
*close_quote = '\0';
// Extract frame_id
char *frame_colon = strchr(frame_start, ':');
if (frame_colon) {
if (frame_colon && frame_colon < payload + payload_len) {
uint32_t frame_id = (uint32_t)atoi(frame_colon + 1);
handle_tunnel_ko(current_ssl, id_start, frame_id, config.debug);
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment