Add reconnection logic to cluster client

- Client now attempts to reconnect if connection is lost
- Prevents processes from being restarted on reconnection
- Maintains persistent cluster node operation
parent b0c7da40
......@@ -142,31 +142,33 @@ class ClusterClient:
print(f"Client backend detection: CUDA={gpu_info['cuda']}, ROCm={gpu_info['rocm']}")
print(f"Available backends: {available_backends}")
# Start analysis workers for available backends (including CPU)
for backend in available_backends:
proc_name = f'analysis_{backend}'
cmd = [sys.executable, '-m', 'vidai.worker_analysis', backend]
if self.optimize:
cmd.append('--optimize')
if self.flash:
cmd.append('--flash')
self.local_processes[proc_name] = subprocess.Popen(cmd)
self.process_weights[proc_name] = 10 # Default weight
self.process_models[proc_name] = 'Qwen/Qwen2.5-VL-7B-Instruct'
print(f"Started analysis worker for {backend}")
# Start training workers for available backends (including CPU)
for backend in available_backends:
proc_name = f'training_{backend}'
cmd = [sys.executable, '-m', 'vidai.worker_training', backend]
if self.optimize:
cmd.append('--optimize')
if self.flash:
cmd.append('--flash')
self.local_processes[proc_name] = subprocess.Popen(cmd)
self.process_weights[proc_name] = 5 # Training typically lower weight
self.process_models[proc_name] = 'Qwen/Qwen2.5-VL-7B-Instruct'
print(f"Started training worker for {backend}")
# Only start processes if not already started
if not self.local_processes:
# Start analysis workers for available backends (including CPU)
for backend in available_backends:
proc_name = f'analysis_{backend}'
cmd = [sys.executable, '-m', 'vidai.worker_analysis', backend]
if self.optimize:
cmd.append('--optimize')
if self.flash:
cmd.append('--flash')
self.local_processes[proc_name] = subprocess.Popen(cmd)
self.process_weights[proc_name] = 10 # Default weight
self.process_models[proc_name] = 'Qwen/Qwen2.5-VL-7B-Instruct'
print(f"Started analysis worker for {backend}")
# Start training workers for available backends (including CPU)
for backend in available_backends:
proc_name = f'training_{backend}'
cmd = [sys.executable, '-m', 'vidai.worker_training', backend]
if self.optimize:
cmd.append('--optimize')
if self.flash:
cmd.append('--flash')
self.local_processes[proc_name] = subprocess.Popen(cmd)
self.process_weights[proc_name] = 5 # Training typically lower weight
self.process_models[proc_name] = 'Qwen/Qwen2.5-VL-7B-Instruct'
print(f"Started training worker for {backend}")
# Register processes with master
await self._send_message({
......@@ -486,35 +488,60 @@ class ClusterClient:
print(f"Error handling shared model file {shared_file_path}: {e}")
async def run(self) -> None:
"""Main client loop."""
if not await self.connect():
return
"""Main client loop with reconnection."""
reconnect = True
while reconnect: # Keep trying to connect/reconnect
if not await self.connect():
print("Failed to connect, retrying in 5 seconds...")
await asyncio.sleep(5)
continue
await self.start_local_processes()
await self.start_local_processes()
# Start command handling task
command_task = asyncio.create_task(self.handle_master_commands())
# Start command handling task
command_task = asyncio.create_task(self.handle_master_commands())
try:
while self.connected:
await asyncio.sleep(1)
connection_lost = False
try:
while self.connected:
await asyncio.sleep(1)
# Send heartbeat
await self._send_message({'type': 'heartbeat'})
# Send heartbeat
await self._send_message({'type': 'heartbeat'})
except KeyboardInterrupt:
print("Shutting down cluster client...")
except KeyboardInterrupt:
print("Shutting down cluster client...")
reconnect = False
finally:
# Cleanup
command_task.cancel()
for proc in self.local_processes.values():
proc.terminate()
for proc in self.local_processes.values():
proc.wait()
except Exception as e:
print(f"Connection lost: {e}, attempting to reconnect...")
connection_lost = True
finally:
# Cleanup for this connection
command_task.cancel()
try:
await command_task
except:
pass
# Don't terminate processes on reconnection, just close websocket
if self.websocket:
try:
await self.websocket.close()
except:
pass
self.websocket = None
# If connection was lost (not keyboard interrupt), wait before reconnecting
if connection_lost and reconnect:
await asyncio.sleep(5)
if self.websocket:
await self.websocket.close()
# Final cleanup
for proc in self.local_processes.values():
proc.terminate()
for proc in self.local_processes.values():
proc.wait()
def start_cluster_client(host: str, port: int, token: str, optimize: bool = False, flash: bool = False, weight: int = 100, shared_dir: str = None) -> None:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment