New version!

parent 3716e54b
This diff is collapsed.
......@@ -81,6 +81,7 @@ body,.table-wrap,.modal-box,.chat-messages,.studio .model-list,.studio .chat-msg
/* ── Main ────────────────────────────────────────────────────────── */
.main{min-height:calc(100vh - 44px)}
.container{max-width:1100px;margin:0 auto;padding:2rem 1.5rem}
.container--full{max-width:100%;padding:2rem 1.5rem}
/* ── Page header ─────────────────────────────────────────────────── */
.page-header{display:flex;justify-content:space-between;align-items:flex-start;margin-bottom:1.5rem;gap:1rem}
......@@ -125,6 +126,8 @@ body,.table-wrap,.modal-box,.chat-messages,.studio .model-list,.studio .chat-msg
.btn-ghost:hover{color:var(--text);border-color:var(--border-2)}
.btn-danger{background:rgba(248,113,113,.08);color:var(--red);border:1px solid rgba(248,113,113,.2)}
.btn-danger:hover{background:rgba(248,113,113,.15);border-color:rgba(248,113,113,.4)}
.btn-warn{background:rgba(251,191,36,.08);color:#f59e0b;border:1px solid rgba(251,191,36,.25)}
.btn-warn:hover{background:rgba(251,191,36,.15);border-color:rgba(251,191,36,.45)}
.btn-sm{padding:.25rem .625rem;font-size:12px}
.btn-sm svg{width:11px;height:11px}
.btn:disabled{opacity:.4;cursor:not-allowed}
......@@ -176,6 +179,7 @@ td code{font-family:var(--mono);font-size:11.5px;background:var(--raised);paddin
.badge-user{background:var(--raised);color:var(--text-3);border:1px solid var(--border)}
.badge-ok{background:rgba(52,211,153,.08);color:var(--green);border:1px solid rgba(52,211,153,.2)}
.badge-warn{background:rgba(251,191,36,.08);color:#f59e0b;border:1px solid rgba(251,191,36,.2)}
.badge-err{background:rgba(248,113,113,.08);color:var(--red);border:1px solid rgba(248,113,113,.2)}
.badge-danger{background:rgba(248,113,113,.08);color:var(--red);border:1px solid rgba(248,113,113,.2)}
/* ── Modals ──────────────────────────────────────────────────────── */
......@@ -276,4 +280,5 @@ hr{border:none;border-top:1px solid var(--border);margin:1.125rem 0}
.nav-links{gap:0}
.nav-link{padding:.3rem .5rem;font-size:12.5px}
.container{padding:1.25rem 1rem}
.container--full{padding:1.25rem 1rem}
}
This diff is collapsed.
......@@ -120,7 +120,7 @@
<div style="display:grid;grid-template-columns:180px 1fr;gap:1rem;align-items:start">
<div class="form-row" style="margin:0">
<label class="form-label">Scope</label>
<select id="s-broker-scope" class="form-input">
<select id="s-broker-scope" class="form-input" onchange="toggleBrokerFields()">
<option value="user">user</option>
<option value="global">global</option>
</select>
......@@ -128,7 +128,7 @@
<div class="form-row" style="margin:0">
<label class="form-label">Username</label>
<input type="text" id="s-broker-username" class="form-input" placeholder="alice or global">
<span class="form-hint">Use `global` when scope is `global`; otherwise provide the AISBF username.</span>
<span class="form-hint">This is forced to `global` for global scope; user scope requires the AISBF username.</span>
</div>
</div>
<div style="display:grid;grid-template-columns:1fr 1fr;gap:1rem;align-items:start">
......@@ -150,6 +150,11 @@
<input type="text" id="s-broker-advertised-endpoint" class="form-input" placeholder="http://127.0.0.1:8776">
<span class="form-hint">Optional external URL advertised to the broker for this instance.</span>
</div>
<div class="form-row">
<label class="form-label">Websocket path override</label>
<input type="text" id="s-broker-websocket-path" class="form-input" placeholder="/api/coderai/wss">
<span class="form-hint">Optional manual websocket route override for proxied or custom broker deployments; leave empty to derive from scope.</span>
</div>
<div style="display:grid;grid-template-columns:repeat(3, minmax(0, 1fr));gap:1rem;align-items:start">
<div class="form-row" style="margin:0">
<label class="form-label">Heartbeat seconds</label>
......@@ -164,7 +169,7 @@
<input type="number" id="s-broker-request-timeout" class="form-input" min="1" placeholder="30">
</div>
</div>
<div style="display:grid;grid-template-columns:repeat(2, minmax(0, 1fr));gap:1rem;align-items:start">
<div style="display:grid;grid-template-columns:repeat(3, minmax(0, 1fr));gap:1rem;align-items:start">
<div class="form-row" style="margin:0">
<label class="form-label">Reconnect initial delay</label>
<input type="number" id="s-broker-reconnect-initial" class="form-input" min="1" placeholder="1">
......@@ -173,6 +178,11 @@
<label class="form-label">Reconnect max delay</label>
<input type="number" id="s-broker-reconnect-max" class="form-input" min="1" placeholder="60">
</div>
<div class="form-row" style="margin:0">
<label class="form-label">WS ping interval (s)</label>
<input type="number" id="s-broker-ws-ping" class="form-input" min="5" placeholder="20">
<span class="form-hint" style="font-size:11px">Keeps connection alive through nginx proxies. Lower if you get 504 timeouts.</span>
</div>
</div>
</div>
</div>
......@@ -188,6 +198,16 @@ function toggleHttps(){
function toggleBrokerFields(){
document.getElementById('broker-fields').style.display =
document.getElementById('s-broker-enabled').checked ? 'block' : 'none';
const scope = document.getElementById('s-broker-scope').value;
const usernameInput = document.getElementById('s-broker-username');
if(scope === 'global'){
usernameInput.value = 'global';
usernameInput.readOnly = true;
} else {
usernameInput.readOnly = false;
if(usernameInput.value === 'global') usernameInput.value = '';
}
}
function showAlert(type, msg){
......@@ -232,11 +252,13 @@ async function loadSettings(){
document.getElementById('s-broker-client-id').value = broker.client_id ?? '';
document.getElementById('s-broker-registration-token').value = broker.registration_token ?? '';
document.getElementById('s-broker-advertised-endpoint').value = broker.advertised_endpoint ?? '';
document.getElementById('s-broker-websocket-path').value = broker.websocket_path ?? '';
document.getElementById('s-broker-heartbeat').value = broker.heartbeat_interval_seconds ?? 30;
document.getElementById('s-broker-connect-timeout').value = broker.connect_timeout_seconds ?? 10;
document.getElementById('s-broker-request-timeout').value = broker.request_timeout_seconds ?? 30;
document.getElementById('s-broker-reconnect-initial').value = broker.reconnect_initial_delay_seconds ?? 1;
document.getElementById('s-broker-reconnect-max').value = broker.reconnect_max_delay_seconds ?? 60;
document.getElementById('s-broker-ws-ping').value = broker.websocket_ping_interval ?? 20;
toggleBrokerFields();
}catch(e){ showAlert('error','Failed to load settings: '+e.message); }
}
......@@ -273,11 +295,13 @@ async function saveSettings(){
client_id: document.getElementById('s-broker-client-id').value.trim(),
registration_token: document.getElementById('s-broker-registration-token').value.trim(),
advertised_endpoint: document.getElementById('s-broker-advertised-endpoint').value.trim(),
websocket_path: document.getElementById('s-broker-websocket-path').value.trim(),
heartbeat_interval_seconds: parseInt(document.getElementById('s-broker-heartbeat').value) || 30,
connect_timeout_seconds: parseInt(document.getElementById('s-broker-connect-timeout').value) || 10,
request_timeout_seconds: parseInt(document.getElementById('s-broker-request-timeout').value) || 30,
reconnect_initial_delay_seconds: parseInt(document.getElementById('s-broker-reconnect-initial').value) || 1,
reconnect_max_delay_seconds: parseInt(document.getElementById('s-broker-reconnect-max').value) || 60,
websocket_ping_interval: parseInt(document.getElementById('s-broker-ws-ping').value) || 20,
transport: 'websocket',
},
};
......
......@@ -43,12 +43,23 @@ global_file_path = None
_aud_progress: dict = {
"current": 0, "total": 0, "active": False,
"started_at": 0.0, "it_per_s": 0.0, "unit": "it",
"phase": "idle", "model": "",
}
def _aud_progress_loading(model_name: str = ""):
_aud_progress["phase"] = "loading"
_aud_progress["active"] = True
_aud_progress["current"] = 0
_aud_progress["total"] = 0
_aud_progress["it_per_s"] = 0.0
_aud_progress["started_at"] = time.monotonic()
_aud_progress["model"] = model_name or ""
def _aud_progress_reset(total: int, unit: str = "it"):
_aud_progress["current"] = 0
_aud_progress["total"] = total
_aud_progress["active"] = True
_aud_progress["phase"] = "generating"
_aud_progress["started_at"] = time.monotonic()
_aud_progress["it_per_s"] = 0.0
_aud_progress["unit"] = unit
......@@ -56,6 +67,7 @@ def _aud_progress_reset(total: int, unit: str = "it"):
def _aud_progress_done():
_aud_progress["current"] = max(_aud_progress["current"], _aud_progress["total"])
_aud_progress["active"] = False
_aud_progress["phase"] = "idle"
def _aud_progress_step(step: int):
_aud_progress["current"] = step
......@@ -196,6 +208,8 @@ async def get_audio_progress():
"current": current,
"total": total,
"active": _aud_progress["active"],
"phase": _aud_progress.get("phase", "idle"),
"model": _aud_progress.get("model", ""),
"pct": int(current / total * 100) if total > 0 else 0,
"it_per_s": _aud_progress["it_per_s"],
"elapsed": round(elapsed, 1),
......@@ -209,6 +223,7 @@ async def audio_generate(request: AudioGenerationRequest, http_request: Request
Generate music, sound effects, or ambient audio.
Compatible models: MusicGen, AudioGen, AudioLDM2, StableAudio.
"""
_aud_progress_loading(request.model or "audio")
model_info = multi_model_manager.request_model(request.model, model_type="audio_gen")
model_name = model_info.get('model_name')
if not model_name:
......
......@@ -123,18 +123,30 @@ import time as _time
_gen_progress: dict = {
"current": 0, "total": 0, "active": False,
"started_at": 0.0, "it_per_s": 0.0,
"phase": "idle", "model": "",
}
def _progress_loading(model_name: str = ""):
_gen_progress["phase"] = "loading"
_gen_progress["active"] = True
_gen_progress["current"] = 0
_gen_progress["total"] = 0
_gen_progress["it_per_s"] = 0.0
_gen_progress["started_at"] = _time.monotonic()
_gen_progress["model"] = model_name or ""
def _progress_reset(total: int):
_gen_progress["current"] = 0
_gen_progress["total"] = total
_gen_progress["active"] = True
_gen_progress["phase"] = "generating"
_gen_progress["started_at"] = _time.monotonic()
_gen_progress["it_per_s"] = 0.0
def _progress_done():
_gen_progress["current"] = _gen_progress["total"]
_gen_progress["active"] = False
_gen_progress["phase"] = "idle"
def _progress_step(step: int):
_gen_progress["current"] = step
......@@ -894,6 +906,8 @@ async def get_image_progress():
"current": _gen_progress["current"],
"total": _gen_progress["total"],
"active": _gen_progress["active"],
"phase": _gen_progress.get("phase", "idle"),
"model": _gen_progress.get("model", ""),
"pct": int(_gen_progress["current"] / _gen_progress["total"] * 100)
if _gen_progress["total"] > 0 else 0,
"it_per_s": _gen_progress["it_per_s"],
......@@ -944,6 +958,7 @@ async def create_image_generation(request: ImageGenerationRequest, http_request:
# =====================================================================
# Step 1: Ask the manager to resolve the model and manage VRAM
# =====================================================================
_progress_loading(request.model or "image")
model_info = multi_model_manager.request_model(
requested_model=request.model,
model_type="image"
......@@ -1173,6 +1188,7 @@ async def create_image_edit(request: ImageEditRequest, http_request: Request = N
if not request.image:
raise HTTPException(status_code=400, detail="image is required")
_progress_loading(request.model or "image")
model_info = multi_model_manager.request_model(request.model, model_type="image")
model_name = model_info.get('model_name')
if not model_name:
......@@ -1306,6 +1322,7 @@ async def create_image_inpaint(request: ImageInpaintRequest, http_request: Reque
global global_args
if not request.image or not request.mask:
raise HTTPException(status_code=400, detail="image and mask are required")
_progress_loading(request.model or "image")
model_info = multi_model_manager.request_model(request.model, model_type="image")
model_name = model_info.get('model_name')
if not model_name:
......@@ -1414,6 +1431,7 @@ def _run_upscale(upscaler, image_bytes: bytes, scale: int):
async def create_image_upscale(request: ImageUpscaleRequest, http_request: Request = None):
"""Upscale an image using Real-ESRGAN or PIL LANCZOS fallback."""
global global_args
_progress_loading(request.model or "image")
model_info = multi_model_manager.request_model(request.model, model_type="image")
model_name = model_info.get('model_name') or request.model
model_key = f"upscale:{model_name}"
......
......@@ -41,6 +41,11 @@ class BearerAuthMiddleware(BaseHTTPMiddleware):
if not path.startswith("/v1/") or path in self._EXEMPT_PATHS:
return await call_next(request)
# Requests from the ASGI broker bridge are in-process and have no real
# Bearer token. Identify them by the sentinel server tuple set in asgi_bridge.py.
if request.scope.get("server") == ("internal", 80):
return await call_next(request)
from codai.admin import routes as _admin_routes
sm = _admin_routes.session_manager
if sm is None:
......@@ -98,6 +103,7 @@ class _Bucket:
self.window_start = now
class RateLimitMiddleware(BaseHTTPMiddleware):
"""Apply per-IP, per-route-prefix rate limiting to API endpoints."""
......
This diff is collapsed.
......@@ -41,11 +41,44 @@ except (ImportError, AttributeError):
try:
from llama_cpp import Llama
from llama_cpp.llama_chat_format import ChatFormatterResponse
import llama_cpp as _llama_cpp
LLAMA_CPP_AVAILABLE = True
except ImportError:
LLAMA_CPP_AVAILABLE = False
Llama = None
ChatFormatterResponse = None
_llama_cpp = None
def _install_layer_log_callback():
"""Replace llama.cpp's log callback with one that prints load-time layer/buffer
messages directly to stdout. Returns the callback object — keep a reference
alive for the duration of the load so ctypes doesn't garbage-collect it."""
if _llama_cpp is None:
return None
# Keywords that identify interesting load-phase messages
_KEEP = (
'llm_load_tensors', 'llm_load_print_meta',
'offload', 'layer', 'buffer size', 'buffer type',
'GPU', 'CUDA', 'Vulkan', 'Metal', 'ROCm', 'SYCL',
'CPU', 'VRAM', 'n_layer', 'n_gpu_layers',
)
@_llama_cpp.llama_log_callback
def _cb(level, text, user_data):
try:
msg = (text.decode('utf-8', errors='replace') if isinstance(text, bytes) else str(text)).rstrip()
if msg and any(k in msg for k in _KEEP):
print(f" [llama.cpp] {msg}", flush=True)
except Exception:
pass
try:
_llama_cpp.llama_log_set(_cb, None)
except Exception:
return None
return _cb # caller must hold this reference
class VulkanBackend(ModelBackend):
......@@ -450,17 +483,18 @@ class VulkanBackend(ModelBackend):
# Try to find GGUF files in the repository
try:
from huggingface_hub import list_repo_files, hf_hub_download
from codai.models.cache import get_hf_hub_cache_dir
print(f"DEBUG: Searching for GGUF files in {model_path}...")
files = list(list_repo_files(model_path, repo_type="model"))
gguf_files = [f for f in files if f.lower().endswith('.gguf')]
if gguf_files:
# Prefer Q4_K_M or Q4_K quantizations, otherwise use first available
preferred = [f for f in gguf_files if 'q4_k_m' in f.lower() or 'q4_k' in f.lower()]
selected = preferred[0] if preferred else gguf_files[0]
print(f"DEBUG: Found GGUF files: {gguf_files}")
print(f"DEBUG: Selected: {selected}")
model_path = hf_hub_download(repo_id=model_path, filename=selected, cache_dir=kwargs.get('cache_dir'))
model_path = hf_hub_download(repo_id=model_path, filename=selected, cache_dir=kwargs.get('cache_dir') or get_hf_hub_cache_dir())
print(f"DEBUG: Downloaded: {model_path}")
else:
print(f"Warning: No GGUF files found in {model_path}, trying direct download...")
......@@ -471,8 +505,9 @@ class VulkanBackend(ModelBackend):
# Try to get from HuggingFace
try:
from huggingface_hub import hf_hub_download
from codai.models.cache import get_hf_hub_cache_dir
# Download the GGUF file
model_path = hf_hub_download(repo_id=model_path, filename="*.gguf", cache_dir=kwargs.get('cache_dir'))
model_path = hf_hub_download(repo_id=model_path, filename="*.gguf", cache_dir=kwargs.get('cache_dir') or get_hf_hub_cache_dir())
except Exception as e:
print(f"Warning: Could not download from HuggingFace: {e}")
# Try as-is
......@@ -557,18 +592,41 @@ class VulkanBackend(ModelBackend):
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
# llama-cpp-python will use CUDA when available
# Pre-load summary
gpu_label = "all" if self.n_gpu_layers == -1 else str(self.n_gpu_layers)
print(f" n_gpu_layers : {gpu_label} | n_ctx : {self.n_ctx} | main_gpu : {self.main_gpu}")
if _llama_cpp:
gpu_supported = _llama_cpp.llama_supports_gpu_offload()
print(f" GPU offload : {'supported' if gpu_supported else 'NOT supported by this build'}")
_log_cb = _install_layer_log_callback()
try:
self.model = Llama(**llama_kwargs)
# Try to detect and set up chat template
self._finalize_chat_template_detection()
print(f"DEBUG: VulkanBackend loaded model: {model_path}")
print(f"DEBUG: n_gpu_layers={self.n_gpu_layers}, n_ctx={self.n_ctx}, no_ram={no_ram}")
print(f"DEBUG: chat_template={self.chat_template}")
except Exception as e:
print(f"Error loading GGUF model: {e}")
raise
finally:
# Restore llama.cpp's default (quiet) logging after load
if _llama_cpp:
try:
_llama_cpp.llama_log_set(None, None)
except Exception:
pass
_log_cb = None # release callback
# Post-load layer/buffer summary
try:
n_total = _llama_cpp.llama_model_n_layer(self.model.model)
n_gpu_actual = n_total if self.n_gpu_layers == -1 else min(self.n_gpu_layers, n_total)
n_cpu = n_total - n_gpu_actual
print(f" Layers total : {n_total}")
print(f" Layers → GPU : {n_gpu_actual} | Layers → CPU : {n_cpu}")
except Exception:
pass
# Try to detect and set up chat template
self._finalize_chat_template_detection()
print(f" chat_template: {self.chat_template}")
def generate(
self,
......
"""Helpers for executing in-process ASGI HTTP requests."""
import base64
import logging
import uuid
from urllib.parse import urlencode
logger = logging.getLogger(__name__)
def _build_multipart_body(multipart):
boundary = f"coderai-broker-{uuid.uuid4().hex}"
chunks = []
for field in multipart.get("fields") or []:
name = str(field.get("name") or "")
value = str(field.get("value") or "")
chunks.append(f"--{boundary}\r\n".encode("utf-8"))
chunks.append(f'Content-Disposition: form-data; name="{name}"\r\n\r\n'.encode("utf-8"))
chunks.append(value.encode("utf-8"))
chunks.append(b"\r\n")
for file_entry in multipart.get("files") or []:
name = str(file_entry.get("name") or "file")
filename = str(file_entry.get("filename") or "upload.bin")
content_type = str(file_entry.get("content_type") or "application/octet-stream")
data_base64 = file_entry.get("data_base64") or ""
file_bytes = base64.b64decode(data_base64) if data_base64 else b""
chunks.append(f"--{boundary}\r\n".encode("utf-8"))
chunks.append(
f'Content-Disposition: form-data; name="{name}"; filename="{filename}"\r\n'.encode("utf-8")
)
chunks.append(f"Content-Type: {content_type}\r\n\r\n".encode("utf-8"))
chunks.append(file_bytes)
chunks.append(b"\r\n")
chunks.append(f"--{boundary}--\r\n".encode("utf-8"))
return b"".join(chunks), f"multipart/form-data; boundary={boundary}"
async def execute_internal_request(app, *, method, path, headers=None, query=None, body=b""):
logger.debug(
"ASGI bridge → %s %s query=%s body_bytes=%d",
method.upper(), path, query or {}, len(body),
)
request_headers = []
for key, value in (headers or {}).items():
request_headers.append((key.lower().encode("latin-1"), str(value).encode("latin-1")))
......@@ -49,4 +89,14 @@ async def execute_internal_request(app, *, method, path, headers=None, query=Non
response["body"] += message.get("body", b"")
await app(scope, receive, send)
body_preview = response["body"][:200].decode("utf-8", errors="replace") if response["body"] else ""
logger.debug(
"ASGI bridge ← %s %s status=%d content-type=%s body_bytes=%d body_preview=%r",
method.upper(), path,
response["status_code"],
response["headers"].get("content-type", ""),
len(response["body"]),
body_preview,
)
return response
"""Broker capability documents and registration payloads."""
import glob
import os
import platform
import socket
from typing import Any, Dict, Sequence
......@@ -41,15 +43,67 @@ DEFAULT_STUDIO_ENDPOINTS = [
def build_hardware_summary() -> Dict[str, Any]:
"""Build a conservative default hardware summary."""
"""Build a conservative hardware summary with VRAM when available."""
gpus = []
total_vram_mb = 0
available_vram_mb = 0
try:
import torch
if torch.cuda.is_available():
gpu_count = torch.cuda.device_count()
for index in range(gpu_count):
props = torch.cuda.get_device_properties(index)
device_total_mb = int(props.total_memory / (1024 * 1024))
if index == torch.cuda.current_device():
free_bytes, total_bytes = torch.cuda.mem_get_info()
total_vram_mb = int(total_bytes / (1024 * 1024))
available_vram_mb = int(free_bytes / (1024 * 1024))
gpus.append(
{
"index": index,
"name": torch.cuda.get_device_name(index),
"total_vram_mb": device_total_mb,
}
)
if gpus:
if total_vram_mb == 0:
total_vram_mb = sum(gpu["total_vram_mb"] for gpu in gpus)
if available_vram_mb == 0 and total_vram_mb:
available_vram_mb = total_vram_mb
except Exception:
pass
if not gpus:
for total_path in sorted(glob.glob("/sys/class/drm/card*/device/mem_info_vram_total")):
used_path = total_path.replace("vram_total", "vram_used")
if not os.path.exists(used_path):
continue
try:
device_total_mb = int(int(open(total_path).read()) / (1024 * 1024))
device_used_mb = int(int(open(used_path).read()) / (1024 * 1024))
device_available_mb = max(0, device_total_mb - device_used_mb)
card_name = os.path.basename(os.path.dirname(os.path.dirname(total_path)))
gpus.append(
{
"name": card_name,
"total_vram_mb": device_total_mb,
}
)
total_vram_mb += device_total_mb
available_vram_mb += device_available_mb
except Exception:
continue
return {
"hostname": socket.gethostname(),
"platform": platform.platform(),
"gpus": [],
"gpu_count": 0,
"total_vram_mb": 0,
"available_vram_mb": 0,
"gpus": gpus,
"gpu_count": len(gpus),
"total_vram_mb": total_vram_mb,
"available_vram_mb": available_vram_mb,
}
......@@ -64,6 +118,7 @@ def build_capabilities_document(
"server": "codai",
"version": version,
"transports": ["websocket"],
"tunnel_only": True,
"openai_compat": {
"chat_completions": True,
"responses": False,
......@@ -88,15 +143,23 @@ def build_register_message(
) -> Dict[str, Any]:
"""Build broker registration frame."""
registration_token = runtime.registration_token
return {
"v": 1,
"op": "register",
"request_id": request_id,
"registration_token": registration_token,
"capabilities": capabilities,
"payload": {
"endpoint": runtime.advertised_endpoint,
"transport": runtime.transport,
"registration_token": runtime.headers.get("Authorization", "").removeprefix("Bearer "),
"registration_token": registration_token,
"hardware": hardware,
"gpus": (hardware or {}).get("gpus", []),
"gpu_count": (hardware or {}).get("gpu_count", 0),
"total_vram_mb": (hardware or {}).get("total_vram_mb", 0),
"available_vram_mb": (hardware or {}).get("available_vram_mb", 0),
"studio_endpoints": list(studio_endpoints or DEFAULT_STUDIO_ENDPOINTS),
"capabilities": capabilities,
},
......
This diff is collapsed.
......@@ -21,12 +21,14 @@ class BrokerConfig:
client_id: str = ""
registration_token: str = ""
advertised_endpoint: str = ""
websocket_path: str = ""
transport: str = "websocket"
heartbeat_interval_seconds: int = 30
connect_timeout_seconds: int = 10
request_timeout_seconds: int = 30
reconnect_initial_delay_seconds: int = 1
reconnect_max_delay_seconds: int = 60
websocket_ping_interval: int = 20
@dataclass
......@@ -36,13 +38,28 @@ class BrokerRuntimeConfig:
enabled: bool
websocket_url: str = ""
headers: Dict[str, str] = field(default_factory=dict)
provider_id: str = ""
client_id: str = ""
username: str = ""
registration_token: str = ""
advertised_endpoint: str = ""
websocket_path: str = ""
transport: str = "websocket"
heartbeat_interval_seconds: int = 30
connect_timeout_seconds: int = 10
request_timeout_seconds: int = 30
reconnect_initial_delay_seconds: int = 1
reconnect_max_delay_seconds: int = 60
websocket_ping_interval: int = 20
def _join_broker_path(base_path: str, suffix: str) -> str:
normalized_base = (base_path or "").rstrip("/")
normalized_suffix = suffix if suffix.startswith("/") else f"/{suffix}"
if normalized_base.endswith("/api") and normalized_suffix.startswith("/api/"):
return f"{normalized_base}{normalized_suffix[4:]}"
return f"{normalized_base}{normalized_suffix}" if normalized_base else normalized_suffix
def build_broker_runtime_config(config: BrokerConfig) -> BrokerRuntimeConfig:
......@@ -50,18 +67,29 @@ def build_broker_runtime_config(config: BrokerConfig) -> BrokerRuntimeConfig:
runtime = BrokerRuntimeConfig(
enabled=config.enabled,
provider_id=config.provider_id,
client_id=config.client_id,
username=config.username,
registration_token=config.registration_token,
advertised_endpoint=config.advertised_endpoint,
websocket_path=config.websocket_path,
transport=config.transport,
heartbeat_interval_seconds=config.heartbeat_interval_seconds,
connect_timeout_seconds=config.connect_timeout_seconds,
request_timeout_seconds=config.request_timeout_seconds,
reconnect_initial_delay_seconds=config.reconnect_initial_delay_seconds,
reconnect_max_delay_seconds=config.reconnect_max_delay_seconds,
websocket_ping_interval=config.websocket_ping_interval,
)
if not config.enabled:
return runtime
if config.scope == "global":
custom_websocket_path = (config.websocket_path or "").strip()
if custom_websocket_path:
suffix = custom_websocket_path
if not suffix.startswith("/"):
suffix = f"/{suffix}"
elif config.scope == "global":
if config.username != "global":
raise BrokerConfigError("global broker scope requires username 'global'")
suffix = "/api/coderai/wss"
......@@ -88,7 +116,7 @@ def build_broker_runtime_config(config: BrokerConfig) -> BrokerRuntimeConfig:
scheme = {"http": "ws", "https": "wss"}.get(split_url.scheme, split_url.scheme)
base_path = split_url.path.rstrip("/")
path = f"{base_path}{suffix}" if base_path else suffix
path = _join_broker_path(base_path, suffix)
query = urlencode(
{
"provider_id": config.provider_id,
......@@ -103,6 +131,7 @@ def build_broker_runtime_config(config: BrokerConfig) -> BrokerRuntimeConfig:
"x-coderai-provider-id": config.provider_id,
"x-coderai-client-id": config.client_id,
"x-coderai-username": config.username,
"x-coderai-registration-token": config.registration_token,
"x-coderai-advertised-endpoint": config.advertised_endpoint,
}
return runtime
......@@ -3,13 +3,17 @@
from __future__ import annotations
import json
import logging
from base64 import b64encode
from base64 import b64decode
from time import perf_counter
from typing import Any
from codai.broker.asgi_bridge import execute_internal_request
from codai.broker.models import error_envelope, success_envelope
logger = logging.getLogger(__name__)
SUPPORTED_PREFIXES = (
"/v1/models",
"/v1/chat/completions",
......@@ -17,9 +21,18 @@ SUPPORTED_PREFIXES = (
"/v1/audio",
"/v1/video",
"/v1/pipelines",
"/v1/files",
"/coderai/capabilities",
"/admin",
"/static",
)
OP_ROUTE_MAP = {
"models.list": ("GET", "/v1/models"),
"chat.completions": ("POST", "/v1/chat/completions"),
"capabilities": ("GET", "/coderai/capabilities"),
}
TEXT_CONTENT_TYPES = (
"application/json",
"application/ld+json",
......@@ -50,9 +63,46 @@ def _is_text_response(content_type: str | None) -> bool:
async def execute_broker_request(app, envelope):
"""Validate and execute a broker request envelope."""
logger.debug(
"broker dispatch → op=%s request_id=%s path=%r method=%r stream=%s",
envelope.op, envelope.request_id, envelope.path, envelope.method, envelope.stream,
)
if envelope.op == "proxy":
proxy_payload = envelope.payload or {}
endpoint_path = str(proxy_payload.get("endpoint_path") or envelope.path or "").strip()
if endpoint_path and not endpoint_path.startswith("/"):
endpoint_path = f"/{endpoint_path}"
envelope.path = endpoint_path
envelope.method = str(proxy_payload.get("method") or envelope.method or "GET").upper()
envelope.headers = dict(proxy_payload.get("headers") or envelope.headers)
envelope.query = dict(proxy_payload.get("query_params") or envelope.query)
envelope.stream = bool(proxy_payload.get("stream", envelope.stream))
if "body" in proxy_payload:
envelope.payload = proxy_payload.get("body")
elif proxy_payload.get("body_base64") is not None:
envelope.payload = b64decode(proxy_payload.get("body_base64") or "")
elif proxy_payload.get("multipart") is not None:
envelope.payload = {"_broker_multipart": proxy_payload.get("multipart")}
logger.debug("broker dispatch proxy resolved → %s %s", envelope.method, envelope.path)
elif envelope.op in OP_ROUTE_MAP:
envelope.method, envelope.path = OP_ROUTE_MAP[envelope.op]
logger.debug("broker dispatch op mapped → %s %s", envelope.method, envelope.path)
elif not envelope.path:
logger.warning("broker dispatch unsupported op=%s request_id=%s", envelope.op, envelope.request_id)
return error_envelope(
envelope.request_id,
code="unsupported_operation",
message=f"Unsupported broker op: {envelope.op}",
)
envelope.validate()
if not is_supported_path(envelope.path):
logger.warning(
"broker dispatch unsupported path=%r op=%s request_id=%s",
envelope.path, envelope.op, envelope.request_id,
)
return error_envelope(
envelope.request_id,
code="unsupported_endpoint",
......@@ -60,18 +110,28 @@ async def execute_broker_request(app, envelope):
)
body: bytes
if isinstance(envelope.payload, (dict, list)):
if isinstance(envelope.payload, dict) and "_broker_multipart" in envelope.payload:
from codai.broker.asgi_bridge import _build_multipart_body
body, multipart_content_type = _build_multipart_body(envelope.payload["_broker_multipart"] or {})
headers = dict(envelope.headers)
headers["content-type"] = multipart_content_type
elif isinstance(envelope.payload, (dict, list)):
body = json.dumps(envelope.payload, separators=(",", ":")).encode("utf-8")
headers = dict(envelope.headers)
elif isinstance(envelope.payload, str):
body = envelope.payload.encode("utf-8")
headers = dict(envelope.headers)
elif isinstance(envelope.payload, bytes):
body = envelope.payload
headers = dict(envelope.headers)
elif envelope.payload is None:
body = b""
headers = dict(envelope.headers)
else:
body = json.dumps(envelope.payload, separators=(",", ":")).encode("utf-8")
headers = dict(envelope.headers)
headers = dict(envelope.headers)
if body and "content-type" not in {key.lower() for key in headers}:
headers["content-type"] = envelope.content_type
......@@ -103,17 +163,36 @@ async def execute_broker_request(app, envelope):
payload["content_type"] = content_type
if _is_text_response(content_type):
payload["body"] = response["body"].decode("utf-8")
body_text = response["body"].decode("utf-8")
payload["body"] = body_text
logger.debug(
"broker dispatch ← op=%s request_id=%s status=%d elapsed_ms=%s body_bytes=%d body_preview=%r",
envelope.op, envelope.request_id,
response["status_code"], elapsed_ms,
len(response["body"]),
body_text[:300],
)
else:
payload["body_base64"] = b64encode(response["body"]).decode("ascii")
filename = response_headers.get("x-filename")
if filename:
payload["filename"] = filename
logger.debug(
"broker dispatch ← op=%s request_id=%s status=%d elapsed_ms=%s body_bytes=%d (binary)",
envelope.op, envelope.request_id,
response["status_code"], elapsed_ms,
len(response["body"]),
)
if envelope.stream:
payload["stream"] = True
return success_envelope(
result = success_envelope(
envelope.request_id,
payload=payload,
metrics={"elapsed_ms": elapsed_ms},
)
logger.debug(
"broker dispatch envelope_bytes=%d op=%s request_id=%s",
len(json.dumps(result)), envelope.op, envelope.request_id,
)
return result
......@@ -9,8 +9,9 @@ class BrokerRequestEnvelope:
"""Normalized broker request payload."""
request_id: str
method: str
path: str
op: str
method: str = "GET"
path: str = ""
headers: Dict[str, str] = field(default_factory=dict)
query: Dict[str, Any] = field(default_factory=dict)
payload: Any = None
......@@ -22,9 +23,13 @@ class BrokerRequestEnvelope:
if not self.request_id or not isinstance(self.request_id, str):
raise ValueError("request_id is required")
if not self.op or not isinstance(self.op, str):
raise ValueError("op is required")
if self.op != "proxy" and not self.path:
raise ValueError("path is required")
if not self.method or not isinstance(self.method, str):
raise ValueError("method is required")
if not self.path or not isinstance(self.path, str):
if self.path and not isinstance(self.path, str):
raise ValueError("path is required")
......@@ -32,8 +37,9 @@ def success_envelope(request_id: str, payload: Any, event: str | None = None, me
"""Build a success response envelope."""
envelope = {
"v": 1,
"request_id": request_id,
"ok": True,
"status": "ok",
"payload": payload,
}
if event is not None:
......@@ -53,7 +59,8 @@ def error_envelope(request_id: str, code: str, message: str, details: Dict[str,
if details is not None:
error["details"] = details
return {
"v": 1,
"request_id": request_id,
"ok": False,
"status": "error",
"error": error,
}
......@@ -18,14 +18,17 @@ class BrokerService:
self.client.dispatcher = dispatch
self.task: asyncio.Task | None = None
self._started = False
def start(self):
if not self.client.runtime.enabled or self.task is not None:
if not self.client.runtime.enabled or self.task is not None or self._started:
return
self._started = True
self.task = asyncio.create_task(self.client.run_forever())
async def stop(self):
if self.task is None:
self._started = False
return
task = self.task
self.task = None
......@@ -34,3 +37,5 @@ class BrokerService:
await task
except asyncio.CancelledError:
pass
finally:
self._started = False
......@@ -18,6 +18,7 @@
import sys
import os
import logging
import threading as _t
# Import configuration from codai modules
from codai.cli import parse_args
......@@ -30,6 +31,83 @@ from codai.broker import BrokerConfigError, build_broker_runtime_config
logger = logging.getLogger(__name__)
def _migrate_hf_gguf_to_gguf_cache() -> None:
"""Move GGUF files stored in the HF cache into the flat GGUF cache directory.
Runs once at startup in a background thread. For repos whose only
non-trivial content is GGUF files, the HF cache entry is removed after
the files are safely copied across.
"""
import shutil
from codai.models.cache import get_hf_hub_cache_dir, get_model_cache_dir
hf_dir = get_hf_hub_cache_dir()
if not os.path.exists(hf_dir):
return
gguf_cache = get_model_cache_dir()
try:
from huggingface_hub import scan_cache_dir
info = scan_cache_dir(hf_dir)
except Exception:
return
_TRIVIAL_EXTS = {'.json', '.txt', '.md', '.py', '.gitattributes', '.model', '.tiktoken', '.vocab'}
migrated_count = 0
repos_to_purge = [] # HF cache repo dirs safe to delete after migration
for repo in info.repos:
if not repo.revisions:
continue
latest_rev = sorted(repo.revisions, key=lambda r: r.commit_hash)[-1]
gguf_files = [f for f in latest_rev.files if f.file_name.endswith('.gguf')]
if not gguf_files:
continue
# Determine whether this repo contains ONLY gguf + trivial metadata
non_trivial = [
f for f in latest_rev.files
if os.path.splitext(f.file_name)[1].lower() not in _TRIVIAL_EXTS
]
gguf_only_repo = non_trivial and all(f.file_name.endswith('.gguf') for f in non_trivial)
all_migrated = True
for f in gguf_files:
dest = os.path.join(gguf_cache, os.path.basename(f.file_name))
if os.path.exists(dest):
continue # already present in GGUF cache
try:
src = os.path.realpath(str(f.file_path)) # resolve symlink → blob
shutil.copy2(src, dest)
migrated_count += 1
logger.info("Migrated GGUF: %s → %s", f.file_name, dest)
except Exception as exc:
logger.warning("Could not migrate %s: %s", f.file_name, exc)
all_migrated = False
if gguf_only_repo and all_migrated:
repo_dir = os.path.join(hf_dir, f"models--{repo.repo_id.replace('/', '--')}")
if os.path.isdir(repo_dir):
repos_to_purge.append((repo.repo_id, repo_dir))
for repo_id, repo_dir in repos_to_purge:
try:
shutil.rmtree(repo_dir)
logger.info("Removed migrated HF cache entry: %s", repo_id)
except Exception as exc:
logger.warning("Could not remove HF cache entry %s: %s", repo_id, exc)
if migrated_count or repos_to_purge:
logger.info(
"GGUF cache migration: %d file(s) moved to %s, %d HF cache entr%s cleaned up.",
migrated_count, gguf_cache,
len(repos_to_purge), "ies" if len(repos_to_purge) != 1 else "y",
)
def main():
"""Main entry point for the codai server."""
# Suppress unraisable exceptions from LlamaModel.__del__
......@@ -54,10 +132,22 @@ def main():
config_mgr = ConfigManager(config_dir)
config = config_mgr.load()
# Apply cache directory overrides from config before any cache module is used
# Apply cache directory overrides from config before any cache module is used.
# We set env vars AND patch huggingface_hub.constants in case the library was
# already imported (constants are computed once at import time from env vars).
if config.models.hf_cache_dir:
hf_hub_cache = os.path.join(config.models.hf_cache_dir, 'hub')
os.environ['HF_HOME'] = config.models.hf_cache_dir
os.environ['HUGGINGFACE_HUB_CACHE'] = config.models.hf_cache_dir
os.environ['HUGGINGFACE_HUB_CACHE'] = hf_hub_cache
try:
import sys as _sys
if 'huggingface_hub.constants' in _sys.modules:
import huggingface_hub.constants as _hfc
_hfc.HF_HUB_CACHE = hf_hub_cache
if hasattr(_hfc, 'HF_HOME'):
_hfc.HF_HOME = config.models.hf_cache_dir
except Exception:
pass
if config.models.gguf_cache_dir:
os.environ['CODERAI_CACHE_DIR'] = config.models.gguf_cache_dir
......@@ -75,7 +165,7 @@ def main():
# Initialize admin session manager and expose config to admin routes
from pathlib import Path
init_session_manager(Path(config_dir))
init_session_manager(Path(config_dir), port=config.server.port)
set_config_manager(config_mgr)
# Handle early exit options (before heavy imports)
......@@ -145,7 +235,8 @@ def main():
print("No GGUF files found, downloading full HuggingFace repo...")
try:
from huggingface_hub import snapshot_download
cached_path = snapshot_download(model_id)
from codai.models.cache import get_hf_hub_cache_dir
cached_path = snapshot_download(model_id, cache_dir=get_hf_hub_cache_dir())
except Exception as e:
print(f"Error downloading full repo: {e}")
cached_path = None
......@@ -175,6 +266,9 @@ def main():
print(f"Error listing devices: {e}")
sys.exit(0)
# Migrate any GGUF files that ended up in the HF cache to the GGUF cache
_t.Thread(target=_migrate_hf_gguf_to_gguf_cache, daemon=True).start()
# Import core modules (only after early exits)
from codai.api import app
from codai.api.state import (
......@@ -724,18 +818,36 @@ def main():
queue_manager.max_size = config.server.queue_max_size
queue_manager.max_parallel_requests = config.server.max_parallel_requests
# Configure Python logging so broker/API log calls reach the terminal.
# uvicorn is started with log_config=None to keep our config in place.
_log_level = logging.DEBUG if global_debug else logging.INFO
logging.basicConfig(
level=_log_level,
format="%(asctime)s [%(levelname)-8s] %(name)s: %(message)s",
stream=sys.stdout,
force=True,
)
# Suppress noisy third-party libraries at WARNING unless in debug mode.
for _noisy in ("httpx", "httpcore", "urllib3", "multipart", "PIL"):
logging.getLogger(_noisy).setLevel(logging.WARNING)
if not global_debug:
logging.getLogger("websockets").setLevel(logging.WARNING)
logging.getLogger("asyncio").setLevel(logging.WARNING)
# Start the server
import uvicorn
print(f"\nStarting server on http://{config.server.host}:{config.server.port}")
print(f"API docs: http://{config.server.host}:{config.server.port}/docs")
print(f"Admin UI: http://{config.server.host}:{config.server.port}/admin")
if model_manager.backend is not None:
actual_backend = model_manager.backend_type
if hasattr(model_manager.backend, 'force_cuda') and model_manager.backend.force_cuda:
actual_backend = "cuda (via llama-cpp-python)"
print(f"Using backend: {actual_backend}")
_uvi_log_level = "debug" if global_debug else "info"
if config.server.https:
import ssl
ssl_keyfile = config.server.https_key_path
......@@ -758,14 +870,17 @@ def main():
except Exception as e:
print(f"Warning: Could not generate certificate: {e}")
print("Falling back to HTTP...")
uvicorn.run(fastapi_app, host=config.server.host, port=config.server.port)
uvicorn.run(fastapi_app, host=config.server.host, port=config.server.port,
log_level=_uvi_log_level, log_config=None)
return
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
ssl_context.load_cert_chain(ssl_certfile, ssl_keyfile)
uvicorn.run(fastapi_app, host=config.server.host, port=config.server.port, ssl_context=ssl_context)
uvicorn.run(fastapi_app, host=config.server.host, port=config.server.port,
ssl_context=ssl_context, log_level=_uvi_log_level, log_config=None)
else:
uvicorn.run(fastapi_app, host=config.server.host, port=config.server.port)
uvicorn.run(fastapi_app, host=config.server.host, port=config.server.port,
log_level=_uvi_log_level, log_config=None)
if __name__ == "__main__":
......
......@@ -56,6 +56,26 @@ def get_model_cache_dir() -> str:
return cache_dir
def get_hf_hub_cache_dir() -> str:
"""Return the HuggingFace Hub cache directory CoderAI is configured to use.
Mirrors huggingface_hub's own env-var priority so that passing this value
as ``cache_dir`` to snapshot_download / hf_hub_download always targets the
same location the library would choose on its own — even if the directory
does not yet exist (first download).
"""
# Priority mirrors huggingface_hub.constants:
# HUGGINGFACE_HUB_CACHE (explicit cache path)
# HF_HOME/hub (parent-home style)
# ~/.cache/huggingface/hub (built-in default)
hf_hub_cache = (
os.environ.get('HUGGINGFACE_HUB_CACHE')
or (os.path.join(os.environ['HF_HOME'], 'hub') if 'HF_HOME' in os.environ else None)
or os.path.join(os.path.expanduser('~'), '.cache', 'huggingface', 'hub')
)
return hf_hub_cache
def get_all_cache_dirs() -> dict:
"""Get all model cache directories."""
caches = {}
......@@ -540,6 +560,7 @@ def remove_all_cached_models() -> int:
# Export all public functions
__all__ = [
'get_model_cache_dir',
'get_hf_hub_cache_dir',
'get_all_cache_dirs',
'get_cached_model_path',
'is_huggingface_model_id',
......
......@@ -110,12 +110,10 @@ def detect_model_capabilities(model_name: str) -> ModelCapabilities:
'animatediff', 'text2video', 'modelscope-t2v',
'zeroscope', 'lavie']):
caps.video_generation = True
caps.text_generation = True # T2V models also do text
return caps
if any(x in n for x in ['wan2.1-t2v', 'wan-t2v']):
caps.video_generation = True
caps.text_generation = True
return caps
# Image-to-video
......@@ -124,17 +122,14 @@ def detect_model_capabilities(model_name: str) -> ModelCapabilities:
'wan2.1-i2v', 'wan-i2v', 'img2vid',
'image2video', 'motionctrl']):
caps.image_to_video = True
caps.image_to_text = True # I2V models process images
return caps
# Wan generic (detect sub-variant)
if 'wan' in n and ('video' in n or 'diffuser' in n):
if 'i2v' in n:
caps.image_to_video = True
caps.image_to_text = True
else:
caps.video_generation = True
caps.text_generation = True
return caps
# Video interpolation
......@@ -158,7 +153,6 @@ def detect_model_capabilities(model_name: str) -> ModelCapabilities:
if any(x in n for x in ['musicgen', 'audiogen', 'audioldm', 'stable-audio',
'mustango', 'noise2music', 'jukebox', 'audiocraft']):
caps.audio_generation = True
caps.text_generation = True # T2A models process text
return caps
if any(x in n for x in ['demucs', 'spleeter', 'asteroid', 'open-unmix']):
......@@ -174,7 +168,6 @@ def detect_model_capabilities(model_name: str) -> ModelCapabilities:
if any(x in n for x in ['kokoro', 'xtts', 'bark', 'tortoise',
'speecht5', 'matcha-tts', 'voicebox']):
caps.text_to_speech = True
caps.text_generation = True # TTS models process text
return caps
# Lip sync / dubbing
......@@ -199,13 +192,11 @@ def detect_model_capabilities(model_name: str) -> ModelCapabilities:
caps.inpainting = True
caps.image_generation = True
caps.image_to_image = True
caps.text_generation = True
return caps
if 'controlnet' in n:
caps.controlnet = True
caps.image_generation = True
caps.text_generation = True
return caps
if any(x in n for x in [
......@@ -235,7 +226,6 @@ def detect_model_capabilities(model_name: str) -> ModelCapabilities:
caps.image_generation = True
caps.image_to_image = True
caps.inpainting = True # most SD/SDXL/Flux checkpoints support inpainting via mask
caps.text_generation = True
return caps
# ── Image: analysis / processing ─────────────────────────────────────────
......@@ -295,12 +285,6 @@ def detect_model_capabilities(model_name: str) -> ModelCapabilities:
'text-embedding', 'voyage-',
]):
caps.embeddings = True
caps.text_generation = True
return caps
# ── GGUF quantised text models ───────────────────────────────────────────
if '.gguf' in n or 'gguf' in n:
caps.text_generation = True
return caps
# Default: text generation
......@@ -315,17 +299,17 @@ _PIPELINE_TAG_CAPS: dict = {
'image-to-text': ['image_to_text', 'text_generation'],
'visual-question-answering': ['image_to_text', 'text_generation'],
'image-text-to-text': ['image_to_text', 'text_generation'],
'text-to-image': ['image_generation', 'image_to_image', 'text_generation'],
'text-to-image': ['image_generation', 'image_to_image'],
'unconditional-image-generation': ['image_generation'],
'image-to-image': ['image_to_image'], # sub-typed below
'automatic-speech-recognition': ['speech_to_text'],
'audio-to-audio': ['audio_to_audio'],
'text-to-speech': ['text_to_speech'],
'text-to-audio': ['audio_generation'],
'text-to-video': ['video_generation', 'text_generation'],
'text-to-video': ['video_generation'],
'image-to-video': ['image_to_video'],
'feature-extraction': ['embeddings', 'text_generation'],
'sentence-similarity': ['embeddings', 'text_generation'],
'feature-extraction': ['embeddings'],
'sentence-similarity': ['embeddings'],
'depth-estimation': ['depth_estimation', 'image_to_text'],
'image-segmentation': ['image_segmentation', 'image_to_text'],
'object-detection': ['object_detection', 'image_to_text'],
......
......@@ -1541,6 +1541,10 @@ class MultiModelManager:
2. Local HuggingFace hub cache scan.
3. HuggingFace API (network, one call per model per process lifetime).
For LoRA adapters (repos containing adapter_config.json) the size of
the base model is added so that VRAM requirements are not
underestimated.
Returns 0 on any failure.
"""
if model_id in MultiModelManager._hf_size_cache:
......@@ -1548,10 +1552,19 @@ class MultiModelManager:
weight_exts = {'.safetensors', '.bin', '.gguf', '.ggml', '.pt'}
def _resolve_base_model(base_model_id: str) -> int:
from codai.models.cache import is_huggingface_model_id
if not base_model_id or base_model_id == model_id:
return 0
if not is_huggingface_model_id(base_model_id):
return 0
return MultiModelManager._hf_cached_model_size_bytes(base_model_id)
# --- Try local HF hub cache first (no network) ---
try:
from huggingface_hub import scan_cache_dir
from codai.models.cache import get_all_cache_dirs
import json as _json
hf_dir = get_all_cache_dirs().get("huggingface")
if hf_dir:
info = scan_cache_dir(hf_dir)
......@@ -1560,11 +1573,26 @@ class MultiModelManager:
continue
revs = sorted(repo.revisions, key=lambda r: r.last_modified, reverse=True)
if revs:
rev = revs[0]
total = sum(
f.size_on_disk
for f in revs[0].files
for f in rev.files
if os.path.splitext(f.file_name)[1].lower() in weight_exts
)
# LoRA adapter: add base model size
for f in rev.files:
if f.file_name == "adapter_config.json":
try:
with open(f.file_path) as fp:
adapter_cfg = _json.load(fp)
base_id = (
adapter_cfg.get("base_model_name_or_path")
or adapter_cfg.get("base_model")
)
total += _resolve_base_model(base_id)
except Exception:
pass
break
if total > 0:
MultiModelManager._hf_size_cache[model_id] = total
return total
......@@ -1580,13 +1608,31 @@ class MultiModelManager:
with urllib.request.urlopen(req, timeout=10) as resp:
data = _json.loads(resp.read())
total = 0
has_adapter_config = False
for sib in data.get("siblings", []):
name = sib.get("rfilename", "")
if name == "adapter_config.json":
has_adapter_config = True
continue
if os.path.splitext(name)[1].lower() not in weight_exts:
continue
lfs = sib.get("lfs") or {}
size = lfs.get("size") or sib.get("size") or 0
total += size
# LoRA adapter: fetch adapter_config.json to get the base model
if has_adapter_config:
try:
cfg_url = f"https://huggingface.co/{model_id}/resolve/main/adapter_config.json"
cfg_req = urllib.request.Request(cfg_url, headers={"User-Agent": "coderai/1.0"})
with urllib.request.urlopen(cfg_req, timeout=10) as resp:
adapter_cfg = _json.loads(resp.read())
base_id = (
adapter_cfg.get("base_model_name_or_path")
or adapter_cfg.get("base_model")
)
total += _resolve_base_model(base_id)
except Exception:
pass
if total > 0:
MultiModelManager._hf_size_cache[model_id] = total
return total
......@@ -2101,9 +2147,12 @@ class MultiModelManager:
needed_gb = self._get_model_used_vram_gb(model_key, resolved_name)
free_gb = self._get_free_vram_gb()
if needed_gb > 0 and free_gb >= needed_gb:
# Require headroom beyond raw weight size for activation buffers
# and generation scratch (30% of model size + 1 GB base).
headroom_gb = max(1.0, needed_gb * 0.30)
if needed_gb > 0 and free_gb >= needed_gb + headroom_gb:
print(f"Ondemand mode - keeping '{loaded_canonical}' in VRAM alongside new model "
f"(need {needed_gb:.1f} GB, have {free_gb:.1f} GB free)")
f"(need {needed_gb:.1f} GB + {headroom_gb:.1f} GB headroom, have {free_gb:.1f} GB free)")
else:
print(f"Ondemand mode - model switch detected:")
print(f" Requested: '{model_key}' (resolved: '{resolved_name}')")
......
......@@ -89,6 +89,22 @@ The outbound WebSocket connection must include:
- `username`: either `global` or the AISBF username for user-owned providers
- `registration_token`: provider-scoped secret from AISBF provider configuration
### Current server-side resolution order
AISBF resolves broker identity in this exact order when the WebSocket handshake arrives:
- `provider_id`: query param `provider_id`, then header `x-coderai-provider-id`, then default `coderai`
- `client_id`: query param `client_id`, then header `x-coderai-client-id`, then generated fallback `anon-<unix_timestamp>`
- `username`: query param `username`, then header `x-coderai-username`, then the path scope name (`global` or the `/api/u/{username}` path segment)
- `registration_token`: query param `registration_token`, then header `x-coderai-registration-token`
Important constraints:
- the `registration_token` is required for admission
- `Authorization: Bearer ...` is currently not used by the broker WebSocket admission check
- if you omit `client_id`, AISBF generates an `anon-*` client id and future broker routing will only work if AISBF also targets that exact generated value
- the `client_id` used by the CoderAI client must match the `coderai_config.client_id` used by the AISBF provider, or the broker can show the session as connected while requests still fail to route
## Optional Headers
AISBF also accepts or may expect these headers:
......@@ -109,6 +125,35 @@ Recommended behavior:
Open the outbound WebSocket to the correct scoped AISBF endpoint.
The handshake is a normal WebSocket upgrade request, which starts as an HTTP `GET` carrying query parameters. This is expected.
Recommended connect template:
```text
wss://<aisbf-host>/<optional-prefix>/api/coderai/wss?provider_id=<provider_id>&client_id=<stable_client_id>&username=global&registration_token=<provider_registration_token>
```
User-scoped template:
```text
wss://<aisbf-host>/<optional-prefix>/api/u/<username>/coderai/wss?provider_id=<provider_id>&client_id=<stable_client_id>&username=<username>&registration_token=<provider_registration_token>
```
Recommended handshake headers:
```text
x-coderai-provider-id: <provider_id>
x-coderai-client-id: <stable_client_id>
x-coderai-username: <username>
x-coderai-registration-token: <provider_registration_token>
```
Best practice:
- send the same identity in both query parameters and headers
- keep `client_id` stable across reconnects
- always reconnect with the same provider scope and owner scope
### 2. Wait for `registered` event
AISBF immediately sends a registration acknowledgment event on successful admission.
......@@ -135,11 +180,21 @@ Store:
- `client_id`
- `username`
- `scope_name`
- `owner_user_id`
- `expires_at`
Notes:
- this event means the socket is admitted and the session row exists
- it does not yet mean hardware/capabilities metadata has been uploaded
- the client should send the explicit `register` operation immediately after this event
### 3. Send explicit `register` operation
After the `registered` event, CoderAI must send a `register` message describing its capabilities, hardware inventory, and advertised endpoints.
AISBF currently processes `register` as a normal inbound WebSocket message and responds with `status=ok` using the same `request_id`.
### 4. Enter long-lived receive loop
Then keep listening for incoming broker requests from AISBF.
......@@ -233,6 +288,60 @@ CoderAI should send this after receiving the initial AISBF `registered` event.
AISBF replies with a success envelope.
### Fields AISBF currently reads from the `register` message
Top-level:
- `v`
- `op` with value `register`
- `request_id`
- optional top-level `registration_token`
- optional top-level `capabilities`
From `payload`:
- `endpoint`
- `transport`
- `registration_token`
- `studio_endpoints`
- `hardware`
- `gpus`
- `gpu_count`
- `total_vram_mb`
- `available_vram_mb`
- `capabilities`
AISBF behavior:
- if `payload.registration_token` or top-level `registration_token` is present and does not match the handshake token, AISBF replies with an error envelope
- if token matches, AISBF persists the metadata onto the broker session
- `payload.capabilities` takes precedence over missing top-level capability data
- if `gpus`, `gpu_count`, `total_vram_mb`, or `available_vram_mb` are omitted at the top level, AISBF falls back to the values inside `payload.hardware`
Minimal acceptable `register` message:
```json
{
"v": 1,
"op": "register",
"request_id": "reg-1",
"payload": {
"transport": "websocket",
"registration_token": "<same_registration_token>",
"capabilities": {}
}
}
```
Recommended full `register` message:
- include `endpoint`
- include `transport`
- include `registration_token`
- include `hardware.gpus`, `hardware.gpu_count`, `hardware.total_vram_mb`, `hardware.available_vram_mb`
- include `studio_endpoints`
- include `capabilities`
### Hardware Reporting Requirements
The `register` payload should include the best hardware view available to the running CoderAI process.
......@@ -326,6 +435,37 @@ Heartbeat payloads may also refresh dynamic hardware state such as changing free
}
```
Current AISBF note:
- AISBF acknowledges heartbeat messages and merges the heartbeat `payload` into session metadata
- keep heartbeat payloads small and non-blocking
- use heartbeats for lightweight dynamic updates only; do not block the main receive loop on expensive hardware rescans
## Async Client Requirements
The broker WebSocket integration must be fully asynchronous.
CoderAI client requirements:
- the main receive loop must never block on model loading, inference, GPU inspection, or disk/network I/O
- expensive work should run in background tasks or worker executors while the socket remains responsive to incoming frames and ping/pong traffic
- the client should be able to receive broker requests while also sending progress or result frames for earlier requests
- the client must not serialize all work behind registration or heartbeat handling
AISBF broker behavior:
- AISBF now drains queued outbound broker requests in a background async task while independently reading inbound websocket messages
- this means the CoderAI client should expect inbound requests to arrive even while it is still sending heartbeat or response messages for unrelated work
- operations are correlated strictly by `request_id`; client implementations must not rely on message ordering alone
Recommended client architecture:
1. one async reader task for inbound WebSocket frames
2. one async writer path or send queue for outbound replies/events
3. per-request async tasks for local execution
4. a lightweight periodic heartbeat task
5. explicit request correlation by `request_id`
AISBF merges those updates into the broker session metadata.
## Local HTTP Endpoints CoderAI Should Expose
......
This diff is collapsed.
......@@ -84,8 +84,13 @@ def test_build_broker_runtime_config_global_scope_builds_url_and_headers():
"x-coderai-provider-id": "provider-1",
"x-coderai-client-id": "client-1",
"x-coderai-username": "global",
"x-coderai-registration-token": "token-123",
"x-coderai-advertised-endpoint": "https://server.example.com",
}
assert runtime.provider_id == "provider-1"
assert runtime.client_id == "client-1"
assert runtime.username == "global"
assert runtime.registration_token == "token-123"
assert runtime.transport == "websocket"
assert runtime.heartbeat_interval_seconds == 30
assert runtime.connect_timeout_seconds == 10
......@@ -94,6 +99,38 @@ def test_build_broker_runtime_config_global_scope_builds_url_and_headers():
assert runtime.reconnect_max_delay_seconds == 60
def test_build_broker_runtime_config_global_scope_uses_global_service_paths():
runtime = build_broker_runtime_config(
BrokerConfig(
enabled=True,
base_url="https://broker.example.com/base",
scope="global",
username="global",
provider_id="provider-1",
client_id="client-1",
registration_token="token-123",
)
)
assert runtime.websocket_url.startswith("wss://broker.example.com/base/api/coderai/wss")
def test_build_broker_runtime_config_user_scope_uses_user_service_paths():
runtime = build_broker_runtime_config(
BrokerConfig(
enabled=True,
base_url="https://broker.example.com/base",
scope="user",
username="alice",
provider_id="provider-1",
client_id="client-1",
registration_token="token-123",
)
)
assert runtime.websocket_url.startswith("wss://broker.example.com/base/api/u/alice/coderai/wss")
def test_build_broker_runtime_config_rejects_invalid_global_username():
try:
build_broker_runtime_config(
......@@ -139,6 +176,7 @@ def test_build_broker_runtime_config_user_scope_uses_user_path():
"x-coderai-provider-id": "provider-1",
"x-coderai-client-id": "client-1",
"x-coderai-username": "alice",
"x-coderai-registration-token": "token-123",
"x-coderai-advertised-endpoint": "https://server.example.com/alice",
}
......@@ -165,6 +203,28 @@ def test_build_broker_runtime_config_preserves_base_url_prefix_in_websocket_url(
)
def test_build_broker_runtime_config_does_not_duplicate_api_prefix_when_base_url_already_ends_with_api():
runtime = build_broker_runtime_config(
BrokerConfig(
enabled=True,
base_url="https://aisbf.cloud/api",
scope="global",
username="global",
provider_id="provider-1",
client_id="client-1",
registration_token="token-123",
)
)
assert runtime.websocket_url == (
"wss://aisbf.cloud/api/coderai/wss"
"?provider_id=provider-1"
"&client_id=client-1"
"&username=global"
"&registration_token=token-123"
)
def test_build_broker_runtime_config_encodes_reserved_username_path_characters():
runtime = build_broker_runtime_config(
BrokerConfig(
......@@ -187,6 +247,48 @@ def test_build_broker_runtime_config_encodes_reserved_username_path_characters()
)
def test_build_broker_runtime_config_uses_manual_websocket_path_override():
runtime = build_broker_runtime_config(
BrokerConfig(
enabled=True,
base_url="https://broker.example.com/prefix",
scope="global",
username="global",
provider_id="provider-1",
client_id="client-1",
registration_token="token-123",
websocket_path="/custom/broker/socket",
)
)
assert runtime.websocket_path == "/custom/broker/socket"
assert runtime.websocket_url == (
"wss://broker.example.com/prefix/custom/broker/socket"
"?provider_id=provider-1"
"&client_id=client-1"
"&username=global"
"&registration_token=token-123"
)
def test_build_broker_runtime_config_normalizes_manual_websocket_path_override_without_leading_slash():
runtime = build_broker_runtime_config(
BrokerConfig(
enabled=True,
base_url="https://broker.example.com",
scope="user",
username="alice",
provider_id="provider-1",
client_id="client-1",
registration_token="token-123",
websocket_path="broker/ws",
)
)
assert runtime.websocket_path == "broker/ws"
assert runtime.websocket_url.startswith("wss://broker.example.com/broker/ws")
def test_build_broker_runtime_config_rejects_invalid_user_scope_username():
try:
build_broker_runtime_config(
......@@ -294,11 +396,17 @@ def test_build_register_message_includes_capabilities_and_hardware():
"v": 1,
"op": "register",
"request_id": "req-1",
"registration_token": "token-123",
"capabilities": capabilities,
"payload": {
"endpoint": "https://server.example.com/alice",
"transport": "websocket",
"registration_token": "token-123",
"hardware": {"gpu": True, "memory_gb": 24},
"gpus": [],
"gpu_count": 0,
"total_vram_mb": 0,
"available_vram_mb": 0,
"studio_endpoints": EXPECTED_STUDIO_ENDPOINTS,
"capabilities": capabilities,
},
......@@ -318,17 +426,65 @@ def test_build_register_message_defaults_token_and_studio_endpoints_for_empty_ru
"v": 1,
"op": "register",
"request_id": "req-1",
"registration_token": "",
"capabilities": {"server": "codai"},
"payload": {
"endpoint": "",
"transport": "websocket",
"registration_token": "",
"hardware": None,
"gpus": [],
"gpu_count": 0,
"total_vram_mb": 0,
"available_vram_mb": 0,
"studio_endpoints": DEFAULT_STUDIO_ENDPOINTS,
"capabilities": {"server": "codai"},
},
}
def test_build_hardware_summary_reports_torch_cuda_vram(monkeypatch):
class FakeProps:
total_memory = 24 * 1024 * 1024 * 1024
class FakeCuda:
@staticmethod
def is_available():
return True
@staticmethod
def device_count():
return 1
@staticmethod
def current_device():
return 0
@staticmethod
def get_device_properties(index):
return FakeProps()
@staticmethod
def get_device_name(index):
return "RTX Test"
@staticmethod
def mem_get_info():
return (10 * 1024 * 1024 * 1024, 24 * 1024 * 1024 * 1024)
class FakeTorch:
cuda = FakeCuda()
monkeypatch.setitem(sys.modules, "torch", FakeTorch())
hardware = build_hardware_summary()
assert hardware["gpu_count"] == 1
assert hardware["total_vram_mb"] == 24576
assert hardware["available_vram_mb"] == 10240
assert hardware["gpus"] == [{"index": 0, "name": "RTX Test", "total_vram_mb": 24576}]
def test_build_capabilities_document_lists_openai_and_studio_support():
document = build_capabilities_document(hardware={"gpu": True})
......
......@@ -6,6 +6,7 @@ sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
import pytest
from fastapi import FastAPI
from fastapi import File, Form, UploadFile
from fastapi.responses import JSONResponse
from fastapi.testclient import TestClient
from starlette.responses import Response
......@@ -87,6 +88,7 @@ async def test_execute_broker_request_returns_success_envelope_for_json_route():
envelope = BrokerRequestEnvelope(
request_id="req-123",
op="chat.completions",
method="POST",
path="/v1/chat/completions",
headers={"accept": "application/json"},
......@@ -96,7 +98,7 @@ async def test_execute_broker_request_returns_success_envelope_for_json_route():
response = await execute_broker_request(app, envelope)
assert response["request_id"] == "req-123"
assert response["ok"] is True
assert response["status"] == "ok"
assert response["payload"] == {
"status_code": 201,
"headers": {
......@@ -125,6 +127,7 @@ async def test_execute_broker_request_preserves_binary_payload_metadata():
envelope = BrokerRequestEnvelope(
request_id="req-binary",
op="proxy",
method="GET",
path="/v1/images/render",
)
......@@ -132,7 +135,7 @@ async def test_execute_broker_request_preserves_binary_payload_metadata():
response = await execute_broker_request(app, envelope)
assert response["request_id"] == "req-binary"
assert response["ok"] is True
assert response["status"] == "ok"
assert response["payload"] == {
"status_code": 200,
"headers": {
......@@ -148,12 +151,75 @@ async def test_execute_broker_request_preserves_binary_payload_metadata():
assert response["metrics"]["elapsed_ms"] >= 0
@pytest.mark.anyio("asyncio")
async def test_execute_broker_request_maps_proxy_operation_to_internal_route():
app = FastAPI()
@app.post("/v1/video/dub")
async def dub_route(payload: dict):
return {"received": payload, "route": "dub"}
envelope = BrokerRequestEnvelope(
request_id="req-proxy-op",
op="proxy",
payload={
"endpoint_path": "v1/video/dub",
"method": "POST",
"headers": {"content-type": "application/json"},
"body": {"prompt": "hello"},
},
)
response = await execute_broker_request(app, envelope)
assert response["status"] == "ok"
assert response["payload"]["status_code"] == 200
assert response["payload"]["body"] == '{"received":{"prompt":"hello"},"route":"dub"}'
@pytest.mark.anyio("asyncio")
async def test_execute_broker_request_supports_proxy_multipart_payloads():
app = FastAPI()
@app.post("/v1/audio/transcriptions")
async def transcription_route(model: str = Form(...), file: UploadFile = File(...)):
data = await file.read()
return {"model": model, "filename": file.filename, "size": len(data)}
envelope = BrokerRequestEnvelope(
request_id="req-multipart",
op="proxy",
payload={
"endpoint_path": "v1/audio/transcriptions",
"method": "POST",
"multipart": {
"fields": [{"name": "model", "value": "whisper-large"}],
"files": [
{
"name": "file",
"filename": "sample.wav",
"content_type": "audio/wav",
"data_base64": "aGVsbG8=",
}
],
},
},
)
response = await execute_broker_request(app, envelope)
assert response["status"] == "ok"
assert response["payload"]["status_code"] == 200
assert response["payload"]["body"] == '{"model":"whisper-large","filename":"sample.wav","size":5}'
@pytest.mark.anyio("asyncio")
async def test_brokered_models_match_direct_http_response_shape():
direct_response = TestClient(real_app).get("/v1/models")
envelope = BrokerRequestEnvelope(
request_id="req-models-shape",
op="models.list",
method="GET",
path="/v1/models",
headers={"accept": "application/json"},
......@@ -163,7 +229,7 @@ async def test_brokered_models_match_direct_http_response_shape():
brokered_body = json.loads(brokered_response["payload"]["body"])
direct_body = direct_response.json()
assert brokered_response["ok"] is True
assert brokered_response["status"] == "ok"
assert brokered_response["payload"]["status_code"] == direct_response.status_code
assert brokered_response["payload"]["content_type"] == direct_response.headers["content-type"]
assert brokered_response["payload"]["headers"]["content-type"] == direct_response.headers["content-type"]
......@@ -192,6 +258,7 @@ async def test_execute_broker_request_rejects_unsupported_endpoint():
app = FastAPI()
envelope = BrokerRequestEnvelope(
request_id="req-unsupported",
op="proxy",
method="GET",
path="/internal",
)
......@@ -199,8 +266,9 @@ async def test_execute_broker_request_rejects_unsupported_endpoint():
response = await execute_broker_request(app, envelope)
assert response == {
"v": 1,
"request_id": "req-unsupported",
"ok": False,
"status": "error",
"error": {
"code": "unsupported_endpoint",
"message": "Unsupported endpoint: /internal",
......
This diff is collapsed.
This diff is collapsed.
......@@ -849,7 +849,10 @@ def test_settings_template_includes_broker_controls():
assert "AISBF Broker" in template
assert "s-broker-enabled" in template
assert "s-broker-base-url" in template
assert 'id="s-broker-scope" class="form-input" onchange="toggleBrokerFields()"' in template
assert "s-broker-provider-id" in template
assert "s-broker-client-id" in template
assert "s-broker-registration-token" in template
assert "s-broker-websocket-path" in template
assert "toggleBrokerFields()" in template
assert "forced to `global` for global scope" in template
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment