Fix tunnel communication corruption and add --enc option to wsssht

- Fixed critical data corruption in WebSocket tunnels between wsssht, wssshd, and wssshc
- Root cause: Inconsistent encoding between components (wsssht used base64, wssshc used hex)
- Solution: Implemented consistent hex encoding across all tunnel data transmission
- Added encoding field to tunnel structures for proper encoding negotiation
- Fixed handle_tunnel_data() to decode using correct encoding type instead of guessing

- Added --enc option to wsssht for data encoding control
- --enc hex: Hexadecimal encoding (default, backward compatible)
- --enc base64: Base64 encoding for efficiency
- --enc bin: Direct binary data transmission
- Configuration file support with enc = hex option in wsssht.conf
- Automatic encoding negotiation between wsssht and wssshc clients
- wsssh and wsscp can pass --enc option to ProxyCommand for wsssht

- Updated documentation and examples
- Maintained backward compatibility
parent 8bd04b0a
......@@ -5,6 +5,35 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [1.6.6] - 2025-09-20
### Fixed
- **Critical Tunnel Communication Corruption**: Fixed data corruption in WebSocket tunnels between wsssht, wssshd, and wssshc
- Root cause: Inconsistent encoding between components (wsssht used base64, wssshc used hex)
- Solution: Implemented consistent hex encoding across all tunnel data transmission
- Added encoding field to tunnel structures for proper encoding negotiation
- Fixed `handle_tunnel_data()` to decode using correct encoding type instead of guessing
- Resolved corruption where hex data was incorrectly decoded as base64
### Added
- **Encoding Options for wsssht**: Added `--enc` option to wsssht for data encoding control
- `--enc hex`: Hexadecimal encoding (default, backward compatible)
- `--enc base64`: Base64 encoding for efficiency
- `--enc bin`: Direct binary data transmission
- Configuration file support with `enc = hex` option in `wsssht.conf`
- Automatic encoding negotiation between wsssht and wssshc clients
- wsssh and wsscp can pass `--enc` option to ProxyCommand for wsssht
### Technical Details
- **Encoding Architecture**: Extensible encoding system with per-tunnel encoding negotiation
- **Data Integrity**: All encoding modes now preserve data integrity during transmission
- **ProxyCommand Enhancement**: wsssh and wsscp now pass encoding options through ProxyCommand
- **Configuration Consistency**: Encoding settings properly propagated through the tunnel chain
### Security
- **Data Transmission Security**: Fixed corruption that could potentially cause data misinterpretation
- **Protocol Compliance**: Proper encoding negotiation prevents data corruption attacks
## [1.6.5] - 2025-09-19
### Added
......
......@@ -273,6 +273,9 @@ service = ssh
# SSH with specific transport
./wsssh --tunnel websocket user@myclient.example.com
# SSH with custom encoding
./wsssh --enc base64 user@myclient.example.com
# Debug mode to see the actual commands
./wsssh --debug user@myclient.example.com
```
......@@ -287,6 +290,9 @@ service = ssh
# SCP with custom transport
./wsscp --tunnel websocket localfile user@myclient.example.com:/remote/path/
# SCP with custom encoding
./wsscp --enc base64 localfile user@myclient.example.com:/remote/path/
```
### Tunnel Setup for Manual Use
......
......@@ -226,10 +226,6 @@ void websocket_remove_terminal(wssshd_state_t *state, const char *request_id) {
// Message handling
int websocket_handle_message(wssshd_state_t *state, ws_connection_t *conn __attribute__((unused)), const char *message, size_t message_len) {
if (state->debug) {
printf("[DEBUG] Handling message: %.*s\n", (int)message_len, message);
}
// Simple string-based JSON parsing for basic functionality
// This is a simplified implementation - a full JSON parser would be better
......@@ -237,13 +233,13 @@ int websocket_handle_message(wssshd_state_t *state, ws_connection_t *conn __attr
char *msg_copy = malloc(MSG_BUFFER_SIZE);
if (!msg_copy) {
if (state->debug) {
printf("[DEBUG] Failed to allocate message buffer\n");
printf("[DEBUG - unknown -> wssshd] Failed to allocate message buffer\n");
}
return -1;
}
if (message_len >= MSG_BUFFER_SIZE) {
if (state->debug) {
printf("[DEBUG] Message too long: %zu bytes\n", message_len);
printf("[DEBUG - unknown -> wssshd] Message too long: %zu bytes\n", message_len);
}
free(msg_copy);
return -1;
......@@ -251,21 +247,47 @@ int websocket_handle_message(wssshd_state_t *state, ws_connection_t *conn __attr
memcpy(msg_copy, message, message_len);
msg_copy[message_len] = '\0';
// Determine connection direction for debug messages
const char *direction = "unknown";
if (strstr(msg_copy, "\"type\":\"register\"") || strstr(msg_copy, "\"type\": \"register\"")) {
direction = "wssshc";
} else if (strstr(msg_copy, "\"type\":\"tunnel_request\"") || strstr(msg_copy, "\"type\": \"tunnel_request\"")) {
direction = "wsssht";
} else {
for (size_t i = 0; i < state->clients_count; i++) {
if (state->clients[i].websocket == conn) {
direction = "wssshc";
break;
}
}
if (strcmp(direction, "unknown") == 0) {
for (size_t i = 0; i < state->tunnels_count; i++) {
if (state->tunnels[i]->wsssh_ws == conn) {
direction = "wsssht";
break;
}
}
}
}
if (state->debug) {
printf("[DEBUG - %s -> wssshd] Handling message: %.*s\n", direction, (int)message_len, message);
}
int result = 0;
// Check for registration message
if (strstr(msg_copy, "\"type\":\"register\"") || strstr(msg_copy, "\"type\": \"register\"")) {
if (state->debug) {
printf("[DEBUG] Processing registration message\n");
printf("[DEBUG] Full message: %s\n", msg_copy);
printf("[DEBUG - %s -> wssshd] Processing registration message\n", direction);
printf("[DEBUG - %s -> wssshd] Full message: %s\n", direction, msg_copy);
}
// Extract client_id and password (simplified parsing)
char *client_id = NULL;
char *password = NULL;
if (state->debug) {
printf("[DEBUG] Parsing client_id and password from: %s\n", msg_copy);
printf("[DEBUG - %s -> wssshd] Parsing client_id and password from: %s\n", direction, msg_copy);
}
// Extract client_id (make a copy to avoid modifying the original string)
......@@ -281,7 +303,7 @@ int websocket_handle_message(wssshd_state_t *state, ws_connection_t *conn __attr
memcpy(client_id_copy, client_id_start, client_id_len);
client_id_copy[client_id_len] = '\0';
client_id = client_id_copy;
if (state->debug) printf("[DEBUG] Extracted client_id: '%s'\n", client_id);
if (state->debug) printf("[DEBUG - %s -> wssshd] Extracted client_id: '%s'\n", direction, client_id);
}
}
}
......@@ -298,7 +320,7 @@ int websocket_handle_message(wssshd_state_t *state, ws_connection_t *conn __attr
memcpy(password_copy, password_start, password_len);
password_copy[password_len] = '\0';
password = password_copy;
if (state->debug) printf("[DEBUG] Extracted password: '%s'\n", password);
if (state->debug) printf("[DEBUG - %s -> wssshd] Extracted password: '%s'\n", direction, password);
}
} else {
if (state->debug) printf("[DEBUG] Password end quote not found\n");
......@@ -308,7 +330,8 @@ int websocket_handle_message(wssshd_state_t *state, ws_connection_t *conn __attr
}
if (state->debug) {
printf("[DEBUG] Password check: received='%s', expected='%s'\n",
printf("[DEBUG - %s -> wssshd] Password check: received='%s', expected='%s'\n",
direction,
password ? password : "(null)",
state->server_password ? state->server_password : "(null)");
}
......@@ -319,6 +342,7 @@ int websocket_handle_message(wssshd_state_t *state, ws_connection_t *conn __attr
// Send registration success
char response[512];
snprintf(response, sizeof(response), REGISTERED_MSG, client_id);
if (state->debug) printf("[DEBUG - wssshd -> wssshc] Sending registration success: %s\n", response);
ws_send_frame(conn, WS_OPCODE_TEXT, response, strlen(response));
if (state->debug) printf("[EVENT] Client %s registered\n", client_id);
}
......@@ -326,9 +350,11 @@ int websocket_handle_message(wssshd_state_t *state, ws_connection_t *conn __attr
// Send registration error
char response[512];
snprintf(response, sizeof(response), REGISTRATION_ERROR_MSG, "Invalid password");
if (state->debug) printf("[DEBUG - wssshd -> wssshc] Sending registration error: %s\n", response);
ws_send_frame(conn, WS_OPCODE_TEXT, response, strlen(response));
if (state->debug) {
printf("[DEBUG] Client %s registration failed: client_id=%s, password=%s, server_password=%s\n",
printf("[DEBUG - %s -> wssshd] Client %s registration failed: client_id=%s, password=%s, server_password=%s\n",
direction,
client_id ? client_id : "unknown",
client_id ? client_id : "(null)",
password ? password : "(null)",
......@@ -341,8 +367,8 @@ int websocket_handle_message(wssshd_state_t *state, ws_connection_t *conn __attr
if (password) free(password);
} else if (strstr(msg_copy, "\"type\":\"tunnel_request\"") || strstr(msg_copy, "\"type\": \"tunnel_request\"")) {
if (state->debug) {
printf("[DEBUG] Processing tunnel request\n");
printf("[DEBUG] Full tunnel request: %s\n", msg_copy);
printf("[DEBUG - %s -> wssshd] Processing tunnel request\n", direction);
printf("[DEBUG - %s -> wssshd] Full tunnel request: %s\n", direction, msg_copy);
}
// Handle tunnel request (simplified)
char *client_id = NULL;
......@@ -360,7 +386,7 @@ int websocket_handle_message(wssshd_state_t *state, ws_connection_t *conn __attr
memcpy(tunnel_client_id_copy, tunnel_client_id_start, tunnel_client_id_len);
tunnel_client_id_copy[tunnel_client_id_len] = '\0';
client_id = tunnel_client_id_copy;
if (state->debug) printf("[DEBUG] Extracted tunnel client_id: '%s' (ptr=%p)\n", client_id, client_id);
if (state->debug) printf("[DEBUG - %s -> wssshd] Extracted tunnel client_id: '%s' (ptr=%p)\n", direction, client_id, client_id);
} else {
if (state->debug) printf("[DEBUG] Failed to allocate memory for client_id\n");
}
......@@ -383,7 +409,7 @@ int websocket_handle_message(wssshd_state_t *state, ws_connection_t *conn __attr
memcpy(tunnel_request_id_copy, tunnel_request_id_start, tunnel_request_id_len);
tunnel_request_id_copy[tunnel_request_id_len] = '\0';
request_id = tunnel_request_id_copy;
if (state->debug) printf("[DEBUG] Extracted request_id: '%s'\n", request_id);
if (state->debug) printf("[DEBUG - %s -> wssshd] Extracted request_id: '%s'\n", direction, request_id);
}
}
}
......@@ -396,44 +422,49 @@ int websocket_handle_message(wssshd_state_t *state, ws_connection_t *conn __attr
if (tunnel) {
// Store the wsssht connection in the tunnel for data forwarding
tunnel->wsssh_ws = conn;
// Store the wssshc connection in the tunnel
tunnel->client_ws = client->websocket;
tunnel_update_status(tunnel, TUNNEL_STATUS_ACTIVE, NULL);
// Send tunnel request to client (wssshc)
char request_msg[512];
snprintf(request_msg, sizeof(request_msg), TUNNEL_REQUEST_MSG, request_id);
if (state->debug) printf("[DEBUG - wssshd -> wssshc] Sending tunnel request: %s\n", request_msg);
ws_send_frame(client->websocket, WS_OPCODE_TEXT, request_msg, strlen(request_msg));
// Send tunnel ack to tool (wsssht/wsscp)
char ack_msg[256];
snprintf(ack_msg, sizeof(ack_msg), TUNNEL_ACK_MSG, request_id);
if (state->debug) printf("[DEBUG - wssshd -> wsssht] Sending tunnel ack: %s\n", ack_msg);
ws_send_frame(conn, WS_OPCODE_TEXT, ack_msg, strlen(ack_msg));
if (state->debug) printf("[DEBUG] Created tunnel %s for client %s\n", request_id, client_id);
if (state->debug) printf("[DEBUG - %s -> wssshd] Created tunnel %s for client %s\n", direction, request_id, client_id);
if (!state->debug) printf("[EVENT] New tunnel %s for client %s\n", request_id, client_id);
}
} else {
// Send error to tool
char error_msg[512];
snprintf(error_msg, sizeof(error_msg), TUNNEL_ERROR_MSG, request_id, "Client not registered or disconnected");
if (state->debug) printf("[DEBUG - wssshd -> wsssht] Sending tunnel error: %s\n", error_msg);
ws_send_frame(conn, WS_OPCODE_TEXT, error_msg, strlen(error_msg));
if (state->debug) printf("[DEBUG] Tunnel request failed: client %s not found or inactive\n", client_id);
if (state->debug) printf("[DEBUG - %s -> wssshd] Tunnel request failed: client %s not found or inactive\n", direction, client_id);
// Close the connection that sent the invalid tunnel request
if (state->debug) printf("[DEBUG] Closing connection due to invalid tunnel request from unregistered client %s\n", client_id);
if (state->debug) printf("[DEBUG - %s -> wssshd] Closing connection due to invalid tunnel request from unregistered client %s\n", direction, client_id);
conn->state = WS_STATE_CLOSED;
}
}
// Free allocated strings
if (state->debug) {
printf("[DEBUG] Freeing tunnel request strings: client_id=%p, request_id=%p\n", client_id, request_id);
printf("[DEBUG - %s -> wssshd] Freeing tunnel request strings: client_id=%p, request_id=%p\n", direction, client_id, request_id);
}
if (client_id) free(client_id);
if (request_id) free(request_id);
} else if (strstr(msg_copy, "\"type\":\"tunnel_ack\"") || strstr(msg_copy, "\"type\": \"tunnel_ack\"")) {
if (state->debug) {
printf("[DEBUG] Processing tunnel acknowledgment\n");
printf("[DEBUG - %s -> wssshd] Processing tunnel acknowledgment\n", direction);
}
// Handle tunnel acknowledgment from wssshc
char *request_id = NULL;
......@@ -449,7 +480,7 @@ int websocket_handle_message(wssshd_state_t *state, ws_connection_t *conn __attr
memcpy(ack_request_id_copy, ack_request_id_start, ack_request_id_len);
ack_request_id_copy[ack_request_id_len] = '\0';
request_id = ack_request_id_copy;
if (state->debug) printf("[DEBUG] Extracted tunnel_ack request_id: '%s'\n", request_id);
if (state->debug) printf("[DEBUG - %s -> wssshd] Extracted tunnel_ack request_id: '%s'\n", direction, request_id);
}
}
}
......@@ -458,13 +489,13 @@ int websocket_handle_message(wssshd_state_t *state, ws_connection_t *conn __attr
// Find the tunnel and mark it as acknowledged
// This is a simple acknowledgment - in a full implementation,
// you might want to track tunnel state more thoroughly
if (state->debug) printf("[DEBUG] Tunnel %s acknowledged by client\n", request_id);
if (state->debug) printf("[DEBUG - %s -> wssshd] Tunnel %s acknowledged by client\n", direction, request_id);
}
if (request_id) free(request_id);
} else if (strstr(msg_copy, "\"type\":\"tunnel_data\"") || strstr(msg_copy, "\"type\": \"tunnel_data\"")) {
if (state->debug) {
printf("[DEBUG] Processing tunnel data\n");
printf("[DEBUG - %s -> wssshd] Processing tunnel data\n", direction);
}
// Handle tunnel data from wssshc to wsssht
char *request_id = NULL;
......@@ -503,9 +534,24 @@ int websocket_handle_message(wssshd_state_t *state, ws_connection_t *conn __attr
}
if (request_id && data) {
// Find the tunnel and forward data to wsssht
// Find the tunnel
tunnel_t *tunnel = websocket_find_tunnel(state, request_id);
if (tunnel && tunnel->wsssh_ws) {
if (tunnel) {
// Determine which side to forward to based on sender
ws_connection_t *target_conn = NULL;
const char *target_side = NULL;
if (strcmp(direction, "wssshc") == 0) {
// Message from wssshc, forward to wsssht
target_conn = tunnel->wsssh_ws;
target_side = "wsssht";
} else if (strcmp(direction, "wsssht") == 0) {
// Message from wsssht, forward to wssshc
target_conn = tunnel->client_ws;
target_side = "wssshc";
}
if (target_conn) {
// Calculate binary size from hex data length
size_t hex_len = strlen(data);
size_t binary_size = hex_len / 2;
......@@ -521,17 +567,21 @@ int websocket_handle_message(wssshd_state_t *state, ws_connection_t *conn __attr
if (forward_msg) {
int msg_len = snprintf(forward_msg, total_size, "{\"type\":\"tunnel_data\",\"request_id\":\"%s\",\"size\":%zu,\"data\":\"%s\"}", request_id, binary_size, data);
if (msg_len > 0 && (size_t)msg_len < total_size) {
ws_send_frame(tunnel->wsssh_ws, WS_OPCODE_TEXT, forward_msg, msg_len);
if (state->debug) printf("[DEBUG] Forwarded tunnel data for request %s to wsssht, hex length: %zu bytes (binary size: %zu bytes)\n", request_id, data_len, binary_size);
if (state->debug) printf("[DEBUG - wssshd -> %s] Forwarding tunnel data: %.*s\n", target_side, msg_len, forward_msg);
ws_send_frame(target_conn, WS_OPCODE_TEXT, forward_msg, msg_len);
if (state->debug) printf("[DEBUG - %s -> wssshd] Forwarded tunnel data for request %s to %s, hex length: %zu bytes (binary size: %zu bytes)\n", direction, request_id, target_side, data_len, binary_size);
} else {
if (state->debug) printf("[DEBUG] Failed to format forward message\n");
if (state->debug) printf("[DEBUG - %s -> wssshd] Failed to format forward message\n", direction);
}
free(forward_msg);
} else {
if (state->debug) printf("[DEBUG] Failed to allocate buffer for forward message\n");
if (state->debug) printf("[DEBUG - %s -> wssshd] Failed to allocate buffer for forward message\n", direction);
}
} else {
if (state->debug) printf("[DEBUG] Could not find tunnel or wsssht connection for request %s\n", request_id);
if (state->debug) printf("[DEBUG - %s -> wssshd] Could not find target connection for request %s\n", direction, request_id);
}
} else {
if (state->debug) printf("[DEBUG - %s -> wssshd] Could not find tunnel for request %s\n", direction, request_id);
}
}
......@@ -539,7 +589,7 @@ int websocket_handle_message(wssshd_state_t *state, ws_connection_t *conn __attr
if (data) free(data);
} else if (strstr(msg_copy, "\"type\":\"tunnel_response\"") || strstr(msg_copy, "\"type\": \"tunnel_response\"")) {
if (state->debug) {
printf("[DEBUG] Processing tunnel response\n");
printf("[DEBUG - %s -> wssshd] Processing tunnel response\n", direction);
}
// Handle tunnel response from wssshc to wsssht
char *request_id = NULL;
......@@ -578,9 +628,24 @@ int websocket_handle_message(wssshd_state_t *state, ws_connection_t *conn __attr
}
if (request_id && data) {
// Find the tunnel and forward response to wsssht
// Find the tunnel
tunnel_t *tunnel = websocket_find_tunnel(state, request_id);
if (tunnel && tunnel->wsssh_ws) {
if (tunnel) {
// Determine which side to forward to based on sender
ws_connection_t *target_conn = NULL;
const char *target_side = NULL;
if (strcmp(direction, "wssshc") == 0) {
// Message from wssshc, forward to wsssht
target_conn = tunnel->wsssh_ws;
target_side = "wsssht";
} else if (strcmp(direction, "wsssht") == 0) {
// Message from wsssht, forward to wssshc
target_conn = tunnel->client_ws;
target_side = "wssshc";
}
if (target_conn) {
// Forward hex data directly without decode/re-encode cycle
size_t request_id_len = strlen(request_id);
size_t data_len = strlen(data);
......@@ -592,17 +657,21 @@ int websocket_handle_message(wssshd_state_t *state, ws_connection_t *conn __attr
if (forward_msg) {
int msg_len = snprintf(forward_msg, total_size, "{\"type\":\"tunnel_response\",\"request_id\":\"%s\",\"data\":\"%s\"}", request_id, data);
if (msg_len > 0 && (size_t)msg_len < total_size) {
ws_send_frame(tunnel->wsssh_ws, WS_OPCODE_TEXT, forward_msg, msg_len);
if (state->debug) printf("[DEBUG] Forwarded tunnel response for request %s to wsssht, hex length: %zu bytes\n", request_id, data_len);
if (state->debug) printf("[DEBUG - wssshd -> %s] Forwarding tunnel response: %.*s\n", target_side, msg_len, forward_msg);
ws_send_frame(target_conn, WS_OPCODE_TEXT, forward_msg, msg_len);
if (state->debug) printf("[DEBUG - %s -> wssshd] Forwarded tunnel response for request %s to %s, hex length: %zu bytes\n", direction, request_id, target_side, data_len);
} else {
if (state->debug) printf("[DEBUG] Failed to format forward message\n");
if (state->debug) printf("[DEBUG - %s -> wssshd] Failed to format forward message\n", direction);
}
free(forward_msg);
} else {
if (state->debug) printf("[DEBUG] Failed to allocate buffer for forward message\n");
if (state->debug) printf("[DEBUG - %s -> wssshd] Failed to allocate buffer for forward message\n", direction);
}
} else {
if (state->debug) printf("[DEBUG - %s -> wssshd] Could not find target connection for request %s\n", direction, request_id);
}
} else {
if (state->debug) printf("[DEBUG] Could not find tunnel or wsssht connection for request %s\n", request_id);
if (state->debug) printf("[DEBUG - %s -> wssshd] Could not find tunnel for request %s\n", direction, request_id);
}
}
......@@ -610,7 +679,7 @@ int websocket_handle_message(wssshd_state_t *state, ws_connection_t *conn __attr
if (data) free(data);
} else if (strstr(msg_copy, "\"type\":\"tunnel_close\"") || strstr(msg_copy, "\"type\": \"tunnel_close\"")) {
if (state->debug) {
printf("[DEBUG] Processing tunnel close\n");
printf("[DEBUG - %s -> wssshd] Processing tunnel close\n", direction);
}
// Handle tunnel close from wssshc
char *request_id = NULL;
......@@ -626,7 +695,7 @@ int websocket_handle_message(wssshd_state_t *state, ws_connection_t *conn __attr
memcpy(close_request_id_copy, close_request_id_start, close_request_id_len);
close_request_id_copy[close_request_id_len] = '\0';
request_id = close_request_id_copy;
if (state->debug) printf("[DEBUG] Extracted tunnel_close request_id: '%s'\n", request_id);
if (state->debug) printf("[DEBUG - %s -> wssshd] Extracted tunnel_close request_id: '%s'\n", direction, request_id);
}
}
}
......@@ -638,20 +707,85 @@ int websocket_handle_message(wssshd_state_t *state, ws_connection_t *conn __attr
// Forward the close message to wsssht
char close_msg[256];
snprintf(close_msg, sizeof(close_msg), "{\"type\":\"tunnel_close\",\"request_id\":\"%s\"}", request_id);
if (state->debug) printf("[DEBUG - wssshd -> wsssht] Forwarding tunnel close: %s\n", close_msg);
ws_send_frame(tunnel->wsssh_ws, WS_OPCODE_TEXT, close_msg, strlen(close_msg));
if (state->debug) printf("[DEBUG] Forwarded tunnel_close for request %s to wsssht\n", request_id);
if (state->debug) printf("[DEBUG - %s -> wssshd] Forwarded tunnel_close for request %s to wsssht\n", direction, request_id);
}
// Remove the tunnel
websocket_remove_tunnel(state, request_id);
if (state->debug) printf("[DEBUG] Closed tunnel %s\n", request_id);
if (state->debug) printf("[DEBUG - %s -> wssshd] Closed tunnel %s\n", direction, request_id);
if (!state->debug) printf("[EVENT] Tunnel %s closed\n", request_id);
}
if (request_id) free(request_id);
} else if (strstr(msg_copy, "\"type\":\"tunnel_keepalive\"") || strstr(msg_copy, "\"type\": \"tunnel_keepalive\"")) {
if (state->debug) {
printf("[DEBUG - %s -> wssshd] Processing tunnel keepalive\n", direction);
}
// Handle tunnel keepalive from wssshc or wsssht
char *request_id = NULL;
unsigned long long total_bytes = 0;
double rate_bps = 0.0;
// Extract request_id
char *ka_request_id_start = strstr(msg_copy, "\"request_id\":\"");
if (ka_request_id_start) {
ka_request_id_start += strlen("\"request_id\":\"");
char *ka_request_id_end = strchr(ka_request_id_start, '"');
if (ka_request_id_end) {
size_t len = ka_request_id_end - ka_request_id_start;
request_id = malloc(len + 1);
if (request_id) {
memcpy(request_id, ka_request_id_start, len);
request_id[len] = '\0';
}
}
}
// Extract total_bytes
char *tb_start = strstr(msg_copy, "\"total_bytes\":");
if (tb_start) {
tb_start += strlen("\"total_bytes\":");
total_bytes = strtoull(tb_start, NULL, 10);
}
// Extract rate_bps
char *rb_start = strstr(msg_copy, "\"rate_bps\":");
if (rb_start) {
rb_start += strlen("\"rate_bps\":");
rate_bps = strtod(rb_start, NULL);
}
if (request_id) {
tunnel_t *tunnel = websocket_find_tunnel(state, request_id);
if (tunnel) {
// Update keepalive time
if (conn == tunnel->client_ws) {
tunnel->last_keepalive_from_client = time(NULL);
} else if (conn == tunnel->wsssh_ws) {
tunnel->last_keepalive_from_tool = time(NULL);
}
// Send ack
char ack_msg[256];
snprintf(ack_msg, sizeof(ack_msg), "{\"type\":\"tunnel_keepalive_ack\",\"request_id\":\"%s\"}", request_id);
if (state->debug) printf("[DEBUG - wssshd -> %s] Sending keepalive ack: %s\n", direction, ack_msg);
ws_send_frame(conn, WS_OPCODE_TEXT, ack_msg, strlen(ack_msg));
if (state->debug) {
printf("[DEBUG - %s -> wssshd] Tunnel keepalive for %s: total_bytes=%llu, rate_bps=%.2f\n", direction, request_id, total_bytes, rate_bps);
}
} else {
if (state->debug) {
printf("[DEBUG] Unhandled message type in: %s\n", msg_copy);
printf("[DEBUG - %s -> wssshd] Keepalive for unknown tunnel %s\n", direction, request_id);
}
}
free(request_id);
}
} else {
if (state->debug) {
printf("[DEBUG - %s -> wssshd] Unhandled message type: %s\n", direction, msg_copy);
}
}
// TODO: Handle other message types with similar string parsing
......@@ -726,6 +860,23 @@ static void *client_handler_thread(void *arg) {
printf("WebSocket connection established\n");
// Determine connection direction for debug messages
const char *direction = "unknown";
for (size_t i = 0; i < state->clients_count; i++) {
if (state->clients[i].websocket == conn) {
direction = "wssshc";
break;
}
}
if (strcmp(direction, "unknown") == 0) {
for (size_t i = 0; i < state->tunnels_count; i++) {
if (state->tunnels[i]->wsssh_ws == conn) {
direction = "wsssht";
break;
}
}
}
// Handle WebSocket messages
while (server_running && conn->state == WS_STATE_OPEN) {
uint8_t opcode;
......@@ -734,32 +885,38 @@ static void *client_handler_thread(void *arg) {
if (ws_receive_frame(conn, &opcode, &data, &len)) {
if (state->debug) {
printf("[DEBUG] Received WebSocket frame: opcode=%d, len=%zu\n", opcode, len);
printf("[DEBUG - %s -> wssshd] Received WebSocket frame: opcode=%d, len=%zu\n", direction, opcode, len);
}
if (opcode == WS_OPCODE_TEXT && len > 0) {
// Handle text message
char *message = (char *)data;
message[len] = '\0'; // Null terminate
// Update direction if register or tunnel_request message
if (strstr(message, "\"type\":\"register\"") || strstr(message, "\"type\": \"register\"")) {
direction = "wssshc";
} else if (strstr(message, "\"type\":\"tunnel_request\"") || strstr(message, "\"type\": \"tunnel_request\"")) {
direction = "wsssht";
}
if (state->debug) {
printf("[DEBUG] Received message: %s\n", message);
printf("[DEBUG - %s -> wssshd] Received message: %s\n", direction, message);
}
websocket_handle_message(state, conn, message, len);
} else if (opcode == WS_OPCODE_CLOSE) {
// Handle close frame
if (state->debug) {
printf("[DEBUG] Received close frame\n");
printf("[DEBUG - %s -> wssshd] Received close frame\n", direction);
}
conn->state = WS_STATE_CLOSED;
} else if (opcode == WS_OPCODE_PING) {
// Respond with pong
if (state->debug) {
printf("[DEBUG] Received ping, sending pong\n");
printf("[DEBUG - %s -> wssshd] Received ping, sending pong\n", direction);
}
ws_send_frame(conn, WS_OPCODE_PONG, data, len);
} else {
if (state->debug) {
printf("[DEBUG] Received unhandled opcode: %d\n", opcode);
printf("[DEBUG - %s -> wssshd] Received unhandled opcode: %d\n", direction, opcode);
}
}
......@@ -767,7 +924,7 @@ static void *client_handler_thread(void *arg) {
} else {
// Connection error
if (state->debug) {
printf("[DEBUG] WebSocket frame receive failed\n");
printf("[DEBUG - %s -> wssshd] WebSocket frame receive failed\n", direction);
}
break;
}
......
......@@ -144,6 +144,13 @@ static bool ws_parse_frame_header(const uint8_t *buffer, size_t len, ws_frame_he
header->payload_len = payload_len;
}
// Validate payload length to prevent memory exhaustion attacks
// Limit to 10MB to prevent excessive memory allocation
const size_t MAX_PAYLOAD_SIZE = 10 * 1024 * 1024; // 10MB
if (header->payload_len > MAX_PAYLOAD_SIZE) {
return false; // Reject frames with excessively large payloads
}
if (header->masked) {
if (len < header_len + 4) return false;
memcpy(header->masking_key, buffer + header_len, 4);
......@@ -217,12 +224,24 @@ bool ws_send_frame(ws_connection_t *conn, uint8_t opcode, const void *data, size
return false;
}
uint8_t frame[14 + len]; // Max header size + data
size_t frame_len = 0;
size_t header_len = 2;
if (len >= 126) {
if (len < 65536) {
header_len = 4;
} else {
header_len = 10;
}
}
size_t frame_len = header_len + len;
uint8_t *frame = malloc(frame_len);
if (!frame) {
printf("[DEBUG] ws_send_frame: Failed to allocate frame buffer\n");
return false;
}
// Frame header
frame[0] = 0x80 | opcode; // FIN bit set
frame_len = 2;
if (len < 126) {
frame[1] = len;
......@@ -230,7 +249,6 @@ bool ws_send_frame(ws_connection_t *conn, uint8_t opcode, const void *data, size
frame[1] = 126;
frame[2] = (len >> 8) & 0xFF;
frame[3] = len & 0xFF;
frame_len = 4;
} else {
frame[1] = 127;
// Only support 32-bit lengths for simplicity
......@@ -239,21 +257,30 @@ bool ws_send_frame(ws_connection_t *conn, uint8_t opcode, const void *data, size
frame[7] = (len >> 16) & 0xFF;
frame[8] = (len >> 8) & 0xFF;
frame[9] = len & 0xFF;
frame_len = 10;
}
// Copy data
if (len > 0) {
memcpy(frame + frame_len, data, len);
frame_len += len;
memcpy(frame + header_len, data, len);
}
printf("[DEBUG] ws_send_frame: Sending frame with opcode=%d, len=%zu, frame_len=%zu\n", opcode, len, frame_len);
// Send frame
int bytes_written = SSL_write(conn->ssl, frame, frame_len);
printf("[DEBUG] ws_send_frame: SSL_write returned %d (expected %zu)\n", bytes_written, frame_len);
return bytes_written == (int)frame_len;
// Send frame with partial write handling
int total_written = 0;
while (total_written < (int)frame_len) {
int to_write = frame_len - total_written;
int written = SSL_write(conn->ssl, frame + total_written, to_write);
if (written <= 0) {
printf("[DEBUG] ws_send_frame: SSL_write failed at offset %d\n", total_written);
free(frame);
return false;
}
total_written += written;
}
printf("[DEBUG] ws_send_frame: SSL_write returned %d (expected %zu)\n", total_written, frame_len);
free(frame);
return total_written == (int)frame_len;
}
// Receive WebSocket frame
......@@ -292,11 +319,15 @@ bool ws_receive_frame(ws_connection_t *conn, uint8_t *opcode, void **data, size_
// Read additional header bytes if needed
if (header_size > 2) {
bytes_read = SSL_read(conn->ssl, header + 2, header_size - 2);
if (bytes_read != (int)(header_size - 2)) {
printf("[DEBUG] ws_receive_frame: Failed to read extended header, expected %zu bytes, got %d\n", header_size - 2, bytes_read);
int total_read = 0;
while (total_read < (int)(header_size - 2)) {
bytes_read = SSL_read(conn->ssl, header + 2 + total_read, header_size - 2 - total_read);
if (bytes_read <= 0) {
printf("[DEBUG] ws_receive_frame: Failed to read extended header\n");
return false;
}
total_read += bytes_read;
}
}
ws_frame_header_t frame_header;
......@@ -311,11 +342,15 @@ bool ws_receive_frame(ws_connection_t *conn, uint8_t *opcode, void **data, size_
// Read payload
if (frame_header.payload_len > 0) {
bytes_read = SSL_read(conn->ssl, *data, frame_header.payload_len);
if (bytes_read != (int)frame_header.payload_len) {
int total_read = 0;
while (total_read < (int)frame_header.payload_len) {
bytes_read = SSL_read(conn->ssl, (char *)*data + total_read, frame_header.payload_len - total_read);
if (bytes_read <= 0) {
free(*data);
return false;
}
total_read += bytes_read;
}
// Unmask if needed
if (frame_header.masked) {
......
......@@ -307,7 +307,7 @@ uint32_t retransmission_buffer_get_next_frame_id(retransmission_buffer_t *buffer
}
// Reliable tunnel data message with frame_id and checksum
int send_tunnel_data_reliable_message(SSL *ssl, const char *request_id, const unsigned char *data, size_t data_len, retransmission_buffer_t *buffer, int debug) {
int send_tunnel_data_reliable_message(SSL *ssl, const char *request_id, const unsigned char *data, size_t data_len, retransmission_buffer_t *buffer, wsssh_encoding_t encoding, int debug) {
if (!buffer) return 0;
// Calculate checksum
......@@ -316,6 +316,17 @@ int send_tunnel_data_reliable_message(SSL *ssl, const char *request_id, const un
// Get next frame ID
uint32_t frame_id = retransmission_buffer_get_next_frame_id(buffer);
char *encoded_data = NULL;
size_t encoded_len = 0;
if (encoding == ENCODING_BINARY) {
// Send as binary WebSocket frame
// For now, fall back to base64 for reliable transmission
// TODO: Implement binary WebSocket frames for reliable transmission
encoding = ENCODING_BASE64;
}
if (encoding == ENCODING_BASE64) {
// Base64 encode the binary data
size_t b64_len = ((data_len + 2) / 3) * 4 + 1;
char *b64_data = malloc(b64_len);
......@@ -350,21 +361,44 @@ int send_tunnel_data_reliable_message(SSL *ssl, const char *request_id, const un
}
b64_data[j] = '\0';
encoded_data = b64_data;
encoded_len = b64_len;
} else if (encoding == ENCODING_HEX) {
// Hex encode the binary data
size_t hex_len = data_len * 2 + 1;
char *hex_data = malloc(hex_len);
if (!hex_data) {
if (debug) {
printf("[DEBUG] Failed to allocate memory for hex encoding (%zu bytes)\n", hex_len);
fflush(stdout);
}
return 0;
}
for (size_t i = 0; i < data_len; i++) {
sprintf(hex_data + i * 2, "%02x", data[i]);
}
hex_data[data_len * 2] = '\0';
encoded_data = hex_data;
encoded_len = hex_len;
}
// Create JSON message with frame_id and checksum
size_t msg_size = strlen("{\"type\":\"tunnel_data\",\"request_id\":\"\",\"frame_id\":,\"checksum\":,\"data\":\"\"}") +
strlen(request_id) + 20 + 10 + b64_len + 1; // frame_id (10) + checksum (10)
strlen(request_id) + 20 + 10 + encoded_len + 1; // frame_id (10) + checksum (10)
char *message = malloc(msg_size);
if (!message) {
if (debug) {
printf("[DEBUG] Failed to allocate memory for reliable tunnel_data message (%zu bytes)\n", msg_size);
fflush(stdout);
}
free(b64_data);
free(encoded_data);
return 0;
}
snprintf(message, msg_size, "{\"type\":\"tunnel_data\",\"request_id\":\"%s\",\"frame_id\":%u,\"checksum\":%u,\"data\":\"%s\"}",
request_id, frame_id, checksum, b64_data);
request_id, frame_id, checksum, encoded_data);
// Add to retransmission buffer
if (!retransmission_buffer_add(buffer, frame_id, message, strlen(message))) {
......@@ -373,7 +407,7 @@ int send_tunnel_data_reliable_message(SSL *ssl, const char *request_id, const un
fflush(stdout);
}
free(message);
free(b64_data);
free(encoded_data);
return 0;
}
......@@ -383,12 +417,12 @@ int send_tunnel_data_reliable_message(SSL *ssl, const char *request_id, const un
fflush(stdout);
}
free(message);
free(b64_data);
free(encoded_data);
return 0;
}
free(message);
free(b64_data);
free(encoded_data);
return 1;
}
......
......@@ -21,6 +21,7 @@
#define DATA_MESSAGES_H
#include <openssl/ssl.h>
#include "wssshlib.h"
// Retransmission buffer entry
typedef struct {
......@@ -47,7 +48,7 @@ typedef struct {
// Function declarations for data channel messages
int send_tunnel_data_message(SSL *ssl, const char *request_id, const char *data_hex, int debug);
int send_tunnel_data_binary_message(SSL *ssl, const char *request_id, const unsigned char *data, size_t data_len, int debug);
int send_tunnel_data_reliable_message(SSL *ssl, const char *request_id, const unsigned char *data, size_t data_len, retransmission_buffer_t *buffer, int debug);
int send_tunnel_data_reliable_message(SSL *ssl, const char *request_id, const unsigned char *data, size_t data_len, retransmission_buffer_t *buffer, wsssh_encoding_t encoding, int debug);
int send_tunnel_response_message(SSL *ssl, const char *request_id, const char *data_hex, int debug);
int send_tunnel_ack_message(SSL *ssl, const char *request_id, uint32_t frame_id, int debug);
int send_tunnel_ko_message(SSL *ssl, const char *request_id, uint32_t frame_id, int debug);
......
......@@ -246,8 +246,9 @@ int run_script_mode(wsssh_config_t *config, const char *client_id, const char *w
time(NULL));
// Establish tunnel
int listen_sock = setup_tunnel(wssshd_host, wssshd_port, client_id, config->local_port ? atoi(config->local_port) : find_available_port(),
config->debug, 0, config->tunnel_host, config->encoding);
tunnel_setup_result_t setup_result = setup_tunnel(wssshd_host, wssshd_port, client_id, config->local_port ? atoi(config->local_port) : find_available_port(),
config->debug, 0, config->tunnel_host, config->encoding, 1);
int listen_sock = setup_result.listen_sock;
if (listen_sock < 0) {
fprintf(stderr, "{\"type\":\"script_error\",\"message\":\"Failed to establish tunnel\",\"timestamp\":%ld}\n", time(NULL));
......@@ -1806,6 +1807,14 @@ int run_daemon_mode(wsssh_config_t *config, const char *client_id, const char *w
return 1;
}
// Set SO_REUSEADDR to allow immediate reuse of the port
int opt = 1;
if (setsockopt(listen_sock, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)) < 0) {
perror("setsockopt on listen_sock failed");
close(listen_sock);
return 1;
}
struct sockaddr_in local_addr;
memset(&local_addr, 0, sizeof(local_addr));
local_addr.sin_family = AF_INET;
......
......@@ -46,7 +46,8 @@ void *run_tunnel_thread(void *arg) {
tunnel_thread_args_t *args = (tunnel_thread_args_t *)arg;
// Establish the tunnel for this connection
int tunnel_sock = setup_tunnel(args->wssshd_host, args->wssshd_port, args->client_id, 0, args->config->debug, 0, args->tunnel_host, args->config->encoding);
tunnel_setup_result_t setup_result = setup_tunnel(args->wssshd_host, args->wssshd_port, args->client_id, 0, args->config->debug, 0, args->tunnel_host, args->config->encoding, 1);
int tunnel_sock = setup_result.listen_sock;
if (tunnel_sock < 0) {
fprintf(stderr, "Failed to establish tunnel for connection\n");
close(args->accepted_sock);
......
......@@ -427,6 +427,7 @@ void handle_tunnel_request_with_enc(SSL *ssl, const char *request_id, int debug,
new_tunnel->incoming_buffer = NULL; // wssshc doesn't need incoming buffer
new_tunnel->server_version_sent = 0; // Not used for raw TCP
new_tunnel->bin = (encoding == ENCODING_BINARY); // Set binary mode based on encoding
new_tunnel->encoding = encoding;
// Initialize retransmission buffer for reliable transmission
new_tunnel->retransmission_buffer = retransmission_buffer_init();
......@@ -605,6 +606,7 @@ void handle_tunnel_request_with_service_and_enc(SSL *ssl, const char *request_id
new_tunnel->incoming_buffer = NULL; // wssshc doesn't need incoming buffer
new_tunnel->server_version_sent = 0; // Not used for raw TCP/UDP
new_tunnel->bin = (encoding == ENCODING_BINARY); // Set binary mode based on encoding
new_tunnel->encoding = encoding;
// Initialize retransmission buffer for reliable transmission
new_tunnel->retransmission_buffer = retransmission_buffer_init();
......@@ -894,7 +896,7 @@ void *forward_tcp_to_ws(void *arg) {
// Use reliable transmission
if (tunnel->retransmission_buffer) {
if (!send_tunnel_data_reliable_message(ssl, request_id, (unsigned char *)buffer, bytes_read, tunnel->retransmission_buffer, debug)) {
if (!send_tunnel_data_reliable_message(ssl, request_id, (unsigned char *)buffer, bytes_read, tunnel->retransmission_buffer, tunnel->encoding, debug)) {
break;
}
} else {
......@@ -1182,22 +1184,36 @@ void handle_tunnel_data(SSL *ssl, const char *request_id, const char *data_hex,
tunnel->total_bytes_received += data_len;
tunnel->bytes_last_period += data_len;
} else {
// Legacy mode: try base64 decoding first (for compatibility with Python wssshd)
// Check if data looks like base64 (contains only valid base64 chars)
int looks_like_base64 = 1;
for (size_t i = 0; i < hex_len; i++) {
char c = data_hex[i];
if (!(isalnum(c) || c == '+' || c == '/' || c == '=')) {
looks_like_base64 = 0;
break;
// Decode based on tunnel encoding
wsssh_encoding_t encoding = tunnel->encoding;
if (encoding == ENCODING_BINARY) {
// Binary mode: use data as-is (no decoding)
data_len = hex_len;
data = malloc(data_len);
if (!data) {
if (debug) {
printf("[DEBUG] Failed to allocate memory for %zu bytes of binary data\n", data_len);
fflush(stdout);
}
pthread_mutex_unlock(&tunnel_mutex);
return;
}
memcpy(data, data_hex, data_len);
if (looks_like_base64 && hex_len % 4 == 0) {
// Try base64 decoding
if (debug) {
printf("[DEBUG] Using %zu bytes of binary data\n", data_len);
fflush(stdout);
}
// Update statistics
tunnel->total_bytes_received += data_len;
tunnel->bytes_last_period += data_len;
} else if (encoding == ENCODING_BASE64) {
// Base64 decoding
data_len = (hex_len * 3) / 4;
if (data_hex[hex_len - 1] == '=') data_len--;
if (data_hex[hex_len - 2] == '=') data_len--;
if (hex_len > 0 && data_hex[hex_len - 1] == '=') data_len--;
if (hex_len > 1 && data_hex[hex_len - 2] == '=') data_len--;
data = malloc(data_len);
if (!data) {
......@@ -1239,9 +1255,13 @@ void handle_tunnel_data(SSL *ssl, const char *request_id, const char *data_hex,
}
if (!decode_success) {
// Base64 decoding failed, free data and fall back to hex
free(data);
data = NULL;
if (debug) {
printf("[DEBUG] Base64 decoding failed\n");
fflush(stdout);
}
pthread_mutex_unlock(&tunnel_mutex);
return;
} else {
data_len = j; // Actual decoded length
......@@ -1254,14 +1274,11 @@ void handle_tunnel_data(SSL *ssl, const char *request_id, const char *data_hex,
tunnel->total_bytes_received += data_len;
tunnel->bytes_last_period += data_len;
}
}
// If base64 failed or wasn't attempted, fall back to hex decoding
if (!data) {
// Fall back to hex decoding
} else if (encoding == ENCODING_HEX) {
// Hex decoding
if (hex_len % 2 != 0) {
if (debug) {
printf("[DEBUG] Invalid hex data length: %zu (must be even), data: %.50s...\n", hex_len, data_hex);
printf("[DEBUG] Invalid hex data length: %zu (must be even)\n", hex_len);
fflush(stdout);
}
pthread_mutex_unlock(&tunnel_mutex);
......@@ -1310,6 +1327,10 @@ void handle_tunnel_data(SSL *ssl, const char *request_id, const char *data_hex,
if (!hex_decode_success) {
free(data);
if (debug) {
printf("[DEBUG] Hex decoding failed\n");
fflush(stdout);
}
pthread_mutex_unlock(&tunnel_mutex);
return;
}
......@@ -1322,6 +1343,13 @@ void handle_tunnel_data(SSL *ssl, const char *request_id, const char *data_hex,
// Update statistics
tunnel->total_bytes_received += data_len;
tunnel->bytes_last_period += data_len;
} else {
if (debug) {
printf("[DEBUG] Unknown encoding type: %d\n", encoding);
fflush(stdout);
}
pthread_mutex_unlock(&tunnel_mutex);
return;
}
}
......@@ -1769,7 +1797,7 @@ int reconnect_websocket(tunnel_t *tunnel, const char *wssshd_host, int wssshd_po
return 0;
}
int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id, int local_port, int debug, int use_buffer, const char *tunnel_host, wsssh_encoding_t encoding) {
tunnel_setup_result_t setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id, int local_port, int debug, int use_buffer, const char *tunnel_host, wsssh_encoding_t encoding, int send_tunnel_request_immediately) {
struct sockaddr_in server_addr;
struct hostent *he;
int sock;
......@@ -1785,13 +1813,15 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id
// Resolve hostname
if ((he = gethostbyname(wssshd_host)) == NULL) {
herror("gethostbyname");
return -1;
tunnel_setup_result_t result = {-1, NULL, NULL, ""};
return result;
}
// Create socket
if ((sock = socket(AF_INET, SOCK_STREAM, 0)) < 0) {
perror("Socket creation failed");
return -1;
tunnel_setup_result_t result = {-1, NULL, NULL, ""};
return result;
}
memset(&server_addr, 0, sizeof(server_addr));
......@@ -1803,21 +1833,24 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id
if (connect(sock, (struct sockaddr *)&server_addr, sizeof(server_addr)) < 0) {
perror("Connection failed");
close(sock);
return -1;
tunnel_setup_result_t result = {-1, NULL, NULL, ""};
return result;
}
// Create SSL context and connection
ssl_ctx = create_ssl_context();
if (!ssl_ctx) {
close(sock);
return -1;
tunnel_setup_result_t result = {-1, NULL, NULL, ""};
return result;
}
ssl = create_ssl_connection(ssl_ctx, sock, debug);
if (!ssl) {
SSL_CTX_free(ssl_ctx);
close(sock);
return -1;
tunnel_setup_result_t result = {-1, NULL, NULL, ""};
return result;
}
// Perform WebSocket handshake
......@@ -1830,13 +1863,15 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id
SSL_free(ssl);
SSL_CTX_free(ssl_ctx);
close(sock);
return -1;
tunnel_setup_result_t result = {-1, NULL, NULL, ""};
return result;
}
if (debug) {
printf("[DEBUG] WebSocket handshake successful\n");
fflush(stdout);
}
if (send_tunnel_request_immediately) {
// Send detailed tunnel request with transport information
char *expanded_tunnel = expand_transport_list("any", 0); // Data channel transports
char *expanded_tunnel_control = expand_transport_list("any", 1); // Control channel transports
......@@ -1853,7 +1888,8 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id
SSL_free(ssl);
SSL_CTX_free(ssl_ctx);
close(sock);
return -1;
tunnel_setup_result_t result = {-1, NULL, NULL, ""};
return result;
}
free(expanded_tunnel);
......@@ -1874,7 +1910,8 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id
SSL_free(ssl);
SSL_CTX_free(ssl_ctx);
close(sock);
return -1;
tunnel_setup_result_t result = {-1, NULL, NULL, ""};
return result;
}
if (debug) {
......@@ -1908,7 +1945,8 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id
SSL_free(ssl);
SSL_CTX_free(ssl_ctx);
close(sock);
return -1;
tunnel_setup_result_t result = {-1, NULL, NULL, ""};
return result;
}
// Parse WebSocket frame
......@@ -1927,7 +1965,8 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id
SSL_free(ssl);
SSL_CTX_free(ssl_ctx);
close(sock);
return -1;
tunnel_setup_result_t result = {-1, NULL, NULL, ""};
return result;
}
// Null terminate payload
......@@ -1953,21 +1992,26 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id
SSL_free(ssl);
SSL_CTX_free(ssl_ctx);
close(sock);
return -1;
tunnel_setup_result_t result = {-1, NULL, NULL, ""};
return result;
}
if (debug) {
printf("[DEBUG] Tunnel established, local port: %d\n", local_port);
}
}
tunnel_t *new_tunnel = NULL;
if (send_tunnel_request_immediately) {
// Create tunnel structure
tunnel_t *new_tunnel = malloc(sizeof(tunnel_t));
new_tunnel = malloc(sizeof(tunnel_t));
if (!new_tunnel) {
perror("Memory allocation failed");
SSL_free(ssl);
SSL_CTX_free(ssl_ctx);
close(sock);
return -1;
tunnel_setup_result_t result = {-1, NULL, NULL, ""};
return result;
}
if (use_buffer) {
......@@ -1978,7 +2022,8 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id
SSL_free(ssl);
SSL_CTX_free(ssl_ctx);
close(sock);
return -1;
tunnel_setup_result_t result = {-1, NULL, NULL, ""};
return result;
}
} else {
new_tunnel->outgoing_buffer = NULL;
......@@ -1993,7 +2038,8 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id
SSL_free(ssl);
SSL_CTX_free(ssl_ctx);
close(sock);
return -1;
tunnel_setup_result_t result = {-1, NULL, NULL, ""};
return result;
}
strcpy(new_tunnel->request_id, request_id);
......@@ -2016,7 +2062,8 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id
SSL_free(ssl);
SSL_CTX_free(ssl_ctx);
close(sock);
return -1;
tunnel_setup_result_t result = {-1, NULL, NULL, ""};
return result;
}
// Add the new tunnel to the array for multiple tunnel support
......@@ -2029,25 +2076,47 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id
SSL_free(ssl);
SSL_CTX_free(ssl_ctx);
close(sock);
return -1;
tunnel_setup_result_t result = {-1, NULL, NULL, ""};
return result;
}
// For backward compatibility with wsssh/wsscp that use active_tunnel global
active_tunnel = new_tunnel;
pthread_mutex_unlock(&tunnel_mutex);
}
// Start listening on local port
int listen_sock = socket(AF_INET, SOCK_STREAM, 0);
if (listen_sock < 0) {
perror("Local socket creation failed");
if (send_tunnel_request_immediately && new_tunnel) {
if (use_buffer) frame_buffer_free(new_tunnel->outgoing_buffer);
frame_buffer_free(new_tunnel->incoming_buffer);
free(new_tunnel);
}
SSL_free(ssl);
SSL_CTX_free(ssl_ctx);
close(sock);
return -1;
tunnel_setup_result_t result = {-1, NULL, NULL, ""};
return result;
}
// Set SO_REUSEADDR to allow immediate reuse of the port
int opt = 1;
if (setsockopt(listen_sock, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)) < 0) {
perror("setsockopt on listen_sock failed");
close(listen_sock);
if (send_tunnel_request_immediately && new_tunnel) {
if (use_buffer) frame_buffer_free(new_tunnel->outgoing_buffer);
frame_buffer_free(new_tunnel->incoming_buffer);
free(new_tunnel);
}
SSL_free(ssl);
SSL_CTX_free(ssl_ctx);
close(sock);
tunnel_setup_result_t result = {-1, NULL, NULL, ""};
return result;
}
struct sockaddr_in local_addr;
......@@ -2082,25 +2151,31 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id
if (bind(listen_sock, (struct sockaddr *)&local_addr, sizeof(local_addr)) < 0) {
perror("Local bind failed");
close(listen_sock);
if (send_tunnel_request_immediately && new_tunnel) {
if (use_buffer) frame_buffer_free(new_tunnel->outgoing_buffer);
frame_buffer_free(new_tunnel->incoming_buffer);
free(new_tunnel);
}
SSL_free(ssl);
SSL_CTX_free(ssl_ctx);
close(sock);
return -1;
tunnel_setup_result_t result = {-1, NULL, NULL, ""};
return result;
}
if (listen(listen_sock, 1) < 0) {
perror("Local listen failed");
close(listen_sock);
if (send_tunnel_request_immediately && new_tunnel) {
if (use_buffer) frame_buffer_free(new_tunnel->outgoing_buffer);
frame_buffer_free(new_tunnel->incoming_buffer);
free(new_tunnel);
}
SSL_free(ssl);
SSL_CTX_free(ssl_ctx);
close(sock);
return -1;
tunnel_setup_result_t result = {-1, NULL, NULL, ""};
return result;
}
if (debug) {
......@@ -2108,9 +2183,13 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id
fflush(stdout);
}
if (send_tunnel_request_immediately) {
// Clean up SSL context
SSL_CTX_free(ssl_ctx);
}
// Return success - tunnel is set up and listening
return listen_sock;
tunnel_setup_result_t result = {listen_sock, ssl, ssl_ctx, ""};
strcpy(result.request_id, request_id);
return result;
}
\ No newline at end of file
......@@ -52,6 +52,7 @@ typedef struct {
unsigned long long bytes_last_period; // Bytes transferred in last 30-second period
time_t last_stats_reset; // When we last reset the period stats
int bin; // Binary mode flag - if true, transmit data as binary instead of hex
wsssh_encoding_t encoding; // Data encoding mode
// Reliable transmission
retransmission_buffer_t *retransmission_buffer; // Buffer for reliable message retransmission
......@@ -113,10 +114,20 @@ void send_tunnel_keepalive(SSL *ssl, tunnel_t *tunnel, int debug);
void check_keepalive_timeouts(int debug);
void cleanup_tunnel(int debug);
int reconnect_websocket(tunnel_t *tunnel, const char *wssshd_host, int wssshd_port, const char *client_id, const char *request_id, int debug);
int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id, int local_port, int debug, int use_buffer, const char *tunnel_host, wsssh_encoding_t encoding);
// CPU affinity functions
void init_cpu_affinity(void);
void set_thread_cpu_affinity(pthread_t thread);
// Structure for tunnel setup result
typedef struct {
int listen_sock;
SSL *ssl;
SSL_CTX *ssl_ctx;
char request_id[37];
} tunnel_setup_result_t;
// Function declarations
tunnel_setup_result_t setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id, int local_port, int debug, int use_buffer, const char *tunnel_host, wsssh_encoding_t encoding, int send_tunnel_request_immediately);
#endif // TUNNEL_H
\ No newline at end of file
......@@ -72,7 +72,7 @@ void print_usage(const char *program_name) {
fprintf(stderr, " ETH: 0xdA6dAb526515b5cb556d20269207D43fcc760E51\n");
}
int parse_wsscp_args(int argc, char *argv[], wsscp_config_t *config) {
int parse_wsscp_args(int argc, char *argv[], wsscp_wrapper_config_t *config) {
static struct option long_options[] = {
{"clientid", required_argument, 0, 'c'},
{"wssshd-host", required_argument, 0, 'H'},
......@@ -124,7 +124,7 @@ int parse_wsscp_args(int argc, char *argv[], wsscp_config_t *config) {
return 1;
}
int parse_target_string(const char *target, wsscp_config_t *config) {
int parse_target_string(const char *target, wsscp_wrapper_config_t *config) {
if (!target) return 0;
char *target_copy = strdup(target);
......@@ -171,7 +171,7 @@ int parse_target_string(const char *target, wsscp_config_t *config) {
return 1;
}
int parse_scp_port_from_args(wsscp_config_t *config) {
int parse_scp_port_from_args(wsscp_wrapper_config_t *config) {
if (!config->remaining_argv || config->remaining_argc < 2) {
return 0;
}
......@@ -221,7 +221,7 @@ char *find_wsssht_path() {
return NULL;
}
char *build_proxy_command(wsscp_config_t *config) {
char *build_proxy_command(wsscp_wrapper_config_t *config) {
char *wsssht_path = find_wsssht_path();
if (!wsssht_path) {
fprintf(stderr, "Error: wsssht not found in PATH or in the same directory as wsscp\n");
......@@ -325,7 +325,7 @@ int execute_scp_command(char *command, int debug) {
int main(int argc, char *argv[]) {
// Initialize configuration
wsscp_config_t config = {
wsscp_wrapper_config_t config = {
.client_id = NULL,
.wssshd_host = NULL,
.wssshd_port = 9898,
......
......@@ -51,13 +51,13 @@ typedef struct {
char *destination;
int remaining_argc;
char **remaining_argv;
} wsscp_config_t;
} wsscp_wrapper_config_t;
// Function declarations
int parse_wsscp_args(int argc, char *argv[], wsscp_config_t *config);
int parse_target_string(const char *target, wsscp_config_t *config);
int parse_scp_port_from_args(wsscp_config_t *config);
char *build_proxy_command(wsscp_config_t *config);
int parse_wsscp_args(int argc, char *argv[], wsscp_wrapper_config_t *config);
int parse_target_string(const char *target, wsscp_wrapper_config_t *config);
int parse_scp_port_from_args(wsscp_wrapper_config_t *config);
char *build_proxy_command(wsscp_wrapper_config_t *config);
char *find_wsssht_path();
int execute_scp_command(char *command, int debug);
void print_usage(const char *program_name);
......
......@@ -51,9 +51,9 @@ void print_wsssh_usage(const char *program_name) {
fprintf(stderr, " ETH: 0xdA6dAb526515b5cb556d20269207D43fcc760E51\n");
}
int parse_wsssh_args(int argc, char *argv[], wsssh_config_t *config) {
int parse_wsssh_args(int argc, char *argv[], wsssh_wrapper_config_t *config) {
// Initialize config with defaults
memset(config, 0, sizeof(wsssh_config_t));
memset(config, 0, sizeof(wsssh_wrapper_config_t));
config->wssshd_port = 9898;
// Parse options
......@@ -101,7 +101,7 @@ int parse_wsssh_args(int argc, char *argv[], wsssh_config_t *config) {
return 1;
}
int parse_target_string(const char *target, wsssh_config_t *config) {
int parse_target_string(const char *target, wsssh_wrapper_config_t *config) {
if (!target) return 0;
char *target_copy = strdup(target);
......@@ -141,7 +141,7 @@ int parse_target_string(const char *target, wsssh_config_t *config) {
return 1;
}
int parse_ssh_port_from_args(wsssh_config_t *config) {
int parse_ssh_port_from_args(wsssh_wrapper_config_t *config) {
if (!config->remaining_argv || config->remaining_argc < 2) {
return 0;
}
......@@ -191,7 +191,7 @@ char *find_wsssht_path() {
return NULL; // wsssht not found
}
char *build_proxy_command(wsssh_config_t *config) {
char *build_proxy_command(wsssh_wrapper_config_t *config) {
char *wsssht_path = find_wsssht_path();
if (!wsssht_path) {
fprintf(stderr, "Error: wsssht not found in PATH or in the same directory as wsssh\n");
......@@ -258,7 +258,7 @@ char *build_proxy_command(wsssh_config_t *config) {
return cmd;
}
char *build_ssh_command(wsssh_config_t *config, const char *proxy_command) {
char *build_ssh_command(wsssh_wrapper_config_t *config, const char *proxy_command) {
if (!config->user || !config->client_id) {
return NULL;
}
......@@ -318,7 +318,7 @@ int execute_ssh_command(const char *ssh_command, int debug) {
}
int main(int argc, char *argv[]) {
wsssh_config_t config;
wsssh_wrapper_config_t config;
// Parse arguments
if (!parse_wsssh_args(argc, argv, &config)) {
......
......@@ -40,14 +40,14 @@ typedef struct {
char *ssh_string;
int remaining_argc;
char **remaining_argv;
} wsssh_config_t;
} wsssh_wrapper_config_t;
// Function declarations
void print_wsssh_usage(const char *program_name);
int parse_wsssh_args(int argc, char *argv[], wsssh_config_t *config);
int parse_target_string(const char *target, wsssh_config_t *config);
char *build_proxy_command(wsssh_config_t *config);
char *build_ssh_command(wsssh_config_t *config, const char *proxy_command);
int parse_wsssh_args(int argc, char *argv[], wsssh_wrapper_config_t *config);
int parse_target_string(const char *target, wsssh_wrapper_config_t *config);
char *build_proxy_command(wsssh_wrapper_config_t *config);
char *build_ssh_command(wsssh_wrapper_config_t *config, const char *proxy_command);
int execute_ssh_command(const char *ssh_command, int debug);
#endif // WSSH_H
\ No newline at end of file
......@@ -303,6 +303,7 @@ int main(int argc, char *argv[]) {
int listen_sock = -1;
int setup_attempts = 0;
int max_setup_attempts = 3;
tunnel_setup_result_t setup_result = {0};
while (setup_attempts < max_setup_attempts && listen_sock < 0) {
if (config.debug && setup_attempts > 0) {
......@@ -310,7 +311,8 @@ int main(int argc, char *argv[]) {
fflush(stdout);
}
listen_sock = setup_tunnel(wssshd_host, wssshd_port, client_id, local_port, config.debug, 0, config.tunnel_host, config.encoding);
setup_result = setup_tunnel(wssshd_host, wssshd_port, client_id, local_port, config.debug, 0, config.tunnel_host, config.encoding, 0);
listen_sock = setup_result.listen_sock;
if (listen_sock < 0) {
setup_attempts++;
......@@ -376,10 +378,8 @@ int main(int argc, char *argv[]) {
if (accepted_sock < 0) {
perror("Local accept failed");
close(listen_sock);
pthread_mutex_lock(&tunnel_mutex);
free(active_tunnel);
active_tunnel = NULL;
pthread_mutex_unlock(&tunnel_mutex);
SSL_free(setup_result.ssl);
if (setup_result.ssl_ctx) SSL_CTX_free(setup_result.ssl_ctx);
free(config.local_port);
pthread_mutex_destroy(&tunnel_mutex);
return 1;
......@@ -387,26 +387,158 @@ int main(int argc, char *argv[]) {
close(listen_sock); // No longer needed
// Set the accepted socket with mutex protection
pthread_mutex_lock(&tunnel_mutex);
active_tunnel->local_sock = accepted_sock;
// Send any buffered data to the client immediately
if (active_tunnel->incoming_buffer && active_tunnel->incoming_buffer->used > 0) {
// Now send the tunnel request
if (config.debug) {
printf("[DEBUG - Tunnel] Sending %zu bytes of buffered server response to client\n", active_tunnel->incoming_buffer->used);
printf("[DEBUG - Tunnel] Local connection accepted, sending tunnel request...\n");
fflush(stdout);
}
ssize_t sent = send(accepted_sock, active_tunnel->incoming_buffer->buffer, active_tunnel->incoming_buffer->used, 0);
if (sent > 0) {
frame_buffer_consume(active_tunnel->incoming_buffer, sent);
// Send tunnel request
char *expanded_tunnel = expand_transport_list("any", 0); // Data channel transports
char *expanded_tunnel_control = expand_transport_list("any", 1); // Control channel transports
// Select best transport based on weight (lowest weight = highest priority)
char *best_tunnel = select_best_transport(expanded_tunnel);
char *best_tunnel_control = select_best_transport(expanded_tunnel_control);
if (!send_tunnel_request_message_with_enc(setup_result.ssl, client_id, setup_result.request_id, best_tunnel ? best_tunnel : expanded_tunnel, best_tunnel_control ? best_tunnel_control : expanded_tunnel_control, "ssh", config.encoding)) {
free(expanded_tunnel);
free(expanded_tunnel_control);
if (best_tunnel) free(best_tunnel);
if (best_tunnel_control) free(best_tunnel_control);
close(accepted_sock);
SSL_free(setup_result.ssl);
// SSL_CTX_free(ssl_ctx);
free(config.local_port);
pthread_mutex_destroy(&tunnel_mutex);
return 1;
}
free(expanded_tunnel);
free(expanded_tunnel_control);
if (best_tunnel) free(best_tunnel);
if (best_tunnel_control) free(best_tunnel_control);
if (config.debug) {
printf("[DEBUG] Sent %zd bytes of buffered server response to client\n", sent);
fflush(stdout);
printf("[DEBUG] Tunnel request sent for client: %s, request_id: %s\n", client_id, setup_result.request_id);
}
// Read acknowledgment
char buffer[BUFFER_SIZE];
int bytes_read = SSL_read(setup_result.ssl, buffer, sizeof(buffer));
if (bytes_read <= 0) {
if (config.debug) {
printf("[DEBUG] No acknowledgment received\n");
}
close(accepted_sock);
SSL_free(setup_result.ssl);
// SSL_CTX_free(ssl_ctx);
free(config.local_port);
pthread_mutex_destroy(&tunnel_mutex);
return 1;
}
// Parse WebSocket frame and check for tunnel_ack
char *payload;
int payload_len;
if (!parse_websocket_frame(buffer, bytes_read, &payload, &payload_len)) {
fprintf(stderr, "Failed to parse WebSocket frame\n");
close(accepted_sock);
SSL_free(setup_result.ssl);
// SSL_CTX_free(ssl_ctx);
free(config.local_port);
pthread_mutex_destroy(&tunnel_mutex);
return 1;
}
payload[payload_len] = '\0';
if (strstr(payload, "tunnel_ack") == NULL) {
fprintf(stderr, "Tunnel request denied or failed: %s\n", payload);
close(accepted_sock);
SSL_free(setup_result.ssl);
// SSL_CTX_free(ssl_ctx);
free(config.local_port);
pthread_mutex_destroy(&tunnel_mutex);
return 1;
}
if (config.debug) {
printf("[DEBUG] Tunnel established, creating tunnel structure...\n");
}
// Create tunnel structure
tunnel_t *new_tunnel = malloc(sizeof(tunnel_t));
if (!new_tunnel) {
perror("Memory allocation failed");
close(accepted_sock);
SSL_free(setup_result.ssl);
// SSL_CTX_free(ssl_ctx);
free(config.local_port);
pthread_mutex_destroy(&tunnel_mutex);
return 1;
}
new_tunnel->outgoing_buffer = NULL;
new_tunnel->incoming_buffer = frame_buffer_init();
if (!new_tunnel->incoming_buffer) {
perror("Failed to initialize incoming buffer");
free(new_tunnel);
close(accepted_sock);
SSL_free(setup_result.ssl);
if (setup_result.ssl_ctx) SSL_CTX_free(setup_result.ssl_ctx);
free(config.local_port);
pthread_mutex_destroy(&tunnel_mutex);
return 1;
}
strcpy(new_tunnel->request_id, setup_result.request_id);
new_tunnel->sock = -1;
new_tunnel->local_sock = accepted_sock;
new_tunnel->active = 1;
new_tunnel->broken = 0;
new_tunnel->ssl = setup_result.ssl;
new_tunnel->server_version_sent = 0;
new_tunnel->bin = (config.encoding == ENCODING_BINARY);
new_tunnel->encoding = config.encoding;
new_tunnel->retransmission_buffer = retransmission_buffer_init();
if (!new_tunnel->retransmission_buffer) {
perror("Failed to initialize retransmission buffer");
frame_buffer_free(new_tunnel->incoming_buffer);
free(new_tunnel);
close(accepted_sock);
SSL_free(setup_result.ssl);
if (setup_result.ssl_ctx) SSL_CTX_free(setup_result.ssl_ctx);
free(config.local_port);
pthread_mutex_destroy(&tunnel_mutex);
return 1;
}
// Initialize keep-alive statistics
time_t current_time = time(NULL);
new_tunnel->last_keepalive_sent = current_time;
new_tunnel->last_keepalive_received = current_time;
new_tunnel->total_bytes_sent = 0;
new_tunnel->total_bytes_received = 0;
new_tunnel->bytes_last_period = 0;
new_tunnel->last_stats_reset = current_time;
// Add the new tunnel to the array
pthread_mutex_lock(&tunnel_mutex);
if (!add_tunnel(new_tunnel)) {
frame_buffer_free(new_tunnel->incoming_buffer);
free(new_tunnel);
pthread_mutex_unlock(&tunnel_mutex);
close(accepted_sock);
SSL_free(setup_result.ssl);
if (setup_result.ssl_ctx) SSL_CTX_free(setup_result.ssl_ctx);
free(config.local_port);
pthread_mutex_destroy(&tunnel_mutex);
return 1;
}
active_tunnel = new_tunnel;
pthread_mutex_unlock(&tunnel_mutex);
if (config.debug) {
......@@ -441,8 +573,6 @@ int main(int argc, char *argv[]) {
pthread_detach(thread);
// Main tunnel loop - handle WebSocket messages
char buffer[BUFFER_SIZE];
int bytes_read;
fd_set readfds;
struct timeval tv;
int tunnel_broken = 0;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment