Fix large WebSocket frame truncation in wssshc by reducing BUFFER_SIZE to 4096...

Fix large WebSocket frame truncation in wssshc by reducing BUFFER_SIZE to 4096 to prevent partial reception on server
parent b3e5e3d2
......@@ -31,7 +31,7 @@
#include <getopt.h>
#include <pthread.h>
#define BUFFER_SIZE 65536
#define BUFFER_SIZE 4096
#define DEFAULT_PORT 9898
int global_debug = 0;
......@@ -97,7 +97,9 @@ void handle_tunnel_request(SSL *ssl, const char *request_id, int debug) {
frame[6 + i] = ack_msg[i] ^ mask_key[i % 4];
}
int frame_len = 6 + msg_len;
SSL_write(ssl, frame, frame_len);
if (!send_all(ssl, frame, frame_len)) {
fprintf(stderr, "Send tunnel_ack failed\n");
}
}
void handle_tunnel_data(SSL *ssl, const char *request_id, const char *data_hex, int debug) {
......@@ -234,7 +236,11 @@ void *tunnel_thread(void *arg) {
}
// Send as tunnel_response
char hex_data[bytes_read * 2 + 1];
char *hex_data = malloc(bytes_read * 2 + 1);
if (!hex_data) {
perror("Memory allocation failed for hex_data");
break;
}
for (int i = 0; i < bytes_read; i++) {
sprintf(hex_data + i * 2, "%02x", (unsigned char)buffer[i]);
}
......@@ -244,8 +250,14 @@ void *tunnel_thread(void *arg) {
printf("[DEBUG] Sending tunnel_response, hex len: %zu, hex: %.100s...\n", strlen(hex_data), hex_data);
}
char response[4096];
snprintf(response, sizeof(response),
size_t response_size = strlen("{\"type\":\"tunnel_response\",\"request_id\":\"\",\"data\":\"\"}") + strlen(request_id) + strlen(hex_data) + 1;
char *response = malloc(response_size);
if (!response) {
perror("Memory allocation failed for response");
free(hex_data);
break;
}
snprintf(response, response_size,
"{\"type\":\"tunnel_response\",\"request_id\":\"%s\",\"data\":\"%s\"}",
request_id, hex_data);
......@@ -254,19 +266,34 @@ void *tunnel_thread(void *arg) {
}
// Send as WebSocket frame with proper length encoding
char frame[4096];
frame[0] = 0x81; // FIN + text opcode
int msg_len = strlen(response);
int header_len = 2;
char mask_key[4];
if (msg_len <= 125) {
header_len = 6; // 2 + 4 for mask
} else if (msg_len <= 65535) {
header_len = 8; // 4 + 4 for mask
} else {
header_len = 14; // 10 + 4 for mask
}
char *frame = malloc(header_len + msg_len);
if (!frame) {
perror("Memory allocation failed for frame");
free(response);
free(hex_data);
break;
}
frame[0] = 0x81; // FIN + text opcode
if (msg_len <= 125) {
frame[1] = 0x80 | msg_len; // MASK + length
} else if (msg_len <= 65535) {
frame[1] = 0x80 | 126; // MASK + extended length
frame[2] = (msg_len >> 8) & 0xFF;
frame[3] = msg_len & 0xFF;
header_len = 4;
} else {
frame[1] = 0x80 | 127; // MASK + extended length
// For simplicity, assume length fits in 32 bits
......@@ -278,15 +305,13 @@ void *tunnel_thread(void *arg) {
frame[7] = (msg_len >> 16) & 0xFF;
frame[8] = (msg_len >> 8) & 0xFF;
frame[9] = msg_len & 0xFF;
header_len = 10;
}
// Add mask key
for (int i = 0; i < 4; i++) {
mask_key[i] = rand() % 256;
frame[header_len + i] = mask_key[i];
frame[header_len - 4 + i] = mask_key[i];
}
header_len += 4;
// Mask payload
for (int i = 0; i < msg_len; i++) {
......@@ -294,7 +319,13 @@ void *tunnel_thread(void *arg) {
}
int frame_len = header_len + msg_len;
SSL_write(ssl, frame, frame_len);
if (!send_all(ssl, frame, frame_len)) {
fprintf(stderr, "Send tunnel_response failed\n");
}
free(frame);
free(response);
free(hex_data);
}
return NULL;
......@@ -407,14 +438,13 @@ int websocket_handshake(SSL *ssl, const char *host, int port, const char *path)
"GET %s HTTP/1.1\r\n"
"Host: %s:%d\r\n"
"Upgrade: websocket\r\n"
"Connection: Upgrade\r\n"
"Connection: upgrade\r\n"
"Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n"
"Sec-WebSocket-Version: 13\r\n"
"\r\n",
path, host, port);
if (SSL_write(ssl, request, strlen(request)) <= 0) {
ERR_print_errors_fp(stderr);
if (!send_all(ssl, request, strlen(request))) {
fprintf(stderr, "WebSocket handshake send failed\n");
return 0;
}
......@@ -438,6 +468,19 @@ int websocket_handshake(SSL *ssl, const char *host, int port, const char *path)
return 1;
}
int send_all(SSL *ssl, const char *data, int len) {
int total_sent = 0;
while (total_sent < len) {
int sent = SSL_write(ssl, data + total_sent, len - total_sent);
if (sent <= 0) {
ERR_print_errors_fp(stderr);
return 0;
}
total_sent += sent;
}
return 1;
}
int send_json_message(SSL *ssl, const char *type, const char *id, const char *password) {
char message[1024];
char frame[1024];
......@@ -492,8 +535,7 @@ int send_json_message(SSL *ssl, const char *type, const char *id, const char *pa
int frame_len = header_len + message_len;
if (SSL_write(ssl, frame, frame_len) <= 0) {
ERR_print_errors_fp(stderr);
if (!send_all(ssl, frame, frame_len)) {
fprintf(stderr, "Send failed\n");
return 0;
}
......@@ -800,7 +842,7 @@ void connect_to_server(const wssshc_config_t *config) {
int masked = buffer[1] & 0x80;
int len_indicator = buffer[1] & 0x7F;
int header_len = 2;
int payload_len;
size_t payload_len;
if (len_indicator <= 125) {
payload_len = len_indicator;
......@@ -814,7 +856,11 @@ void connect_to_server(const wssshc_config_t *config) {
for (int i = 0; i < 8; i++) {
len = (len << 8) | (unsigned char)buffer[2 + i];
}
payload_len = (int)len;
if (len > SIZE_MAX) {
fprintf(stderr, "Ping payload too large\n");
continue;
}
payload_len = (size_t)len;
header_len = 10;
}
......@@ -822,7 +868,7 @@ void connect_to_server(const wssshc_config_t *config) {
header_len += 4;
}
if (payload_len >= 0 && payload_len <= 1024 && bytes_read >= header_len + payload_len) {
if (payload_len <= 1024 && bytes_read >= header_len + payload_len) {
char *payload = buffer + header_len;
if (masked) {
char *mask_key = buffer + header_len - 4;
......@@ -873,7 +919,9 @@ void connect_to_server(const wssshc_config_t *config) {
if (config->debug) {
printf("[DEBUG] Sending pong frame, len: %d\n", pong_frame_len);
}
SSL_write(ssl, pong_frame, pong_frame_len);
if (!send_all(ssl, pong_frame, pong_frame_len)) {
fprintf(stderr, "Send pong failed\n");
}
}
}
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment