Improve wssshd graceful shutdown with better messaging and cleanup

- Add informative messages during graceful shutdown indicating what the server is waiting for
- Stop cleanup task immediately during shutdown to prevent keepalive timeouts
- Send tunnel close messages to ALL tunnels (both client and tool endpoints) during shutdown
- Import TUNNEL_CLOSE_MSG constant for proper tunnel closure
- Enhanced shutdown sequence with clear progress indicators
- Ensure all active connections receive proper shutdown notifications

This ensures wssshd provides clear feedback during shutdown and properly closes all tunnels without waiting for keepalive timeouts.
parent b46562ea
...@@ -11,7 +11,7 @@ import signal ...@@ -11,7 +11,7 @@ import signal
import websockets import websockets
from functools import partial from functools import partial
from .config import load_config from .config import load_config
from .websocket import handle_websocket, cleanup_task, shutdown_event, debug, clients, active_tunnels, active_terminals, SERVER_SHUTDOWN_MSG from .websocket import handle_websocket, cleanup_task, shutdown_event, debug, clients, active_tunnels, active_terminals, SERVER_SHUTDOWN_MSG, TUNNEL_CLOSE_MSG
from .web import run_flask from .web import run_flask
...@@ -55,34 +55,42 @@ async def shutdown_server(ws_server, cleanup_coro, flask_thread): ...@@ -55,34 +55,42 @@ async def shutdown_server(ws_server, cleanup_coro, flask_thread):
# Wait for all notifications with timeout # Wait for all notifications with timeout
if notify_tasks: if notify_tasks:
print(f"Waiting for {len(notify_tasks)} client notifications to complete...")
try: try:
await asyncio.wait_for( await asyncio.wait_for(
asyncio.gather(*notify_tasks, return_exceptions=True), asyncio.gather(*notify_tasks, return_exceptions=True),
timeout=0.3 timeout=0.3
) )
print("All clients notified") print("All clients notified successfully")
except asyncio.TimeoutError: except asyncio.TimeoutError:
print("Timeout waiting for client notifications, proceeding with shutdown") print("Timeout waiting for client notifications (0.3s), proceeding with shutdown")
except Exception as e: except Exception as e:
print(f"Error during client notifications, proceeding with shutdown")
if debug: print(f"[DEBUG] Error during client notifications: {e}") if debug: print(f"[DEBUG] Error during client notifications: {e}")
# Give clients a brief moment to process the shutdown message # Give clients a brief moment to process the shutdown message
print("Waiting for clients to process shutdown message...")
await asyncio.sleep(0.1) await asyncio.sleep(0.1)
# Close WebSocket server # Close WebSocket server
print("Closing WebSocket server...")
try: try:
ws_server.close() ws_server.close()
await ws_server.wait_closed() await ws_server.wait_closed()
print("WebSocket server closed successfully")
except Exception as e: except Exception as e:
print("Error closing WebSocket server, continuing shutdown")
if debug: print(f"[DEBUG] Error closing WebSocket server: {e}") if debug: print(f"[DEBUG] Error closing WebSocket server: {e}")
# Cancel cleanup task # Cancel cleanup task immediately to stop keepalive timeouts
print("Stopping cleanup task (keepalive timeouts)...")
if not cleanup_coro.done(): if not cleanup_coro.done():
cleanup_coro.cancel() cleanup_coro.cancel()
try: try:
await cleanup_coro await cleanup_coro
except asyncio.CancelledError: except asyncio.CancelledError:
pass pass
print("Cleanup task stopped")
# Clean up active terminals more efficiently # Clean up active terminals more efficiently
print("Terminating active terminal processes...") print("Terminating active terminal processes...")
...@@ -101,29 +109,43 @@ async def shutdown_server(ws_server, cleanup_coro, flask_thread): ...@@ -101,29 +109,43 @@ async def shutdown_server(ws_server, cleanup_coro, flask_thread):
if term_procs: if term_procs:
# Wait for graceful termination # Wait for graceful termination
print(f"Waiting for {len(term_procs)} terminal processes to terminate gracefully...")
await asyncio.sleep(0.3) await asyncio.sleep(0.3)
# Force kill remaining processes # Check which processes are still running
kill_tasks = [] still_running = []
for request_id, proc in term_procs: for request_id, proc in term_procs:
if proc.poll() is None: if proc.poll() is None:
still_running.append((request_id, proc))
if still_running:
print(f"Force killing {len(still_running)} remaining terminal processes...")
# Force kill remaining processes
kill_tasks = []
for request_id, proc in still_running:
if debug: print(f"[DEBUG] Force killing terminal process {request_id}") if debug: print(f"[DEBUG] Force killing terminal process {request_id}")
proc.kill() proc.kill()
# Create async task for waiting # Create async task for waiting
task = asyncio.get_event_loop().run_in_executor(None, proc.wait) task = asyncio.get_event_loop().run_in_executor(None, proc.wait)
kill_tasks.append(task) kill_tasks.append(task)
# Wait for all kill operations to complete # Wait for all kill operations to complete
if kill_tasks: if kill_tasks:
try: print("Waiting for force-killed processes to complete...")
await asyncio.wait_for( try:
asyncio.gather(*kill_tasks, return_exceptions=True), await asyncio.wait_for(
timeout=0.2 asyncio.gather(*kill_tasks, return_exceptions=True),
) timeout=0.2
except asyncio.TimeoutError: )
if debug: print("[DEBUG] Some processes still running after SIGKILL") print("All terminal processes terminated")
except Exception as e: except asyncio.TimeoutError:
if debug: print(f"[DEBUG] Error during process cleanup: {e}") print("Timeout waiting for some processes to terminate")
if debug: print("[DEBUG] Some processes still running after SIGKILL")
except Exception as e:
print("Error during process cleanup")
if debug: print(f"[DEBUG] Error during process cleanup: {e}")
else:
print("All terminal processes terminated gracefully")
# Clean up terminal records (optimized) # Clean up terminal records (optimized)
terminal_count = len(active_terminals) terminal_count = len(active_terminals)
...@@ -131,43 +153,58 @@ async def shutdown_server(ws_server, cleanup_coro, flask_thread): ...@@ -131,43 +153,58 @@ async def shutdown_server(ws_server, cleanup_coro, flask_thread):
active_terminals.clear() active_terminals.clear()
if debug: print(f"[DEBUG] Cleaned up {terminal_count} terminal records") if debug: print(f"[DEBUG] Cleaned up {terminal_count} terminal records")
# Clean up active tunnels (optimized) # Clean up ALL tunnels (not just active ones)
print("Cleaning up active tunnels...") print("Sending close messages to all tunnels...")
if active_tunnels: if active_tunnels:
# Create close tasks for all active tunnels # Create close tasks for ALL tunnels
close_tasks = [] close_tasks = []
for request_id, tunnel in active_tunnels.items(): for request_id, tunnel in active_tunnels.items():
if tunnel.status == 'active': # Check tunnel status # Send close message to both client and tool endpoints
client_info = clients.get(tunnel.client_id) try:
if client_info and client_info['status'] == 'active': # Send to client (wssshc) if websocket exists
try: if tunnel.client_ws:
close_task = asyncio.create_task( close_task = asyncio.create_task(
tunnel.client_ws.send(SERVER_SHUTDOWN_MSG) tunnel.client_ws.send(TUNNEL_CLOSE_MSG % request_id)
) )
close_tasks.append((request_id, close_task)) close_tasks.append((f"{request_id}_client", close_task))
except Exception as e: except Exception as e:
if debug: print(f"[DEBUG] Failed to create close task for {request_id}: {e}") if debug: print(f"[DEBUG] Failed to send close to client for {request_id}: {e}")
try:
# Send to tool (wsssht/wsscp) if websocket exists
if tunnel.wsssh_ws:
close_task = asyncio.create_task(
tunnel.wsssh_ws.send(TUNNEL_CLOSE_MSG % request_id)
)
close_tasks.append((f"{request_id}_tool", close_task))
except Exception as e:
if debug: print(f"[DEBUG] Failed to send close to tool for {request_id}: {e}")
# Wait for all close tasks with timeout # Wait for all close tasks with timeout
if close_tasks: if close_tasks:
print(f"Waiting for {len(close_tasks)} tunnel close messages...")
try: try:
await asyncio.wait_for( await asyncio.wait_for(
asyncio.gather(*[task for _, task in close_tasks], return_exceptions=True), asyncio.gather(*[task for _, task in close_tasks], return_exceptions=True),
timeout=0.2 timeout=0.2
) )
if debug: print(f"[DEBUG] Sent shutdown to {len(close_tasks)} clients") print("All tunnel close messages sent")
if debug: print(f"[DEBUG] Sent close messages to {len(close_tasks)} endpoints")
except asyncio.TimeoutError: except asyncio.TimeoutError:
if debug: print("[DEBUG] Timeout waiting for tunnel close notifications") print("Timeout waiting for tunnel close messages (0.2s)")
if debug: print("[DEBUG] Timeout waiting for tunnel close messages")
except Exception as e: except Exception as e:
if debug: print(f"[DEBUG] Error during tunnel close notifications: {e}") print("Error during tunnel close messages")
if debug: print(f"[DEBUG] Error during tunnel close messages: {e}")
# Update tunnel statuses and clean up all tunnels # Update tunnel statuses and clean up all tunnels
for request_id, tunnel in active_tunnels.items(): for request_id, tunnel in active_tunnels.items():
tunnel.update_status('closed', 'Server shutdown') tunnel.update_status('closed', 'Server shutdown')
if debug: print(f"[DEBUG] Tunnel {request_id} status updated: {tunnel}") if debug: print(f"[DEBUG] Tunnel {request_id} status updated: {tunnel}")
tunnel_count = len(active_tunnels)
active_tunnels.clear() active_tunnels.clear()
if debug: print(f"[DEBUG] Cleaned up {len(close_tasks)} tunnels") print(f"Cleaned up {tunnel_count} tunnels")
# Clean up clients (optimized) # Clean up clients (optimized)
client_count = len(clients) client_count = len(clients)
...@@ -177,7 +214,12 @@ async def shutdown_server(ws_server, cleanup_coro, flask_thread): ...@@ -177,7 +214,12 @@ async def shutdown_server(ws_server, cleanup_coro, flask_thread):
# Stop Flask thread # Stop Flask thread
if flask_thread and flask_thread.is_alive(): if flask_thread and flask_thread.is_alive():
print("Waiting for web interface thread to stop...")
flask_thread.join(timeout=1.0) flask_thread.join(timeout=1.0)
if flask_thread.is_alive():
print("Web interface thread still running after timeout")
else:
print("Web interface thread stopped successfully")
print("WSSSH Daemon stopped cleanly") print("WSSSH Daemon stopped cleanly")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment