Fix wssshc tunnel handling - remove SSH protocol logic, implement raw TCP forwarding

parent 1114c673
......@@ -95,7 +95,7 @@ int frame_buffer_consume(frame_buffer_t *fb, size_t len) {
return 1;
}
void handle_tunnel_request(SSL *ssl, const char *request_id, int debug) {
void handle_tunnel_request(SSL *ssl, const char *request_id, int debug, const char *ssh_host, int ssh_port) {
pthread_mutex_lock(&tunnel_mutex);
if (active_tunnel) {
if (active_tunnel->sock >= 0) {
......@@ -110,11 +110,11 @@ void handle_tunnel_request(SSL *ssl, const char *request_id, int debug) {
return;
}
// For wssshc: Act as SSH client - connect to target SSH server
// For wssshc: Connect to target TCP endpoint and forward raw TCP data
struct sockaddr_in target_addr;
int ssh_sock = socket(AF_INET, SOCK_STREAM, 0);
if (ssh_sock < 0) {
perror("SSH socket creation failed");
int target_sock = socket(AF_INET, SOCK_STREAM, 0);
if (target_sock < 0) {
perror("Target socket creation failed");
free(active_tunnel);
active_tunnel = NULL;
pthread_mutex_unlock(&tunnel_mutex);
......@@ -123,19 +123,30 @@ void handle_tunnel_request(SSL *ssl, const char *request_id, int debug) {
memset(&target_addr, 0, sizeof(target_addr));
target_addr.sin_family = AF_INET;
target_addr.sin_port = htons(22); // Target SSH port
target_addr.sin_addr.s_addr = htonl(INADDR_LOOPBACK); // Target SSH host
target_addr.sin_port = htons(ssh_port); // Target port
// Resolve target host
struct hostent *target_he;
if ((target_he = gethostbyname(ssh_host)) == NULL) {
herror("Target host resolution failed");
close(target_sock);
free(active_tunnel);
active_tunnel = NULL;
pthread_mutex_unlock(&tunnel_mutex);
return;
}
target_addr.sin_addr = *((struct in_addr *)target_he->h_addr); // Target host
if (connect(ssh_sock, (struct sockaddr *)&target_addr, sizeof(target_addr)) < 0) {
perror("Connection to target SSH server failed");
close(ssh_sock);
if (connect(target_sock, (struct sockaddr *)&target_addr, sizeof(target_addr)) < 0) {
perror("Connection to target endpoint failed");
close(target_sock);
free(active_tunnel);
active_tunnel = NULL;
pthread_mutex_unlock(&tunnel_mutex);
return;
}
active_tunnel->sock = ssh_sock; // SSH client connection to target
active_tunnel->sock = target_sock; // TCP connection to target
active_tunnel->local_sock = -1; // Not used in wssshc
strcpy(active_tunnel->request_id, request_id);
active_tunnel->active = 1;
......@@ -143,31 +154,11 @@ void handle_tunnel_request(SSL *ssl, const char *request_id, int debug) {
active_tunnel->ssl = ssl;
active_tunnel->outgoing_buffer = NULL; // wssshc doesn't use buffer
active_tunnel->incoming_buffer = NULL; // wssshc doesn't need incoming buffer
active_tunnel->server_version_sent = 0; // Not used for raw TCP
pthread_mutex_unlock(&tunnel_mutex);
if (debug) {
printf("[DEBUG - Tunnel] wssshc connected to target SSH server\n");
fflush(stdout);
}
// Send client version to target SSH server
const char *client_version = "SSH-2.0-OpenSSH_9.9p1 Debian-3\r\n";
size_t version_len = strlen(client_version);
if (send(ssh_sock, client_version, version_len, 0) < 0) {
perror("Send client version failed");
return;
}
// Receive server version
char server_version_buf[256];
int bytes_read = recv(ssh_sock, server_version_buf, sizeof(server_version_buf) - 1, 0);
if (bytes_read <= 0) {
perror("Receive server version failed");
return;
}
server_version_buf[bytes_read] = '\0';
if (debug) {
printf("[DEBUG - Tunnel] Received server version: %s", server_version_buf);
printf("[DEBUG - Tunnel] wssshc connected to target %s:%d\n", ssh_host, ssh_port);
fflush(stdout);
}
......@@ -185,30 +176,7 @@ void handle_tunnel_request(SSL *ssl, const char *request_id, int debug) {
return;
}
// Send server version as tunnel_response immediately
// Convert to hex
size_t hex_size = (size_t)bytes_read * 2 + 1;
if (hex_size > 256) hex_size = 256;
char hex_data[256];
for (int i = 0; i < bytes_read && (size_t)i * 2 < hex_size - 1; i++) {
sprintf(hex_data + i * 2, "%02x", (unsigned char)server_version_buf[i]);
}
hex_data[bytes_read * 2] = '\0';
char response_msg[512];
snprintf(response_msg, sizeof(response_msg), "{\"type\":\"tunnel_response\",\"request_id\":\"%s\",\"data\":\"%s\"}", request_id, hex_data);
if (debug) {
printf("[DEBUG - WebSockets] Sending server version immediately: %s\n", response_msg);
fflush(stdout);
}
if (!send_websocket_frame(ssl, response_msg)) {
fprintf(stderr, "Send server version failed\n");
return;
}
// Start bidirectional forwarding between WebSocket and SSH server
// Start bidirectional forwarding between WebSocket and target TCP endpoint
thread_args_t *thread_args = malloc(sizeof(thread_args_t));
if (thread_args) {
thread_args->ssl = ssl;
......@@ -224,19 +192,19 @@ void cleanup_tunnel(int debug) {
pthread_mutex_lock(&tunnel_mutex);
if (active_tunnel) {
if (active_tunnel->sock >= 0) {
// Check if SSH connection is still valid before closing
// Check if TCP connection is still valid before closing
char test_buf[1];
int result = recv(active_tunnel->sock, test_buf, 1, MSG_PEEK | MSG_DONTWAIT);
if (result == 0 || (result < 0 && (errno == ECONNRESET || errno == EPIPE))) {
// SSH connection is closed or broken, safe to close
// TCP connection is closed or broken, safe to close
close(active_tunnel->sock);
if (debug) {
printf("[DEBUG] [TCP Tunnel] Closed broken SSH connection\n");
printf("[DEBUG] [TCP Tunnel] Closed broken TCP connection\n");
}
} else {
// SSH connection appears valid, don't close it
// TCP connection appears valid, don't close it
if (debug) {
printf("[DEBUG] [TCP Tunnel] Keeping SSH connection alive for potential reuse\n");
printf("[DEBUG] [TCP Tunnel] Keeping TCP connection alive for potential reuse\n");
}
// Reset socket to -1 so it will be reconnected if needed
active_tunnel->sock = -1;
......@@ -393,9 +361,6 @@ void *forward_tcp_to_ws(void *arg) {
fflush(stdout);
}
// Check if this is the SSH client version string
int is_client_version = (bytes_read >= 8 && strncmp(buffer, "SSH-2.0-", 8) == 0);
// Convert to hex with bounds checking
size_t hex_size = (size_t)bytes_read * 2 + 1;
if (hex_size > BUFFER_SIZE) {
......@@ -424,29 +389,6 @@ void *forward_tcp_to_ws(void *arg) {
}
break;
}
// If this was the client SSH version, immediately send the server version to prevent timeout
if (is_client_version) {
const char *server_version = "SSH-2.0-OpenSSH_10.0p2 Debian-8\r\n";
size_t version_len = strlen(server_version);
if (debug) {
printf("[DEBUG - TCPConnection] Sending server version immediately after client version\n");
fflush(stdout);
}
ssize_t sent = send(client_sock, server_version, version_len, 0);
if (sent > 0) {
if (debug) {
printf("[DEBUG] Sent %zd bytes of server version\n", sent);
fflush(stdout);
}
// Set flag to skip duplicate server version
pthread_mutex_lock(&tunnel_mutex);
if (active_tunnel) {
active_tunnel->server_version_sent = 1;
}
pthread_mutex_unlock(&tunnel_mutex);
}
}
}
}
......@@ -490,21 +432,21 @@ void *forward_ws_to_ssh_server(void *arg) {
pthread_mutex_unlock(&tunnel_mutex);
break;
}
int ssh_sock = active_tunnel->sock; // SSH server connection
int target_sock = active_tunnel->sock; // Target TCP connection
char request_id[37];
strcpy(request_id, active_tunnel->request_id);
pthread_mutex_unlock(&tunnel_mutex);
// Use select to wait for data on SSH server connection
// Use select to wait for data on target TCP connection
FD_ZERO(&readfds);
FD_SET(ssh_sock, &readfds);
FD_SET(target_sock, &readfds);
tv.tv_sec = 0; // 0 seconds
tv.tv_usec = 50000; // 50ms timeout
int retval = select(ssh_sock + 1, &readfds, NULL, NULL, &tv);
int retval = select(target_sock + 1, &readfds, NULL, NULL, &tv);
if (retval == -1) {
if (debug) {
perror("[DEBUG] select on SSH server socket failed");
perror("[DEBUG] select on target socket failed");
fflush(stdout);
}
break;
......@@ -513,18 +455,18 @@ void *forward_ws_to_ssh_server(void *arg) {
continue;
}
if (FD_ISSET(ssh_sock, &readfds)) {
bytes_read = recv(ssh_sock, buffer, sizeof(buffer), 0);
if (FD_ISSET(target_sock, &readfds)) {
bytes_read = recv(target_sock, buffer, sizeof(buffer), 0);
if (bytes_read <= 0) {
if (debug) {
printf("[DEBUG - TCPConnection] SSH server connection closed or error\n");
printf("[DEBUG - TCPConnection] Target connection closed or error\n");
fflush(stdout);
}
break;
}
if (debug) {
printf("[DEBUG - TCPConnection] Forwarding %d bytes from SSH server to WebSocket\n", bytes_read);
printf("[DEBUG - TCPConnection] Forwarding %d bytes from target to WebSocket\n", bytes_read);
fflush(stdout);
}
......@@ -543,7 +485,7 @@ void *forward_ws_to_ssh_server(void *arg) {
}
hex_data[bytes_read * 2] = '\0';
// Send as tunnel_response (from SSH server back to wsssh/wsscp)
// Send as tunnel_response (from target back to WebSocket)
char message[BUFFER_SIZE];
snprintf(message, sizeof(message),
"{\"type\":\"tunnel_response\",\"request_id\":\"%s\",\"data\":\"%s\"}",
......@@ -560,7 +502,7 @@ void *forward_ws_to_ssh_server(void *arg) {
}
if (debug) {
printf("[DEBUG - TCPConnection] SSH server to WebSocket forwarding thread exiting\n");
printf("[DEBUG - TCPConnection] Target to WebSocket forwarding thread exiting\n");
fflush(stdout);
}
......@@ -670,16 +612,8 @@ void handle_tunnel_data(SSL *ssl __attribute__((unused)), const char *request_id
pthread_mutex_unlock(&tunnel_mutex);
} else {
// wsssh/wssshc: Send directly to target socket
// Skip if server version was already sent early
if (active_tunnel->server_version_sent && strstr(data, "SSH-2.0-OpenSSH_10.0p2 Debian-8")) {
if (debug) {
printf("[DEBUG] Skipping duplicate server version send\n");
fflush(stdout);
}
} else {
if (debug) {
printf("[DEBUG] Attempting to send %zu bytes to socket %d\n", data_len, target_sock);
printf("[DEBUG - TOREMOVE] Target socket is %d\n", target_sock);
fflush(stdout);
}
ssize_t sent = send(target_sock, data, data_len, 0);
......@@ -692,7 +626,7 @@ void handle_tunnel_data(SSL *ssl __attribute__((unused)), const char *request_id
// Check if this is a recoverable error
if (errno == EPIPE || errno == ECONNRESET) {
if (debug) {
printf("[DEBUG] SSH client disconnected, marking tunnel inactive\n");
printf("[DEBUG] Target disconnected, marking tunnel inactive\n");
fflush(stdout);
}
// If send fails due to disconnection, mark tunnel as inactive
......@@ -713,11 +647,10 @@ void handle_tunnel_data(SSL *ssl __attribute__((unused)), const char *request_id
// Don't mark tunnel as inactive
} else if (errno == EBADF) {
if (debug) {
printf("[DEBUG] Bad file descriptor - socket may have been closed by SSH client\n");
printf("[DEBUG - TOREMOVE] EBADF occurred, marking tunnel inactive\n");
printf("[DEBUG] Bad file descriptor - socket may have been closed\n");
fflush(stdout);
}
// EBADF indicates the socket is invalid - SSH client likely disconnected
// EBADF indicates the socket is invalid - target likely disconnected
pthread_mutex_lock(&tunnel_mutex);
if (active_tunnel) {
active_tunnel->active = 0;
......@@ -747,7 +680,6 @@ void handle_tunnel_data(SSL *ssl __attribute__((unused)), const char *request_id
fflush(stdout);
}
}
}
free(data);
}
......
......@@ -58,7 +58,7 @@ void *forward_tcp_to_ws(void *arg);
void *forward_ws_to_local(void *arg);
void *forward_ws_to_ssh_server(void *arg);
void *tunnel_thread(void *arg);
void handle_tunnel_request(SSL *ssl, const char *request_id, int debug);
void handle_tunnel_request(SSL *ssl, const char *request_id, int debug, const char *ssh_host, int ssh_port);
void handle_tunnel_data(SSL *ssl, const char *request_id, const char *data_hex, int debug);
void handle_tunnel_close(SSL *ssl, const char *request_id, int debug);
void send_tunnel_close(SSL *ssl, const char *request_id, int debug);
......
......@@ -236,6 +236,7 @@ char **modify_ssh_args(int argc, char *argv[], const char *original_host, int lo
int main(int argc, char *argv[]) {
// Read config
char *config_domain = read_config_value("domain");
......
......@@ -41,8 +41,10 @@ int global_debug = 0;
typedef struct {
char *server_ip;
int port;
char *wssshd_server;
int wssshd_port;
char *ssh_host;
int ssh_port;
char *client_id;
char *password;
int interval;
......@@ -73,12 +75,18 @@ void load_config_file(const char *config_path, wssshc_config_t *config) {
while (end > value && *end == ' ') *end-- = 0;
if (strcmp(key, "password") == 0 && !config->password) {
config->password = strdup(value);
} else if (strcmp(key, "server-ip") == 0 && !config->server_ip) {
config->server_ip = strdup(value);
} else if (strcmp(key, "domain") == 0 && !config->server_ip) {
config->server_ip = strdup(value);
} else if (strcmp(key, "port") == 0) {
config->port = atoi(value);
} else if (strcmp(key, "wssshd-server") == 0 && !config->wssshd_server) {
config->wssshd_server = strdup(value);
} else if (strcmp(key, "domain") == 0 && !config->wssshd_server) {
config->wssshd_server = strdup(value);
} else if (strcmp(key, "wssshd-port") == 0) {
config->wssshd_port = atoi(value);
} else if (strcmp(key, "port") == 0 && config->wssshd_port == 9898) {
config->wssshd_port = atoi(value);
} else if (strcmp(key, "ssh-host") == 0 && !config->ssh_host) {
config->ssh_host = strdup(value);
} else if (strcmp(key, "ssh-port") == 0) {
config->ssh_port = atoi(value);
} else if (strcmp(key, "id") == 0 && !config->client_id) {
config->client_id = strdup(value);
} else if (strcmp(key, "interval") == 0) {
......@@ -117,8 +125,10 @@ void print_usage(const char *program_name) {
fprintf(stderr, "Protect the dolls!\n\n");
fprintf(stderr, "Options:\n");
fprintf(stderr, " --config FILE Configuration file path (overrides default hierarchy)\n");
fprintf(stderr, " --server-ip IP Server IP address\n");
fprintf(stderr, " --port PORT Server port (default: %d)\n", DEFAULT_PORT);
fprintf(stderr, " --wssshd-server HOST WSSSHD server host (default: mbeted.nexlab.net)\n");
fprintf(stderr, " --wssshd-port PORT WSSSHD server port (default: 9898)\n");
fprintf(stderr, " --ssh-host HOST SSH host to forward tunnel data to (default: 127.0.0.1)\n");
fprintf(stderr, " --ssh-port PORT SSH port to forward tunnel data to (default: 22)\n");
fprintf(stderr, " --id ID Client identifier\n");
fprintf(stderr, " --password PASS Registration password\n");
fprintf(stderr, " --interval SEC Reconnection interval (default: 30)\n");
......@@ -136,8 +146,10 @@ void print_usage(const char *program_name) {
int parse_args(int argc, char *argv[], wssshc_config_t *config) {
static struct option long_options[] = {
{"config", required_argument, 0, 'c'},
{"server-ip", required_argument, 0, 's'},
{"port", required_argument, 0, 'p'},
{"wssshd-server", required_argument, 0, 's'},
{"wssshd-port", required_argument, 0, 'p'},
{"ssh-host", required_argument, 0, 'H'},
{"ssh-port", required_argument, 0, 'P'},
{"id", required_argument, 0, 'i'},
{"password", required_argument, 0, 'w'},
{"interval", required_argument, 0, 't'},
......@@ -149,17 +161,24 @@ int parse_args(int argc, char *argv[], wssshc_config_t *config) {
int opt;
char *custom_config = NULL;
while ((opt = getopt_long(argc, argv, "c:s:p:i:w:t:dh", long_options, NULL)) != -1) {
while ((opt = getopt_long(argc, argv, "c:s:p:H:P:i:w:t:dh", long_options, NULL)) != -1) {
switch (opt) {
case 'c':
custom_config = optarg;
break;
case 's':
if (config->server_ip) free(config->server_ip);
config->server_ip = strdup(optarg);
if (config->wssshd_server) free(config->wssshd_server);
config->wssshd_server = strdup(optarg);
break;
case 'p':
config->port = atoi(optarg);
config->wssshd_port = atoi(optarg);
break;
case 'H':
if (config->ssh_host) free(config->ssh_host);
config->ssh_host = strdup(optarg);
break;
case 'P':
config->ssh_port = atoi(optarg);
break;
case 'i':
if (config->client_id) free(config->client_id);
......@@ -215,7 +234,7 @@ int connect_to_server(const wssshc_config_t *config) {
cleanup_tunnel(config->debug);
// Resolve hostname
if ((he = gethostbyname(config->server_ip)) == NULL) {
if ((he = gethostbyname(config->wssshd_server)) == NULL) {
herror("gethostbyname");
return 1;
}
......@@ -228,7 +247,7 @@ int connect_to_server(const wssshc_config_t *config) {
memset(&server_addr, 0, sizeof(server_addr));
server_addr.sin_family = AF_INET;
server_addr.sin_port = htons(config->port);
server_addr.sin_port = htons(config->wssshd_port);
server_addr.sin_addr = *((struct in_addr *)he->h_addr);
// Connect to server
......@@ -257,7 +276,7 @@ int connect_to_server(const wssshc_config_t *config) {
}
// Perform WebSocket handshake
if (!websocket_handshake(ssl, config->server_ip, config->port, "/")) {
if (!websocket_handshake(ssl, config->wssshd_server, config->wssshd_port, "/")) {
SSL_free(ssl);
SSL_CTX_free(ssl_ctx);
close(sock);
......@@ -456,7 +475,7 @@ int connect_to_server(const wssshc_config_t *config) {
printf("[DEBUG - WebSockets] Received tunnel_request for ID: %s\n", id_start);
fflush(stdout);
}
handle_tunnel_request(ssl, id_start, config->debug);
handle_tunnel_request(ssl, id_start, config->debug, config->ssh_host, config->ssh_port);
}
}
}
......@@ -648,8 +667,10 @@ int connect_to_server(const wssshc_config_t *config) {
int main(int argc, char *argv[]) {
wssshc_config_t config = {
.server_ip = NULL,
.port = DEFAULT_PORT,
.wssshd_server = NULL,
.wssshd_port = 9898,
.ssh_host = NULL,
.ssh_port = 22,
.client_id = NULL,
.password = NULL,
.interval = 30,
......@@ -667,11 +688,20 @@ int main(int argc, char *argv[]) {
return 1;
}
// Set defaults for optional fields
if (!config.wssshd_server) {
config.wssshd_server = strdup("mbeted.nexlab.net");
}
if (!config.ssh_host) {
config.ssh_host = strdup("127.0.0.1");
}
// Validate required arguments
if (!config.server_ip || !config.client_id || !config.password) {
fprintf(stderr, "Error: --server-ip, --id, and --password are required\n");
if (!config.client_id || !config.password) {
fprintf(stderr, "Error: --id and --password are required\n");
print_usage(argv[0]);
if (config.server_ip) free(config.server_ip);
if (config.wssshd_server) free(config.wssshd_server);
if (config.ssh_host) free(config.ssh_host);
if (config.client_id) free(config.client_id);
if (config.password) free(config.password);
pthread_mutex_destroy(&tunnel_mutex);
......@@ -679,7 +709,19 @@ int main(int argc, char *argv[]) {
}
global_debug = config.debug;
// Print configured options
printf("WebSocket SSH Client starting...\n");
printf("Configuration:\n");
printf(" WSSSHD Server: %s\n", config.wssshd_server ? config.wssshd_server : "(null)");
printf(" WSSSHD Port: %d\n", config.wssshd_port);
printf(" SSH Host: %s\n", config.ssh_host ? config.ssh_host : "(null)");
printf(" SSH Port: %d\n", config.ssh_port);
printf(" Client ID: %s\n", config.client_id ? config.client_id : "(null)");
printf(" Password: %s\n", config.password ? "***" : "(null)");
printf(" Reconnection Interval: %d seconds\n", config.interval);
printf(" Debug Mode: %s\n", config.debug ? "enabled" : "disabled");
printf("\n");
while (1) {
int result = connect_to_server(&config);
......@@ -699,7 +741,8 @@ int main(int argc, char *argv[]) {
}
// Cleanup
free(config.server_ip);
free(config.wssshd_server);
free(config.ssh_host);
free(config.client_id);
free(config.password);
pthread_mutex_destroy(&tunnel_mutex);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment