Apply same performance optimizations to wsssh C version as wsscp

- Replace blocking recv() with select() for efficient I/O multiplexing
- Add 50ms timeouts to prevent indefinite blocking operations
- Increase buffer size from 4KB to 1MB for optimal throughput
- Add 3-second exit delay after SSH completion for clean shutdown
- Fix setup_tunnel return codes to prevent segmentation faults
- Correct memory management in argument cleanup
parent 7eb01b2a
...@@ -32,8 +32,9 @@ ...@@ -32,8 +32,9 @@
#include <sys/wait.h> #include <sys/wait.h>
#include <fcntl.h> #include <fcntl.h>
#include <pthread.h> #include <pthread.h>
#include <sys/select.h>
#define BUFFER_SIZE 4096 #define BUFFER_SIZE 1048576
#define DEFAULT_PORT 22 #define DEFAULT_PORT 22
typedef struct { typedef struct {
...@@ -450,6 +451,8 @@ void *forward_tcp_to_ws(void *arg) { ...@@ -450,6 +451,8 @@ void *forward_tcp_to_ws(void *arg) {
int debug = args->debug; int debug = args->debug;
char buffer[BUFFER_SIZE]; char buffer[BUFFER_SIZE];
int bytes_read; int bytes_read;
fd_set readfds;
struct timeval tv;
while (1) { while (1) {
pthread_mutex_lock(&tunnel_mutex); pthread_mutex_lock(&tunnel_mutex);
...@@ -462,39 +465,59 @@ void *forward_tcp_to_ws(void *arg) { ...@@ -462,39 +465,59 @@ void *forward_tcp_to_ws(void *arg) {
strcpy(request_id, active_tunnel->request_id); strcpy(request_id, active_tunnel->request_id);
pthread_mutex_unlock(&tunnel_mutex); pthread_mutex_unlock(&tunnel_mutex);
bytes_read = recv(sock, buffer, sizeof(buffer), 0); // Use select to wait for data on local socket
if (bytes_read <= 0) { FD_ZERO(&readfds);
FD_SET(sock, &readfds);
tv.tv_sec = 0; // 0 seconds
tv.tv_usec = 50000; // 50ms timeout
int retval = select(sock + 1, &readfds, NULL, NULL, &tv);
if (retval == -1) {
if (debug) { if (debug) {
printf("[DEBUG] TCP connection closed or error\n"); perror("[DEBUG] select failed");
fflush(stdout); fflush(stdout);
} }
break; break;
} else if (retval == 0) {
// Timeout, continue loop
continue;
} }
if (debug) { if (FD_ISSET(sock, &readfds)) {
printf("[DEBUG] Forwarding %d bytes from TCP to WebSocket\n", bytes_read); bytes_read = recv(sock, buffer, sizeof(buffer), 0);
fflush(stdout); if (bytes_read <= 0) {
} if (debug) {
printf("[DEBUG] TCP connection closed or error\n");
// Convert to hex fflush(stdout);
char hex_data[bytes_read * 2 + 1]; }
for (int i = 0; i < bytes_read; i++) { break;
sprintf(hex_data + i * 2, "%02x", (unsigned char)buffer[i]); }
}
hex_data[bytes_read * 2] = '\0';
// Send as tunnel_data
char message[BUFFER_SIZE];
snprintf(message, sizeof(message),
"{\"type\":\"tunnel_data\",\"request_id\":\"%s\",\"data\":\"%s\"}",
request_id, hex_data);
if (!send_websocket_frame(ssl, message)) {
if (debug) { if (debug) {
printf("[DEBUG] Failed to send WebSocket frame\n"); printf("[DEBUG] Forwarding %d bytes from TCP to WebSocket\n", bytes_read);
fflush(stdout); fflush(stdout);
} }
break;
// Convert to hex
char hex_data[bytes_read * 2 + 1];
for (int i = 0; i < bytes_read; i++) {
sprintf(hex_data + i * 2, "%02x", (unsigned char)buffer[i]);
}
hex_data[bytes_read * 2] = '\0';
// Send as tunnel_data
char message[BUFFER_SIZE];
snprintf(message, sizeof(message),
"{\"type\":\"tunnel_data\",\"request_id\":\"%s\",\"data\":\"%s\"}",
request_id, hex_data);
if (!send_websocket_frame(ssl, message)) {
if (debug) {
printf("[DEBUG] Failed to send WebSocket frame\n");
fflush(stdout);
}
break;
}
} }
} }
...@@ -560,13 +583,13 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id ...@@ -560,13 +583,13 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id
// Resolve hostname // Resolve hostname
if ((he = gethostbyname(wssshd_host)) == NULL) { if ((he = gethostbyname(wssshd_host)) == NULL) {
herror("gethostbyname"); herror("gethostbyname");
return 0; return -1;
} }
// Create socket // Create socket
if ((sock = socket(AF_INET, SOCK_STREAM, 0)) < 0) { if ((sock = socket(AF_INET, SOCK_STREAM, 0)) < 0) {
perror("Socket creation failed"); perror("Socket creation failed");
return 0; return -1;
} }
memset(&server_addr, 0, sizeof(server_addr)); memset(&server_addr, 0, sizeof(server_addr));
...@@ -578,7 +601,7 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id ...@@ -578,7 +601,7 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id
if (connect(sock, (struct sockaddr *)&server_addr, sizeof(server_addr)) < 0) { if (connect(sock, (struct sockaddr *)&server_addr, sizeof(server_addr)) < 0) {
perror("Connection failed"); perror("Connection failed");
close(sock); close(sock);
return 0; return -1;
} }
// Initialize SSL // Initialize SSL
...@@ -590,7 +613,7 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id ...@@ -590,7 +613,7 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id
if (!ssl_ctx) { if (!ssl_ctx) {
ERR_print_errors_fp(stderr); ERR_print_errors_fp(stderr);
close(sock); close(sock);
return 0; return -1;
} }
// Allow self-signed certificates // Allow self-signed certificates
...@@ -609,7 +632,7 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id ...@@ -609,7 +632,7 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id
SSL_free(ssl); SSL_free(ssl);
SSL_CTX_free(ssl_ctx); SSL_CTX_free(ssl_ctx);
close(sock); close(sock);
return 0; return -1;
} }
if (debug) { if (debug) {
printf("[DEBUG] SSL connection established\n"); printf("[DEBUG] SSL connection established\n");
...@@ -626,7 +649,7 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id ...@@ -626,7 +649,7 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id
SSL_free(ssl); SSL_free(ssl);
SSL_CTX_free(ssl_ctx); SSL_CTX_free(ssl_ctx);
close(sock); close(sock);
return 0; return -1;
} }
if (debug) { if (debug) {
printf("[DEBUG] WebSocket handshake successful\n"); printf("[DEBUG] WebSocket handshake successful\n");
...@@ -638,7 +661,7 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id ...@@ -638,7 +661,7 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id
SSL_free(ssl); SSL_free(ssl);
SSL_CTX_free(ssl_ctx); SSL_CTX_free(ssl_ctx);
close(sock); close(sock);
return 0; return -1;
} }
if (debug) { if (debug) {
...@@ -654,7 +677,7 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id ...@@ -654,7 +677,7 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id
SSL_free(ssl); SSL_free(ssl);
SSL_CTX_free(ssl_ctx); SSL_CTX_free(ssl_ctx);
close(sock); close(sock);
return 0; return -1;
} }
if (debug) { if (debug) {
...@@ -688,7 +711,7 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id ...@@ -688,7 +711,7 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id
SSL_free(ssl); SSL_free(ssl);
SSL_CTX_free(ssl_ctx); SSL_CTX_free(ssl_ctx);
close(sock); close(sock);
return 0; return -1;
} }
// Parse WebSocket frame // Parse WebSocket frame
...@@ -707,7 +730,7 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id ...@@ -707,7 +730,7 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id
SSL_free(ssl); SSL_free(ssl);
SSL_CTX_free(ssl_ctx); SSL_CTX_free(ssl_ctx);
close(sock); close(sock);
return 0; return -1;
} }
// Null terminate payload // Null terminate payload
...@@ -733,7 +756,7 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id ...@@ -733,7 +756,7 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id
SSL_free(ssl); SSL_free(ssl);
SSL_CTX_free(ssl_ctx); SSL_CTX_free(ssl_ctx);
close(sock); close(sock);
return 0; return -1;
} }
if (debug) { if (debug) {
...@@ -747,7 +770,7 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id ...@@ -747,7 +770,7 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id
SSL_free(ssl); SSL_free(ssl);
SSL_CTX_free(ssl_ctx); SSL_CTX_free(ssl_ctx);
close(sock); close(sock);
return 0; return -1;
} }
strcpy(active_tunnel->request_id, request_id); strcpy(active_tunnel->request_id, request_id);
...@@ -764,7 +787,7 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id ...@@ -764,7 +787,7 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id
SSL_free(ssl); SSL_free(ssl);
SSL_CTX_free(ssl_ctx); SSL_CTX_free(ssl_ctx);
close(sock); close(sock);
return 0; return -1;
} }
struct sockaddr_in local_addr; struct sockaddr_in local_addr;
...@@ -781,7 +804,7 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id ...@@ -781,7 +804,7 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id
SSL_free(ssl); SSL_free(ssl);
SSL_CTX_free(ssl_ctx); SSL_CTX_free(ssl_ctx);
close(sock); close(sock);
return 0; return -1;
} }
if (listen(listen_sock, 1) < 0) { if (listen(listen_sock, 1) < 0) {
...@@ -792,7 +815,7 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id ...@@ -792,7 +815,7 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id
SSL_free(ssl); SSL_free(ssl);
SSL_CTX_free(ssl_ctx); SSL_CTX_free(ssl_ctx);
close(sock); close(sock);
return 0; return -1;
} }
if (debug) { if (debug) {
...@@ -1034,100 +1057,124 @@ int main(int argc, char *argv[]) { ...@@ -1034,100 +1057,124 @@ int main(int argc, char *argv[]) {
// Main tunnel loop - handle WebSocket messages // Main tunnel loop - handle WebSocket messages
char buffer[BUFFER_SIZE]; char buffer[BUFFER_SIZE];
int bytes_read; int bytes_read;
fd_set readfds;
struct timeval tv;
int ssl_fd = SSL_get_fd(active_tunnel->ssl);
while (active_tunnel && active_tunnel->active) { while (active_tunnel && active_tunnel->active) {
bytes_read = SSL_read(active_tunnel->ssl, buffer, sizeof(buffer)); // Use select to wait for data on SSL socket with timeout
if (bytes_read <= 0) { FD_ZERO(&readfds);
FD_SET(ssl_fd, &readfds);
tv.tv_sec = 0; // 0 seconds
tv.tv_usec = 50000; // 50ms timeout
int retval = select(ssl_fd + 1, &readfds, NULL, NULL, &tv);
if (retval == -1) {
if (config.debug) { if (config.debug) {
printf("[DEBUG] WebSocket connection closed\n"); perror("[DEBUG] select on SSL fd failed");
fflush(stdout); fflush(stdout);
} }
break; break;
} else if (retval == 0) {
// Timeout, continue loop
continue;
} }
// Check if it's a close frame if (FD_ISSET(ssl_fd, &readfds)) {
if (bytes_read >= 2 && (buffer[0] & 0x8F) == 0x88) { bytes_read = SSL_read(active_tunnel->ssl, buffer, sizeof(buffer));
if (config.debug) { if (bytes_read <= 0) {
printf("[DEBUG] Received close frame from server\n"); if (config.debug) {
fflush(stdout); printf("[DEBUG] WebSocket connection closed\n");
fflush(stdout);
}
break;
} }
break;
}
char *payload; // Check if it's a close frame
int payload_len; if (bytes_read >= 2 && (buffer[0] & 0x8F) == 0x88) {
if (!parse_websocket_frame(buffer, bytes_read, &payload, &payload_len)) { if (config.debug) {
if (config.debug) { printf("[DEBUG] Received close frame from server\n");
printf("[DEBUG] Failed to parse WebSocket frame\n"); fflush(stdout);
fflush(stdout); }
break;
} }
continue;
}
payload[payload_len] = '\0'; char *payload;
int payload_len;
if (!parse_websocket_frame(buffer, bytes_read, &payload, &payload_len)) {
if (config.debug) {
printf("[DEBUG] Failed to parse WebSocket frame\n");
fflush(stdout);
}
continue;
}
if (config.debug) { payload[payload_len] = '\0';
printf("[DEBUG] Received: %s\n", payload);
fflush(stdout);
}
// Handle messages
if (strstr(payload, "tunnel_data")) {
if (config.debug) { if (config.debug) {
printf("[DEBUG] Received tunnel_data message\n"); printf("[DEBUG] Received: %s\n", payload);
fflush(stdout); fflush(stdout);
} }
// Extract request_id and data
char *id_start = strstr(payload, "\"request_id\""); // Handle messages
char *data_start = strstr(payload, "\"data\""); if (strstr(payload, "tunnel_data")) {
if (id_start && data_start) { if (config.debug) {
char *colon = strchr(id_start, ':'); printf("[DEBUG] Received tunnel_data message\n");
if (colon) { fflush(stdout);
char *open_quote = strchr(colon, '"'); }
if (open_quote) { // Extract request_id and data
id_start = open_quote + 1; char *id_start = strstr(payload, "\"request_id\"");
char *close_quote = strchr(id_start, '"'); char *data_start = strstr(payload, "\"data\"");
if (close_quote) { if (id_start && data_start) {
*close_quote = '\0'; char *colon = strchr(id_start, ':');
char *data_colon = strchr(data_start, ':'); if (colon) {
if (data_colon) { char *open_quote = strchr(colon, '"');
char *data_quote = strchr(data_colon, '"'); if (open_quote) {
if (data_quote) { id_start = open_quote + 1;
data_start = data_quote + 1; char *close_quote = strchr(id_start, '"');
char *data_end = strchr(data_start, '"'); if (close_quote) {
if (data_end) { *close_quote = '\0';
*data_end = '\0'; char *data_colon = strchr(data_start, ':');
handle_tunnel_data(active_tunnel->ssl, id_start, data_start, config.debug); if (data_colon) {
char *data_quote = strchr(data_colon, '"');
if (data_quote) {
data_start = data_quote + 1;
char *data_end = strchr(data_start, '"');
if (data_end) {
*data_end = '\0';
handle_tunnel_data(active_tunnel->ssl, id_start, data_start, config.debug);
}
} }
} }
} }
} }
} }
} }
} } else if (strstr(payload, "tunnel_close")) {
} else if (strstr(payload, "tunnel_close")) { if (config.debug) {
if (config.debug) { printf("[DEBUG] Received tunnel_close message\n");
printf("[DEBUG] Received tunnel_close message\n"); fflush(stdout);
fflush(stdout); }
} char *id_start = strstr(payload, "\"request_id\"");
char *id_start = strstr(payload, "\"request_id\""); if (id_start) {
if (id_start) { char *colon = strchr(id_start, ':');
char *colon = strchr(id_start, ':'); if (colon) {
if (colon) { char *open_quote = strchr(colon, '"');
char *open_quote = strchr(colon, '"'); if (open_quote) {
if (open_quote) { id_start = open_quote + 1;
id_start = open_quote + 1; char *close_quote = strchr(id_start, '"');
char *close_quote = strchr(id_start, '"'); if (close_quote) {
if (close_quote) { *close_quote = '\0';
*close_quote = '\0'; handle_tunnel_close(active_tunnel->ssl, id_start, config.debug);
handle_tunnel_close(active_tunnel->ssl, id_start, config.debug); }
} }
} }
} }
} } else {
} else { if (config.debug) {
if (config.debug) { printf("[DEBUG] Received unknown message type: %s\n", payload);
printf("[DEBUG] Received unknown message type: %s\n", payload); fflush(stdout);
fflush(stdout); }
} }
} }
} }
...@@ -1146,6 +1193,13 @@ int main(int argc, char *argv[]) { ...@@ -1146,6 +1193,13 @@ int main(int argc, char *argv[]) {
fflush(stdout); fflush(stdout);
} }
// Wait a few seconds before exiting cleanly
if (config.debug) {
printf("[DEBUG] Waiting 3 seconds before exit...\n");
fflush(stdout);
}
sleep(3);
// Cleanup // Cleanup
if (active_tunnel) { if (active_tunnel) {
close(active_tunnel->local_sock); close(active_tunnel->local_sock);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment