Enable dynamic backend switching for cluster clients with mixed GPU support

- Added restart_workers command from master to clients for backend switching
- Cluster clients can now restart their workers with different backends (CUDA/ROCm/CPU)
- Added mixed GPU detection - nodes with both CUDA and ROCm show 'Mixed GPU Available' indicator
- Clients with mixed GPUs can switch between CUDA and ROCm backends dynamically
- Updated API endpoint to send restart commands to connected clients
- Clients save driver preferences and restart workers immediately when changed
- Graceful fallback to available backends if requested backend not available
- Visual indicator for nodes capable of backend switching
parent 6b838e4a
...@@ -202,6 +202,7 @@ function renderNodesTable() { ...@@ -202,6 +202,7 @@ function renderNodesTable() {
<td> <td>
<strong>${node.workers_available}</strong> workers<br> <strong>${node.workers_available}</strong> workers<br>
<small>${node.workers_summary || 'No workers'}</small> <small>${node.workers_summary || 'No workers'}</small>
${node.mixed_gpu ? '<br><small style="color: #059669; font-weight: bold;">● Mixed GPU Available</small>' : ''}
</td> </td>
<td>${node.ip_address}</td> <td>${node.ip_address}</td>
<td>${formatUptime(node.uptime_seconds || 0)}</td> <td>${formatUptime(node.uptime_seconds || 0)}</td>
......
...@@ -239,6 +239,9 @@ class ClusterClient: ...@@ -239,6 +239,9 @@ class ClusterClient:
elif msg_type == 'model_transfer_complete': elif msg_type == 'model_transfer_complete':
await self._handle_model_transfer_complete(message) await self._handle_model_transfer_complete(message)
elif msg_type == 'restart_workers':
await self._handle_restart_workers(message)
elif msg_type == 'model_shared_file': elif msg_type == 'model_shared_file':
await self._handle_model_shared_file(message) await self._handle_model_shared_file(message)
...@@ -356,6 +359,99 @@ class ClusterClient: ...@@ -356,6 +359,99 @@ class ClusterClient:
# Clean up # Clean up
delattr(self, '_current_model_transfer') delattr(self, '_current_model_transfer')
async def _handle_restart_workers(self, message: Dict[str, Any]) -> None:
"""Handle restart workers command from master."""
backend = message.get('backend', 'cuda')
if backend not in ['cuda', 'rocm', 'cpu']:
print(f"Invalid backend requested: {backend}")
return
print(f"Restarting workers with {backend} backend")
# Terminate existing workers
for proc in self.local_processes.values():
try:
proc.terminate()
except:
pass
# Wait a bit for termination
await asyncio.sleep(2)
# Force kill if still running
for proc in self.local_processes.values():
try:
if proc.poll() is None: # Still running
proc.kill()
except:
pass
# Clear the processes dict
self.local_processes.clear()
self.process_weights.clear()
self.process_models.clear()
# Check if backend is available
from .compat import get_available_backends
available_backends = get_available_backends()
if backend not in available_backends:
print(f"Requested backend {backend} not available, available: {available_backends}")
# Use first available backend as fallback
if available_backends:
backend = available_backends[0]
print(f"Using fallback backend: {backend}")
else:
print("No backends available, cannot restart workers")
return
# Start new workers with the specified backend
import sys
# Start analysis worker
try:
cmd = [sys.executable, '-m', 'vidai.worker_analysis', backend]
if self.optimize:
cmd.append('--optimize')
if self.flash:
cmd.append('--flash')
self.local_processes[f'analysis_{backend}'] = subprocess.Popen(cmd)
self.process_weights[f'analysis_{backend}'] = 10
self.process_models[f'analysis_{backend}'] = 'Qwen/Qwen2.5-VL-7B-Instruct'
print(f"Started analysis worker with {backend}")
except Exception as e:
print(f"Failed to start analysis worker: {e}")
# Start training worker
try:
cmd = [sys.executable, '-m', 'vidai.worker_training', backend]
if self.optimize:
cmd.append('--optimize')
if self.flash:
cmd.append('--flash')
self.local_processes[f'training_{backend}'] = subprocess.Popen(cmd)
self.process_weights[f'training_{backend}'] = 5
self.process_models[f'training_{backend}'] = 'Qwen/Qwen2.5-VL-7B-Instruct'
print(f"Started training worker with {backend}")
except Exception as e:
print(f"Failed to start training worker: {e}")
# Re-register processes with master
await self._send_message({
'type': 'register_processes',
'processes': {
name: {
'weight': weight,
'model': model,
'status': 'active'
}
for name, weight in self.process_weights.items()
for model_name, model in self.process_models.items()
if name in self.process_models
}
})
async def _handle_model_shared_file(self, message: Dict[str, Any]) -> None: async def _handle_model_shared_file(self, message: Dict[str, Any]) -> None:
"""Handle model file available in shared directory.""" """Handle model file available in shared directory."""
model_path = message.get('model_path') model_path = message.get('model_path')
......
...@@ -169,7 +169,7 @@ class ClusterMaster: ...@@ -169,7 +169,7 @@ class ClusterMaster:
msg_type = message.get('type') msg_type = message.get('type')
if msg_type == 'auth': if msg_type == 'auth':
return self._handle_auth(message, client_sock) return self._handle_auth(message, websocket)
elif msg_type == 'register_processes': elif msg_type == 'register_processes':
return self._handle_register_processes(message, websocket) return self._handle_register_processes(message, websocket)
...@@ -590,6 +590,23 @@ class ClusterMaster: ...@@ -590,6 +590,23 @@ class ClusterMaster:
return True return True
return False return False
def restart_client_workers(self, client_id: str, backend: str) -> bool:
"""Restart all workers on a client with a different backend."""
if client_id not in self.client_websockets:
return False
if backend not in ['cuda', 'rocm', 'cpu']:
return False
# Send restart command to client
asyncio.create_task(self.client_websockets[client_id].send(
json.dumps({
'type': 'restart_workers',
'backend': backend
})
))
return True
async def _management_loop(self) -> None: async def _management_loop(self) -> None:
"""Main management loop.""" """Main management loop."""
while self.running: while self.running:
......
...@@ -418,7 +418,13 @@ def api_cluster_nodes(): ...@@ -418,7 +418,13 @@ def api_cluster_nodes():
# Calculate uptime # Calculate uptime
connected_at = client_info.get('connected_at', current_time) connected_at = client_info.get('connected_at', current_time)
uptime_seconds = current_time - connected_at uptime_seconds = current_time - connected_at
# Detect mixed GPU availability
gpu_info = client_info.get('gpu_info', {})
has_cuda = gpu_info.get('cuda_available', False)
has_rocm = gpu_info.get('rocm_available', False)
mixed_gpu = has_cuda and has_rocm
node_map[node_key] = { node_map[node_key] = {
'token': token, 'token': token,
'token_name': token_name, 'token_name': token_name,
...@@ -434,6 +440,7 @@ def api_cluster_nodes(): ...@@ -434,6 +440,7 @@ def api_cluster_nodes():
'completed_jobs': 0, # Placeholder 'completed_jobs': 0, # Placeholder
'weight': client_info.get('weight', 100), 'weight': client_info.get('weight', 100),
'is_local': False, 'is_local': False,
'mixed_gpu': mixed_gpu,
'workers': [] # Will collect worker details 'workers': [] # Will collect worker details
} }
...@@ -632,8 +639,24 @@ def api_set_client_driver(): ...@@ -632,8 +639,24 @@ def api_set_client_driver():
return {'success': False, 'error': 'Missing token for remote client'}, 400 return {'success': False, 'error': 'Missing token for remote client'}, 400
from .database import set_client_driver_preference from .database import set_client_driver_preference
success = set_client_driver_preference(hostname, token, driver) from .cluster_master import cluster_master
return {'success': success}
# Save preference to database
db_success = set_client_driver_preference(hostname, token, driver)
# Find client and send restart command
client_id = None
for cid, client_info in cluster_master.clients.items():
if client_info['token'] == token:
client_id = cid
break
if client_id:
restart_success = cluster_master.restart_client_workers(client_id, driver)
return {'success': db_success and restart_success}
else:
# Client not currently connected, just save preference
return {'success': db_success}
def switch_local_worker_backends(new_backend): def switch_local_worker_backends(new_backend):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment