Apply same performance optimizations to wsssh C version as wsscp

- Replace blocking recv() with select() for efficient I/O multiplexing
- Add 50ms timeouts to prevent indefinite blocking operations
- Increase buffer size from 4KB to 1MB for optimal throughput
- Add 3-second exit delay after SSH completion for clean shutdown
- Fix setup_tunnel return codes to prevent segmentation faults
- Correct memory management in argument cleanup
parent 7eb01b2a
......@@ -32,8 +32,9 @@
#include <sys/wait.h>
#include <fcntl.h>
#include <pthread.h>
#include <sys/select.h>
#define BUFFER_SIZE 4096
#define BUFFER_SIZE 1048576
#define DEFAULT_PORT 22
typedef struct {
......@@ -450,6 +451,8 @@ void *forward_tcp_to_ws(void *arg) {
int debug = args->debug;
char buffer[BUFFER_SIZE];
int bytes_read;
fd_set readfds;
struct timeval tv;
while (1) {
pthread_mutex_lock(&tunnel_mutex);
......@@ -462,39 +465,59 @@ void *forward_tcp_to_ws(void *arg) {
strcpy(request_id, active_tunnel->request_id);
pthread_mutex_unlock(&tunnel_mutex);
bytes_read = recv(sock, buffer, sizeof(buffer), 0);
if (bytes_read <= 0) {
// Use select to wait for data on local socket
FD_ZERO(&readfds);
FD_SET(sock, &readfds);
tv.tv_sec = 0; // 0 seconds
tv.tv_usec = 50000; // 50ms timeout
int retval = select(sock + 1, &readfds, NULL, NULL, &tv);
if (retval == -1) {
if (debug) {
printf("[DEBUG] TCP connection closed or error\n");
perror("[DEBUG] select failed");
fflush(stdout);
}
break;
} else if (retval == 0) {
// Timeout, continue loop
continue;
}
if (debug) {
printf("[DEBUG] Forwarding %d bytes from TCP to WebSocket\n", bytes_read);
fflush(stdout);
}
// Convert to hex
char hex_data[bytes_read * 2 + 1];
for (int i = 0; i < bytes_read; i++) {
sprintf(hex_data + i * 2, "%02x", (unsigned char)buffer[i]);
}
hex_data[bytes_read * 2] = '\0';
// Send as tunnel_data
char message[BUFFER_SIZE];
snprintf(message, sizeof(message),
"{\"type\":\"tunnel_data\",\"request_id\":\"%s\",\"data\":\"%s\"}",
request_id, hex_data);
if (FD_ISSET(sock, &readfds)) {
bytes_read = recv(sock, buffer, sizeof(buffer), 0);
if (bytes_read <= 0) {
if (debug) {
printf("[DEBUG] TCP connection closed or error\n");
fflush(stdout);
}
break;
}
if (!send_websocket_frame(ssl, message)) {
if (debug) {
printf("[DEBUG] Failed to send WebSocket frame\n");
printf("[DEBUG] Forwarding %d bytes from TCP to WebSocket\n", bytes_read);
fflush(stdout);
}
break;
// Convert to hex
char hex_data[bytes_read * 2 + 1];
for (int i = 0; i < bytes_read; i++) {
sprintf(hex_data + i * 2, "%02x", (unsigned char)buffer[i]);
}
hex_data[bytes_read * 2] = '\0';
// Send as tunnel_data
char message[BUFFER_SIZE];
snprintf(message, sizeof(message),
"{\"type\":\"tunnel_data\",\"request_id\":\"%s\",\"data\":\"%s\"}",
request_id, hex_data);
if (!send_websocket_frame(ssl, message)) {
if (debug) {
printf("[DEBUG] Failed to send WebSocket frame\n");
fflush(stdout);
}
break;
}
}
}
......@@ -560,13 +583,13 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id
// Resolve hostname
if ((he = gethostbyname(wssshd_host)) == NULL) {
herror("gethostbyname");
return 0;
return -1;
}
// Create socket
if ((sock = socket(AF_INET, SOCK_STREAM, 0)) < 0) {
perror("Socket creation failed");
return 0;
return -1;
}
memset(&server_addr, 0, sizeof(server_addr));
......@@ -578,7 +601,7 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id
if (connect(sock, (struct sockaddr *)&server_addr, sizeof(server_addr)) < 0) {
perror("Connection failed");
close(sock);
return 0;
return -1;
}
// Initialize SSL
......@@ -590,7 +613,7 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id
if (!ssl_ctx) {
ERR_print_errors_fp(stderr);
close(sock);
return 0;
return -1;
}
// Allow self-signed certificates
......@@ -609,7 +632,7 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id
SSL_free(ssl);
SSL_CTX_free(ssl_ctx);
close(sock);
return 0;
return -1;
}
if (debug) {
printf("[DEBUG] SSL connection established\n");
......@@ -626,7 +649,7 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id
SSL_free(ssl);
SSL_CTX_free(ssl_ctx);
close(sock);
return 0;
return -1;
}
if (debug) {
printf("[DEBUG] WebSocket handshake successful\n");
......@@ -638,7 +661,7 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id
SSL_free(ssl);
SSL_CTX_free(ssl_ctx);
close(sock);
return 0;
return -1;
}
if (debug) {
......@@ -654,7 +677,7 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id
SSL_free(ssl);
SSL_CTX_free(ssl_ctx);
close(sock);
return 0;
return -1;
}
if (debug) {
......@@ -688,7 +711,7 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id
SSL_free(ssl);
SSL_CTX_free(ssl_ctx);
close(sock);
return 0;
return -1;
}
// Parse WebSocket frame
......@@ -707,7 +730,7 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id
SSL_free(ssl);
SSL_CTX_free(ssl_ctx);
close(sock);
return 0;
return -1;
}
// Null terminate payload
......@@ -733,7 +756,7 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id
SSL_free(ssl);
SSL_CTX_free(ssl_ctx);
close(sock);
return 0;
return -1;
}
if (debug) {
......@@ -747,7 +770,7 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id
SSL_free(ssl);
SSL_CTX_free(ssl_ctx);
close(sock);
return 0;
return -1;
}
strcpy(active_tunnel->request_id, request_id);
......@@ -764,7 +787,7 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id
SSL_free(ssl);
SSL_CTX_free(ssl_ctx);
close(sock);
return 0;
return -1;
}
struct sockaddr_in local_addr;
......@@ -781,7 +804,7 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id
SSL_free(ssl);
SSL_CTX_free(ssl_ctx);
close(sock);
return 0;
return -1;
}
if (listen(listen_sock, 1) < 0) {
......@@ -792,7 +815,7 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id
SSL_free(ssl);
SSL_CTX_free(ssl_ctx);
close(sock);
return 0;
return -1;
}
if (debug) {
......@@ -1034,100 +1057,124 @@ int main(int argc, char *argv[]) {
// Main tunnel loop - handle WebSocket messages
char buffer[BUFFER_SIZE];
int bytes_read;
fd_set readfds;
struct timeval tv;
int ssl_fd = SSL_get_fd(active_tunnel->ssl);
while (active_tunnel && active_tunnel->active) {
bytes_read = SSL_read(active_tunnel->ssl, buffer, sizeof(buffer));
if (bytes_read <= 0) {
// Use select to wait for data on SSL socket with timeout
FD_ZERO(&readfds);
FD_SET(ssl_fd, &readfds);
tv.tv_sec = 0; // 0 seconds
tv.tv_usec = 50000; // 50ms timeout
int retval = select(ssl_fd + 1, &readfds, NULL, NULL, &tv);
if (retval == -1) {
if (config.debug) {
printf("[DEBUG] WebSocket connection closed\n");
perror("[DEBUG] select on SSL fd failed");
fflush(stdout);
}
break;
} else if (retval == 0) {
// Timeout, continue loop
continue;
}
// Check if it's a close frame
if (bytes_read >= 2 && (buffer[0] & 0x8F) == 0x88) {
if (config.debug) {
printf("[DEBUG] Received close frame from server\n");
fflush(stdout);
if (FD_ISSET(ssl_fd, &readfds)) {
bytes_read = SSL_read(active_tunnel->ssl, buffer, sizeof(buffer));
if (bytes_read <= 0) {
if (config.debug) {
printf("[DEBUG] WebSocket connection closed\n");
fflush(stdout);
}
break;
}
break;
}
char *payload;
int payload_len;
if (!parse_websocket_frame(buffer, bytes_read, &payload, &payload_len)) {
if (config.debug) {
printf("[DEBUG] Failed to parse WebSocket frame\n");
fflush(stdout);
// Check if it's a close frame
if (bytes_read >= 2 && (buffer[0] & 0x8F) == 0x88) {
if (config.debug) {
printf("[DEBUG] Received close frame from server\n");
fflush(stdout);
}
break;
}
continue;
}
payload[payload_len] = '\0';
char *payload;
int payload_len;
if (!parse_websocket_frame(buffer, bytes_read, &payload, &payload_len)) {
if (config.debug) {
printf("[DEBUG] Failed to parse WebSocket frame\n");
fflush(stdout);
}
continue;
}
if (config.debug) {
printf("[DEBUG] Received: %s\n", payload);
fflush(stdout);
}
payload[payload_len] = '\0';
// Handle messages
if (strstr(payload, "tunnel_data")) {
if (config.debug) {
printf("[DEBUG] Received tunnel_data message\n");
printf("[DEBUG] Received: %s\n", payload);
fflush(stdout);
}
// Extract request_id and data
char *id_start = strstr(payload, "\"request_id\"");
char *data_start = strstr(payload, "\"data\"");
if (id_start && data_start) {
char *colon = strchr(id_start, ':');
if (colon) {
char *open_quote = strchr(colon, '"');
if (open_quote) {
id_start = open_quote + 1;
char *close_quote = strchr(id_start, '"');
if (close_quote) {
*close_quote = '\0';
char *data_colon = strchr(data_start, ':');
if (data_colon) {
char *data_quote = strchr(data_colon, '"');
if (data_quote) {
data_start = data_quote + 1;
char *data_end = strchr(data_start, '"');
if (data_end) {
*data_end = '\0';
handle_tunnel_data(active_tunnel->ssl, id_start, data_start, config.debug);
// Handle messages
if (strstr(payload, "tunnel_data")) {
if (config.debug) {
printf("[DEBUG] Received tunnel_data message\n");
fflush(stdout);
}
// Extract request_id and data
char *id_start = strstr(payload, "\"request_id\"");
char *data_start = strstr(payload, "\"data\"");
if (id_start && data_start) {
char *colon = strchr(id_start, ':');
if (colon) {
char *open_quote = strchr(colon, '"');
if (open_quote) {
id_start = open_quote + 1;
char *close_quote = strchr(id_start, '"');
if (close_quote) {
*close_quote = '\0';
char *data_colon = strchr(data_start, ':');
if (data_colon) {
char *data_quote = strchr(data_colon, '"');
if (data_quote) {
data_start = data_quote + 1;
char *data_end = strchr(data_start, '"');
if (data_end) {
*data_end = '\0';
handle_tunnel_data(active_tunnel->ssl, id_start, data_start, config.debug);
}
}
}
}
}
}
}
}
} else if (strstr(payload, "tunnel_close")) {
if (config.debug) {
printf("[DEBUG] Received tunnel_close message\n");
fflush(stdout);
}
char *id_start = strstr(payload, "\"request_id\"");
if (id_start) {
char *colon = strchr(id_start, ':');
if (colon) {
char *open_quote = strchr(colon, '"');
if (open_quote) {
id_start = open_quote + 1;
char *close_quote = strchr(id_start, '"');
if (close_quote) {
*close_quote = '\0';
handle_tunnel_close(active_tunnel->ssl, id_start, config.debug);
} else if (strstr(payload, "tunnel_close")) {
if (config.debug) {
printf("[DEBUG] Received tunnel_close message\n");
fflush(stdout);
}
char *id_start = strstr(payload, "\"request_id\"");
if (id_start) {
char *colon = strchr(id_start, ':');
if (colon) {
char *open_quote = strchr(colon, '"');
if (open_quote) {
id_start = open_quote + 1;
char *close_quote = strchr(id_start, '"');
if (close_quote) {
*close_quote = '\0';
handle_tunnel_close(active_tunnel->ssl, id_start, config.debug);
}
}
}
}
}
} else {
if (config.debug) {
printf("[DEBUG] Received unknown message type: %s\n", payload);
fflush(stdout);
} else {
if (config.debug) {
printf("[DEBUG] Received unknown message type: %s\n", payload);
fflush(stdout);
}
}
}
}
......@@ -1146,6 +1193,13 @@ int main(int argc, char *argv[]) {
fflush(stdout);
}
// Wait a few seconds before exiting cleanly
if (config.debug) {
printf("[DEBUG] Waiting 3 seconds before exit...\n");
fflush(stdout);
}
sleep(3);
// Cleanup
if (active_tunnel) {
close(active_tunnel->local_sock);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment