Commit f7dc5d7f authored by nextime's avatar nextime

Enhance Python implementations with reconnection logic

- Add --interval option to both wsssh.py and wsscp.py (default: 30s)
- Implement automatic WebSocket reconnection during active tunnels
- Add retry logic for initial connection failures
- Smart timing: 1-second WebSocket reconnections vs configurable initial setup retries
- Enhanced connection resilience for both SSH and SCP operations
- Maintain backward compatibility with existing functionality
parent d803985a
......@@ -92,8 +92,8 @@ async def handle_local_connection(reader, writer, ws, request_id):
writer.close()
await writer.wait_closed()
async def run_scp(server_ip, server_port, client_id, local_port, scp_args):
"""Connect to wssshd and run SCP"""
async def run_scp(server_ip, server_port, client_id, local_port, scp_args, interval=30):
"""Connect to wssshd and run SCP with reconnection support"""
uri = f"wss://{server_ip}:{server_port}"
ssl_context = ssl.create_default_context()
ssl_context.check_hostname = False
......@@ -101,68 +101,161 @@ async def run_scp(server_ip, server_port, client_id, local_port, scp_args):
request_id = str(uuid.uuid4())
# Initial connection with retry logic
websocket = None
max_initial_attempts = 3
initial_attempt = 0
while initial_attempt < max_initial_attempts and websocket is None:
try:
if debug and initial_attempt > 0:
print(f"[DEBUG] Initial connection attempt {initial_attempt + 1}/{max_initial_attempts}")
websocket = await websockets.connect(uri, ssl=ssl_context)
if debug: print("[DEBUG] Initial WebSocket connection established")
except Exception as e:
initial_attempt += 1
if initial_attempt < max_initial_attempts:
if debug: print(f"[DEBUG] Initial connection failed, waiting {interval} seconds...")
await asyncio.sleep(interval)
else:
print(f"Connection failed after {max_initial_attempts} attempts: {e}")
return 1
try:
async with websockets.connect(uri, ssl=ssl_context) as websocket:
# Set up ping/pong handlers for explicit handling
async def ping_handler(payload):
if debug: print(f"[DEBUG] Received ping: {payload}")
# Pong is sent automatically by websockets library
async def pong_handler(payload):
if debug: print(f"[DEBUG] Received pong: {payload}")
websocket.ping_handler = ping_handler
websocket.pong_handler = pong_handler
# Request tunnel
await websocket.send(json.dumps({
"type": "tunnel_request",
"client_id": client_id,
"request_id": request_id
}))
# Wait for acknowledgment
response = await websocket.recv()
data = json.loads(response)
if data.get('type') == 'tunnel_ack':
if debug: print(f"[DEBUG] Tunnel request acknowledged: {request_id}")
# Start tunnel handler (opens listening port)
ready_event = asyncio.Event()
tunnel_task = asyncio.create_task(handle_tunnel(websocket, local_port, request_id, ready_event))
# Wait for tunnel to be ready (listening port opened)
await ready_event.wait()
# Launch SCP with modified arguments
scp_cmd = ['scp'] + scp_args
if debug: print(f"[DEBUG] Launching: {' '.join(scp_cmd)}")
# Run SCP process
process = await asyncio.create_subprocess_exec(
*scp_cmd,
stdin=sys.stdin,
stdout=sys.stdout,
stderr=sys.stderr
)
# Wait for SCP to complete
# Set up ping/pong handlers for explicit handling
async def ping_handler(payload):
if debug: print(f"[DEBUG] Received ping: {payload}")
# Pong is sent automatically by websockets library
async def pong_handler(payload):
if debug: print(f"[DEBUG] Received pong: {payload}")
websocket.ping_handler = ping_handler
websocket.pong_handler = pong_handler
# Request tunnel
await websocket.send(json.dumps({
"type": "tunnel_request",
"client_id": client_id,
"request_id": request_id
}))
# Wait for acknowledgment
response = await websocket.recv()
data = json.loads(response)
if data.get('type') == 'tunnel_ack':
if debug: print(f"[DEBUG] Tunnel request acknowledged: {request_id}")
# Start tunnel handler (opens listening port)
ready_event = asyncio.Event()
tunnel_task = asyncio.create_task(handle_tunnel(websocket, local_port, request_id, ready_event))
# Wait for tunnel to be ready (listening port opened)
await ready_event.wait()
# Launch SCP with modified arguments
scp_cmd = ['scp'] + scp_args
if debug: print(f"[DEBUG] Launching: {' '.join(scp_cmd)}")
# Run SCP process
process = await asyncio.create_subprocess_exec(
*scp_cmd,
stdin=sys.stdin,
stdout=sys.stdout,
stderr=sys.stderr
)
# Monitor both SCP process and WebSocket connection
tunnel_active = True
while tunnel_active:
# Check if SCP process is still running
if process.returncode is not None:
if debug: print(f"[DEBUG] SCP process finished with code: {process.returncode}")
tunnel_active = False
break
# Check WebSocket connection and attempt reconnection if needed
try:
# Try to send a small message to test connection
await asyncio.wait_for(websocket.ping(), timeout=5.0)
except Exception as e:
if debug: print(f"[DEBUG] WebSocket connection lost, attempting reconnection...")
# Attempt WebSocket reconnection
reconnect_attempts = 0
max_reconnect_attempts = 3
reconnected = False
while reconnect_attempts < max_reconnect_attempts and not reconnected:
try:
if debug: print(f"[DEBUG] WebSocket reconnection attempt {reconnect_attempts + 1}/{max_reconnect_attempts}")
# Create new WebSocket connection
new_websocket = await websockets.connect(uri, ssl=ssl_context)
new_websocket.ping_handler = ping_handler
new_websocket.pong_handler = pong_handler
# Re-request tunnel
await new_websocket.send(json.dumps({
"type": "tunnel_request",
"client_id": client_id,
"request_id": request_id
}))
# Wait for acknowledgment
response = await new_websocket.recv()
data = json.loads(response)
if data.get('type') == 'tunnel_ack':
# Update tunnel task with new WebSocket
tunnel_task.cancel()
tunnel_task = asyncio.create_task(handle_tunnel(new_websocket, local_port, request_id, None))
websocket = new_websocket
reconnected = True
if debug: print("[DEBUG] WebSocket reconnection successful")
else:
await new_websocket.close()
except Exception as reconnect_error:
reconnect_attempts += 1
if reconnect_attempts < max_reconnect_attempts:
if debug: print(f"[DEBUG] WebSocket reconnection failed, waiting 1 second...")
await asyncio.sleep(1) # Fast reconnection for WebSocket
else:
if debug: print(f"[DEBUG] All reconnection attempts failed: {reconnect_error}")
tunnel_active = False
break
if tunnel_active:
await asyncio.sleep(2) # Check every 2 seconds
# Wait for SCP to complete if still running
if process.returncode is None:
await process.wait()
# Close tunnel
# Close tunnel
try:
await websocket.send(json.dumps({
"type": "tunnel_close",
"request_id": request_id
}))
except Exception:
pass # WebSocket might already be closed
if tunnel_task and not tunnel_task.done():
tunnel_task.cancel()
elif data.get('type') == 'tunnel_error':
print(f"Error: {data.get('error', 'Unknown error')}")
return 1
elif data.get('type') == 'tunnel_error':
print(f"Error: {data.get('error', 'Unknown error')}")
return 1
except Exception as e:
print(f"Connection failed: {e}")
return 1
finally:
if websocket:
try:
await websocket.close()
except Exception:
pass
return 0
......@@ -199,6 +292,7 @@ def main():
parser = argparse.ArgumentParser(description='WebSocket SCP (wsscp)', add_help=False)
parser.add_argument('--local-port', type=int, default=0, help='Local port for tunnel (0 = auto)')
parser.add_argument('--interval', type=int, default=30, help='Connection retry interval in seconds (default: 30)')
parser.add_argument('--debug', action='store_true', help='Enable debug output')
# Parse our arguments first
......@@ -298,7 +392,8 @@ def main():
scp_port, # Use -P port as wssshd_port
client_id,
local_port,
final_args
final_args,
args.interval # Pass the interval parameter
))
sys.exit(exit_code)
......
......@@ -92,8 +92,8 @@ async def handle_local_connection(reader, writer, ws, request_id):
writer.close()
await writer.wait_closed()
async def run_ssh(server_ip, server_port, client_id, local_port, ssh_args):
"""Connect to wssshd and run SSH"""
async def run_ssh(server_ip, server_port, client_id, local_port, ssh_args, interval=30):
"""Connect to wssshd and run SSH with reconnection support"""
uri = f"wss://{server_ip}:{server_port}"
ssl_context = ssl.create_default_context()
ssl_context.check_hostname = False
......@@ -101,68 +101,161 @@ async def run_ssh(server_ip, server_port, client_id, local_port, ssh_args):
request_id = str(uuid.uuid4())
# Initial connection with retry logic
websocket = None
max_initial_attempts = 3
initial_attempt = 0
while initial_attempt < max_initial_attempts and websocket is None:
try:
if debug and initial_attempt > 0:
print(f"[DEBUG] Initial connection attempt {initial_attempt + 1}/{max_initial_attempts}")
websocket = await websockets.connect(uri, ssl=ssl_context)
if debug: print("[DEBUG] Initial WebSocket connection established")
except Exception as e:
initial_attempt += 1
if initial_attempt < max_initial_attempts:
if debug: print(f"[DEBUG] Initial connection failed, waiting {interval} seconds...")
await asyncio.sleep(interval)
else:
print(f"Connection failed after {max_initial_attempts} attempts: {e}")
return 1
try:
async with websockets.connect(uri, ssl=ssl_context) as websocket:
# Set up ping/pong handlers for explicit handling
async def ping_handler(payload):
if debug: print(f"[DEBUG] Received ping: {payload}")
# Pong is sent automatically by websockets library
async def pong_handler(payload):
if debug: print(f"[DEBUG] Received pong: {payload}")
websocket.ping_handler = ping_handler
websocket.pong_handler = pong_handler
# Request tunnel
await websocket.send(json.dumps({
"type": "tunnel_request",
"client_id": client_id,
"request_id": request_id
}))
# Wait for acknowledgment
response = await websocket.recv()
data = json.loads(response)
if data.get('type') == 'tunnel_ack':
if debug: print(f"[DEBUG] Tunnel request acknowledged: {request_id}")
# Start tunnel handler (opens listening port)
ready_event = asyncio.Event()
tunnel_task = asyncio.create_task(handle_tunnel(websocket, local_port, request_id, ready_event))
# Wait for tunnel to be ready (listening port opened)
await ready_event.wait()
# Launch SSH with modified arguments
ssh_cmd = ['ssh'] + ssh_args
if debug: print(f"[DEBUG] Launching: {' '.join(ssh_cmd)}")
# Run SSH process
process = await asyncio.create_subprocess_exec(
*ssh_cmd,
stdin=sys.stdin,
stdout=sys.stdout,
stderr=sys.stderr
)
# Wait for SSH to complete
# Set up ping/pong handlers for explicit handling
async def ping_handler(payload):
if debug: print(f"[DEBUG] Received ping: {payload}")
# Pong is sent automatically by websockets library
async def pong_handler(payload):
if debug: print(f"[DEBUG] Received pong: {payload}")
websocket.ping_handler = ping_handler
websocket.pong_handler = pong_handler
# Request tunnel
await websocket.send(json.dumps({
"type": "tunnel_request",
"client_id": client_id,
"request_id": request_id
}))
# Wait for acknowledgment
response = await websocket.recv()
data = json.loads(response)
if data.get('type') == 'tunnel_ack':
if debug: print(f"[DEBUG] Tunnel request acknowledged: {request_id}")
# Start tunnel handler (opens listening port)
ready_event = asyncio.Event()
tunnel_task = asyncio.create_task(handle_tunnel(websocket, local_port, request_id, ready_event))
# Wait for tunnel to be ready (listening port opened)
await ready_event.wait()
# Launch SSH with modified arguments
ssh_cmd = ['ssh'] + ssh_args
if debug: print(f"[DEBUG] Launching: {' '.join(ssh_cmd)}")
# Run SSH process
process = await asyncio.create_subprocess_exec(
*ssh_cmd,
stdin=sys.stdin,
stdout=sys.stdout,
stderr=sys.stderr
)
# Monitor both SSH process and WebSocket connection
tunnel_active = True
while tunnel_active:
# Check if SSH process is still running
if process.returncode is not None:
if debug: print(f"[DEBUG] SSH process finished with code: {process.returncode}")
tunnel_active = False
break
# Check WebSocket connection and attempt reconnection if needed
try:
# Try to send a small message to test connection
await asyncio.wait_for(websocket.ping(), timeout=5.0)
except Exception as e:
if debug: print(f"[DEBUG] WebSocket connection lost, attempting reconnection...")
# Attempt WebSocket reconnection
reconnect_attempts = 0
max_reconnect_attempts = 3
reconnected = False
while reconnect_attempts < max_reconnect_attempts and not reconnected:
try:
if debug: print(f"[DEBUG] WebSocket reconnection attempt {reconnect_attempts + 1}/{max_reconnect_attempts}")
# Create new WebSocket connection
new_websocket = await websockets.connect(uri, ssl=ssl_context)
new_websocket.ping_handler = ping_handler
new_websocket.pong_handler = pong_handler
# Re-request tunnel
await new_websocket.send(json.dumps({
"type": "tunnel_request",
"client_id": client_id,
"request_id": request_id
}))
# Wait for acknowledgment
response = await new_websocket.recv()
data = json.loads(response)
if data.get('type') == 'tunnel_ack':
# Update tunnel task with new WebSocket
tunnel_task.cancel()
tunnel_task = asyncio.create_task(handle_tunnel(new_websocket, local_port, request_id, None))
websocket = new_websocket
reconnected = True
if debug: print("[DEBUG] WebSocket reconnection successful")
else:
await new_websocket.close()
except Exception as reconnect_error:
reconnect_attempts += 1
if reconnect_attempts < max_reconnect_attempts:
if debug: print(f"[DEBUG] WebSocket reconnection failed, waiting 1 second...")
await asyncio.sleep(1) # Fast reconnection for WebSocket
else:
if debug: print(f"[DEBUG] All reconnection attempts failed: {reconnect_error}")
tunnel_active = False
break
if tunnel_active:
await asyncio.sleep(2) # Check every 2 seconds
# Wait for SSH to complete if still running
if process.returncode is None:
await process.wait()
# Close tunnel
# Close tunnel
try:
await websocket.send(json.dumps({
"type": "tunnel_close",
"request_id": request_id
}))
except Exception:
pass # WebSocket might already be closed
if tunnel_task and not tunnel_task.done():
tunnel_task.cancel()
elif data.get('type') == 'tunnel_error':
print(f"Error: {data.get('error', 'Unknown error')}")
return 1
elif data.get('type') == 'tunnel_error':
print(f"Error: {data.get('error', 'Unknown error')}")
return 1
except Exception as e:
print(f"Connection failed: {e}")
return 1
finally:
if websocket:
try:
await websocket.close()
except Exception:
pass
return 0
......@@ -199,6 +292,7 @@ def main():
parser = argparse.ArgumentParser(description='WebSocket SSH (wsssh)', add_help=False)
parser.add_argument('--local-port', type=int, default=0, help='Local port for tunnel (0 = auto)')
parser.add_argument('--interval', type=int, default=30, help='Connection retry interval in seconds (default: 30)')
parser.add_argument('--debug', action='store_true', help='Enable debug output')
# Parse our arguments first
......@@ -293,7 +387,8 @@ def main():
ssh_port, # Use -p port as wssshd_port
client_id,
local_port,
final_args
final_args,
args.interval # Pass the interval parameter
))
sys.exit(exit_code)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment