Add reconnection logic to cluster client

- Client now attempts to reconnect if connection is lost
- Prevents processes from being restarted on reconnection
- Maintains persistent cluster node operation
parent b0c7da40
...@@ -142,31 +142,33 @@ class ClusterClient: ...@@ -142,31 +142,33 @@ class ClusterClient:
print(f"Client backend detection: CUDA={gpu_info['cuda']}, ROCm={gpu_info['rocm']}") print(f"Client backend detection: CUDA={gpu_info['cuda']}, ROCm={gpu_info['rocm']}")
print(f"Available backends: {available_backends}") print(f"Available backends: {available_backends}")
# Start analysis workers for available backends (including CPU) # Only start processes if not already started
for backend in available_backends: if not self.local_processes:
proc_name = f'analysis_{backend}' # Start analysis workers for available backends (including CPU)
cmd = [sys.executable, '-m', 'vidai.worker_analysis', backend] for backend in available_backends:
if self.optimize: proc_name = f'analysis_{backend}'
cmd.append('--optimize') cmd = [sys.executable, '-m', 'vidai.worker_analysis', backend]
if self.flash: if self.optimize:
cmd.append('--flash') cmd.append('--optimize')
self.local_processes[proc_name] = subprocess.Popen(cmd) if self.flash:
self.process_weights[proc_name] = 10 # Default weight cmd.append('--flash')
self.process_models[proc_name] = 'Qwen/Qwen2.5-VL-7B-Instruct' self.local_processes[proc_name] = subprocess.Popen(cmd)
print(f"Started analysis worker for {backend}") self.process_weights[proc_name] = 10 # Default weight
self.process_models[proc_name] = 'Qwen/Qwen2.5-VL-7B-Instruct'
# Start training workers for available backends (including CPU) print(f"Started analysis worker for {backend}")
for backend in available_backends:
proc_name = f'training_{backend}' # Start training workers for available backends (including CPU)
cmd = [sys.executable, '-m', 'vidai.worker_training', backend] for backend in available_backends:
if self.optimize: proc_name = f'training_{backend}'
cmd.append('--optimize') cmd = [sys.executable, '-m', 'vidai.worker_training', backend]
if self.flash: if self.optimize:
cmd.append('--flash') cmd.append('--optimize')
self.local_processes[proc_name] = subprocess.Popen(cmd) if self.flash:
self.process_weights[proc_name] = 5 # Training typically lower weight cmd.append('--flash')
self.process_models[proc_name] = 'Qwen/Qwen2.5-VL-7B-Instruct' self.local_processes[proc_name] = subprocess.Popen(cmd)
print(f"Started training worker for {backend}") self.process_weights[proc_name] = 5 # Training typically lower weight
self.process_models[proc_name] = 'Qwen/Qwen2.5-VL-7B-Instruct'
print(f"Started training worker for {backend}")
# Register processes with master # Register processes with master
await self._send_message({ await self._send_message({
...@@ -486,35 +488,60 @@ class ClusterClient: ...@@ -486,35 +488,60 @@ class ClusterClient:
print(f"Error handling shared model file {shared_file_path}: {e}") print(f"Error handling shared model file {shared_file_path}: {e}")
async def run(self) -> None: async def run(self) -> None:
"""Main client loop.""" """Main client loop with reconnection."""
if not await self.connect(): reconnect = True
return while reconnect: # Keep trying to connect/reconnect
if not await self.connect():
print("Failed to connect, retrying in 5 seconds...")
await asyncio.sleep(5)
continue
await self.start_local_processes() await self.start_local_processes()
# Start command handling task # Start command handling task
command_task = asyncio.create_task(self.handle_master_commands()) command_task = asyncio.create_task(self.handle_master_commands())
try: connection_lost = False
while self.connected: try:
await asyncio.sleep(1) while self.connected:
await asyncio.sleep(1)
# Send heartbeat # Send heartbeat
await self._send_message({'type': 'heartbeat'}) await self._send_message({'type': 'heartbeat'})
except KeyboardInterrupt: except KeyboardInterrupt:
print("Shutting down cluster client...") print("Shutting down cluster client...")
reconnect = False
finally: except Exception as e:
# Cleanup print(f"Connection lost: {e}, attempting to reconnect...")
command_task.cancel() connection_lost = True
for proc in self.local_processes.values():
proc.terminate() finally:
for proc in self.local_processes.values(): # Cleanup for this connection
proc.wait() command_task.cancel()
try:
await command_task
except:
pass
# Don't terminate processes on reconnection, just close websocket
if self.websocket:
try:
await self.websocket.close()
except:
pass
self.websocket = None
# If connection was lost (not keyboard interrupt), wait before reconnecting
if connection_lost and reconnect:
await asyncio.sleep(5)
if self.websocket: # Final cleanup
await self.websocket.close() for proc in self.local_processes.values():
proc.terminate()
for proc in self.local_processes.values():
proc.wait()
def start_cluster_client(host: str, port: int, token: str, optimize: bool = False, flash: bool = False, weight: int = 100, shared_dir: str = None) -> None: def start_cluster_client(host: str, port: int, token: str, optimize: bool = False, flash: bool = False, weight: int = 100, shared_dir: str = None) -> None:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment