Implement comprehensive tunnel status tracking and IP autodetection

- Add IP autodetection function in wssshlib.c that detects local IP excluding loopback
- Create comprehensive Tunnel class in wsssd/tunnel.py with all required attributes:
  * client_id, tunnel_id, status, protocol, tunnel_type
  * dst_public_ip, dst_public_port, dst_private_ip, dst_private_port
  * src_public_ip, src_private_ip
- Update WebSocket handling to use Tunnel objects throughout lifecycle
- Add IP detection utilities for public/private IPs
- Maintain original tunnel binding behavior (127.0.0.1)
- Update server shutdown process for proper tunnel cleanup
- Test implementation with virtual environment
parent e7dbe71f
...@@ -137,15 +137,16 @@ async def shutdown_server(ws_server, cleanup_coro, flask_thread): ...@@ -137,15 +137,16 @@ async def shutdown_server(ws_server, cleanup_coro, flask_thread):
# Create close tasks for all active tunnels # Create close tasks for all active tunnels
close_tasks = [] close_tasks = []
for request_id, tunnel in active_tunnels.items(): for request_id, tunnel in active_tunnels.items():
client_info = clients.get(tunnel['client_id']) if tunnel.status == 'active': # Check tunnel status
if client_info and client_info['status'] == 'active': client_info = clients.get(tunnel.client_id)
try: if client_info and client_info['status'] == 'active':
close_task = asyncio.create_task( try:
tunnel['client_ws'].send(SERVER_SHUTDOWN_MSG) close_task = asyncio.create_task(
) tunnel.client_ws.send(SERVER_SHUTDOWN_MSG)
close_tasks.append((request_id, close_task)) )
except Exception as e: close_tasks.append((request_id, close_task))
if debug: print(f"[DEBUG] Failed to create close task for {request_id}: {e}") except Exception as e:
if debug: print(f"[DEBUG] Failed to create close task for {request_id}: {e}")
# Wait for all close tasks with timeout # Wait for all close tasks with timeout
if close_tasks: if close_tasks:
...@@ -160,7 +161,11 @@ async def shutdown_server(ws_server, cleanup_coro, flask_thread): ...@@ -160,7 +161,11 @@ async def shutdown_server(ws_server, cleanup_coro, flask_thread):
except Exception as e: except Exception as e:
if debug: print(f"[DEBUG] Error during tunnel close notifications: {e}") if debug: print(f"[DEBUG] Error during tunnel close notifications: {e}")
# Clean up all tunnels # Update tunnel statuses and clean up all tunnels
for request_id, tunnel in active_tunnels.items():
tunnel.update_status('closed', 'Server shutdown')
if debug: print(f"[DEBUG] Tunnel {request_id} status updated: {tunnel}")
active_tunnels.clear() active_tunnels.clear()
if debug: print(f"[DEBUG] Cleaned up {len(close_tasks)} tunnels") if debug: print(f"[DEBUG] Cleaned up {len(close_tasks)} tunnels")
......
"""
Tunnel object management for wssshd
"""
import time
import socket
import ipaddress
class TunnelStatus:
"""Enumeration of tunnel statuses"""
CREATING = "creating"
ACTIVE = "active"
CLOSING = "closing"
CLOSED = "closed"
ERROR = "error"
class Tunnel:
"""Comprehensive tunnel object that tracks all tunnel attributes"""
def __init__(self, request_id, client_id):
self.request_id = request_id
self.client_id = client_id
self.tunnel_id = request_id # Use request_id as tunnel_id for now
# Status and lifecycle
self.status = TunnelStatus.CREATING
self.created_at = time.time()
self.updated_at = time.time()
# Protocol and type
self.protocol = "ssh" # default
self.tunnel_type = "any" # default
# Destination (wssshc) information
self.dst_public_ip = None
self.dst_public_port = None
self.dst_private_ip = None
self.dst_private_port = None
# Source (wsssh/wsscp) information
self.src_public_ip = None
self.src_private_ip = None
# WebSocket connections
self.client_ws = None # wssshc WebSocket
self.wsssh_ws = None # wsssh/wsscp WebSocket
# Additional metadata
self.error_message = None
self.metadata = {}
def update_status(self, new_status, error_message=None):
"""Update tunnel status and timestamp"""
self.status = new_status
self.updated_at = time.time()
if error_message:
self.error_message = error_message
def set_destination_info(self, public_ip=None, public_port=None, private_ip=None, private_port=None):
"""Set destination (wssshc) connection information"""
if public_ip:
self.dst_public_ip = public_ip
if public_port:
self.dst_public_port = public_port
if private_ip:
self.dst_private_ip = private_ip
if private_port:
self.dst_private_port = private_port
self.updated_at = time.time()
def set_source_info(self, public_ip=None, private_ip=None):
"""Set source (wsssh/wsscp) connection information"""
if public_ip:
self.src_public_ip = public_ip
if private_ip:
self.src_private_ip = private_ip
self.updated_at = time.time()
def set_websockets(self, client_ws, wsssh_ws):
"""Set WebSocket connections"""
self.client_ws = client_ws
self.wsssh_ws = wsssh_ws
self.updated_at = time.time()
def to_dict(self):
"""Convert tunnel object to dictionary for serialization"""
return {
'request_id': self.request_id,
'client_id': self.client_id,
'tunnel_id': self.tunnel_id,
'status': self.status,
'created_at': self.created_at,
'updated_at': self.updated_at,
'protocol': self.protocol,
'tunnel_type': self.tunnel_type,
'dst_public_ip': self.dst_public_ip,
'dst_public_port': self.dst_public_port,
'dst_private_ip': self.dst_private_ip,
'dst_private_port': self.dst_private_port,
'src_public_ip': self.src_public_ip,
'src_private_ip': self.src_private_ip,
'error_message': self.error_message
}
def __str__(self):
return f"Tunnel(id={self.tunnel_id}, client={self.client_id}, status={self.status})"
def __repr__(self):
return self.__str__()
def detect_client_public_ip(websocket):
"""Detect the public IP address of a client from WebSocket connection"""
try:
# Get the remote address from WebSocket
remote_addr = websocket.remote_address
if remote_addr and len(remote_addr) >= 2:
ip = remote_addr[0]
# Check if it's a valid public IP
ip_obj = ipaddress.ip_address(ip)
if not ip_obj.is_private and not ip_obj.is_loopback:
return ip
except Exception:
pass
return None
def detect_client_private_ip(websocket):
"""Detect the private IP address of a client from WebSocket connection"""
try:
# Get the remote address from WebSocket
remote_addr = websocket.remote_address
if remote_addr and len(remote_addr) >= 2:
ip = remote_addr[0]
# Check if it's a valid private IP
ip_obj = ipaddress.ip_address(ip)
if ip_obj.is_private:
return ip
except Exception:
pass
return None
def get_server_public_ip():
"""Get the server's public IP address"""
try:
# Create a socket to connect to an external service
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
s.connect(("8.8.8.8", 80)) # Connect to Google DNS
public_ip = s.getsockname()[0]
s.close()
return public_ip
except Exception:
return None
def get_server_private_ip():
"""Get the server's private IP address"""
try:
# Create a socket and connect to get local IP
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
s.connect(("8.8.8.8", 80)) # Connect to Google DNS
private_ip = s.getsockname()[0]
s.close()
return private_ip
except Exception:
return None
\ No newline at end of file
...@@ -7,6 +7,7 @@ import json ...@@ -7,6 +7,7 @@ import json
import time import time
import websockets import websockets
from .terminal import openpty_with_fallback from .terminal import openpty_with_fallback
from .tunnel import Tunnel, TunnelStatus, detect_client_public_ip, detect_client_private_ip
# Client registry: id -> {'websocket': ws, 'last_seen': timestamp, 'status': 'active'|'disconnected'} # Client registry: id -> {'websocket': ws, 'last_seen': timestamp, 'status': 'active'|'disconnected'}
clients = {} clients = {}
...@@ -52,7 +53,8 @@ def print_status(): ...@@ -52,7 +53,8 @@ def print_status():
uptime = time.time() - start_time uptime = time.time() - start_time
active_clients = sum(1 for c in clients.values() if c['status'] == 'active') active_clients = sum(1 for c in clients.values() if c['status'] == 'active')
total_clients = len(clients) total_clients = len(clients)
active_tunnels_count = len(active_tunnels) active_tunnels_count = sum(1 for t in active_tunnels.values() if t.status == 'active')
total_tunnels = len(active_tunnels)
hours = int(uptime // 3600) hours = int(uptime // 3600)
minutes = int((uptime % 3600) // 60) minutes = int((uptime % 3600) // 60)
...@@ -60,7 +62,7 @@ def print_status(): ...@@ -60,7 +62,7 @@ def print_status():
print(f"[STATUS] Uptime: {hours:02d}:{minutes:02d}:{seconds:02d} | " print(f"[STATUS] Uptime: {hours:02d}:{minutes:02d}:{seconds:02d} | "
f"Clients: {active_clients}/{total_clients} active | " f"Clients: {active_clients}/{total_clients} active | "
f"Tunnels: {active_tunnels_count} active") f"Tunnels: {active_tunnels_count}/{total_tunnels} active")
async def handle_websocket(websocket, path=None, *, server_password=None): async def handle_websocket(websocket, path=None, *, server_password=None):
...@@ -134,19 +136,34 @@ async def handle_websocket(websocket, path=None, *, server_password=None): ...@@ -134,19 +136,34 @@ async def handle_websocket(websocket, path=None, *, server_password=None):
request_id = data['request_id'] request_id = data['request_id']
client_info = clients.get(client_id) client_info = clients.get(client_id)
if client_info and client_info['status'] == 'active': if client_info and client_info['status'] == 'active':
# Store tunnel mapping with optimized structure # Create comprehensive tunnel object
active_tunnels[request_id] = { tunnel = Tunnel(request_id, client_id)
'client_ws': client_info['websocket'],
'wsssh_ws': websocket, # Set WebSocket connections
'client_id': client_id tunnel.set_websockets(client_info['websocket'], websocket)
}
# Detect and set IP information
tunnel.set_destination_info(
public_ip=detect_client_public_ip(client_info['websocket']),
private_ip=detect_client_private_ip(client_info['websocket'])
)
# Store tunnel object
active_tunnels[request_id] = tunnel
# Update tunnel status to active
tunnel.update_status(TunnelStatus.ACTIVE)
# Forward tunnel request to client # Forward tunnel request to client
try: try:
await client_info['websocket'].send(TUNNEL_REQUEST_MSG % request_id) await client_info['websocket'].send(TUNNEL_REQUEST_MSG % request_id)
await websocket.send(TUNNEL_ACK_MSG % request_id) await websocket.send(TUNNEL_ACK_MSG % request_id)
if not debug: if not debug:
print(f"[EVENT] New tunnel {request_id} for client {client_id}") print(f"[EVENT] New tunnel {request_id} for client {client_id}")
except Exception: else:
print(f"[DEBUG] Created tunnel object: {tunnel}")
except Exception as e:
tunnel.update_status(TunnelStatus.ERROR, str(e))
# Send error response for tunnel request failures # Send error response for tunnel request failures
try: try:
await websocket.send(TUNNEL_ERROR_MSG % (request_id, "Failed to forward request")) await websocket.send(TUNNEL_ERROR_MSG % (request_id, "Failed to forward request"))
...@@ -162,23 +179,24 @@ async def handle_websocket(websocket, path=None, *, server_password=None): ...@@ -162,23 +179,24 @@ async def handle_websocket(websocket, path=None, *, server_password=None):
request_id = data['request_id'] request_id = data['request_id']
if request_id in active_tunnels: if request_id in active_tunnels:
tunnel = active_tunnels[request_id] tunnel = active_tunnels[request_id]
# Check client status first (faster lookup) # Check if tunnel is active and client is connected
client_info = clients.get(tunnel['client_id']) if tunnel.status == TunnelStatus.ACTIVE:
if client_info and client_info['status'] == 'active': client_info = clients.get(tunnel.client_id)
# Use pre-formatted JSON template for better performance if client_info and client_info['status'] == 'active':
try: # Use pre-formatted JSON template for better performance
await tunnel['client_ws'].send(TUNNEL_DATA_MSG % (request_id, data['data'])) try:
except Exception: await tunnel.client_ws.send(TUNNEL_DATA_MSG % (request_id, data['data']))
# Silent failure for performance - connection issues will be handled by cleanup except Exception:
pass # Silent failure for performance - connection issues will be handled by cleanup
pass
# No debug logging for performance - tunnel_data messages are too frequent # No debug logging for performance - tunnel_data messages are too frequent
elif data.get('type') == 'tunnel_response': elif data.get('type') == 'tunnel_response':
# Optimized tunnel response forwarding # Optimized tunnel response forwarding
request_id = data['request_id'] request_id = data['request_id']
tunnel = active_tunnels.get(request_id) tunnel = active_tunnels.get(request_id)
if tunnel: if tunnel and tunnel.status == TunnelStatus.ACTIVE:
try: try:
await tunnel['wsssh_ws'].send(TUNNEL_DATA_MSG % (request_id, data['data'])) await tunnel.wsssh_ws.send(TUNNEL_DATA_MSG % (request_id, data['data']))
except Exception: except Exception:
# Silent failure for performance - connection issues will be handled by cleanup # Silent failure for performance - connection issues will be handled by cleanup
pass pass
...@@ -186,18 +204,25 @@ async def handle_websocket(websocket, path=None, *, server_password=None): ...@@ -186,18 +204,25 @@ async def handle_websocket(websocket, path=None, *, server_password=None):
request_id = data['request_id'] request_id = data['request_id']
tunnel = active_tunnels.get(request_id) tunnel = active_tunnels.get(request_id)
if tunnel: if tunnel:
# Update tunnel status to closing
tunnel.update_status(TunnelStatus.CLOSING)
# Forward close to client if still active # Forward close to client if still active
client_info = clients.get(tunnel['client_id']) client_info = clients.get(tunnel.client_id)
if client_info and client_info['status'] == 'active': if client_info and client_info['status'] == 'active':
try: try:
await tunnel['client_ws'].send(TUNNEL_CLOSE_MSG % request_id) await tunnel.client_ws.send(TUNNEL_CLOSE_MSG % request_id)
except Exception: except Exception:
# Silent failure for performance # Silent failure for performance
pass pass
# Clean up tunnel
# Update tunnel status to closed and clean up
tunnel.update_status(TunnelStatus.CLOSED)
del active_tunnels[request_id] del active_tunnels[request_id]
if debug: if debug:
print(f"[DEBUG] [WebSocket] Tunnel {request_id} closed") print(f"[DEBUG] [WebSocket] Tunnel {request_id} closed")
print(f"[DEBUG] Tunnel object: {tunnel}")
else: else:
print(f"[EVENT] Tunnel {request_id} closed") print(f"[EVENT] Tunnel {request_id} closed")
except websockets.exceptions.ConnectionClosed: except websockets.exceptions.ConnectionClosed:
...@@ -215,10 +240,14 @@ async def handle_websocket(websocket, path=None, *, server_password=None): ...@@ -215,10 +240,14 @@ async def handle_websocket(websocket, path=None, *, server_password=None):
if disconnected_client: if disconnected_client:
# Use list comprehension for better performance # Use list comprehension for better performance
tunnels_to_remove = [rid for rid, tunnel in active_tunnels.items() tunnels_to_remove = [rid for rid, tunnel in active_tunnels.items()
if tunnel['client_id'] == disconnected_client] if tunnel.client_id == disconnected_client]
for request_id in tunnels_to_remove: for request_id in tunnels_to_remove:
tunnel = active_tunnels[request_id]
tunnel.update_status(TunnelStatus.ERROR, "Client disconnected")
del active_tunnels[request_id] del active_tunnels[request_id]
if debug: print(f"[DEBUG] [WebSocket] Tunnel {request_id} cleaned up due to client disconnect") if debug:
print(f"[DEBUG] [WebSocket] Tunnel {request_id} cleaned up due to client disconnect")
print(f"[DEBUG] Tunnel object: {tunnel}")
async def cleanup_task(): async def cleanup_task():
......
...@@ -21,6 +21,14 @@ ...@@ -21,6 +21,14 @@
#include <stdlib.h> #include <stdlib.h>
#include <time.h> #include <time.h>
#include <unistd.h> #include <unistd.h>
#include <ifaddrs.h>
#include <net/if.h>
#include <linux/rtnetlink.h>
#include <linux/if.h>
#include <linux/route.h>
#include <sys/ioctl.h>
#include <string.h>
#include <errno.h>
// Global signal flag // Global signal flag
volatile sig_atomic_t sigint_received = 0; volatile sig_atomic_t sigint_received = 0;
...@@ -170,4 +178,86 @@ void generate_request_id(char *request_id, size_t size) { ...@@ -170,4 +178,86 @@ void generate_request_id(char *request_id, size_t size) {
request_id[i] = charset[rand() % (sizeof(charset) - 1)]; request_id[i] = charset[rand() % (sizeof(charset) - 1)];
} }
request_id[size - 1] = '\0'; request_id[size - 1] = '\0';
}
// Helper function to get default gateway interface
static char *get_default_gateway_interface() {
FILE *fp = fopen("/proc/net/route", "r");
if (!fp) {
return NULL;
}
char line[256];
char *interface = NULL;
// Skip header line
if (!fgets(line, sizeof(line), fp)) {
fclose(fp);
return NULL;
}
while (fgets(line, sizeof(line), fp)) {
char iface[16];
unsigned long dest, gateway, flags;
if (sscanf(line, "%15s %lx %lx %lx", iface, &dest, &gateway, &flags) == 4) {
// Check if destination is 0.0.0.0 (default route) and gateway is not 0.0.0.0
if (dest == 0 && gateway != 0 && (flags & RTF_UP) && (flags & RTF_GATEWAY)) {
interface = strdup(iface);
break;
}
}
}
fclose(fp);
return interface;
}
char *autodetect_local_ip() {
struct ifaddrs *ifaddr, *ifa;
char *selected_ip = NULL;
char *gateway_iface = get_default_gateway_interface();
if (getifaddrs(&ifaddr) == -1) {
perror("getifaddrs");
free(gateway_iface);
return strdup("127.0.0.1"); // Fallback to localhost
}
// First priority: interface with default gateway
if (gateway_iface) {
for (ifa = ifaddr; ifa != NULL; ifa = ifa->ifa_next) {
if (ifa->ifa_addr == NULL) continue;
if (ifa->ifa_addr->sa_family != AF_INET) continue;
if (strcmp(ifa->ifa_name, "lo") == 0) continue; // Skip loopback
if (strcmp(ifa->ifa_name, gateway_iface) == 0) {
struct sockaddr_in *addr = (struct sockaddr_in *)ifa->ifa_addr;
selected_ip = strdup(inet_ntoa(addr->sin_addr));
break;
}
}
}
// Second priority: first non-loopback interface
if (!selected_ip) {
for (ifa = ifaddr; ifa != NULL; ifa = ifa->ifa_next) {
if (ifa->ifa_addr == NULL) continue;
if (ifa->ifa_addr->sa_family != AF_INET) continue;
if (strcmp(ifa->ifa_name, "lo") == 0) continue; // Skip loopback
struct sockaddr_in *addr = (struct sockaddr_in *)ifa->ifa_addr;
selected_ip = strdup(inet_ntoa(addr->sin_addr));
break;
}
}
freeifaddrs(ifaddr);
free(gateway_iface);
// Final fallback
if (!selected_ip) {
selected_ip = strdup("127.0.0.1");
}
return selected_ip;
} }
\ No newline at end of file
...@@ -36,6 +36,11 @@ ...@@ -36,6 +36,11 @@
#include <pthread.h> #include <pthread.h>
#include <sys/select.h> #include <sys/select.h>
#include <signal.h> #include <signal.h>
#include <ifaddrs.h>
#include <net/if.h>
#include <linux/rtnetlink.h>
#include <linux/if.h>
#include <linux/route.h>
#define BUFFER_SIZE 1048576 #define BUFFER_SIZE 1048576
#define MAX_CHUNK_SIZE 65536 #define MAX_CHUNK_SIZE 65536
...@@ -74,5 +79,6 @@ void print_trans_flag(void); ...@@ -74,5 +79,6 @@ void print_trans_flag(void);
void print_palestinian_flag(void); void print_palestinian_flag(void);
int find_available_port(void); int find_available_port(void);
void generate_request_id(char *request_id, size_t size); void generate_request_id(char *request_id, size_t size);
char *autodetect_local_ip(void);
#endif // WSSH_LIB_H #endif // WSSH_LIB_H
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment