Commit f32bbbd9 authored by nextime's avatar nextime

Add SIGINT signal handling for clean Ctrl+C exit

- Add signal handling to wsssh.py, wsscp.py, and wssshc.py
- Implement asyncio Event-based shutdown mechanism
- Enable clean exit when pressing Ctrl+C during connection setup
- Maintain proper resource cleanup on signal interruption
- Preserve existing functionality while adding graceful shutdown
parent f7dc5d7f
......@@ -30,6 +30,7 @@ import sys
import uuid
import configparser
import os
import signal
debug = False
......@@ -92,7 +93,7 @@ async def handle_local_connection(reader, writer, ws, request_id):
writer.close()
await writer.wait_closed()
async def run_scp(server_ip, server_port, client_id, local_port, scp_args, interval=30):
async def run_scp(server_ip, server_port, client_id, local_port, scp_args, interval=30, shutdown_event=None):
"""Connect to wssshd and run SCP with reconnection support"""
uri = f"wss://{server_ip}:{server_port}"
ssl_context = ssl.create_default_context()
......@@ -174,6 +175,12 @@ async def run_scp(server_ip, server_port, client_id, local_port, scp_args, inter
tunnel_active = False
break
# Check for shutdown signal
if shutdown_event and shutdown_event.is_set():
if debug: print("[DEBUG] Shutdown signal received, terminating tunnel")
tunnel_active = False
break
# Check WebSocket connection and attempt reconnection if needed
try:
# Try to send a small message to test connection
......@@ -187,6 +194,12 @@ async def run_scp(server_ip, server_port, client_id, local_port, scp_args, inter
reconnected = False
while reconnect_attempts < max_reconnect_attempts and not reconnected:
# Check for shutdown signal during reconnection
if shutdown_event and shutdown_event.is_set():
if debug: print("[DEBUG] Shutdown signal received during reconnection")
tunnel_active = False
break
try:
if debug: print(f"[DEBUG] WebSocket reconnection attempt {reconnect_attempts + 1}/{max_reconnect_attempts}")
......@@ -386,6 +399,16 @@ def main():
if debug: print(f"[DEBUG] Final SCP args: {final_args}")
# Set up signal handling for clean exit
shutdown_event = asyncio.Event()
def signal_handler(signum, frame):
if debug: print(f"[DEBUG] Received signal {signum}, initiating shutdown")
shutdown_event.set()
# Register signal handler for SIGINT (Ctrl+C)
signal.signal(signal.SIGINT, signal_handler)
# Run the async SCP wrapper
exit_code = asyncio.run(run_scp(
wssshd_host, # Use parsed wssshd_host as server_ip
......@@ -393,7 +416,8 @@ def main():
client_id,
local_port,
final_args,
args.interval # Pass the interval parameter
args.interval, # Pass the interval parameter
shutdown_event # Pass the shutdown event
))
sys.exit(exit_code)
......
......@@ -30,6 +30,7 @@ import sys
import uuid
import configparser
import os
import signal
debug = False
......@@ -92,7 +93,7 @@ async def handle_local_connection(reader, writer, ws, request_id):
writer.close()
await writer.wait_closed()
async def run_ssh(server_ip, server_port, client_id, local_port, ssh_args, interval=30):
async def run_ssh(server_ip, server_port, client_id, local_port, ssh_args, interval=30, shutdown_event=None):
"""Connect to wssshd and run SSH with reconnection support"""
uri = f"wss://{server_ip}:{server_port}"
ssl_context = ssl.create_default_context()
......@@ -174,6 +175,12 @@ async def run_ssh(server_ip, server_port, client_id, local_port, ssh_args, inter
tunnel_active = False
break
# Check for shutdown signal
if shutdown_event and shutdown_event.is_set():
if debug: print("[DEBUG] Shutdown signal received, terminating tunnel")
tunnel_active = False
break
# Check WebSocket connection and attempt reconnection if needed
try:
# Try to send a small message to test connection
......@@ -187,6 +194,12 @@ async def run_ssh(server_ip, server_port, client_id, local_port, ssh_args, inter
reconnected = False
while reconnect_attempts < max_reconnect_attempts and not reconnected:
# Check for shutdown signal during reconnection
if shutdown_event and shutdown_event.is_set():
if debug: print("[DEBUG] Shutdown signal received during reconnection")
tunnel_active = False
break
try:
if debug: print(f"[DEBUG] WebSocket reconnection attempt {reconnect_attempts + 1}/{max_reconnect_attempts}")
......@@ -381,6 +394,16 @@ def main():
if debug: print(f"[DEBUG] Final SSH args: {final_args}")
# Set up signal handling for clean exit
shutdown_event = asyncio.Event()
def signal_handler(signum, frame):
if debug: print(f"[DEBUG] Received signal {signum}, initiating shutdown")
shutdown_event.set()
# Register signal handler for SIGINT (Ctrl+C)
signal.signal(signal.SIGINT, signal_handler)
# Run the async SSH wrapper
exit_code = asyncio.run(run_ssh(
wssshd_host, # Use parsed wssshd_host as server_ip
......@@ -388,7 +411,8 @@ def main():
client_id,
local_port,
final_args,
args.interval # Pass the interval parameter
args.interval, # Pass the interval parameter
shutdown_event # Pass the shutdown event
))
sys.exit(exit_code)
......
......@@ -27,6 +27,7 @@ import json
import socket
import configparser
import os
import signal
debug = False
......@@ -53,13 +54,18 @@ async def forward_tcp_to_ws(tcp_reader, websocket, request_id):
del active_tunnels[request_id]
if debug: print(f"[DEBUG] Tunnel {request_id} closed")
async def connect_to_server(server_ip, port, client_id, password, interval):
async def connect_to_server(server_ip, port, client_id, password, interval, shutdown_event=None):
uri = f"wss://{server_ip}:{port}"
ssl_context = ssl.create_default_context()
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE
while True:
# Check for shutdown signal
if shutdown_event and shutdown_event.is_set():
if debug: print("[DEBUG] Shutdown signal received, exiting")
break
try:
async with websockets.connect(uri, ssl=ssl_context) as websocket:
# Register
......@@ -152,7 +158,17 @@ def main():
global debug
debug = args.debug
asyncio.run(connect_to_server(args.server_ip, args.port, args.id, args.password, args.interval))
# Set up signal handling for clean exit
shutdown_event = asyncio.Event()
def signal_handler(signum, frame):
if debug: print(f"[DEBUG] Received signal {signum}, initiating shutdown")
shutdown_event.set()
# Register signal handler for SIGINT (Ctrl+C)
signal.signal(signal.SIGINT, signal_handler)
asyncio.run(connect_to_server(args.server_ip, args.port, args.id, args.password, args.interval, shutdown_event))
if __name__ == '__main__':
main()
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment