Fix thread synchronization issues in wsssh.c

- Added proper mutex protection for active_tunnel access
- Fixed race condition between main thread and forwarding thread
- Protected SSL connection access with mutex locks
- Fixed socket setting with proper mutex synchronization
- Ensured thread-safe access to tunnel state variables
- Prevented concurrent access to shared tunnel resources
- Added proper SSL connection lifecycle management
- Fixed potential data races in WebSocket message handling
parent 35e6e78f
...@@ -443,8 +443,8 @@ int main(int argc, char *argv[]) { ...@@ -443,8 +443,8 @@ int main(int argc, char *argv[]) {
struct sockaddr_in client_addr; struct sockaddr_in client_addr;
socklen_t client_len = sizeof(client_addr); socklen_t client_len = sizeof(client_addr);
active_tunnel->local_sock = accept(listen_sock, (struct sockaddr *)&client_addr, &client_len); int accepted_sock = accept(listen_sock, (struct sockaddr *)&client_addr, &client_len);
if (active_tunnel->local_sock < 0) { if (accepted_sock < 0) {
perror("Local accept failed"); perror("Local accept failed");
kill(pid, SIGTERM); kill(pid, SIGTERM);
waitpid(pid, NULL, 0); waitpid(pid, NULL, 0);
...@@ -466,11 +466,21 @@ int main(int argc, char *argv[]) { ...@@ -466,11 +466,21 @@ int main(int argc, char *argv[]) {
close(listen_sock); // No longer needed close(listen_sock); // No longer needed
// Set the accepted socket with mutex protection
pthread_mutex_lock(&tunnel_mutex);
active_tunnel->local_sock = accepted_sock;
pthread_mutex_unlock(&tunnel_mutex);
if (config.debug) { if (config.debug) {
printf("[DEBUG] Local SSH connection accepted! Starting data forwarding...\n"); printf("[DEBUG] Local SSH connection accepted! Starting data forwarding...\n");
fflush(stdout); fflush(stdout);
} }
// Get initial SSL connection for thread
pthread_mutex_lock(&tunnel_mutex);
SSL *current_ssl = active_tunnel ? active_tunnel->ssl : NULL;
pthread_mutex_unlock(&tunnel_mutex);
// Start forwarding thread // Start forwarding thread
thread_args_t *thread_args = malloc(sizeof(thread_args_t)); thread_args_t *thread_args = malloc(sizeof(thread_args_t));
if (!thread_args) { if (!thread_args) {
...@@ -492,7 +502,7 @@ int main(int argc, char *argv[]) { ...@@ -492,7 +502,7 @@ int main(int argc, char *argv[]) {
pthread_mutex_destroy(&tunnel_mutex); pthread_mutex_destroy(&tunnel_mutex);
return 1; return 1;
} }
thread_args->ssl = active_tunnel->ssl; // Need to store SSL in tunnel struct thread_args->ssl = current_ssl; // Use the current SSL connection
thread_args->debug = config.debug; thread_args->debug = config.debug;
pthread_t thread; pthread_t thread;
...@@ -504,9 +514,17 @@ int main(int argc, char *argv[]) { ...@@ -504,9 +514,17 @@ int main(int argc, char *argv[]) {
int bytes_read; int bytes_read;
fd_set readfds; fd_set readfds;
struct timeval tv; struct timeval tv;
int ssl_fd = SSL_get_fd(active_tunnel->ssl);
while (active_tunnel && active_tunnel->active) { while (1) {
// Get SSL fd with mutex protection
pthread_mutex_lock(&tunnel_mutex);
if (!active_tunnel || !active_tunnel->active) {
pthread_mutex_unlock(&tunnel_mutex);
break;
}
int ssl_fd = SSL_get_fd(active_tunnel->ssl);
current_ssl = active_tunnel->ssl;
pthread_mutex_unlock(&tunnel_mutex);
// Use select to wait for data on SSL socket with timeout // Use select to wait for data on SSL socket with timeout
FD_ZERO(&readfds); FD_ZERO(&readfds);
FD_SET(ssl_fd, &readfds); FD_SET(ssl_fd, &readfds);
...@@ -526,7 +544,7 @@ int main(int argc, char *argv[]) { ...@@ -526,7 +544,7 @@ int main(int argc, char *argv[]) {
} }
if (FD_ISSET(ssl_fd, &readfds)) { if (FD_ISSET(ssl_fd, &readfds)) {
bytes_read = SSL_read(active_tunnel->ssl, buffer, sizeof(buffer)); bytes_read = SSL_read(current_ssl, buffer, sizeof(buffer));
if (bytes_read <= 0) { if (bytes_read <= 0) {
if (config.debug) { if (config.debug) {
printf("[DEBUG] WebSocket connection lost, attempting reconnection...\n"); printf("[DEBUG] WebSocket connection lost, attempting reconnection...\n");
...@@ -544,6 +562,11 @@ int main(int argc, char *argv[]) { ...@@ -544,6 +562,11 @@ int main(int argc, char *argv[]) {
fflush(stdout); fflush(stdout);
} }
pthread_mutex_lock(&tunnel_mutex);
if (!active_tunnel) {
pthread_mutex_unlock(&tunnel_mutex);
break;
}
if (reconnect_websocket(active_tunnel, wssshd_host, wssshd_port, client_id, active_tunnel->request_id, config.debug) == 0) { if (reconnect_websocket(active_tunnel, wssshd_host, wssshd_port, client_id, active_tunnel->request_id, config.debug) == 0) {
reconnected = 1; reconnected = 1;
if (config.debug) { if (config.debug) {
...@@ -552,7 +575,11 @@ int main(int argc, char *argv[]) { ...@@ -552,7 +575,11 @@ int main(int argc, char *argv[]) {
} }
// Update ssl_fd for select // Update ssl_fd for select
ssl_fd = SSL_get_fd(active_tunnel->ssl); ssl_fd = SSL_get_fd(active_tunnel->ssl);
} else { current_ssl = active_tunnel->ssl;
}
pthread_mutex_unlock(&tunnel_mutex);
if (!reconnected) {
reconnect_attempts++; reconnect_attempts++;
if (reconnect_attempts < max_reconnect_attempts) { if (reconnect_attempts < max_reconnect_attempts) {
if (config.debug) { if (config.debug) {
...@@ -599,7 +626,7 @@ int main(int argc, char *argv[]) { ...@@ -599,7 +626,7 @@ int main(int argc, char *argv[]) {
int ping_payload_len; int ping_payload_len;
if (parse_websocket_frame(buffer, bytes_read, &ping_payload, &ping_payload_len)) { if (parse_websocket_frame(buffer, bytes_read, &ping_payload, &ping_payload_len)) {
// Send pong with same payload // Send pong with same payload
if (!send_pong_frame(active_tunnel->ssl, ping_payload, ping_payload_len)) { if (!send_pong_frame(current_ssl, ping_payload, ping_payload_len)) {
if (config.debug) { if (config.debug) {
printf("[DEBUG] Failed to send pong frame\n"); printf("[DEBUG] Failed to send pong frame\n");
fflush(stdout); fflush(stdout);
...@@ -668,7 +695,7 @@ int main(int argc, char *argv[]) { ...@@ -668,7 +695,7 @@ int main(int argc, char *argv[]) {
char *data_end = strchr(data_start, '"'); char *data_end = strchr(data_start, '"');
if (data_end) { if (data_end) {
*data_end = '\0'; *data_end = '\0';
handle_tunnel_data(active_tunnel->ssl, id_start, data_start, config.debug); handle_tunnel_data(current_ssl, id_start, data_start, config.debug);
} }
} }
} }
...@@ -691,7 +718,7 @@ int main(int argc, char *argv[]) { ...@@ -691,7 +718,7 @@ int main(int argc, char *argv[]) {
char *close_quote = strchr(id_start, '"'); char *close_quote = strchr(id_start, '"');
if (close_quote) { if (close_quote) {
*close_quote = '\0'; *close_quote = '\0';
handle_tunnel_close(active_tunnel->ssl, id_start, config.debug); handle_tunnel_close(current_ssl, id_start, config.debug);
} }
} }
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment