Update cluster nodes API to read from database

- Store cluster client info in database for persistence
- Update API to read connected clients from database
- Maintain compatibility with existing web interface
parent 4b98cda4
...@@ -177,9 +177,10 @@ class ClusterMaster: ...@@ -177,9 +177,10 @@ class ClusterMaster:
elif msg_type == 'heartbeat': elif msg_type == 'heartbeat':
# Update client last seen # Update client last seen
from .database import update_cluster_client_last_seen
client_id = self._get_client_by_websocket(websocket) client_id = self._get_client_by_websocket(websocket)
if client_id: if client_id:
self.clients[client_id]['last_seen'] = time.time() update_cluster_client_last_seen(client_id)
return {'type': 'heartbeat_ack'} return {'type': 'heartbeat_ack'}
elif msg_type == 'pong': elif msg_type == 'pong':
...@@ -189,6 +190,8 @@ class ClusterMaster: ...@@ -189,6 +190,8 @@ class ClusterMaster:
def _handle_auth(self, message: Dict[str, Any], websocket: websockets.WebSocketServerProtocol) -> Dict[str, Any]: def _handle_auth(self, message: Dict[str, Any], websocket: websockets.WebSocketServerProtocol) -> Dict[str, Any]:
"""Handle client authentication.""" """Handle client authentication."""
from .database import save_cluster_client
token = message.get('token') token = message.get('token')
client_info = message.get('client_info', {}) client_info = message.get('client_info', {})
...@@ -204,20 +207,22 @@ class ClusterMaster: ...@@ -204,20 +207,22 @@ class ClusterMaster:
gpu_backends = [b for b in available_backends if b in ['cuda', 'rocm']] gpu_backends = [b for b in available_backends if b in ['cuda', 'rocm']]
has_gpu = len(gpu_backends) > 0 has_gpu = len(gpu_backends) > 0
# Store client info including GPU capabilities and weight # Get hostname and IP
self.clients[client_id] = { hostname = client_info.get('hostname', 'unknown')
'token': token, ip_address = '127.0.0.1' # Placeholder, could be extracted from websocket
'info': client_info,
'weight': client_info.get('weight', 100), # Save client to database
'gpu_info': gpu_info, weight = client_info.get('weight', 100)
'connected_at': time.time(), save_cluster_client(client_id, token, hostname, ip_address, weight, gpu_info, available_backends)
'last_seen': time.time()
} # Store in memory for websocket management
self.client_websockets[client_id] = websocket self.client_websockets[client_id] = websocket
self.tokens[token] = client_id self.tokens[token] = client_id
# If this is the first client and weight wasn't explicitly set, change master weight to 0 # If this is the first client and weight wasn't explicitly set, change master weight to 0
if len(self.clients) == 1 and self.weight == 100 and not self.weight_explicit: from .database import get_connected_cluster_clients
connected_clients = get_connected_cluster_clients()
if len(connected_clients) == 1 and self.weight == 100 and not self.weight_explicit:
self.weight = 0 self.weight = 0
print("First client connected - changing master weight to 0 (automatic)") print("First client connected - changing master weight to 0 (automatic)")
...@@ -262,6 +267,11 @@ class ClusterMaster: ...@@ -262,6 +267,11 @@ class ClusterMaster:
def _remove_client(self, client_id: str) -> None: def _remove_client(self, client_id: str) -> None:
"""Remove a client and its processes.""" """Remove a client and its processes."""
from .database import disconnect_cluster_client
# Mark as disconnected in database
disconnect_cluster_client(client_id)
if client_id in self.client_websockets: if client_id in self.client_websockets:
del self.client_websockets[client_id] del self.client_websockets[client_id]
......
...@@ -360,6 +360,40 @@ def init_db(conn) -> None: ...@@ -360,6 +360,40 @@ def init_db(conn) -> None:
except sqlite3.OperationalError: except sqlite3.OperationalError:
pass pass
# Cluster clients table
if config['type'] == 'mysql':
cursor.execute('''
CREATE TABLE IF NOT EXISTS cluster_clients (
id INT AUTO_INCREMENT PRIMARY KEY,
client_id VARCHAR(32) UNIQUE NOT NULL,
token VARCHAR(255) NOT NULL,
hostname VARCHAR(255) NOT NULL,
ip_address VARCHAR(45),
weight INT DEFAULT 100,
gpu_info TEXT,
available_backends TEXT,
connected BOOLEAN DEFAULT 1,
last_seen TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci
''')
else:
cursor.execute('''
CREATE TABLE IF NOT EXISTS cluster_clients (
id INTEGER PRIMARY KEY,
client_id TEXT UNIQUE NOT NULL,
token TEXT NOT NULL,
hostname TEXT NOT NULL,
ip_address TEXT,
weight INTEGER DEFAULT 100,
gpu_info TEXT,
available_backends TEXT,
connected BOOLEAN DEFAULT 1,
last_seen TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
''')
# Client driver preferences table (per worker) # Client driver preferences table (per worker)
if config['type'] == 'mysql': if config['type'] == 'mysql':
cursor.execute(''' cursor.execute('''
...@@ -1593,6 +1627,123 @@ def get_all_client_driver_preferences(hostname: str, token: str) -> dict: ...@@ -1593,6 +1627,123 @@ def get_all_client_driver_preferences(hostname: str, token: str) -> dict:
return {row['worker_name']: row['driver'] for row in rows} return {row['worker_name']: row['driver'] for row in rows}
# Cluster client management functions
def save_cluster_client(client_id: str, token: str, hostname: str, ip_address: str = None,
weight: int = 100, gpu_info: dict = None, available_backends: list = None) -> bool:
"""Save or update a cluster client in the database."""
import json
conn = get_db_connection()
cursor = conn.cursor()
gpu_info_json = json.dumps(gpu_info) if gpu_info else None
available_backends_json = json.dumps(available_backends) if available_backends else None
config = get_db_config()
if config['type'] == 'mysql':
cursor.execute('''
INSERT INTO cluster_clients (client_id, token, hostname, ip_address, weight, gpu_info, available_backends, connected, last_seen)
VALUES (?, ?, ?, ?, ?, ?, ?, 1, CURRENT_TIMESTAMP)
ON DUPLICATE KEY UPDATE
hostname = VALUES(hostname),
ip_address = VALUES(ip_address),
weight = VALUES(weight),
gpu_info = VALUES(gpu_info),
available_backends = VALUES(available_backends),
connected = 1,
last_seen = CURRENT_TIMESTAMP
''', (client_id, token, hostname, ip_address, weight, gpu_info_json, available_backends_json))
else:
cursor.execute('''
INSERT OR REPLACE INTO cluster_clients
(client_id, token, hostname, ip_address, weight, gpu_info, available_backends, connected, last_seen)
VALUES (?, ?, ?, ?, ?, ?, ?, 1, CURRENT_TIMESTAMP)
''', (client_id, token, hostname, ip_address, weight, gpu_info_json, available_backends_json))
conn.commit()
success = cursor.rowcount > 0
conn.close()
return success
def update_cluster_client_last_seen(client_id: str) -> bool:
"""Update the last seen timestamp for a cluster client."""
conn = get_db_connection()
cursor = conn.cursor()
config = get_db_config()
if config['type'] == 'mysql':
cursor.execute('UPDATE cluster_clients SET last_seen = CURRENT_TIMESTAMP WHERE client_id = ?', (client_id,))
else:
cursor.execute('UPDATE cluster_clients SET last_seen = CURRENT_TIMESTAMP WHERE client_id = ?', (client_id,))
conn.commit()
success = cursor.rowcount > 0
conn.close()
return success
def disconnect_cluster_client(client_id: str) -> bool:
"""Mark a cluster client as disconnected."""
conn = get_db_connection()
cursor = conn.cursor()
cursor.execute('UPDATE cluster_clients SET connected = 0 WHERE client_id = ?', (client_id,))
conn.commit()
success = cursor.rowcount > 0
conn.close()
return success
def get_connected_cluster_clients() -> List[Dict[str, Any]]:
"""Get all connected cluster clients."""
conn = get_db_connection()
cursor = conn.cursor()
config = get_db_config()
# Get clients that are connected and seen within the last 60 seconds
if config['type'] == 'mysql':
cursor.execute('''
SELECT * FROM cluster_clients
WHERE connected = 1 AND last_seen > DATE_SUB(NOW(), INTERVAL 60 SECOND)
ORDER BY hostname
''')
else:
cursor.execute('''
SELECT * FROM cluster_clients
WHERE connected = 1 AND last_seen > datetime('now', '-60 seconds')
ORDER BY hostname
''')
rows = cursor.fetchall()
conn.close()
clients = []
import json
for row in rows:
client = dict(row)
client['gpu_info'] = json.loads(client['gpu_info']) if client['gpu_info'] else {}
client['available_backends'] = json.loads(client['available_backends']) if client['available_backends'] else []
clients.append(client)
return clients
def get_cluster_client(client_id: str) -> Optional[Dict[str, Any]]:
"""Get a specific cluster client by client_id."""
conn = get_db_connection()
cursor = conn.cursor()
cursor.execute('SELECT * FROM cluster_clients WHERE client_id = ?', (client_id,))
row = cursor.fetchone()
conn.close()
if row:
client = dict(row)
import json
client['gpu_info'] = json.loads(client['gpu_info']) if client['gpu_info'] else {}
client['available_backends'] = json.loads(client['available_backends']) if client['available_backends'] else []
return client
return None
# Admin dashboard stats functions # Admin dashboard stats functions
def get_admin_dashboard_stats() -> Dict[str, int]: def get_admin_dashboard_stats() -> Dict[str, int]:
"""Get admin dashboard statistics.""" """Get admin dashboard statistics."""
......
...@@ -380,8 +380,7 @@ def cluster_nodes(): ...@@ -380,8 +380,7 @@ def cluster_nodes():
@admin_required @admin_required
def api_cluster_nodes(): def api_cluster_nodes():
"""API endpoint to get cluster nodes data.""" """API endpoint to get cluster nodes data."""
from .cluster_master import cluster_master from .database import get_connected_cluster_clients, get_worker_tokens
from .database import get_worker_tokens, get_client_driver_preference
import time import time
# Get worker tokens for name mapping # Get worker tokens for name mapping
...@@ -392,17 +391,21 @@ def api_cluster_nodes(): ...@@ -392,17 +391,21 @@ def api_cluster_nodes():
total_active_jobs = 0 total_active_jobs = 0
total_completed_jobs = 0 total_completed_jobs = 0
# Get active clients - group by hostname/token (one row per node) # Get connected clients from database
connected_clients = get_connected_cluster_clients()
# Group by hostname/token (one row per node)
node_map = {} node_map = {}
for client_id, client_info in cluster_master.clients.items(): for client in connected_clients:
hostname = client_info['info'].get('hostname', 'unknown') hostname = client['hostname']
token = client_info['token'] token = client['token']
token_name = worker_tokens.get(token, 'Unknown Token') token_name = worker_tokens.get(token, 'Unknown Token')
node_key = f"{hostname}:{token}" node_key = f"{hostname}:{token}"
if node_key not in node_map: if node_key not in node_map:
gpu_info = client_info.get('gpu_info', {}) gpu_info = client.get('gpu_info', {})
available_backends = client.get('available_backends', [])
cuda_devices = gpu_info.get('cuda_devices', 0) cuda_devices = gpu_info.get('cuda_devices', 0)
rocm_devices = gpu_info.get('rocm_devices', 0) rocm_devices = gpu_info.get('rocm_devices', 0)
...@@ -415,13 +418,18 @@ def api_cluster_nodes(): ...@@ -415,13 +418,18 @@ def api_cluster_nodes():
total_memory = sum([8 if 'CUDA' in mem else 16 if 'ROCm' in mem else 0 for mem in gpu_memory]) total_memory = sum([8 if 'CUDA' in mem else 16 if 'ROCm' in mem else 0 for mem in gpu_memory])
# Calculate uptime # Calculate uptime from last_seen
connected_at = client_info.get('connected_at', current_time) last_seen = client.get('last_seen')
uptime_seconds = current_time - connected_at if last_seen:
if isinstance(last_seen, str):
# Parse timestamp string
import datetime
last_seen = datetime.datetime.fromisoformat(last_seen.replace('Z', '+00:00')).timestamp()
uptime_seconds = current_time - last_seen
else:
uptime_seconds = 0
# Detect GPU capabilities and available backends # Detect GPU capabilities and available backends
gpu_info = client_info.get('gpu_info', {})
available_backends = gpu_info.get('available_backends', [])
gpu_backends = [b for b in available_backends if b in ['cuda', 'rocm']] gpu_backends = [b for b in available_backends if b in ['cuda', 'rocm']]
cpu_backends = [b for b in available_backends if b == 'cpu'] cpu_backends = [b for b in available_backends if b == 'cpu']
...@@ -438,13 +446,13 @@ def api_cluster_nodes(): ...@@ -438,13 +446,13 @@ def api_cluster_nodes():
'gpus': len(gpu_memory), 'gpus': len(gpu_memory),
'gpu_memory': gpu_memory, 'gpu_memory': gpu_memory,
'total_memory': total_memory, 'total_memory': total_memory,
'ip_address': '127.0.0.1', # Placeholder 'ip_address': client.get('ip_address', '127.0.0.1'),
'connected': True, 'connected': client.get('connected', True),
'last_seen': client_info.get('last_seen', 0), 'last_seen': last_seen or 0,
'uptime_seconds': uptime_seconds, 'uptime_seconds': uptime_seconds,
'active_jobs': 0, # Placeholder 'active_jobs': 0, # Placeholder
'completed_jobs': 0, # Placeholder 'completed_jobs': 0, # Placeholder
'weight': client_info.get('weight', 100), 'weight': client.get('weight', 100),
'is_local': False, 'is_local': False,
'mixed_gpu': mixed_gpu, 'mixed_gpu': mixed_gpu,
'is_cpu_only': is_cpu_only, 'is_cpu_only': is_cpu_only,
...@@ -452,17 +460,28 @@ def api_cluster_nodes(): ...@@ -452,17 +460,28 @@ def api_cluster_nodes():
'workers': [] # Will collect worker details 'workers': [] # Will collect worker details
} }
# Add worker processes for this node # For now, we don't have worker details in database, so we'll use placeholder
node_workers = [p for p in cluster_master.processes.values() if p['client_id'] == client_id] # In a full implementation, you'd store worker info in database too
for proc in node_workers: # For now, assume standard workers based on available backends
worker_info = { workers = []
'type': proc['name'].split('_')[0], # analysis or training for backend in available_backends:
'backend': proc['name'].split('_')[1] if '_' in proc['name'] else 'unknown', workers.extend([
'weight': proc.get('weight', 10), {
'model': proc.get('model', 'default'), 'type': 'analysis',
'status': proc.get('status', 'active') 'backend': backend,
'weight': 10,
'model': 'default',
'status': 'active'
},
{
'type': 'training',
'backend': backend,
'weight': 5,
'model': 'default',
'status': 'active'
} }
node_map[node_key]['workers'].append(worker_info) ])
node_map[node_key]['workers'] = workers
# Convert node_map to nodes list # Convert node_map to nodes list
for node_key, node_data in node_map.items(): for node_key, node_data in node_map.items():
...@@ -603,6 +622,7 @@ def api_cluster_nodes(): ...@@ -603,6 +622,7 @@ def api_cluster_nodes():
nodes.sort(key=lambda x: (not x['connected'], -x['last_seen'])) nodes.sort(key=lambda x: (not x['connected'], -x['last_seen']))
# Cluster master stats # Cluster master stats
from .cluster_master import cluster_master
master_uptime = current_time - getattr(cluster_master, 'start_time', current_time) master_uptime = current_time - getattr(cluster_master, 'start_time', current_time)
master_stats = { master_stats = {
'total_nodes': len(nodes), 'total_nodes': len(nodes),
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment