Add reconnection logic to cluster client

- Client now attempts to reconnect if connection is lost
- Prevents processes from being restarted on reconnection
- Maintains persistent cluster node operation
parent b0c7da40
......@@ -142,6 +142,8 @@ class ClusterClient:
print(f"Client backend detection: CUDA={gpu_info['cuda']}, ROCm={gpu_info['rocm']}")
print(f"Available backends: {available_backends}")
# Only start processes if not already started
if not self.local_processes:
# Start analysis workers for available backends (including CPU)
for backend in available_backends:
proc_name = f'analysis_{backend}'
......@@ -486,15 +488,20 @@ class ClusterClient:
print(f"Error handling shared model file {shared_file_path}: {e}")
async def run(self) -> None:
"""Main client loop."""
"""Main client loop with reconnection."""
reconnect = True
while reconnect: # Keep trying to connect/reconnect
if not await self.connect():
return
print("Failed to connect, retrying in 5 seconds...")
await asyncio.sleep(5)
continue
await self.start_local_processes()
# Start command handling task
command_task = asyncio.create_task(self.handle_master_commands())
connection_lost = False
try:
while self.connected:
await asyncio.sleep(1)
......@@ -504,18 +511,38 @@ class ClusterClient:
except KeyboardInterrupt:
print("Shutting down cluster client...")
reconnect = False
except Exception as e:
print(f"Connection lost: {e}, attempting to reconnect...")
connection_lost = True
finally:
# Cleanup
# Cleanup for this connection
command_task.cancel()
try:
await command_task
except:
pass
# Don't terminate processes on reconnection, just close websocket
if self.websocket:
try:
await self.websocket.close()
except:
pass
self.websocket = None
# If connection was lost (not keyboard interrupt), wait before reconnecting
if connection_lost and reconnect:
await asyncio.sleep(5)
# Final cleanup
for proc in self.local_processes.values():
proc.terminate()
for proc in self.local_processes.values():
proc.wait()
if self.websocket:
await self.websocket.close()
def start_cluster_client(host: str, port: int, token: str, optimize: bool = False, flash: bool = False, weight: int = 100, shared_dir: str = None) -> None:
"""Start the cluster client."""
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment