Fix client disconnection handling in cluster master

- Made assign_job_to_worker, _transfer_job_files, _transfer_file_via_websocket, enable_process, disable_process, update_process_weight, restart_client_workers, and restart_client_worker async methods
- Added proper exception handling for websocket send operations
- When websocket send fails due to broken connection, clients are now properly removed from available workers selection
- This ensures that disconnected clients are immediately removed from the worker pool and jobs are re-assigned to available workers
parent a2c308f1
...@@ -444,7 +444,7 @@ class ClusterMaster: ...@@ -444,7 +444,7 @@ class ClusterMaster:
return max(1, cuda_count + rocm_count) return max(1, cuda_count + rocm_count)
def assign_job_to_worker(self, worker_key: str, job_data: dict) -> Optional[str]: async def assign_job_to_worker(self, worker_key: str, job_data: dict) -> Optional[str]:
"""Assign a job to a worker and handle file/model transfer.""" """Assign a job to a worker and handle file/model transfer."""
from .models import estimate_model_vram_requirements from .models import estimate_model_vram_requirements
import uuid import uuid
...@@ -491,11 +491,15 @@ class ClusterMaster: ...@@ -491,11 +491,15 @@ class ClusterMaster:
'job_data': job_data 'job_data': job_data
} }
# Send via websocket (async) # Send via websocket synchronously to catch connection errors
asyncio.create_task(self.client_websockets[client_id].send(json.dumps(job_message))) await self.client_websockets[client_id].send(json.dumps(job_message))
print(f"Job {job_id} assigned to worker {worker_key} on client {client_id}") print(f"Job {job_id} assigned to worker {worker_key} on client {client_id}")
except Exception as e: except Exception as e:
print(f"Failed to send job {job_id} to worker {worker_key}: {e}") print(f"Failed to send job {job_id} to worker {worker_key}: {e}")
# Connection is broken, remove the client
if client_id in self.clients:
print(f"Removing disconnected client {client_id}")
self._remove_client(client_id)
# Clean up the failed assignment # Clean up the failed assignment
self.worker_jobs[worker_key].remove(job_id) self.worker_jobs[worker_key].remove(job_id)
self.worker_vram_usage[worker_key] -= vram_required self.worker_vram_usage[worker_key] -= vram_required
...@@ -511,7 +515,7 @@ class ClusterMaster: ...@@ -511,7 +515,7 @@ class ClusterMaster:
return job_id return job_id
def _transfer_job_files(self, client_id: str, job_data: dict, job_id: str) -> None: async def _transfer_job_files(self, client_id: str, job_data: dict, job_id: str) -> None:
"""Transfer job files to client.""" """Transfer job files to client."""
media_path = job_data.get('local_path') media_path = job_data.get('local_path')
if not media_path: if not media_path:
...@@ -529,21 +533,29 @@ class ClusterMaster: ...@@ -529,21 +533,29 @@ class ClusterMaster:
shutil.copy2(media_path, shared_file_path) shutil.copy2(media_path, shared_file_path)
# Notify client of shared file location # Notify client of shared file location
asyncio.create_task(self.client_websockets[client_id].send(json.dumps({ try:
await self.client_websockets[client_id].send(json.dumps({
'type': 'job_file_shared', 'type': 'job_file_shared',
'job_id': job_id, 'job_id': job_id,
'shared_file_path': shared_file_path, 'shared_file_path': shared_file_path,
'original_path': media_path 'original_path': media_path
}))) }))
except Exception as e:
print(f"Failed to send shared file notification for job {job_id}: {e}")
# Connection is broken, remove the client
if client_id in self.clients:
print(f"Removing disconnected client {client_id}")
self._remove_client(client_id)
return
except Exception as e: except Exception as e:
print(f"Failed to use shared directory for job {job_id}: {e}") print(f"Failed to use shared directory for job {job_id}: {e}")
# Fall back to websocket transfer # Fall back to websocket transfer
self._transfer_file_via_websocket(client_id, media_path, job_id) await self._transfer_file_via_websocket(client_id, media_path, job_id)
else: else:
# Use websocket transfer # Use websocket transfer
self._transfer_file_via_websocket(client_id, media_path, job_id) await self._transfer_file_via_websocket(client_id, media_path, job_id)
def _transfer_file_via_websocket(self, client_id: str, file_path: str, job_id: str) -> None: async def _transfer_file_via_websocket(self, client_id: str, file_path: str, job_id: str) -> None:
"""Transfer file via websocket.""" """Transfer file via websocket."""
try: try:
with open(file_path, 'rb') as f: with open(file_path, 'rb') as f:
...@@ -554,29 +566,37 @@ class ClusterMaster: ...@@ -554,29 +566,37 @@ class ClusterMaster:
total_size = len(file_data) total_size = len(file_data)
# Send start message # Send start message
asyncio.create_task(self.client_websockets[client_id].send(json.dumps({ try:
await self.client_websockets[client_id].send(json.dumps({
'type': 'job_file_transfer_start', 'type': 'job_file_transfer_start',
'job_id': job_id, 'job_id': job_id,
'file_path': file_path, 'file_path': file_path,
'total_size': total_size 'total_size': total_size
}))) }))
# Send chunks # Send chunks
for i in range(0, total_size, chunk_size): for i in range(0, total_size, chunk_size):
chunk = file_data[i:i + chunk_size] chunk = file_data[i:i + chunk_size]
asyncio.create_task(self.client_websockets[client_id].send(json.dumps({ await self.client_websockets[client_id].send(json.dumps({
'type': 'job_file_chunk', 'type': 'job_file_chunk',
'job_id': job_id, 'job_id': job_id,
'offset': i, 'offset': i,
'data': chunk.hex() 'data': chunk.hex()
}))) }))
# Send completion # Send completion
asyncio.create_task(self.client_websockets[client_id].send(json.dumps({ await self.client_websockets[client_id].send(json.dumps({
'type': 'job_file_transfer_complete', 'type': 'job_file_transfer_complete',
'job_id': job_id, 'job_id': job_id,
'file_path': file_path 'file_path': file_path
}))) }))
except Exception as e:
print(f"Failed to transfer file {file_path} for job {job_id}: {e}")
# Connection is broken, remove the client
if client_id in self.clients:
print(f"Removing disconnected client {client_id}")
self._remove_client(client_id)
raise # Re-raise to indicate failure
except Exception as e: except Exception as e:
print(f"Failed to transfer file {file_path} for job {job_id}: {e}") print(f"Failed to transfer file {file_path} for job {job_id}: {e}")
...@@ -702,12 +722,20 @@ class ClusterMaster: ...@@ -702,12 +722,20 @@ class ClusterMaster:
f.write(model_data) f.write(model_data)
# Send file path to client # Send file path to client
try:
await self.client_websockets[client_id].send(json.dumps({ await self.client_websockets[client_id].send(json.dumps({
'type': 'model_shared_file', 'type': 'model_shared_file',
'model_path': model_path, 'model_path': model_path,
'shared_file_path': shared_file_path, 'shared_file_path': shared_file_path,
'total_size': len(model_data) 'total_size': len(model_data)
})) }))
except Exception as e:
print(f"Failed to send model shared file notification to client {client_id}: {e}")
# Connection is broken, remove the client
if client_id in self.clients:
print(f"Removing disconnected client {client_id}")
self._remove_client(client_id)
return False
print(f"Model {model_path} placed in shared directory for client {client_id}: {shared_file_path}") print(f"Model {model_path} placed in shared directory for client {client_id}: {shared_file_path}")
return True return True
...@@ -770,19 +798,27 @@ class ClusterMaster: ...@@ -770,19 +798,27 @@ class ClusterMaster:
print("huggingface_hub not available for model download") print("huggingface_hub not available for model download")
return None return None
def enable_process(self, process_key: str) -> bool: async def enable_process(self, process_key: str) -> bool:
"""Enable a specific process.""" """Enable a specific process."""
if process_key in self.processes: if process_key in self.processes:
self.processes[process_key]['status'] = 'active' self.processes[process_key]['status'] = 'active'
# Send command to client # Send command to client
client_id = self.processes[process_key]['client_id'] client_id = self.processes[process_key]['client_id']
if client_id in self.client_websockets: if client_id in self.client_websockets:
asyncio.create_task(self.client_websockets[client_id].send( try:
await self.client_websockets[client_id].send(
json.dumps({ json.dumps({
'type': 'enable_process', 'type': 'enable_process',
'process_name': self.processes[process_key]['name'] 'process_name': self.processes[process_key]['name']
}) })
)) )
except Exception as e:
print(f"Failed to send enable_process command to client {client_id}: {e}")
# Connection is broken, remove the client
if client_id in self.clients:
print(f"Removing disconnected client {client_id}")
self._remove_client(client_id)
return False
return True return True
return False return False
...@@ -951,23 +987,31 @@ class ClusterMaster: ...@@ -951,23 +987,31 @@ class ClusterMaster:
return None return None
def disable_process(self, process_key: str) -> bool: async def disable_process(self, process_key: str) -> bool:
"""Disable a specific process.""" """Disable a specific process."""
if process_key in self.processes: if process_key in self.processes:
self.processes[process_key]['status'] = 'disabled' self.processes[process_key]['status'] = 'disabled'
# Send command to client # Send command to client
client_id = self.processes[process_key]['client_id'] client_id = self.processes[process_key]['client_id']
if client_id in self.client_websockets: if client_id in self.client_websockets:
asyncio.create_task(self.client_websockets[client_id].send( try:
await self.client_websockets[client_id].send(
json.dumps({ json.dumps({
'type': 'disable_process', 'type': 'disable_process',
'process_name': self.processes[process_key]['name'] 'process_name': self.processes[process_key]['name']
}) })
)) )
except Exception as e:
print(f"Failed to send disable_process command to client {client_id}: {e}")
# Connection is broken, remove the client
if client_id in self.clients:
print(f"Removing disconnected client {client_id}")
self._remove_client(client_id)
return False
return True return True
return False return False
def update_process_weight(self, process_key: str, weight: int) -> bool: async def update_process_weight(self, process_key: str, weight: int) -> bool:
"""Update process weight for load balancing.""" """Update process weight for load balancing."""
if process_key in self.processes: if process_key in self.processes:
self.processes[process_key]['weight'] = weight self.processes[process_key]['weight'] = weight
...@@ -986,17 +1030,25 @@ class ClusterMaster: ...@@ -986,17 +1030,25 @@ class ClusterMaster:
# Send command to client # Send command to client
client_id = self.processes[process_key]['client_id'] client_id = self.processes[process_key]['client_id']
if client_id in self.client_websockets: if client_id in self.client_websockets:
asyncio.create_task(self.client_websockets[client_id].send( try:
await self.client_websockets[client_id].send(
json.dumps({ json.dumps({
'type': 'update_weight', 'type': 'update_weight',
'process_name': self.processes[process_key]['name'], 'process_name': self.processes[process_key]['name'],
'weight': weight 'weight': weight
}) })
)) )
except Exception as e:
print(f"Failed to send update_weight command to client {client_id}: {e}")
# Connection is broken, remove the client
if client_id in self.clients:
print(f"Removing disconnected client {client_id}")
self._remove_client(client_id)
return False
return True return True
return False return False
def restart_client_workers(self, client_id: str, backend: str) -> bool: async def restart_client_workers(self, client_id: str, backend: str) -> bool:
"""Restart all workers on a client with a different backend.""" """Restart all workers on a client with a different backend."""
if client_id not in self.client_websockets: if client_id not in self.client_websockets:
return False return False
...@@ -1006,15 +1058,23 @@ class ClusterMaster: ...@@ -1006,15 +1058,23 @@ class ClusterMaster:
return False return False
# Send restart command to client # Send restart command to client
asyncio.create_task(self.client_websockets[client_id].send( try:
await self.client_websockets[client_id].send(
json.dumps({ json.dumps({
'type': 'restart_workers', 'type': 'restart_workers',
'backend': backend 'backend': backend
}) })
)) )
return True return True
except Exception as e:
print(f"Failed to send restart_workers command to client {client_id}: {e}")
# Connection is broken, remove the client
if client_id in self.clients:
print(f"Removing disconnected client {client_id}")
self._remove_client(client_id)
return False
def restart_client_worker(self, client_id: str, worker_name: str, backend: str) -> bool: async def restart_client_worker(self, client_id: str, worker_name: str, backend: str) -> bool:
"""Restart a specific worker on a client with a different backend.""" """Restart a specific worker on a client with a different backend."""
if client_id not in self.client_websockets: if client_id not in self.client_websockets:
return False return False
...@@ -1025,16 +1085,20 @@ class ClusterMaster: ...@@ -1025,16 +1085,20 @@ class ClusterMaster:
# Send restart command for specific worker to client # Send restart command for specific worker to client
try: try:
asyncio.create_task(self.client_websockets[client_id].send( await self.client_websockets[client_id].send(
json.dumps({ json.dumps({
'type': 'restart_worker', 'type': 'restart_worker',
'worker_name': worker_name, 'worker_name': worker_name,
'backend': backend 'backend': backend
}) })
)) )
return True return True
except Exception as e: except Exception as e:
print(f"Failed to send restart command for worker {worker_name} to client {client_id}: {e}") print(f"Failed to send restart command for worker {worker_name} to client {client_id}: {e}")
# Connection is broken, remove the client
if client_id in self.clients:
print(f"Removing disconnected client {client_id}")
self._remove_client(client_id)
return False return False
async def _management_loop(self) -> None: async def _management_loop(self) -> None:
...@@ -1074,7 +1138,7 @@ class ClusterMaster: ...@@ -1074,7 +1138,7 @@ class ClusterMaster:
print(f"Job {job['id']} waiting for available workers") print(f"Job {job['id']} waiting for available workers")
worker_key = self.select_worker_for_job(process_type, job['data'].get('model_path', 'Qwen/Qwen2.5-VL-7B-Instruct'), job['data']) worker_key = self.select_worker_for_job(process_type, job['data'].get('model_path', 'Qwen/Qwen2.5-VL-7B-Instruct'), job['data'])
if worker_key: if worker_key:
job_id = self.assign_job_to_worker(worker_key, job['data']) job_id = await self.assign_job_to_worker(worker_key, job['data'])
if job_id: if job_id:
from .database import update_queue_status from .database import update_queue_status
update_queue_status(job['id'], 'processing', {'job_id': job_id, 'status': 'Assigned to worker'}, job_id=job_id) update_queue_status(job['id'], 'processing', {'job_id': job_id, 'status': 'Assigned to worker'}, job_id=job_id)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment