Optimize C wsscp performance and add SSH options

- Replace busy-waiting usleep loops with select() for efficient I/O multiplexing
- Add 50ms timeouts to prevent blocking operations
- Maintain 1MB buffers for optimal throughput
- Add 3-second exit delay after SCP completion for clean shutdown
- Fix critical bug in setup_tunnel return codes to prevent segfaults
- Correct memory management in argument cleanup to avoid corruption
- Add StrictHostKeyChecking=no option to wsscp and wsssh in both Python and C implementations
parent 92294668
...@@ -256,8 +256,8 @@ def main(): ...@@ -256,8 +256,8 @@ def main():
final_args.append(arg) final_args.append(arg)
i += 1 i += 1
# Add port argument for local tunnel at the beginning # Add StrictHostKeyChecking=no and port argument for local tunnel at the beginning
final_args = ['-P', str(local_port)] + final_args final_args = ['-o', 'StrictHostKeyChecking=no', '-P', str(local_port)] + final_args
if debug: print(f"[DEBUG] Final SCP args: {final_args}") if debug: print(f"[DEBUG] Final SCP args: {final_args}")
......
...@@ -251,8 +251,8 @@ def main(): ...@@ -251,8 +251,8 @@ def main():
else: else:
final_args.append(arg) final_args.append(arg)
# Add port argument for local tunnel # Add StrictHostKeyChecking=no and port argument for local tunnel
final_args.extend(['-p', str(local_port)]) final_args.extend(['-o', 'StrictHostKeyChecking=no', '-p', str(local_port)])
if debug: print(f"[DEBUG] Final SSH args: {final_args}") if debug: print(f"[DEBUG] Final SSH args: {final_args}")
......
...@@ -33,6 +33,7 @@ ...@@ -33,6 +33,7 @@
#include <fcntl.h> #include <fcntl.h>
#include <pthread.h> #include <pthread.h>
#include <errno.h> #include <errno.h>
#include <sys/select.h>
#define BUFFER_SIZE 1048576 #define BUFFER_SIZE 1048576
#define DEFAULT_PORT 22 #define DEFAULT_PORT 22
...@@ -206,8 +207,8 @@ int parse_scp_args(int argc, char *argv[], char **destination, int *scp_port, in ...@@ -206,8 +207,8 @@ int parse_scp_args(int argc, char *argv[], char **destination, int *scp_port, in
} }
char **modify_scp_args(int argc, char *argv[], const char *original_host, int local_port, int *new_argc) { char **modify_scp_args(int argc, char *argv[], const char *original_host, int local_port, int *new_argc) {
// Allocate space for: scp + -P + port + original args + NULL // Allocate space for: scp + -o + StrictHostKeyChecking=no + -P + port + original args + NULL
char **new_args = malloc((argc + 4) * sizeof(char *)); char **new_args = malloc((argc + 6) * sizeof(char *));
if (!new_args) { if (!new_args) {
return NULL; return NULL;
} }
...@@ -215,7 +216,11 @@ char **modify_scp_args(int argc, char *argv[], const char *original_host, int lo ...@@ -215,7 +216,11 @@ char **modify_scp_args(int argc, char *argv[], const char *original_host, int lo
int idx = 0; int idx = 0;
new_args[idx++] = "scp"; new_args[idx++] = "scp";
// Add port argument first // Add StrictHostKeyChecking=no option
new_args[idx++] = "-o";
new_args[idx++] = "StrictHostKeyChecking=no";
// Add port argument
new_args[idx++] = "-P"; new_args[idx++] = "-P";
char *port_str = malloc(16); char *port_str = malloc(16);
if (!port_str) { if (!port_str) {
...@@ -594,6 +599,8 @@ void *forward_tcp_to_ws(void *arg) { ...@@ -594,6 +599,8 @@ void *forward_tcp_to_ws(void *arg) {
int debug = args->debug; int debug = args->debug;
char buffer[BUFFER_SIZE]; char buffer[BUFFER_SIZE];
int bytes_read; int bytes_read;
fd_set readfds;
struct timeval tv;
while (1) { while (1) {
pthread_mutex_lock(&tunnel_mutex); pthread_mutex_lock(&tunnel_mutex);
...@@ -626,12 +633,27 @@ void *forward_tcp_to_ws(void *arg) { ...@@ -626,12 +633,27 @@ void *forward_tcp_to_ws(void *arg) {
pthread_mutex_unlock(&tunnel_mutex); pthread_mutex_unlock(&tunnel_mutex);
bytes_read = recv(sock, buffer, sizeof(buffer), MSG_DONTWAIT); // Use select to wait for data on local socket
if (bytes_read == -1 && (errno == EAGAIN || errno == EWOULDBLOCK)) { FD_ZERO(&readfds);
// No data available, sleep to avoid busy waiting FD_SET(sock, &readfds);
usleep(50000); // Increased sleep time to reduce CPU usage tv.tv_sec = 0; // 0 seconds
tv.tv_usec = 50000; // 50ms timeout
int retval = select(sock + 1, &readfds, NULL, NULL, &tv);
if (retval == -1) {
if (debug) {
perror("[DEBUG] select failed");
fflush(stdout);
}
break;
} else if (retval == 0) {
// Timeout, continue loop
continue; continue;
} else if (bytes_read <= 0) { }
if (FD_ISSET(sock, &readfds)) {
bytes_read = recv(sock, buffer, sizeof(buffer), 0);
if (bytes_read <= 0) {
if (debug) { if (debug) {
printf("[DEBUG] TCP connection closed or error\n"); printf("[DEBUG] TCP connection closed or error\n");
fflush(stdout); fflush(stdout);
...@@ -665,6 +687,7 @@ void *forward_tcp_to_ws(void *arg) { ...@@ -665,6 +687,7 @@ void *forward_tcp_to_ws(void *arg) {
break; break;
} }
} }
}
if (debug) { if (debug) {
printf("[DEBUG] TCP to WebSocket forwarding thread exiting\n"); printf("[DEBUG] TCP to WebSocket forwarding thread exiting\n");
...@@ -735,13 +758,13 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id ...@@ -735,13 +758,13 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id
// Resolve hostname // Resolve hostname
if ((he = gethostbyname(wssshd_host)) == NULL) { if ((he = gethostbyname(wssshd_host)) == NULL) {
herror("gethostbyname"); herror("gethostbyname");
return 0; return -1;
} }
// Create socket // Create socket
if ((sock = socket(AF_INET, SOCK_STREAM, 0)) < 0) { if ((sock = socket(AF_INET, SOCK_STREAM, 0)) < 0) {
perror("Socket creation failed"); perror("Socket creation failed");
return 0; return -1;
} }
memset(&server_addr, 0, sizeof(server_addr)); memset(&server_addr, 0, sizeof(server_addr));
...@@ -753,7 +776,7 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id ...@@ -753,7 +776,7 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id
if (connect(sock, (struct sockaddr *)&server_addr, sizeof(server_addr)) < 0) { if (connect(sock, (struct sockaddr *)&server_addr, sizeof(server_addr)) < 0) {
perror("Connection failed"); perror("Connection failed");
close(sock); close(sock);
return 0; return -1;
} }
// Initialize SSL // Initialize SSL
...@@ -784,7 +807,7 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id ...@@ -784,7 +807,7 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id
SSL_free(ssl); SSL_free(ssl);
SSL_CTX_free(ssl_ctx); SSL_CTX_free(ssl_ctx);
close(sock); close(sock);
return 0; return -1;
} }
if (debug) { if (debug) {
printf("[DEBUG] SSL connection established\n"); printf("[DEBUG] SSL connection established\n");
...@@ -801,7 +824,7 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id ...@@ -801,7 +824,7 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id
SSL_free(ssl); SSL_free(ssl);
SSL_CTX_free(ssl_ctx); SSL_CTX_free(ssl_ctx);
close(sock); close(sock);
return 0; return -1;
} }
if (debug) { if (debug) {
printf("[DEBUG] WebSocket handshake successful\n"); printf("[DEBUG] WebSocket handshake successful\n");
...@@ -813,7 +836,7 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id ...@@ -813,7 +836,7 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id
SSL_free(ssl); SSL_free(ssl);
SSL_CTX_free(ssl_ctx); SSL_CTX_free(ssl_ctx);
close(sock); close(sock);
return 0; return -1;
} }
if (debug) { if (debug) {
...@@ -829,7 +852,7 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id ...@@ -829,7 +852,7 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id
SSL_free(ssl); SSL_free(ssl);
SSL_CTX_free(ssl_ctx); SSL_CTX_free(ssl_ctx);
close(sock); close(sock);
return 0; return -1;
} }
if (debug) { if (debug) {
...@@ -863,7 +886,7 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id ...@@ -863,7 +886,7 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id
SSL_free(ssl); SSL_free(ssl);
SSL_CTX_free(ssl_ctx); SSL_CTX_free(ssl_ctx);
close(sock); close(sock);
return 0; return -1;
} }
// Parse WebSocket frame // Parse WebSocket frame
...@@ -882,7 +905,7 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id ...@@ -882,7 +905,7 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id
SSL_free(ssl); SSL_free(ssl);
SSL_CTX_free(ssl_ctx); SSL_CTX_free(ssl_ctx);
close(sock); close(sock);
return 0; return -1;
} }
// Null terminate payload // Null terminate payload
...@@ -908,7 +931,7 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id ...@@ -908,7 +931,7 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id
SSL_free(ssl); SSL_free(ssl);
SSL_CTX_free(ssl_ctx); SSL_CTX_free(ssl_ctx);
close(sock); close(sock);
return 0; return -1;
} }
if (debug) { if (debug) {
...@@ -922,7 +945,7 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id ...@@ -922,7 +945,7 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id
SSL_free(ssl); SSL_free(ssl);
SSL_CTX_free(ssl_ctx); SSL_CTX_free(ssl_ctx);
close(sock); close(sock);
return 0; return -1;
} }
active_tunnel->outgoing_buffer = frame_buffer_init(); active_tunnel->outgoing_buffer = frame_buffer_init();
...@@ -932,7 +955,7 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id ...@@ -932,7 +955,7 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id
SSL_free(ssl); SSL_free(ssl);
SSL_CTX_free(ssl_ctx); SSL_CTX_free(ssl_ctx);
close(sock); close(sock);
return 0; return -1;
} }
strcpy(active_tunnel->request_id, request_id); strcpy(active_tunnel->request_id, request_id);
...@@ -949,7 +972,7 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id ...@@ -949,7 +972,7 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id
SSL_free(ssl); SSL_free(ssl);
SSL_CTX_free(ssl_ctx); SSL_CTX_free(ssl_ctx);
close(sock); close(sock);
return 0; return -1;
} }
struct sockaddr_in local_addr; struct sockaddr_in local_addr;
...@@ -967,7 +990,7 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id ...@@ -967,7 +990,7 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id
SSL_free(ssl); SSL_free(ssl);
SSL_CTX_free(ssl_ctx); SSL_CTX_free(ssl_ctx);
close(sock); close(sock);
return 0; return -1;
} }
if (listen(listen_sock, 1) < 0) { if (listen(listen_sock, 1) < 0) {
...@@ -979,7 +1002,7 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id ...@@ -979,7 +1002,7 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id
SSL_free(ssl); SSL_free(ssl);
SSL_CTX_free(ssl_ctx); SSL_CTX_free(ssl_ctx);
close(sock); close(sock);
return 0; return -1;
} }
if (debug) { if (debug) {
...@@ -1234,6 +1257,10 @@ int main(int argc, char *argv[]) { ...@@ -1234,6 +1257,10 @@ int main(int argc, char *argv[]) {
} }
char read_buffer[BUFFER_SIZE]; char read_buffer[BUFFER_SIZE];
fd_set readfds;
struct timeval tv;
int ssl_fd = SSL_get_fd(active_tunnel->ssl);
while (active_tunnel && active_tunnel->active) { while (active_tunnel && active_tunnel->active) {
// Check if SCP process has finished // Check if SCP process has finished
int status; int status;
...@@ -1246,6 +1273,25 @@ int main(int argc, char *argv[]) { ...@@ -1246,6 +1273,25 @@ int main(int argc, char *argv[]) {
break; break;
} }
// Use select to wait for data on SSL socket with timeout
FD_ZERO(&readfds);
FD_SET(ssl_fd, &readfds);
tv.tv_sec = 0; // 0 seconds
tv.tv_usec = 50000; // 50ms timeout
int retval = select(ssl_fd + 1, &readfds, NULL, NULL, &tv);
if (retval == -1) {
if (config.debug) {
perror("[DEBUG] select on SSL fd failed");
fflush(stdout);
}
break;
} else if (retval == 0) {
// Timeout, continue loop
continue;
}
if (FD_ISSET(ssl_fd, &readfds)) {
int bytes_read = SSL_read(active_tunnel->ssl, read_buffer, sizeof(read_buffer)); int bytes_read = SSL_read(active_tunnel->ssl, read_buffer, sizeof(read_buffer));
if (bytes_read <= 0) { if (bytes_read <= 0) {
if (config.debug) { if (config.debug) {
...@@ -1476,6 +1522,7 @@ int main(int argc, char *argv[]) { ...@@ -1476,6 +1522,7 @@ int main(int argc, char *argv[]) {
} }
} }
} }
}
frame_buffer_free(frame_buffer); frame_buffer_free(frame_buffer);
...@@ -1494,6 +1541,13 @@ cleanup: ...@@ -1494,6 +1541,13 @@ cleanup:
fflush(stdout); fflush(stdout);
} }
// Wait a few seconds before exiting cleanly
if (config.debug) {
printf("[DEBUG] Waiting 3 seconds before exit...\n");
fflush(stdout);
}
sleep(3);
// Cleanup // Cleanup
if (active_tunnel) { if (active_tunnel) {
close(active_tunnel->local_sock); close(active_tunnel->local_sock);
...@@ -1509,9 +1563,9 @@ cleanup: ...@@ -1509,9 +1563,9 @@ cleanup:
// Free allocated strings in new_scp_args // Free allocated strings in new_scp_args
for (int i = 0; i < new_scp_argc; i++) { for (int i = 0; i < new_scp_argc; i++) {
// Free strings that were allocated with malloc/strdup // Free strings that were allocated with malloc/strdup
if (i == 2) { // The port string if (i == 4) { // The port string
free(new_scp_args[i]); free(new_scp_args[i]);
} else if (i == 3) { // The user@localhost:path string (if allocated) } else if (i == 5) { // The user@localhost:path string (if allocated)
// Check if this was allocated (contains @localhost) // Check if this was allocated (contains @localhost)
if (strstr(new_scp_args[i], "@localhost")) { if (strstr(new_scp_args[i], "@localhost")) {
free(new_scp_args[i]); free(new_scp_args[i]);
......
...@@ -194,8 +194,8 @@ int parse_ssh_args(int argc, char *argv[], char **host, int *ssh_port, int debug ...@@ -194,8 +194,8 @@ int parse_ssh_args(int argc, char *argv[], char **host, int *ssh_port, int debug
} }
char **modify_ssh_args(int argc, char *argv[], const char *original_host, int local_port, int *new_argc) { char **modify_ssh_args(int argc, char *argv[], const char *original_host, int local_port, int *new_argc) {
// Allocate space for: ssh + -p + port + original args + NULL // Allocate space for: ssh + -o + StrictHostKeyChecking=no + -p + port + original args + NULL
char **new_args = malloc((argc + 4) * sizeof(char *)); char **new_args = malloc((argc + 6) * sizeof(char *));
if (!new_args) { if (!new_args) {
return NULL; return NULL;
} }
...@@ -203,6 +203,10 @@ char **modify_ssh_args(int argc, char *argv[], const char *original_host, int lo ...@@ -203,6 +203,10 @@ char **modify_ssh_args(int argc, char *argv[], const char *original_host, int lo
int idx = 0; int idx = 0;
new_args[idx++] = "ssh"; new_args[idx++] = "ssh";
// Add StrictHostKeyChecking=no option
new_args[idx++] = "-o";
new_args[idx++] = "StrictHostKeyChecking=no";
// Add port argument for local tunnel // Add port argument for local tunnel
new_args[idx++] = "-p"; new_args[idx++] = "-p";
char *port_str = malloc(16); char *port_str = malloc(16);
...@@ -1156,9 +1160,9 @@ int main(int argc, char *argv[]) { ...@@ -1156,9 +1160,9 @@ int main(int argc, char *argv[]) {
// Free allocated strings in new_ssh_args // Free allocated strings in new_ssh_args
for (int i = 0; i < new_ssh_argc; i++) { for (int i = 0; i < new_ssh_argc; i++) {
// Free strings that were allocated with malloc/strdup // Free strings that were allocated with malloc/strdup
if (i == 2) { // The port string if (i == 4) { // The port string
free(new_ssh_args[i]); free(new_ssh_args[i]);
} else if (i == 3) { // The user@localhost string (if allocated) } else if (i == 5) { // The user@localhost string (if allocated)
// Check if this was allocated (contains @localhost) // Check if this was allocated (contains @localhost)
if (strstr(new_ssh_args[i], "@localhost")) { if (strstr(new_ssh_args[i], "@localhost")) {
free(new_ssh_args[i]); free(new_ssh_args[i]);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment