Implement per-worker driver selection modal

- Modified modal to show individual GPU-requiring workers on each node
- Allow granular driver selection (CUDA/ROCm/CPU) for each worker subprocess
- Updated database schema to store driver preferences per worker (hostname + token + worker_name)
- Enhanced API to handle per-worker driver setting with form field parsing
- Added restart_client_worker method to cluster master for individual worker restarts
- Frontend now displays worker-specific driver selection controls in modal
- Maintains node-level table view while providing worker-level configuration
- Supports CPU-only nodes and mixed GPU/CPU worker configurations
- Backward compatible with existing single-driver preference system
parent 5cbdab26
......@@ -121,20 +121,17 @@
<div class="modal-content">
<div class="modal-header">
<span class="close" onclick="closeModal()">&times;</span>
<h3>Select Driver for <span id="modalHostname"></span></h3>
<h3>Select Drivers for Workers on <span id="modalHostname"></span></h3>
</div>
<form id="driverForm">
<input type="hidden" id="modalHostnameInput" name="hostname">
<input type="hidden" id="modalTokenInput" name="token">
<div class="form-group">
<label for="driverSelect">Preferred GPU Driver:</label>
<select id="driverSelect" name="driver">
<!-- Options will be populated dynamically -->
</select>
<div id="workerDriverSelections">
<!-- Worker driver selections will be populated dynamically -->
</div>
<div class="modal-footer">
<button type="button" class="btn btn-secondary" onclick="closeModal()">Cancel</button>
<button type="submit" class="btn">Save Preference</button>
<button type="submit" class="btn">Save Preferences</button>
</div>
</form>
</div>
......@@ -214,7 +211,7 @@ function renderNodesTable() {
}
function openDriverModal(hostname, token, hostnameValue) {
// Find the node data to get available backends
// Find the node data
const node = nodesData.find(n => n.hostname === hostname && n.token === token);
if (!node) {
console.error('Node not found:', hostname, token);
......@@ -225,30 +222,44 @@ function openDriverModal(hostname, token, hostnameValue) {
document.getElementById('modalHostnameInput').value = hostnameValue;
document.getElementById('modalTokenInput').value = token;
// Populate driver options based on available backends (GPU and CPU)
const driverSelect = document.getElementById('driverSelect');
driverSelect.innerHTML = '';
// Populate worker driver selections
const container = document.getElementById('workerDriverSelections');
container.innerHTML = '';
const availableBackends = node.available_backends || [];
if (availableBackends.length === 0) {
// Fallback for nodes without backend info
if (node.is_cpu_only) {
availableBackends.push('cpu');
} else {
availableBackends.push('cuda', 'rocm');
// Get GPU-requiring workers (analysis_cuda, training_cuda, analysis_rocm, training_rocm)
const gpuWorkers = [];
if (node.workers && node.workers.length > 0) {
node.workers.forEach(worker => {
if (worker.backend !== 'cpu') {
gpuWorkers.push(worker);
}
});
}
availableBackends.forEach(backend => {
const option = document.createElement('option');
option.value = backend;
option.textContent = backend.toUpperCase();
driverSelect.appendChild(option);
});
if (gpuWorkers.length === 0) {
container.innerHTML = '<p>No GPU-requiring workers found on this node.</p>';
return;
}
// Set CUDA as default if available, otherwise first available
const defaultBackend = availableBackends.includes('cuda') ? 'cuda' : availableBackends[0];
driverSelect.value = defaultBackend;
// Create form fields for each GPU worker
gpuWorkers.forEach(worker => {
const workerDiv = document.createElement('div');
workerDiv.className = 'form-group';
workerDiv.innerHTML = `
<label for="worker_${worker.type}_${worker.backend}_driver">${worker.type} (${worker.backend.toUpperCase()}) Worker:</label>
<select id="worker_${worker.type}_${worker.backend}_driver" name="worker_${worker.type}_${worker.backend}_driver">
<option value="cuda">CUDA</option>
<option value="rocm">ROCm</option>
<option value="cpu">CPU</option>
</select>
`;
container.appendChild(workerDiv);
// Set current preference or default
const select = workerDiv.querySelector('select');
// For now, default to the worker's current backend
select.value = worker.backend;
});
document.getElementById('driverModal').style.display = 'block';
}
......
......@@ -615,6 +615,29 @@ class ClusterMaster:
))
return True
def restart_client_worker(self, client_id: str, worker_name: str, backend: str) -> bool:
"""Restart a specific worker on a client with a different backend."""
if client_id not in self.client_websockets:
return False
if backend not in ['cuda', 'rocm', 'cpu']:
print(f"Invalid backend requested: {backend} - only CUDA, ROCm, and CPU supported")
return False
# Send restart command for specific worker to client
try:
asyncio.create_task(self.client_websockets[client_id].send(
json.dumps({
'type': 'restart_worker',
'worker_name': worker_name,
'backend': backend
})
))
return True
except Exception as e:
print(f"Failed to send restart command for worker {worker_name} to client {client_id}: {e}")
return False
async def _management_loop(self) -> None:
"""Main management loop."""
while self.running:
......
......@@ -360,17 +360,18 @@ def init_db(conn) -> None:
except sqlite3.OperationalError:
pass
# Client driver preferences table
# Client driver preferences table (per worker)
if config['type'] == 'mysql':
cursor.execute('''
CREATE TABLE IF NOT EXISTS client_driver_preferences (
id INT AUTO_INCREMENT PRIMARY KEY,
hostname VARCHAR(255) NOT NULL,
token VARCHAR(255) NOT NULL,
worker_name VARCHAR(255) NOT NULL,
driver VARCHAR(10) NOT NULL DEFAULT 'cuda',
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
UNIQUE KEY unique_hostname_token (hostname, token)
UNIQUE KEY unique_hostname_token_worker (hostname, token, worker_name)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci
''')
else:
......@@ -379,10 +380,11 @@ def init_db(conn) -> None:
id INTEGER PRIMARY KEY,
hostname TEXT NOT NULL,
token TEXT NOT NULL,
worker_name TEXT NOT NULL,
driver TEXT NOT NULL DEFAULT 'cuda',
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
UNIQUE(hostname, token)
UNIQUE(hostname, token, worker_name)
)
''')
......@@ -1521,20 +1523,25 @@ def cleanup_expired_sessions() -> None:
conn.close()
# Client driver preferences functions
def get_client_driver_preference(hostname: str, token: str) -> str:
"""Get the preferred driver for a client (hostname + token)."""
# Client driver preferences functions (per worker)
def get_client_driver_preference(hostname: str, token: str, worker_name: str = None) -> str:
"""Get the preferred driver for a client worker (hostname + token + worker_name)."""
conn = get_db_connection()
cursor = conn.cursor()
cursor.execute('SELECT driver FROM client_driver_preferences WHERE hostname = ? AND token = ?',
if worker_name:
cursor.execute('SELECT driver FROM client_driver_preferences WHERE hostname = ? AND token = ? AND worker_name = ?',
(hostname, token, worker_name))
else:
# Fallback for old format (per node)
cursor.execute('SELECT driver FROM client_driver_preferences WHERE hostname = ? AND token = ? AND worker_name = ""',
(hostname, token))
row = cursor.fetchone()
conn.close()
return row['driver'] if row else 'cuda' # Default to cuda
def set_client_driver_preference(hostname: str, token: str, driver: str) -> bool:
"""Set the preferred driver for a client (hostname + token)."""
def set_client_driver_preference(hostname: str, token: str, driver: str, worker_name: str = None) -> bool:
"""Set the preferred driver for a client worker (hostname + token + worker_name)."""
if driver not in ['cuda', 'rocm', 'cpu']:
return False
......@@ -1543,15 +1550,30 @@ def set_client_driver_preference(hostname: str, token: str, driver: str) -> bool
config = get_db_config()
if config['type'] == 'mysql':
if worker_name:
cursor.execute('''
INSERT INTO client_driver_preferences (hostname, token, worker_name, driver, updated_at)
VALUES (?, ?, ?, ?, CURRENT_TIMESTAMP)
ON DUPLICATE KEY UPDATE driver = VALUES(driver), updated_at = CURRENT_TIMESTAMP
''', (hostname, token, worker_name, driver))
else:
# Fallback for old format
cursor.execute('''
INSERT INTO client_driver_preferences (hostname, token, driver, updated_at)
VALUES (?, ?, ?, CURRENT_TIMESTAMP)
INSERT INTO client_driver_preferences (hostname, token, worker_name, driver, updated_at)
VALUES (?, ?, "", ?, CURRENT_TIMESTAMP)
ON DUPLICATE KEY UPDATE driver = VALUES(driver), updated_at = CURRENT_TIMESTAMP
''', (hostname, token, driver))
else:
if worker_name:
cursor.execute('''
INSERT OR REPLACE INTO client_driver_preferences (hostname, token, driver, updated_at)
VALUES (?, ?, ?, CURRENT_TIMESTAMP)
INSERT OR REPLACE INTO client_driver_preferences (hostname, token, worker_name, driver, updated_at)
VALUES (?, ?, ?, ?, CURRENT_TIMESTAMP)
''', (hostname, token, worker_name, driver))
else:
# Fallback for old format
cursor.execute('''
INSERT OR REPLACE INTO client_driver_preferences (hostname, token, worker_name, driver, updated_at)
VALUES (?, ?, "", ?, CURRENT_TIMESTAMP)
''', (hostname, token, driver))
conn.commit()
......@@ -1560,6 +1582,17 @@ def set_client_driver_preference(hostname: str, token: str, driver: str) -> bool
return success
def get_all_client_driver_preferences(hostname: str, token: str) -> dict:
"""Get all driver preferences for workers on a client (hostname + token)."""
conn = get_db_connection()
cursor = conn.cursor()
cursor.execute('SELECT worker_name, driver FROM client_driver_preferences WHERE hostname = ? AND token = ?',
(hostname, token))
rows = cursor.fetchall()
conn.close()
return {row['worker_name']: row['driver'] for row in rows}
# Admin dashboard stats functions
def get_admin_dashboard_stats() -> Dict[str, int]:
"""Get admin dashboard statistics."""
......
......@@ -635,45 +635,66 @@ def detect_local_workers():
@app.route('/api/admin/cluster_nodes/set_driver', methods=['POST'])
@admin_required
def api_set_client_driver():
"""API endpoint to set driver preference for a client or local workers."""
"""API endpoint to set driver preferences for workers on a client or local workers."""
hostname = request.form.get('hostname')
token = request.form.get('token')
driver = request.form.get('driver')
if not hostname or not driver:
return {'success': False, 'error': 'Missing required parameters'}, 400
if driver not in ['cuda', 'rocm', 'cpu']:
return {'success': False, 'error': 'Invalid driver - only CUDA, ROCm, and CPU are supported'}, 400
if not hostname:
return {'success': False, 'error': 'Missing hostname parameter'}, 400
# Handle local workers
if token == 'local':
# For local workers, expect a single driver for all workers
driver = request.form.get('driver')
if not driver or driver not in ['cuda', 'rocm', 'cpu']:
return {'success': False, 'error': 'Invalid driver - only CUDA, ROCm, and CPU are supported'}, 400
success = switch_local_worker_backends(driver)
return {'success': success}
# Handle remote clients
# Handle remote clients - expect worker-specific preferences
if not token:
return {'success': False, 'error': 'Missing token for remote client'}, 400
from .database import set_client_driver_preference
from .database import set_client_driver_preference, get_all_client_driver_preferences
from .cluster_master import cluster_master
# Save preference to database
db_success = set_client_driver_preference(hostname, token, driver)
# Parse worker driver preferences from form data
worker_preferences = {}
for key, value in request.form.items():
if key.startswith('worker_') and key.endswith('_driver'):
worker_name = key[7:-7] # Remove 'worker_' prefix and '_driver' suffix
if value in ['cuda', 'rocm', 'cpu']:
worker_preferences[worker_name] = value
# Find client and send restart command
if not worker_preferences:
return {'success': False, 'error': 'No valid worker driver preferences provided'}, 400
# Save preferences to database
db_success = True
for worker_name, driver in worker_preferences.items():
if not set_client_driver_preference(hostname, token, driver, worker_name):
db_success = False
# Find client and send restart command if connected
client_id = None
for cid, client_info in cluster_master.clients.items():
if client_info['token'] == token:
client_id = cid
break
restart_success = True
if client_id:
restart_success = cluster_master.restart_client_workers(client_id, driver)
# Send restart command for each worker that changed
existing_prefs = get_all_client_driver_preferences(hostname, token)
for worker_name, new_driver in worker_preferences.items():
old_driver = existing_prefs.get(worker_name, 'cuda')
if old_driver != new_driver:
# Only restart if driver actually changed
worker_restart_success = cluster_master.restart_client_worker(client_id, worker_name, new_driver)
if not worker_restart_success:
restart_success = False
return {'success': db_success and restart_success}
else:
# Client not currently connected, just save preference
return {'success': db_success}
def switch_local_worker_backends(new_backend):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment