Improve wssshd graceful shutdown with better messaging and cleanup

- Add informative messages during graceful shutdown indicating what the server is waiting for
- Stop cleanup task immediately during shutdown to prevent keepalive timeouts
- Send tunnel close messages to ALL tunnels (both client and tool endpoints) during shutdown
- Import TUNNEL_CLOSE_MSG constant for proper tunnel closure
- Enhanced shutdown sequence with clear progress indicators
- Ensure all active connections receive proper shutdown notifications

This ensures wssshd provides clear feedback during shutdown and properly closes all tunnels without waiting for keepalive timeouts.
parent b46562ea
......@@ -11,7 +11,7 @@ import signal
import websockets
from functools import partial
from .config import load_config
from .websocket import handle_websocket, cleanup_task, shutdown_event, debug, clients, active_tunnels, active_terminals, SERVER_SHUTDOWN_MSG
from .websocket import handle_websocket, cleanup_task, shutdown_event, debug, clients, active_tunnels, active_terminals, SERVER_SHUTDOWN_MSG, TUNNEL_CLOSE_MSG
from .web import run_flask
......@@ -55,34 +55,42 @@ async def shutdown_server(ws_server, cleanup_coro, flask_thread):
# Wait for all notifications with timeout
if notify_tasks:
print(f"Waiting for {len(notify_tasks)} client notifications to complete...")
try:
await asyncio.wait_for(
asyncio.gather(*notify_tasks, return_exceptions=True),
timeout=0.3
)
print("All clients notified")
print("All clients notified successfully")
except asyncio.TimeoutError:
print("Timeout waiting for client notifications, proceeding with shutdown")
print("Timeout waiting for client notifications (0.3s), proceeding with shutdown")
except Exception as e:
print(f"Error during client notifications, proceeding with shutdown")
if debug: print(f"[DEBUG] Error during client notifications: {e}")
# Give clients a brief moment to process the shutdown message
print("Waiting for clients to process shutdown message...")
await asyncio.sleep(0.1)
# Close WebSocket server
print("Closing WebSocket server...")
try:
ws_server.close()
await ws_server.wait_closed()
print("WebSocket server closed successfully")
except Exception as e:
print("Error closing WebSocket server, continuing shutdown")
if debug: print(f"[DEBUG] Error closing WebSocket server: {e}")
# Cancel cleanup task
# Cancel cleanup task immediately to stop keepalive timeouts
print("Stopping cleanup task (keepalive timeouts)...")
if not cleanup_coro.done():
cleanup_coro.cancel()
try:
await cleanup_coro
except asyncio.CancelledError:
pass
print("Cleanup task stopped")
# Clean up active terminals more efficiently
print("Terminating active terminal processes...")
......@@ -101,12 +109,20 @@ async def shutdown_server(ws_server, cleanup_coro, flask_thread):
if term_procs:
# Wait for graceful termination
print(f"Waiting for {len(term_procs)} terminal processes to terminate gracefully...")
await asyncio.sleep(0.3)
# Force kill remaining processes
kill_tasks = []
# Check which processes are still running
still_running = []
for request_id, proc in term_procs:
if proc.poll() is None:
still_running.append((request_id, proc))
if still_running:
print(f"Force killing {len(still_running)} remaining terminal processes...")
# Force kill remaining processes
kill_tasks = []
for request_id, proc in still_running:
if debug: print(f"[DEBUG] Force killing terminal process {request_id}")
proc.kill()
# Create async task for waiting
......@@ -115,15 +131,21 @@ async def shutdown_server(ws_server, cleanup_coro, flask_thread):
# Wait for all kill operations to complete
if kill_tasks:
print("Waiting for force-killed processes to complete...")
try:
await asyncio.wait_for(
asyncio.gather(*kill_tasks, return_exceptions=True),
timeout=0.2
)
print("All terminal processes terminated")
except asyncio.TimeoutError:
print("Timeout waiting for some processes to terminate")
if debug: print("[DEBUG] Some processes still running after SIGKILL")
except Exception as e:
print("Error during process cleanup")
if debug: print(f"[DEBUG] Error during process cleanup: {e}")
else:
print("All terminal processes terminated gracefully")
# Clean up terminal records (optimized)
terminal_count = len(active_terminals)
......@@ -131,43 +153,58 @@ async def shutdown_server(ws_server, cleanup_coro, flask_thread):
active_terminals.clear()
if debug: print(f"[DEBUG] Cleaned up {terminal_count} terminal records")
# Clean up active tunnels (optimized)
print("Cleaning up active tunnels...")
# Clean up ALL tunnels (not just active ones)
print("Sending close messages to all tunnels...")
if active_tunnels:
# Create close tasks for all active tunnels
# Create close tasks for ALL tunnels
close_tasks = []
for request_id, tunnel in active_tunnels.items():
if tunnel.status == 'active': # Check tunnel status
client_info = clients.get(tunnel.client_id)
if client_info and client_info['status'] == 'active':
# Send close message to both client and tool endpoints
try:
# Send to client (wssshc) if websocket exists
if tunnel.client_ws:
close_task = asyncio.create_task(
tunnel.client_ws.send(SERVER_SHUTDOWN_MSG)
tunnel.client_ws.send(TUNNEL_CLOSE_MSG % request_id)
)
close_tasks.append((request_id, close_task))
close_tasks.append((f"{request_id}_client", close_task))
except Exception as e:
if debug: print(f"[DEBUG] Failed to create close task for {request_id}: {e}")
if debug: print(f"[DEBUG] Failed to send close to client for {request_id}: {e}")
try:
# Send to tool (wsssht/wsscp) if websocket exists
if tunnel.wsssh_ws:
close_task = asyncio.create_task(
tunnel.wsssh_ws.send(TUNNEL_CLOSE_MSG % request_id)
)
close_tasks.append((f"{request_id}_tool", close_task))
except Exception as e:
if debug: print(f"[DEBUG] Failed to send close to tool for {request_id}: {e}")
# Wait for all close tasks with timeout
if close_tasks:
print(f"Waiting for {len(close_tasks)} tunnel close messages...")
try:
await asyncio.wait_for(
asyncio.gather(*[task for _, task in close_tasks], return_exceptions=True),
timeout=0.2
)
if debug: print(f"[DEBUG] Sent shutdown to {len(close_tasks)} clients")
print("All tunnel close messages sent")
if debug: print(f"[DEBUG] Sent close messages to {len(close_tasks)} endpoints")
except asyncio.TimeoutError:
if debug: print("[DEBUG] Timeout waiting for tunnel close notifications")
print("Timeout waiting for tunnel close messages (0.2s)")
if debug: print("[DEBUG] Timeout waiting for tunnel close messages")
except Exception as e:
if debug: print(f"[DEBUG] Error during tunnel close notifications: {e}")
print("Error during tunnel close messages")
if debug: print(f"[DEBUG] Error during tunnel close messages: {e}")
# Update tunnel statuses and clean up all tunnels
for request_id, tunnel in active_tunnels.items():
tunnel.update_status('closed', 'Server shutdown')
if debug: print(f"[DEBUG] Tunnel {request_id} status updated: {tunnel}")
tunnel_count = len(active_tunnels)
active_tunnels.clear()
if debug: print(f"[DEBUG] Cleaned up {len(close_tasks)} tunnels")
print(f"Cleaned up {tunnel_count} tunnels")
# Clean up clients (optimized)
client_count = len(clients)
......@@ -177,7 +214,12 @@ async def shutdown_server(ws_server, cleanup_coro, flask_thread):
# Stop Flask thread
if flask_thread and flask_thread.is_alive():
print("Waiting for web interface thread to stop...")
flask_thread.join(timeout=1.0)
if flask_thread.is_alive():
print("Web interface thread still running after timeout")
else:
print("Web interface thread stopped successfully")
print("WSSSH Daemon stopped cleanly")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment