Implement per-worker driver selection modal

- Modified modal to show individual GPU-requiring workers on each node
- Allow granular driver selection (CUDA/ROCm/CPU) for each worker subprocess
- Updated database schema to store driver preferences per worker (hostname + token + worker_name)
- Enhanced API to handle per-worker driver setting with form field parsing
- Added restart_client_worker method to cluster master for individual worker restarts
- Frontend now displays worker-specific driver selection controls in modal
- Maintains node-level table view while providing worker-level configuration
- Supports CPU-only nodes and mixed GPU/CPU worker configurations
- Backward compatible with existing single-driver preference system
parent 5cbdab26
...@@ -121,20 +121,17 @@ ...@@ -121,20 +121,17 @@
<div class="modal-content"> <div class="modal-content">
<div class="modal-header"> <div class="modal-header">
<span class="close" onclick="closeModal()">&times;</span> <span class="close" onclick="closeModal()">&times;</span>
<h3>Select Driver for <span id="modalHostname"></span></h3> <h3>Select Drivers for Workers on <span id="modalHostname"></span></h3>
</div> </div>
<form id="driverForm"> <form id="driverForm">
<input type="hidden" id="modalHostnameInput" name="hostname"> <input type="hidden" id="modalHostnameInput" name="hostname">
<input type="hidden" id="modalTokenInput" name="token"> <input type="hidden" id="modalTokenInput" name="token">
<div class="form-group"> <div id="workerDriverSelections">
<label for="driverSelect">Preferred GPU Driver:</label> <!-- Worker driver selections will be populated dynamically -->
<select id="driverSelect" name="driver">
<!-- Options will be populated dynamically -->
</select>
</div> </div>
<div class="modal-footer"> <div class="modal-footer">
<button type="button" class="btn btn-secondary" onclick="closeModal()">Cancel</button> <button type="button" class="btn btn-secondary" onclick="closeModal()">Cancel</button>
<button type="submit" class="btn">Save Preference</button> <button type="submit" class="btn">Save Preferences</button>
</div> </div>
</form> </form>
</div> </div>
...@@ -214,7 +211,7 @@ function renderNodesTable() { ...@@ -214,7 +211,7 @@ function renderNodesTable() {
} }
function openDriverModal(hostname, token, hostnameValue) { function openDriverModal(hostname, token, hostnameValue) {
// Find the node data to get available backends // Find the node data
const node = nodesData.find(n => n.hostname === hostname && n.token === token); const node = nodesData.find(n => n.hostname === hostname && n.token === token);
if (!node) { if (!node) {
console.error('Node not found:', hostname, token); console.error('Node not found:', hostname, token);
...@@ -225,30 +222,44 @@ function openDriverModal(hostname, token, hostnameValue) { ...@@ -225,30 +222,44 @@ function openDriverModal(hostname, token, hostnameValue) {
document.getElementById('modalHostnameInput').value = hostnameValue; document.getElementById('modalHostnameInput').value = hostnameValue;
document.getElementById('modalTokenInput').value = token; document.getElementById('modalTokenInput').value = token;
// Populate driver options based on available backends (GPU and CPU) // Populate worker driver selections
const driverSelect = document.getElementById('driverSelect'); const container = document.getElementById('workerDriverSelections');
driverSelect.innerHTML = ''; container.innerHTML = '';
const availableBackends = node.available_backends || []; // Get GPU-requiring workers (analysis_cuda, training_cuda, analysis_rocm, training_rocm)
if (availableBackends.length === 0) { const gpuWorkers = [];
// Fallback for nodes without backend info if (node.workers && node.workers.length > 0) {
if (node.is_cpu_only) { node.workers.forEach(worker => {
availableBackends.push('cpu'); if (worker.backend !== 'cpu') {
} else { gpuWorkers.push(worker);
availableBackends.push('cuda', 'rocm');
} }
});
} }
availableBackends.forEach(backend => { if (gpuWorkers.length === 0) {
const option = document.createElement('option'); container.innerHTML = '<p>No GPU-requiring workers found on this node.</p>';
option.value = backend; return;
option.textContent = backend.toUpperCase(); }
driverSelect.appendChild(option);
});
// Set CUDA as default if available, otherwise first available // Create form fields for each GPU worker
const defaultBackend = availableBackends.includes('cuda') ? 'cuda' : availableBackends[0]; gpuWorkers.forEach(worker => {
driverSelect.value = defaultBackend; const workerDiv = document.createElement('div');
workerDiv.className = 'form-group';
workerDiv.innerHTML = `
<label for="worker_${worker.type}_${worker.backend}_driver">${worker.type} (${worker.backend.toUpperCase()}) Worker:</label>
<select id="worker_${worker.type}_${worker.backend}_driver" name="worker_${worker.type}_${worker.backend}_driver">
<option value="cuda">CUDA</option>
<option value="rocm">ROCm</option>
<option value="cpu">CPU</option>
</select>
`;
container.appendChild(workerDiv);
// Set current preference or default
const select = workerDiv.querySelector('select');
// For now, default to the worker's current backend
select.value = worker.backend;
});
document.getElementById('driverModal').style.display = 'block'; document.getElementById('driverModal').style.display = 'block';
} }
......
...@@ -615,6 +615,29 @@ class ClusterMaster: ...@@ -615,6 +615,29 @@ class ClusterMaster:
)) ))
return True return True
def restart_client_worker(self, client_id: str, worker_name: str, backend: str) -> bool:
"""Restart a specific worker on a client with a different backend."""
if client_id not in self.client_websockets:
return False
if backend not in ['cuda', 'rocm', 'cpu']:
print(f"Invalid backend requested: {backend} - only CUDA, ROCm, and CPU supported")
return False
# Send restart command for specific worker to client
try:
asyncio.create_task(self.client_websockets[client_id].send(
json.dumps({
'type': 'restart_worker',
'worker_name': worker_name,
'backend': backend
})
))
return True
except Exception as e:
print(f"Failed to send restart command for worker {worker_name} to client {client_id}: {e}")
return False
async def _management_loop(self) -> None: async def _management_loop(self) -> None:
"""Main management loop.""" """Main management loop."""
while self.running: while self.running:
......
...@@ -360,17 +360,18 @@ def init_db(conn) -> None: ...@@ -360,17 +360,18 @@ def init_db(conn) -> None:
except sqlite3.OperationalError: except sqlite3.OperationalError:
pass pass
# Client driver preferences table # Client driver preferences table (per worker)
if config['type'] == 'mysql': if config['type'] == 'mysql':
cursor.execute(''' cursor.execute('''
CREATE TABLE IF NOT EXISTS client_driver_preferences ( CREATE TABLE IF NOT EXISTS client_driver_preferences (
id INT AUTO_INCREMENT PRIMARY KEY, id INT AUTO_INCREMENT PRIMARY KEY,
hostname VARCHAR(255) NOT NULL, hostname VARCHAR(255) NOT NULL,
token VARCHAR(255) NOT NULL, token VARCHAR(255) NOT NULL,
worker_name VARCHAR(255) NOT NULL,
driver VARCHAR(10) NOT NULL DEFAULT 'cuda', driver VARCHAR(10) NOT NULL DEFAULT 'cuda',
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
UNIQUE KEY unique_hostname_token (hostname, token) UNIQUE KEY unique_hostname_token_worker (hostname, token, worker_name)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci
''') ''')
else: else:
...@@ -379,10 +380,11 @@ def init_db(conn) -> None: ...@@ -379,10 +380,11 @@ def init_db(conn) -> None:
id INTEGER PRIMARY KEY, id INTEGER PRIMARY KEY,
hostname TEXT NOT NULL, hostname TEXT NOT NULL,
token TEXT NOT NULL, token TEXT NOT NULL,
worker_name TEXT NOT NULL,
driver TEXT NOT NULL DEFAULT 'cuda', driver TEXT NOT NULL DEFAULT 'cuda',
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
UNIQUE(hostname, token) UNIQUE(hostname, token, worker_name)
) )
''') ''')
...@@ -1521,20 +1523,25 @@ def cleanup_expired_sessions() -> None: ...@@ -1521,20 +1523,25 @@ def cleanup_expired_sessions() -> None:
conn.close() conn.close()
# Client driver preferences functions # Client driver preferences functions (per worker)
def get_client_driver_preference(hostname: str, token: str) -> str: def get_client_driver_preference(hostname: str, token: str, worker_name: str = None) -> str:
"""Get the preferred driver for a client (hostname + token).""" """Get the preferred driver for a client worker (hostname + token + worker_name)."""
conn = get_db_connection() conn = get_db_connection()
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute('SELECT driver FROM client_driver_preferences WHERE hostname = ? AND token = ?', if worker_name:
cursor.execute('SELECT driver FROM client_driver_preferences WHERE hostname = ? AND token = ? AND worker_name = ?',
(hostname, token, worker_name))
else:
# Fallback for old format (per node)
cursor.execute('SELECT driver FROM client_driver_preferences WHERE hostname = ? AND token = ? AND worker_name = ""',
(hostname, token)) (hostname, token))
row = cursor.fetchone() row = cursor.fetchone()
conn.close() conn.close()
return row['driver'] if row else 'cuda' # Default to cuda return row['driver'] if row else 'cuda' # Default to cuda
def set_client_driver_preference(hostname: str, token: str, driver: str) -> bool: def set_client_driver_preference(hostname: str, token: str, driver: str, worker_name: str = None) -> bool:
"""Set the preferred driver for a client (hostname + token).""" """Set the preferred driver for a client worker (hostname + token + worker_name)."""
if driver not in ['cuda', 'rocm', 'cpu']: if driver not in ['cuda', 'rocm', 'cpu']:
return False return False
...@@ -1543,15 +1550,30 @@ def set_client_driver_preference(hostname: str, token: str, driver: str) -> bool ...@@ -1543,15 +1550,30 @@ def set_client_driver_preference(hostname: str, token: str, driver: str) -> bool
config = get_db_config() config = get_db_config()
if config['type'] == 'mysql': if config['type'] == 'mysql':
if worker_name:
cursor.execute('''
INSERT INTO client_driver_preferences (hostname, token, worker_name, driver, updated_at)
VALUES (?, ?, ?, ?, CURRENT_TIMESTAMP)
ON DUPLICATE KEY UPDATE driver = VALUES(driver), updated_at = CURRENT_TIMESTAMP
''', (hostname, token, worker_name, driver))
else:
# Fallback for old format
cursor.execute(''' cursor.execute('''
INSERT INTO client_driver_preferences (hostname, token, driver, updated_at) INSERT INTO client_driver_preferences (hostname, token, worker_name, driver, updated_at)
VALUES (?, ?, ?, CURRENT_TIMESTAMP) VALUES (?, ?, "", ?, CURRENT_TIMESTAMP)
ON DUPLICATE KEY UPDATE driver = VALUES(driver), updated_at = CURRENT_TIMESTAMP ON DUPLICATE KEY UPDATE driver = VALUES(driver), updated_at = CURRENT_TIMESTAMP
''', (hostname, token, driver)) ''', (hostname, token, driver))
else: else:
if worker_name:
cursor.execute(''' cursor.execute('''
INSERT OR REPLACE INTO client_driver_preferences (hostname, token, driver, updated_at) INSERT OR REPLACE INTO client_driver_preferences (hostname, token, worker_name, driver, updated_at)
VALUES (?, ?, ?, CURRENT_TIMESTAMP) VALUES (?, ?, ?, ?, CURRENT_TIMESTAMP)
''', (hostname, token, worker_name, driver))
else:
# Fallback for old format
cursor.execute('''
INSERT OR REPLACE INTO client_driver_preferences (hostname, token, worker_name, driver, updated_at)
VALUES (?, ?, "", ?, CURRENT_TIMESTAMP)
''', (hostname, token, driver)) ''', (hostname, token, driver))
conn.commit() conn.commit()
...@@ -1560,6 +1582,17 @@ def set_client_driver_preference(hostname: str, token: str, driver: str) -> bool ...@@ -1560,6 +1582,17 @@ def set_client_driver_preference(hostname: str, token: str, driver: str) -> bool
return success return success
def get_all_client_driver_preferences(hostname: str, token: str) -> dict:
"""Get all driver preferences for workers on a client (hostname + token)."""
conn = get_db_connection()
cursor = conn.cursor()
cursor.execute('SELECT worker_name, driver FROM client_driver_preferences WHERE hostname = ? AND token = ?',
(hostname, token))
rows = cursor.fetchall()
conn.close()
return {row['worker_name']: row['driver'] for row in rows}
# Admin dashboard stats functions # Admin dashboard stats functions
def get_admin_dashboard_stats() -> Dict[str, int]: def get_admin_dashboard_stats() -> Dict[str, int]:
"""Get admin dashboard statistics.""" """Get admin dashboard statistics."""
......
...@@ -635,45 +635,66 @@ def detect_local_workers(): ...@@ -635,45 +635,66 @@ def detect_local_workers():
@app.route('/api/admin/cluster_nodes/set_driver', methods=['POST']) @app.route('/api/admin/cluster_nodes/set_driver', methods=['POST'])
@admin_required @admin_required
def api_set_client_driver(): def api_set_client_driver():
"""API endpoint to set driver preference for a client or local workers.""" """API endpoint to set driver preferences for workers on a client or local workers."""
hostname = request.form.get('hostname') hostname = request.form.get('hostname')
token = request.form.get('token') token = request.form.get('token')
driver = request.form.get('driver')
if not hostname or not driver:
return {'success': False, 'error': 'Missing required parameters'}, 400
if driver not in ['cuda', 'rocm', 'cpu']: if not hostname:
return {'success': False, 'error': 'Invalid driver - only CUDA, ROCm, and CPU are supported'}, 400 return {'success': False, 'error': 'Missing hostname parameter'}, 400
# Handle local workers # Handle local workers
if token == 'local': if token == 'local':
# For local workers, expect a single driver for all workers
driver = request.form.get('driver')
if not driver or driver not in ['cuda', 'rocm', 'cpu']:
return {'success': False, 'error': 'Invalid driver - only CUDA, ROCm, and CPU are supported'}, 400
success = switch_local_worker_backends(driver) success = switch_local_worker_backends(driver)
return {'success': success} return {'success': success}
# Handle remote clients # Handle remote clients - expect worker-specific preferences
if not token: if not token:
return {'success': False, 'error': 'Missing token for remote client'}, 400 return {'success': False, 'error': 'Missing token for remote client'}, 400
from .database import set_client_driver_preference from .database import set_client_driver_preference, get_all_client_driver_preferences
from .cluster_master import cluster_master from .cluster_master import cluster_master
# Save preference to database # Parse worker driver preferences from form data
db_success = set_client_driver_preference(hostname, token, driver) worker_preferences = {}
for key, value in request.form.items():
if key.startswith('worker_') and key.endswith('_driver'):
worker_name = key[7:-7] # Remove 'worker_' prefix and '_driver' suffix
if value in ['cuda', 'rocm', 'cpu']:
worker_preferences[worker_name] = value
# Find client and send restart command if not worker_preferences:
return {'success': False, 'error': 'No valid worker driver preferences provided'}, 400
# Save preferences to database
db_success = True
for worker_name, driver in worker_preferences.items():
if not set_client_driver_preference(hostname, token, driver, worker_name):
db_success = False
# Find client and send restart command if connected
client_id = None client_id = None
for cid, client_info in cluster_master.clients.items(): for cid, client_info in cluster_master.clients.items():
if client_info['token'] == token: if client_info['token'] == token:
client_id = cid client_id = cid
break break
restart_success = True
if client_id: if client_id:
restart_success = cluster_master.restart_client_workers(client_id, driver) # Send restart command for each worker that changed
existing_prefs = get_all_client_driver_preferences(hostname, token)
for worker_name, new_driver in worker_preferences.items():
old_driver = existing_prefs.get(worker_name, 'cuda')
if old_driver != new_driver:
# Only restart if driver actually changed
worker_restart_success = cluster_master.restart_client_worker(client_id, worker_name, new_driver)
if not worker_restart_success:
restart_success = False
return {'success': db_success and restart_success} return {'success': db_success and restart_success}
else:
# Client not currently connected, just save preference
return {'success': db_success}
def switch_local_worker_backends(new_backend): def switch_local_worker_backends(new_backend):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment