Fix thread synchronization issues in wsssh.c

- Added proper mutex protection for active_tunnel access
- Fixed race condition between main thread and forwarding thread
- Protected SSL connection access with mutex locks
- Fixed socket setting with proper mutex synchronization
- Ensured thread-safe access to tunnel state variables
- Prevented concurrent access to shared tunnel resources
- Added proper SSL connection lifecycle management
- Fixed potential data races in WebSocket message handling
parent 35e6e78f
......@@ -443,8 +443,8 @@ int main(int argc, char *argv[]) {
struct sockaddr_in client_addr;
socklen_t client_len = sizeof(client_addr);
active_tunnel->local_sock = accept(listen_sock, (struct sockaddr *)&client_addr, &client_len);
if (active_tunnel->local_sock < 0) {
int accepted_sock = accept(listen_sock, (struct sockaddr *)&client_addr, &client_len);
if (accepted_sock < 0) {
perror("Local accept failed");
kill(pid, SIGTERM);
waitpid(pid, NULL, 0);
......@@ -466,11 +466,21 @@ int main(int argc, char *argv[]) {
close(listen_sock); // No longer needed
// Set the accepted socket with mutex protection
pthread_mutex_lock(&tunnel_mutex);
active_tunnel->local_sock = accepted_sock;
pthread_mutex_unlock(&tunnel_mutex);
if (config.debug) {
printf("[DEBUG] Local SSH connection accepted! Starting data forwarding...\n");
fflush(stdout);
}
// Get initial SSL connection for thread
pthread_mutex_lock(&tunnel_mutex);
SSL *current_ssl = active_tunnel ? active_tunnel->ssl : NULL;
pthread_mutex_unlock(&tunnel_mutex);
// Start forwarding thread
thread_args_t *thread_args = malloc(sizeof(thread_args_t));
if (!thread_args) {
......@@ -492,7 +502,7 @@ int main(int argc, char *argv[]) {
pthread_mutex_destroy(&tunnel_mutex);
return 1;
}
thread_args->ssl = active_tunnel->ssl; // Need to store SSL in tunnel struct
thread_args->ssl = current_ssl; // Use the current SSL connection
thread_args->debug = config.debug;
pthread_t thread;
......@@ -504,9 +514,17 @@ int main(int argc, char *argv[]) {
int bytes_read;
fd_set readfds;
struct timeval tv;
int ssl_fd = SSL_get_fd(active_tunnel->ssl);
while (active_tunnel && active_tunnel->active) {
while (1) {
// Get SSL fd with mutex protection
pthread_mutex_lock(&tunnel_mutex);
if (!active_tunnel || !active_tunnel->active) {
pthread_mutex_unlock(&tunnel_mutex);
break;
}
int ssl_fd = SSL_get_fd(active_tunnel->ssl);
current_ssl = active_tunnel->ssl;
pthread_mutex_unlock(&tunnel_mutex);
// Use select to wait for data on SSL socket with timeout
FD_ZERO(&readfds);
FD_SET(ssl_fd, &readfds);
......@@ -526,7 +544,7 @@ int main(int argc, char *argv[]) {
}
if (FD_ISSET(ssl_fd, &readfds)) {
bytes_read = SSL_read(active_tunnel->ssl, buffer, sizeof(buffer));
bytes_read = SSL_read(current_ssl, buffer, sizeof(buffer));
if (bytes_read <= 0) {
if (config.debug) {
printf("[DEBUG] WebSocket connection lost, attempting reconnection...\n");
......@@ -544,6 +562,11 @@ int main(int argc, char *argv[]) {
fflush(stdout);
}
pthread_mutex_lock(&tunnel_mutex);
if (!active_tunnel) {
pthread_mutex_unlock(&tunnel_mutex);
break;
}
if (reconnect_websocket(active_tunnel, wssshd_host, wssshd_port, client_id, active_tunnel->request_id, config.debug) == 0) {
reconnected = 1;
if (config.debug) {
......@@ -552,7 +575,11 @@ int main(int argc, char *argv[]) {
}
// Update ssl_fd for select
ssl_fd = SSL_get_fd(active_tunnel->ssl);
} else {
current_ssl = active_tunnel->ssl;
}
pthread_mutex_unlock(&tunnel_mutex);
if (!reconnected) {
reconnect_attempts++;
if (reconnect_attempts < max_reconnect_attempts) {
if (config.debug) {
......@@ -599,7 +626,7 @@ int main(int argc, char *argv[]) {
int ping_payload_len;
if (parse_websocket_frame(buffer, bytes_read, &ping_payload, &ping_payload_len)) {
// Send pong with same payload
if (!send_pong_frame(active_tunnel->ssl, ping_payload, ping_payload_len)) {
if (!send_pong_frame(current_ssl, ping_payload, ping_payload_len)) {
if (config.debug) {
printf("[DEBUG] Failed to send pong frame\n");
fflush(stdout);
......@@ -668,7 +695,7 @@ int main(int argc, char *argv[]) {
char *data_end = strchr(data_start, '"');
if (data_end) {
*data_end = '\0';
handle_tunnel_data(active_tunnel->ssl, id_start, data_start, config.debug);
handle_tunnel_data(current_ssl, id_start, data_start, config.debug);
}
}
}
......@@ -691,7 +718,7 @@ int main(int argc, char *argv[]) {
char *close_quote = strchr(id_start, '"');
if (close_quote) {
*close_quote = '\0';
handle_tunnel_close(active_tunnel->ssl, id_start, config.debug);
handle_tunnel_close(current_ssl, id_start, config.debug);
}
}
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment