Commit f32bbbd9 authored by nextime's avatar nextime

Add SIGINT signal handling for clean Ctrl+C exit

- Add signal handling to wsssh.py, wsscp.py, and wssshc.py
- Implement asyncio Event-based shutdown mechanism
- Enable clean exit when pressing Ctrl+C during connection setup
- Maintain proper resource cleanup on signal interruption
- Preserve existing functionality while adding graceful shutdown
parent f7dc5d7f
...@@ -30,6 +30,7 @@ import sys ...@@ -30,6 +30,7 @@ import sys
import uuid import uuid
import configparser import configparser
import os import os
import signal
debug = False debug = False
...@@ -92,7 +93,7 @@ async def handle_local_connection(reader, writer, ws, request_id): ...@@ -92,7 +93,7 @@ async def handle_local_connection(reader, writer, ws, request_id):
writer.close() writer.close()
await writer.wait_closed() await writer.wait_closed()
async def run_scp(server_ip, server_port, client_id, local_port, scp_args, interval=30): async def run_scp(server_ip, server_port, client_id, local_port, scp_args, interval=30, shutdown_event=None):
"""Connect to wssshd and run SCP with reconnection support""" """Connect to wssshd and run SCP with reconnection support"""
uri = f"wss://{server_ip}:{server_port}" uri = f"wss://{server_ip}:{server_port}"
ssl_context = ssl.create_default_context() ssl_context = ssl.create_default_context()
...@@ -174,6 +175,12 @@ async def run_scp(server_ip, server_port, client_id, local_port, scp_args, inter ...@@ -174,6 +175,12 @@ async def run_scp(server_ip, server_port, client_id, local_port, scp_args, inter
tunnel_active = False tunnel_active = False
break break
# Check for shutdown signal
if shutdown_event and shutdown_event.is_set():
if debug: print("[DEBUG] Shutdown signal received, terminating tunnel")
tunnel_active = False
break
# Check WebSocket connection and attempt reconnection if needed # Check WebSocket connection and attempt reconnection if needed
try: try:
# Try to send a small message to test connection # Try to send a small message to test connection
...@@ -187,6 +194,12 @@ async def run_scp(server_ip, server_port, client_id, local_port, scp_args, inter ...@@ -187,6 +194,12 @@ async def run_scp(server_ip, server_port, client_id, local_port, scp_args, inter
reconnected = False reconnected = False
while reconnect_attempts < max_reconnect_attempts and not reconnected: while reconnect_attempts < max_reconnect_attempts and not reconnected:
# Check for shutdown signal during reconnection
if shutdown_event and shutdown_event.is_set():
if debug: print("[DEBUG] Shutdown signal received during reconnection")
tunnel_active = False
break
try: try:
if debug: print(f"[DEBUG] WebSocket reconnection attempt {reconnect_attempts + 1}/{max_reconnect_attempts}") if debug: print(f"[DEBUG] WebSocket reconnection attempt {reconnect_attempts + 1}/{max_reconnect_attempts}")
...@@ -386,6 +399,16 @@ def main(): ...@@ -386,6 +399,16 @@ def main():
if debug: print(f"[DEBUG] Final SCP args: {final_args}") if debug: print(f"[DEBUG] Final SCP args: {final_args}")
# Set up signal handling for clean exit
shutdown_event = asyncio.Event()
def signal_handler(signum, frame):
if debug: print(f"[DEBUG] Received signal {signum}, initiating shutdown")
shutdown_event.set()
# Register signal handler for SIGINT (Ctrl+C)
signal.signal(signal.SIGINT, signal_handler)
# Run the async SCP wrapper # Run the async SCP wrapper
exit_code = asyncio.run(run_scp( exit_code = asyncio.run(run_scp(
wssshd_host, # Use parsed wssshd_host as server_ip wssshd_host, # Use parsed wssshd_host as server_ip
...@@ -393,7 +416,8 @@ def main(): ...@@ -393,7 +416,8 @@ def main():
client_id, client_id,
local_port, local_port,
final_args, final_args,
args.interval # Pass the interval parameter args.interval, # Pass the interval parameter
shutdown_event # Pass the shutdown event
)) ))
sys.exit(exit_code) sys.exit(exit_code)
......
...@@ -30,6 +30,7 @@ import sys ...@@ -30,6 +30,7 @@ import sys
import uuid import uuid
import configparser import configparser
import os import os
import signal
debug = False debug = False
...@@ -92,7 +93,7 @@ async def handle_local_connection(reader, writer, ws, request_id): ...@@ -92,7 +93,7 @@ async def handle_local_connection(reader, writer, ws, request_id):
writer.close() writer.close()
await writer.wait_closed() await writer.wait_closed()
async def run_ssh(server_ip, server_port, client_id, local_port, ssh_args, interval=30): async def run_ssh(server_ip, server_port, client_id, local_port, ssh_args, interval=30, shutdown_event=None):
"""Connect to wssshd and run SSH with reconnection support""" """Connect to wssshd and run SSH with reconnection support"""
uri = f"wss://{server_ip}:{server_port}" uri = f"wss://{server_ip}:{server_port}"
ssl_context = ssl.create_default_context() ssl_context = ssl.create_default_context()
...@@ -174,6 +175,12 @@ async def run_ssh(server_ip, server_port, client_id, local_port, ssh_args, inter ...@@ -174,6 +175,12 @@ async def run_ssh(server_ip, server_port, client_id, local_port, ssh_args, inter
tunnel_active = False tunnel_active = False
break break
# Check for shutdown signal
if shutdown_event and shutdown_event.is_set():
if debug: print("[DEBUG] Shutdown signal received, terminating tunnel")
tunnel_active = False
break
# Check WebSocket connection and attempt reconnection if needed # Check WebSocket connection and attempt reconnection if needed
try: try:
# Try to send a small message to test connection # Try to send a small message to test connection
...@@ -187,6 +194,12 @@ async def run_ssh(server_ip, server_port, client_id, local_port, ssh_args, inter ...@@ -187,6 +194,12 @@ async def run_ssh(server_ip, server_port, client_id, local_port, ssh_args, inter
reconnected = False reconnected = False
while reconnect_attempts < max_reconnect_attempts and not reconnected: while reconnect_attempts < max_reconnect_attempts and not reconnected:
# Check for shutdown signal during reconnection
if shutdown_event and shutdown_event.is_set():
if debug: print("[DEBUG] Shutdown signal received during reconnection")
tunnel_active = False
break
try: try:
if debug: print(f"[DEBUG] WebSocket reconnection attempt {reconnect_attempts + 1}/{max_reconnect_attempts}") if debug: print(f"[DEBUG] WebSocket reconnection attempt {reconnect_attempts + 1}/{max_reconnect_attempts}")
...@@ -381,6 +394,16 @@ def main(): ...@@ -381,6 +394,16 @@ def main():
if debug: print(f"[DEBUG] Final SSH args: {final_args}") if debug: print(f"[DEBUG] Final SSH args: {final_args}")
# Set up signal handling for clean exit
shutdown_event = asyncio.Event()
def signal_handler(signum, frame):
if debug: print(f"[DEBUG] Received signal {signum}, initiating shutdown")
shutdown_event.set()
# Register signal handler for SIGINT (Ctrl+C)
signal.signal(signal.SIGINT, signal_handler)
# Run the async SSH wrapper # Run the async SSH wrapper
exit_code = asyncio.run(run_ssh( exit_code = asyncio.run(run_ssh(
wssshd_host, # Use parsed wssshd_host as server_ip wssshd_host, # Use parsed wssshd_host as server_ip
...@@ -388,7 +411,8 @@ def main(): ...@@ -388,7 +411,8 @@ def main():
client_id, client_id,
local_port, local_port,
final_args, final_args,
args.interval # Pass the interval parameter args.interval, # Pass the interval parameter
shutdown_event # Pass the shutdown event
)) ))
sys.exit(exit_code) sys.exit(exit_code)
......
...@@ -27,6 +27,7 @@ import json ...@@ -27,6 +27,7 @@ import json
import socket import socket
import configparser import configparser
import os import os
import signal
debug = False debug = False
...@@ -53,13 +54,18 @@ async def forward_tcp_to_ws(tcp_reader, websocket, request_id): ...@@ -53,13 +54,18 @@ async def forward_tcp_to_ws(tcp_reader, websocket, request_id):
del active_tunnels[request_id] del active_tunnels[request_id]
if debug: print(f"[DEBUG] Tunnel {request_id} closed") if debug: print(f"[DEBUG] Tunnel {request_id} closed")
async def connect_to_server(server_ip, port, client_id, password, interval): async def connect_to_server(server_ip, port, client_id, password, interval, shutdown_event=None):
uri = f"wss://{server_ip}:{port}" uri = f"wss://{server_ip}:{port}"
ssl_context = ssl.create_default_context() ssl_context = ssl.create_default_context()
ssl_context.check_hostname = False ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE ssl_context.verify_mode = ssl.CERT_NONE
while True: while True:
# Check for shutdown signal
if shutdown_event and shutdown_event.is_set():
if debug: print("[DEBUG] Shutdown signal received, exiting")
break
try: try:
async with websockets.connect(uri, ssl=ssl_context) as websocket: async with websockets.connect(uri, ssl=ssl_context) as websocket:
# Register # Register
...@@ -152,7 +158,17 @@ def main(): ...@@ -152,7 +158,17 @@ def main():
global debug global debug
debug = args.debug debug = args.debug
asyncio.run(connect_to_server(args.server_ip, args.port, args.id, args.password, args.interval)) # Set up signal handling for clean exit
shutdown_event = asyncio.Event()
def signal_handler(signum, frame):
if debug: print(f"[DEBUG] Received signal {signum}, initiating shutdown")
shutdown_event.set()
# Register signal handler for SIGINT (Ctrl+C)
signal.signal(signal.SIGINT, signal_handler)
asyncio.run(connect_to_server(args.server_ip, args.port, args.id, args.password, args.interval, shutdown_event))
if __name__ == '__main__': if __name__ == '__main__':
main() main()
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment