Implement secure websockets for cluster master and client with auto-generated...

Implement secure websockets for cluster master and client with auto-generated self-signed certificates
parent bb0f720a
...@@ -8,4 +8,5 @@ flash-attn>=2.0.0 ...@@ -8,4 +8,5 @@ flash-attn>=2.0.0
pyinstaller>=5.0.0 pyinstaller>=5.0.0
PyMySQL>=1.0.0 PyMySQL>=1.0.0
redis>=4.0.0 redis>=4.0.0
websockets>=12.0.0 websockets>=12.0.0
\ No newline at end of file cryptography>=42.0.0
\ No newline at end of file
...@@ -27,6 +27,7 @@ import sys ...@@ -27,6 +27,7 @@ import sys
import subprocess import subprocess
import asyncio import asyncio
import websockets import websockets
import ssl
from typing import Dict, Any, Optional from typing import Dict, Any, Optional
from .comm import Message from .comm import Message
from .config import get_analysis_backend, get_training_backend from .config import get_analysis_backend, get_training_backend
...@@ -50,10 +51,15 @@ class ClusterClient: ...@@ -50,10 +51,15 @@ class ClusterClient:
self.loop = None self.loop = None
async def connect(self) -> bool: async def connect(self) -> bool:
"""Connect to cluster master via websocket.""" """Connect to cluster master via secure websocket."""
try: try:
uri = f"ws://{self.host}:{self.port}/cluster" # Create SSL context that accepts self-signed certificates
self.websocket = await websockets.connect(uri) ssl_context = ssl.create_default_context()
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE
uri = f"wss://{self.host}:{self.port}/cluster"
self.websocket = await websockets.connect(uri, ssl=ssl_context)
# Detect available backends # Detect available backends
from .compat import detect_gpu_backends, get_available_backends from .compat import detect_gpu_backends, get_available_backends
......
...@@ -26,6 +26,9 @@ import time ...@@ -26,6 +26,9 @@ import time
import hashlib import hashlib
import asyncio import asyncio
import websockets import websockets
import ssl
import os
import ipaddress
from typing import Dict, Any, List, Optional from typing import Dict, Any, List, Optional
from collections import defaultdict from collections import defaultdict
...@@ -45,13 +48,84 @@ class ClusterMaster: ...@@ -45,13 +48,84 @@ class ClusterMaster:
# Load balancing # Load balancing
self.process_queue = defaultdict(list) # process_type -> [(client_id, weight), ...] self.process_queue = defaultdict(list) # process_type -> [(client_id, weight), ...]
def _generate_ssl_cert(self) -> ssl.SSLContext:
"""Generate self-signed SSL certificate for secure websockets."""
cert_file = 'cluster.crt'
key_file = 'cluster.key'
if not os.path.exists(cert_file) or not os.path.exists(key_file):
print("Generating self-signed SSL certificate for cluster...")
try:
from cryptography import x509
from cryptography.x509.oid import NameOID
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import rsa
import datetime
# Generate private key
private_key = rsa.generate_private_key(
public_exponent=65537,
key_size=2048
)
# Generate certificate
subject = issuer = x509.Name([
x509.NameAttribute(NameOID.COMMON_NAME, "VidAI Cluster Master"),
x509.NameAttribute(NameOID.ORGANIZATION_NAME, "VidAI"),
])
cert = x509.CertificateBuilder().subject_name(
subject
).issuer_name(
issuer
).public_key(
private_key.public_key()
).serial_number(
x509.random_serial_number()
).not_valid_before(
datetime.datetime.utcnow()
).not_valid_after(
datetime.datetime.utcnow() + datetime.timedelta(days=365)
).add_extension(
x509.SubjectAlternativeName([
x509.IPAddress(ipaddress.IPv4Address('127.0.0.1')),
x509.IPAddress(ipaddress.IPv4Address('0.0.0.0')),
]),
critical=False,
).sign(private_key, hashes.SHA256())
# Save certificate and key
with open(key_file, "wb") as f:
f.write(private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption()
))
with open(cert_file, "wb") as f:
f.write(cert.public_bytes(serialization.Encoding.PEM))
print(f"SSL certificate generated: {cert_file}, {key_file}")
except ImportError:
print("cryptography library not available, cannot generate SSL certificate")
raise
# Create SSL context
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
ssl_context.load_cert_chain(cert_file, key_file)
return ssl_context
async def start(self) -> None: async def start(self) -> None:
"""Start the cluster master server.""" """Start the cluster master server."""
self.running = True self.running = True
print(f"Cluster master started on port {self.port}") print(f"Cluster master started on port {self.port} (secure websocket)")
# Generate/load SSL certificate
ssl_context = self._generate_ssl_cert()
# Start websocket server # Start secure websocket server
start_server = websockets.serve(self._handle_client, '0.0.0.0', self.port) start_server = websockets.serve(self._handle_client, '0.0.0.0', self.port, ssl=ssl_context)
await start_server await start_server
# Start management loop in background # Start management loop in background
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment