Optimize C wsscp performance and add SSH options

- Replace busy-waiting usleep loops with select() for efficient I/O multiplexing
- Add 50ms timeouts to prevent blocking operations
- Maintain 1MB buffers for optimal throughput
- Add 3-second exit delay after SCP completion for clean shutdown
- Fix critical bug in setup_tunnel return codes to prevent segfaults
- Correct memory management in argument cleanup to avoid corruption
- Add StrictHostKeyChecking=no option to wsscp and wsssh in both Python and C implementations
parent 92294668
......@@ -256,8 +256,8 @@ def main():
final_args.append(arg)
i += 1
# Add port argument for local tunnel at the beginning
final_args = ['-P', str(local_port)] + final_args
# Add StrictHostKeyChecking=no and port argument for local tunnel at the beginning
final_args = ['-o', 'StrictHostKeyChecking=no', '-P', str(local_port)] + final_args
if debug: print(f"[DEBUG] Final SCP args: {final_args}")
......
......@@ -251,8 +251,8 @@ def main():
else:
final_args.append(arg)
# Add port argument for local tunnel
final_args.extend(['-p', str(local_port)])
# Add StrictHostKeyChecking=no and port argument for local tunnel
final_args.extend(['-o', 'StrictHostKeyChecking=no', '-p', str(local_port)])
if debug: print(f"[DEBUG] Final SSH args: {final_args}")
......
......@@ -33,6 +33,7 @@
#include <fcntl.h>
#include <pthread.h>
#include <errno.h>
#include <sys/select.h>
#define BUFFER_SIZE 1048576
#define DEFAULT_PORT 22
......@@ -206,8 +207,8 @@ int parse_scp_args(int argc, char *argv[], char **destination, int *scp_port, in
}
char **modify_scp_args(int argc, char *argv[], const char *original_host, int local_port, int *new_argc) {
// Allocate space for: scp + -P + port + original args + NULL
char **new_args = malloc((argc + 4) * sizeof(char *));
// Allocate space for: scp + -o + StrictHostKeyChecking=no + -P + port + original args + NULL
char **new_args = malloc((argc + 6) * sizeof(char *));
if (!new_args) {
return NULL;
}
......@@ -215,7 +216,11 @@ char **modify_scp_args(int argc, char *argv[], const char *original_host, int lo
int idx = 0;
new_args[idx++] = "scp";
// Add port argument first
// Add StrictHostKeyChecking=no option
new_args[idx++] = "-o";
new_args[idx++] = "StrictHostKeyChecking=no";
// Add port argument
new_args[idx++] = "-P";
char *port_str = malloc(16);
if (!port_str) {
......@@ -594,6 +599,8 @@ void *forward_tcp_to_ws(void *arg) {
int debug = args->debug;
char buffer[BUFFER_SIZE];
int bytes_read;
fd_set readfds;
struct timeval tv;
while (1) {
pthread_mutex_lock(&tunnel_mutex);
......@@ -626,43 +633,59 @@ void *forward_tcp_to_ws(void *arg) {
pthread_mutex_unlock(&tunnel_mutex);
bytes_read = recv(sock, buffer, sizeof(buffer), MSG_DONTWAIT);
if (bytes_read == -1 && (errno == EAGAIN || errno == EWOULDBLOCK)) {
// No data available, sleep to avoid busy waiting
usleep(50000); // Increased sleep time to reduce CPU usage
continue;
} else if (bytes_read <= 0) {
// Use select to wait for data on local socket
FD_ZERO(&readfds);
FD_SET(sock, &readfds);
tv.tv_sec = 0; // 0 seconds
tv.tv_usec = 50000; // 50ms timeout
int retval = select(sock + 1, &readfds, NULL, NULL, &tv);
if (retval == -1) {
if (debug) {
printf("[DEBUG] TCP connection closed or error\n");
perror("[DEBUG] select failed");
fflush(stdout);
}
break;
} else if (retval == 0) {
// Timeout, continue loop
continue;
}
if (debug) {
printf("[DEBUG] Forwarding %d bytes from TCP to WebSocket\n", bytes_read);
fflush(stdout);
}
// Convert to hex
char hex_data[bytes_read * 2 + 1];
for (int i = 0; i < bytes_read; i++) {
sprintf(hex_data + i * 2, "%02x", (unsigned char)buffer[i]);
}
hex_data[bytes_read * 2] = '\0';
// Send as tunnel_data
char message[BUFFER_SIZE];
snprintf(message, sizeof(message),
"{\"type\":\"tunnel_data\",\"request_id\":\"%s\",\"data\":\"%s\"}",
request_id, hex_data);
if (FD_ISSET(sock, &readfds)) {
bytes_read = recv(sock, buffer, sizeof(buffer), 0);
if (bytes_read <= 0) {
if (debug) {
printf("[DEBUG] TCP connection closed or error\n");
fflush(stdout);
}
break;
}
if (!send_websocket_frame(ssl, message)) {
if (debug) {
printf("[DEBUG] Failed to send WebSocket frame\n");
printf("[DEBUG] Forwarding %d bytes from TCP to WebSocket\n", bytes_read);
fflush(stdout);
}
break;
// Convert to hex
char hex_data[bytes_read * 2 + 1];
for (int i = 0; i < bytes_read; i++) {
sprintf(hex_data + i * 2, "%02x", (unsigned char)buffer[i]);
}
hex_data[bytes_read * 2] = '\0';
// Send as tunnel_data
char message[BUFFER_SIZE];
snprintf(message, sizeof(message),
"{\"type\":\"tunnel_data\",\"request_id\":\"%s\",\"data\":\"%s\"}",
request_id, hex_data);
if (!send_websocket_frame(ssl, message)) {
if (debug) {
printf("[DEBUG] Failed to send WebSocket frame\n");
fflush(stdout);
}
break;
}
}
}
......@@ -735,13 +758,13 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id
// Resolve hostname
if ((he = gethostbyname(wssshd_host)) == NULL) {
herror("gethostbyname");
return 0;
return -1;
}
// Create socket
if ((sock = socket(AF_INET, SOCK_STREAM, 0)) < 0) {
perror("Socket creation failed");
return 0;
return -1;
}
memset(&server_addr, 0, sizeof(server_addr));
......@@ -753,7 +776,7 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id
if (connect(sock, (struct sockaddr *)&server_addr, sizeof(server_addr)) < 0) {
perror("Connection failed");
close(sock);
return 0;
return -1;
}
// Initialize SSL
......@@ -784,7 +807,7 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id
SSL_free(ssl);
SSL_CTX_free(ssl_ctx);
close(sock);
return 0;
return -1;
}
if (debug) {
printf("[DEBUG] SSL connection established\n");
......@@ -801,7 +824,7 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id
SSL_free(ssl);
SSL_CTX_free(ssl_ctx);
close(sock);
return 0;
return -1;
}
if (debug) {
printf("[DEBUG] WebSocket handshake successful\n");
......@@ -813,7 +836,7 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id
SSL_free(ssl);
SSL_CTX_free(ssl_ctx);
close(sock);
return 0;
return -1;
}
if (debug) {
......@@ -829,7 +852,7 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id
SSL_free(ssl);
SSL_CTX_free(ssl_ctx);
close(sock);
return 0;
return -1;
}
if (debug) {
......@@ -863,7 +886,7 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id
SSL_free(ssl);
SSL_CTX_free(ssl_ctx);
close(sock);
return 0;
return -1;
}
// Parse WebSocket frame
......@@ -882,7 +905,7 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id
SSL_free(ssl);
SSL_CTX_free(ssl_ctx);
close(sock);
return 0;
return -1;
}
// Null terminate payload
......@@ -908,7 +931,7 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id
SSL_free(ssl);
SSL_CTX_free(ssl_ctx);
close(sock);
return 0;
return -1;
}
if (debug) {
......@@ -922,7 +945,7 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id
SSL_free(ssl);
SSL_CTX_free(ssl_ctx);
close(sock);
return 0;
return -1;
}
active_tunnel->outgoing_buffer = frame_buffer_init();
......@@ -932,7 +955,7 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id
SSL_free(ssl);
SSL_CTX_free(ssl_ctx);
close(sock);
return 0;
return -1;
}
strcpy(active_tunnel->request_id, request_id);
......@@ -949,7 +972,7 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id
SSL_free(ssl);
SSL_CTX_free(ssl_ctx);
close(sock);
return 0;
return -1;
}
struct sockaddr_in local_addr;
......@@ -967,7 +990,7 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id
SSL_free(ssl);
SSL_CTX_free(ssl_ctx);
close(sock);
return 0;
return -1;
}
if (listen(listen_sock, 1) < 0) {
......@@ -979,7 +1002,7 @@ int setup_tunnel(const char *wssshd_host, int wssshd_port, const char *client_id
SSL_free(ssl);
SSL_CTX_free(ssl_ctx);
close(sock);
return 0;
return -1;
}
if (debug) {
......@@ -1234,6 +1257,10 @@ int main(int argc, char *argv[]) {
}
char read_buffer[BUFFER_SIZE];
fd_set readfds;
struct timeval tv;
int ssl_fd = SSL_get_fd(active_tunnel->ssl);
while (active_tunnel && active_tunnel->active) {
// Check if SCP process has finished
int status;
......@@ -1246,235 +1273,255 @@ int main(int argc, char *argv[]) {
break;
}
int bytes_read = SSL_read(active_tunnel->ssl, read_buffer, sizeof(read_buffer));
if (bytes_read <= 0) {
if (config.debug) {
printf("[DEBUG] WebSocket connection closed or SSL_read error: %d\n", bytes_read);
fflush(stdout);
}
break;
}
// Use select to wait for data on SSL socket with timeout
FD_ZERO(&readfds);
FD_SET(ssl_fd, &readfds);
tv.tv_sec = 0; // 0 seconds
tv.tv_usec = 50000; // 50ms timeout
if (config.debug) {
printf("[DEBUG] SSL_read returned %d bytes\n", bytes_read);
fflush(stdout);
}
// Append new data to frame buffer
if (!frame_buffer_append(frame_buffer, read_buffer, bytes_read)) {
int retval = select(ssl_fd + 1, &readfds, NULL, NULL, &tv);
if (retval == -1) {
if (config.debug) {
printf("[DEBUG] Failed to append data to frame buffer\n");
perror("[DEBUG] select on SSL fd failed");
fflush(stdout);
}
break;
} else if (retval == 0) {
// Timeout, continue loop
continue;
}
// Process complete frames from the buffer
int processed_frames = 0;
while (frame_buffer->used >= 2 && processed_frames < 100) { // Limit to prevent infinite loops
// Check frame type first
unsigned char frame_byte = frame_buffer->buffer[0];
unsigned char frame_type = frame_byte & 0x8F;
int fin = (frame_byte & 0x80) != 0;
int masked = (frame_buffer->buffer[1] & 0x80) != 0;
if (FD_ISSET(ssl_fd, &readfds)) {
int bytes_read = SSL_read(active_tunnel->ssl, read_buffer, sizeof(read_buffer));
if (bytes_read <= 0) {
if (config.debug) {
printf("[DEBUG] WebSocket connection closed or SSL_read error: %d\n", bytes_read);
fflush(stdout);
}
break;
}
if (config.debug) {
printf("[DEBUG] Processing frame: type=0x%02x, fin=%d, masked=%d, buffer_used=%zu\n",
frame_type, fin, masked, frame_buffer->used);
printf("[DEBUG] SSL_read returned %d bytes\n", bytes_read);
fflush(stdout);
}
// Handle close frame
if (frame_type == 0x88) {
// Append new data to frame buffer
if (!frame_buffer_append(frame_buffer, read_buffer, bytes_read)) {
if (config.debug) {
printf("[DEBUG] Received close frame from server\n");
printf("[DEBUG] Failed to append data to frame buffer\n");
fflush(stdout);
}
goto cleanup;
continue;
}
// Handle ping frame
if (frame_type == 0x89) {
// Process complete frames from the buffer
int processed_frames = 0;
while (frame_buffer->used >= 2 && processed_frames < 100) { // Limit to prevent infinite loops
// Check frame type first
unsigned char frame_byte = frame_buffer->buffer[0];
unsigned char frame_type = frame_byte & 0x8F;
int fin = (frame_byte & 0x80) != 0;
int masked = (frame_buffer->buffer[1] & 0x80) != 0;
if (config.debug) {
printf("[DEBUG] Received ping frame, sending pong\n");
printf("[DEBUG] Processing frame: type=0x%02x, fin=%d, masked=%d, buffer_used=%zu\n",
frame_type, fin, masked, frame_buffer->used);
fflush(stdout);
}
// Parse the ping frame to get payload
char *ping_payload;
int ping_payload_len;
if (parse_websocket_frame(frame_buffer->buffer, frame_buffer->used, &ping_payload, &ping_payload_len)) {
// Send pong with same payload
if (!send_pong_frame(active_tunnel->ssl, ping_payload, ping_payload_len)) {
if (config.debug) {
printf("[DEBUG] Failed to send pong frame\n");
fflush(stdout);
}
}
// Calculate frame size and consume it
int frame_size = (ping_payload - frame_buffer->buffer) + ping_payload_len;
if (frame_type & 0x80) { // FIN bit set
frame_buffer_consume(frame_buffer, frame_size);
processed_frames++;
// Handle close frame
if (frame_type == 0x88) {
if (config.debug) {
printf("[DEBUG] Received close frame from server\n");
fflush(stdout);
}
} else {
// Incomplete ping frame, wait for more data
break;
goto cleanup;
}
continue;
}
// Handle pong frame
if (frame_type == 0x8A) {
if (config.debug) {
printf("[DEBUG] Received pong frame\n");
fflush(stdout);
}
// Parse to find frame boundaries
char *pong_payload;
int pong_payload_len;
if (parse_websocket_frame(frame_buffer->buffer, frame_buffer->used, &pong_payload, &pong_payload_len)) {
int frame_size = (pong_payload - frame_buffer->buffer) + pong_payload_len;
if (frame_type & 0x80) { // FIN bit set
frame_buffer_consume(frame_buffer, frame_size);
processed_frames++;
// Handle ping frame
if (frame_type == 0x89) {
if (config.debug) {
printf("[DEBUG] Received ping frame, sending pong\n");
fflush(stdout);
}
} else {
break; // Incomplete pong frame
// Parse the ping frame to get payload
char *ping_payload;
int ping_payload_len;
if (parse_websocket_frame(frame_buffer->buffer, frame_buffer->used, &ping_payload, &ping_payload_len)) {
// Send pong with same payload
if (!send_pong_frame(active_tunnel->ssl, ping_payload, ping_payload_len)) {
if (config.debug) {
printf("[DEBUG] Failed to send pong frame\n");
fflush(stdout);
}
}
// Calculate frame size and consume it
int frame_size = (ping_payload - frame_buffer->buffer) + ping_payload_len;
if (frame_type & 0x80) { // FIN bit set
frame_buffer_consume(frame_buffer, frame_size);
processed_frames++;
}
} else {
// Incomplete ping frame, wait for more data
break;
}
continue;
}
continue;
}
// Parse regular frame (text or binary)
char *payload;
int payload_len;
if (!parse_websocket_frame(frame_buffer->buffer, frame_buffer->used, &payload, &payload_len)) {
// Incomplete frame, wait for more data
if (config.debug) {
printf("[DEBUG] Incomplete frame, buffer used: %zu, first few bytes: ", frame_buffer->used);
for (size_t i = 0; i < frame_buffer->used && i < 10; i++) {
printf("%02x ", (unsigned char)frame_buffer->buffer[i]);
// Handle pong frame
if (frame_type == 0x8A) {
if (config.debug) {
printf("[DEBUG] Received pong frame\n");
fflush(stdout);
}
printf("\n");
fflush(stdout);
// Parse to find frame boundaries
char *pong_payload;
int pong_payload_len;
if (parse_websocket_frame(frame_buffer->buffer, frame_buffer->used, &pong_payload, &pong_payload_len)) {
int frame_size = (pong_payload - frame_buffer->buffer) + pong_payload_len;
if (frame_type & 0x80) { // FIN bit set
frame_buffer_consume(frame_buffer, frame_size);
processed_frames++;
}
} else {
break; // Incomplete pong frame
}
continue;
}
break;
}
// Calculate total frame size (header + payload)
// payload points to start of payload data, so header_len = payload - buffer
int header_len = payload - frame_buffer->buffer;
int frame_size = header_len + payload_len;
// Parse regular frame (text or binary)
char *payload;
int payload_len;
if (!parse_websocket_frame(frame_buffer->buffer, frame_buffer->used, &payload, &payload_len)) {
// Incomplete frame, wait for more data
if (config.debug) {
printf("[DEBUG] Incomplete frame, buffer used: %zu, first few bytes: ", frame_buffer->used);
for (size_t i = 0; i < frame_buffer->used && i < 10; i++) {
printf("%02x ", (unsigned char)frame_buffer->buffer[i]);
}
printf("\n");
fflush(stdout);
}
break;
}
if (config.debug) {
printf("[DEBUG] Frame details: header_len=%d, payload_len=%d, frame_size=%d, buffer_used=%zu\n",
header_len, payload_len, frame_size, frame_buffer->used);
fflush(stdout);
}
// Calculate total frame size (header + payload)
// payload points to start of payload data, so header_len = payload - buffer
int header_len = payload - frame_buffer->buffer;
int frame_size = header_len + payload_len;
// Validate frame size before consuming
if (frame_size <= 0 || frame_size > (int)frame_buffer->used) {
if (config.debug) {
printf("[DEBUG] Invalid frame size %d, skipping frame\n", frame_size);
printf("[DEBUG] Frame details: header_len=%d, payload_len=%d, frame_size=%d, buffer_used=%zu\n",
header_len, payload_len, frame_size, frame_buffer->used);
fflush(stdout);
}
// Skip this frame by consuming just the header
if (header_len > 0 && header_len <= (int)frame_buffer->used) {
frame_buffer_consume(frame_buffer, header_len);
} else {
// Can't even consume header, break to avoid infinite loop
break;
}
processed_frames++;
continue;
}
frame_buffer_consume(frame_buffer, frame_size);
processed_frames++;
// Validate frame size before consuming
if (frame_size <= 0 || frame_size > (int)frame_buffer->used) {
if (config.debug) {
printf("[DEBUG] Invalid frame size %d, skipping frame\n", frame_size);
fflush(stdout);
}
// Skip this frame by consuming just the header
if (header_len > 0 && header_len <= (int)frame_buffer->used) {
frame_buffer_consume(frame_buffer, header_len);
} else {
// Can't even consume header, break to avoid infinite loop
break;
}
processed_frames++;
continue;
}
payload[payload_len] = '\0';
frame_buffer_consume(frame_buffer, frame_size);
processed_frames++;
if (config.debug) {
printf("[DEBUG] Received: %s\n", payload);
fflush(stdout);
}
payload[payload_len] = '\0';
// Handle messages
if (strstr(payload, "tunnel_data") || strstr(payload, "tunnel_response") ||
strstr(payload, "tunnel_request") || strstr(payload, "tunnel_ack")) {
if (config.debug) {
if (strstr(payload, "tunnel_data")) {
printf("[DEBUG] Received tunnel_data message\n");
} else if (strstr(payload, "tunnel_response")) {
printf("[DEBUG] Received tunnel_response message\n");
} else if (strstr(payload, "tunnel_request")) {
printf("[DEBUG] Received tunnel_request message\n");
} else if (strstr(payload, "tunnel_ack")) {
printf("[DEBUG] Received tunnel_ack message\n");
}
printf("[DEBUG] Received: %s\n", payload);
fflush(stdout);
}
// Extract request_id and data if present
char *id_start = strstr(payload, "\"request_id\"");
char *data_start = strstr(payload, "\"data\"");
if (id_start && data_start) {
char *colon = strchr(id_start, ':');
if (colon) {
char *open_quote = strchr(colon, '"');
if (open_quote) {
id_start = open_quote + 1;
char *close_quote = strchr(id_start, '"');
if (close_quote) {
*close_quote = '\0';
char *data_colon = strchr(data_start, ':');
if (data_colon) {
char *data_quote = strchr(data_colon, '"');
if (data_quote) {
data_start = data_quote + 1;
char *data_end = strchr(data_start, '"');
if (data_end) {
*data_end = '\0';
handle_tunnel_data(active_tunnel->ssl, id_start, data_start, config.debug);
// Handle messages
if (strstr(payload, "tunnel_data") || strstr(payload, "tunnel_response") ||
strstr(payload, "tunnel_request") || strstr(payload, "tunnel_ack")) {
if (config.debug) {
if (strstr(payload, "tunnel_data")) {
printf("[DEBUG] Received tunnel_data message\n");
} else if (strstr(payload, "tunnel_response")) {
printf("[DEBUG] Received tunnel_response message\n");
} else if (strstr(payload, "tunnel_request")) {
printf("[DEBUG] Received tunnel_request message\n");
} else if (strstr(payload, "tunnel_ack")) {
printf("[DEBUG] Received tunnel_ack message\n");
}
fflush(stdout);
}
// Extract request_id and data if present
char *id_start = strstr(payload, "\"request_id\"");
char *data_start = strstr(payload, "\"data\"");
if (id_start && data_start) {
char *colon = strchr(id_start, ':');
if (colon) {
char *open_quote = strchr(colon, '"');
if (open_quote) {
id_start = open_quote + 1;
char *close_quote = strchr(id_start, '"');
if (close_quote) {
*close_quote = '\0';
char *data_colon = strchr(data_start, ':');
if (data_colon) {
char *data_quote = strchr(data_colon, '"');
if (data_quote) {
data_start = data_quote + 1;
char *data_end = strchr(data_start, '"');
if (data_end) {
*data_end = '\0';
handle_tunnel_data(active_tunnel->ssl, id_start, data_start, config.debug);
}
}
}
}
}
}
}
}
} else if (strstr(payload, "tunnel_close")) {
if (config.debug) {
printf("[DEBUG] Received tunnel_close message\n");
fflush(stdout);
}
char *id_start = strstr(payload, "\"request_id\"");
if (id_start) {
char *colon = strchr(id_start, ':');
if (colon) {
char *open_quote = strchr(colon, '"');
if (open_quote) {
id_start = open_quote + 1;
char *close_quote = strchr(id_start, '"');
if (close_quote) {
*close_quote = '\0';
handle_tunnel_close(active_tunnel->ssl, id_start, config.debug);
} else if (strstr(payload, "tunnel_close")) {
if (config.debug) {
printf("[DEBUG] Received tunnel_close message\n");
fflush(stdout);
}
char *id_start = strstr(payload, "\"request_id\"");
if (id_start) {
char *colon = strchr(id_start, ':');
if (colon) {
char *open_quote = strchr(colon, '"');
if (open_quote) {
id_start = open_quote + 1;
char *close_quote = strchr(id_start, '"');
if (close_quote) {
*close_quote = '\0';
handle_tunnel_close(active_tunnel->ssl, id_start, config.debug);
}
}
}
}
} else {
if (config.debug) {
printf("[DEBUG] Received unknown message type: %s\n", payload);
fflush(stdout);
}
}
} else {
}
if (processed_frames >= 100) {
if (config.debug) {
printf("[DEBUG] Received unknown message type: %s\n", payload);
printf("[DEBUG] Processed 100 frames in one iteration, possible infinite loop\n");
fflush(stdout);
}
}
}
if (processed_frames >= 100) {
if (config.debug) {
printf("[DEBUG] Processed 100 frames in one iteration, possible infinite loop\n");
fflush(stdout);
}
}
}
frame_buffer_free(frame_buffer);
......@@ -1494,6 +1541,13 @@ cleanup:
fflush(stdout);
}
// Wait a few seconds before exiting cleanly
if (config.debug) {
printf("[DEBUG] Waiting 3 seconds before exit...\n");
fflush(stdout);
}
sleep(3);
// Cleanup
if (active_tunnel) {
close(active_tunnel->local_sock);
......@@ -1509,9 +1563,9 @@ cleanup:
// Free allocated strings in new_scp_args
for (int i = 0; i < new_scp_argc; i++) {
// Free strings that were allocated with malloc/strdup
if (i == 2) { // The port string
if (i == 4) { // The port string
free(new_scp_args[i]);
} else if (i == 3) { // The user@localhost:path string (if allocated)
} else if (i == 5) { // The user@localhost:path string (if allocated)
// Check if this was allocated (contains @localhost)
if (strstr(new_scp_args[i], "@localhost")) {
free(new_scp_args[i]);
......
......@@ -194,8 +194,8 @@ int parse_ssh_args(int argc, char *argv[], char **host, int *ssh_port, int debug
}
char **modify_ssh_args(int argc, char *argv[], const char *original_host, int local_port, int *new_argc) {
// Allocate space for: ssh + -p + port + original args + NULL
char **new_args = malloc((argc + 4) * sizeof(char *));
// Allocate space for: ssh + -o + StrictHostKeyChecking=no + -p + port + original args + NULL
char **new_args = malloc((argc + 6) * sizeof(char *));
if (!new_args) {
return NULL;
}
......@@ -203,6 +203,10 @@ char **modify_ssh_args(int argc, char *argv[], const char *original_host, int lo
int idx = 0;
new_args[idx++] = "ssh";
// Add StrictHostKeyChecking=no option
new_args[idx++] = "-o";
new_args[idx++] = "StrictHostKeyChecking=no";
// Add port argument for local tunnel
new_args[idx++] = "-p";
char *port_str = malloc(16);
......@@ -1156,9 +1160,9 @@ int main(int argc, char *argv[]) {
// Free allocated strings in new_ssh_args
for (int i = 0; i < new_ssh_argc; i++) {
// Free strings that were allocated with malloc/strdup
if (i == 2) { // The port string
if (i == 4) { // The port string
free(new_ssh_args[i]);
} else if (i == 3) { // The user@localhost string (if allocated)
} else if (i == 5) { // The user@localhost string (if allocated)
// Check if this was allocated (contains @localhost)
if (strstr(new_ssh_args[i], "@localhost")) {
free(new_ssh_args[i]);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment