Update cluster nodes API to read from database

- Store cluster client info in database for persistence
- Update API to read connected clients from database
- Maintain compatibility with existing web interface
parent 4b98cda4
......@@ -177,9 +177,10 @@ class ClusterMaster:
elif msg_type == 'heartbeat':
# Update client last seen
from .database import update_cluster_client_last_seen
client_id = self._get_client_by_websocket(websocket)
if client_id:
self.clients[client_id]['last_seen'] = time.time()
update_cluster_client_last_seen(client_id)
return {'type': 'heartbeat_ack'}
elif msg_type == 'pong':
......@@ -189,6 +190,8 @@ class ClusterMaster:
def _handle_auth(self, message: Dict[str, Any], websocket: websockets.WebSocketServerProtocol) -> Dict[str, Any]:
"""Handle client authentication."""
from .database import save_cluster_client
token = message.get('token')
client_info = message.get('client_info', {})
......@@ -204,20 +207,22 @@ class ClusterMaster:
gpu_backends = [b for b in available_backends if b in ['cuda', 'rocm']]
has_gpu = len(gpu_backends) > 0
# Store client info including GPU capabilities and weight
self.clients[client_id] = {
'token': token,
'info': client_info,
'weight': client_info.get('weight', 100),
'gpu_info': gpu_info,
'connected_at': time.time(),
'last_seen': time.time()
}
# Get hostname and IP
hostname = client_info.get('hostname', 'unknown')
ip_address = '127.0.0.1' # Placeholder, could be extracted from websocket
# Save client to database
weight = client_info.get('weight', 100)
save_cluster_client(client_id, token, hostname, ip_address, weight, gpu_info, available_backends)
# Store in memory for websocket management
self.client_websockets[client_id] = websocket
self.tokens[token] = client_id
# If this is the first client and weight wasn't explicitly set, change master weight to 0
if len(self.clients) == 1 and self.weight == 100 and not self.weight_explicit:
from .database import get_connected_cluster_clients
connected_clients = get_connected_cluster_clients()
if len(connected_clients) == 1 and self.weight == 100 and not self.weight_explicit:
self.weight = 0
print("First client connected - changing master weight to 0 (automatic)")
......@@ -262,6 +267,11 @@ class ClusterMaster:
def _remove_client(self, client_id: str) -> None:
"""Remove a client and its processes."""
from .database import disconnect_cluster_client
# Mark as disconnected in database
disconnect_cluster_client(client_id)
if client_id in self.client_websockets:
del self.client_websockets[client_id]
......
......@@ -360,6 +360,40 @@ def init_db(conn) -> None:
except sqlite3.OperationalError:
pass
# Cluster clients table
if config['type'] == 'mysql':
cursor.execute('''
CREATE TABLE IF NOT EXISTS cluster_clients (
id INT AUTO_INCREMENT PRIMARY KEY,
client_id VARCHAR(32) UNIQUE NOT NULL,
token VARCHAR(255) NOT NULL,
hostname VARCHAR(255) NOT NULL,
ip_address VARCHAR(45),
weight INT DEFAULT 100,
gpu_info TEXT,
available_backends TEXT,
connected BOOLEAN DEFAULT 1,
last_seen TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci
''')
else:
cursor.execute('''
CREATE TABLE IF NOT EXISTS cluster_clients (
id INTEGER PRIMARY KEY,
client_id TEXT UNIQUE NOT NULL,
token TEXT NOT NULL,
hostname TEXT NOT NULL,
ip_address TEXT,
weight INTEGER DEFAULT 100,
gpu_info TEXT,
available_backends TEXT,
connected BOOLEAN DEFAULT 1,
last_seen TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
''')
# Client driver preferences table (per worker)
if config['type'] == 'mysql':
cursor.execute('''
......@@ -1593,6 +1627,123 @@ def get_all_client_driver_preferences(hostname: str, token: str) -> dict:
return {row['worker_name']: row['driver'] for row in rows}
# Cluster client management functions
def save_cluster_client(client_id: str, token: str, hostname: str, ip_address: str = None,
weight: int = 100, gpu_info: dict = None, available_backends: list = None) -> bool:
"""Save or update a cluster client in the database."""
import json
conn = get_db_connection()
cursor = conn.cursor()
gpu_info_json = json.dumps(gpu_info) if gpu_info else None
available_backends_json = json.dumps(available_backends) if available_backends else None
config = get_db_config()
if config['type'] == 'mysql':
cursor.execute('''
INSERT INTO cluster_clients (client_id, token, hostname, ip_address, weight, gpu_info, available_backends, connected, last_seen)
VALUES (?, ?, ?, ?, ?, ?, ?, 1, CURRENT_TIMESTAMP)
ON DUPLICATE KEY UPDATE
hostname = VALUES(hostname),
ip_address = VALUES(ip_address),
weight = VALUES(weight),
gpu_info = VALUES(gpu_info),
available_backends = VALUES(available_backends),
connected = 1,
last_seen = CURRENT_TIMESTAMP
''', (client_id, token, hostname, ip_address, weight, gpu_info_json, available_backends_json))
else:
cursor.execute('''
INSERT OR REPLACE INTO cluster_clients
(client_id, token, hostname, ip_address, weight, gpu_info, available_backends, connected, last_seen)
VALUES (?, ?, ?, ?, ?, ?, ?, 1, CURRENT_TIMESTAMP)
''', (client_id, token, hostname, ip_address, weight, gpu_info_json, available_backends_json))
conn.commit()
success = cursor.rowcount > 0
conn.close()
return success
def update_cluster_client_last_seen(client_id: str) -> bool:
"""Update the last seen timestamp for a cluster client."""
conn = get_db_connection()
cursor = conn.cursor()
config = get_db_config()
if config['type'] == 'mysql':
cursor.execute('UPDATE cluster_clients SET last_seen = CURRENT_TIMESTAMP WHERE client_id = ?', (client_id,))
else:
cursor.execute('UPDATE cluster_clients SET last_seen = CURRENT_TIMESTAMP WHERE client_id = ?', (client_id,))
conn.commit()
success = cursor.rowcount > 0
conn.close()
return success
def disconnect_cluster_client(client_id: str) -> bool:
"""Mark a cluster client as disconnected."""
conn = get_db_connection()
cursor = conn.cursor()
cursor.execute('UPDATE cluster_clients SET connected = 0 WHERE client_id = ?', (client_id,))
conn.commit()
success = cursor.rowcount > 0
conn.close()
return success
def get_connected_cluster_clients() -> List[Dict[str, Any]]:
"""Get all connected cluster clients."""
conn = get_db_connection()
cursor = conn.cursor()
config = get_db_config()
# Get clients that are connected and seen within the last 60 seconds
if config['type'] == 'mysql':
cursor.execute('''
SELECT * FROM cluster_clients
WHERE connected = 1 AND last_seen > DATE_SUB(NOW(), INTERVAL 60 SECOND)
ORDER BY hostname
''')
else:
cursor.execute('''
SELECT * FROM cluster_clients
WHERE connected = 1 AND last_seen > datetime('now', '-60 seconds')
ORDER BY hostname
''')
rows = cursor.fetchall()
conn.close()
clients = []
import json
for row in rows:
client = dict(row)
client['gpu_info'] = json.loads(client['gpu_info']) if client['gpu_info'] else {}
client['available_backends'] = json.loads(client['available_backends']) if client['available_backends'] else []
clients.append(client)
return clients
def get_cluster_client(client_id: str) -> Optional[Dict[str, Any]]:
"""Get a specific cluster client by client_id."""
conn = get_db_connection()
cursor = conn.cursor()
cursor.execute('SELECT * FROM cluster_clients WHERE client_id = ?', (client_id,))
row = cursor.fetchone()
conn.close()
if row:
client = dict(row)
import json
client['gpu_info'] = json.loads(client['gpu_info']) if client['gpu_info'] else {}
client['available_backends'] = json.loads(client['available_backends']) if client['available_backends'] else []
return client
return None
# Admin dashboard stats functions
def get_admin_dashboard_stats() -> Dict[str, int]:
"""Get admin dashboard statistics."""
......
......@@ -380,8 +380,7 @@ def cluster_nodes():
@admin_required
def api_cluster_nodes():
"""API endpoint to get cluster nodes data."""
from .cluster_master import cluster_master
from .database import get_worker_tokens, get_client_driver_preference
from .database import get_connected_cluster_clients, get_worker_tokens
import time
# Get worker tokens for name mapping
......@@ -392,17 +391,21 @@ def api_cluster_nodes():
total_active_jobs = 0
total_completed_jobs = 0
# Get active clients - group by hostname/token (one row per node)
# Get connected clients from database
connected_clients = get_connected_cluster_clients()
# Group by hostname/token (one row per node)
node_map = {}
for client_id, client_info in cluster_master.clients.items():
hostname = client_info['info'].get('hostname', 'unknown')
token = client_info['token']
for client in connected_clients:
hostname = client['hostname']
token = client['token']
token_name = worker_tokens.get(token, 'Unknown Token')
node_key = f"{hostname}:{token}"
if node_key not in node_map:
gpu_info = client_info.get('gpu_info', {})
gpu_info = client.get('gpu_info', {})
available_backends = client.get('available_backends', [])
cuda_devices = gpu_info.get('cuda_devices', 0)
rocm_devices = gpu_info.get('rocm_devices', 0)
......@@ -415,22 +418,27 @@ def api_cluster_nodes():
total_memory = sum([8 if 'CUDA' in mem else 16 if 'ROCm' in mem else 0 for mem in gpu_memory])
# Calculate uptime
connected_at = client_info.get('connected_at', current_time)
uptime_seconds = current_time - connected_at
# Calculate uptime from last_seen
last_seen = client.get('last_seen')
if last_seen:
if isinstance(last_seen, str):
# Parse timestamp string
import datetime
last_seen = datetime.datetime.fromisoformat(last_seen.replace('Z', '+00:00')).timestamp()
uptime_seconds = current_time - last_seen
else:
uptime_seconds = 0
# Detect GPU capabilities and available backends
gpu_info = client_info.get('gpu_info', {})
available_backends = gpu_info.get('available_backends', [])
gpu_backends = [b for b in available_backends if b in ['cuda', 'rocm']]
cpu_backends = [b for b in available_backends if b == 'cpu']
has_cuda = 'cuda' in gpu_backends
has_rocm = 'rocm' in gpu_backends
has_cpu = len(cpu_backends) > 0
mixed_gpu = has_cuda and has_rocm
is_cpu_only = not gpu_backends and has_cpu
node_map[node_key] = {
'token': token,
'token_name': token_name,
......@@ -438,13 +446,13 @@ def api_cluster_nodes():
'gpus': len(gpu_memory),
'gpu_memory': gpu_memory,
'total_memory': total_memory,
'ip_address': '127.0.0.1', # Placeholder
'connected': True,
'last_seen': client_info.get('last_seen', 0),
'ip_address': client.get('ip_address', '127.0.0.1'),
'connected': client.get('connected', True),
'last_seen': last_seen or 0,
'uptime_seconds': uptime_seconds,
'active_jobs': 0, # Placeholder
'completed_jobs': 0, # Placeholder
'weight': client_info.get('weight', 100),
'weight': client.get('weight', 100),
'is_local': False,
'mixed_gpu': mixed_gpu,
'is_cpu_only': is_cpu_only,
......@@ -452,17 +460,28 @@ def api_cluster_nodes():
'workers': [] # Will collect worker details
}
# Add worker processes for this node
node_workers = [p for p in cluster_master.processes.values() if p['client_id'] == client_id]
for proc in node_workers:
worker_info = {
'type': proc['name'].split('_')[0], # analysis or training
'backend': proc['name'].split('_')[1] if '_' in proc['name'] else 'unknown',
'weight': proc.get('weight', 10),
'model': proc.get('model', 'default'),
'status': proc.get('status', 'active')
}
node_map[node_key]['workers'].append(worker_info)
# For now, we don't have worker details in database, so we'll use placeholder
# In a full implementation, you'd store worker info in database too
# For now, assume standard workers based on available backends
workers = []
for backend in available_backends:
workers.extend([
{
'type': 'analysis',
'backend': backend,
'weight': 10,
'model': 'default',
'status': 'active'
},
{
'type': 'training',
'backend': backend,
'weight': 5,
'model': 'default',
'status': 'active'
}
])
node_map[node_key]['workers'] = workers
# Convert node_map to nodes list
for node_key, node_data in node_map.items():
......@@ -603,6 +622,7 @@ def api_cluster_nodes():
nodes.sort(key=lambda x: (not x['connected'], -x['last_seen']))
# Cluster master stats
from .cluster_master import cluster_master
master_uptime = current_time - getattr(cluster_master, 'start_time', current_time)
master_stats = {
'total_nodes': len(nodes),
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment