New version!

parent 3716e54b
This diff is collapsed.
...@@ -81,6 +81,7 @@ body,.table-wrap,.modal-box,.chat-messages,.studio .model-list,.studio .chat-msg ...@@ -81,6 +81,7 @@ body,.table-wrap,.modal-box,.chat-messages,.studio .model-list,.studio .chat-msg
/* ── Main ────────────────────────────────────────────────────────── */ /* ── Main ────────────────────────────────────────────────────────── */
.main{min-height:calc(100vh - 44px)} .main{min-height:calc(100vh - 44px)}
.container{max-width:1100px;margin:0 auto;padding:2rem 1.5rem} .container{max-width:1100px;margin:0 auto;padding:2rem 1.5rem}
.container--full{max-width:100%;padding:2rem 1.5rem}
/* ── Page header ─────────────────────────────────────────────────── */ /* ── Page header ─────────────────────────────────────────────────── */
.page-header{display:flex;justify-content:space-between;align-items:flex-start;margin-bottom:1.5rem;gap:1rem} .page-header{display:flex;justify-content:space-between;align-items:flex-start;margin-bottom:1.5rem;gap:1rem}
...@@ -125,6 +126,8 @@ body,.table-wrap,.modal-box,.chat-messages,.studio .model-list,.studio .chat-msg ...@@ -125,6 +126,8 @@ body,.table-wrap,.modal-box,.chat-messages,.studio .model-list,.studio .chat-msg
.btn-ghost:hover{color:var(--text);border-color:var(--border-2)} .btn-ghost:hover{color:var(--text);border-color:var(--border-2)}
.btn-danger{background:rgba(248,113,113,.08);color:var(--red);border:1px solid rgba(248,113,113,.2)} .btn-danger{background:rgba(248,113,113,.08);color:var(--red);border:1px solid rgba(248,113,113,.2)}
.btn-danger:hover{background:rgba(248,113,113,.15);border-color:rgba(248,113,113,.4)} .btn-danger:hover{background:rgba(248,113,113,.15);border-color:rgba(248,113,113,.4)}
.btn-warn{background:rgba(251,191,36,.08);color:#f59e0b;border:1px solid rgba(251,191,36,.25)}
.btn-warn:hover{background:rgba(251,191,36,.15);border-color:rgba(251,191,36,.45)}
.btn-sm{padding:.25rem .625rem;font-size:12px} .btn-sm{padding:.25rem .625rem;font-size:12px}
.btn-sm svg{width:11px;height:11px} .btn-sm svg{width:11px;height:11px}
.btn:disabled{opacity:.4;cursor:not-allowed} .btn:disabled{opacity:.4;cursor:not-allowed}
...@@ -176,6 +179,7 @@ td code{font-family:var(--mono);font-size:11.5px;background:var(--raised);paddin ...@@ -176,6 +179,7 @@ td code{font-family:var(--mono);font-size:11.5px;background:var(--raised);paddin
.badge-user{background:var(--raised);color:var(--text-3);border:1px solid var(--border)} .badge-user{background:var(--raised);color:var(--text-3);border:1px solid var(--border)}
.badge-ok{background:rgba(52,211,153,.08);color:var(--green);border:1px solid rgba(52,211,153,.2)} .badge-ok{background:rgba(52,211,153,.08);color:var(--green);border:1px solid rgba(52,211,153,.2)}
.badge-warn{background:rgba(251,191,36,.08);color:#f59e0b;border:1px solid rgba(251,191,36,.2)} .badge-warn{background:rgba(251,191,36,.08);color:#f59e0b;border:1px solid rgba(251,191,36,.2)}
.badge-err{background:rgba(248,113,113,.08);color:var(--red);border:1px solid rgba(248,113,113,.2)}
.badge-danger{background:rgba(248,113,113,.08);color:var(--red);border:1px solid rgba(248,113,113,.2)} .badge-danger{background:rgba(248,113,113,.08);color:var(--red);border:1px solid rgba(248,113,113,.2)}
/* ── Modals ──────────────────────────────────────────────────────── */ /* ── Modals ──────────────────────────────────────────────────────── */
...@@ -276,4 +280,5 @@ hr{border:none;border-top:1px solid var(--border);margin:1.125rem 0} ...@@ -276,4 +280,5 @@ hr{border:none;border-top:1px solid var(--border);margin:1.125rem 0}
.nav-links{gap:0} .nav-links{gap:0}
.nav-link{padding:.3rem .5rem;font-size:12.5px} .nav-link{padding:.3rem .5rem;font-size:12.5px}
.container{padding:1.25rem 1rem} .container{padding:1.25rem 1rem}
.container--full{padding:1.25rem 1rem}
} }
This diff is collapsed.
...@@ -120,7 +120,7 @@ ...@@ -120,7 +120,7 @@
<div style="display:grid;grid-template-columns:180px 1fr;gap:1rem;align-items:start"> <div style="display:grid;grid-template-columns:180px 1fr;gap:1rem;align-items:start">
<div class="form-row" style="margin:0"> <div class="form-row" style="margin:0">
<label class="form-label">Scope</label> <label class="form-label">Scope</label>
<select id="s-broker-scope" class="form-input"> <select id="s-broker-scope" class="form-input" onchange="toggleBrokerFields()">
<option value="user">user</option> <option value="user">user</option>
<option value="global">global</option> <option value="global">global</option>
</select> </select>
...@@ -128,7 +128,7 @@ ...@@ -128,7 +128,7 @@
<div class="form-row" style="margin:0"> <div class="form-row" style="margin:0">
<label class="form-label">Username</label> <label class="form-label">Username</label>
<input type="text" id="s-broker-username" class="form-input" placeholder="alice or global"> <input type="text" id="s-broker-username" class="form-input" placeholder="alice or global">
<span class="form-hint">Use `global` when scope is `global`; otherwise provide the AISBF username.</span> <span class="form-hint">This is forced to `global` for global scope; user scope requires the AISBF username.</span>
</div> </div>
</div> </div>
<div style="display:grid;grid-template-columns:1fr 1fr;gap:1rem;align-items:start"> <div style="display:grid;grid-template-columns:1fr 1fr;gap:1rem;align-items:start">
...@@ -150,6 +150,11 @@ ...@@ -150,6 +150,11 @@
<input type="text" id="s-broker-advertised-endpoint" class="form-input" placeholder="http://127.0.0.1:8776"> <input type="text" id="s-broker-advertised-endpoint" class="form-input" placeholder="http://127.0.0.1:8776">
<span class="form-hint">Optional external URL advertised to the broker for this instance.</span> <span class="form-hint">Optional external URL advertised to the broker for this instance.</span>
</div> </div>
<div class="form-row">
<label class="form-label">Websocket path override</label>
<input type="text" id="s-broker-websocket-path" class="form-input" placeholder="/api/coderai/wss">
<span class="form-hint">Optional manual websocket route override for proxied or custom broker deployments; leave empty to derive from scope.</span>
</div>
<div style="display:grid;grid-template-columns:repeat(3, minmax(0, 1fr));gap:1rem;align-items:start"> <div style="display:grid;grid-template-columns:repeat(3, minmax(0, 1fr));gap:1rem;align-items:start">
<div class="form-row" style="margin:0"> <div class="form-row" style="margin:0">
<label class="form-label">Heartbeat seconds</label> <label class="form-label">Heartbeat seconds</label>
...@@ -164,7 +169,7 @@ ...@@ -164,7 +169,7 @@
<input type="number" id="s-broker-request-timeout" class="form-input" min="1" placeholder="30"> <input type="number" id="s-broker-request-timeout" class="form-input" min="1" placeholder="30">
</div> </div>
</div> </div>
<div style="display:grid;grid-template-columns:repeat(2, minmax(0, 1fr));gap:1rem;align-items:start"> <div style="display:grid;grid-template-columns:repeat(3, minmax(0, 1fr));gap:1rem;align-items:start">
<div class="form-row" style="margin:0"> <div class="form-row" style="margin:0">
<label class="form-label">Reconnect initial delay</label> <label class="form-label">Reconnect initial delay</label>
<input type="number" id="s-broker-reconnect-initial" class="form-input" min="1" placeholder="1"> <input type="number" id="s-broker-reconnect-initial" class="form-input" min="1" placeholder="1">
...@@ -173,6 +178,11 @@ ...@@ -173,6 +178,11 @@
<label class="form-label">Reconnect max delay</label> <label class="form-label">Reconnect max delay</label>
<input type="number" id="s-broker-reconnect-max" class="form-input" min="1" placeholder="60"> <input type="number" id="s-broker-reconnect-max" class="form-input" min="1" placeholder="60">
</div> </div>
<div class="form-row" style="margin:0">
<label class="form-label">WS ping interval (s)</label>
<input type="number" id="s-broker-ws-ping" class="form-input" min="5" placeholder="20">
<span class="form-hint" style="font-size:11px">Keeps connection alive through nginx proxies. Lower if you get 504 timeouts.</span>
</div>
</div> </div>
</div> </div>
</div> </div>
...@@ -188,6 +198,16 @@ function toggleHttps(){ ...@@ -188,6 +198,16 @@ function toggleHttps(){
function toggleBrokerFields(){ function toggleBrokerFields(){
document.getElementById('broker-fields').style.display = document.getElementById('broker-fields').style.display =
document.getElementById('s-broker-enabled').checked ? 'block' : 'none'; document.getElementById('s-broker-enabled').checked ? 'block' : 'none';
const scope = document.getElementById('s-broker-scope').value;
const usernameInput = document.getElementById('s-broker-username');
if(scope === 'global'){
usernameInput.value = 'global';
usernameInput.readOnly = true;
} else {
usernameInput.readOnly = false;
if(usernameInput.value === 'global') usernameInput.value = '';
}
} }
function showAlert(type, msg){ function showAlert(type, msg){
...@@ -232,11 +252,13 @@ async function loadSettings(){ ...@@ -232,11 +252,13 @@ async function loadSettings(){
document.getElementById('s-broker-client-id').value = broker.client_id ?? ''; document.getElementById('s-broker-client-id').value = broker.client_id ?? '';
document.getElementById('s-broker-registration-token').value = broker.registration_token ?? ''; document.getElementById('s-broker-registration-token').value = broker.registration_token ?? '';
document.getElementById('s-broker-advertised-endpoint').value = broker.advertised_endpoint ?? ''; document.getElementById('s-broker-advertised-endpoint').value = broker.advertised_endpoint ?? '';
document.getElementById('s-broker-websocket-path').value = broker.websocket_path ?? '';
document.getElementById('s-broker-heartbeat').value = broker.heartbeat_interval_seconds ?? 30; document.getElementById('s-broker-heartbeat').value = broker.heartbeat_interval_seconds ?? 30;
document.getElementById('s-broker-connect-timeout').value = broker.connect_timeout_seconds ?? 10; document.getElementById('s-broker-connect-timeout').value = broker.connect_timeout_seconds ?? 10;
document.getElementById('s-broker-request-timeout').value = broker.request_timeout_seconds ?? 30; document.getElementById('s-broker-request-timeout').value = broker.request_timeout_seconds ?? 30;
document.getElementById('s-broker-reconnect-initial').value = broker.reconnect_initial_delay_seconds ?? 1; document.getElementById('s-broker-reconnect-initial').value = broker.reconnect_initial_delay_seconds ?? 1;
document.getElementById('s-broker-reconnect-max').value = broker.reconnect_max_delay_seconds ?? 60; document.getElementById('s-broker-reconnect-max').value = broker.reconnect_max_delay_seconds ?? 60;
document.getElementById('s-broker-ws-ping').value = broker.websocket_ping_interval ?? 20;
toggleBrokerFields(); toggleBrokerFields();
}catch(e){ showAlert('error','Failed to load settings: '+e.message); } }catch(e){ showAlert('error','Failed to load settings: '+e.message); }
} }
...@@ -273,11 +295,13 @@ async function saveSettings(){ ...@@ -273,11 +295,13 @@ async function saveSettings(){
client_id: document.getElementById('s-broker-client-id').value.trim(), client_id: document.getElementById('s-broker-client-id').value.trim(),
registration_token: document.getElementById('s-broker-registration-token').value.trim(), registration_token: document.getElementById('s-broker-registration-token').value.trim(),
advertised_endpoint: document.getElementById('s-broker-advertised-endpoint').value.trim(), advertised_endpoint: document.getElementById('s-broker-advertised-endpoint').value.trim(),
websocket_path: document.getElementById('s-broker-websocket-path').value.trim(),
heartbeat_interval_seconds: parseInt(document.getElementById('s-broker-heartbeat').value) || 30, heartbeat_interval_seconds: parseInt(document.getElementById('s-broker-heartbeat').value) || 30,
connect_timeout_seconds: parseInt(document.getElementById('s-broker-connect-timeout').value) || 10, connect_timeout_seconds: parseInt(document.getElementById('s-broker-connect-timeout').value) || 10,
request_timeout_seconds: parseInt(document.getElementById('s-broker-request-timeout').value) || 30, request_timeout_seconds: parseInt(document.getElementById('s-broker-request-timeout').value) || 30,
reconnect_initial_delay_seconds: parseInt(document.getElementById('s-broker-reconnect-initial').value) || 1, reconnect_initial_delay_seconds: parseInt(document.getElementById('s-broker-reconnect-initial').value) || 1,
reconnect_max_delay_seconds: parseInt(document.getElementById('s-broker-reconnect-max').value) || 60, reconnect_max_delay_seconds: parseInt(document.getElementById('s-broker-reconnect-max').value) || 60,
websocket_ping_interval: parseInt(document.getElementById('s-broker-ws-ping').value) || 20,
transport: 'websocket', transport: 'websocket',
}, },
}; };
......
...@@ -43,12 +43,23 @@ global_file_path = None ...@@ -43,12 +43,23 @@ global_file_path = None
_aud_progress: dict = { _aud_progress: dict = {
"current": 0, "total": 0, "active": False, "current": 0, "total": 0, "active": False,
"started_at": 0.0, "it_per_s": 0.0, "unit": "it", "started_at": 0.0, "it_per_s": 0.0, "unit": "it",
"phase": "idle", "model": "",
} }
def _aud_progress_loading(model_name: str = ""):
_aud_progress["phase"] = "loading"
_aud_progress["active"] = True
_aud_progress["current"] = 0
_aud_progress["total"] = 0
_aud_progress["it_per_s"] = 0.0
_aud_progress["started_at"] = time.monotonic()
_aud_progress["model"] = model_name or ""
def _aud_progress_reset(total: int, unit: str = "it"): def _aud_progress_reset(total: int, unit: str = "it"):
_aud_progress["current"] = 0 _aud_progress["current"] = 0
_aud_progress["total"] = total _aud_progress["total"] = total
_aud_progress["active"] = True _aud_progress["active"] = True
_aud_progress["phase"] = "generating"
_aud_progress["started_at"] = time.monotonic() _aud_progress["started_at"] = time.monotonic()
_aud_progress["it_per_s"] = 0.0 _aud_progress["it_per_s"] = 0.0
_aud_progress["unit"] = unit _aud_progress["unit"] = unit
...@@ -56,6 +67,7 @@ def _aud_progress_reset(total: int, unit: str = "it"): ...@@ -56,6 +67,7 @@ def _aud_progress_reset(total: int, unit: str = "it"):
def _aud_progress_done(): def _aud_progress_done():
_aud_progress["current"] = max(_aud_progress["current"], _aud_progress["total"]) _aud_progress["current"] = max(_aud_progress["current"], _aud_progress["total"])
_aud_progress["active"] = False _aud_progress["active"] = False
_aud_progress["phase"] = "idle"
def _aud_progress_step(step: int): def _aud_progress_step(step: int):
_aud_progress["current"] = step _aud_progress["current"] = step
...@@ -196,6 +208,8 @@ async def get_audio_progress(): ...@@ -196,6 +208,8 @@ async def get_audio_progress():
"current": current, "current": current,
"total": total, "total": total,
"active": _aud_progress["active"], "active": _aud_progress["active"],
"phase": _aud_progress.get("phase", "idle"),
"model": _aud_progress.get("model", ""),
"pct": int(current / total * 100) if total > 0 else 0, "pct": int(current / total * 100) if total > 0 else 0,
"it_per_s": _aud_progress["it_per_s"], "it_per_s": _aud_progress["it_per_s"],
"elapsed": round(elapsed, 1), "elapsed": round(elapsed, 1),
...@@ -209,6 +223,7 @@ async def audio_generate(request: AudioGenerationRequest, http_request: Request ...@@ -209,6 +223,7 @@ async def audio_generate(request: AudioGenerationRequest, http_request: Request
Generate music, sound effects, or ambient audio. Generate music, sound effects, or ambient audio.
Compatible models: MusicGen, AudioGen, AudioLDM2, StableAudio. Compatible models: MusicGen, AudioGen, AudioLDM2, StableAudio.
""" """
_aud_progress_loading(request.model or "audio")
model_info = multi_model_manager.request_model(request.model, model_type="audio_gen") model_info = multi_model_manager.request_model(request.model, model_type="audio_gen")
model_name = model_info.get('model_name') model_name = model_info.get('model_name')
if not model_name: if not model_name:
......
...@@ -123,18 +123,30 @@ import time as _time ...@@ -123,18 +123,30 @@ import time as _time
_gen_progress: dict = { _gen_progress: dict = {
"current": 0, "total": 0, "active": False, "current": 0, "total": 0, "active": False,
"started_at": 0.0, "it_per_s": 0.0, "started_at": 0.0, "it_per_s": 0.0,
"phase": "idle", "model": "",
} }
def _progress_loading(model_name: str = ""):
_gen_progress["phase"] = "loading"
_gen_progress["active"] = True
_gen_progress["current"] = 0
_gen_progress["total"] = 0
_gen_progress["it_per_s"] = 0.0
_gen_progress["started_at"] = _time.monotonic()
_gen_progress["model"] = model_name or ""
def _progress_reset(total: int): def _progress_reset(total: int):
_gen_progress["current"] = 0 _gen_progress["current"] = 0
_gen_progress["total"] = total _gen_progress["total"] = total
_gen_progress["active"] = True _gen_progress["active"] = True
_gen_progress["phase"] = "generating"
_gen_progress["started_at"] = _time.monotonic() _gen_progress["started_at"] = _time.monotonic()
_gen_progress["it_per_s"] = 0.0 _gen_progress["it_per_s"] = 0.0
def _progress_done(): def _progress_done():
_gen_progress["current"] = _gen_progress["total"] _gen_progress["current"] = _gen_progress["total"]
_gen_progress["active"] = False _gen_progress["active"] = False
_gen_progress["phase"] = "idle"
def _progress_step(step: int): def _progress_step(step: int):
_gen_progress["current"] = step _gen_progress["current"] = step
...@@ -894,6 +906,8 @@ async def get_image_progress(): ...@@ -894,6 +906,8 @@ async def get_image_progress():
"current": _gen_progress["current"], "current": _gen_progress["current"],
"total": _gen_progress["total"], "total": _gen_progress["total"],
"active": _gen_progress["active"], "active": _gen_progress["active"],
"phase": _gen_progress.get("phase", "idle"),
"model": _gen_progress.get("model", ""),
"pct": int(_gen_progress["current"] / _gen_progress["total"] * 100) "pct": int(_gen_progress["current"] / _gen_progress["total"] * 100)
if _gen_progress["total"] > 0 else 0, if _gen_progress["total"] > 0 else 0,
"it_per_s": _gen_progress["it_per_s"], "it_per_s": _gen_progress["it_per_s"],
...@@ -944,6 +958,7 @@ async def create_image_generation(request: ImageGenerationRequest, http_request: ...@@ -944,6 +958,7 @@ async def create_image_generation(request: ImageGenerationRequest, http_request:
# ===================================================================== # =====================================================================
# Step 1: Ask the manager to resolve the model and manage VRAM # Step 1: Ask the manager to resolve the model and manage VRAM
# ===================================================================== # =====================================================================
_progress_loading(request.model or "image")
model_info = multi_model_manager.request_model( model_info = multi_model_manager.request_model(
requested_model=request.model, requested_model=request.model,
model_type="image" model_type="image"
...@@ -1173,6 +1188,7 @@ async def create_image_edit(request: ImageEditRequest, http_request: Request = N ...@@ -1173,6 +1188,7 @@ async def create_image_edit(request: ImageEditRequest, http_request: Request = N
if not request.image: if not request.image:
raise HTTPException(status_code=400, detail="image is required") raise HTTPException(status_code=400, detail="image is required")
_progress_loading(request.model or "image")
model_info = multi_model_manager.request_model(request.model, model_type="image") model_info = multi_model_manager.request_model(request.model, model_type="image")
model_name = model_info.get('model_name') model_name = model_info.get('model_name')
if not model_name: if not model_name:
...@@ -1306,6 +1322,7 @@ async def create_image_inpaint(request: ImageInpaintRequest, http_request: Reque ...@@ -1306,6 +1322,7 @@ async def create_image_inpaint(request: ImageInpaintRequest, http_request: Reque
global global_args global global_args
if not request.image or not request.mask: if not request.image or not request.mask:
raise HTTPException(status_code=400, detail="image and mask are required") raise HTTPException(status_code=400, detail="image and mask are required")
_progress_loading(request.model or "image")
model_info = multi_model_manager.request_model(request.model, model_type="image") model_info = multi_model_manager.request_model(request.model, model_type="image")
model_name = model_info.get('model_name') model_name = model_info.get('model_name')
if not model_name: if not model_name:
...@@ -1414,6 +1431,7 @@ def _run_upscale(upscaler, image_bytes: bytes, scale: int): ...@@ -1414,6 +1431,7 @@ def _run_upscale(upscaler, image_bytes: bytes, scale: int):
async def create_image_upscale(request: ImageUpscaleRequest, http_request: Request = None): async def create_image_upscale(request: ImageUpscaleRequest, http_request: Request = None):
"""Upscale an image using Real-ESRGAN or PIL LANCZOS fallback.""" """Upscale an image using Real-ESRGAN or PIL LANCZOS fallback."""
global global_args global global_args
_progress_loading(request.model or "image")
model_info = multi_model_manager.request_model(request.model, model_type="image") model_info = multi_model_manager.request_model(request.model, model_type="image")
model_name = model_info.get('model_name') or request.model model_name = model_info.get('model_name') or request.model
model_key = f"upscale:{model_name}" model_key = f"upscale:{model_name}"
......
...@@ -41,6 +41,11 @@ class BearerAuthMiddleware(BaseHTTPMiddleware): ...@@ -41,6 +41,11 @@ class BearerAuthMiddleware(BaseHTTPMiddleware):
if not path.startswith("/v1/") or path in self._EXEMPT_PATHS: if not path.startswith("/v1/") or path in self._EXEMPT_PATHS:
return await call_next(request) return await call_next(request)
# Requests from the ASGI broker bridge are in-process and have no real
# Bearer token. Identify them by the sentinel server tuple set in asgi_bridge.py.
if request.scope.get("server") == ("internal", 80):
return await call_next(request)
from codai.admin import routes as _admin_routes from codai.admin import routes as _admin_routes
sm = _admin_routes.session_manager sm = _admin_routes.session_manager
if sm is None: if sm is None:
...@@ -98,6 +103,7 @@ class _Bucket: ...@@ -98,6 +103,7 @@ class _Bucket:
self.window_start = now self.window_start = now
class RateLimitMiddleware(BaseHTTPMiddleware): class RateLimitMiddleware(BaseHTTPMiddleware):
"""Apply per-IP, per-route-prefix rate limiting to API endpoints.""" """Apply per-IP, per-route-prefix rate limiting to API endpoints."""
......
This diff is collapsed.
...@@ -41,11 +41,44 @@ except (ImportError, AttributeError): ...@@ -41,11 +41,44 @@ except (ImportError, AttributeError):
try: try:
from llama_cpp import Llama from llama_cpp import Llama
from llama_cpp.llama_chat_format import ChatFormatterResponse from llama_cpp.llama_chat_format import ChatFormatterResponse
import llama_cpp as _llama_cpp
LLAMA_CPP_AVAILABLE = True LLAMA_CPP_AVAILABLE = True
except ImportError: except ImportError:
LLAMA_CPP_AVAILABLE = False LLAMA_CPP_AVAILABLE = False
Llama = None Llama = None
ChatFormatterResponse = None ChatFormatterResponse = None
_llama_cpp = None
def _install_layer_log_callback():
"""Replace llama.cpp's log callback with one that prints load-time layer/buffer
messages directly to stdout. Returns the callback object — keep a reference
alive for the duration of the load so ctypes doesn't garbage-collect it."""
if _llama_cpp is None:
return None
# Keywords that identify interesting load-phase messages
_KEEP = (
'llm_load_tensors', 'llm_load_print_meta',
'offload', 'layer', 'buffer size', 'buffer type',
'GPU', 'CUDA', 'Vulkan', 'Metal', 'ROCm', 'SYCL',
'CPU', 'VRAM', 'n_layer', 'n_gpu_layers',
)
@_llama_cpp.llama_log_callback
def _cb(level, text, user_data):
try:
msg = (text.decode('utf-8', errors='replace') if isinstance(text, bytes) else str(text)).rstrip()
if msg and any(k in msg for k in _KEEP):
print(f" [llama.cpp] {msg}", flush=True)
except Exception:
pass
try:
_llama_cpp.llama_log_set(_cb, None)
except Exception:
return None
return _cb # caller must hold this reference
class VulkanBackend(ModelBackend): class VulkanBackend(ModelBackend):
...@@ -450,17 +483,18 @@ class VulkanBackend(ModelBackend): ...@@ -450,17 +483,18 @@ class VulkanBackend(ModelBackend):
# Try to find GGUF files in the repository # Try to find GGUF files in the repository
try: try:
from huggingface_hub import list_repo_files, hf_hub_download from huggingface_hub import list_repo_files, hf_hub_download
from codai.models.cache import get_hf_hub_cache_dir
print(f"DEBUG: Searching for GGUF files in {model_path}...") print(f"DEBUG: Searching for GGUF files in {model_path}...")
files = list(list_repo_files(model_path, repo_type="model")) files = list(list_repo_files(model_path, repo_type="model"))
gguf_files = [f for f in files if f.lower().endswith('.gguf')] gguf_files = [f for f in files if f.lower().endswith('.gguf')]
if gguf_files: if gguf_files:
# Prefer Q4_K_M or Q4_K quantizations, otherwise use first available # Prefer Q4_K_M or Q4_K quantizations, otherwise use first available
preferred = [f for f in gguf_files if 'q4_k_m' in f.lower() or 'q4_k' in f.lower()] preferred = [f for f in gguf_files if 'q4_k_m' in f.lower() or 'q4_k' in f.lower()]
selected = preferred[0] if preferred else gguf_files[0] selected = preferred[0] if preferred else gguf_files[0]
print(f"DEBUG: Found GGUF files: {gguf_files}") print(f"DEBUG: Found GGUF files: {gguf_files}")
print(f"DEBUG: Selected: {selected}") print(f"DEBUG: Selected: {selected}")
model_path = hf_hub_download(repo_id=model_path, filename=selected, cache_dir=kwargs.get('cache_dir')) model_path = hf_hub_download(repo_id=model_path, filename=selected, cache_dir=kwargs.get('cache_dir') or get_hf_hub_cache_dir())
print(f"DEBUG: Downloaded: {model_path}") print(f"DEBUG: Downloaded: {model_path}")
else: else:
print(f"Warning: No GGUF files found in {model_path}, trying direct download...") print(f"Warning: No GGUF files found in {model_path}, trying direct download...")
...@@ -471,8 +505,9 @@ class VulkanBackend(ModelBackend): ...@@ -471,8 +505,9 @@ class VulkanBackend(ModelBackend):
# Try to get from HuggingFace # Try to get from HuggingFace
try: try:
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from codai.models.cache import get_hf_hub_cache_dir
# Download the GGUF file # Download the GGUF file
model_path = hf_hub_download(repo_id=model_path, filename="*.gguf", cache_dir=kwargs.get('cache_dir')) model_path = hf_hub_download(repo_id=model_path, filename="*.gguf", cache_dir=kwargs.get('cache_dir') or get_hf_hub_cache_dir())
except Exception as e: except Exception as e:
print(f"Warning: Could not download from HuggingFace: {e}") print(f"Warning: Could not download from HuggingFace: {e}")
# Try as-is # Try as-is
...@@ -557,18 +592,41 @@ class VulkanBackend(ModelBackend): ...@@ -557,18 +592,41 @@ class VulkanBackend(ModelBackend):
os.environ['CUDA_VISIBLE_DEVICES'] = '0' os.environ['CUDA_VISIBLE_DEVICES'] = '0'
# llama-cpp-python will use CUDA when available # llama-cpp-python will use CUDA when available
# Pre-load summary
gpu_label = "all" if self.n_gpu_layers == -1 else str(self.n_gpu_layers)
print(f" n_gpu_layers : {gpu_label} | n_ctx : {self.n_ctx} | main_gpu : {self.main_gpu}")
if _llama_cpp:
gpu_supported = _llama_cpp.llama_supports_gpu_offload()
print(f" GPU offload : {'supported' if gpu_supported else 'NOT supported by this build'}")
_log_cb = _install_layer_log_callback()
try: try:
self.model = Llama(**llama_kwargs) self.model = Llama(**llama_kwargs)
# Try to detect and set up chat template
self._finalize_chat_template_detection()
print(f"DEBUG: VulkanBackend loaded model: {model_path}")
print(f"DEBUG: n_gpu_layers={self.n_gpu_layers}, n_ctx={self.n_ctx}, no_ram={no_ram}")
print(f"DEBUG: chat_template={self.chat_template}")
except Exception as e: except Exception as e:
print(f"Error loading GGUF model: {e}") print(f"Error loading GGUF model: {e}")
raise raise
finally:
# Restore llama.cpp's default (quiet) logging after load
if _llama_cpp:
try:
_llama_cpp.llama_log_set(None, None)
except Exception:
pass
_log_cb = None # release callback
# Post-load layer/buffer summary
try:
n_total = _llama_cpp.llama_model_n_layer(self.model.model)
n_gpu_actual = n_total if self.n_gpu_layers == -1 else min(self.n_gpu_layers, n_total)
n_cpu = n_total - n_gpu_actual
print(f" Layers total : {n_total}")
print(f" Layers → GPU : {n_gpu_actual} | Layers → CPU : {n_cpu}")
except Exception:
pass
# Try to detect and set up chat template
self._finalize_chat_template_detection()
print(f" chat_template: {self.chat_template}")
def generate( def generate(
self, self,
......
"""Helpers for executing in-process ASGI HTTP requests.""" """Helpers for executing in-process ASGI HTTP requests."""
import base64
import logging
import uuid
from urllib.parse import urlencode from urllib.parse import urlencode
logger = logging.getLogger(__name__)
def _build_multipart_body(multipart):
boundary = f"coderai-broker-{uuid.uuid4().hex}"
chunks = []
for field in multipart.get("fields") or []:
name = str(field.get("name") or "")
value = str(field.get("value") or "")
chunks.append(f"--{boundary}\r\n".encode("utf-8"))
chunks.append(f'Content-Disposition: form-data; name="{name}"\r\n\r\n'.encode("utf-8"))
chunks.append(value.encode("utf-8"))
chunks.append(b"\r\n")
for file_entry in multipart.get("files") or []:
name = str(file_entry.get("name") or "file")
filename = str(file_entry.get("filename") or "upload.bin")
content_type = str(file_entry.get("content_type") or "application/octet-stream")
data_base64 = file_entry.get("data_base64") or ""
file_bytes = base64.b64decode(data_base64) if data_base64 else b""
chunks.append(f"--{boundary}\r\n".encode("utf-8"))
chunks.append(
f'Content-Disposition: form-data; name="{name}"; filename="{filename}"\r\n'.encode("utf-8")
)
chunks.append(f"Content-Type: {content_type}\r\n\r\n".encode("utf-8"))
chunks.append(file_bytes)
chunks.append(b"\r\n")
chunks.append(f"--{boundary}--\r\n".encode("utf-8"))
return b"".join(chunks), f"multipart/form-data; boundary={boundary}"
async def execute_internal_request(app, *, method, path, headers=None, query=None, body=b""): async def execute_internal_request(app, *, method, path, headers=None, query=None, body=b""):
logger.debug(
"ASGI bridge → %s %s query=%s body_bytes=%d",
method.upper(), path, query or {}, len(body),
)
request_headers = [] request_headers = []
for key, value in (headers or {}).items(): for key, value in (headers or {}).items():
request_headers.append((key.lower().encode("latin-1"), str(value).encode("latin-1"))) request_headers.append((key.lower().encode("latin-1"), str(value).encode("latin-1")))
...@@ -49,4 +89,14 @@ async def execute_internal_request(app, *, method, path, headers=None, query=Non ...@@ -49,4 +89,14 @@ async def execute_internal_request(app, *, method, path, headers=None, query=Non
response["body"] += message.get("body", b"") response["body"] += message.get("body", b"")
await app(scope, receive, send) await app(scope, receive, send)
body_preview = response["body"][:200].decode("utf-8", errors="replace") if response["body"] else ""
logger.debug(
"ASGI bridge ← %s %s status=%d content-type=%s body_bytes=%d body_preview=%r",
method.upper(), path,
response["status_code"],
response["headers"].get("content-type", ""),
len(response["body"]),
body_preview,
)
return response return response
"""Broker capability documents and registration payloads.""" """Broker capability documents and registration payloads."""
import glob
import os
import platform import platform
import socket import socket
from typing import Any, Dict, Sequence from typing import Any, Dict, Sequence
...@@ -41,15 +43,67 @@ DEFAULT_STUDIO_ENDPOINTS = [ ...@@ -41,15 +43,67 @@ DEFAULT_STUDIO_ENDPOINTS = [
def build_hardware_summary() -> Dict[str, Any]: def build_hardware_summary() -> Dict[str, Any]:
"""Build a conservative default hardware summary.""" """Build a conservative hardware summary with VRAM when available."""
gpus = []
total_vram_mb = 0
available_vram_mb = 0
try:
import torch
if torch.cuda.is_available():
gpu_count = torch.cuda.device_count()
for index in range(gpu_count):
props = torch.cuda.get_device_properties(index)
device_total_mb = int(props.total_memory / (1024 * 1024))
if index == torch.cuda.current_device():
free_bytes, total_bytes = torch.cuda.mem_get_info()
total_vram_mb = int(total_bytes / (1024 * 1024))
available_vram_mb = int(free_bytes / (1024 * 1024))
gpus.append(
{
"index": index,
"name": torch.cuda.get_device_name(index),
"total_vram_mb": device_total_mb,
}
)
if gpus:
if total_vram_mb == 0:
total_vram_mb = sum(gpu["total_vram_mb"] for gpu in gpus)
if available_vram_mb == 0 and total_vram_mb:
available_vram_mb = total_vram_mb
except Exception:
pass
if not gpus:
for total_path in sorted(glob.glob("/sys/class/drm/card*/device/mem_info_vram_total")):
used_path = total_path.replace("vram_total", "vram_used")
if not os.path.exists(used_path):
continue
try:
device_total_mb = int(int(open(total_path).read()) / (1024 * 1024))
device_used_mb = int(int(open(used_path).read()) / (1024 * 1024))
device_available_mb = max(0, device_total_mb - device_used_mb)
card_name = os.path.basename(os.path.dirname(os.path.dirname(total_path)))
gpus.append(
{
"name": card_name,
"total_vram_mb": device_total_mb,
}
)
total_vram_mb += device_total_mb
available_vram_mb += device_available_mb
except Exception:
continue
return { return {
"hostname": socket.gethostname(), "hostname": socket.gethostname(),
"platform": platform.platform(), "platform": platform.platform(),
"gpus": [], "gpus": gpus,
"gpu_count": 0, "gpu_count": len(gpus),
"total_vram_mb": 0, "total_vram_mb": total_vram_mb,
"available_vram_mb": 0, "available_vram_mb": available_vram_mb,
} }
...@@ -64,6 +118,7 @@ def build_capabilities_document( ...@@ -64,6 +118,7 @@ def build_capabilities_document(
"server": "codai", "server": "codai",
"version": version, "version": version,
"transports": ["websocket"], "transports": ["websocket"],
"tunnel_only": True,
"openai_compat": { "openai_compat": {
"chat_completions": True, "chat_completions": True,
"responses": False, "responses": False,
...@@ -88,15 +143,23 @@ def build_register_message( ...@@ -88,15 +143,23 @@ def build_register_message(
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""Build broker registration frame.""" """Build broker registration frame."""
registration_token = runtime.registration_token
return { return {
"v": 1, "v": 1,
"op": "register", "op": "register",
"request_id": request_id, "request_id": request_id,
"registration_token": registration_token,
"capabilities": capabilities,
"payload": { "payload": {
"endpoint": runtime.advertised_endpoint, "endpoint": runtime.advertised_endpoint,
"transport": runtime.transport, "transport": runtime.transport,
"registration_token": runtime.headers.get("Authorization", "").removeprefix("Bearer "), "registration_token": registration_token,
"hardware": hardware, "hardware": hardware,
"gpus": (hardware or {}).get("gpus", []),
"gpu_count": (hardware or {}).get("gpu_count", 0),
"total_vram_mb": (hardware or {}).get("total_vram_mb", 0),
"available_vram_mb": (hardware or {}).get("available_vram_mb", 0),
"studio_endpoints": list(studio_endpoints or DEFAULT_STUDIO_ENDPOINTS), "studio_endpoints": list(studio_endpoints or DEFAULT_STUDIO_ENDPOINTS),
"capabilities": capabilities, "capabilities": capabilities,
}, },
......
This diff is collapsed.
...@@ -21,12 +21,14 @@ class BrokerConfig: ...@@ -21,12 +21,14 @@ class BrokerConfig:
client_id: str = "" client_id: str = ""
registration_token: str = "" registration_token: str = ""
advertised_endpoint: str = "" advertised_endpoint: str = ""
websocket_path: str = ""
transport: str = "websocket" transport: str = "websocket"
heartbeat_interval_seconds: int = 30 heartbeat_interval_seconds: int = 30
connect_timeout_seconds: int = 10 connect_timeout_seconds: int = 10
request_timeout_seconds: int = 30 request_timeout_seconds: int = 30
reconnect_initial_delay_seconds: int = 1 reconnect_initial_delay_seconds: int = 1
reconnect_max_delay_seconds: int = 60 reconnect_max_delay_seconds: int = 60
websocket_ping_interval: int = 20
@dataclass @dataclass
...@@ -36,13 +38,28 @@ class BrokerRuntimeConfig: ...@@ -36,13 +38,28 @@ class BrokerRuntimeConfig:
enabled: bool enabled: bool
websocket_url: str = "" websocket_url: str = ""
headers: Dict[str, str] = field(default_factory=dict) headers: Dict[str, str] = field(default_factory=dict)
provider_id: str = ""
client_id: str = ""
username: str = ""
registration_token: str = ""
advertised_endpoint: str = "" advertised_endpoint: str = ""
websocket_path: str = ""
transport: str = "websocket" transport: str = "websocket"
heartbeat_interval_seconds: int = 30 heartbeat_interval_seconds: int = 30
connect_timeout_seconds: int = 10 connect_timeout_seconds: int = 10
request_timeout_seconds: int = 30 request_timeout_seconds: int = 30
reconnect_initial_delay_seconds: int = 1 reconnect_initial_delay_seconds: int = 1
reconnect_max_delay_seconds: int = 60 reconnect_max_delay_seconds: int = 60
websocket_ping_interval: int = 20
def _join_broker_path(base_path: str, suffix: str) -> str:
normalized_base = (base_path or "").rstrip("/")
normalized_suffix = suffix if suffix.startswith("/") else f"/{suffix}"
if normalized_base.endswith("/api") and normalized_suffix.startswith("/api/"):
return f"{normalized_base}{normalized_suffix[4:]}"
return f"{normalized_base}{normalized_suffix}" if normalized_base else normalized_suffix
def build_broker_runtime_config(config: BrokerConfig) -> BrokerRuntimeConfig: def build_broker_runtime_config(config: BrokerConfig) -> BrokerRuntimeConfig:
...@@ -50,18 +67,29 @@ def build_broker_runtime_config(config: BrokerConfig) -> BrokerRuntimeConfig: ...@@ -50,18 +67,29 @@ def build_broker_runtime_config(config: BrokerConfig) -> BrokerRuntimeConfig:
runtime = BrokerRuntimeConfig( runtime = BrokerRuntimeConfig(
enabled=config.enabled, enabled=config.enabled,
provider_id=config.provider_id,
client_id=config.client_id,
username=config.username,
registration_token=config.registration_token,
advertised_endpoint=config.advertised_endpoint, advertised_endpoint=config.advertised_endpoint,
websocket_path=config.websocket_path,
transport=config.transport, transport=config.transport,
heartbeat_interval_seconds=config.heartbeat_interval_seconds, heartbeat_interval_seconds=config.heartbeat_interval_seconds,
connect_timeout_seconds=config.connect_timeout_seconds, connect_timeout_seconds=config.connect_timeout_seconds,
request_timeout_seconds=config.request_timeout_seconds, request_timeout_seconds=config.request_timeout_seconds,
reconnect_initial_delay_seconds=config.reconnect_initial_delay_seconds, reconnect_initial_delay_seconds=config.reconnect_initial_delay_seconds,
reconnect_max_delay_seconds=config.reconnect_max_delay_seconds, reconnect_max_delay_seconds=config.reconnect_max_delay_seconds,
websocket_ping_interval=config.websocket_ping_interval,
) )
if not config.enabled: if not config.enabled:
return runtime return runtime
if config.scope == "global": custom_websocket_path = (config.websocket_path or "").strip()
if custom_websocket_path:
suffix = custom_websocket_path
if not suffix.startswith("/"):
suffix = f"/{suffix}"
elif config.scope == "global":
if config.username != "global": if config.username != "global":
raise BrokerConfigError("global broker scope requires username 'global'") raise BrokerConfigError("global broker scope requires username 'global'")
suffix = "/api/coderai/wss" suffix = "/api/coderai/wss"
...@@ -88,7 +116,7 @@ def build_broker_runtime_config(config: BrokerConfig) -> BrokerRuntimeConfig: ...@@ -88,7 +116,7 @@ def build_broker_runtime_config(config: BrokerConfig) -> BrokerRuntimeConfig:
scheme = {"http": "ws", "https": "wss"}.get(split_url.scheme, split_url.scheme) scheme = {"http": "ws", "https": "wss"}.get(split_url.scheme, split_url.scheme)
base_path = split_url.path.rstrip("/") base_path = split_url.path.rstrip("/")
path = f"{base_path}{suffix}" if base_path else suffix path = _join_broker_path(base_path, suffix)
query = urlencode( query = urlencode(
{ {
"provider_id": config.provider_id, "provider_id": config.provider_id,
...@@ -103,6 +131,7 @@ def build_broker_runtime_config(config: BrokerConfig) -> BrokerRuntimeConfig: ...@@ -103,6 +131,7 @@ def build_broker_runtime_config(config: BrokerConfig) -> BrokerRuntimeConfig:
"x-coderai-provider-id": config.provider_id, "x-coderai-provider-id": config.provider_id,
"x-coderai-client-id": config.client_id, "x-coderai-client-id": config.client_id,
"x-coderai-username": config.username, "x-coderai-username": config.username,
"x-coderai-registration-token": config.registration_token,
"x-coderai-advertised-endpoint": config.advertised_endpoint, "x-coderai-advertised-endpoint": config.advertised_endpoint,
} }
return runtime return runtime
...@@ -3,13 +3,17 @@ ...@@ -3,13 +3,17 @@
from __future__ import annotations from __future__ import annotations
import json import json
import logging
from base64 import b64encode from base64 import b64encode
from base64 import b64decode
from time import perf_counter from time import perf_counter
from typing import Any from typing import Any
from codai.broker.asgi_bridge import execute_internal_request from codai.broker.asgi_bridge import execute_internal_request
from codai.broker.models import error_envelope, success_envelope from codai.broker.models import error_envelope, success_envelope
logger = logging.getLogger(__name__)
SUPPORTED_PREFIXES = ( SUPPORTED_PREFIXES = (
"/v1/models", "/v1/models",
"/v1/chat/completions", "/v1/chat/completions",
...@@ -17,9 +21,18 @@ SUPPORTED_PREFIXES = ( ...@@ -17,9 +21,18 @@ SUPPORTED_PREFIXES = (
"/v1/audio", "/v1/audio",
"/v1/video", "/v1/video",
"/v1/pipelines", "/v1/pipelines",
"/v1/files",
"/coderai/capabilities", "/coderai/capabilities",
"/admin",
"/static",
) )
OP_ROUTE_MAP = {
"models.list": ("GET", "/v1/models"),
"chat.completions": ("POST", "/v1/chat/completions"),
"capabilities": ("GET", "/coderai/capabilities"),
}
TEXT_CONTENT_TYPES = ( TEXT_CONTENT_TYPES = (
"application/json", "application/json",
"application/ld+json", "application/ld+json",
...@@ -50,9 +63,46 @@ def _is_text_response(content_type: str | None) -> bool: ...@@ -50,9 +63,46 @@ def _is_text_response(content_type: str | None) -> bool:
async def execute_broker_request(app, envelope): async def execute_broker_request(app, envelope):
"""Validate and execute a broker request envelope.""" """Validate and execute a broker request envelope."""
logger.debug(
"broker dispatch → op=%s request_id=%s path=%r method=%r stream=%s",
envelope.op, envelope.request_id, envelope.path, envelope.method, envelope.stream,
)
if envelope.op == "proxy":
proxy_payload = envelope.payload or {}
endpoint_path = str(proxy_payload.get("endpoint_path") or envelope.path or "").strip()
if endpoint_path and not endpoint_path.startswith("/"):
endpoint_path = f"/{endpoint_path}"
envelope.path = endpoint_path
envelope.method = str(proxy_payload.get("method") or envelope.method or "GET").upper()
envelope.headers = dict(proxy_payload.get("headers") or envelope.headers)
envelope.query = dict(proxy_payload.get("query_params") or envelope.query)
envelope.stream = bool(proxy_payload.get("stream", envelope.stream))
if "body" in proxy_payload:
envelope.payload = proxy_payload.get("body")
elif proxy_payload.get("body_base64") is not None:
envelope.payload = b64decode(proxy_payload.get("body_base64") or "")
elif proxy_payload.get("multipart") is not None:
envelope.payload = {"_broker_multipart": proxy_payload.get("multipart")}
logger.debug("broker dispatch proxy resolved → %s %s", envelope.method, envelope.path)
elif envelope.op in OP_ROUTE_MAP:
envelope.method, envelope.path = OP_ROUTE_MAP[envelope.op]
logger.debug("broker dispatch op mapped → %s %s", envelope.method, envelope.path)
elif not envelope.path:
logger.warning("broker dispatch unsupported op=%s request_id=%s", envelope.op, envelope.request_id)
return error_envelope(
envelope.request_id,
code="unsupported_operation",
message=f"Unsupported broker op: {envelope.op}",
)
envelope.validate() envelope.validate()
if not is_supported_path(envelope.path): if not is_supported_path(envelope.path):
logger.warning(
"broker dispatch unsupported path=%r op=%s request_id=%s",
envelope.path, envelope.op, envelope.request_id,
)
return error_envelope( return error_envelope(
envelope.request_id, envelope.request_id,
code="unsupported_endpoint", code="unsupported_endpoint",
...@@ -60,18 +110,28 @@ async def execute_broker_request(app, envelope): ...@@ -60,18 +110,28 @@ async def execute_broker_request(app, envelope):
) )
body: bytes body: bytes
if isinstance(envelope.payload, (dict, list)): if isinstance(envelope.payload, dict) and "_broker_multipart" in envelope.payload:
from codai.broker.asgi_bridge import _build_multipart_body
body, multipart_content_type = _build_multipart_body(envelope.payload["_broker_multipart"] or {})
headers = dict(envelope.headers)
headers["content-type"] = multipart_content_type
elif isinstance(envelope.payload, (dict, list)):
body = json.dumps(envelope.payload, separators=(",", ":")).encode("utf-8") body = json.dumps(envelope.payload, separators=(",", ":")).encode("utf-8")
headers = dict(envelope.headers)
elif isinstance(envelope.payload, str): elif isinstance(envelope.payload, str):
body = envelope.payload.encode("utf-8") body = envelope.payload.encode("utf-8")
headers = dict(envelope.headers)
elif isinstance(envelope.payload, bytes): elif isinstance(envelope.payload, bytes):
body = envelope.payload body = envelope.payload
headers = dict(envelope.headers)
elif envelope.payload is None: elif envelope.payload is None:
body = b"" body = b""
headers = dict(envelope.headers)
else: else:
body = json.dumps(envelope.payload, separators=(",", ":")).encode("utf-8") body = json.dumps(envelope.payload, separators=(",", ":")).encode("utf-8")
headers = dict(envelope.headers)
headers = dict(envelope.headers)
if body and "content-type" not in {key.lower() for key in headers}: if body and "content-type" not in {key.lower() for key in headers}:
headers["content-type"] = envelope.content_type headers["content-type"] = envelope.content_type
...@@ -103,17 +163,36 @@ async def execute_broker_request(app, envelope): ...@@ -103,17 +163,36 @@ async def execute_broker_request(app, envelope):
payload["content_type"] = content_type payload["content_type"] = content_type
if _is_text_response(content_type): if _is_text_response(content_type):
payload["body"] = response["body"].decode("utf-8") body_text = response["body"].decode("utf-8")
payload["body"] = body_text
logger.debug(
"broker dispatch ← op=%s request_id=%s status=%d elapsed_ms=%s body_bytes=%d body_preview=%r",
envelope.op, envelope.request_id,
response["status_code"], elapsed_ms,
len(response["body"]),
body_text[:300],
)
else: else:
payload["body_base64"] = b64encode(response["body"]).decode("ascii") payload["body_base64"] = b64encode(response["body"]).decode("ascii")
filename = response_headers.get("x-filename") filename = response_headers.get("x-filename")
if filename: if filename:
payload["filename"] = filename payload["filename"] = filename
logger.debug(
"broker dispatch ← op=%s request_id=%s status=%d elapsed_ms=%s body_bytes=%d (binary)",
envelope.op, envelope.request_id,
response["status_code"], elapsed_ms,
len(response["body"]),
)
if envelope.stream: if envelope.stream:
payload["stream"] = True payload["stream"] = True
return success_envelope( result = success_envelope(
envelope.request_id, envelope.request_id,
payload=payload, payload=payload,
metrics={"elapsed_ms": elapsed_ms}, metrics={"elapsed_ms": elapsed_ms},
) )
logger.debug(
"broker dispatch envelope_bytes=%d op=%s request_id=%s",
len(json.dumps(result)), envelope.op, envelope.request_id,
)
return result
...@@ -9,8 +9,9 @@ class BrokerRequestEnvelope: ...@@ -9,8 +9,9 @@ class BrokerRequestEnvelope:
"""Normalized broker request payload.""" """Normalized broker request payload."""
request_id: str request_id: str
method: str op: str
path: str method: str = "GET"
path: str = ""
headers: Dict[str, str] = field(default_factory=dict) headers: Dict[str, str] = field(default_factory=dict)
query: Dict[str, Any] = field(default_factory=dict) query: Dict[str, Any] = field(default_factory=dict)
payload: Any = None payload: Any = None
...@@ -22,9 +23,13 @@ class BrokerRequestEnvelope: ...@@ -22,9 +23,13 @@ class BrokerRequestEnvelope:
if not self.request_id or not isinstance(self.request_id, str): if not self.request_id or not isinstance(self.request_id, str):
raise ValueError("request_id is required") raise ValueError("request_id is required")
if not self.op or not isinstance(self.op, str):
raise ValueError("op is required")
if self.op != "proxy" and not self.path:
raise ValueError("path is required")
if not self.method or not isinstance(self.method, str): if not self.method or not isinstance(self.method, str):
raise ValueError("method is required") raise ValueError("method is required")
if not self.path or not isinstance(self.path, str): if self.path and not isinstance(self.path, str):
raise ValueError("path is required") raise ValueError("path is required")
...@@ -32,8 +37,9 @@ def success_envelope(request_id: str, payload: Any, event: str | None = None, me ...@@ -32,8 +37,9 @@ def success_envelope(request_id: str, payload: Any, event: str | None = None, me
"""Build a success response envelope.""" """Build a success response envelope."""
envelope = { envelope = {
"v": 1,
"request_id": request_id, "request_id": request_id,
"ok": True, "status": "ok",
"payload": payload, "payload": payload,
} }
if event is not None: if event is not None:
...@@ -53,7 +59,8 @@ def error_envelope(request_id: str, code: str, message: str, details: Dict[str, ...@@ -53,7 +59,8 @@ def error_envelope(request_id: str, code: str, message: str, details: Dict[str,
if details is not None: if details is not None:
error["details"] = details error["details"] = details
return { return {
"v": 1,
"request_id": request_id, "request_id": request_id,
"ok": False, "status": "error",
"error": error, "error": error,
} }
...@@ -18,14 +18,17 @@ class BrokerService: ...@@ -18,14 +18,17 @@ class BrokerService:
self.client.dispatcher = dispatch self.client.dispatcher = dispatch
self.task: asyncio.Task | None = None self.task: asyncio.Task | None = None
self._started = False
def start(self): def start(self):
if not self.client.runtime.enabled or self.task is not None: if not self.client.runtime.enabled or self.task is not None or self._started:
return return
self._started = True
self.task = asyncio.create_task(self.client.run_forever()) self.task = asyncio.create_task(self.client.run_forever())
async def stop(self): async def stop(self):
if self.task is None: if self.task is None:
self._started = False
return return
task = self.task task = self.task
self.task = None self.task = None
...@@ -34,3 +37,5 @@ class BrokerService: ...@@ -34,3 +37,5 @@ class BrokerService:
await task await task
except asyncio.CancelledError: except asyncio.CancelledError:
pass pass
finally:
self._started = False
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
import sys import sys
import os import os
import logging import logging
import threading as _t
# Import configuration from codai modules # Import configuration from codai modules
from codai.cli import parse_args from codai.cli import parse_args
...@@ -30,6 +31,83 @@ from codai.broker import BrokerConfigError, build_broker_runtime_config ...@@ -30,6 +31,83 @@ from codai.broker import BrokerConfigError, build_broker_runtime_config
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _migrate_hf_gguf_to_gguf_cache() -> None:
"""Move GGUF files stored in the HF cache into the flat GGUF cache directory.
Runs once at startup in a background thread. For repos whose only
non-trivial content is GGUF files, the HF cache entry is removed after
the files are safely copied across.
"""
import shutil
from codai.models.cache import get_hf_hub_cache_dir, get_model_cache_dir
hf_dir = get_hf_hub_cache_dir()
if not os.path.exists(hf_dir):
return
gguf_cache = get_model_cache_dir()
try:
from huggingface_hub import scan_cache_dir
info = scan_cache_dir(hf_dir)
except Exception:
return
_TRIVIAL_EXTS = {'.json', '.txt', '.md', '.py', '.gitattributes', '.model', '.tiktoken', '.vocab'}
migrated_count = 0
repos_to_purge = [] # HF cache repo dirs safe to delete after migration
for repo in info.repos:
if not repo.revisions:
continue
latest_rev = sorted(repo.revisions, key=lambda r: r.commit_hash)[-1]
gguf_files = [f for f in latest_rev.files if f.file_name.endswith('.gguf')]
if not gguf_files:
continue
# Determine whether this repo contains ONLY gguf + trivial metadata
non_trivial = [
f for f in latest_rev.files
if os.path.splitext(f.file_name)[1].lower() not in _TRIVIAL_EXTS
]
gguf_only_repo = non_trivial and all(f.file_name.endswith('.gguf') for f in non_trivial)
all_migrated = True
for f in gguf_files:
dest = os.path.join(gguf_cache, os.path.basename(f.file_name))
if os.path.exists(dest):
continue # already present in GGUF cache
try:
src = os.path.realpath(str(f.file_path)) # resolve symlink → blob
shutil.copy2(src, dest)
migrated_count += 1
logger.info("Migrated GGUF: %s → %s", f.file_name, dest)
except Exception as exc:
logger.warning("Could not migrate %s: %s", f.file_name, exc)
all_migrated = False
if gguf_only_repo and all_migrated:
repo_dir = os.path.join(hf_dir, f"models--{repo.repo_id.replace('/', '--')}")
if os.path.isdir(repo_dir):
repos_to_purge.append((repo.repo_id, repo_dir))
for repo_id, repo_dir in repos_to_purge:
try:
shutil.rmtree(repo_dir)
logger.info("Removed migrated HF cache entry: %s", repo_id)
except Exception as exc:
logger.warning("Could not remove HF cache entry %s: %s", repo_id, exc)
if migrated_count or repos_to_purge:
logger.info(
"GGUF cache migration: %d file(s) moved to %s, %d HF cache entr%s cleaned up.",
migrated_count, gguf_cache,
len(repos_to_purge), "ies" if len(repos_to_purge) != 1 else "y",
)
def main(): def main():
"""Main entry point for the codai server.""" """Main entry point for the codai server."""
# Suppress unraisable exceptions from LlamaModel.__del__ # Suppress unraisable exceptions from LlamaModel.__del__
...@@ -54,10 +132,22 @@ def main(): ...@@ -54,10 +132,22 @@ def main():
config_mgr = ConfigManager(config_dir) config_mgr = ConfigManager(config_dir)
config = config_mgr.load() config = config_mgr.load()
# Apply cache directory overrides from config before any cache module is used # Apply cache directory overrides from config before any cache module is used.
# We set env vars AND patch huggingface_hub.constants in case the library was
# already imported (constants are computed once at import time from env vars).
if config.models.hf_cache_dir: if config.models.hf_cache_dir:
hf_hub_cache = os.path.join(config.models.hf_cache_dir, 'hub')
os.environ['HF_HOME'] = config.models.hf_cache_dir os.environ['HF_HOME'] = config.models.hf_cache_dir
os.environ['HUGGINGFACE_HUB_CACHE'] = config.models.hf_cache_dir os.environ['HUGGINGFACE_HUB_CACHE'] = hf_hub_cache
try:
import sys as _sys
if 'huggingface_hub.constants' in _sys.modules:
import huggingface_hub.constants as _hfc
_hfc.HF_HUB_CACHE = hf_hub_cache
if hasattr(_hfc, 'HF_HOME'):
_hfc.HF_HOME = config.models.hf_cache_dir
except Exception:
pass
if config.models.gguf_cache_dir: if config.models.gguf_cache_dir:
os.environ['CODERAI_CACHE_DIR'] = config.models.gguf_cache_dir os.environ['CODERAI_CACHE_DIR'] = config.models.gguf_cache_dir
...@@ -75,7 +165,7 @@ def main(): ...@@ -75,7 +165,7 @@ def main():
# Initialize admin session manager and expose config to admin routes # Initialize admin session manager and expose config to admin routes
from pathlib import Path from pathlib import Path
init_session_manager(Path(config_dir)) init_session_manager(Path(config_dir), port=config.server.port)
set_config_manager(config_mgr) set_config_manager(config_mgr)
# Handle early exit options (before heavy imports) # Handle early exit options (before heavy imports)
...@@ -145,7 +235,8 @@ def main(): ...@@ -145,7 +235,8 @@ def main():
print("No GGUF files found, downloading full HuggingFace repo...") print("No GGUF files found, downloading full HuggingFace repo...")
try: try:
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
cached_path = snapshot_download(model_id) from codai.models.cache import get_hf_hub_cache_dir
cached_path = snapshot_download(model_id, cache_dir=get_hf_hub_cache_dir())
except Exception as e: except Exception as e:
print(f"Error downloading full repo: {e}") print(f"Error downloading full repo: {e}")
cached_path = None cached_path = None
...@@ -175,6 +266,9 @@ def main(): ...@@ -175,6 +266,9 @@ def main():
print(f"Error listing devices: {e}") print(f"Error listing devices: {e}")
sys.exit(0) sys.exit(0)
# Migrate any GGUF files that ended up in the HF cache to the GGUF cache
_t.Thread(target=_migrate_hf_gguf_to_gguf_cache, daemon=True).start()
# Import core modules (only after early exits) # Import core modules (only after early exits)
from codai.api import app from codai.api import app
from codai.api.state import ( from codai.api.state import (
...@@ -724,18 +818,36 @@ def main(): ...@@ -724,18 +818,36 @@ def main():
queue_manager.max_size = config.server.queue_max_size queue_manager.max_size = config.server.queue_max_size
queue_manager.max_parallel_requests = config.server.max_parallel_requests queue_manager.max_parallel_requests = config.server.max_parallel_requests
# Configure Python logging so broker/API log calls reach the terminal.
# uvicorn is started with log_config=None to keep our config in place.
_log_level = logging.DEBUG if global_debug else logging.INFO
logging.basicConfig(
level=_log_level,
format="%(asctime)s [%(levelname)-8s] %(name)s: %(message)s",
stream=sys.stdout,
force=True,
)
# Suppress noisy third-party libraries at WARNING unless in debug mode.
for _noisy in ("httpx", "httpcore", "urllib3", "multipart", "PIL"):
logging.getLogger(_noisy).setLevel(logging.WARNING)
if not global_debug:
logging.getLogger("websockets").setLevel(logging.WARNING)
logging.getLogger("asyncio").setLevel(logging.WARNING)
# Start the server # Start the server
import uvicorn import uvicorn
print(f"\nStarting server on http://{config.server.host}:{config.server.port}") print(f"\nStarting server on http://{config.server.host}:{config.server.port}")
print(f"API docs: http://{config.server.host}:{config.server.port}/docs") print(f"API docs: http://{config.server.host}:{config.server.port}/docs")
print(f"Admin UI: http://{config.server.host}:{config.server.port}/admin") print(f"Admin UI: http://{config.server.host}:{config.server.port}/admin")
if model_manager.backend is not None: if model_manager.backend is not None:
actual_backend = model_manager.backend_type actual_backend = model_manager.backend_type
if hasattr(model_manager.backend, 'force_cuda') and model_manager.backend.force_cuda: if hasattr(model_manager.backend, 'force_cuda') and model_manager.backend.force_cuda:
actual_backend = "cuda (via llama-cpp-python)" actual_backend = "cuda (via llama-cpp-python)"
print(f"Using backend: {actual_backend}") print(f"Using backend: {actual_backend}")
_uvi_log_level = "debug" if global_debug else "info"
if config.server.https: if config.server.https:
import ssl import ssl
ssl_keyfile = config.server.https_key_path ssl_keyfile = config.server.https_key_path
...@@ -758,14 +870,17 @@ def main(): ...@@ -758,14 +870,17 @@ def main():
except Exception as e: except Exception as e:
print(f"Warning: Could not generate certificate: {e}") print(f"Warning: Could not generate certificate: {e}")
print("Falling back to HTTP...") print("Falling back to HTTP...")
uvicorn.run(fastapi_app, host=config.server.host, port=config.server.port) uvicorn.run(fastapi_app, host=config.server.host, port=config.server.port,
log_level=_uvi_log_level, log_config=None)
return return
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
ssl_context.load_cert_chain(ssl_certfile, ssl_keyfile) ssl_context.load_cert_chain(ssl_certfile, ssl_keyfile)
uvicorn.run(fastapi_app, host=config.server.host, port=config.server.port, ssl_context=ssl_context) uvicorn.run(fastapi_app, host=config.server.host, port=config.server.port,
ssl_context=ssl_context, log_level=_uvi_log_level, log_config=None)
else: else:
uvicorn.run(fastapi_app, host=config.server.host, port=config.server.port) uvicorn.run(fastapi_app, host=config.server.host, port=config.server.port,
log_level=_uvi_log_level, log_config=None)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -56,6 +56,26 @@ def get_model_cache_dir() -> str: ...@@ -56,6 +56,26 @@ def get_model_cache_dir() -> str:
return cache_dir return cache_dir
def get_hf_hub_cache_dir() -> str:
"""Return the HuggingFace Hub cache directory CoderAI is configured to use.
Mirrors huggingface_hub's own env-var priority so that passing this value
as ``cache_dir`` to snapshot_download / hf_hub_download always targets the
same location the library would choose on its own — even if the directory
does not yet exist (first download).
"""
# Priority mirrors huggingface_hub.constants:
# HUGGINGFACE_HUB_CACHE (explicit cache path)
# HF_HOME/hub (parent-home style)
# ~/.cache/huggingface/hub (built-in default)
hf_hub_cache = (
os.environ.get('HUGGINGFACE_HUB_CACHE')
or (os.path.join(os.environ['HF_HOME'], 'hub') if 'HF_HOME' in os.environ else None)
or os.path.join(os.path.expanduser('~'), '.cache', 'huggingface', 'hub')
)
return hf_hub_cache
def get_all_cache_dirs() -> dict: def get_all_cache_dirs() -> dict:
"""Get all model cache directories.""" """Get all model cache directories."""
caches = {} caches = {}
...@@ -540,6 +560,7 @@ def remove_all_cached_models() -> int: ...@@ -540,6 +560,7 @@ def remove_all_cached_models() -> int:
# Export all public functions # Export all public functions
__all__ = [ __all__ = [
'get_model_cache_dir', 'get_model_cache_dir',
'get_hf_hub_cache_dir',
'get_all_cache_dirs', 'get_all_cache_dirs',
'get_cached_model_path', 'get_cached_model_path',
'is_huggingface_model_id', 'is_huggingface_model_id',
......
...@@ -110,12 +110,10 @@ def detect_model_capabilities(model_name: str) -> ModelCapabilities: ...@@ -110,12 +110,10 @@ def detect_model_capabilities(model_name: str) -> ModelCapabilities:
'animatediff', 'text2video', 'modelscope-t2v', 'animatediff', 'text2video', 'modelscope-t2v',
'zeroscope', 'lavie']): 'zeroscope', 'lavie']):
caps.video_generation = True caps.video_generation = True
caps.text_generation = True # T2V models also do text
return caps return caps
if any(x in n for x in ['wan2.1-t2v', 'wan-t2v']): if any(x in n for x in ['wan2.1-t2v', 'wan-t2v']):
caps.video_generation = True caps.video_generation = True
caps.text_generation = True
return caps return caps
# Image-to-video # Image-to-video
...@@ -124,17 +122,14 @@ def detect_model_capabilities(model_name: str) -> ModelCapabilities: ...@@ -124,17 +122,14 @@ def detect_model_capabilities(model_name: str) -> ModelCapabilities:
'wan2.1-i2v', 'wan-i2v', 'img2vid', 'wan2.1-i2v', 'wan-i2v', 'img2vid',
'image2video', 'motionctrl']): 'image2video', 'motionctrl']):
caps.image_to_video = True caps.image_to_video = True
caps.image_to_text = True # I2V models process images
return caps return caps
# Wan generic (detect sub-variant) # Wan generic (detect sub-variant)
if 'wan' in n and ('video' in n or 'diffuser' in n): if 'wan' in n and ('video' in n or 'diffuser' in n):
if 'i2v' in n: if 'i2v' in n:
caps.image_to_video = True caps.image_to_video = True
caps.image_to_text = True
else: else:
caps.video_generation = True caps.video_generation = True
caps.text_generation = True
return caps return caps
# Video interpolation # Video interpolation
...@@ -158,7 +153,6 @@ def detect_model_capabilities(model_name: str) -> ModelCapabilities: ...@@ -158,7 +153,6 @@ def detect_model_capabilities(model_name: str) -> ModelCapabilities:
if any(x in n for x in ['musicgen', 'audiogen', 'audioldm', 'stable-audio', if any(x in n for x in ['musicgen', 'audiogen', 'audioldm', 'stable-audio',
'mustango', 'noise2music', 'jukebox', 'audiocraft']): 'mustango', 'noise2music', 'jukebox', 'audiocraft']):
caps.audio_generation = True caps.audio_generation = True
caps.text_generation = True # T2A models process text
return caps return caps
if any(x in n for x in ['demucs', 'spleeter', 'asteroid', 'open-unmix']): if any(x in n for x in ['demucs', 'spleeter', 'asteroid', 'open-unmix']):
...@@ -174,7 +168,6 @@ def detect_model_capabilities(model_name: str) -> ModelCapabilities: ...@@ -174,7 +168,6 @@ def detect_model_capabilities(model_name: str) -> ModelCapabilities:
if any(x in n for x in ['kokoro', 'xtts', 'bark', 'tortoise', if any(x in n for x in ['kokoro', 'xtts', 'bark', 'tortoise',
'speecht5', 'matcha-tts', 'voicebox']): 'speecht5', 'matcha-tts', 'voicebox']):
caps.text_to_speech = True caps.text_to_speech = True
caps.text_generation = True # TTS models process text
return caps return caps
# Lip sync / dubbing # Lip sync / dubbing
...@@ -199,13 +192,11 @@ def detect_model_capabilities(model_name: str) -> ModelCapabilities: ...@@ -199,13 +192,11 @@ def detect_model_capabilities(model_name: str) -> ModelCapabilities:
caps.inpainting = True caps.inpainting = True
caps.image_generation = True caps.image_generation = True
caps.image_to_image = True caps.image_to_image = True
caps.text_generation = True
return caps return caps
if 'controlnet' in n: if 'controlnet' in n:
caps.controlnet = True caps.controlnet = True
caps.image_generation = True caps.image_generation = True
caps.text_generation = True
return caps return caps
if any(x in n for x in [ if any(x in n for x in [
...@@ -235,7 +226,6 @@ def detect_model_capabilities(model_name: str) -> ModelCapabilities: ...@@ -235,7 +226,6 @@ def detect_model_capabilities(model_name: str) -> ModelCapabilities:
caps.image_generation = True caps.image_generation = True
caps.image_to_image = True caps.image_to_image = True
caps.inpainting = True # most SD/SDXL/Flux checkpoints support inpainting via mask caps.inpainting = True # most SD/SDXL/Flux checkpoints support inpainting via mask
caps.text_generation = True
return caps return caps
# ── Image: analysis / processing ───────────────────────────────────────── # ── Image: analysis / processing ─────────────────────────────────────────
...@@ -295,12 +285,6 @@ def detect_model_capabilities(model_name: str) -> ModelCapabilities: ...@@ -295,12 +285,6 @@ def detect_model_capabilities(model_name: str) -> ModelCapabilities:
'text-embedding', 'voyage-', 'text-embedding', 'voyage-',
]): ]):
caps.embeddings = True caps.embeddings = True
caps.text_generation = True
return caps
# ── GGUF quantised text models ───────────────────────────────────────────
if '.gguf' in n or 'gguf' in n:
caps.text_generation = True
return caps return caps
# Default: text generation # Default: text generation
...@@ -315,17 +299,17 @@ _PIPELINE_TAG_CAPS: dict = { ...@@ -315,17 +299,17 @@ _PIPELINE_TAG_CAPS: dict = {
'image-to-text': ['image_to_text', 'text_generation'], 'image-to-text': ['image_to_text', 'text_generation'],
'visual-question-answering': ['image_to_text', 'text_generation'], 'visual-question-answering': ['image_to_text', 'text_generation'],
'image-text-to-text': ['image_to_text', 'text_generation'], 'image-text-to-text': ['image_to_text', 'text_generation'],
'text-to-image': ['image_generation', 'image_to_image', 'text_generation'], 'text-to-image': ['image_generation', 'image_to_image'],
'unconditional-image-generation': ['image_generation'], 'unconditional-image-generation': ['image_generation'],
'image-to-image': ['image_to_image'], # sub-typed below 'image-to-image': ['image_to_image'], # sub-typed below
'automatic-speech-recognition': ['speech_to_text'], 'automatic-speech-recognition': ['speech_to_text'],
'audio-to-audio': ['audio_to_audio'], 'audio-to-audio': ['audio_to_audio'],
'text-to-speech': ['text_to_speech'], 'text-to-speech': ['text_to_speech'],
'text-to-audio': ['audio_generation'], 'text-to-audio': ['audio_generation'],
'text-to-video': ['video_generation', 'text_generation'], 'text-to-video': ['video_generation'],
'image-to-video': ['image_to_video'], 'image-to-video': ['image_to_video'],
'feature-extraction': ['embeddings', 'text_generation'], 'feature-extraction': ['embeddings'],
'sentence-similarity': ['embeddings', 'text_generation'], 'sentence-similarity': ['embeddings'],
'depth-estimation': ['depth_estimation', 'image_to_text'], 'depth-estimation': ['depth_estimation', 'image_to_text'],
'image-segmentation': ['image_segmentation', 'image_to_text'], 'image-segmentation': ['image_segmentation', 'image_to_text'],
'object-detection': ['object_detection', 'image_to_text'], 'object-detection': ['object_detection', 'image_to_text'],
......
...@@ -1541,6 +1541,10 @@ class MultiModelManager: ...@@ -1541,6 +1541,10 @@ class MultiModelManager:
2. Local HuggingFace hub cache scan. 2. Local HuggingFace hub cache scan.
3. HuggingFace API (network, one call per model per process lifetime). 3. HuggingFace API (network, one call per model per process lifetime).
For LoRA adapters (repos containing adapter_config.json) the size of
the base model is added so that VRAM requirements are not
underestimated.
Returns 0 on any failure. Returns 0 on any failure.
""" """
if model_id in MultiModelManager._hf_size_cache: if model_id in MultiModelManager._hf_size_cache:
...@@ -1548,10 +1552,19 @@ class MultiModelManager: ...@@ -1548,10 +1552,19 @@ class MultiModelManager:
weight_exts = {'.safetensors', '.bin', '.gguf', '.ggml', '.pt'} weight_exts = {'.safetensors', '.bin', '.gguf', '.ggml', '.pt'}
def _resolve_base_model(base_model_id: str) -> int:
from codai.models.cache import is_huggingface_model_id
if not base_model_id or base_model_id == model_id:
return 0
if not is_huggingface_model_id(base_model_id):
return 0
return MultiModelManager._hf_cached_model_size_bytes(base_model_id)
# --- Try local HF hub cache first (no network) --- # --- Try local HF hub cache first (no network) ---
try: try:
from huggingface_hub import scan_cache_dir from huggingface_hub import scan_cache_dir
from codai.models.cache import get_all_cache_dirs from codai.models.cache import get_all_cache_dirs
import json as _json
hf_dir = get_all_cache_dirs().get("huggingface") hf_dir = get_all_cache_dirs().get("huggingface")
if hf_dir: if hf_dir:
info = scan_cache_dir(hf_dir) info = scan_cache_dir(hf_dir)
...@@ -1560,11 +1573,26 @@ class MultiModelManager: ...@@ -1560,11 +1573,26 @@ class MultiModelManager:
continue continue
revs = sorted(repo.revisions, key=lambda r: r.last_modified, reverse=True) revs = sorted(repo.revisions, key=lambda r: r.last_modified, reverse=True)
if revs: if revs:
rev = revs[0]
total = sum( total = sum(
f.size_on_disk f.size_on_disk
for f in revs[0].files for f in rev.files
if os.path.splitext(f.file_name)[1].lower() in weight_exts if os.path.splitext(f.file_name)[1].lower() in weight_exts
) )
# LoRA adapter: add base model size
for f in rev.files:
if f.file_name == "adapter_config.json":
try:
with open(f.file_path) as fp:
adapter_cfg = _json.load(fp)
base_id = (
adapter_cfg.get("base_model_name_or_path")
or adapter_cfg.get("base_model")
)
total += _resolve_base_model(base_id)
except Exception:
pass
break
if total > 0: if total > 0:
MultiModelManager._hf_size_cache[model_id] = total MultiModelManager._hf_size_cache[model_id] = total
return total return total
...@@ -1580,13 +1608,31 @@ class MultiModelManager: ...@@ -1580,13 +1608,31 @@ class MultiModelManager:
with urllib.request.urlopen(req, timeout=10) as resp: with urllib.request.urlopen(req, timeout=10) as resp:
data = _json.loads(resp.read()) data = _json.loads(resp.read())
total = 0 total = 0
has_adapter_config = False
for sib in data.get("siblings", []): for sib in data.get("siblings", []):
name = sib.get("rfilename", "") name = sib.get("rfilename", "")
if name == "adapter_config.json":
has_adapter_config = True
continue
if os.path.splitext(name)[1].lower() not in weight_exts: if os.path.splitext(name)[1].lower() not in weight_exts:
continue continue
lfs = sib.get("lfs") or {} lfs = sib.get("lfs") or {}
size = lfs.get("size") or sib.get("size") or 0 size = lfs.get("size") or sib.get("size") or 0
total += size total += size
# LoRA adapter: fetch adapter_config.json to get the base model
if has_adapter_config:
try:
cfg_url = f"https://huggingface.co/{model_id}/resolve/main/adapter_config.json"
cfg_req = urllib.request.Request(cfg_url, headers={"User-Agent": "coderai/1.0"})
with urllib.request.urlopen(cfg_req, timeout=10) as resp:
adapter_cfg = _json.loads(resp.read())
base_id = (
adapter_cfg.get("base_model_name_or_path")
or adapter_cfg.get("base_model")
)
total += _resolve_base_model(base_id)
except Exception:
pass
if total > 0: if total > 0:
MultiModelManager._hf_size_cache[model_id] = total MultiModelManager._hf_size_cache[model_id] = total
return total return total
...@@ -2101,9 +2147,12 @@ class MultiModelManager: ...@@ -2101,9 +2147,12 @@ class MultiModelManager:
needed_gb = self._get_model_used_vram_gb(model_key, resolved_name) needed_gb = self._get_model_used_vram_gb(model_key, resolved_name)
free_gb = self._get_free_vram_gb() free_gb = self._get_free_vram_gb()
if needed_gb > 0 and free_gb >= needed_gb: # Require headroom beyond raw weight size for activation buffers
# and generation scratch (30% of model size + 1 GB base).
headroom_gb = max(1.0, needed_gb * 0.30)
if needed_gb > 0 and free_gb >= needed_gb + headroom_gb:
print(f"Ondemand mode - keeping '{loaded_canonical}' in VRAM alongside new model " print(f"Ondemand mode - keeping '{loaded_canonical}' in VRAM alongside new model "
f"(need {needed_gb:.1f} GB, have {free_gb:.1f} GB free)") f"(need {needed_gb:.1f} GB + {headroom_gb:.1f} GB headroom, have {free_gb:.1f} GB free)")
else: else:
print(f"Ondemand mode - model switch detected:") print(f"Ondemand mode - model switch detected:")
print(f" Requested: '{model_key}' (resolved: '{resolved_name}')") print(f" Requested: '{model_key}' (resolved: '{resolved_name}')")
......
...@@ -89,6 +89,22 @@ The outbound WebSocket connection must include: ...@@ -89,6 +89,22 @@ The outbound WebSocket connection must include:
- `username`: either `global` or the AISBF username for user-owned providers - `username`: either `global` or the AISBF username for user-owned providers
- `registration_token`: provider-scoped secret from AISBF provider configuration - `registration_token`: provider-scoped secret from AISBF provider configuration
### Current server-side resolution order
AISBF resolves broker identity in this exact order when the WebSocket handshake arrives:
- `provider_id`: query param `provider_id`, then header `x-coderai-provider-id`, then default `coderai`
- `client_id`: query param `client_id`, then header `x-coderai-client-id`, then generated fallback `anon-<unix_timestamp>`
- `username`: query param `username`, then header `x-coderai-username`, then the path scope name (`global` or the `/api/u/{username}` path segment)
- `registration_token`: query param `registration_token`, then header `x-coderai-registration-token`
Important constraints:
- the `registration_token` is required for admission
- `Authorization: Bearer ...` is currently not used by the broker WebSocket admission check
- if you omit `client_id`, AISBF generates an `anon-*` client id and future broker routing will only work if AISBF also targets that exact generated value
- the `client_id` used by the CoderAI client must match the `coderai_config.client_id` used by the AISBF provider, or the broker can show the session as connected while requests still fail to route
## Optional Headers ## Optional Headers
AISBF also accepts or may expect these headers: AISBF also accepts or may expect these headers:
...@@ -109,6 +125,35 @@ Recommended behavior: ...@@ -109,6 +125,35 @@ Recommended behavior:
Open the outbound WebSocket to the correct scoped AISBF endpoint. Open the outbound WebSocket to the correct scoped AISBF endpoint.
The handshake is a normal WebSocket upgrade request, which starts as an HTTP `GET` carrying query parameters. This is expected.
Recommended connect template:
```text
wss://<aisbf-host>/<optional-prefix>/api/coderai/wss?provider_id=<provider_id>&client_id=<stable_client_id>&username=global&registration_token=<provider_registration_token>
```
User-scoped template:
```text
wss://<aisbf-host>/<optional-prefix>/api/u/<username>/coderai/wss?provider_id=<provider_id>&client_id=<stable_client_id>&username=<username>&registration_token=<provider_registration_token>
```
Recommended handshake headers:
```text
x-coderai-provider-id: <provider_id>
x-coderai-client-id: <stable_client_id>
x-coderai-username: <username>
x-coderai-registration-token: <provider_registration_token>
```
Best practice:
- send the same identity in both query parameters and headers
- keep `client_id` stable across reconnects
- always reconnect with the same provider scope and owner scope
### 2. Wait for `registered` event ### 2. Wait for `registered` event
AISBF immediately sends a registration acknowledgment event on successful admission. AISBF immediately sends a registration acknowledgment event on successful admission.
...@@ -135,11 +180,21 @@ Store: ...@@ -135,11 +180,21 @@ Store:
- `client_id` - `client_id`
- `username` - `username`
- `scope_name` - `scope_name`
- `owner_user_id`
- `expires_at`
Notes:
- this event means the socket is admitted and the session row exists
- it does not yet mean hardware/capabilities metadata has been uploaded
- the client should send the explicit `register` operation immediately after this event
### 3. Send explicit `register` operation ### 3. Send explicit `register` operation
After the `registered` event, CoderAI must send a `register` message describing its capabilities, hardware inventory, and advertised endpoints. After the `registered` event, CoderAI must send a `register` message describing its capabilities, hardware inventory, and advertised endpoints.
AISBF currently processes `register` as a normal inbound WebSocket message and responds with `status=ok` using the same `request_id`.
### 4. Enter long-lived receive loop ### 4. Enter long-lived receive loop
Then keep listening for incoming broker requests from AISBF. Then keep listening for incoming broker requests from AISBF.
...@@ -233,6 +288,60 @@ CoderAI should send this after receiving the initial AISBF `registered` event. ...@@ -233,6 +288,60 @@ CoderAI should send this after receiving the initial AISBF `registered` event.
AISBF replies with a success envelope. AISBF replies with a success envelope.
### Fields AISBF currently reads from the `register` message
Top-level:
- `v`
- `op` with value `register`
- `request_id`
- optional top-level `registration_token`
- optional top-level `capabilities`
From `payload`:
- `endpoint`
- `transport`
- `registration_token`
- `studio_endpoints`
- `hardware`
- `gpus`
- `gpu_count`
- `total_vram_mb`
- `available_vram_mb`
- `capabilities`
AISBF behavior:
- if `payload.registration_token` or top-level `registration_token` is present and does not match the handshake token, AISBF replies with an error envelope
- if token matches, AISBF persists the metadata onto the broker session
- `payload.capabilities` takes precedence over missing top-level capability data
- if `gpus`, `gpu_count`, `total_vram_mb`, or `available_vram_mb` are omitted at the top level, AISBF falls back to the values inside `payload.hardware`
Minimal acceptable `register` message:
```json
{
"v": 1,
"op": "register",
"request_id": "reg-1",
"payload": {
"transport": "websocket",
"registration_token": "<same_registration_token>",
"capabilities": {}
}
}
```
Recommended full `register` message:
- include `endpoint`
- include `transport`
- include `registration_token`
- include `hardware.gpus`, `hardware.gpu_count`, `hardware.total_vram_mb`, `hardware.available_vram_mb`
- include `studio_endpoints`
- include `capabilities`
### Hardware Reporting Requirements ### Hardware Reporting Requirements
The `register` payload should include the best hardware view available to the running CoderAI process. The `register` payload should include the best hardware view available to the running CoderAI process.
...@@ -326,6 +435,37 @@ Heartbeat payloads may also refresh dynamic hardware state such as changing free ...@@ -326,6 +435,37 @@ Heartbeat payloads may also refresh dynamic hardware state such as changing free
} }
``` ```
Current AISBF note:
- AISBF acknowledges heartbeat messages and merges the heartbeat `payload` into session metadata
- keep heartbeat payloads small and non-blocking
- use heartbeats for lightweight dynamic updates only; do not block the main receive loop on expensive hardware rescans
## Async Client Requirements
The broker WebSocket integration must be fully asynchronous.
CoderAI client requirements:
- the main receive loop must never block on model loading, inference, GPU inspection, or disk/network I/O
- expensive work should run in background tasks or worker executors while the socket remains responsive to incoming frames and ping/pong traffic
- the client should be able to receive broker requests while also sending progress or result frames for earlier requests
- the client must not serialize all work behind registration or heartbeat handling
AISBF broker behavior:
- AISBF now drains queued outbound broker requests in a background async task while independently reading inbound websocket messages
- this means the CoderAI client should expect inbound requests to arrive even while it is still sending heartbeat or response messages for unrelated work
- operations are correlated strictly by `request_id`; client implementations must not rely on message ordering alone
Recommended client architecture:
1. one async reader task for inbound WebSocket frames
2. one async writer path or send queue for outbound replies/events
3. per-request async tasks for local execution
4. a lightweight periodic heartbeat task
5. explicit request correlation by `request_id`
AISBF merges those updates into the broker session metadata. AISBF merges those updates into the broker session metadata.
## Local HTTP Endpoints CoderAI Should Expose ## Local HTTP Endpoints CoderAI Should Expose
......
This diff is collapsed.
...@@ -84,8 +84,13 @@ def test_build_broker_runtime_config_global_scope_builds_url_and_headers(): ...@@ -84,8 +84,13 @@ def test_build_broker_runtime_config_global_scope_builds_url_and_headers():
"x-coderai-provider-id": "provider-1", "x-coderai-provider-id": "provider-1",
"x-coderai-client-id": "client-1", "x-coderai-client-id": "client-1",
"x-coderai-username": "global", "x-coderai-username": "global",
"x-coderai-registration-token": "token-123",
"x-coderai-advertised-endpoint": "https://server.example.com", "x-coderai-advertised-endpoint": "https://server.example.com",
} }
assert runtime.provider_id == "provider-1"
assert runtime.client_id == "client-1"
assert runtime.username == "global"
assert runtime.registration_token == "token-123"
assert runtime.transport == "websocket" assert runtime.transport == "websocket"
assert runtime.heartbeat_interval_seconds == 30 assert runtime.heartbeat_interval_seconds == 30
assert runtime.connect_timeout_seconds == 10 assert runtime.connect_timeout_seconds == 10
...@@ -94,6 +99,38 @@ def test_build_broker_runtime_config_global_scope_builds_url_and_headers(): ...@@ -94,6 +99,38 @@ def test_build_broker_runtime_config_global_scope_builds_url_and_headers():
assert runtime.reconnect_max_delay_seconds == 60 assert runtime.reconnect_max_delay_seconds == 60
def test_build_broker_runtime_config_global_scope_uses_global_service_paths():
runtime = build_broker_runtime_config(
BrokerConfig(
enabled=True,
base_url="https://broker.example.com/base",
scope="global",
username="global",
provider_id="provider-1",
client_id="client-1",
registration_token="token-123",
)
)
assert runtime.websocket_url.startswith("wss://broker.example.com/base/api/coderai/wss")
def test_build_broker_runtime_config_user_scope_uses_user_service_paths():
runtime = build_broker_runtime_config(
BrokerConfig(
enabled=True,
base_url="https://broker.example.com/base",
scope="user",
username="alice",
provider_id="provider-1",
client_id="client-1",
registration_token="token-123",
)
)
assert runtime.websocket_url.startswith("wss://broker.example.com/base/api/u/alice/coderai/wss")
def test_build_broker_runtime_config_rejects_invalid_global_username(): def test_build_broker_runtime_config_rejects_invalid_global_username():
try: try:
build_broker_runtime_config( build_broker_runtime_config(
...@@ -139,6 +176,7 @@ def test_build_broker_runtime_config_user_scope_uses_user_path(): ...@@ -139,6 +176,7 @@ def test_build_broker_runtime_config_user_scope_uses_user_path():
"x-coderai-provider-id": "provider-1", "x-coderai-provider-id": "provider-1",
"x-coderai-client-id": "client-1", "x-coderai-client-id": "client-1",
"x-coderai-username": "alice", "x-coderai-username": "alice",
"x-coderai-registration-token": "token-123",
"x-coderai-advertised-endpoint": "https://server.example.com/alice", "x-coderai-advertised-endpoint": "https://server.example.com/alice",
} }
...@@ -165,6 +203,28 @@ def test_build_broker_runtime_config_preserves_base_url_prefix_in_websocket_url( ...@@ -165,6 +203,28 @@ def test_build_broker_runtime_config_preserves_base_url_prefix_in_websocket_url(
) )
def test_build_broker_runtime_config_does_not_duplicate_api_prefix_when_base_url_already_ends_with_api():
runtime = build_broker_runtime_config(
BrokerConfig(
enabled=True,
base_url="https://aisbf.cloud/api",
scope="global",
username="global",
provider_id="provider-1",
client_id="client-1",
registration_token="token-123",
)
)
assert runtime.websocket_url == (
"wss://aisbf.cloud/api/coderai/wss"
"?provider_id=provider-1"
"&client_id=client-1"
"&username=global"
"&registration_token=token-123"
)
def test_build_broker_runtime_config_encodes_reserved_username_path_characters(): def test_build_broker_runtime_config_encodes_reserved_username_path_characters():
runtime = build_broker_runtime_config( runtime = build_broker_runtime_config(
BrokerConfig( BrokerConfig(
...@@ -187,6 +247,48 @@ def test_build_broker_runtime_config_encodes_reserved_username_path_characters() ...@@ -187,6 +247,48 @@ def test_build_broker_runtime_config_encodes_reserved_username_path_characters()
) )
def test_build_broker_runtime_config_uses_manual_websocket_path_override():
runtime = build_broker_runtime_config(
BrokerConfig(
enabled=True,
base_url="https://broker.example.com/prefix",
scope="global",
username="global",
provider_id="provider-1",
client_id="client-1",
registration_token="token-123",
websocket_path="/custom/broker/socket",
)
)
assert runtime.websocket_path == "/custom/broker/socket"
assert runtime.websocket_url == (
"wss://broker.example.com/prefix/custom/broker/socket"
"?provider_id=provider-1"
"&client_id=client-1"
"&username=global"
"&registration_token=token-123"
)
def test_build_broker_runtime_config_normalizes_manual_websocket_path_override_without_leading_slash():
runtime = build_broker_runtime_config(
BrokerConfig(
enabled=True,
base_url="https://broker.example.com",
scope="user",
username="alice",
provider_id="provider-1",
client_id="client-1",
registration_token="token-123",
websocket_path="broker/ws",
)
)
assert runtime.websocket_path == "broker/ws"
assert runtime.websocket_url.startswith("wss://broker.example.com/broker/ws")
def test_build_broker_runtime_config_rejects_invalid_user_scope_username(): def test_build_broker_runtime_config_rejects_invalid_user_scope_username():
try: try:
build_broker_runtime_config( build_broker_runtime_config(
...@@ -294,11 +396,17 @@ def test_build_register_message_includes_capabilities_and_hardware(): ...@@ -294,11 +396,17 @@ def test_build_register_message_includes_capabilities_and_hardware():
"v": 1, "v": 1,
"op": "register", "op": "register",
"request_id": "req-1", "request_id": "req-1",
"registration_token": "token-123",
"capabilities": capabilities,
"payload": { "payload": {
"endpoint": "https://server.example.com/alice", "endpoint": "https://server.example.com/alice",
"transport": "websocket", "transport": "websocket",
"registration_token": "token-123", "registration_token": "token-123",
"hardware": {"gpu": True, "memory_gb": 24}, "hardware": {"gpu": True, "memory_gb": 24},
"gpus": [],
"gpu_count": 0,
"total_vram_mb": 0,
"available_vram_mb": 0,
"studio_endpoints": EXPECTED_STUDIO_ENDPOINTS, "studio_endpoints": EXPECTED_STUDIO_ENDPOINTS,
"capabilities": capabilities, "capabilities": capabilities,
}, },
...@@ -318,17 +426,65 @@ def test_build_register_message_defaults_token_and_studio_endpoints_for_empty_ru ...@@ -318,17 +426,65 @@ def test_build_register_message_defaults_token_and_studio_endpoints_for_empty_ru
"v": 1, "v": 1,
"op": "register", "op": "register",
"request_id": "req-1", "request_id": "req-1",
"registration_token": "",
"capabilities": {"server": "codai"},
"payload": { "payload": {
"endpoint": "", "endpoint": "",
"transport": "websocket", "transport": "websocket",
"registration_token": "", "registration_token": "",
"hardware": None, "hardware": None,
"gpus": [],
"gpu_count": 0,
"total_vram_mb": 0,
"available_vram_mb": 0,
"studio_endpoints": DEFAULT_STUDIO_ENDPOINTS, "studio_endpoints": DEFAULT_STUDIO_ENDPOINTS,
"capabilities": {"server": "codai"}, "capabilities": {"server": "codai"},
}, },
} }
def test_build_hardware_summary_reports_torch_cuda_vram(monkeypatch):
class FakeProps:
total_memory = 24 * 1024 * 1024 * 1024
class FakeCuda:
@staticmethod
def is_available():
return True
@staticmethod
def device_count():
return 1
@staticmethod
def current_device():
return 0
@staticmethod
def get_device_properties(index):
return FakeProps()
@staticmethod
def get_device_name(index):
return "RTX Test"
@staticmethod
def mem_get_info():
return (10 * 1024 * 1024 * 1024, 24 * 1024 * 1024 * 1024)
class FakeTorch:
cuda = FakeCuda()
monkeypatch.setitem(sys.modules, "torch", FakeTorch())
hardware = build_hardware_summary()
assert hardware["gpu_count"] == 1
assert hardware["total_vram_mb"] == 24576
assert hardware["available_vram_mb"] == 10240
assert hardware["gpus"] == [{"index": 0, "name": "RTX Test", "total_vram_mb": 24576}]
def test_build_capabilities_document_lists_openai_and_studio_support(): def test_build_capabilities_document_lists_openai_and_studio_support():
document = build_capabilities_document(hardware={"gpu": True}) document = build_capabilities_document(hardware={"gpu": True})
......
...@@ -6,6 +6,7 @@ sys.path.insert(0, str(Path(__file__).resolve().parents[1])) ...@@ -6,6 +6,7 @@ sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
import pytest import pytest
from fastapi import FastAPI from fastapi import FastAPI
from fastapi import File, Form, UploadFile
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from starlette.responses import Response from starlette.responses import Response
...@@ -87,6 +88,7 @@ async def test_execute_broker_request_returns_success_envelope_for_json_route(): ...@@ -87,6 +88,7 @@ async def test_execute_broker_request_returns_success_envelope_for_json_route():
envelope = BrokerRequestEnvelope( envelope = BrokerRequestEnvelope(
request_id="req-123", request_id="req-123",
op="chat.completions",
method="POST", method="POST",
path="/v1/chat/completions", path="/v1/chat/completions",
headers={"accept": "application/json"}, headers={"accept": "application/json"},
...@@ -96,7 +98,7 @@ async def test_execute_broker_request_returns_success_envelope_for_json_route(): ...@@ -96,7 +98,7 @@ async def test_execute_broker_request_returns_success_envelope_for_json_route():
response = await execute_broker_request(app, envelope) response = await execute_broker_request(app, envelope)
assert response["request_id"] == "req-123" assert response["request_id"] == "req-123"
assert response["ok"] is True assert response["status"] == "ok"
assert response["payload"] == { assert response["payload"] == {
"status_code": 201, "status_code": 201,
"headers": { "headers": {
...@@ -125,6 +127,7 @@ async def test_execute_broker_request_preserves_binary_payload_metadata(): ...@@ -125,6 +127,7 @@ async def test_execute_broker_request_preserves_binary_payload_metadata():
envelope = BrokerRequestEnvelope( envelope = BrokerRequestEnvelope(
request_id="req-binary", request_id="req-binary",
op="proxy",
method="GET", method="GET",
path="/v1/images/render", path="/v1/images/render",
) )
...@@ -132,7 +135,7 @@ async def test_execute_broker_request_preserves_binary_payload_metadata(): ...@@ -132,7 +135,7 @@ async def test_execute_broker_request_preserves_binary_payload_metadata():
response = await execute_broker_request(app, envelope) response = await execute_broker_request(app, envelope)
assert response["request_id"] == "req-binary" assert response["request_id"] == "req-binary"
assert response["ok"] is True assert response["status"] == "ok"
assert response["payload"] == { assert response["payload"] == {
"status_code": 200, "status_code": 200,
"headers": { "headers": {
...@@ -148,12 +151,75 @@ async def test_execute_broker_request_preserves_binary_payload_metadata(): ...@@ -148,12 +151,75 @@ async def test_execute_broker_request_preserves_binary_payload_metadata():
assert response["metrics"]["elapsed_ms"] >= 0 assert response["metrics"]["elapsed_ms"] >= 0
@pytest.mark.anyio("asyncio")
async def test_execute_broker_request_maps_proxy_operation_to_internal_route():
app = FastAPI()
@app.post("/v1/video/dub")
async def dub_route(payload: dict):
return {"received": payload, "route": "dub"}
envelope = BrokerRequestEnvelope(
request_id="req-proxy-op",
op="proxy",
payload={
"endpoint_path": "v1/video/dub",
"method": "POST",
"headers": {"content-type": "application/json"},
"body": {"prompt": "hello"},
},
)
response = await execute_broker_request(app, envelope)
assert response["status"] == "ok"
assert response["payload"]["status_code"] == 200
assert response["payload"]["body"] == '{"received":{"prompt":"hello"},"route":"dub"}'
@pytest.mark.anyio("asyncio")
async def test_execute_broker_request_supports_proxy_multipart_payloads():
app = FastAPI()
@app.post("/v1/audio/transcriptions")
async def transcription_route(model: str = Form(...), file: UploadFile = File(...)):
data = await file.read()
return {"model": model, "filename": file.filename, "size": len(data)}
envelope = BrokerRequestEnvelope(
request_id="req-multipart",
op="proxy",
payload={
"endpoint_path": "v1/audio/transcriptions",
"method": "POST",
"multipart": {
"fields": [{"name": "model", "value": "whisper-large"}],
"files": [
{
"name": "file",
"filename": "sample.wav",
"content_type": "audio/wav",
"data_base64": "aGVsbG8=",
}
],
},
},
)
response = await execute_broker_request(app, envelope)
assert response["status"] == "ok"
assert response["payload"]["status_code"] == 200
assert response["payload"]["body"] == '{"model":"whisper-large","filename":"sample.wav","size":5}'
@pytest.mark.anyio("asyncio") @pytest.mark.anyio("asyncio")
async def test_brokered_models_match_direct_http_response_shape(): async def test_brokered_models_match_direct_http_response_shape():
direct_response = TestClient(real_app).get("/v1/models") direct_response = TestClient(real_app).get("/v1/models")
envelope = BrokerRequestEnvelope( envelope = BrokerRequestEnvelope(
request_id="req-models-shape", request_id="req-models-shape",
op="models.list",
method="GET", method="GET",
path="/v1/models", path="/v1/models",
headers={"accept": "application/json"}, headers={"accept": "application/json"},
...@@ -163,7 +229,7 @@ async def test_brokered_models_match_direct_http_response_shape(): ...@@ -163,7 +229,7 @@ async def test_brokered_models_match_direct_http_response_shape():
brokered_body = json.loads(brokered_response["payload"]["body"]) brokered_body = json.loads(brokered_response["payload"]["body"])
direct_body = direct_response.json() direct_body = direct_response.json()
assert brokered_response["ok"] is True assert brokered_response["status"] == "ok"
assert brokered_response["payload"]["status_code"] == direct_response.status_code assert brokered_response["payload"]["status_code"] == direct_response.status_code
assert brokered_response["payload"]["content_type"] == direct_response.headers["content-type"] assert brokered_response["payload"]["content_type"] == direct_response.headers["content-type"]
assert brokered_response["payload"]["headers"]["content-type"] == direct_response.headers["content-type"] assert brokered_response["payload"]["headers"]["content-type"] == direct_response.headers["content-type"]
...@@ -192,6 +258,7 @@ async def test_execute_broker_request_rejects_unsupported_endpoint(): ...@@ -192,6 +258,7 @@ async def test_execute_broker_request_rejects_unsupported_endpoint():
app = FastAPI() app = FastAPI()
envelope = BrokerRequestEnvelope( envelope = BrokerRequestEnvelope(
request_id="req-unsupported", request_id="req-unsupported",
op="proxy",
method="GET", method="GET",
path="/internal", path="/internal",
) )
...@@ -199,8 +266,9 @@ async def test_execute_broker_request_rejects_unsupported_endpoint(): ...@@ -199,8 +266,9 @@ async def test_execute_broker_request_rejects_unsupported_endpoint():
response = await execute_broker_request(app, envelope) response = await execute_broker_request(app, envelope)
assert response == { assert response == {
"v": 1,
"request_id": "req-unsupported", "request_id": "req-unsupported",
"ok": False, "status": "error",
"error": { "error": {
"code": "unsupported_endpoint", "code": "unsupported_endpoint",
"message": "Unsupported endpoint: /internal", "message": "Unsupported endpoint: /internal",
......
This diff is collapsed.
This diff is collapsed.
...@@ -849,7 +849,10 @@ def test_settings_template_includes_broker_controls(): ...@@ -849,7 +849,10 @@ def test_settings_template_includes_broker_controls():
assert "AISBF Broker" in template assert "AISBF Broker" in template
assert "s-broker-enabled" in template assert "s-broker-enabled" in template
assert "s-broker-base-url" in template assert "s-broker-base-url" in template
assert 'id="s-broker-scope" class="form-input" onchange="toggleBrokerFields()"' in template
assert "s-broker-provider-id" in template assert "s-broker-provider-id" in template
assert "s-broker-client-id" in template assert "s-broker-client-id" in template
assert "s-broker-registration-token" in template assert "s-broker-registration-token" in template
assert "s-broker-websocket-path" in template
assert "toggleBrokerFields()" in template assert "toggleBrokerFields()" in template
assert "forced to `global` for global scope" in template
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment