Optimize C wsscp performance and add SSH options

- Replace busy-waiting usleep loops with select() for efficient I/O multiplexing
- Add 50ms timeouts to prevent blocking operations
- Maintain 1MB buffers for optimal throughput
- Add 3-second exit delay after SCP completion for clean shutdown
- Fix critical bug in setup_tunnel return codes to prevent segfaults
- Correct memory management in argument cleanup to avoid corruption
- Add StrictHostKeyChecking=no option to wsscp and wsssh in both Python and C implementations
parent 92294668
...@@ -256,8 +256,8 @@ def main(): ...@@ -256,8 +256,8 @@ def main():
final_args.append(arg) final_args.append(arg)
i += 1 i += 1
# Add port argument for local tunnel at the beginning # Add StrictHostKeyChecking=no and port argument for local tunnel at the beginning
final_args = ['-P', str(local_port)] + final_args final_args = ['-o', 'StrictHostKeyChecking=no', '-P', str(local_port)] + final_args
if debug: print(f"[DEBUG] Final SCP args: {final_args}") if debug: print(f"[DEBUG] Final SCP args: {final_args}")
......
...@@ -251,8 +251,8 @@ def main(): ...@@ -251,8 +251,8 @@ def main():
else: else:
final_args.append(arg) final_args.append(arg)
# Add port argument for local tunnel # Add StrictHostKeyChecking=no and port argument for local tunnel
final_args.extend(['-p', str(local_port)]) final_args.extend(['-o', 'StrictHostKeyChecking=no', '-p', str(local_port)])
if debug: print(f"[DEBUG] Final SSH args: {final_args}") if debug: print(f"[DEBUG] Final SSH args: {final_args}")
......
...@@ -33,6 +33,7 @@ ...@@ -33,6 +33,7 @@
#include <fcntl.h> #include <fcntl.h>
#include <pthread.h> #include <pthread.h>
#include <errno.h> #include <errno.h>
#include <sys/select.h>
#define BUFFER_SIZE 1048576 #define BUFFER_SIZE 1048576
#define DEFAULT_PORT 22 #define DEFAULT_PORT 22
...@@ -206,8 +207,8 @@ int parse_scp_args(int argc, char *argv[], char **destination, int *scp_port, in ...@@ -206,8 +207,8 @@ int parse_scp_args(int argc, char *argv[], char **destination, int *scp_port, in
} }
char **modify_scp_args(int argc, char *argv[], const char *original_host, int local_port, int *new_argc) { char **modify_scp_args(int argc, char *argv[], const char *original_host, int local_port, int *new_argc) {
// Allocate space for: scp + -P + port + original args + NULL // Allocate space for: scp + -o + StrictHostKeyChecking=no + -P + port + original args + NULL
char **new_args = malloc((argc + 4) * sizeof(char *)); char **new_args = malloc((argc + 6) * sizeof(char *));
if (!new_args) { if (!new_args) {
return NULL; return NULL;
} }
...@@ -215,7 +216,11 @@ char **modify_scp_args(int argc, char *argv[], const char *original_host, int lo ...@@ -215,7 +216,11 @@ char **modify_scp_args(int argc, char *argv[], const char *original_host, int lo
int idx = 0; int idx = 0;
new_args[idx++] = "scp"; new_args[idx++] = "scp";
// Add port argument first // Add StrictHostKeyChecking=no option
new_args[idx++] = "-o";
new_args[idx++] = "StrictHostKeyChecking=no";
// Add port argument
new_args[idx++] = "-P"; new_args[idx++] = "-P";
char *port_str = malloc(16); char *port_str = malloc(16);
if (!port_str) { if (!port_str) {
...@@ -594,6 +599,8 @@ void *forward_tcp_to_ws(void *arg) { ...@@ -594,6 +599,8 @@ void *forward_tcp_to_ws(void *arg) {
int debug = args->debug; int debug = args->debug;
char buffer[BUFFER_SIZE]; char buffer[BUFFER_SIZE];
int bytes_read; int bytes_read;
fd_set readfds;
struct timeval tv;
while (1) { while (1) {
pthread_mutex_lock(&tunnel_mutex); pthread_mutex_lock(&tunnel_mutex);
...@@ -626,43 +633,59 @@ void *forward_tcp_to_ws(void *arg) { ...@@ -626,43 +633,59 @@ void *forward_tcp_to_ws(void *arg) {
pthread_mutex_unlock(&tunnel_mutex); pthread_mutex_unlock(&tunnel_mutex);
bytes_read = recv(sock, buffer, sizeof(buffer), MSG_DONTWAIT); // Use select to wait for data on local socket
if (bytes_read == -1 && (errno == EAGAIN || errno == EWOULDBLOCK)) { FD_ZERO(&readfds);
// No data available, sleep to avoid busy waiting FD_SET(sock, &readfds);
usleep(50000); // Increased sleep time to reduce CPU usage tv.tv_sec = 0; // 0 seconds
continue; tv.tv_usec = 50000; // 50ms timeout
} else if (bytes_read <= 0) {
int retval = select(sock + 1, &readfds, NULL, NULL, &tv);
if (retval == -1) {
if (debug) { if (debug) {
printf("[DEBUG] TCP connection closed or error\n"); perror("[DEBUG] select failed");
fflush(stdout); fflush(stdout);
} }
break; break;
} else if (retval == 0) {
// Timeout, continue loop
continue;
} }
if (debug) { if (FD_ISSET(sock, &readfds)) {
printf("[DEBUG] Forwarding %d bytes from TCP to WebSocket\n", bytes_read); bytes_read = recv(sock, buffer, sizeof(buffer), 0);
fflush(stdout); if (bytes_read <= 0) {
} if (debug) {
printf("[DEBUG] TCP connection closed or error\n");
// Convert to hex fflush(stdout);
char hex_data[bytes_read * 2 + 1]; }
for (int i = 0; i < bytes_read; i++) { break;
sprintf(hex_data + i * 2, "%02x", (unsigned char)buffer[i]); }
}
hex_data[bytes_read * 2] = '\0';
// Send as tunnel_data
char message[BUFFER_SIZE];
snprintf(message, sizeof(message),
"{\"type\":\"tunnel_data\",\"request_id\":\"%s\",\"data\":\"%s\"}",
request_id, hex_data);
if (!send_websocket_frame(ssl, message)) {
if (debug) { if (debug) {
printf("[DEBUG] Failed to send WebSocket frame\n"); printf("[DEBUG] Forwarding %d bytes from TCP to WebSocket\n", bytes_read);
fflush(stdout); fflush(stdout);
} }
break;
// Convert to hex
char hex_data[bytes_read * 2 + 1];
for (int i = 0; i < bytes_read; i++) {
sprintf(hex_data + i * 2, "%02x", (unsigned char)buffer[i]);
}
hex_data[bytes_read * 2] = '\0';
// Send as tunnel_data
char message[BUFFER_SIZE];
snprintf(message, sizeof(message),
"{\"type\":\"tunnel_data\",\"request_id\":\"%s\",\"data\":\"%s\"}",
request_id, hex_data);
if (!send_websocket_frame(ssl, message)) {
if (debug) {
printf("[DEBUG] Failed to send WebSocket frame\n");
fflush(stdout);
}
break;
}
} }
} }
...@@ -735,13 +758,13 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id ...@@ -735,13 +758,13 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id
// Resolve hostname // Resolve hostname
if ((he = gethostbyname(wssshd_host)) == NULL) { if ((he = gethostbyname(wssshd_host)) == NULL) {
herror("gethostbyname"); herror("gethostbyname");
return 0; return -1;
} }
// Create socket // Create socket
if ((sock = socket(AF_INET, SOCK_STREAM, 0)) < 0) { if ((sock = socket(AF_INET, SOCK_STREAM, 0)) < 0) {
perror("Socket creation failed"); perror("Socket creation failed");
return 0; return -1;
} }
memset(&server_addr, 0, sizeof(server_addr)); memset(&server_addr, 0, sizeof(server_addr));
...@@ -753,7 +776,7 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id ...@@ -753,7 +776,7 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id
if (connect(sock, (struct sockaddr *)&server_addr, sizeof(server_addr)) < 0) { if (connect(sock, (struct sockaddr *)&server_addr, sizeof(server_addr)) < 0) {
perror("Connection failed"); perror("Connection failed");
close(sock); close(sock);
return 0; return -1;
} }
// Initialize SSL // Initialize SSL
...@@ -784,7 +807,7 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id ...@@ -784,7 +807,7 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id
SSL_free(ssl); SSL_free(ssl);
SSL_CTX_free(ssl_ctx); SSL_CTX_free(ssl_ctx);
close(sock); close(sock);
return 0; return -1;
} }
if (debug) { if (debug) {
printf("[DEBUG] SSL connection established\n"); printf("[DEBUG] SSL connection established\n");
...@@ -801,7 +824,7 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id ...@@ -801,7 +824,7 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id
SSL_free(ssl); SSL_free(ssl);
SSL_CTX_free(ssl_ctx); SSL_CTX_free(ssl_ctx);
close(sock); close(sock);
return 0; return -1;
} }
if (debug) { if (debug) {
printf("[DEBUG] WebSocket handshake successful\n"); printf("[DEBUG] WebSocket handshake successful\n");
...@@ -813,7 +836,7 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id ...@@ -813,7 +836,7 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id
SSL_free(ssl); SSL_free(ssl);
SSL_CTX_free(ssl_ctx); SSL_CTX_free(ssl_ctx);
close(sock); close(sock);
return 0; return -1;
} }
if (debug) { if (debug) {
...@@ -829,7 +852,7 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id ...@@ -829,7 +852,7 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id
SSL_free(ssl); SSL_free(ssl);
SSL_CTX_free(ssl_ctx); SSL_CTX_free(ssl_ctx);
close(sock); close(sock);
return 0; return -1;
} }
if (debug) { if (debug) {
...@@ -863,7 +886,7 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id ...@@ -863,7 +886,7 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id
SSL_free(ssl); SSL_free(ssl);
SSL_CTX_free(ssl_ctx); SSL_CTX_free(ssl_ctx);
close(sock); close(sock);
return 0; return -1;
} }
// Parse WebSocket frame // Parse WebSocket frame
...@@ -882,7 +905,7 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id ...@@ -882,7 +905,7 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id
SSL_free(ssl); SSL_free(ssl);
SSL_CTX_free(ssl_ctx); SSL_CTX_free(ssl_ctx);
close(sock); close(sock);
return 0; return -1;
} }
// Null terminate payload // Null terminate payload
...@@ -908,7 +931,7 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id ...@@ -908,7 +931,7 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id
SSL_free(ssl); SSL_free(ssl);
SSL_CTX_free(ssl_ctx); SSL_CTX_free(ssl_ctx);
close(sock); close(sock);
return 0; return -1;
} }
if (debug) { if (debug) {
...@@ -922,7 +945,7 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id ...@@ -922,7 +945,7 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id
SSL_free(ssl); SSL_free(ssl);
SSL_CTX_free(ssl_ctx); SSL_CTX_free(ssl_ctx);
close(sock); close(sock);
return 0; return -1;
} }
active_tunnel->outgoing_buffer = frame_buffer_init(); active_tunnel->outgoing_buffer = frame_buffer_init();
...@@ -932,7 +955,7 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id ...@@ -932,7 +955,7 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id
SSL_free(ssl); SSL_free(ssl);
SSL_CTX_free(ssl_ctx); SSL_CTX_free(ssl_ctx);
close(sock); close(sock);
return 0; return -1;
} }
strcpy(active_tunnel->request_id, request_id); strcpy(active_tunnel->request_id, request_id);
...@@ -949,7 +972,7 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id ...@@ -949,7 +972,7 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id
SSL_free(ssl); SSL_free(ssl);
SSL_CTX_free(ssl_ctx); SSL_CTX_free(ssl_ctx);
close(sock); close(sock);
return 0; return -1;
} }
struct sockaddr_in local_addr; struct sockaddr_in local_addr;
...@@ -967,7 +990,7 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id ...@@ -967,7 +990,7 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id
SSL_free(ssl); SSL_free(ssl);
SSL_CTX_free(ssl_ctx); SSL_CTX_free(ssl_ctx);
close(sock); close(sock);
return 0; return -1;
} }
if (listen(listen_sock, 1) < 0) { if (listen(listen_sock, 1) < 0) {
...@@ -979,7 +1002,7 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id ...@@ -979,7 +1002,7 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id
SSL_free(ssl); SSL_free(ssl);
SSL_CTX_free(ssl_ctx); SSL_CTX_free(ssl_ctx);
close(sock); close(sock);
return 0; return -1;
} }
if (debug) { if (debug) {
...@@ -1234,6 +1257,10 @@ int main(int argc, char *argv[]) { ...@@ -1234,6 +1257,10 @@ int main(int argc, char *argv[]) {
} }
char read_buffer[BUFFER_SIZE]; char read_buffer[BUFFER_SIZE];
fd_set readfds;
struct timeval tv;
int ssl_fd = SSL_get_fd(active_tunnel->ssl);
while (active_tunnel && active_tunnel->active) { while (active_tunnel && active_tunnel->active) {
// Check if SCP process has finished // Check if SCP process has finished
int status; int status;
...@@ -1246,235 +1273,255 @@ int main(int argc, char *argv[]) { ...@@ -1246,235 +1273,255 @@ int main(int argc, char *argv[]) {
break; break;
} }
int bytes_read = SSL_read(active_tunnel->ssl, read_buffer, sizeof(read_buffer)); // Use select to wait for data on SSL socket with timeout
if (bytes_read <= 0) { FD_ZERO(&readfds);
if (config.debug) { FD_SET(ssl_fd, &readfds);
printf("[DEBUG] WebSocket connection closed or SSL_read error: %d\n", bytes_read); tv.tv_sec = 0; // 0 seconds
fflush(stdout); tv.tv_usec = 50000; // 50ms timeout
}
break;
}
if (config.debug) { int retval = select(ssl_fd + 1, &readfds, NULL, NULL, &tv);
printf("[DEBUG] SSL_read returned %d bytes\n", bytes_read); if (retval == -1) {
fflush(stdout);
}
// Append new data to frame buffer
if (!frame_buffer_append(frame_buffer, read_buffer, bytes_read)) {
if (config.debug) { if (config.debug) {
printf("[DEBUG] Failed to append data to frame buffer\n"); perror("[DEBUG] select on SSL fd failed");
fflush(stdout); fflush(stdout);
} }
break;
} else if (retval == 0) {
// Timeout, continue loop
continue; continue;
} }
// Process complete frames from the buffer if (FD_ISSET(ssl_fd, &readfds)) {
int processed_frames = 0; int bytes_read = SSL_read(active_tunnel->ssl, read_buffer, sizeof(read_buffer));
while (frame_buffer->used >= 2 && processed_frames < 100) { // Limit to prevent infinite loops if (bytes_read <= 0) {
// Check frame type first if (config.debug) {
unsigned char frame_byte = frame_buffer->buffer[0]; printf("[DEBUG] WebSocket connection closed or SSL_read error: %d\n", bytes_read);
unsigned char frame_type = frame_byte & 0x8F; fflush(stdout);
int fin = (frame_byte & 0x80) != 0; }
int masked = (frame_buffer->buffer[1] & 0x80) != 0; break;
}
if (config.debug) { if (config.debug) {
printf("[DEBUG] Processing frame: type=0x%02x, fin=%d, masked=%d, buffer_used=%zu\n", printf("[DEBUG] SSL_read returned %d bytes\n", bytes_read);
frame_type, fin, masked, frame_buffer->used);
fflush(stdout); fflush(stdout);
} }
// Handle close frame // Append new data to frame buffer
if (frame_type == 0x88) { if (!frame_buffer_append(frame_buffer, read_buffer, bytes_read)) {
if (config.debug) { if (config.debug) {
printf("[DEBUG] Received close frame from server\n"); printf("[DEBUG] Failed to append data to frame buffer\n");
fflush(stdout); fflush(stdout);
} }
goto cleanup; continue;
} }
// Handle ping frame // Process complete frames from the buffer
if (frame_type == 0x89) { int processed_frames = 0;
while (frame_buffer->used >= 2 && processed_frames < 100) { // Limit to prevent infinite loops
// Check frame type first
unsigned char frame_byte = frame_buffer->buffer[0];
unsigned char frame_type = frame_byte & 0x8F;
int fin = (frame_byte & 0x80) != 0;
int masked = (frame_buffer->buffer[1] & 0x80) != 0;
if (config.debug) { if (config.debug) {
printf("[DEBUG] Received ping frame, sending pong\n"); printf("[DEBUG] Processing frame: type=0x%02x, fin=%d, masked=%d, buffer_used=%zu\n",
frame_type, fin, masked, frame_buffer->used);
fflush(stdout); fflush(stdout);
} }
// Parse the ping frame to get payload
char *ping_payload; // Handle close frame
int ping_payload_len; if (frame_type == 0x88) {
if (parse_websocket_frame(frame_buffer->buffer, frame_buffer->used, &ping_payload, &ping_payload_len)) { if (config.debug) {
// Send pong with same payload printf("[DEBUG] Received close frame from server\n");
if (!send_pong_frame(active_tunnel->ssl, ping_payload, ping_payload_len)) { fflush(stdout);
if (config.debug) {
printf("[DEBUG] Failed to send pong frame\n");
fflush(stdout);
}
}
// Calculate frame size and consume it
int frame_size = (ping_payload - frame_buffer->buffer) + ping_payload_len;
if (frame_type & 0x80) { // FIN bit set
frame_buffer_consume(frame_buffer, frame_size);
processed_frames++;
} }
} else { goto cleanup;
// Incomplete ping frame, wait for more data
break;
} }
continue;
}
// Handle pong frame // Handle ping frame
if (frame_type == 0x8A) { if (frame_type == 0x89) {
if (config.debug) { if (config.debug) {
printf("[DEBUG] Received pong frame\n"); printf("[DEBUG] Received ping frame, sending pong\n");
fflush(stdout); fflush(stdout);
}
// Parse to find frame boundaries
char *pong_payload;
int pong_payload_len;
if (parse_websocket_frame(frame_buffer->buffer, frame_buffer->used, &pong_payload, &pong_payload_len)) {
int frame_size = (pong_payload - frame_buffer->buffer) + pong_payload_len;
if (frame_type & 0x80) { // FIN bit set
frame_buffer_consume(frame_buffer, frame_size);
processed_frames++;
} }
} else { // Parse the ping frame to get payload
break; // Incomplete pong frame char *ping_payload;
int ping_payload_len;
if (parse_websocket_frame(frame_buffer->buffer, frame_buffer->used, &ping_payload, &ping_payload_len)) {
// Send pong with same payload
if (!send_pong_frame(active_tunnel->ssl, ping_payload, ping_payload_len)) {
if (config.debug) {
printf("[DEBUG] Failed to send pong frame\n");
fflush(stdout);
}
}
// Calculate frame size and consume it
int frame_size = (ping_payload - frame_buffer->buffer) + ping_payload_len;
if (frame_type & 0x80) { // FIN bit set
frame_buffer_consume(frame_buffer, frame_size);
processed_frames++;
}
} else {
// Incomplete ping frame, wait for more data
break;
}
continue;
} }
continue;
}
// Parse regular frame (text or binary) // Handle pong frame
char *payload; if (frame_type == 0x8A) {
int payload_len; if (config.debug) {
if (!parse_websocket_frame(frame_buffer->buffer, frame_buffer->used, &payload, &payload_len)) { printf("[DEBUG] Received pong frame\n");
// Incomplete frame, wait for more data fflush(stdout);
if (config.debug) {
printf("[DEBUG] Incomplete frame, buffer used: %zu, first few bytes: ", frame_buffer->used);
for (size_t i = 0; i < frame_buffer->used && i < 10; i++) {
printf("%02x ", (unsigned char)frame_buffer->buffer[i]);
} }
printf("\n"); // Parse to find frame boundaries
fflush(stdout); char *pong_payload;
int pong_payload_len;
if (parse_websocket_frame(frame_buffer->buffer, frame_buffer->used, &pong_payload, &pong_payload_len)) {
int frame_size = (pong_payload - frame_buffer->buffer) + pong_payload_len;
if (frame_type & 0x80) { // FIN bit set
frame_buffer_consume(frame_buffer, frame_size);
processed_frames++;
}
} else {
break; // Incomplete pong frame
}
continue;
} }
break;
}
// Calculate total frame size (header + payload) // Parse regular frame (text or binary)
// payload points to start of payload data, so header_len = payload - buffer char *payload;
int header_len = payload - frame_buffer->buffer; int payload_len;
int frame_size = header_len + payload_len; if (!parse_websocket_frame(frame_buffer->buffer, frame_buffer->used, &payload, &payload_len)) {
// Incomplete frame, wait for more data
if (config.debug) {
printf("[DEBUG] Incomplete frame, buffer used: %zu, first few bytes: ", frame_buffer->used);
for (size_t i = 0; i < frame_buffer->used && i < 10; i++) {
printf("%02x ", (unsigned char)frame_buffer->buffer[i]);
}
printf("\n");
fflush(stdout);
}
break;
}
if (config.debug) { // Calculate total frame size (header + payload)
printf("[DEBUG] Frame details: header_len=%d, payload_len=%d, frame_size=%d, buffer_used=%zu\n", // payload points to start of payload data, so header_len = payload - buffer
header_len, payload_len, frame_size, frame_buffer->used); int header_len = payload - frame_buffer->buffer;
fflush(stdout); int frame_size = header_len + payload_len;
}
// Validate frame size before consuming
if (frame_size <= 0 || frame_size > (int)frame_buffer->used) {
if (config.debug) { if (config.debug) {
printf("[DEBUG] Invalid frame size %d, skipping frame\n", frame_size); printf("[DEBUG] Frame details: header_len=%d, payload_len=%d, frame_size=%d, buffer_used=%zu\n",
header_len, payload_len, frame_size, frame_buffer->used);
fflush(stdout); fflush(stdout);
} }
// Skip this frame by consuming just the header
if (header_len > 0 && header_len <= (int)frame_buffer->used) {
frame_buffer_consume(frame_buffer, header_len);
} else {
// Can't even consume header, break to avoid infinite loop
break;
}
processed_frames++;
continue;
}
frame_buffer_consume(frame_buffer, frame_size); // Validate frame size before consuming
processed_frames++; if (frame_size <= 0 || frame_size > (int)frame_buffer->used) {
if (config.debug) {
printf("[DEBUG] Invalid frame size %d, skipping frame\n", frame_size);
fflush(stdout);
}
// Skip this frame by consuming just the header
if (header_len > 0 && header_len <= (int)frame_buffer->used) {
frame_buffer_consume(frame_buffer, header_len);
} else {
// Can't even consume header, break to avoid infinite loop
break;
}
processed_frames++;
continue;
}
payload[payload_len] = '\0'; frame_buffer_consume(frame_buffer, frame_size);
processed_frames++;
if (config.debug) { payload[payload_len] = '\0';
printf("[DEBUG] Received: %s\n", payload);
fflush(stdout);
}
// Handle messages
if (strstr(payload, "tunnel_data") || strstr(payload, "tunnel_response") ||
strstr(payload, "tunnel_request") || strstr(payload, "tunnel_ack")) {
if (config.debug) { if (config.debug) {
if (strstr(payload, "tunnel_data")) { printf("[DEBUG] Received: %s\n", payload);
printf("[DEBUG] Received tunnel_data message\n");
} else if (strstr(payload, "tunnel_response")) {
printf("[DEBUG] Received tunnel_response message\n");
} else if (strstr(payload, "tunnel_request")) {
printf("[DEBUG] Received tunnel_request message\n");
} else if (strstr(payload, "tunnel_ack")) {
printf("[DEBUG] Received tunnel_ack message\n");
}
fflush(stdout); fflush(stdout);
} }
// Extract request_id and data if present
char *id_start = strstr(payload, "\"request_id\""); // Handle messages
char *data_start = strstr(payload, "\"data\""); if (strstr(payload, "tunnel_data") || strstr(payload, "tunnel_response") ||
if (id_start && data_start) { strstr(payload, "tunnel_request") || strstr(payload, "tunnel_ack")) {
char *colon = strchr(id_start, ':'); if (config.debug) {
if (colon) { if (strstr(payload, "tunnel_data")) {
char *open_quote = strchr(colon, '"'); printf("[DEBUG] Received tunnel_data message\n");
if (open_quote) { } else if (strstr(payload, "tunnel_response")) {
id_start = open_quote + 1; printf("[DEBUG] Received tunnel_response message\n");
char *close_quote = strchr(id_start, '"'); } else if (strstr(payload, "tunnel_request")) {
if (close_quote) { printf("[DEBUG] Received tunnel_request message\n");
*close_quote = '\0'; } else if (strstr(payload, "tunnel_ack")) {
char *data_colon = strchr(data_start, ':'); printf("[DEBUG] Received tunnel_ack message\n");
if (data_colon) { }
char *data_quote = strchr(data_colon, '"'); fflush(stdout);
if (data_quote) { }
data_start = data_quote + 1; // Extract request_id and data if present
char *data_end = strchr(data_start, '"'); char *id_start = strstr(payload, "\"request_id\"");
if (data_end) { char *data_start = strstr(payload, "\"data\"");
*data_end = '\0'; if (id_start && data_start) {
handle_tunnel_data(active_tunnel->ssl, id_start, data_start, config.debug); char *colon = strchr(id_start, ':');
if (colon) {
char *open_quote = strchr(colon, '"');
if (open_quote) {
id_start = open_quote + 1;
char *close_quote = strchr(id_start, '"');
if (close_quote) {
*close_quote = '\0';
char *data_colon = strchr(data_start, ':');
if (data_colon) {
char *data_quote = strchr(data_colon, '"');
if (data_quote) {
data_start = data_quote + 1;
char *data_end = strchr(data_start, '"');
if (data_end) {
*data_end = '\0';
handle_tunnel_data(active_tunnel->ssl, id_start, data_start, config.debug);
}
} }
} }
} }
} }
} }
} }
} } else if (strstr(payload, "tunnel_close")) {
} else if (strstr(payload, "tunnel_close")) { if (config.debug) {
if (config.debug) { printf("[DEBUG] Received tunnel_close message\n");
printf("[DEBUG] Received tunnel_close message\n"); fflush(stdout);
fflush(stdout); }
} char *id_start = strstr(payload, "\"request_id\"");
char *id_start = strstr(payload, "\"request_id\""); if (id_start) {
if (id_start) { char *colon = strchr(id_start, ':');
char *colon = strchr(id_start, ':'); if (colon) {
if (colon) { char *open_quote = strchr(colon, '"');
char *open_quote = strchr(colon, '"'); if (open_quote) {
if (open_quote) { id_start = open_quote + 1;
id_start = open_quote + 1; char *close_quote = strchr(id_start, '"');
char *close_quote = strchr(id_start, '"'); if (close_quote) {
if (close_quote) { *close_quote = '\0';
*close_quote = '\0'; handle_tunnel_close(active_tunnel->ssl, id_start, config.debug);
handle_tunnel_close(active_tunnel->ssl, id_start, config.debug); }
} }
} }
} }
} else {
if (config.debug) {
printf("[DEBUG] Received unknown message type: %s\n", payload);
fflush(stdout);
}
} }
} else { }
if (processed_frames >= 100) {
if (config.debug) { if (config.debug) {
printf("[DEBUG] Received unknown message type: %s\n", payload); printf("[DEBUG] Processed 100 frames in one iteration, possible infinite loop\n");
fflush(stdout); fflush(stdout);
} }
} }
} }
if (processed_frames >= 100) {
if (config.debug) {
printf("[DEBUG] Processed 100 frames in one iteration, possible infinite loop\n");
fflush(stdout);
}
}
} }
frame_buffer_free(frame_buffer); frame_buffer_free(frame_buffer);
...@@ -1494,6 +1541,13 @@ cleanup: ...@@ -1494,6 +1541,13 @@ cleanup:
fflush(stdout); fflush(stdout);
} }
// Wait a few seconds before exiting cleanly
if (config.debug) {
printf("[DEBUG] Waiting 3 seconds before exit...\n");
fflush(stdout);
}
sleep(3);
// Cleanup // Cleanup
if (active_tunnel) { if (active_tunnel) {
close(active_tunnel->local_sock); close(active_tunnel->local_sock);
...@@ -1509,9 +1563,9 @@ cleanup: ...@@ -1509,9 +1563,9 @@ cleanup:
// Free allocated strings in new_scp_args // Free allocated strings in new_scp_args
for (int i = 0; i < new_scp_argc; i++) { for (int i = 0; i < new_scp_argc; i++) {
// Free strings that were allocated with malloc/strdup // Free strings that were allocated with malloc/strdup
if (i == 2) { // The port string if (i == 4) { // The port string
free(new_scp_args[i]); free(new_scp_args[i]);
} else if (i == 3) { // The user@localhost:path string (if allocated) } else if (i == 5) { // The user@localhost:path string (if allocated)
// Check if this was allocated (contains @localhost) // Check if this was allocated (contains @localhost)
if (strstr(new_scp_args[i], "@localhost")) { if (strstr(new_scp_args[i], "@localhost")) {
free(new_scp_args[i]); free(new_scp_args[i]);
......
...@@ -194,8 +194,8 @@ int parse_ssh_args(int argc, char *argv[], char **host, int *ssh_port, int debug ...@@ -194,8 +194,8 @@ int parse_ssh_args(int argc, char *argv[], char **host, int *ssh_port, int debug
} }
char **modify_ssh_args(int argc, char *argv[], const char *original_host, int local_port, int *new_argc) { char **modify_ssh_args(int argc, char *argv[], const char *original_host, int local_port, int *new_argc) {
// Allocate space for: ssh + -p + port + original args + NULL // Allocate space for: ssh + -o + StrictHostKeyChecking=no + -p + port + original args + NULL
char **new_args = malloc((argc + 4) * sizeof(char *)); char **new_args = malloc((argc + 6) * sizeof(char *));
if (!new_args) { if (!new_args) {
return NULL; return NULL;
} }
...@@ -203,6 +203,10 @@ char **modify_ssh_args(int argc, char *argv[], const char *original_host, int lo ...@@ -203,6 +203,10 @@ char **modify_ssh_args(int argc, char *argv[], const char *original_host, int lo
int idx = 0; int idx = 0;
new_args[idx++] = "ssh"; new_args[idx++] = "ssh";
// Add StrictHostKeyChecking=no option
new_args[idx++] = "-o";
new_args[idx++] = "StrictHostKeyChecking=no";
// Add port argument for local tunnel // Add port argument for local tunnel
new_args[idx++] = "-p"; new_args[idx++] = "-p";
char *port_str = malloc(16); char *port_str = malloc(16);
...@@ -1156,9 +1160,9 @@ int main(int argc, char *argv[]) { ...@@ -1156,9 +1160,9 @@ int main(int argc, char *argv[]) {
// Free allocated strings in new_ssh_args // Free allocated strings in new_ssh_args
for (int i = 0; i < new_ssh_argc; i++) { for (int i = 0; i < new_ssh_argc; i++) {
// Free strings that were allocated with malloc/strdup // Free strings that were allocated with malloc/strdup
if (i == 2) { // The port string if (i == 4) { // The port string
free(new_ssh_args[i]); free(new_ssh_args[i]);
} else if (i == 3) { // The user@localhost string (if allocated) } else if (i == 5) { // The user@localhost string (if allocated)
// Check if this was allocated (contains @localhost) // Check if this was allocated (contains @localhost)
if (strstr(new_ssh_args[i], "@localhost")) { if (strstr(new_ssh_args[i], "@localhost")) {
free(new_ssh_args[i]); free(new_ssh_args[i]);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment