New version!

parent 3716e54b
......@@ -40,9 +40,11 @@ templates = Jinja2Templates(directory=str(templates_dir))
# Session manager (will be initialized in main.py)
session_manager: Optional[SessionManager] = None
SESSION_COOKIE_NAME: str = "session" # overridden at startup to be port-specific
config_manager = None # set via set_config_manager()
_download_sessions: dict = {}
_download_status: dict = {} # session_id → latest progress state (survives SSE disconnect)
_download_cancelled: set = set() # session_ids the user has requested to cancel
def _url(request: Request, path: str) -> str:
......@@ -59,10 +61,17 @@ def _tmpl(request: Request, name: str, ctx: dict = None):
return templates.TemplateResponse(request, name, c)
def init_session_manager(config_dir: Path):
"""Initialize the session manager."""
global session_manager
def init_session_manager(config_dir: Path, port: int = 0):
"""Initialize the session manager.
port is used to derive a port-specific cookie name so that two instances
running on different ports on the same host don't share (and overwrite)
each other's browser cookie.
"""
global session_manager, SESSION_COOKIE_NAME
session_manager = SessionManager(config_dir)
if port:
SESSION_COOKIE_NAME = f"session_{port}"
def set_config_manager(mgr):
......@@ -73,6 +82,21 @@ def set_config_manager(mgr):
init_capability_cache(str(mgr.config_dir))
def _broker_notify_models_updated(request: Request) -> None:
"""Fire-and-forget: tell AISBF broker to refresh its model cache if connected."""
try:
broker_service = getattr(request.app.state, "broker_service", None)
if broker_service is None:
return
client = getattr(broker_service, "client", None)
if client is None:
return
import asyncio
asyncio.create_task(client.notify_models_updated())
except Exception:
pass
def _next_whisper_server_model_id(audio_models) -> str:
used_suffixes = set()
for model in audio_models or []:
......@@ -99,7 +123,7 @@ def get_current_user(request: Request) -> Optional[str]:
if session_manager is None:
return None
cookie = request.cookies.get("session")
cookie = request.cookies.get(SESSION_COOKIE_NAME)
if not cookie:
return None
......@@ -160,7 +184,7 @@ async def login(
redirect_path = "/admin/change-password" if must_change else "/admin"
response = RedirectResponse(url=_url(request, redirect_path), status_code=302)
response.set_cookie(
key="session",
key=SESSION_COOKIE_NAME,
value=session_cookie,
httponly=True,
secure=False, # Set to True if using HTTPS
......@@ -174,11 +198,11 @@ async def login(
async def logout(request: Request):
"""Handle logout."""
if session_manager:
cookie = request.cookies.get("session")
cookie = request.cookies.get(SESSION_COOKIE_NAME)
session_manager.destroy_session(cookie)
response = RedirectResponse(url=_url(request, "/login"), status_code=302)
response.delete_cookie("session")
response.delete_cookie(SESSION_COOKIE_NAME)
return response
......@@ -568,7 +592,56 @@ async def api_list_models(username: str = Depends(require_admin)):
return []
def _make_tqdm_class(pq, status=None):
_DISK_MIN_FREE_BYTES = 256 * 1024 * 1024 # 256 MB safety margin
def _check_disk_space(path: str, needed_bytes: int = 0) -> None:
"""Raise RuntimeError if `path`'s filesystem lacks enough free space."""
import os as _os, shutil
# Walk up to an existing ancestor — the target dir may not exist yet on first download
check_path = path
while check_path and not _os.path.exists(check_path):
parent = _os.path.dirname(check_path)
if parent == check_path:
break
check_path = parent
try:
free = shutil.disk_usage(check_path).free
except OSError:
return # can't stat — proceed anyway
required = needed_bytes + _DISK_MIN_FREE_BYTES
if free < required:
free_gb = free / 1e9
needed_gb = needed_bytes / 1e9
msg = (
f"Not enough disk space: {free_gb:.1f} GB free"
+ (f", ~{needed_gb:.1f} GB needed" if needed_bytes else "")
+ ". Free up space and try again."
)
raise RuntimeError(msg)
def _get_hf_expected_size(model_id: str, file_pattern: str) -> int:
"""Return expected download size in bytes for a HF model (best-effort, 0 on failure)."""
try:
import fnmatch
from huggingface_hub import model_info as _hf_model_info
info = _hf_model_info(model_id, files_metadata=True)
siblings = info.siblings or []
if file_pattern:
if file_pattern.startswith('.'):
pats = [f"*{file_pattern}"]
elif '/' in file_pattern:
pats = [file_pattern]
else:
pats = [f"*{file_pattern}"]
siblings = [s for s in siblings if any(fnmatch.fnmatch(s.rfilename, p) for p in pats)]
return sum(getattr(s, 'size', 0) or 0 for s in siblings)
except Exception:
return 0
def _make_tqdm_class(pq, status=None, session_id=None, cache_dir=None):
"""Return a tqdm-compatible class that forwards progress events to pq and optionally updates a status dict."""
import time as _time
......@@ -579,6 +652,7 @@ def _make_tqdm_class(pq, status=None):
self.total = int(total) if total else 0
self.n = int(initial) if initial else 0
self._start = _time.time()
self._update_count = 0
if self.total:
pq.put({"type": "start", "filename": self.desc, "total": self.total})
if status is not None:
......@@ -586,7 +660,13 @@ def _make_tqdm_class(pq, status=None):
"total": self.total, "downloaded": self.n, "percent": 0})
def update(self, n=1):
if session_id and session_id in _download_cancelled:
raise RuntimeError("Download cancelled by user")
self.n += n
self._update_count += 1
# Check disk space every 64 progress ticks
if cache_dir and self._update_count % 64 == 0:
_check_disk_space(cache_dir)
elapsed = (_time.time() - self._start) or 0.001
rate = self.n / elapsed
eta = (self.total - self.n) / rate if rate and self.total else None
......@@ -675,37 +755,97 @@ def _run_download_thread(session_id: str, model_id: str, file_pattern: str, pq):
status["last_info"] = evt.get("message", "")
try:
from codai.models.cache import is_huggingface_model_id
from codai.models.cache import is_huggingface_model_id, get_model_cache_dir, get_hf_hub_cache_dir
from huggingface_hub import snapshot_download
tqdm_cls = _make_tqdm_class(pq, status=status)
if is_huggingface_model_id(model_id):
if file_pattern:
# Convert suffix/quant pattern to fnmatch glob for allow_patterns
# GGUF files always land in the GGUF cache (flat); everything else
# (full repos, transformers checkpoints, diffusers, …) goes in the HF cache.
is_gguf_download = file_pattern and '.gguf' in file_pattern.lower()
if is_gguf_download:
gguf_cache = get_model_cache_dir()
dl_cache_dir = gguf_cache
else:
dl_cache_dir = get_hf_hub_cache_dir()
# Pre-check disk space using HF file-size metadata
expected_bytes = _get_hf_expected_size(model_id, file_pattern)
_check_disk_space(dl_cache_dir, expected_bytes)
tqdm_cls = _make_tqdm_class(pq, status=status, session_id=session_id, cache_dir=dl_cache_dir)
if is_gguf_download:
import fnmatch as _fnmatch
import shutil as _shutil
from huggingface_hub import list_repo_files, hf_hub_download
# If pattern has no wildcards and looks like an exact filename, skip listing.
_is_exact = ('*' not in file_pattern and '?' not in file_pattern
and file_pattern.lower().endswith('.gguf'))
if _is_exact:
matching = [file_pattern]
else:
# Resolve the pattern to actual filenames in the repo
if file_pattern.startswith('.'):
allow = [f"*{file_pattern}"] # ".gguf" → "*.gguf"
pat = f"*{file_pattern}"
elif '/' in file_pattern:
allow = [file_pattern] # exact subpath
pat = file_pattern
else:
allow = [f"*{file_pattern}"] # "Q4_K_M.gguf" → "*Q4_K_M.gguf"
pat = f"*{file_pattern}"
all_repo_files = list(list_repo_files(model_id))
matching = [
f for f in all_repo_files
if _fnmatch.fnmatch(f, pat) or _fnmatch.fnmatch(os.path.basename(f), pat)
]
if not matching:
push({"type": "error", "message": f"No files matching {file_pattern!r} found in {model_id}"})
return
last_dest = gguf_cache
for hf_filename in matching:
basename = os.path.basename(hf_filename)
push({"type": "info", "message": f"Downloading {basename} from {model_id}…"})
dl_path = hf_hub_download(
repo_id=model_id,
filename=hf_filename,
local_dir=gguf_cache,
tqdm_class=tqdm_cls,
)
# hf_hub_download preserves subfolder structure; flatten to cache root
flat_dest = os.path.join(gguf_cache, basename)
if os.path.abspath(dl_path) != os.path.abspath(flat_dest) and os.path.isfile(dl_path):
_shutil.move(dl_path, flat_dest)
last_dest = flat_dest
path = last_dest
elif file_pattern:
# Non-GGUF pattern — use snapshot into HF cache
if file_pattern.startswith('.'):
allow = [f"*{file_pattern}"]
elif '/' in file_pattern:
allow = [file_pattern]
else:
allow = [f"*{file_pattern}"]
push({"type": "info", "message": f"Downloading {allow[0]} from {model_id}…"})
path = snapshot_download(model_id, allow_patterns=allow, tqdm_class=tqdm_cls)
path = snapshot_download(model_id, cache_dir=dl_cache_dir, allow_patterns=allow, tqdm_class=tqdm_cls)
else:
push({"type": "info", "message": f"Downloading full repository {model_id}…"})
path = snapshot_download(model_id, tqdm_class=tqdm_cls)
path = snapshot_download(model_id, cache_dir=dl_cache_dir, tqdm_class=tqdm_cls)
else:
# Direct URL download (non-HF source)
import requests as _req
import hashlib
from codai.models.cache import get_model_cache_dir
cache_dir = get_model_cache_dir()
dl_cache_dir = get_model_cache_dir()
_check_disk_space(dl_cache_dir) # basic free-space sanity check before connecting
url_path = model_id.split('?')[0]
filename = os.path.basename(url_path) or "model.bin"
url_hash = hashlib.sha256(model_id.encode()).hexdigest()
dest = os.path.join(cache_dir, f"{url_hash}_{filename}")
dest = os.path.join(dl_cache_dir, f"{url_hash}_{filename}")
if os.path.exists(dest):
push({"type": "done", "path": dest})
......@@ -714,14 +854,22 @@ def _run_download_thread(session_id: str, model_id: str, file_pattern: str, pq):
resp = _req.get(model_id, stream=True, timeout=60, allow_redirects=True)
resp.raise_for_status()
total = int(resp.headers.get('content-length', 0))
if total:
_check_disk_space(dl_cache_dir, total)
push({"type": "start", "filename": filename, "total": total})
tqdm_cls = _make_tqdm_class(pq, status=status, session_id=session_id, cache_dir=dl_cache_dir)
downloaded = 0
start_t = time.time()
last_evt = 0.0
with open(dest, 'wb') as f:
for chunk in resp.iter_content(chunk_size=524288):
if chunk:
if session_id in _download_cancelled:
raise RuntimeError("Download cancelled by user")
# Check disk space roughly every 64 MB
if downloaded % (64 * 1024 * 1024) < len(chunk):
_check_disk_space(dl_cache_dir)
f.write(chunk)
downloaded += len(chunk)
now = time.time()
......@@ -742,8 +890,13 @@ def _run_download_thread(session_id: str, model_id: str, file_pattern: str, pq):
push({"type": "done", "path": str(path)})
except Exception as exc:
if session_id in _download_cancelled:
pq.put({"type": "cancelled", "message": "Download cancelled by user"})
_download_status.get(session_id, {}).update({"status": "cancelled"})
else:
push({"type": "error", "message": str(exc)})
finally:
_download_cancelled.discard(session_id)
def _gc():
time.sleep(300)
_download_sessions.pop(session_id, None)
......@@ -831,12 +984,43 @@ async def api_delete_model(
# --- Download status / cache management ---
@router.get("/admin/api/hf-files")
async def api_hf_repo_files(repo_id: str, username: str = Depends(require_admin)):
"""Return the file list for a HuggingFace repo with name and size metadata."""
import asyncio
def _fetch():
try:
from huggingface_hub import model_info as _hf_model_info
info = _hf_model_info(repo_id, files_metadata=True)
return {
"repo_id": repo_id,
"files": [
{"name": f.rfilename, "size": getattr(f, "size", None)}
for f in (info.siblings or [])
],
}
except Exception as exc:
return {"repo_id": repo_id, "files": [], "error": str(exc)}
return await asyncio.to_thread(_fetch)
@router.get("/admin/api/downloads")
async def api_list_downloads(username: str = Depends(require_admin)):
"""Return status of all active and recently completed download sessions."""
return list(_download_status.values())
@router.post("/admin/api/download-cancel/{session_id}")
async def api_cancel_download(session_id: str, username: str = Depends(require_admin)):
"""Request cancellation of an active download session."""
if session_id not in _download_sessions and session_id not in _download_status:
raise HTTPException(status_code=404, detail="Download session not found")
_download_cancelled.add(session_id)
return {"success": True}
@router.post("/admin/api/model-upload")
async def api_model_upload(request: Request, username: str = Depends(require_admin)):
"""Upload a GGUF model file in chunks."""
......@@ -873,6 +1057,22 @@ async def api_model_upload(request: Request, username: str = Depends(require_adm
# ── cache scan helpers (run in thread pool) ──────────────────────────────────
def _hf_repo_id_from_path(path: str) -> str:
"""Extract a HuggingFace repo ID from an HF hub cache path.
HF hub cache paths look like:
.../hub/models--OWNER--REPO-NAME/snapshots/HASH/filename.gguf
The first '--' inside the 'models--...' component separates owner from repo.
"""
for part in path.replace('\\', '/').split('/'):
if part.startswith('models--'):
repo_part = part[len('models--'):]
sep = repo_part.find('--')
if sep != -1:
return repo_part[:sep] + '/' + repo_part[sep + 2:]
return ''
def _scan_caches() -> dict:
import os
result: dict = {"hf": [], "gguf": []}
......@@ -883,8 +1083,11 @@ def _scan_caches() -> dict:
)
caches = get_all_cache_dirs()
# Collect configured models: key (path/id) → (settings_dict, model_type)
# Collect configured models.
# configured_settings: path → primary (first) config entry (backward compat)
# all_configs: path → list of all config entries (for multi-config support)
configured_settings: dict = {}
all_configs: dict = {}
if config_manager:
md = config_manager.models_data
for cat in ("text_models", "image_models", "audio_models",
......@@ -892,12 +1095,31 @@ def _scan_caches() -> dict:
"audio_gen_models", "embedding_models", "spatial_models"):
for m in md.get(cat, []):
if isinstance(m, str):
p = m
configured_settings[p] = ({}, cat)
p, s = m, {}
else:
# Whisper-server entries have no "path"; their file is at "model_path"
if m.get("backend") == "whisper-server" and m.get("model_path"):
p = m["model_path"]
else:
p = m.get("path") or m.get("id") or ""
if p:
configured_settings[p] = (m, cat)
s = m if isinstance(m, dict) else {}
if not p:
continue
if p not in configured_settings:
configured_settings[p] = (s, cat)
all_configs.setdefault(p, []).append({"settings": s, "cat": cat})
# Secondary index: basename → (settings_tuple, original_path)
# Used to reconnect a config to a re-downloaded file that landed at a different path.
# Only populated for .gguf entries whose basename is unique (avoids ambiguous matches).
_cfg_by_fname: dict = {}
for _p, _val in configured_settings.items():
_bn = os.path.basename(_p) if ('/' in _p or os.sep in _p) else _p
if _bn and _bn.endswith('.gguf'):
if _bn in _cfg_by_fname:
_cfg_by_fname[_bn] = None # mark as ambiguous — don't use
else:
_cfg_by_fname[_bn] = (_val, _p)
# HuggingFace cache
hf_dir = caches.get("huggingface")
......@@ -905,6 +1127,30 @@ def _scan_caches() -> dict:
try:
from huggingface_hub import scan_cache_dir
info = scan_cache_dir(hf_dir)
# Build set of repo IDs that have incomplete/corrupted cache entries.
# huggingface_hub reports these via info.warnings (CorruptedCacheException).
incomplete_repos: set = set()
for w in getattr(info, 'warnings', []):
rid = getattr(w, 'repo_id', None)
if rid:
incomplete_repos.add(str(rid))
# Also scan each repo's blobs directory for .incomplete marker files
# (used by some huggingface_hub versions for in-progress downloads).
try:
for _repo_entry in os.scandir(hf_dir):
if not _repo_entry.is_dir() or not _repo_entry.name.startswith('models--'):
continue
_blobs = os.path.join(_repo_entry.path, 'blobs')
if os.path.isdir(_blobs) and any(
n.endswith('.incomplete') or n.endswith('.lock')
for n in os.listdir(_blobs)
):
_rid = _repo_entry.name[len('models--'):].replace('--', '/', 1)
incomplete_repos.add(_rid)
except Exception:
pass
for repo in sorted(info.repos, key=lambda r: r.repo_id):
revs = sorted(repo.revisions, key=lambda r: r.commit_hash)
size_bytes = sum(r.size_on_disk for r in repo.revisions)
......@@ -921,22 +1167,31 @@ def _scan_caches() -> dict:
fname = hf_file.file_name
fsize = hf_file.size_on_disk
cfg = (configured_settings.get(fpath)
or configured_settings.get(fname)
or ({}, None))
or configured_settings.get(fname))
_fname_match = None if cfg else _cfg_by_fname.get(fname)
cfg = cfg or (_fname_match[0] if _fname_match else None) or ({}, None)
_configured_path = _fname_match[1] if _fname_match else None
cfg_s = cfg[0] if isinstance(cfg[0], dict) else {}
saved_caps = cfg_s.get("capabilities") or []
caps_list = saved_caps if saved_caps else detect_model_capabilities(fname).to_list()
result["gguf"].append({
_direct_match = fpath in configured_settings or fname in configured_settings
_gguf_key = fpath if fpath in all_configs else (fname if fname in all_configs else (_configured_path or fpath))
_entry = {
"filename": fname,
"path": fpath,
"size_gb": round(fsize / 1e9, 2),
"size_bytes": fsize,
"in_config": fpath in configured_settings or fname in configured_settings,
"in_config": _direct_match or bool(_fname_match),
"model_type": cfg[1] if cfg[1] and cfg[1] != "gguf_models" else "text_models",
"settings": cfg_s,
"capabilities": caps_list,
"source_repo": repo.repo_id,
})
"incomplete": repo.repo_id in incomplete_repos,
"configs": all_configs.get(_gguf_key, []),
}
if _configured_path:
_entry["configured_path"] = _configured_path
result["gguf"].append(_entry)
continue # skip adding to hf list
cfg = configured_settings.get(repo.repo_id, ({}, None))
......@@ -959,6 +1214,8 @@ def _scan_caches() -> dict:
"model_type": cfg[1] if cfg[1] and cfg[1] != "gguf_models" else "text_models",
"settings": cfg_settings,
"capabilities": caps_list,
"incomplete": repo.repo_id in incomplete_repos,
"configs": all_configs.get(repo.repo_id, []),
})
except Exception as e:
result["hf_error"] = str(e)
......@@ -966,26 +1223,52 @@ def _scan_caches() -> dict:
# GGUF cache (coderai-specific)
gguf_dir = caches.get("coderai") or get_model_cache_dir()
if gguf_dir and os.path.exists(gguf_dir):
# Files with these suffixes are known-incomplete downloads
_incomplete_gguf_stems = {
os.path.splitext(n)[0]
for n in os.listdir(gguf_dir)
if n.endswith(('.part', '.tmp', '.download', '.incomplete'))
}
for fname in sorted(os.listdir(gguf_dir)):
fpath = os.path.join(gguf_dir, fname)
if os.path.isfile(fpath):
if not os.path.isfile(fpath):
continue
# Skip the partial-download sentinel files themselves
if any(fname.endswith(s) for s in ('.part', '.tmp', '.download', '.incomplete')):
continue
size = os.path.getsize(fpath)
cfg = (configured_settings.get(fpath)
or configured_settings.get(fname)
or ({}, None))
or configured_settings.get(fname))
_fname_match = None if cfg else _cfg_by_fname.get(fname)
cfg = cfg or (_fname_match[0] if _fname_match else None) or ({}, None)
_configured_path = _fname_match[1] if _fname_match else None
cfg_s = cfg[0] if isinstance(cfg[0], dict) else {}
saved_caps = cfg_s.get("capabilities") or []
caps_list = saved_caps if saved_caps else detect_model_capabilities(fname).to_list()
result["gguf"].append({
# A file is incomplete if there is a same-stem partial file alongside it,
# or if an active download session targets this exact path/filename.
_stem = os.path.splitext(fname)[0]
_dl_active = any(
s.get("model_id") in (fname, fpath) and s.get("status") not in ("done", "error", "cancelled")
for s in _download_status.values()
)
_direct_match = fpath in configured_settings or fname in configured_settings
_key = fpath if fpath in all_configs else (fname if fname in all_configs else (_configured_path or fpath))
_entry = {
"filename": fname,
"path": fpath,
"size_gb": round(size / 1e9, 2),
"size_bytes": size,
"in_config": fpath in configured_settings or fname in configured_settings,
"in_config": _direct_match or bool(_fname_match),
"model_type": cfg[1] if cfg[1] and cfg[1] != "gguf_models" else "text_models",
"settings": cfg_s,
"capabilities": caps_list,
})
"incomplete": _stem in _incomplete_gguf_stems or _dl_active,
"configs": all_configs.get(_key, []),
}
if _configured_path:
_entry["configured_path"] = _configured_path
result["gguf"].append(_entry)
# Add configured GGUF models not yet in the list (e.g., HF repo IDs or external paths)
existing_paths = {m["path"] for m in result["gguf"]}
......@@ -999,20 +1282,24 @@ def _scan_caches() -> dict:
# Check if it's a GGUF model (ends with .gguf or is in a GGUF repo)
is_gguf = path.endswith('.gguf') or 'gguf' in path.lower() or mtype == "gguf_models"
if is_gguf:
# Try to get size if it's a local file
size_bytes = 0
if os.path.isfile(path):
size_bytes = os.path.getsize(path)
file_exists = os.path.isfile(path)
size_bytes = os.path.getsize(path) if file_exists else 0
caps = detect_model_capabilities(path)
s = settings if isinstance(settings, dict) else {}
# Derive HF repo ID from path when not explicitly stored in settings
source_repo = s.get("source_repo") or _hf_repo_id_from_path(path)
result["gguf"].append({
"filename": os.path.basename(path) if '/' in path else path,
"path": path,
"size_gb": round(size_bytes / 1e9, 2) if size_bytes else 0,
"size_bytes": size_bytes,
"in_config": True,
"missing": not file_exists,
"source_repo": source_repo,
"model_type": mtype if mtype and mtype != "gguf_models" else "text_models",
"settings": settings if isinstance(settings, dict) else {},
"settings": s,
"capabilities": caps.to_list(),
"configs": all_configs.get(path, []),
})
return result
......@@ -1224,6 +1511,7 @@ async def api_model_enable(request: Request, username: str = Depends(require_adm
if path not in lst:
lst.append(path)
config_manager.save_models()
_broker_notify_models_updated(request)
return {"success": True}
......@@ -1235,11 +1523,17 @@ async def api_model_disable(request: Request, username: str = Depends(require_ad
import os as _os
data = await request.json()
path = data.get("path") or data.get("model_id", "")
config_id = (data.get("config_id") or "").strip()
# Also match by bare filename so entries stored without full path are caught
fname = _os.path.basename(path) if (_os.sep in path or "/" in path) else ""
def _matches(m_entry) -> bool:
key = m_entry if isinstance(m_entry, str) else m_entry.get("path", m_entry.get("id", ""))
if isinstance(m_entry, str):
return m_entry == path or (fname and _os.path.basename(m_entry) == fname)
if config_id:
# Targeted removal: only remove the entry with this config_id
return m_entry.get("config_id", "") == config_id
key = m_entry.get("path", m_entry.get("id", ""))
return key == path or (fname and _os.path.basename(key) == fname)
changed = False
......@@ -1253,6 +1547,7 @@ async def api_model_disable(request: Request, username: str = Depends(require_ad
changed = True
if changed:
config_manager.save_models()
_broker_notify_models_updated(request)
return {"success": True}
......@@ -1292,6 +1587,7 @@ async def api_model_load(request: Request, username: str = Depends(require_admin
# Find the model config entry to determine its type
model_type = "text"
model_cfg: dict = {}
if config_manager:
md = config_manager.models_data
for cat, mtype in (("image_models", "image"), ("audio_models", "audio"),
......@@ -1304,6 +1600,7 @@ async def api_model_load(request: Request, username: str = Depends(require_admin
mid = m if isinstance(m, str) else m.get("path") or m.get("id") or ""
if mid == path:
model_type = mtype
model_cfg = m if isinstance(m, dict) else {}
break
result = multi_model_manager.request_model(path, model_type if model_type != "text" else None)
......@@ -1365,6 +1662,77 @@ async def api_model_load(request: Request, username: str = Depends(require_admin
if pipeline:
multi_model_manager.add_model(model_key, pipeline)
multi_model_manager.record_vram_delta(model_key, _snap)
elif model_type == "video":
import asyncio
from codai.api.video import _load_video_pipeline, _derive_device
model_key = f"video:{path}"
device = _derive_device()
_snap = multi_model_manager.vram_before_load()
_offload = model_cfg.get("offload_strategy") or None
pipe = await asyncio.to_thread(_load_video_pipeline, path, device, "t2v", _offload, model_cfg)
if pipe is None:
raise RuntimeError("Video model failed to load")
multi_model_manager.models[model_key] = pipe
multi_model_manager.current_model_key = model_key
multi_model_manager.active_in_vram = model_key
multi_model_manager.models_in_vram.add(model_key)
multi_model_manager.record_vram_delta(model_key, _snap)
elif model_type == "audio_gen":
import asyncio
from codai.api.audio_gen import _load_musicgen, _load_audioldm, _detect_audio_gen_type, _derive_device
model_key = f"audio_gen:{path}"
device = _derive_device()
_snap = multi_model_manager.vram_before_load()
gen_type = _detect_audio_gen_type(path)
if gen_type in ("musicgen", "audiogen"):
pipe = await asyncio.to_thread(_load_musicgen, path, device)
else:
pipe = await asyncio.to_thread(_load_audioldm, path, device)
if pipe is None:
raise RuntimeError("Audio gen model failed to load")
multi_model_manager.models[model_key] = pipe
multi_model_manager.current_model_key = model_key
multi_model_manager.active_in_vram = model_key
multi_model_manager.models_in_vram.add(model_key)
multi_model_manager.record_vram_delta(model_key, _snap)
elif model_type == "tts":
import asyncio
model_key = f"tts:{path}"
_snap = multi_model_manager.vram_before_load()
def _load_tts():
try:
from kokoro import Kokoro
return Kokoro(path)
except ImportError:
pass
try:
from bark import preload_models
preload_models()
return {"bark": True}
except ImportError:
pass
return None
tts_obj = await asyncio.to_thread(_load_tts)
if tts_obj is None:
raise RuntimeError("No supported TTS backend found (kokoro / bark)")
multi_model_manager.models[model_key] = tts_obj
multi_model_manager.current_model_key = model_key
multi_model_manager.active_in_vram = model_key
multi_model_manager.models_in_vram.add(model_key)
multi_model_manager.record_vram_delta(model_key, _snap)
elif model_type in ("embedding", "spatial", "vision"):
import asyncio
from codai.api.images import _load_diffusers_pipeline
from codai.api.state import get_global_args
model_key = f"{model_type}:{path}"
_snap = multi_model_manager.vram_before_load()
pipeline = await asyncio.to_thread(_load_diffusers_pipeline, path, get_global_args())
if pipeline is None:
raise RuntimeError(f"{model_type} model failed to load")
multi_model_manager.add_model(model_key, pipeline)
multi_model_manager.active_in_vram = model_key
multi_model_manager.models_in_vram.add(model_key)
multi_model_manager.record_vram_delta(model_key, _snap)
return {"success": True, "already_loaded": False}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
......@@ -1488,18 +1856,34 @@ async def api_model_configure(request: Request, username: str = Depends(require_
if not model_types:
model_types = ["text_models"]
# Remove from all categories (handles type changes and quant switches)
# paths_to_remove: the new path + the original path before a quant change
# config_id: when provided, identifies a specific config entry to update.
# A new UUID is assigned if none is given (new entry).
config_id = (data.get("config_id") or "").strip()
is_new_config = not config_id
if is_new_config:
config_id = str(_uuid.uuid4())
# Remove from all categories.
# When config_id matches an existing entry, remove only that entry so that
# sibling configs (same path, different config_id) are preserved.
# Fall back to path-based removal for entries that predate config_id support.
import os as _os
paths_to_remove = {path}
orig_path = (data.get("orig_path") or "").strip()
if orig_path and orig_path != path:
paths_to_remove.add(orig_path)
# Also match by bare filename so entries stored without full path are caught
fnames_to_remove = {_os.path.basename(p) for p in paths_to_remove if _os.sep in p or "/" in p}
def _should_remove(m_entry) -> bool:
key = m_entry if isinstance(m_entry, str) else m_entry.get("path", m_entry.get("id", ""))
if not isinstance(m_entry, dict):
# Legacy string entry — fall back to path matching
return m_entry in paths_to_remove
existing_cid = m_entry.get("config_id", "")
if existing_cid and not is_new_config:
# Targeted removal: only the entry that shares this config_id
return existing_cid == config_id
# Path-based removal (no config_id on either side, or new entry replacing old)
key = m_entry.get("path", m_entry.get("id", ""))
return key in paths_to_remove or (fnames_to_remove and _os.path.basename(key) in fnames_to_remove)
for cat in valid | {"gguf_models"}:
......@@ -1522,17 +1906,20 @@ async def api_model_configure(request: Request, username: str = Depends(require_
used_vram_gb = round(size_bytes / 1e9 * 1.2, 2)
# Build settings entry
entry: dict = {"path": path, "model_type": model_types[0], "model_types": model_types}
entry: dict = {"path": path, "model_type": model_types[0], "model_types": model_types, "config_id": config_id}
if used_vram_gb is not None:
entry["used_vram_gb"] = used_vram_gb
# Store video sub-types (t2v / i2v / v2v) when present
if data.get("video_subtypes"):
entry["video_subtypes"] = data["video_subtypes"]
for key in ("alias", "backend", "load_mode", "n_gpu_layers", "n_ctx",
for key in ("alias", "config_name", "backend", "load_mode", "n_gpu_layers", "n_ctx",
"max_gpu_percent", "manual_ram_gb", "load_in_4bit", "load_in_8bit",
"flash_attention", "no_ram", "offload_strategy", "offload_dir",
"system_prompt", "parser", "tools_closer_prompt", "grammar_guided",
"max_instances", "preload_all_instances", "capabilities"):
"max_instances", "preload_all_instances", "capabilities",
"model_template", "vae_path", "t5xxl_path", "clip_l_path",
"clip_g_path", "clip_vision_path", "lora_path", "lora_model_dir",
"max_vram", "sdcpp_flash_attn", "sdcpp_diffusion_flash_attn", "vae_tiling"):
if key in data:
entry[key] = data[key]
......@@ -1633,12 +2020,14 @@ async def api_get_settings(username: str = Depends(require_admin)):
"client_id": c.broker.client_id,
"registration_token": c.broker.registration_token,
"advertised_endpoint": c.broker.advertised_endpoint,
"websocket_path": c.broker.websocket_path,
"transport": c.broker.transport,
"heartbeat_interval_seconds": c.broker.heartbeat_interval_seconds,
"connect_timeout_seconds": c.broker.connect_timeout_seconds,
"request_timeout_seconds": c.broker.request_timeout_seconds,
"reconnect_initial_delay_seconds": c.broker.reconnect_initial_delay_seconds,
"reconnect_max_delay_seconds": c.broker.reconnect_max_delay_seconds,
"websocket_ping_interval": c.broker.websocket_ping_interval,
},
"system_prompt": c.system_prompt,
"tools_closer_prompt": c.tools_closer_prompt,
......@@ -1735,11 +2124,16 @@ async def api_save_settings(request: Request, username: str = Depends(require_ad
c.broker.enabled = bool(bro.get("enabled", c.broker.enabled))
c.broker.base_url = (bro.get("base_url") or "").strip()
c.broker.scope = (bro.get("scope") or c.broker.scope or "user").strip()
c.broker.username = (bro.get("username") or "").strip()
broker_username = (bro.get("username") or "").strip()
if c.broker.scope == "global":
c.broker.username = "global"
else:
c.broker.username = broker_username
c.broker.provider_id = (bro.get("provider_id") or "").strip()
c.broker.client_id = (bro.get("client_id") or "").strip()
c.broker.registration_token = (bro.get("registration_token") or "").strip()
c.broker.advertised_endpoint = (bro.get("advertised_endpoint") or "").strip()
c.broker.websocket_path = (bro.get("websocket_path") or "").strip()
c.broker.transport = (bro.get("transport") or c.broker.transport or "websocket").strip()
c.broker.heartbeat_interval_seconds = max(1, int(bro.get("heartbeat_interval_seconds", c.broker.heartbeat_interval_seconds)))
c.broker.connect_timeout_seconds = max(1, int(bro.get("connect_timeout_seconds", c.broker.connect_timeout_seconds)))
......@@ -1749,8 +2143,12 @@ async def api_save_settings(request: Request, username: str = Depends(require_ad
c.broker.reconnect_initial_delay_seconds,
int(bro.get("reconnect_max_delay_seconds", c.broker.reconnect_max_delay_seconds)),
)
from codai.broker.config import build_broker_runtime_config
c.broker.websocket_ping_interval = max(5, int(bro.get("websocket_ping_interval", c.broker.websocket_ping_interval)))
from codai.broker.config import BrokerConfigError, build_broker_runtime_config
try:
request.app.state.broker_runtime = build_broker_runtime_config(c.broker)
except BrokerConfigError as error:
raise HTTPException(status_code=400, detail=str(error)) from error
config_manager.save_config()
return {"success": True}
......@@ -1851,6 +2249,7 @@ async def api_hf_search(
sizes: str = "", # comma-separated e.g. "7b,70b"
arch: str = "",
capabilities: str = "", # comma-separated e.g. "function-calling,vision"
component_type: str = "", # "vae" | "t5xxl" | "clip_l" | "clip_g" | "clip_vision" | "lora" | "encoder" | "controlnet" | "unet"
username: str = Depends(require_admin),
):
"""Proxy HuggingFace model search; supports multiple sizes via parallel requests."""
......@@ -1863,6 +2262,28 @@ async def api_hf_search(
if sort not in ("downloads", "likes", "lastModified", "createdAt"):
sort = "downloads"
# Component type → search keywords + HF tags
# Most components are safetensors, so override gguf_mode → "all" unless caller forced it
_COMP_SEARCH: dict = {
"vae": {"kw": "vae", "tags": ["vae"]},
"t5xxl": {"kw": "t5xxl OR t5-xxl", "tags": []},
"clip_l": {"kw": "clip-l encoder", "tags": []},
"clip_g": {"kw": "clip-g encoder", "tags": []},
"clip_vision": {"kw": "clip vision encoder","tags": []},
"lora": {"kw": "lora", "tags": ["lora"]},
"encoder": {"kw": "text encoder", "tags": []},
"controlnet": {"kw": "controlnet", "tags": ["controlnet"]},
"unet": {"kw": "unet", "tags": []},
}
comp_kw: str = ""
comp_tags: list = []
if component_type and component_type in _COMP_SEARCH:
spec = _COMP_SEARCH[component_type]
comp_kw = spec["kw"]
comp_tags = spec["tags"]
if gguf_mode == "gguf":
gguf_mode = "all" # components are usually safetensors; respect explicit "no-gguf" only
# Filter tags shared across all requests
filter_pairs: list = []
if gguf_mode == "gguf":
......@@ -1871,6 +2292,8 @@ async def api_hf_search(
filter_pairs.append(("filter", pipeline_tag))
if arch == "lora":
filter_pairs.append(("filter", "lora"))
for tag in comp_tags:
filter_pairs.append(("filter", tag))
# Capability filters
cap_list = [c.strip() for c in capabilities.split(",") if c.strip()]
......@@ -1879,6 +2302,8 @@ async def api_hf_search(
# Base search keywords
base_parts = [q.strip()] if q.strip() else []
if comp_kw:
base_parts.append(comp_kw)
if arch == "moe":
base_parts.append("moe")
......
......@@ -81,6 +81,7 @@ body,.table-wrap,.modal-box,.chat-messages,.studio .model-list,.studio .chat-msg
/* ── Main ────────────────────────────────────────────────────────── */
.main{min-height:calc(100vh - 44px)}
.container{max-width:1100px;margin:0 auto;padding:2rem 1.5rem}
.container--full{max-width:100%;padding:2rem 1.5rem}
/* ── Page header ─────────────────────────────────────────────────── */
.page-header{display:flex;justify-content:space-between;align-items:flex-start;margin-bottom:1.5rem;gap:1rem}
......@@ -125,6 +126,8 @@ body,.table-wrap,.modal-box,.chat-messages,.studio .model-list,.studio .chat-msg
.btn-ghost:hover{color:var(--text);border-color:var(--border-2)}
.btn-danger{background:rgba(248,113,113,.08);color:var(--red);border:1px solid rgba(248,113,113,.2)}
.btn-danger:hover{background:rgba(248,113,113,.15);border-color:rgba(248,113,113,.4)}
.btn-warn{background:rgba(251,191,36,.08);color:#f59e0b;border:1px solid rgba(251,191,36,.25)}
.btn-warn:hover{background:rgba(251,191,36,.15);border-color:rgba(251,191,36,.45)}
.btn-sm{padding:.25rem .625rem;font-size:12px}
.btn-sm svg{width:11px;height:11px}
.btn:disabled{opacity:.4;cursor:not-allowed}
......@@ -176,6 +179,7 @@ td code{font-family:var(--mono);font-size:11.5px;background:var(--raised);paddin
.badge-user{background:var(--raised);color:var(--text-3);border:1px solid var(--border)}
.badge-ok{background:rgba(52,211,153,.08);color:var(--green);border:1px solid rgba(52,211,153,.2)}
.badge-warn{background:rgba(251,191,36,.08);color:#f59e0b;border:1px solid rgba(251,191,36,.2)}
.badge-err{background:rgba(248,113,113,.08);color:var(--red);border:1px solid rgba(248,113,113,.2)}
.badge-danger{background:rgba(248,113,113,.08);color:var(--red);border:1px solid rgba(248,113,113,.2)}
/* ── Modals ──────────────────────────────────────────────────────── */
......@@ -276,4 +280,5 @@ hr{border:none;border-top:1px solid var(--border);margin:1.125rem 0}
.nav-links{gap:0}
.nav-link{padding:.3rem .5rem;font-size:12.5px}
.container{padding:1.25rem 1rem}
.container--full{padding:1.25rem 1rem}
}
{% extends "base.html" %}
{% block wrapper_class %}container container--full{% endblock %}
{% block title %}Models — CoderAI{% endblock %}
{% block head %}
......@@ -22,6 +23,7 @@
#info-drawer.open{transform:translateX(0)}
#info-sticky{position:sticky;top:0;background:var(--bg);border-bottom:1px solid var(--border);z-index:1;padding:1rem 1.25rem;display:flex;align-items:center;gap:.75rem}
#info-title{font-weight:600;font-size:14px;flex:1;overflow:hidden;text-overflow:ellipsis;white-space:nowrap}
._dl-quant-row:hover{background:var(--raised)}
</style>
{% endblock %}
......@@ -36,7 +38,7 @@ window.__DEFAULT_WHISPER_SERVER_PATH__ = {{ default_whisper_server_path|tojson }
</div>
<div class="header-actions">
<button class="btn btn-secondary" onclick="openModal('upload-modal')">Upload GGUF</button>
<button class="btn btn-primary" onclick="openModal('dl-modal')">Download model</button>
<button class="btn btn-primary" onclick="openDownloadFor('')">Download model</button>
</div>
</div>
......@@ -98,6 +100,16 @@ window.__DEFAULT_WHISPER_SERVER_PATH__ = {{ default_whisper_server_path|tojson }
<div id="gguf-models-list"><span class="muted small">Loading…</span></div>
</div>
<!-- Component files -->
<div class="card" style="margin-top:1rem">
<div style="display:flex;align-items:center;justify-content:space-between;flex-wrap:wrap;gap:.5rem;margin-bottom:.5rem">
<div class="card-title" style="margin:0">Component files <span id="comp-badge" class="muted small"></span></div>
<button class="btn btn-secondary btn-sm" onclick="openDownloadFor('')">Download</button>
</div>
<p class="muted small" style="margin-top:0;margin-bottom:.75rem">VAEs, text encoders (T5-XXL, CLIP), and other standalone files used alongside main models.</p>
<div id="comp-list"><span class="muted small">Loading…</span></div>
</div>
<div class="card mb-0" style="margin-top:1rem" id="ws-model-builder">
<div class="card-title">Whisper-server simulated models</div>
<p class="muted small" style="margin-top:0">Create local audio models backed by dedicated whisper-server subprocess configurations.</p>
......@@ -238,7 +250,23 @@ window.__DEFAULT_WHISPER_SERVER_PATH__ = {{ default_whisper_server_path|tojson }
</div>
</div>
<!-- filter row 4: quant chips (file-level filter) -->
<!-- filter row 4: component type chips (single-select) -->
<div style="display:flex;align-items:flex-start;gap:.5rem;margin-bottom:.5rem">
<span class="fl" style="padding-top:.25rem;min-width:32px">Comp.</span>
<div class="chip-row" id="comp-type-chips">
<span class="chip" data-val="vae">VAE</span>
<span class="chip" data-val="t5xxl">T5-XXL</span>
<span class="chip" data-val="clip_l">CLIP-L</span>
<span class="chip" data-val="clip_g">CLIP-G</span>
<span class="chip" data-val="clip_vision">CLIP Vision</span>
<span class="chip" data-val="lora">LoRA</span>
<span class="chip" data-val="encoder">Encoder</span>
<span class="chip" data-val="controlnet">ControlNet</span>
<span class="chip" data-val="unet">UNet</span>
</div>
</div>
<!-- filter row 5: quant chips (file-level filter) -->
<div style="display:flex;align-items:flex-start;gap:.5rem;margin-bottom:1rem">
<span class="fl" style="padding-top:.25rem;min-width:32px">Quant</span>
<div class="chip-row" id="quant-chips">
......@@ -262,6 +290,23 @@ window.__DEFAULT_WHISPER_SERVER_PATH__ = {{ default_whisper_server_path|tojson }
</div>
<!-- Generic confirm modal -->
<div id="confirm-modal" class="modal" onclick="if(event.target===this)document.getElementById('confirm-modal-cancel').click()">
<div class="modal-box" style="max-width:420px">
<div class="modal-head">
<span class="modal-title" id="confirm-modal-title">Confirm</span>
<button class="modal-close" id="confirm-modal-x">&times;</button>
</div>
<div class="modal-body">
<p id="confirm-modal-msg" style="margin:0 0 1.25rem"></p>
<div style="display:flex;gap:.5rem;justify-content:flex-end">
<button class="btn btn-ghost" id="confirm-modal-cancel">Cancel</button>
<button class="btn btn-danger" id="confirm-modal-ok">Confirm</button>
</div>
</div>
</div>
</div>
<!-- Download modal -->
<div id="dl-modal" class="modal">
<div class="modal-box">
......@@ -273,20 +318,35 @@ window.__DEFAULT_WHISPER_SERVER_PATH__ = {{ default_whisper_server_path|tojson }
<div id="dl-form">
<div class="form-row">
<label class="form-label">HuggingFace repo ID or URL</label>
<input type="text" id="dl-id" class="form-input" placeholder="e.g. bartowski/Llama-3.1-8B-Instruct-GGUF">
<div style="display:flex;gap:.5rem">
<input type="text" id="dl-id" class="form-input" style="flex:1" placeholder="e.g. bartowski/Llama-3.1-8B-Instruct-GGUF" onkeydown="if(event.key==='Enter'){event.preventDefault();browseHfFiles()}">
<button class="btn btn-ghost btn-sm" id="dl-browse-btn" onclick="browseHfFiles()" style="white-space:nowrap" title="Fetch available files from HuggingFace">Browse</button>
</div>
</div>
<!-- GGUF mode: specific file or pattern -->
<!-- Quant picker: shown when GGUF files are found via Browse -->
<div id="dl-quant-row" style="display:none;margin-top:.5rem">
<div style="display:flex;justify-content:space-between;align-items:center;margin-bottom:.35rem">
<label class="form-label" style="margin:0">Select quantizations to download</label>
<span style="font-size:11px;display:flex;gap:.25rem">
<button class="btn btn-ghost btn-sm" style="font-size:10px;padding:.1rem .4rem" onclick="_dlSelectAll(true)">All</button>
<button class="btn btn-ghost btn-sm" style="font-size:10px;padding:.1rem .4rem" onclick="_dlSelectAll(false)">None</button>
</span>
</div>
<div id="dl-quant-list" style="max-height:220px;overflow-y:auto;border:1px solid var(--border);border-radius:6px;padding:.25rem .5rem;display:flex;flex-direction:column;gap:.1rem"></div>
<div id="dl-quant-note" style="font-size:11px;color:var(--text-2);margin-top:.4rem"></div>
</div>
<!-- GGUF mode: manual file / pattern (fallback when Browse not used) -->
<div id="dl-pattern-row" class="form-row">
<label class="form-label">File / pattern</label>
<label class="form-label">File / pattern <span class="muted">(or use Browse above)</span></label>
<input type="text" id="dl-pattern" class="form-input" placeholder=".gguf">
<span class="form-hint" id="dl-hint">Exact filename (e.g. <code>model-Q4_K_M.gguf</code>) or pattern (<code>.gguf</code>). Leave blank to download the first .gguf found.</span>
<span class="form-hint">Exact filename or suffix pattern. Leave blank to download all .gguf files.</span>
</div>
<!-- Snapshot mode: full repo via HF API -->
<!-- Snapshot mode: full repo -->
<div id="dl-snapshot-note" class="alert alert-info" style="display:none">
Will download the full repository using the HuggingFace snapshot API. This is the correct method for safetensors / non-GGUF models. Large repos may take a while.
</div>
<div class="form-actions">
<button class="btn btn-primary" onclick="startDownload()">Download</button>
<button class="btn btn-primary" id="dl-start-btn" onclick="startDownloadFromForm()">Download</button>
<button class="btn btn-ghost" onclick="closeModal('dl-modal')">Close</button>
</div>
</div>
......@@ -300,6 +360,9 @@ window.__DEFAULT_WHISPER_SERVER_PATH__ = {{ default_whisper_server_path|tojson }
<span id="dl-pct">0%</span>
</div>
<div id="dl-log" style="display:none;background:var(--raised);border-radius:6px;padding:.4rem .6rem;font-size:11px;font-family:monospace;color:var(--text-2);max-height:72px;overflow-y:auto"></div>
<div style="display:flex;justify-content:flex-end;margin-top:.75rem">
<button class="btn btn-ghost btn-sm" onclick="minimizeDownload()" title="Close this window — the download continues in the background">Minimize ↓</button>
</div>
</div>
</div>
</div>
......@@ -409,11 +472,13 @@ window.__DEFAULT_WHISPER_SERVER_PATH__ = {{ default_whisper_server_path|tojson }
<input type="hidden" id="cfg-path">
<input type="hidden" id="cfg-orig-path">
<input type="hidden" id="cfg-orig-type">
<input type="hidden" id="cfg-config-id">
<!-- identity -->
<div class="form-row">
<label class="form-label">Model ID / path</label>
<div id="cfg-id-label" style="font-size:12px;font-family:monospace;color:var(--text-2);word-break:break-all;padding:.3rem 0"></div>
<div id="cfg-path-moved-notice" style="display:none;font-size:11px;color:#f59e0b;margin-top:.2rem;padding:.25rem .5rem;background:rgba(251,191,36,.08);border:1px solid rgba(251,191,36,.25);border-radius:4px"></div>
</div>
<div class="form-row" id="cfg-quant-row" style="display:none">
<label class="form-label">Quantization</label>
......@@ -484,6 +549,11 @@ window.__DEFAULT_WHISPER_SERVER_PATH__ = {{ default_whisper_server_path|tojson }
<label class="form-label">Alias <span class="muted">(optional)</span></label>
<input type="text" id="cfg-alias" class="form-input" placeholder="Friendly name">
</div>
<div class="form-row" id="cfg-config-name-row">
<label class="form-label">Configuration name <span class="muted">(optional)</span></label>
<input type="text" id="cfg-config-name" class="form-input" placeholder="Label for this config in the model list">
<div style="font-size:11px;color:var(--text-2);margin-top:.25rem">Shown on the config pill. Falls back to alias, then "Config N".</div>
</div>
<!-- backend -->
<div class="card-title" style="margin-top:1.25rem">Backend</div>
......@@ -563,9 +633,11 @@ window.__DEFAULT_WHISPER_SERVER_PATH__ = {{ default_whisper_server_path|tojson }
<label class="form-label">Strategy</label>
<select id="cfg-offload-strategy" class="form-input">
<option value="auto">Auto</option>
<option value="cpu">CPU RAM</option>
<option value="model">CPU offload (model)</option>
<option value="sequential">CPU offload (sequential)</option>
<option value="cpu">CPU RAM (legacy)</option>
<option value="disk">Disk</option>
<option value="none">None</option>
<option value="none">None (GPU only)</option>
</select>
</div>
<div class="form-row" style="margin:0">
......@@ -574,6 +646,86 @@ window.__DEFAULT_WHISPER_SERVER_PATH__ = {{ default_whisper_server_path|tojson }
</div>
</div>
<!-- sd.cpp video options -->
<div class="card-title" style="margin-top:1.25rem">sd.cpp Video Options</div>
<div class="form-row" style="margin-bottom:.5rem">
<label class="form-label">Max VRAM budget (MB) <span class="muted">graph-cut execution limit, 0 = disabled</span></label>
<input type="number" id="cfg-max-vram" class="form-input" min="0" step="512" placeholder="0">
</div>
<div style="display:flex;gap:1.5rem;flex-wrap:wrap;margin-top:.5rem">
<label style="display:flex;align-items:center;gap:.5rem;cursor:pointer;font-size:13px"><input type="checkbox" id="cfg-sdcpp-flash"> Flash Attention</label>
<label style="display:flex;align-items:center;gap:.5rem;cursor:pointer;font-size:13px"><input type="checkbox" id="cfg-sdcpp-diff-flash"> Diffusion Flash Attention</label>
<label style="display:flex;align-items:center;gap:.5rem;cursor:pointer;font-size:13px"><input type="checkbox" id="cfg-vae-tiling"> VAE Tiling <span class="muted">(reduces VRAM at decode)</span></label>
</div>
<!-- components -->
<div class="card-title" style="margin-top:1.25rem">Components</div>
<div class="form-row">
<label class="form-label">Model template <span class="muted">(architecture hint for sd.cpp)</span></label>
<select id="cfg-template" class="form-input">
<option value="">Auto-detect</option>
<optgroup label="WAN">
<option value="wan_t2v_1.3b">WAN 1.3B T2V</option>
<option value="wan_t2v_14b">WAN 14B T2V</option>
<option value="wan_i2v_480p">WAN 1.3B I2V 480p</option>
<option value="wan_i2v_720p">WAN 14B I2V 720p</option>
<option value="wan_vace">WAN VACE</option>
</optgroup>
<optgroup label="FLUX">
<option value="flux_dev">FLUX.1-dev</option>
<option value="flux_schnell">FLUX.1-schnell</option>
</optgroup>
<optgroup label="Stable Diffusion">
<option value="sd15">Stable Diffusion 1.5</option>
<option value="sdxl">Stable Diffusion XL</option>
<option value="sd3">Stable Diffusion 3</option>
</optgroup>
<optgroup label="Other">
<option value="auraflow">AuraFlow</option>
<option value="chroma">Chroma</option>
<option value="hunyuan_video">HunyuanVideo</option>
<option value="mochi">Mochi</option>
<option value="ltxv">LTX-Video</option>
<option value="cogvideox">CogVideoX</option>
</optgroup>
</select>
</div>
<div class="form-row" style="margin-top:.5rem">
<label class="form-label">VAE <span class="muted">(optional)</span></label>
<select id="cfg-vae-sel" class="form-input" onchange="_onCompPick('vae')"></select>
<input type="text" id="cfg-vae-path" class="form-input" style="display:none;margin-top:.3rem" placeholder="Custom VAE path…">
</div>
<div class="form-row">
<label class="form-label">T5-XXL <span class="muted">(optional)</span></label>
<select id="cfg-t5xxl-sel" class="form-input" onchange="_onCompPick('t5xxl')"></select>
<input type="text" id="cfg-t5xxl-path" class="form-input" style="display:none;margin-top:.3rem" placeholder="Custom T5-XXL path…">
</div>
<div class="form-row">
<label class="form-label">CLIP-L <span class="muted">(optional)</span></label>
<select id="cfg-clip-l-sel" class="form-input" onchange="_onCompPick('clip-l')"></select>
<input type="text" id="cfg-clip-l-path" class="form-input" style="display:none;margin-top:.3rem" placeholder="Custom CLIP-L path…">
</div>
<div class="form-row">
<label class="form-label">CLIP-G <span class="muted">(optional)</span></label>
<select id="cfg-clip-g-sel" class="form-input" onchange="_onCompPick('clip-g')"></select>
<input type="text" id="cfg-clip-g-path" class="form-input" style="display:none;margin-top:.3rem" placeholder="Custom CLIP-G path…">
</div>
<div class="form-row">
<label class="form-label">CLIP Vision <span class="muted">(optional)</span></label>
<select id="cfg-clip-vision-sel" class="form-input" onchange="_onCompPick('clip-vision')"></select>
<input type="text" id="cfg-clip-vision-path" class="form-input" style="display:none;margin-top:.3rem" placeholder="Custom CLIP Vision path…">
</div>
<div class="form-row">
<label class="form-label">LoRA file <span class="muted">(optional — applied at load)</span></label>
<select id="cfg-lora-sel" class="form-input" onchange="_onCompPick('lora')"></select>
<input type="text" id="cfg-lora-path" class="form-input" style="display:none;margin-top:.3rem" placeholder="Custom LoRA path…">
<span class="form-hint" style="font-size:11px;color:var(--text-3)">Tip: use &lt;lora:name:strength&gt; in prompts to apply at inference time</span>
</div>
<div class="form-row">
<label class="form-label">LoRA directory <span class="muted">(optional — for prompt-based loras)</span></label>
<input type="text" id="cfg-lora-dir" class="form-input" placeholder="e.g. /models/loras">
</div>
<!-- generation -->
<div class="card-title" style="margin-top:1.25rem">Generation</div>
<div class="form-row">
......@@ -601,9 +753,11 @@ window.__DEFAULT_WHISPER_SERVER_PATH__ = {{ default_whisper_server_path|tojson }
<label style="display:flex;align-items:center;gap:.5rem;cursor:pointer;font-size:13px"><input type="checkbox" id="cfg-grammar"> Grammar-guided generation</label>
</div>
<div class="form-actions" style="margin-top:1.5rem">
<div class="form-actions" style="margin-top:1.5rem;display:flex;gap:.5rem;flex-wrap:wrap;align-items:center">
<button class="btn btn-primary" onclick="saveModelConfig()">Save</button>
<button class="btn btn-secondary" id="cfg-save-new-btn" onclick="saveNewModelConfig()" style="display:none" title="Save as a separate configuration for the same model file">Save as new config</button>
<button class="btn btn-ghost" onclick="closeModal('cfg-modal')">Cancel</button>
<button class="btn btn-danger btn-sm" id="cfg-remove-config-btn" onclick="removeThisConfig()" style="display:none;margin-left:auto" title="Remove only this configuration (keeps other configs and the model file)">Remove this config</button>
</div>
</div>
</div>
......@@ -624,6 +778,7 @@ window.__DEFAULT_WHISPER_SERVER_PATH__ = {{ default_whisper_server_path|tojson }
<script>
/* ── helpers ─────────────────────────────────────────── */
function esc(s){return String(s).replace(/&/g,'&amp;').replace(/</g,'&lt;').replace(/>/g,'&gt;')}
function jq(v){return JSON.stringify(v).replace(/"/g,'&quot;')}
function fmtNum(n){if(!n)return'0';return n>=1e6?(n/1e6).toFixed(1)+'M':n>=1000?(n/1000).toFixed(1)+'k':String(n)}
function fmtGB(gb){if(!gb)return'—';return gb>=1?gb.toFixed(1)+' GB':(gb*1024).toFixed(0)+' MB'}
function fmtDate(s){try{return new Date(s).toLocaleDateString(undefined,{year:'numeric',month:'short',day:'numeric'})}catch{return s}}
......@@ -655,30 +810,6 @@ function _ggufBaseName(filename){
// Strip quant suffix and .gguf extension to get the model base name
return filename.replace(_QUANT_RE,'').replace(/\.gguf$/i,'');
}
function _switchGgufQuant(idx, newPath){
const m = _localModels[idx];
if(!m || !m.ggufGroup) return;
const f = m.ggufGroup.find(x=>x.path===newPath);
if(!f) return;
m.label = f.filename;
m.path = f.path;
m.size_gb = f.size_gb||0;
m.settings = f.settings||{};
m.in_config = f.in_config;
// Re-render quant badges in the row
const row = document.getElementById('gguf-row-'+idx);
if(!row) return;
const base = _ggufBaseName(f.filename);
const quantBadges = m.ggufGroup.length > 1
? m.ggufGroup.map(gf=>{
const q = _ggufQuant(gf.filename);
const active = gf.path === newPath;
return `<span class="badge ${active?'badge-ok':'badge-user'}" style="font-size:10px;padding:.1rem .3rem;cursor:pointer" title="${esc(gf.filename)}" onclick="_switchGgufQuant(${idx},${JSON.stringify(gf.path)})">${esc(q||gf.filename)}</span>`;
}).join(' ')
: '';
const nameCell = row.querySelector('td:first-child');
if(nameCell) nameCell.innerHTML = `${esc(base)}<br>${quantBadges}`;
}
/* ── tab / modal ─────────────────────────────────────── */
function switchTab(name,btn){
......@@ -690,6 +821,31 @@ function switchTab(name,btn){
function openModal(id){document.getElementById(id).classList.add('show')}
function closeModal(id){document.getElementById(id).classList.remove('show')}
// Custom confirm to avoid browser native-dialog blocking
function showConfirm(title, msg, okLabel){
return new Promise(resolve => {
document.getElementById('confirm-modal-title').textContent = title;
document.getElementById('confirm-modal-msg').textContent = msg;
const okBtn = document.getElementById('confirm-modal-ok');
const cancelBtn= document.getElementById('confirm-modal-cancel');
const xBtn = document.getElementById('confirm-modal-x');
okBtn.textContent = okLabel || 'Confirm';
openModal('confirm-modal');
function cleanup(result){
closeModal('confirm-modal');
okBtn.removeEventListener('click', onOk);
cancelBtn.removeEventListener('click', onCancel);
xBtn.removeEventListener('click', onCancel);
resolve(result);
}
function onOk(){ cleanup(true); }
function onCancel(){ cleanup(false); }
okBtn.addEventListener('click', onOk);
cancelBtn.addEventListener('click', onCancel);
xBtn.addEventListener('click', onCancel);
});
}
/* ── Global settings ─────────────────────────────────── */
let _defaultOffloadDir = './offload';
let _highlightCap = null; // capability to highlight in local models list (from ?local_cap= param)
......@@ -718,6 +874,24 @@ document.querySelectorAll('.tog-btn').forEach(btn=>{
document.querySelectorAll('.chip').forEach(c=>{
c.addEventListener('click',()=>c.classList.toggle('on'));
});
// Component type chips are single-select: clicking one deselects the rest.
// Also switch the format toggle to "All" when a component is selected,
// since most components are safetensors and not GGUF.
document.querySelectorAll('#comp-type-chips .chip').forEach(c=>{
c.addEventListener('click',()=>{
const wasOn = c.classList.contains('on');
document.querySelectorAll('#comp-type-chips .chip').forEach(x=>x.classList.remove('on'));
if(!wasOn){
c.classList.add('on');
if(_ggufMode === 'gguf'){
document.querySelectorAll('.tog-btn').forEach(b=>{
b.classList.toggle('on', b.dataset.val === 'all');
});
_ggufMode = 'all';
}
}
});
});
function getChips(id){return[...document.querySelectorAll('#'+id+' .chip.on')].map(c=>c.dataset.val)}
/* ── search ──────────────────────────────────────────── */
......@@ -764,6 +938,7 @@ async function doSearch(){
const arch = document.getElementById('filter-arch').value;
const sort = document.getElementById('filter-sort').value;
const sizes = getChips('size-chips').join(',');
const compType = getChips('comp-type-chips')[0] || '';
_activeQuants = new Set(getChips('quant-chips').map(v=>v.toUpperCase().split(' ')[0])); // strip ★
// Get selected capability filters (from our custom chips)
......@@ -778,6 +953,7 @@ async function doSearch(){
if(pipeline) params.append('pipeline_tag', pipeline);
if(sizes) params.append('sizes', sizes);
if(arch) params.append('arch', arch);
if(compType) params.append('component_type', compType);
const caps = getChips('cap-chips');
if(caps.length) params.append('capabilities', caps.join(','));
......@@ -812,12 +988,14 @@ async function doSearch(){
const capBadges = fmtCapabilities(m.capabilities||[]);
const isDownloaded = _cachedSearchIds.has(m.id);
const downloadedBadge = isDownloaded ? `<span class="badge badge-ok" style="font-size:10px;padding:.1rem .35rem;margin-left:.4rem;vertical-align:middle">✓ local</span>` : '';
const detectedRole = _detectCompRole(m.id.split('/').pop() || m.id);
const compBadge = detectedRole ? `<span class="badge badge-user" style="font-size:10px;padding:.1rem .35rem;margin-left:.4rem;vertical-align:middle">${esc(detectedRole.label)}</span>` : '';
return `
<div style="padding:.75rem 0;border-bottom:1px solid var(--border)">
<div style="display:flex;align-items:flex-start;justify-content:space-between;gap:.5rem">
<div style="min-width:0;flex:1">
<div style="font-weight:500;font-size:13px;overflow:hidden;text-overflow:ellipsis;white-space:nowrap;display:flex;align-items:center"
title="${esc(m.id)}">${vramDot}${esc(m.id)}${downloadedBadge}</div>
title="${esc(m.id)}">${vramDot}${esc(m.id)}${downloadedBadge}${compBadge}</div>
<div style="font-size:11px;color:var(--text-3);margin-top:.25rem;display:flex;align-items:center;gap:.5rem;flex-wrap:wrap">
${m.pipeline_tag?`<span class="badge badge-user">${esc(m.pipeline_tag)}</span>`:''}
${capBadges}
......@@ -993,18 +1171,38 @@ function looksLikeGguf(modelId, filePattern){
}
function openDownloadFor(modelId, filePattern){
document.getElementById('dl-id').value = modelId;
// Always reset to form view so the user can start a new download even while
// one is already running. The existing download continues on the server; the
// strip tracks it via polling independently of the EventSource.
if(_dlEs){ _dlEs.close(); _dlEs=null; }
_dlSessionId=null;
_dlQuantFiles=[];
document.getElementById('dl-form').style.display='block';
document.getElementById('dl-progress').style.display='none';
document.getElementById('dl-quant-row').style.display='none';
document.getElementById('dl-start-btn').textContent='Download';
document.getElementById('dl-id').value = modelId || '';
document.getElementById('dl-pattern').value = filePattern || '';
const isGguf = looksLikeGguf(modelId, filePattern);
if(isGguf){
if(!modelId){
// Blank form — show pattern row as default
document.getElementById('dl-pattern-row').style.display = 'block';
document.getElementById('dl-snapshot-note').style.display = 'none';
document.getElementById('dl-pattern').value = filePattern || '.gguf';
} else {
} else if(!isGguf && modelId.includes('/')){
// Looks like a full HF non-GGUF repo
document.getElementById('dl-pattern-row').style.display = 'none';
document.getElementById('dl-snapshot-note').style.display = 'flex';
document.getElementById('dl-pattern').value = '';
} else {
document.getElementById('dl-pattern-row').style.display = 'block';
document.getElementById('dl-snapshot-note').style.display = 'none';
}
openModal('dl-modal');
// Auto-browse if we have a specific HF repo ID so the quant picker appears immediately
if(modelId && modelId.includes('/')){
browseHfFiles();
}
}
/* progress helpers */
......@@ -1014,6 +1212,8 @@ function fmtEta(s){if(s===null||s===undefined)return'';s=Math.round(s);if(s<60)r
let _dlEs = null;
let _dlDone = false;
let _dlSessionId = null;
let _dlQuantFiles = [];
function _dlReset(){
const bar = document.getElementById('dl-bar');
......@@ -1071,29 +1271,66 @@ function handleProgressEvent(evt){
closeModal('dl-modal');
document.getElementById('dl-form').style.display='block';
document.getElementById('dl-progress').style.display='none';
document.getElementById('dl-quant-row').style.display='none';
document.getElementById('dl-pattern-row').style.display='block';
document.getElementById('dl-start-btn').textContent='Download';
_dlQuantFiles=[];
_dlReset();
},1800);
}else if(evt.type==='error'){
_dlDone=true;
showDownloadError(evt.message);
}else if(evt.type==='cancelled'){
_dlDone=true;
if(_dlEs){_dlEs.close();_dlEs=null;}
showDownloadError('Download cancelled');
}
// keepalive: ignore
}
async function startDownload(){
const id=document.getElementById('dl-id').value.trim();
if(!id){document.getElementById('dl-id').focus();return}
_dlDone=false;
_dlReset();
function minimizeDownload(){
closeModal('dl-modal');
}
async function reopenDownload(session_id){
openModal('dl-modal');
document.getElementById('dl-form').style.display='none';
document.getElementById('dl-progress').style.display='block';
// Already tracking this session with a live EventSource — just show the modal.
if(_dlSessionId===session_id && _dlEs) return;
// Different session or page reload: seed state from status API, then reconnect.
if(_dlEs){_dlEs.close();_dlEs=null;}
_dlSessionId=session_id;
_dlDone=false;
_dlReset();
try{
const r=await fetch(ROOT_PATH + '/admin/api/model-download',{
method:'POST',headers:{'Content-Type':'application/json'},
body:JSON.stringify({model_id:id,file_pattern:document.getElementById('dl-pattern').value||null})
});
if(!r.ok){const e=await r.json();showDownloadError(e.detail||'Request failed');return}
const {session_id}=await r.json();
const r=await fetch(ROOT_PATH + '/admin/api/downloads');
if(r.ok){
const all=await r.json();
const s=all.find(d=>d.session_id===session_id);
if(s){
document.getElementById('dl-filename').textContent=s.filename||s.model_id||'Downloading…';
const pct=s.percent||0;
const bar=document.getElementById('dl-bar');
bar.style.transition='none';
bar.style.width=pct+'%';
requestAnimationFrame(()=>{bar.style.transition='';});
document.getElementById('dl-pct').textContent=pct.toFixed(1)+'%';
if(s.downloaded!=null&&s.total!=null)
document.getElementById('dl-bytes').textContent=fmtBytes(s.downloaded)+' / '+fmtBytes(s.total);
if(s.rate) document.getElementById('dl-speed').textContent=fmtRate(s.rate);
if(s.eta!=null) document.getElementById('dl-eta').textContent=fmtEta(s.eta);
if(s.status==='done'){handleProgressEvent({type:'done'});return;}
if(s.status==='cancelled'){showDownloadError('Download cancelled');return;}
if(s.status==='error'){showDownloadError(s.error||'Download failed');return;}
}
}
}catch{}
// Connect a fresh EventSource to stream remaining progress events.
_dlEs=new EventSource(ROOT_PATH + '/admin/api/download-stream/'+session_id);
_dlEs.onmessage=function(e){
try{handleProgressEvent(JSON.parse(e.data))}catch{}
......@@ -1103,6 +1340,208 @@ async function startDownload(){
if(_dlEs&&_dlEs.readyState===EventSource.CLOSED) return;
showDownloadError('Connection to download stream lost');
};
}
async function stopDownload(session_id){
if(!confirm('Cancel this download?')) return;
try{
await fetch(ROOT_PATH + '/admin/api/download-cancel/'+session_id, {method:'POST'});
if(_dlSessionId===session_id){
if(_dlEs){_dlEs.close();_dlEs=null;}
_dlDone=true;
showDownloadError('Download cancelled');
}
}catch(e){
alert('Could not cancel download: '+e.message);
}
}
/* ── quant picker ──────────────────────────────────────── */
function _quantLabel(fname){
// Reuse the quant extractor; fall back to the bare filename
const q = _ggufQuant(fname);
return q || fname.replace(/\.gguf$/i,'');
}
const _RECOMMENDED_QUANTS = new Set(['Q4_K_M','Q5_K_M','Q6_K','Q4_K_S','Q5_K_S','IQ4_NL','IQ4_XS']);
function _dlSelectAll(val){
document.querySelectorAll('#dl-quant-list input[type=checkbox]').forEach(cb=>cb.checked=val);
}
function renderQuantPicker(ggufFiles, initialPattern){
_dlQuantFiles = ggufFiles;
// Sort: recommended first, then by size descending
const sorted = [...ggufFiles].sort((a,b)=>{
const ar = _RECOMMENDED_QUANTS.has(_quantLabel(a.name).toUpperCase());
const br = _RECOMMENDED_QUANTS.has(_quantLabel(b.name).toUpperCase());
if(ar!==br) return br-ar;
return (b.size||0)-(a.size||0);
});
const pat = (initialPattern||'').toLowerCase();
const list = document.getElementById('dl-quant-list');
list.innerHTML = sorted.map((f,i)=>{
const label = _quantLabel(f.name);
const upper = label.toUpperCase();
const isRec = _RECOMMENDED_QUANTS.has(upper);
const sz = f.size ? fmtBytes(f.size) : '?';
// Pre-check: recommended ones by default, or whatever matches the initial pattern
const checked = pat ? f.name.toLowerCase().includes(pat.replace('*','')) : isRec;
return `<label style="display:flex;align-items:center;gap:.6rem;padding:.3rem .25rem;border-radius:4px;cursor:pointer;font-size:13px" class="_dl-quant-row">
<input type="checkbox" data-fname="${esc(f.name)}" ${checked?'checked':''}>
<span style="flex:1;font-family:monospace;font-size:12px;overflow:hidden;text-overflow:ellipsis;white-space:nowrap" title="${esc(f.name)}">${esc(label)}</span>
${isRec?'<span class="badge badge-user" style="font-size:9px;padding:.1rem .3rem">rec</span>':''}
<span style="color:var(--text-2);font-size:11px;white-space:nowrap;flex-shrink:0">${esc(sz)}</span>
</label>`;
}).join('');
// Show picker, hide manual pattern row
document.getElementById('dl-quant-row').style.display = 'block';
document.getElementById('dl-pattern-row').style.display = 'none';
document.getElementById('dl-snapshot-note').style.display = 'none';
const note = document.getElementById('dl-quant-note');
const rec = sorted.filter(f=>_RECOMMENDED_QUANTS.has(_quantLabel(f.name).toUpperCase())).length;
note.textContent = rec
? `${rec} recommended quantization${rec>1?'s':''} pre-selected (marked "rec"). Larger Q = better quality, more VRAM.`
: `${ggufFiles.length} quantization${ggufFiles.length!==1?'s':''} available. Select which to download.`;
document.getElementById('dl-start-btn').textContent = 'Download selected';
}
async function browseHfFiles(){
const id = document.getElementById('dl-id').value.trim();
if(!id){ document.getElementById('dl-id').focus(); return; }
if(!id.includes('/')){
// Local path or bare name — can't browse, use pattern fallback
document.getElementById('dl-quant-row').style.display='none';
const isGguf = looksLikeGguf(id,'');
document.getElementById('dl-pattern-row').style.display = isGguf?'block':'none';
document.getElementById('dl-snapshot-note').style.display = isGguf?'none':'flex';
return;
}
const btn = document.getElementById('dl-browse-btn');
btn.textContent='…'; btn.disabled=true;
try{
const r = await fetch(ROOT_PATH+'/admin/api/hf-files?repo_id='+encodeURIComponent(id));
if(!r.ok){ throw new Error((await r.json()).detail||'Request failed'); }
const {files, error} = await r.json();
if(error){ throw new Error(error); }
const gguf = (files||[]).filter(f=>f.name.toLowerCase().endsWith('.gguf'));
if(gguf.length){
renderQuantPicker(gguf, document.getElementById('dl-pattern').value);
} else {
// Non-GGUF repo — hide picker, show snapshot note
_dlQuantFiles=[];
document.getElementById('dl-quant-row').style.display='none';
document.getElementById('dl-pattern-row').style.display='none';
document.getElementById('dl-snapshot-note').style.display='flex';
document.getElementById('dl-start-btn').textContent='Download';
const note = document.getElementById('dl-quant-note');
note.textContent='';
}
}catch(e){
// Fall back to manual pattern input
_dlQuantFiles=[];
document.getElementById('dl-quant-row').style.display='none';
document.getElementById('dl-pattern-row').style.display='block';
document.getElementById('dl-snapshot-note').style.display='none';
document.getElementById('dl-start-btn').textContent='Download';
alert('Could not browse repo: '+e.message);
}finally{
btn.textContent='Browse'; btn.disabled=false;
}
}
async function _startSingleDownload(modelId, filePattern){
/* Starts one download session, returns the session_id or null on error. */
const r = await fetch(ROOT_PATH+'/admin/api/model-download',{
method:'POST', headers:{'Content-Type':'application/json'},
body: JSON.stringify({model_id:modelId, file_pattern:filePattern||null})
});
if(!r.ok){ const e=await r.json(); throw new Error(e.detail||'Request failed'); }
return (await r.json()).session_id;
}
function _attachEventSource(session_id){
if(_dlEs){_dlEs.close();_dlEs=null;}
_dlSessionId=session_id;
_dlEs=new EventSource(ROOT_PATH+'/admin/api/download-stream/'+session_id);
_dlEs.onmessage=function(e){ try{handleProgressEvent(JSON.parse(e.data))}catch{} };
_dlEs.onerror=function(){
if(_dlDone) return;
if(_dlEs&&_dlEs.readyState===EventSource.CLOSED) return;
showDownloadError('Connection to download stream lost');
};
}
async function startDownloadFromForm(){
const id = document.getElementById('dl-id').value.trim();
if(!id){ document.getElementById('dl-id').focus(); return; }
// For HF repo IDs: always browse first so the user can pick specific files.
// Skip if the quant picker is already visible (Browse was already done).
const quantRowVisible = document.getElementById('dl-quant-row').style.display!=='none';
const looksLikeHfRepo = id.includes('/') && !id.startsWith('http');
if(!quantRowVisible && looksLikeHfRepo){
await browseHfFiles();
// If Browse produced a quant picker, stop here — let the user select files.
if(document.getElementById('dl-quant-row').style.display!=='none'){
return;
}
// No GGUF files found (non-GGUF repo or browse error) — fall through to snapshot download.
}
const pickerVisible = document.getElementById('dl-quant-row').style.display!=='none';
if(pickerVisible){
// Multi-file download from quant picker
const checked = [...document.querySelectorAll('#dl-quant-list input[type=checkbox]:checked')];
if(!checked.length){ alert('Select at least one quantization to download.'); return; }
_dlDone=false; _dlSessionId=null; _dlReset();
if(checked.length===1){
// Single selection → show progress in modal
document.getElementById('dl-form').style.display='none';
document.getElementById('dl-progress').style.display='block';
try{
const sid = await _startSingleDownload(id, checked[0].dataset.fname);
startPolling();
_attachEventSource(sid);
}catch(e){ showDownloadError(e.message); }
} else {
// Multiple selections → start all, show in strip
document.getElementById('dl-form').style.display='none';
document.getElementById('dl-progress').style.display='block';
document.getElementById('dl-filename').textContent=`Starting ${checked.length} downloads…`;
try{
const sids = [];
for(const cb of checked){
const sid = await _startSingleDownload(id, cb.dataset.fname);
sids.push(sid);
}
startPolling();
// Attach EventSource to the first one, rest tracked in strip
_attachEventSource(sids[0]);
document.getElementById('dl-filename').textContent=`Downloading ${checked.length} files — see strip below`;
}catch(e){ showDownloadError(e.message); }
}
} else {
// Manual / snapshot mode
await startDownload();
}
}
async function startDownload(){
const id=document.getElementById('dl-id').value.trim();
if(!id){document.getElementById('dl-id').focus();return}
_dlDone=false;
_dlSessionId=null;
_dlReset();
document.getElementById('dl-form').style.display='none';
document.getElementById('dl-progress').style.display='block';
try{
const sid = await _startSingleDownload(id, document.getElementById('dl-pattern').value||null);
_dlSessionId=sid;
startPolling();
_attachEventSource(sid);
}catch(e){showDownloadError(e.message)}
}
......@@ -1114,16 +1553,22 @@ async function pollDownloads(){
const r = await fetch(ROOT_PATH + '/admin/api/downloads');
if(!r.ok) return;
const all = await r.json();
const active = all.filter(d=>d.status!=='done'&&d.status!=='error');
const active = all.filter(d=>d.status!=='done'&&d.status!=='error'&&d.status!=='cancelled');
const strip = document.getElementById('dl-strip');
const list = document.getElementById('dl-strip-list');
if(!active.length){ strip.style.display='none'; return; }
if(!active.length){
strip.style.display='none';
stopPolling();
return;
}
startPolling(); // keep the timer alive (handles page-load case where timer was never set)
strip.style.display='block';
list.innerHTML = active.map(d=>{
const pct = d.percent||0;
const name = d.filename||d.model_id||'';
const spd = d.rate?fmtRate(d.rate):'';
const eta = d.eta!=null?fmtEta(d.eta):'';
const sid = d.session_id||'';
return `<div style="display:flex;align-items:center;gap:.75rem;padding:.2rem 0">
<div style="flex:1;min-width:0">
<div style="font-size:12px;font-weight:500;overflow:hidden;text-overflow:ellipsis;white-space:nowrap">${esc(d.model_id)}</div>
......@@ -1134,6 +1579,10 @@ async function pollDownloads(){
<div>${pct.toFixed(1)}%</div>
${spd?`<div>${esc(spd)}</div>`:''}
${eta?`<div class="muted">${esc(eta)}</div>`:''}
${sid?`<div style="margin-top:.25rem;display:flex;gap:.3rem">
<button class="btn btn-ghost" style="font-size:10px;padding:.1rem .35rem;line-height:1.4" onclick="reopenDownload('${sid}')" title="Open download progress">↗ View</button>
<button class="btn btn-ghost" style="font-size:10px;padding:.1rem .35rem;line-height:1.4;color:var(--error,#e55)" onclick="stopDownload('${sid}')" title="Cancel this download">✕ Stop</button>
</div>`:''}
</div>
</div>`;
}).join('<div style="border-top:1px solid var(--border);margin:.3rem 0"></div>');
......@@ -1146,7 +1595,12 @@ function startPolling(){
pollDownloads();
}
startPolling();
function stopPolling(){
if(_pollTimer){ clearInterval(_pollTimer); _pollTimer=null; }
}
// Single one-shot check on page load; polling only runs while downloads are active.
pollDownloads();
/* ── cache stats & local models ──────────────────────── */
async function loadCacheStats(){
......@@ -1166,6 +1620,183 @@ async function loadCacheStats(){
let _localModels = [];
let _ggufFiles = [];
let _hfModels = [];
function _renderConfigPills(idx, m) {
const configs = m.configs || [];
if (!configs.length) return '';
const pills = configs.map((c, cfgIdx) => {
const label = (c.settings && (c.settings.config_name || c.settings.alias)) || `Config ${cfgIdx + 1}`;
return `<span class="badge badge-user" style="font-size:10px;cursor:pointer;vertical-align:middle;margin:.1rem .1rem 0 0" onclick="openCfgModal(${idx},${cfgIdx})" title="Edit this configuration">${esc(label)}</span>`;
}).join('');
const addPill = `<span class="badge" style="font-size:10px;cursor:pointer;vertical-align:middle;margin:.1rem 0 0 0;background:var(--raised);border:1px dashed var(--border);color:var(--text-2)" onclick="openCfgModalNew(${idx})" title="Add another configuration for this model">+ Config</span>`;
return `<br style="line-height:.5rem">${pills}${addPill}`;
}
// ── Component file helpers ────────────────────────────────────────────────────
const _COMP_ROLES = [
{key:'vae', label:'VAE', pat:/vae|(^|[^a-z])ae\.(gguf|safetensors)/i},
{key:'t5xxl', label:'T5-XXL', pat:/t5.?xxl|t5.?v1_1/i},
{key:'clip_l', label:'CLIP-L', pat:/clip[_\-]l(?![_\-g])/i},
{key:'clip_g', label:'CLIP-G', pat:/clip[_\-]g(?![_\-]v)/i},
{key:'clip_vision', label:'CLIP Vision', pat:/clip[_\-]vi(?:s|t)/i},
{key:'lora', label:'LoRA', pat:/\blora\b/i},
{key:'encoder', label:'Encoder', pat:/encoder|text.?enc/i},
];
function _detectCompRole(filename) {
for(const r of _COMP_ROLES) if(r.pat.test(filename)) return r;
return null;
}
function _isCompFile(filename) {
return _COMP_ROLES.some(r => r.pat.test(filename));
}
function renderComponentsList(ggufFiles) {
// Files from the local cache (GGUF cache dir — any extension: .gguf, .safetensors, etc.)
const fileComps = ggufFiles.filter(f => _isCompFile(f.filename));
// HF repos whose repo-name looks like a component
const hfComps = _hfModels.filter(m => _isCompFile(m.id.split('/').pop() || m.id));
const total = fileComps.length + hfComps.length;
const badge = document.getElementById('comp-badge');
if(badge) badge.textContent = total ? `(${total})` : '';
const el = document.getElementById('comp-list');
if(!el) return;
if(!total){
el.innerHTML = '<span class="muted small">No component files detected. Use Download to fetch VAEs, text encoders, CLIP models or LoRAs.</span>';
return;
}
const fileRows = fileComps.map(f => {
const role = _detectCompRole(f.filename);
const ext = f.filename.includes('.') ? f.filename.split('.').pop().toUpperCase() : '';
return `<tr style="border-top:1px solid var(--border)">
<td style="padding:.4rem .25rem;font-family:monospace;font-size:11px;max-width:240px;overflow:hidden;text-overflow:ellipsis;white-space:nowrap" title="${esc(f.filename)}">${esc(f.filename)}</td>
<td style="padding:.4rem .25rem;white-space:nowrap">
<span class="badge badge-user" style="font-size:10px">${role?esc(role.label):'component'}</span>
${ext && ext !== 'GGUF' ? `<span class="badge" style="font-size:10px;background:var(--raised)">${esc(ext)}</span>` : ''}
</td>
<td style="text-align:right;padding:.4rem .25rem;white-space:nowrap;color:var(--text-2)">${fmtGB(f.size_gb)}</td>
<td style="padding:.4rem .25rem;font-family:monospace;font-size:10px;color:var(--text-3);overflow:hidden;text-overflow:ellipsis;white-space:nowrap;max-width:200px" title="${esc(f.path)}">${esc(f.path)}</td>
<td style="padding:.4rem .25rem;text-align:right;white-space:nowrap">
<button class="btn btn-danger btn-sm" onclick="_deleteCompFile(${jq(f.path)},'gguf')">Delete</button>
</td>
</tr>`;
});
const hfRows = hfComps.map(m => {
const role = _detectCompRole(m.id.split('/').pop() || m.id);
return `<tr style="border-top:1px solid var(--border)">
<td style="padding:.4rem .25rem;font-family:monospace;font-size:11px;max-width:240px;overflow:hidden;text-overflow:ellipsis;white-space:nowrap" title="${esc(m.id)}">${esc(m.id)}</td>
<td style="padding:.4rem .25rem;white-space:nowrap">
<span class="badge badge-user" style="font-size:10px">${role?esc(role.label):'component'}</span>
<span class="badge" style="font-size:10px;background:var(--raised)">HF</span>
</td>
<td style="text-align:right;padding:.4rem .25rem;white-space:nowrap;color:var(--text-2)">${fmtGB(m.size_gb)}</td>
<td style="padding:.4rem .25rem;font-size:10px;color:var(--text-3)">${m.file_count} file${m.file_count!==1?'s':''}</td>
<td style="padding:.4rem .25rem;text-align:right;white-space:nowrap">
<button class="btn btn-danger btn-sm" onclick="_deleteCompFile(${jq(m.id)},'hf')">Delete</button>
</td>
</tr>`;
});
el.innerHTML = '<table style="width:100%;border-collapse:collapse;font-size:13px">'+
'<thead><tr style="color:var(--text-2);font-size:10px;text-transform:uppercase;letter-spacing:.05em">'+
'<th style="text-align:left;padding:.3rem .25rem;font-weight:700">File / Repo</th>'+
'<th style="text-align:left;padding:.3rem .25rem;font-weight:700">Type</th>'+
'<th style="text-align:right;padding:.3rem .25rem;font-weight:700">Size</th>'+
'<th style="text-align:left;padding:.3rem .25rem;font-weight:700">Path / Info</th>'+
'<th></th></tr></thead><tbody>'+fileRows.join('')+hfRows.join('')+'</tbody></table>';
}
async function _deleteCompFile(id, cacheType) {
if(!await showConfirm('Delete component file', `Delete "${id}" from local cache? This cannot be undone.`, 'Delete')) return;
try{
const r = await fetch(ROOT_PATH + '/admin/api/cached-models/' + encodeURIComponent(id) + '?cache_type=' + (cacheType||'gguf'), {method:'DELETE'});
const d = await r.json();
if(d.success) refreshLocal();
else await showConfirm('Delete failed', d.detail || 'Unknown error', 'OK');
}catch(e){ await showConfirm('Error', e.message, 'OK'); }
}
// ── Component select combos ───────────────────────────────────────────────────
// role slug → input IDs: sel = 'cfg-{role}-sel', text = 'cfg-{role}-path'
// role slugs used: vae t5xxl clip-l clip-g clip-vision
const _COMP_FIELDS = [
{slug:'vae', pathId:'cfg-vae-path'},
{slug:'t5xxl', pathId:'cfg-t5xxl-path'},
{slug:'clip-l', pathId:'cfg-clip-l-path'},
{slug:'clip-g', pathId:'cfg-clip-g-path'},
{slug:'clip-vision',pathId:'cfg-clip-vision-path'},
{slug:'lora', pathId:'cfg-lora-path'},
];
function _populateCompSelects(){
const loraRol = _COMP_ROLES.find(r=>r.key==='lora');
const loras = _ggufFiles.filter(f=>loraRol?.pat.test(f.filename));
const other = _ggufFiles.filter(f=>_isCompFile(f.filename) && !loraRol?.pat.test(f.filename));
const rest = _ggufFiles.filter(f=>!_isCompFile(f.filename));
const hfComps = _hfModels.filter(m=>_isCompFile(m.id.split('/').pop()||m.id));
const hfRest = _hfModels.filter(m=>!_isCompFile(m.id.split('/').pop()||m.id));
const mkOpt = (val, text) => `<option value="${esc(val)}">${esc(text)}</option>`;
const otherOpts = other.length
? `<optgroup label="Component files">${other.map(f=>mkOpt(f.path, f.filename+' ('+fmtGB(f.size_gb)+')')).join('')}</optgroup>`
: '';
const loraOpts = loras.length
? `<optgroup label="LoRA files">${loras.map(f=>mkOpt(f.path, f.filename+' ('+fmtGB(f.size_gb)+')')).join('')}</optgroup>`
: '';
const restOpts = rest.length
? `<optgroup label="Local files">${rest.map(f=>mkOpt(f.path, f.filename+' ('+fmtGB(f.size_gb)+')')).join('')}</optgroup>`
: '';
const hfCompOpts = hfComps.length
? `<optgroup label="HuggingFace component repos">${hfComps.map(m=>mkOpt(m.id, m.id+' ('+fmtGB(m.size_gb)+')')).join('')}</optgroup>`
: '';
const hfRestOpts = hfRest.length
? `<optgroup label="HuggingFace repos">${hfRest.map(m=>mkOpt(m.id, m.id+' ('+fmtGB(m.size_gb)+')')).join('')}</optgroup>`
: '';
_COMP_FIELDS.forEach(({slug}) => {
const sel = document.getElementById('cfg-'+slug+'-sel');
if(!sel) return;
const html = slug === 'lora'
? `<option value="">None</option>${loraOpts}${otherOpts}${hfCompOpts}${restOpts}${hfRestOpts}<option value="__custom__">Custom path…</option>`
: `<option value="">None</option>${otherOpts}${hfCompOpts}${loraOpts}${restOpts}${hfRestOpts}<option value="__custom__">Custom path…</option>`;
sel.innerHTML = html;
});
}
function _setCompField(slug, savedPath){
const sel = document.getElementById('cfg-'+slug+'-sel');
const inp = document.getElementById('cfg-'+slug+'-path');
if(!sel) return;
if(!savedPath){
sel.value = '';
if(inp){ inp.value=''; inp.style.display='none'; }
return;
}
const match = Array.from(sel.options).find(o=>o.value===savedPath);
if(match){
sel.value = savedPath;
if(inp){ inp.value=''; inp.style.display='none'; }
} else {
sel.value = '__custom__';
if(inp){ inp.value=savedPath; inp.style.display=''; }
}
}
function _onCompPick(slug){
const sel = document.getElementById('cfg-'+slug+'-sel');
const pathId= 'cfg-'+slug+'-path';
const inp = document.getElementById(pathId);
if(!sel || !inp) return;
inp.style.display = sel.value === '__custom__' ? '' : 'none';
if(sel.value !== '__custom__') inp.value = '';
}
function _readCompField(slug){
const sel = document.getElementById('cfg-'+slug+'-sel');
const inp = document.getElementById('cfg-'+slug+'-path');
if(!sel) return '';
if(sel.value === '__custom__') return inp ? inp.value.trim() : '';
return sel.value || '';
}
function _renderWhisperServerRows(models){
if(!models.length) return '';
......@@ -1280,32 +1911,41 @@ async function loadCachedModels(){
// HF models
const hf = d.hf||[];
document.getElementById('hf-model-badge').textContent = hf.length ? `(${hf.length})` : '';
if(!hf.length){
_hfModels = hf;
const hfModelList = hf.filter(m => !_isCompFile(m.id.split('/').pop() || m.id));
document.getElementById('hf-model-badge').textContent = hfModelList.length ? `(${hfModelList.length})` : '';
if(!hfModelList.length){
hfEl.innerHTML = '<span class="muted small">No HuggingFace models cached.</span>';
}else{
const rows = hf.map(m=>{
const rows = hfModelList.map(m=>{
const idx = _localModels.length;
_localModels.push({label:m.id, path:m.id, cacheType:'hf', size_gb:m.size_gb||0,
defaultType:m.model_type||'text_models', settings:m.settings||{}, in_config:m.in_config,
capabilities:m.capabilities||[]});
capabilities:m.capabilities||[], incomplete:!!m.incomplete, configs:m.configs||[]});
const loaded = _loadedKeys.has(m.id) || [..._loadedKeys].some(k=>k.endsWith(':'+m.id)||k===m.id);
const capBadges = fmtCapabilities(m.capabilities||[]);
const instBadgeHf = m.in_config ? _instanceBadge([m.id], (m.settings||{}).max_instances||1) : '';
const _hfLm = _localModels[idx];
const _hfConfigPills = _renderConfigPills(idx, _hfLm);
const _hfConfigs = _hfLm.configs || [];
const _hfCfgMax = _hfConfigs.length === 1 ? ((_hfConfigs[0].settings||{}).max_instances||1) : 1;
const instBadgeHf = m.in_config ? _instanceBadge([m.id], _hfCfgMax) : '';
const hlHf = _highlightCap && (m.capabilities||[]).includes(_highlightCap);
return `<tr${hlHf?' class="local-cap-highlight"':''} style="border-top:1px solid var(--border)${hlHf?';background:rgba(110,207,126,.07);outline:2px solid rgba(110,207,126,.25);outline-offset:-1px':''}">
<td style="padding:.4rem .25rem;font-family:monospace;font-size:12px;max-width:260px;overflow:hidden;text-overflow:ellipsis;white-space:nowrap" title="${esc(m.id)}">${esc(m.id)}</td>
const incompleteBadgeHf = m.incomplete ? '<span class="badge" style="background:rgba(255,160,0,.18);color:#b87200;font-size:10px;margin-left:.3rem" title="Download may be incomplete — some files are missing or truncated">⚠ incomplete</span>' : '';
const _hfConfigCount = _hfConfigs.length;
return `<tr${hlHf?' class="local-cap-highlight"':''} style="border-top:1px solid var(--border)${hlHf?';background:rgba(110,207,126,.07);outline:2px solid rgba(110,207,126,.25);outline-offset:-1px':''}${m.incomplete?';background:rgba(255,160,0,.04)':''}">
<td style="padding:.4rem .25rem;font-family:monospace;font-size:12px;max-width:260px;overflow:hidden;text-overflow:ellipsis" title="${esc(m.id)}">${esc(m.id)}${incompleteBadgeHf}${_hfConfigPills}</td>
<td style="text-align:right;padding:.4rem .25rem;white-space:nowrap;color:var(--text-2)">${fmtGB(m.size_gb)}</td>
<td style="text-align:right;padding:.4rem .25rem;color:var(--text-2)">${m.file_count}</td>
<td style="padding:.4rem .25rem;font-size:11px">${capBadges||'<span class="muted small">—</span>'}</td>
<td style="text-align:center;padding:.4rem .25rem">${m.in_config?`<span class="badge badge-ok">enabled</span>${instBadgeHf?'<br>'+instBadgeHf:''}`:' <span class="muted small">—</span>'}</td>
<td style="text-align:center;padding:.4rem .25rem">${m.in_config?`<span class="badge badge-ok">enabled${_hfConfigCount>1?` ×${_hfConfigCount}`:''}</span>${instBadgeHf?'<br>'+instBadgeHf:''}`:' <span class="muted small">—</span>'}</td>
<td style="padding:.4rem .25rem;text-align:right;white-space:nowrap">
${m.in_config?(loaded
?`<button class="btn btn-ghost btn-sm" onclick="unloadModel(${idx})">Unload</button>`
:`<button class="btn btn-primary btn-sm" onclick="loadModel(${idx})">Load now</button>`):''}
<button class="btn btn-secondary btn-sm" onclick="openCfgModal(${idx})">${m.in_config?'Configure':'Add to CoderAI'}</button>
${m.in_config?`<button class="btn btn-ghost btn-sm" onclick="disableModel(${idx})">Remove</button>`:''}
<button class="btn btn-danger btn-sm" onclick="deleteModelConfirm(${idx})">Delete</button>
?`<button class="btn btn-ghost btn-sm" style="font-size:10px;padding:.15rem .4rem" onclick="unloadModel(${idx})">Unload</button>`
:`<button class="btn btn-primary btn-sm" style="font-size:10px;padding:.15rem .4rem" onclick="loadModel(${idx})">Load now</button>`):''}
<button class="btn btn-secondary btn-sm" style="font-size:10px;padding:.15rem .4rem" onclick="openCfgModal(${idx},0)">${m.in_config?'Configure':'Add'}</button>
${m.in_config?`<button class="btn btn-ghost btn-sm" style="font-size:10px;padding:.15rem .4rem" onclick="disableModel(${idx})">Remove</button>`:''}
<button class="btn btn-ghost btn-sm" style="font-size:10px;padding:.15rem .4rem" onclick="openDownloadFor('${esc(m.id)}')" title="Re-download from HuggingFace">⬇ Re-download</button>
<button class="btn btn-danger btn-sm" style="font-size:10px;padding:.15rem .4rem" onclick="deleteModelConfirm(${idx})">Delete</button>
</td>
</tr>`;
});
......@@ -1323,51 +1963,58 @@ async function loadCachedModels(){
const gguf = d.gguf||[];
_ggufFiles = gguf;
refreshWhisperGgufOptions();
document.getElementById('gguf-file-badge').textContent = gguf.length ? `(${gguf.length})` : '';
if(!gguf.length){
renderComponentsList(gguf);
const ggufModelFiles = gguf.filter(f => !_isCompFile(f.filename));
document.getElementById('gguf-file-badge').textContent = ggufModelFiles.length ? `(${ggufModelFiles.length})` : '';
if(!ggufModelFiles.length){
ggufEl.innerHTML = '<span class="muted small">No GGUF files cached.</span>';
}else{
// Group by base model name (strip quant suffix)
const groups = {};
gguf.forEach(f=>{
// One row per file — no grouping. Each quantization is independently
// visible and can be individually configured, removed, and deleted.
const rows = ggufModelFiles.map(f=>{
const base = _ggufBaseName(f.filename);
if(!groups[base]) groups[base] = [];
groups[base].push(f);
});
const rows = Object.entries(groups).map(([base, files])=>{
// Use the configured file if any, else first
const primary = files.find(f=>f.in_config) || files[0];
const quant = _ggufQuant(f.filename);
const idx = _localModels.length;
_localModels.push({label:primary.filename, path:primary.path, cacheType:'gguf',
size_gb:primary.size_gb||0, defaultType:primary.model_type||'text_models',
settings:primary.settings||{}, in_config:primary.in_config,
capabilities:primary.capabilities||[], ggufGroup:files});
const loaded = _loadedKeys.has(primary.path) || _loadedKeys.has(primary.filename) ||
[..._loadedKeys].some(k=>k.endsWith(':'+primary.path)||k.endsWith(':'+primary.filename));
const capBadges = fmtCapabilities(primary.capabilities||[]);
const in_config = files.some(f=>f.in_config);
const instBadgeGguf = in_config ? _instanceBadge([primary.path, primary.filename], (primary.settings||{}).max_instances||1) : '';
// Quant badges
const quantBadges = files.length > 1
? files.map(f=>{
const q = _ggufQuant(f.filename);
const active = f.path === primary.path;
return `<span class="badge ${active?'badge-ok':'badge-user'}" style="font-size:10px;padding:.1rem .3rem;cursor:pointer" title="${esc(f.filename)}" onclick="_switchGgufQuant(${idx},${JSON.stringify(f.path)})">${esc(q||f.filename)}</span>`;
}).join(' ')
: '';
const hlGguf = _highlightCap && files.some(f=>(f.capabilities||[]).includes(_highlightCap));
return `<tr${hlGguf?' class="local-cap-highlight"':''} id="gguf-row-${idx}" style="border-top:1px solid var(--border)${hlGguf?';background:rgba(110,207,126,.07);outline:2px solid rgba(110,207,126,.25);outline-offset:-1px':''}">
<td style="padding:.4rem .25rem;font-family:monospace;font-size:11px;max-width:280px;overflow:hidden;text-overflow:ellipsis;white-space:nowrap" title="${esc(primary.filename)}">${esc(base)}<br>${quantBadges}</td>
<td style="text-align:right;padding:.4rem .25rem;white-space:nowrap;color:var(--text-2)">${fmtGB(primary.size_gb)}</td>
_localModels.push({label:f.filename, path:f.path, cacheType:'gguf',
size_gb:f.size_gb||0, defaultType:f.model_type||'text_models',
settings:f.settings||{}, in_config:f.in_config,
capabilities:f.capabilities||[], ggufGroup:[f],
source_repo:f.source_repo||'',
incomplete:!!f.incomplete,
missing:!!f.missing,
configured_path:f.configured_path||'',
configs:f.configs||[]});
const loaded = _loadedKeys.has(f.path) || _loadedKeys.has(f.filename) ||
[..._loadedKeys].some(k=>k.endsWith(':'+f.path)||k.endsWith(':'+f.filename));
const capBadges = fmtCapabilities(f.capabilities||[]);
const lm = _localModels[idx];
const _cfgMax = lm.configs && lm.configs.length === 1 ? ((lm.configs[0].settings||{}).max_instances||1) : 1;
const instBadge = f.in_config ? _instanceBadge([f.path, f.filename], _cfgMax) : '';
const quantBadge = quant ? `<span class="badge badge-user" style="font-size:10px;padding:.1rem .3rem;margin-left:.3rem">${esc(quant)}</span>` : '';
const hl = _highlightCap && (f.capabilities||[]).includes(_highlightCap);
const missingBadge = f.missing ? ' <span class="badge" style="background:rgba(220,50,50,.18);color:#e05555;font-size:10px" title="File not found at configured path — re-download or remove this configuration">✕ file missing</span>' : '';
const incompleteBadge = f.incomplete ? ' <span class="badge" style="background:rgba(255,160,0,.18);color:#b87200;font-size:10px" title="Download may be incomplete">⚠ incomplete</span>' : '';
const redownloadTarget = f.source_repo || '';
const redownloadPattern = f.source_repo ? f.filename : '';
const configPills = _renderConfigPills(idx, lm);
const configCount = (lm.configs||[]).length;
const rowBg = f.missing ? ';background:rgba(220,50,50,.04);outline:1px solid rgba(220,50,50,.2);outline-offset:-1px'
: f.incomplete ? ';background:rgba(255,160,0,.04)' : '';
return `<tr${hl?' class="local-cap-highlight"':''} id="gguf-row-${idx}" style="border-top:1px solid var(--border)${hl?';background:rgba(110,207,126,.07);outline:2px solid rgba(110,207,126,.25);outline-offset:-1px':''}${rowBg}">
<td style="padding:.4rem .25rem;font-family:monospace;font-size:11px;max-width:280px;overflow:hidden;text-overflow:ellipsis" title="${esc(f.path)}">${esc(base)}${quantBadge}${missingBadge}${incompleteBadge}${configPills}</td>
<td style="text-align:right;padding:.4rem .25rem;white-space:nowrap;color:var(--text-2)">${f.missing?'<span class="muted small">—</span>':fmtGB(f.size_gb)}</td>
<td style="padding:.4rem .25rem;font-size:11px">${capBadges||'<span class="muted small">—</span>'}</td>
<td style="text-align:center;padding:.4rem .25rem">${in_config?`<span class="badge badge-ok">enabled</span>${instBadgeGguf?'<br>'+instBadgeGguf:''}`:' <span class="muted small">—</span>'}</td>
<td style="text-align:center;padding:.4rem .25rem">${f.in_config?`<span class="badge ${f.missing?'badge-err':'badge-ok'}">enabled${configCount>1?` ×${configCount}`:''}</span>${instBadge?'<br>'+instBadge:''}`:' <span class="muted small">—</span>'}</td>
<td style="padding:.4rem .25rem;text-align:right;white-space:nowrap">
${in_config?(loaded
?`<button class="btn btn-ghost btn-sm" onclick="unloadModel(${idx})">Unload</button>`
:`<button class="btn btn-primary btn-sm" onclick="loadModel(${idx})">Load now</button>`):''}
<button class="btn btn-secondary btn-sm" onclick="openCfgModal(${idx})">${in_config?'Configure':'Add to CoderAI'}</button>
${in_config?`<button class="btn btn-ghost btn-sm" onclick="disableModel(${idx})">Remove</button>`:''}
<button class="btn btn-danger btn-sm" onclick="deleteModelConfirm(${idx})">Delete</button>
${f.in_config&&!f.missing?(loaded
?`<button class="btn btn-ghost btn-sm" style="font-size:10px;padding:.15rem .4rem" onclick="unloadModel(${idx})">Unload</button>`
:`<button class="btn btn-primary btn-sm" style="font-size:10px;padding:.15rem .4rem" onclick="loadModel(${idx})">Load now</button>`):''}
${f.missing
?`<button class="btn btn-warn btn-sm" style="font-size:10px;padding:.15rem .4rem" onclick="openDownloadFor('${esc(redownloadTarget)}','${esc(redownloadPattern)}')" title="Re-download to restore this model">⬇ Re-download</button>`
:`<button class="btn btn-ghost btn-sm" style="font-size:10px;padding:.15rem .4rem" onclick="openDownloadFor('${esc(redownloadTarget)}','${esc(redownloadPattern)}')" title="${redownloadTarget?'Re-download from HuggingFace':'Download a replacement'}">⬇ Re-download</button>`}
<button class="btn btn-secondary btn-sm" style="font-size:10px;padding:.15rem .4rem" onclick="openCfgModal(${idx},0)">${f.in_config?'Configure':'Add'}</button>
${f.in_config?`<button class="btn btn-ghost btn-sm" style="font-size:10px;padding:.15rem .4rem" onclick="disableModel(${idx})">Remove</button>`:''}
${!f.missing?`<button class="btn btn-danger btn-sm" style="font-size:10px;padding:.15rem .4rem" onclick="deleteModelConfirm(${idx})">Delete</button>`:''}
</td>
</tr>`;
});
......@@ -1442,17 +2089,19 @@ function _findLoadedKey(paths){
return null;
}
function _instanceBadge(lookupPaths, maxCfg){
if(!maxCfg || maxCfg <= 1) return '';
function _instanceBadge(lookupPaths, cfgMax){
const instKey = _findLoadedKey(lookupPaths);
const info = instKey ? _instanceInfo[instKey] : null;
const loadedCount = info ? info.loaded : 0;
const maxCount = info ? info.max : maxCfg;
if(loadedCount === 0){
return `<span class="badge badge-user" style="font-size:10px" title="${maxCount} instances configured">×${maxCount} inst.</span>`;
if(info && info.max > 1){
// Live data: show actual loaded/max counts
const cls = info.loaded >= info.max ? 'badge-ok' : 'badge-warn';
return `<span class="badge ${cls}" style="font-size:10px" title="${info.loaded} of ${info.max} instances loaded">${info.loaded}/${info.max} inst.</span>`;
}
if(cfgMax && cfgMax > 1){
// Configured but not yet loaded — show the configured count
return `<span class="badge badge-user" style="font-size:10px" title="${cfgMax} instances configured">×${cfgMax} inst.</span>`;
}
const cls = loadedCount >= maxCount ? 'badge-ok' : 'badge-warn';
return `<span class="badge ${cls}" style="font-size:10px" title="${loadedCount} of ${maxCount} instances loaded">${loadedCount}/${maxCount} inst.</span>`;
return '';
}
async function refreshLocal(){
......@@ -1510,27 +2159,39 @@ refreshLocal();
async function clearCacheConfirm(type){
const labels = {hf:'HuggingFace', gguf:'GGUF', all:'ALL'};
if(!confirm(`Delete ${labels[type]} model cache? This cannot be undone.`)) return;
if(!await showConfirm('Clear cache', `Delete ${labels[type]} model cache? This cannot be undone.`, 'Delete')) return;
try{
const r = await fetch(ROOT_PATH + '/admin/api/cache?cache_type='+type, {method:'DELETE'});
const d = await r.json();
if(d.success){
refreshLocal();
alert(`Cache cleared. Freed ${fmtBytes(d.freed_bytes||0)}.`);
}else alert('Error clearing cache');
}catch(e){alert('Error: '+e.message)}
if(d.success) refreshLocal();
else await showConfirm('Error', 'Error clearing cache', 'OK');
}catch(e){ await showConfirm('Error', e.message, 'OK'); }
}
async function deleteModelConfirm(idx){
const m = _localModels[idx];
if(!confirm(`Delete "${m.label}" from local cache? This cannot be undone.`)) return;
// Always send the full path so the backend can locate the file regardless of cache location
const idForUrl = m.path;
if(!m){ alert('Error: model not found — please refresh the page.'); return; }
const extra = m.in_config ? ' It will also be removed from your configuration.' : '';
if(!await showConfirm('Delete model', `Delete "${m.label}" from local cache? This cannot be undone.${extra}`, 'Delete')) return;
try{
// De-configure first if the model is currently enabled, to avoid dangling config entries
if(m.in_config){
const paths = m.ggufGroup
? m.ggufGroup.filter(f=>f.in_config).map(f=>f.path)
: [m.path];
for(const p of paths.length ? paths : [m.path]){
await fetch(ROOT_PATH + '/admin/api/model-disable', {
method:'POST', headers:{'Content-Type':'application/json'},
body: JSON.stringify({path: p})
});
}
}
// Remove from cache
const idForUrl = m.path;
const r = await fetch(ROOT_PATH + '/admin/api/cached-models/'+encodeURIComponent(idForUrl)+'?cache_type='+m.cacheType, {method:'DELETE'});
const d = await r.json();
if(d.success) refreshLocal();
else alert('Error: '+(d.detail||'Unknown'));
else alert('Error: '+(d.detail||d.error||'Unknown'));
}catch(e){alert('Error: '+e.message)}
}
......@@ -1674,18 +2335,66 @@ async function saveWhisperServerEdit() {
} catch(e) { alert('Error: ' + e.message); }
}
function openCfgModal(idx){
function openCfgModalNew(idx) {
// Open modal for a brand-new config (no config_id) on an existing model
openCfgModal(idx, -1);
}
function openCfgModal(idx, cfgIdx){
const m = _localModels[idx];
if (m.cacheType === 'whisper-server') {
_openWhisperServerEdit(m);
return;
}
const s = m.settings || {};
document.getElementById('cfg-modal-title').textContent = m.in_config ? 'Configure model' : 'Add to CoderAI';
// Determine which config's settings to load
const configs = m.configs || [];
// Normalize cfgIdx: undefined/null → 0, -1 → "new"
const _cfgIdx = (cfgIdx == null) ? 0 : cfgIdx;
let s, configId, isNewConfig;
if (_cfgIdx === -1 || (configs.length === 0 && !m.in_config)) {
// New config: explicitly requested (-1) or truly unconfigured model
s = {};
configId = '';
isNewConfig = true;
} else if (_cfgIdx >= 0 && configs[_cfgIdx]) {
s = configs[_cfgIdx].settings || {};
configId = s.config_id || '';
isNewConfig = false;
} else {
// Fallback to legacy m.settings
s = m.settings || {};
configId = s.config_id || '';
isNewConfig = !configId;
}
document.getElementById('cfg-config-id').value = configId;
// Show/hide "Save as new config" and "Remove this config" buttons
const saveNewBtn = document.getElementById('cfg-save-new-btn');
saveNewBtn.style.display = (!isNewConfig && m.in_config) ? '' : 'none';
const removeConfigBtn = document.getElementById('cfg-remove-config-btn');
// Only show remove-this-config when the model has multiple configs — removing the last one is handled by the row-level Remove button
removeConfigBtn.style.display = (!isNewConfig && configId && configs.length > 1) ? '' : 'none';
const cfgLabel = isNewConfig
? (m.in_config ? 'Add configuration' : 'Add')
: (configs.length > 1 ? `Configure: ${s.alias || ('Config ' + (_cfgIdx + 1))}` : 'Configure model');
document.getElementById('cfg-modal-title').textContent = cfgLabel;
document.getElementById('cfg-id-label').textContent = m.label;
document.getElementById('cfg-path').value = m.path;
document.getElementById('cfg-orig-path').value = m.path; // frozen at open time
// If the file was re-downloaded to a new path, use the old configured path as orig_path
// so saveModelConfig removes the stale entry and writes one at the new location.
document.getElementById('cfg-orig-path').value = m.configured_path || m.path;
document.getElementById('cfg-orig-type').value = m.defaultType;
// Show a notice when the file moved (re-downloaded to a different location)
const pathMovedEl = document.getElementById('cfg-path-moved-notice');
if(pathMovedEl){
if(m.configured_path && m.configured_path !== m.path){
pathMovedEl.textContent = `File re-downloaded to new path. Saving will update the configuration.`;
pathMovedEl.style.display = '';
} else {
pathMovedEl.style.display = 'none';
}
}
// Quantization selector for grouped GGUF models
const quantRow = document.getElementById('cfg-quant-row');
......@@ -1749,6 +2458,10 @@ function openCfgModal(idx){
capsdet.textContent = capsFromConfig ? '' : '(auto-detected)';
document.getElementById('cfg-alias').value = s.alias || '';
document.getElementById('cfg-config-name').value = s.config_name || '';
// Show config-name field only when multiple configs exist or a new one is being added to an existing model
document.getElementById('cfg-config-name-row').style.display =
(configs.length > 1 || (isNewConfig && m.in_config)) ? '' : 'none';
document.getElementById('cfg-backend').value = s.backend || 'auto';
document.getElementById('cfg-load-mode').value = s.load_mode || 'on-request';
// Used VRAM
......@@ -1787,6 +2500,19 @@ function openCfgModal(idx){
document.getElementById('cfg-parser').value = s.parser || (!m.in_config ? _autoDetectParser(m.path) : 'auto');
document.getElementById('cfg-tools').checked = !!s.tools_closer_prompt;
document.getElementById('cfg-grammar').checked = !!s.grammar_guided;
document.getElementById('cfg-template').value = s.model_template || '';
document.getElementById('cfg-max-vram').value = s.max_vram != null ? s.max_vram : '';
document.getElementById('cfg-sdcpp-flash').checked = !!s.sdcpp_flash_attn;
document.getElementById('cfg-sdcpp-diff-flash').checked = !!s.sdcpp_diffusion_flash_attn;
document.getElementById('cfg-vae-tiling').checked = !!s.vae_tiling;
_populateCompSelects();
_setCompField('vae', s.vae_path || '');
_setCompField('t5xxl', s.t5xxl_path || '');
_setCompField('clip-l', s.clip_l_path || '');
_setCompField('clip-g', s.clip_g_path || '');
_setCompField('clip-vision',s.clip_vision_path|| '');
_setCompField('lora', s.lora_path || '');
document.getElementById('cfg-lora-dir').value = s.lora_model_dir || '';
openModal('cfg-modal');
}
......@@ -1805,6 +2531,32 @@ function _estimateVram(m, nCtx) {
return weights * 1.15 + kvCacheGb;
}
function saveNewModelConfig(){
// Clear config_id so saveModelConfig creates a new entry instead of updating
document.getElementById('cfg-config-id').value = '';
saveModelConfig();
}
async function removeThisConfig(){
const path = document.getElementById('cfg-path').value;
const configId = document.getElementById('cfg-config-id').value.trim();
if (!configId) return;
if (!await showConfirm('Remove configuration', 'Remove this configuration? The model file and other configurations will be kept.', 'Remove')) return;
try {
const r = await fetch(ROOT_PATH + '/admin/api/model-disable', {
method: 'POST', headers: {'Content-Type': 'application/json'},
body: JSON.stringify({path, config_id: configId})
});
if (!r.ok) {
let msg = `HTTP ${r.status}`;
try { const d = await r.json(); msg = d.detail || d.error || msg; } catch(_){}
throw new Error(msg);
}
closeModal('cfg-modal');
loadCachedModels();
} catch(e) { alert('Error: ' + e.message); }
}
async function saveModelConfig(){
const path = document.getElementById('cfg-path').value;
const maxGpu = parseFloat(document.getElementById('cfg-max-gpu').value);
......@@ -1812,13 +2564,16 @@ async function saveModelConfig(){
const usedVram = parseFloat(document.getElementById('cfg-used-vram').value);
const {primaryType, model_types, video_subtypes} = _getCheckedTypes();
const origPath = document.getElementById('cfg-orig-path').value;
const configId = document.getElementById('cfg-config-id').value.trim();
const data = {
path,
config_id: configId || undefined,
orig_path: (origPath && origPath !== path) ? origPath : undefined,
model_type: primaryType,
model_types: model_types,
video_subtypes: video_subtypes.length ? video_subtypes : undefined,
alias: document.getElementById('cfg-alias').value.trim() || null,
config_name: document.getElementById('cfg-config-name').value.trim() || null,
backend: document.getElementById('cfg-backend').value,
load_mode: document.getElementById('cfg-load-mode').value,
used_vram_gb: isNaN(usedVram) ? null : usedVram,
......@@ -1839,6 +2594,18 @@ async function saveModelConfig(){
tools_closer_prompt: document.getElementById('cfg-tools').checked,
grammar_guided: document.getElementById('cfg-grammar').checked,
capabilities: [...document.querySelectorAll('.cfg-cap-cb:checked')].map(cb => cb.value),
model_template: document.getElementById('cfg-template').value || null,
vae_path: _readCompField('vae') || null,
t5xxl_path: _readCompField('t5xxl') || null,
clip_l_path: _readCompField('clip-l') || null,
clip_g_path: _readCompField('clip-g') || null,
clip_vision_path: _readCompField('clip-vision') || null,
lora_path: _readCompField('lora') || null,
lora_model_dir: document.getElementById('cfg-lora-dir').value.trim() || null,
max_vram: parseFloat(document.getElementById('cfg-max-vram').value) || 0,
sdcpp_flash_attn: document.getElementById('cfg-sdcpp-flash').checked,
sdcpp_diffusion_flash_attn: document.getElementById('cfg-sdcpp-diff-flash').checked,
vae_tiling: document.getElementById('cfg-vae-tiling').checked,
};
try{
const r = await fetch(ROOT_PATH + '/admin/api/model-configure',{
......@@ -1924,19 +2691,25 @@ async function unloadModel(idx){
async function disableModel(idx){
const m = _localModels[idx];
if(!confirm('Remove this model from CoderAI config? It will stay in the local cache.')) return;
if(!m){ alert('Error: model not found — please refresh the page.'); return; }
if(!await showConfirm('Remove model', 'Remove this model from CoderAI config? It will stay in the local cache.', 'Remove')) return;
try{
// Collect all paths that are marked in_config for this model (covers quant groups)
const paths = m.ggufGroup
? m.ggufGroup.filter(f=>f.in_config).map(f=>f.path)
: (m.in_config ? [m.path] : []);
if(!paths.length) paths.push(m.path); // fallback
try{
// Disable each configured path in the group
for(const p of paths){
await fetch(ROOT_PATH + '/admin/api/model-disable',{
const r = await fetch(ROOT_PATH + '/admin/api/model-disable',{
method:'POST', headers:{'Content-Type':'application/json'},
body: JSON.stringify({path: p})
});
if(!r.ok){
let msg = `HTTP ${r.status}`;
try{ const d = await r.json(); msg = d.detail || d.error || msg; }catch(_){}
throw new Error(msg);
}
}
loadCachedModels();
}catch(e){ alert('Error: '+e.message); }
......
......@@ -120,7 +120,7 @@
<div style="display:grid;grid-template-columns:180px 1fr;gap:1rem;align-items:start">
<div class="form-row" style="margin:0">
<label class="form-label">Scope</label>
<select id="s-broker-scope" class="form-input">
<select id="s-broker-scope" class="form-input" onchange="toggleBrokerFields()">
<option value="user">user</option>
<option value="global">global</option>
</select>
......@@ -128,7 +128,7 @@
<div class="form-row" style="margin:0">
<label class="form-label">Username</label>
<input type="text" id="s-broker-username" class="form-input" placeholder="alice or global">
<span class="form-hint">Use `global` when scope is `global`; otherwise provide the AISBF username.</span>
<span class="form-hint">This is forced to `global` for global scope; user scope requires the AISBF username.</span>
</div>
</div>
<div style="display:grid;grid-template-columns:1fr 1fr;gap:1rem;align-items:start">
......@@ -150,6 +150,11 @@
<input type="text" id="s-broker-advertised-endpoint" class="form-input" placeholder="http://127.0.0.1:8776">
<span class="form-hint">Optional external URL advertised to the broker for this instance.</span>
</div>
<div class="form-row">
<label class="form-label">Websocket path override</label>
<input type="text" id="s-broker-websocket-path" class="form-input" placeholder="/api/coderai/wss">
<span class="form-hint">Optional manual websocket route override for proxied or custom broker deployments; leave empty to derive from scope.</span>
</div>
<div style="display:grid;grid-template-columns:repeat(3, minmax(0, 1fr));gap:1rem;align-items:start">
<div class="form-row" style="margin:0">
<label class="form-label">Heartbeat seconds</label>
......@@ -164,7 +169,7 @@
<input type="number" id="s-broker-request-timeout" class="form-input" min="1" placeholder="30">
</div>
</div>
<div style="display:grid;grid-template-columns:repeat(2, minmax(0, 1fr));gap:1rem;align-items:start">
<div style="display:grid;grid-template-columns:repeat(3, minmax(0, 1fr));gap:1rem;align-items:start">
<div class="form-row" style="margin:0">
<label class="form-label">Reconnect initial delay</label>
<input type="number" id="s-broker-reconnect-initial" class="form-input" min="1" placeholder="1">
......@@ -173,6 +178,11 @@
<label class="form-label">Reconnect max delay</label>
<input type="number" id="s-broker-reconnect-max" class="form-input" min="1" placeholder="60">
</div>
<div class="form-row" style="margin:0">
<label class="form-label">WS ping interval (s)</label>
<input type="number" id="s-broker-ws-ping" class="form-input" min="5" placeholder="20">
<span class="form-hint" style="font-size:11px">Keeps connection alive through nginx proxies. Lower if you get 504 timeouts.</span>
</div>
</div>
</div>
</div>
......@@ -188,6 +198,16 @@ function toggleHttps(){
function toggleBrokerFields(){
document.getElementById('broker-fields').style.display =
document.getElementById('s-broker-enabled').checked ? 'block' : 'none';
const scope = document.getElementById('s-broker-scope').value;
const usernameInput = document.getElementById('s-broker-username');
if(scope === 'global'){
usernameInput.value = 'global';
usernameInput.readOnly = true;
} else {
usernameInput.readOnly = false;
if(usernameInput.value === 'global') usernameInput.value = '';
}
}
function showAlert(type, msg){
......@@ -232,11 +252,13 @@ async function loadSettings(){
document.getElementById('s-broker-client-id').value = broker.client_id ?? '';
document.getElementById('s-broker-registration-token').value = broker.registration_token ?? '';
document.getElementById('s-broker-advertised-endpoint').value = broker.advertised_endpoint ?? '';
document.getElementById('s-broker-websocket-path').value = broker.websocket_path ?? '';
document.getElementById('s-broker-heartbeat').value = broker.heartbeat_interval_seconds ?? 30;
document.getElementById('s-broker-connect-timeout').value = broker.connect_timeout_seconds ?? 10;
document.getElementById('s-broker-request-timeout').value = broker.request_timeout_seconds ?? 30;
document.getElementById('s-broker-reconnect-initial').value = broker.reconnect_initial_delay_seconds ?? 1;
document.getElementById('s-broker-reconnect-max').value = broker.reconnect_max_delay_seconds ?? 60;
document.getElementById('s-broker-ws-ping').value = broker.websocket_ping_interval ?? 20;
toggleBrokerFields();
}catch(e){ showAlert('error','Failed to load settings: '+e.message); }
}
......@@ -273,11 +295,13 @@ async function saveSettings(){
client_id: document.getElementById('s-broker-client-id').value.trim(),
registration_token: document.getElementById('s-broker-registration-token').value.trim(),
advertised_endpoint: document.getElementById('s-broker-advertised-endpoint').value.trim(),
websocket_path: document.getElementById('s-broker-websocket-path').value.trim(),
heartbeat_interval_seconds: parseInt(document.getElementById('s-broker-heartbeat').value) || 30,
connect_timeout_seconds: parseInt(document.getElementById('s-broker-connect-timeout').value) || 10,
request_timeout_seconds: parseInt(document.getElementById('s-broker-request-timeout').value) || 30,
reconnect_initial_delay_seconds: parseInt(document.getElementById('s-broker-reconnect-initial').value) || 1,
reconnect_max_delay_seconds: parseInt(document.getElementById('s-broker-reconnect-max').value) || 60,
websocket_ping_interval: parseInt(document.getElementById('s-broker-ws-ping').value) || 20,
transport: 'websocket',
},
};
......
......@@ -43,12 +43,23 @@ global_file_path = None
_aud_progress: dict = {
"current": 0, "total": 0, "active": False,
"started_at": 0.0, "it_per_s": 0.0, "unit": "it",
"phase": "idle", "model": "",
}
def _aud_progress_loading(model_name: str = ""):
_aud_progress["phase"] = "loading"
_aud_progress["active"] = True
_aud_progress["current"] = 0
_aud_progress["total"] = 0
_aud_progress["it_per_s"] = 0.0
_aud_progress["started_at"] = time.monotonic()
_aud_progress["model"] = model_name or ""
def _aud_progress_reset(total: int, unit: str = "it"):
_aud_progress["current"] = 0
_aud_progress["total"] = total
_aud_progress["active"] = True
_aud_progress["phase"] = "generating"
_aud_progress["started_at"] = time.monotonic()
_aud_progress["it_per_s"] = 0.0
_aud_progress["unit"] = unit
......@@ -56,6 +67,7 @@ def _aud_progress_reset(total: int, unit: str = "it"):
def _aud_progress_done():
_aud_progress["current"] = max(_aud_progress["current"], _aud_progress["total"])
_aud_progress["active"] = False
_aud_progress["phase"] = "idle"
def _aud_progress_step(step: int):
_aud_progress["current"] = step
......@@ -196,6 +208,8 @@ async def get_audio_progress():
"current": current,
"total": total,
"active": _aud_progress["active"],
"phase": _aud_progress.get("phase", "idle"),
"model": _aud_progress.get("model", ""),
"pct": int(current / total * 100) if total > 0 else 0,
"it_per_s": _aud_progress["it_per_s"],
"elapsed": round(elapsed, 1),
......@@ -209,6 +223,7 @@ async def audio_generate(request: AudioGenerationRequest, http_request: Request
Generate music, sound effects, or ambient audio.
Compatible models: MusicGen, AudioGen, AudioLDM2, StableAudio.
"""
_aud_progress_loading(request.model or "audio")
model_info = multi_model_manager.request_model(request.model, model_type="audio_gen")
model_name = model_info.get('model_name')
if not model_name:
......
......@@ -123,18 +123,30 @@ import time as _time
_gen_progress: dict = {
"current": 0, "total": 0, "active": False,
"started_at": 0.0, "it_per_s": 0.0,
"phase": "idle", "model": "",
}
def _progress_loading(model_name: str = ""):
_gen_progress["phase"] = "loading"
_gen_progress["active"] = True
_gen_progress["current"] = 0
_gen_progress["total"] = 0
_gen_progress["it_per_s"] = 0.0
_gen_progress["started_at"] = _time.monotonic()
_gen_progress["model"] = model_name or ""
def _progress_reset(total: int):
_gen_progress["current"] = 0
_gen_progress["total"] = total
_gen_progress["active"] = True
_gen_progress["phase"] = "generating"
_gen_progress["started_at"] = _time.monotonic()
_gen_progress["it_per_s"] = 0.0
def _progress_done():
_gen_progress["current"] = _gen_progress["total"]
_gen_progress["active"] = False
_gen_progress["phase"] = "idle"
def _progress_step(step: int):
_gen_progress["current"] = step
......@@ -894,6 +906,8 @@ async def get_image_progress():
"current": _gen_progress["current"],
"total": _gen_progress["total"],
"active": _gen_progress["active"],
"phase": _gen_progress.get("phase", "idle"),
"model": _gen_progress.get("model", ""),
"pct": int(_gen_progress["current"] / _gen_progress["total"] * 100)
if _gen_progress["total"] > 0 else 0,
"it_per_s": _gen_progress["it_per_s"],
......@@ -944,6 +958,7 @@ async def create_image_generation(request: ImageGenerationRequest, http_request:
# =====================================================================
# Step 1: Ask the manager to resolve the model and manage VRAM
# =====================================================================
_progress_loading(request.model or "image")
model_info = multi_model_manager.request_model(
requested_model=request.model,
model_type="image"
......@@ -1173,6 +1188,7 @@ async def create_image_edit(request: ImageEditRequest, http_request: Request = N
if not request.image:
raise HTTPException(status_code=400, detail="image is required")
_progress_loading(request.model or "image")
model_info = multi_model_manager.request_model(request.model, model_type="image")
model_name = model_info.get('model_name')
if not model_name:
......@@ -1306,6 +1322,7 @@ async def create_image_inpaint(request: ImageInpaintRequest, http_request: Reque
global global_args
if not request.image or not request.mask:
raise HTTPException(status_code=400, detail="image and mask are required")
_progress_loading(request.model or "image")
model_info = multi_model_manager.request_model(request.model, model_type="image")
model_name = model_info.get('model_name')
if not model_name:
......@@ -1414,6 +1431,7 @@ def _run_upscale(upscaler, image_bytes: bytes, scale: int):
async def create_image_upscale(request: ImageUpscaleRequest, http_request: Request = None):
"""Upscale an image using Real-ESRGAN or PIL LANCZOS fallback."""
global global_args
_progress_loading(request.model or "image")
model_info = multi_model_manager.request_model(request.model, model_type="image")
model_name = model_info.get('model_name') or request.model
model_key = f"upscale:{model_name}"
......
......@@ -41,6 +41,11 @@ class BearerAuthMiddleware(BaseHTTPMiddleware):
if not path.startswith("/v1/") or path in self._EXEMPT_PATHS:
return await call_next(request)
# Requests from the ASGI broker bridge are in-process and have no real
# Bearer token. Identify them by the sentinel server tuple set in asgi_bridge.py.
if request.scope.get("server") == ("internal", 80):
return await call_next(request)
from codai.admin import routes as _admin_routes
sm = _admin_routes.session_manager
if sm is None:
......@@ -98,6 +103,7 @@ class _Bucket:
self.window_start = now
class RateLimitMiddleware(BaseHTTPMiddleware):
"""Apply per-IP, per-route-prefix rate limiting to API endpoints."""
......
......@@ -57,18 +57,30 @@ global_file_path = None
_vid_progress: dict = {
"current": 0, "total": 0, "active": False,
"started_at": 0.0, "it_per_s": 0.0,
"phase": "idle", "model": "",
}
def _vid_progress_loading(model_name: str = ""):
_vid_progress["phase"] = "loading"
_vid_progress["active"] = True
_vid_progress["current"] = 0
_vid_progress["total"] = 0
_vid_progress["it_per_s"] = 0.0
_vid_progress["started_at"] = time.monotonic()
_vid_progress["model"] = model_name or ""
def _vid_progress_reset(total: int):
_vid_progress["current"] = 0
_vid_progress["total"] = total
_vid_progress["active"] = True
_vid_progress["phase"] = "generating"
_vid_progress["started_at"] = time.monotonic()
_vid_progress["it_per_s"] = 0.0
def _vid_progress_done():
_vid_progress["current"] = _vid_progress["total"]
_vid_progress["active"] = False
_vid_progress["phase"] = "idle"
def _vid_progress_step(step: int):
_vid_progress["current"] = step
......@@ -190,19 +202,307 @@ def _detect_pipeline_class(model_name: str, mode: str):
return None
def _load_video_pipeline(model_name: str, device: str, mode: str):
import torch, gc
def _gguf_needs_wan_prefix(path: str) -> bool:
"""Return True if this GGUF has bare WAN tensor names (no model.diffusion_model. prefix)."""
import struct
try:
with open(path, 'rb') as f:
magic = f.read(4)
if magic != b'GGUF':
return False
f.read(4) # version
f.read(8) # tensor_count
kv_count = struct.unpack('<Q', f.read(8))[0]
def _read_str(fh):
n = struct.unpack('<Q', fh.read(8))[0]
return fh.read(n)
arch = b''
for _ in range(kv_count):
key = _read_str(f)
vtype = struct.unpack('<I', f.read(4))[0]
if vtype == 8: # string
val = _read_str(f)
if key == b'general.architecture':
arch = val
elif vtype in (0, 1, 7): # u8/i8/bool
f.read(1)
elif vtype in (2, 3): # u16/i16
f.read(2)
elif vtype in (4, 5, 6): # u32/i32/f32
f.read(4)
elif vtype in (10, 11, 12): # u64/i64/f64
f.read(8)
elif vtype == 9: # array — skip
atype = struct.unpack('<I', f.read(4))[0]
alen = struct.unpack('<Q', f.read(8))[0]
# skip array payload: only handle simple element types
sizes = {0:1,1:1,2:2,3:2,4:4,5:4,6:4,7:1,10:8,11:8,12:8}
if atype in sizes:
f.read(alen * sizes[atype])
else:
return False # complex array — bail out safely
else:
return False # unknown type — bail out
if arch.lower() != b'wan':
return False
# Read first tensor name; if it lacks the expected prefix, rewrite is needed
name = _read_str(f)
return not name.startswith(b'model.diffusion_model.')
except Exception:
return False
def _ensure_wan_prefixed_gguf(src_path: str) -> str:
"""Return a path to a GGUF with model.diffusion_model. prefixed tensor names.
If src_path already has the prefix, returns it unchanged.
Otherwise creates a sibling file <name>.sdcpp.gguf by rewriting the header
and streaming the original data section — no full file load into memory.
"""
import os as _os, struct, shutil
dst_path = src_path + '.sdcpp.gguf'
if _os.path.exists(dst_path) and _os.path.getmtime(dst_path) >= _os.path.getmtime(src_path):
print(f" [gguf] using cached prefixed GGUF: {dst_path}")
return dst_path
prefix = b'model.diffusion_model.'
print(f" [gguf] rewriting tensor names with '{prefix.decode()}' prefix …")
with open(src_path, 'rb') as src:
magic = src.read(4)
version = src.read(4)
tensor_count_bytes = src.read(8)
kv_count_bytes = src.read(8)
tensor_count = struct.unpack('<Q', tensor_count_bytes)[0]
kv_count = struct.unpack('<Q', kv_count_bytes)[0]
# ── collect raw KV bytes (pass through unchanged) ──────────────────
kv_bytes = bytearray()
def _read_str_raw(fh):
n_bytes = fh.read(8)
n = struct.unpack('<Q', n_bytes)[0]
data = fh.read(n)
return n_bytes + data, data
for _ in range(kv_count):
raw_key, _ = _read_str_raw(src)
kv_bytes += raw_key
vtype_bytes = src.read(4)
vtype = struct.unpack('<I', vtype_bytes)[0]
kv_bytes += vtype_bytes
if vtype == 8:
raw_val, _ = _read_str_raw(src)
kv_bytes += raw_val
elif vtype in (0, 1, 7):
kv_bytes += src.read(1)
elif vtype in (2, 3):
kv_bytes += src.read(2)
elif vtype in (4, 5, 6):
kv_bytes += src.read(4)
elif vtype in (10, 11, 12):
kv_bytes += src.read(8)
elif vtype == 9:
atype_bytes = src.read(4)
alen_bytes = src.read(8)
atype = struct.unpack('<I', atype_bytes)[0]
alen = struct.unpack('<Q', alen_bytes)[0]
sizes = {0:1,1:1,2:2,3:2,4:4,5:4,6:4,7:1,10:8,11:8,12:8}
elem_size = sizes.get(atype, 0)
arr_data = src.read(alen * elem_size) if elem_size else b''
kv_bytes += atype_bytes + alen_bytes + arr_data
# ── collect tensor info, prefixing each name ────────────────────────
ti_bytes = bytearray()
for _ in range(tensor_count):
raw_nlen, name = _read_str_raw(src)
new_name = prefix + name
ti_bytes += struct.pack('<Q', len(new_name)) + new_name
n_dims_bytes = src.read(4)
n_dims = struct.unpack('<I', n_dims_bytes)[0]
ti_bytes += n_dims_bytes
ti_bytes += src.read(n_dims * 8) # shape (u64 each)
ti_bytes += src.read(4) # dtype
ti_bytes += src.read(8) # offset within data section
# Alignment: GGUF data section starts at next 32-byte boundary
ALIGN = 32
header_size = 4 + 4 + 8 + 8 + len(kv_bytes) + len(ti_bytes)
pad = (ALIGN - header_size % ALIGN) % ALIGN
data_offset = src.tell() + ((ALIGN - src.tell() % ALIGN) % ALIGN)
src.seek(data_offset)
with open(dst_path, 'wb') as dst:
dst.write(magic + version + tensor_count_bytes + kv_count_bytes)
dst.write(kv_bytes)
dst.write(ti_bytes)
dst.write(b'\x00' * pad)
shutil.copyfileobj(src, dst, length=8 * 1024 * 1024)
print(f" [gguf] prefixed GGUF written: {dst_path}")
return dst_path
def _load_sdcpp_video_model(model_path: str, offload: str = None, model_cfg: dict = None):
"""Load a GGUF video model via stable-diffusion.cpp."""
try:
from stable_diffusion_cpp import StableDiffusion
import stable_diffusion_cpp.stable_diffusion_cpp as _sd_cpp
except ImportError:
raise RuntimeError("stable-diffusion-cpp-python required: pip install stable-diffusion-cpp-python")
import os as _os
model_cfg = model_cfg or {}
# Resolve bare filename to absolute path from the GGUF cache
if not _os.path.isabs(model_path) and not _os.path.exists(model_path):
try:
from codai.models.cache import get_model_cache_dir
candidate = _os.path.join(get_model_cache_dir(), model_path)
if _os.path.exists(candidate):
model_path = candidate
except Exception:
pass
if not _os.path.exists(model_path):
raise FileNotFoundError(f"GGUF video model not found: {model_path}")
# WAN DiT-only GGUFs (e.g. QuantStack) contain only the denoiser — no VAE or
# text encoders. They must be loaded via diffusion_model_path, not model_path.
# sd.cpp internally prepends "model.diffusion_model." when reading tensors from
# diffusion_model_path, so the original bare-named file is passed directly.
# VAE / text encoders must be supplied separately via model_cfg component paths.
_is_wan_dit = _gguf_needs_wan_prefix(model_path)
if _is_wan_dit:
print(f"Loading sd.cpp video model (WAN DiT, diffusion_model_path): {model_path}")
kwargs = {'diffusion_model_path': model_path, 'verbose': True}
else:
print(f"Loading sd.cpp video model: {model_path}")
kwargs = {'model_path': model_path, 'verbose': True}
if offload in ('model', 'cpu', 'sequential'):
kwargs['offload_params_to_cpu'] = True
kwargs['keep_clip_on_cpu'] = True
kwargs['keep_vae_on_cpu'] = True
# sd.cpp VRAM budget for graph-cut layer execution (0 = disabled)
max_vram = float(model_cfg.get('max_vram') or 0)
if max_vram > 0:
kwargs['max_vram'] = max_vram
# Flash attention variants (sdcpp_flash_attn = full, sdcpp_diffusion_flash_attn = DiT only)
if model_cfg.get('sdcpp_flash_attn'):
kwargs['flash_attn'] = True
if model_cfg.get('sdcpp_diffusion_flash_attn'):
kwargs['diffusion_flash_attn'] = True
# Inject component paths from per-model configuration
for key in ('vae_path', 't5xxl_path', 'clip_l_path', 'clip_g_path',
'clip_vision_path', 'lora_model_dir'):
val = (model_cfg.get(key) or '').strip()
if val:
kwargs[key] = val
print(f" [sd.cpp] {key}: {val}")
# A single LoRA file associated with this model: derive lora_model_dir from its parent
# and append <lora:basename:1.0> to the default prompt via lora_model_dir.
# sd.cpp needs a directory, so point it at the file's parent directory.
lora_path = (model_cfg.get('lora_path') or '').strip()
if lora_path and 'lora_model_dir' not in kwargs:
import os as _os2
kwargs['lora_model_dir'] = _os2.path.dirname(lora_path) or '.'
print(f" [sd.cpp] lora_path: {lora_path} → lora_model_dir: {kwargs['lora_model_dir']}")
@_sd_cpp.sd_log_callback
def _log_cb(level, text, data):
if text:
line = text.decode('utf-8', errors='replace').rstrip()
if line:
print(f" [sd.cpp] {line}", flush=True)
_sd_cpp.sd_set_log_callback(_log_cb, None)
try:
model = StableDiffusion(**kwargs)
finally:
_sd_cpp.sd_set_log_callback(None, None)
return model
def _generate_sdcpp_video(sd_model, request, model_cfg=None):
"""Generate frames via stable-diffusion.cpp and return (frames, fps)."""
mode = request.mode or 't2v'
fps = request.fps or 8
num_frames = request.num_frames or 25
steps = request.num_inference_steps or 20
_vid_progress_reset(steps)
def _progress_cb(step: int, total: int, elapsed: float):
_vid_progress_step(step)
kw = {
'prompt': request.prompt or '',
'negative_prompt': request.negative_prompt or '',
'width': request.width or 512,
'height': request.height or 512,
'video_frames': num_frames,
'sample_steps': steps,
'cfg_scale': request.guidance_scale or 7.0,
'seed': request.seed if request.seed is not None else -1,
'progress_callback': _progress_cb,
}
if (model_cfg or {}).get('vae_tiling'):
kw['vae_tiling'] = True
init_src = request.init_image or request.image
if mode in ('i2v', 'ti2v') and init_src:
kw['init_image'] = _pil_from_b64(init_src)
elif mode == 'interp':
if not init_src or not request.end_image:
raise ValueError("interp mode requires both init_image and end_image")
kw['init_image'] = _pil_from_b64(init_src)
kw['end_image'] = _pil_from_b64(request.end_image)
frames = sd_model.generate_video(**kw)
_vid_progress_done()
return list(frames), fps
def _load_video_pipeline(model_name: str, device: str, mode: str, offload: str = None, model_cfg: dict = None):
# GGUF models go through stable-diffusion.cpp, not diffusers
from codai.api.images import _is_gguf_model
if _is_gguf_model(model_name):
return _load_sdcpp_video_model(model_name, offload, model_cfg)
import sys, time, torch, gc
PClass = _detect_pipeline_class(model_name, mode)
if PClass is None:
raise RuntimeError("diffusers not installed: pip install diffusers")
precision = getattr(global_args, 'image_precision', 'bf16') if global_args else 'bf16'
dtype_map = {'bf16': torch.bfloat16, 'f16': torch.float16, 'f32': torch.float32}
torch_dtype = dtype_map.get(precision, torch.bfloat16)
# Explicit parameter wins; fall back to global CLI arg
if offload is None:
offload = getattr(global_args, 'offload_strategy', None) if global_args else None
# Normalise UI values to diffusers vocabulary
if offload == 'cpu':
offload = 'model'
# Lower the GIL switch interval so the asyncio event loop thread wins the GIL
# more often during Python-heavy component loading (diffusers from_pretrained
# holds the GIL for extended stretches while instantiating pipeline components).
_old_interval = sys.getswitchinterval()
sys.setswitchinterval(0.001)
try:
for attempt in range(3):
try:
time.sleep(0) # yield GIL before heavy loading begins
pipe = PClass.from_pretrained(model_name, torch_dtype=torch_dtype)
time.sleep(0) # yield GIL before GPU transfer
if offload == 'sequential' or attempt >= 2:
pipe.enable_sequential_cpu_offload()
elif offload == 'model' or attempt >= 1:
......@@ -219,8 +519,11 @@ def _load_video_pipeline(model_name: str, device: str, mode: str):
_torch.cuda.empty_cache()
except Exception:
pass
time.sleep(0)
continue
raise
finally:
sys.setswitchinterval(_old_interval)
# =============================================================================
......@@ -817,6 +1120,8 @@ async def get_video_progress():
"current": _vid_progress["current"],
"total": _vid_progress["total"],
"active": _vid_progress["active"],
"phase": _vid_progress.get("phase", "idle"),
"model": _vid_progress.get("model", ""),
"pct": int(_vid_progress["current"] / _vid_progress["total"] * 100)
if _vid_progress["total"] > 0 else 0,
"it_per_s": _vid_progress["it_per_s"],
......@@ -843,6 +1148,7 @@ async def video_generations(request: VideoGenerationRequest,
"""
if not request.model:
raise HTTPException(status_code=400, detail="model is required")
_vid_progress_loading(request.model)
# Infer mode from inputs if not set
if not request.mode or request.mode == 't2v':
......@@ -864,9 +1170,11 @@ async def video_generations(request: VideoGenerationRequest,
if pipe is None:
device = _derive_device()
_model_cfg = model_info.get('config') or {}
_offload = _model_cfg.get('offload_strategy') or None
try:
pipe = await asyncio.get_event_loop().run_in_executor(
None, _load_video_pipeline, model_name, device, request.mode)
None, _load_video_pipeline, model_name, device, request.mode, _offload, _model_cfg)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to load video model: {e}")
multi_model_manager.models[model_key] = pipe
......@@ -875,7 +1183,18 @@ async def video_generations(request: VideoGenerationRequest,
if getattr(request, 'disable_safety_checker', False):
_disable_safety_checker(pipe)
_is_sdcpp_video = False
try:
from stable_diffusion_cpp import StableDiffusion as _SD
_is_sdcpp_video = isinstance(pipe, _SD)
except ImportError:
pass
try:
if _is_sdcpp_video:
frames, fps = await asyncio.get_event_loop().run_in_executor(
None, _generate_sdcpp_video, pipe, request, _model_cfg)
else:
frames, fps = await asyncio.get_event_loop().run_in_executor(
None, _generate_video, pipe, request)
except Exception as e:
......
......@@ -41,11 +41,44 @@ except (ImportError, AttributeError):
try:
from llama_cpp import Llama
from llama_cpp.llama_chat_format import ChatFormatterResponse
import llama_cpp as _llama_cpp
LLAMA_CPP_AVAILABLE = True
except ImportError:
LLAMA_CPP_AVAILABLE = False
Llama = None
ChatFormatterResponse = None
_llama_cpp = None
def _install_layer_log_callback():
"""Replace llama.cpp's log callback with one that prints load-time layer/buffer
messages directly to stdout. Returns the callback object — keep a reference
alive for the duration of the load so ctypes doesn't garbage-collect it."""
if _llama_cpp is None:
return None
# Keywords that identify interesting load-phase messages
_KEEP = (
'llm_load_tensors', 'llm_load_print_meta',
'offload', 'layer', 'buffer size', 'buffer type',
'GPU', 'CUDA', 'Vulkan', 'Metal', 'ROCm', 'SYCL',
'CPU', 'VRAM', 'n_layer', 'n_gpu_layers',
)
@_llama_cpp.llama_log_callback
def _cb(level, text, user_data):
try:
msg = (text.decode('utf-8', errors='replace') if isinstance(text, bytes) else str(text)).rstrip()
if msg and any(k in msg for k in _KEEP):
print(f" [llama.cpp] {msg}", flush=True)
except Exception:
pass
try:
_llama_cpp.llama_log_set(_cb, None)
except Exception:
return None
return _cb # caller must hold this reference
class VulkanBackend(ModelBackend):
......@@ -450,6 +483,7 @@ class VulkanBackend(ModelBackend):
# Try to find GGUF files in the repository
try:
from huggingface_hub import list_repo_files, hf_hub_download
from codai.models.cache import get_hf_hub_cache_dir
print(f"DEBUG: Searching for GGUF files in {model_path}...")
files = list(list_repo_files(model_path, repo_type="model"))
gguf_files = [f for f in files if f.lower().endswith('.gguf')]
......@@ -460,7 +494,7 @@ class VulkanBackend(ModelBackend):
selected = preferred[0] if preferred else gguf_files[0]
print(f"DEBUG: Found GGUF files: {gguf_files}")
print(f"DEBUG: Selected: {selected}")
model_path = hf_hub_download(repo_id=model_path, filename=selected, cache_dir=kwargs.get('cache_dir'))
model_path = hf_hub_download(repo_id=model_path, filename=selected, cache_dir=kwargs.get('cache_dir') or get_hf_hub_cache_dir())
print(f"DEBUG: Downloaded: {model_path}")
else:
print(f"Warning: No GGUF files found in {model_path}, trying direct download...")
......@@ -471,8 +505,9 @@ class VulkanBackend(ModelBackend):
# Try to get from HuggingFace
try:
from huggingface_hub import hf_hub_download
from codai.models.cache import get_hf_hub_cache_dir
# Download the GGUF file
model_path = hf_hub_download(repo_id=model_path, filename="*.gguf", cache_dir=kwargs.get('cache_dir'))
model_path = hf_hub_download(repo_id=model_path, filename="*.gguf", cache_dir=kwargs.get('cache_dir') or get_hf_hub_cache_dir())
except Exception as e:
print(f"Warning: Could not download from HuggingFace: {e}")
# Try as-is
......@@ -557,18 +592,41 @@ class VulkanBackend(ModelBackend):
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
# llama-cpp-python will use CUDA when available
# Pre-load summary
gpu_label = "all" if self.n_gpu_layers == -1 else str(self.n_gpu_layers)
print(f" n_gpu_layers : {gpu_label} | n_ctx : {self.n_ctx} | main_gpu : {self.main_gpu}")
if _llama_cpp:
gpu_supported = _llama_cpp.llama_supports_gpu_offload()
print(f" GPU offload : {'supported' if gpu_supported else 'NOT supported by this build'}")
_log_cb = _install_layer_log_callback()
try:
self.model = Llama(**llama_kwargs)
# Try to detect and set up chat template
self._finalize_chat_template_detection()
print(f"DEBUG: VulkanBackend loaded model: {model_path}")
print(f"DEBUG: n_gpu_layers={self.n_gpu_layers}, n_ctx={self.n_ctx}, no_ram={no_ram}")
print(f"DEBUG: chat_template={self.chat_template}")
except Exception as e:
print(f"Error loading GGUF model: {e}")
raise
finally:
# Restore llama.cpp's default (quiet) logging after load
if _llama_cpp:
try:
_llama_cpp.llama_log_set(None, None)
except Exception:
pass
_log_cb = None # release callback
# Post-load layer/buffer summary
try:
n_total = _llama_cpp.llama_model_n_layer(self.model.model)
n_gpu_actual = n_total if self.n_gpu_layers == -1 else min(self.n_gpu_layers, n_total)
n_cpu = n_total - n_gpu_actual
print(f" Layers total : {n_total}")
print(f" Layers → GPU : {n_gpu_actual} | Layers → CPU : {n_cpu}")
except Exception:
pass
# Try to detect and set up chat template
self._finalize_chat_template_detection()
print(f" chat_template: {self.chat_template}")
def generate(
self,
......
"""Helpers for executing in-process ASGI HTTP requests."""
import base64
import logging
import uuid
from urllib.parse import urlencode
logger = logging.getLogger(__name__)
def _build_multipart_body(multipart):
boundary = f"coderai-broker-{uuid.uuid4().hex}"
chunks = []
for field in multipart.get("fields") or []:
name = str(field.get("name") or "")
value = str(field.get("value") or "")
chunks.append(f"--{boundary}\r\n".encode("utf-8"))
chunks.append(f'Content-Disposition: form-data; name="{name}"\r\n\r\n'.encode("utf-8"))
chunks.append(value.encode("utf-8"))
chunks.append(b"\r\n")
for file_entry in multipart.get("files") or []:
name = str(file_entry.get("name") or "file")
filename = str(file_entry.get("filename") or "upload.bin")
content_type = str(file_entry.get("content_type") or "application/octet-stream")
data_base64 = file_entry.get("data_base64") or ""
file_bytes = base64.b64decode(data_base64) if data_base64 else b""
chunks.append(f"--{boundary}\r\n".encode("utf-8"))
chunks.append(
f'Content-Disposition: form-data; name="{name}"; filename="{filename}"\r\n'.encode("utf-8")
)
chunks.append(f"Content-Type: {content_type}\r\n\r\n".encode("utf-8"))
chunks.append(file_bytes)
chunks.append(b"\r\n")
chunks.append(f"--{boundary}--\r\n".encode("utf-8"))
return b"".join(chunks), f"multipart/form-data; boundary={boundary}"
async def execute_internal_request(app, *, method, path, headers=None, query=None, body=b""):
logger.debug(
"ASGI bridge → %s %s query=%s body_bytes=%d",
method.upper(), path, query or {}, len(body),
)
request_headers = []
for key, value in (headers or {}).items():
request_headers.append((key.lower().encode("latin-1"), str(value).encode("latin-1")))
......@@ -49,4 +89,14 @@ async def execute_internal_request(app, *, method, path, headers=None, query=Non
response["body"] += message.get("body", b"")
await app(scope, receive, send)
body_preview = response["body"][:200].decode("utf-8", errors="replace") if response["body"] else ""
logger.debug(
"ASGI bridge ← %s %s status=%d content-type=%s body_bytes=%d body_preview=%r",
method.upper(), path,
response["status_code"],
response["headers"].get("content-type", ""),
len(response["body"]),
body_preview,
)
return response
"""Broker capability documents and registration payloads."""
import glob
import os
import platform
import socket
from typing import Any, Dict, Sequence
......@@ -41,15 +43,67 @@ DEFAULT_STUDIO_ENDPOINTS = [
def build_hardware_summary() -> Dict[str, Any]:
"""Build a conservative default hardware summary."""
"""Build a conservative hardware summary with VRAM when available."""
gpus = []
total_vram_mb = 0
available_vram_mb = 0
try:
import torch
if torch.cuda.is_available():
gpu_count = torch.cuda.device_count()
for index in range(gpu_count):
props = torch.cuda.get_device_properties(index)
device_total_mb = int(props.total_memory / (1024 * 1024))
if index == torch.cuda.current_device():
free_bytes, total_bytes = torch.cuda.mem_get_info()
total_vram_mb = int(total_bytes / (1024 * 1024))
available_vram_mb = int(free_bytes / (1024 * 1024))
gpus.append(
{
"index": index,
"name": torch.cuda.get_device_name(index),
"total_vram_mb": device_total_mb,
}
)
if gpus:
if total_vram_mb == 0:
total_vram_mb = sum(gpu["total_vram_mb"] for gpu in gpus)
if available_vram_mb == 0 and total_vram_mb:
available_vram_mb = total_vram_mb
except Exception:
pass
if not gpus:
for total_path in sorted(glob.glob("/sys/class/drm/card*/device/mem_info_vram_total")):
used_path = total_path.replace("vram_total", "vram_used")
if not os.path.exists(used_path):
continue
try:
device_total_mb = int(int(open(total_path).read()) / (1024 * 1024))
device_used_mb = int(int(open(used_path).read()) / (1024 * 1024))
device_available_mb = max(0, device_total_mb - device_used_mb)
card_name = os.path.basename(os.path.dirname(os.path.dirname(total_path)))
gpus.append(
{
"name": card_name,
"total_vram_mb": device_total_mb,
}
)
total_vram_mb += device_total_mb
available_vram_mb += device_available_mb
except Exception:
continue
return {
"hostname": socket.gethostname(),
"platform": platform.platform(),
"gpus": [],
"gpu_count": 0,
"total_vram_mb": 0,
"available_vram_mb": 0,
"gpus": gpus,
"gpu_count": len(gpus),
"total_vram_mb": total_vram_mb,
"available_vram_mb": available_vram_mb,
}
......@@ -64,6 +118,7 @@ def build_capabilities_document(
"server": "codai",
"version": version,
"transports": ["websocket"],
"tunnel_only": True,
"openai_compat": {
"chat_completions": True,
"responses": False,
......@@ -88,15 +143,23 @@ def build_register_message(
) -> Dict[str, Any]:
"""Build broker registration frame."""
registration_token = runtime.registration_token
return {
"v": 1,
"op": "register",
"request_id": request_id,
"registration_token": registration_token,
"capabilities": capabilities,
"payload": {
"endpoint": runtime.advertised_endpoint,
"transport": runtime.transport,
"registration_token": runtime.headers.get("Authorization", "").removeprefix("Bearer "),
"registration_token": registration_token,
"hardware": hardware,
"gpus": (hardware or {}).get("gpus", []),
"gpu_count": (hardware or {}).get("gpu_count", 0),
"total_vram_mb": (hardware or {}).get("total_vram_mb", 0),
"available_vram_mb": (hardware or {}).get("available_vram_mb", 0),
"studio_endpoints": list(studio_endpoints or DEFAULT_STUDIO_ENDPOINTS),
"capabilities": capabilities,
},
......
......@@ -4,6 +4,7 @@ from __future__ import annotations
import json
import asyncio
import logging
import time
import uuid
from typing import Any, Awaitable, Callable
......@@ -16,10 +17,13 @@ from codai.broker.capabilities import (
build_hardware_summary,
build_register_message,
)
from codai.broker.dispatcher import OP_ROUTE_MAP
from codai.broker.models import BrokerRequestEnvelope, success_envelope
Dispatcher = Callable[[dict[str, Any]], Awaitable[dict[str, Any]]]
logger = logging.getLogger(__name__)
class BrokerClient:
def __init__(self, runtime, dispatcher: Dispatcher | None = None):
......@@ -27,54 +31,131 @@ class BrokerClient:
self.dispatcher = dispatcher
self.websocket = None
self.session_id = None
self.session_metadata: dict[str, Any] = {}
self._heartbeat_task: asyncio.Task | None = None
self._inflight_tasks: set[asyncio.Task] = set()
async def connect_and_register(self):
logger.info(
"CoderAI broker connecting url=%s provider=%s client=%s username=%s",
self.runtime.websocket_url,
self.runtime.provider_id,
self.runtime.client_id,
self.runtime.username,
)
ping_interval = self.runtime.websocket_ping_interval or 20
self.websocket = await websockets.connect(
self.runtime.websocket_url,
additional_headers=self.runtime.headers,
open_timeout=self.runtime.connect_timeout_seconds,
ping_interval=ping_interval,
ping_timeout=max(ping_interval * 2, 60),
)
# Client speaks first: send op=register before waiting for any server message.
hardware = build_hardware_summary()
capabilities = build_capabilities_document(
studio_endpoints=DEFAULT_STUDIO_ENDPOINTS,
hardware=hardware,
)
request_id = f"reg-{uuid.uuid4().hex}"
register_message = build_register_message(
runtime=self.runtime,
request_id=request_id,
hardware=hardware,
capabilities=capabilities,
studio_endpoints=DEFAULT_STUDIO_ENDPOINTS,
)
logger.debug("broker → sending op=register request_id=%s", request_id)
await self.websocket.send(json.dumps(register_message))
registered_message = json.loads(
await asyncio.wait_for(
# Server responds with event=registered containing the assigned session_id.
raw_registered = await asyncio.wait_for(
self.websocket.recv(),
timeout=self.runtime.request_timeout_seconds,
)
registered_message = json.loads(raw_registered)
logger.debug(
"broker ← received event=%s status=%s session_id=%s bytes=%d",
registered_message.get("event"),
registered_message.get("status"),
registered_message.get("session_id"),
len(raw_registered),
)
session_id = registered_message.get("session_id")
if (
registered_message.get("event") != "registered"
or registered_message.get("accepted") is not True
or registered_message.get("status") != "ok"
or not session_id
):
logger.error(
"broker registration rejected: event=%s status=%s message=%r",
registered_message.get("event"),
registered_message.get("status"),
raw_registered[:200],
)
raise ValueError("broker did not accept registration")
self.session_id = session_id
hardware = build_hardware_summary()
capabilities = build_capabilities_document(
studio_endpoints=DEFAULT_STUDIO_ENDPOINTS,
hardware=hardware,
self.session_metadata = {
key: registered_message.get(key)
for key in (
"session_id",
"provider_id",
"client_id",
"username",
"scope_name",
"owner_user_id",
"expires_at",
)
register_message = build_register_message(
runtime=self.runtime,
request_id=str(uuid.uuid4()),
hardware=hardware,
capabilities=capabilities,
studio_endpoints=DEFAULT_STUDIO_ENDPOINTS,
}
_gpu_names = ", ".join(
g.get("name", "?") for g in (hardware.get("gpus") or [])
) or "none"
logger.info(
"CoderAI broker registered provider=%s client=%s session_id=%s scope=%s"
" gpu_count=%s gpus=[%s] total_vram_mb=%s available_vram_mb=%s",
self.session_metadata.get("provider_id"),
self.session_metadata.get("client_id"),
session_id,
self.session_metadata.get("scope_name"),
hardware.get("gpu_count"),
_gpu_names,
hardware.get("total_vram_mb"),
hardware.get("available_vram_mb"),
)
await self.websocket.send(json.dumps(register_message))
return register_message
def message_to_envelope(self, message: dict[str, Any]) -> BrokerRequestEnvelope:
payload = message.get("payload") or {}
op = message.get("op") or "proxy"
method = message.get("method") or "GET"
path = message.get("path") or ""
headers = message.get("headers", {})
query = message.get("query", {})
stream = message.get("stream", False)
content_type = message.get("content_type", "application/json")
if op == "proxy":
method = payload.get("method", method)
path = payload.get("endpoint_path", path)
headers = payload.get("headers", headers)
query = payload.get("query_params", query)
stream = payload.get("stream", stream)
content_type = payload.get("content_type", content_type)
elif op in OP_ROUTE_MAP:
method, path = OP_ROUTE_MAP[op]
return BrokerRequestEnvelope(
request_id=message["request_id"],
method=message["method"],
path=message["path"],
headers=message.get("headers", {}),
query=message.get("query", {}),
payload=message.get("payload"),
stream=message.get("stream", False),
content_type=message.get("content_type", "application/json"),
op=op,
method=str(method).upper(),
path=str(path),
headers=headers,
query=query,
payload=payload,
stream=bool(stream),
content_type=content_type,
)
def next_reconnect_delay(self, current_delay):
......@@ -85,21 +166,191 @@ class BrokerClient:
while True:
try:
await self.connect_and_register()
self._heartbeat_task = asyncio.create_task(self._heartbeat_loop())
delay = self.runtime.reconnect_initial_delay_seconds
# Watchdog must outlast one full heartbeat cycle so a quiet connection
# (no inflight requests, only periodic heartbeats) never times out.
recv_timeout = max(
self.runtime.request_timeout_seconds,
self.runtime.heartbeat_interval_seconds * 3,
)
while True:
message = await asyncio.wait_for(
self.websocket.recv(),
timeout=self.runtime.request_timeout_seconds,
timeout=recv_timeout,
)
parsed = json.loads(message)
op = parsed.get("op")
event = parsed.get("event")
request_id = parsed.get("request_id")
logger.debug(
"broker ← recv op=%s event=%s request_id=%s bytes=%d",
op, event, request_id, len(message),
)
# Server-initiated heartbeat (op=heartbeat): respond in-line.
if op == "heartbeat":
logger.debug("broker heartbeat ping request_id=%s", request_id)
await self.handle_message(parsed)
continue
# Heartbeat reply from server (status=ok, event=heartbeat, no op):
# just discard — the client initiated it, no response needed.
if event == "heartbeat" and not op:
logger.debug("broker heartbeat pong discarded request_id=%s", request_id)
continue
# Re-register ACK (event=registered, no op): discard — sent in
# response to our notify_models_updated() re-register push.
if event == "registered" and not op:
logger.debug("broker re-register ack discarded request_id=%s", request_id)
continue
logger.debug(
"broker dispatching op=%s request_id=%s as async task (inflight=%d)",
op, request_id, len(self._inflight_tasks),
)
await self.handle_message(message)
task = asyncio.create_task(self.handle_message(parsed))
self._inflight_tasks.add(task)
task.add_done_callback(self._inflight_tasks.discard)
except asyncio.CancelledError:
await self._stop_heartbeat_task()
await self._cancel_inflight_tasks()
await self._close_websocket()
raise
except Exception:
except Exception as error:
logger.warning(
"CoderAI broker connection lost provider=%s client=%s url=%s reconnect_in=%ss error=%s",
self.runtime.provider_id,
self.runtime.client_id,
self.runtime.websocket_url,
delay,
error,
)
await self._stop_heartbeat_task()
await self._cancel_inflight_tasks()
await self._close_websocket()
await asyncio.sleep(delay)
delay = self.next_reconnect_delay(delay)
finally:
await self._stop_heartbeat_task()
async def _close_websocket(self):
websocket = self.websocket
self.websocket = None
if websocket is None:
return
try:
await websocket.close()
except Exception:
pass
async def _stop_heartbeat_task(self):
if self._heartbeat_task is None:
return
task = self._heartbeat_task
self._heartbeat_task = None
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
async def _cancel_inflight_tasks(self):
if not self._inflight_tasks:
return
tasks = list(self._inflight_tasks)
self._inflight_tasks.clear()
for task in tasks:
task.cancel()
for task in tasks:
try:
await task
except asyncio.CancelledError:
pass
except Exception:
pass
async def _send_keepalives(
self, request_id: str, interval: float = 30.0, estimated_timeout: float = 300.0
):
"""Send periodic pending keepalive messages for a long-running request.
Cancelled by the caller when the request finishes."""
try:
while True:
await asyncio.sleep(interval)
if self.websocket is None:
break
try:
await self.websocket.send(json.dumps({
"v": 1,
"event": "pending",
"status": "pending",
"request_id": request_id,
"payload": {"message": "processing", "estimated_timeout": estimated_timeout},
}))
logger.debug("broker → pending keepalive request_id=%s", request_id)
except Exception:
break
except asyncio.CancelledError:
pass
async def _heartbeat_loop(self):
started_at = time.monotonic()
while True:
await asyncio.sleep(self.runtime.heartbeat_interval_seconds)
if self.websocket is None:
continue
uptime = int(time.monotonic() - started_at)
available_vram_mb = build_hardware_summary().get("available_vram_mb", 0)
heartbeat = {
"v": 1,
"op": "heartbeat",
"request_id": f"hb-{uuid.uuid4()}",
"payload": {
"uptime": uptime,
"hardware": {
"available_vram_mb": available_vram_mb,
},
},
}
logger.debug(
"broker → heartbeat uptime=%ds available_vram_mb=%s",
uptime, available_vram_mb,
)
await self.websocket.send(json.dumps(heartbeat))
async def notify_models_updated(self) -> bool:
"""Send a re-register to AISBF so it refreshes its model cache.
Returns True if the message was sent, False if not connected.
AISBF handles op=register on a live session by calling _broker_refresh_models(),
which re-fetches /v1/models from this coderai instance.
"""
if self.websocket is None:
return False
try:
hardware = build_hardware_summary()
capabilities = build_capabilities_document(
studio_endpoints=DEFAULT_STUDIO_ENDPOINTS,
hardware=hardware,
)
register_message = build_register_message(
runtime=self.runtime,
request_id=f"rereg-{uuid.uuid4().hex}",
hardware=hardware,
capabilities=capabilities,
studio_endpoints=DEFAULT_STUDIO_ENDPOINTS,
)
await self.websocket.send(json.dumps(register_message))
logger.info(
"broker → re-register (models updated) provider=%s client=%s",
self.runtime.provider_id,
self.runtime.client_id,
)
return True
except Exception as exc:
logger.warning("broker notify_models_updated failed: %s", exc)
return False
async def handle_message(self, raw_message):
message = json.loads(raw_message)
message = json.loads(raw_message) if isinstance(raw_message, str) else raw_message
if message.get("op") == "heartbeat":
response = success_envelope(
message["request_id"],
......@@ -110,8 +361,73 @@ class BrokerClient:
return response
if self.dispatcher is not None:
op = message.get("op")
request_id = message.get("request_id")
payload = message.get("payload") or {}
endpoint_path = payload.get("endpoint_path") or message.get("path") or OP_ROUTE_MAP.get(op, (None, None))[1]
method = payload.get("method") or message.get("method") or OP_ROUTE_MAP.get(op, ("GET", None))[0]
stream = bool(payload.get("stream") or message.get("stream"))
logger.info(
"CoderAI broker received request op=%s request_id=%s provider=%s client=%s endpoint=%s method=%s stream=%s",
op, request_id,
message.get("provider_id") or self.session_metadata.get("provider_id"),
message.get("client_id") or self.session_metadata.get("client_id"),
endpoint_path, method, stream,
)
# Send an immediate "pending" acknowledgment so the broker side extends
# its deadline, then keep sending periodic keepalives while the request
# runs (e.g. during model download). The keepalive task is cancelled
# as soon as we have a final response to send.
keepalive_task: asyncio.Task | None = None
if request_id:
try:
await self.websocket.send(json.dumps({
"v": 1,
"event": "pending",
"status": "pending",
"request_id": request_id,
"payload": {"message": "processing", "estimated_timeout": 300},
}))
except Exception:
pass
keepalive_task = asyncio.create_task(
self._send_keepalives(request_id, interval=30.0, estimated_timeout=300.0)
)
try:
response = await self.dispatcher(message)
await self.websocket.send(json.dumps(response))
except Exception as exc:
logger.error(
"CoderAI broker dispatcher error op=%s request_id=%s: %s",
op, request_id, exc, exc_info=True,
)
raise
finally:
if keepalive_task is not None:
keepalive_task.cancel()
try:
await keepalive_task
except asyncio.CancelledError:
pass
reply_bytes = json.dumps(response)
logger.info(
"CoderAI broker replied request_id=%s status=%s event=%s reply_bytes=%d",
response.get("request_id"),
response.get("status"),
response.get("event"),
len(reply_bytes),
)
resp_payload = response.get("payload") or {}
if isinstance(resp_payload, dict):
logger.debug(
"broker reply payload status_code=%s content_type=%s body_bytes=%s",
resp_payload.get("status_code"),
resp_payload.get("content_type"),
len((resp_payload.get("body") or "").encode()) if isinstance(resp_payload.get("body"), str) else (
len(resp_payload.get("body", b"")) if isinstance(resp_payload.get("body"), bytes) else "n/a"
),
)
await self.websocket.send(reply_bytes)
return response
return None
......@@ -21,12 +21,14 @@ class BrokerConfig:
client_id: str = ""
registration_token: str = ""
advertised_endpoint: str = ""
websocket_path: str = ""
transport: str = "websocket"
heartbeat_interval_seconds: int = 30
connect_timeout_seconds: int = 10
request_timeout_seconds: int = 30
reconnect_initial_delay_seconds: int = 1
reconnect_max_delay_seconds: int = 60
websocket_ping_interval: int = 20
@dataclass
......@@ -36,13 +38,28 @@ class BrokerRuntimeConfig:
enabled: bool
websocket_url: str = ""
headers: Dict[str, str] = field(default_factory=dict)
provider_id: str = ""
client_id: str = ""
username: str = ""
registration_token: str = ""
advertised_endpoint: str = ""
websocket_path: str = ""
transport: str = "websocket"
heartbeat_interval_seconds: int = 30
connect_timeout_seconds: int = 10
request_timeout_seconds: int = 30
reconnect_initial_delay_seconds: int = 1
reconnect_max_delay_seconds: int = 60
websocket_ping_interval: int = 20
def _join_broker_path(base_path: str, suffix: str) -> str:
normalized_base = (base_path or "").rstrip("/")
normalized_suffix = suffix if suffix.startswith("/") else f"/{suffix}"
if normalized_base.endswith("/api") and normalized_suffix.startswith("/api/"):
return f"{normalized_base}{normalized_suffix[4:]}"
return f"{normalized_base}{normalized_suffix}" if normalized_base else normalized_suffix
def build_broker_runtime_config(config: BrokerConfig) -> BrokerRuntimeConfig:
......@@ -50,18 +67,29 @@ def build_broker_runtime_config(config: BrokerConfig) -> BrokerRuntimeConfig:
runtime = BrokerRuntimeConfig(
enabled=config.enabled,
provider_id=config.provider_id,
client_id=config.client_id,
username=config.username,
registration_token=config.registration_token,
advertised_endpoint=config.advertised_endpoint,
websocket_path=config.websocket_path,
transport=config.transport,
heartbeat_interval_seconds=config.heartbeat_interval_seconds,
connect_timeout_seconds=config.connect_timeout_seconds,
request_timeout_seconds=config.request_timeout_seconds,
reconnect_initial_delay_seconds=config.reconnect_initial_delay_seconds,
reconnect_max_delay_seconds=config.reconnect_max_delay_seconds,
websocket_ping_interval=config.websocket_ping_interval,
)
if not config.enabled:
return runtime
if config.scope == "global":
custom_websocket_path = (config.websocket_path or "").strip()
if custom_websocket_path:
suffix = custom_websocket_path
if not suffix.startswith("/"):
suffix = f"/{suffix}"
elif config.scope == "global":
if config.username != "global":
raise BrokerConfigError("global broker scope requires username 'global'")
suffix = "/api/coderai/wss"
......@@ -88,7 +116,7 @@ def build_broker_runtime_config(config: BrokerConfig) -> BrokerRuntimeConfig:
scheme = {"http": "ws", "https": "wss"}.get(split_url.scheme, split_url.scheme)
base_path = split_url.path.rstrip("/")
path = f"{base_path}{suffix}" if base_path else suffix
path = _join_broker_path(base_path, suffix)
query = urlencode(
{
"provider_id": config.provider_id,
......@@ -103,6 +131,7 @@ def build_broker_runtime_config(config: BrokerConfig) -> BrokerRuntimeConfig:
"x-coderai-provider-id": config.provider_id,
"x-coderai-client-id": config.client_id,
"x-coderai-username": config.username,
"x-coderai-registration-token": config.registration_token,
"x-coderai-advertised-endpoint": config.advertised_endpoint,
}
return runtime
......@@ -3,13 +3,17 @@
from __future__ import annotations
import json
import logging
from base64 import b64encode
from base64 import b64decode
from time import perf_counter
from typing import Any
from codai.broker.asgi_bridge import execute_internal_request
from codai.broker.models import error_envelope, success_envelope
logger = logging.getLogger(__name__)
SUPPORTED_PREFIXES = (
"/v1/models",
"/v1/chat/completions",
......@@ -17,9 +21,18 @@ SUPPORTED_PREFIXES = (
"/v1/audio",
"/v1/video",
"/v1/pipelines",
"/v1/files",
"/coderai/capabilities",
"/admin",
"/static",
)
OP_ROUTE_MAP = {
"models.list": ("GET", "/v1/models"),
"chat.completions": ("POST", "/v1/chat/completions"),
"capabilities": ("GET", "/coderai/capabilities"),
}
TEXT_CONTENT_TYPES = (
"application/json",
"application/ld+json",
......@@ -50,9 +63,46 @@ def _is_text_response(content_type: str | None) -> bool:
async def execute_broker_request(app, envelope):
"""Validate and execute a broker request envelope."""
logger.debug(
"broker dispatch → op=%s request_id=%s path=%r method=%r stream=%s",
envelope.op, envelope.request_id, envelope.path, envelope.method, envelope.stream,
)
if envelope.op == "proxy":
proxy_payload = envelope.payload or {}
endpoint_path = str(proxy_payload.get("endpoint_path") or envelope.path or "").strip()
if endpoint_path and not endpoint_path.startswith("/"):
endpoint_path = f"/{endpoint_path}"
envelope.path = endpoint_path
envelope.method = str(proxy_payload.get("method") or envelope.method or "GET").upper()
envelope.headers = dict(proxy_payload.get("headers") or envelope.headers)
envelope.query = dict(proxy_payload.get("query_params") or envelope.query)
envelope.stream = bool(proxy_payload.get("stream", envelope.stream))
if "body" in proxy_payload:
envelope.payload = proxy_payload.get("body")
elif proxy_payload.get("body_base64") is not None:
envelope.payload = b64decode(proxy_payload.get("body_base64") or "")
elif proxy_payload.get("multipart") is not None:
envelope.payload = {"_broker_multipart": proxy_payload.get("multipart")}
logger.debug("broker dispatch proxy resolved → %s %s", envelope.method, envelope.path)
elif envelope.op in OP_ROUTE_MAP:
envelope.method, envelope.path = OP_ROUTE_MAP[envelope.op]
logger.debug("broker dispatch op mapped → %s %s", envelope.method, envelope.path)
elif not envelope.path:
logger.warning("broker dispatch unsupported op=%s request_id=%s", envelope.op, envelope.request_id)
return error_envelope(
envelope.request_id,
code="unsupported_operation",
message=f"Unsupported broker op: {envelope.op}",
)
envelope.validate()
if not is_supported_path(envelope.path):
logger.warning(
"broker dispatch unsupported path=%r op=%s request_id=%s",
envelope.path, envelope.op, envelope.request_id,
)
return error_envelope(
envelope.request_id,
code="unsupported_endpoint",
......@@ -60,18 +110,28 @@ async def execute_broker_request(app, envelope):
)
body: bytes
if isinstance(envelope.payload, (dict, list)):
if isinstance(envelope.payload, dict) and "_broker_multipart" in envelope.payload:
from codai.broker.asgi_bridge import _build_multipart_body
body, multipart_content_type = _build_multipart_body(envelope.payload["_broker_multipart"] or {})
headers = dict(envelope.headers)
headers["content-type"] = multipart_content_type
elif isinstance(envelope.payload, (dict, list)):
body = json.dumps(envelope.payload, separators=(",", ":")).encode("utf-8")
headers = dict(envelope.headers)
elif isinstance(envelope.payload, str):
body = envelope.payload.encode("utf-8")
headers = dict(envelope.headers)
elif isinstance(envelope.payload, bytes):
body = envelope.payload
headers = dict(envelope.headers)
elif envelope.payload is None:
body = b""
headers = dict(envelope.headers)
else:
body = json.dumps(envelope.payload, separators=(",", ":")).encode("utf-8")
headers = dict(envelope.headers)
if body and "content-type" not in {key.lower() for key in headers}:
headers["content-type"] = envelope.content_type
......@@ -103,17 +163,36 @@ async def execute_broker_request(app, envelope):
payload["content_type"] = content_type
if _is_text_response(content_type):
payload["body"] = response["body"].decode("utf-8")
body_text = response["body"].decode("utf-8")
payload["body"] = body_text
logger.debug(
"broker dispatch ← op=%s request_id=%s status=%d elapsed_ms=%s body_bytes=%d body_preview=%r",
envelope.op, envelope.request_id,
response["status_code"], elapsed_ms,
len(response["body"]),
body_text[:300],
)
else:
payload["body_base64"] = b64encode(response["body"]).decode("ascii")
filename = response_headers.get("x-filename")
if filename:
payload["filename"] = filename
logger.debug(
"broker dispatch ← op=%s request_id=%s status=%d elapsed_ms=%s body_bytes=%d (binary)",
envelope.op, envelope.request_id,
response["status_code"], elapsed_ms,
len(response["body"]),
)
if envelope.stream:
payload["stream"] = True
return success_envelope(
result = success_envelope(
envelope.request_id,
payload=payload,
metrics={"elapsed_ms": elapsed_ms},
)
logger.debug(
"broker dispatch envelope_bytes=%d op=%s request_id=%s",
len(json.dumps(result)), envelope.op, envelope.request_id,
)
return result
......@@ -9,8 +9,9 @@ class BrokerRequestEnvelope:
"""Normalized broker request payload."""
request_id: str
method: str
path: str
op: str
method: str = "GET"
path: str = ""
headers: Dict[str, str] = field(default_factory=dict)
query: Dict[str, Any] = field(default_factory=dict)
payload: Any = None
......@@ -22,9 +23,13 @@ class BrokerRequestEnvelope:
if not self.request_id or not isinstance(self.request_id, str):
raise ValueError("request_id is required")
if not self.op or not isinstance(self.op, str):
raise ValueError("op is required")
if self.op != "proxy" and not self.path:
raise ValueError("path is required")
if not self.method or not isinstance(self.method, str):
raise ValueError("method is required")
if not self.path or not isinstance(self.path, str):
if self.path and not isinstance(self.path, str):
raise ValueError("path is required")
......@@ -32,8 +37,9 @@ def success_envelope(request_id: str, payload: Any, event: str | None = None, me
"""Build a success response envelope."""
envelope = {
"v": 1,
"request_id": request_id,
"ok": True,
"status": "ok",
"payload": payload,
}
if event is not None:
......@@ -53,7 +59,8 @@ def error_envelope(request_id: str, code: str, message: str, details: Dict[str,
if details is not None:
error["details"] = details
return {
"v": 1,
"request_id": request_id,
"ok": False,
"status": "error",
"error": error,
}
......@@ -18,14 +18,17 @@ class BrokerService:
self.client.dispatcher = dispatch
self.task: asyncio.Task | None = None
self._started = False
def start(self):
if not self.client.runtime.enabled or self.task is not None:
if not self.client.runtime.enabled or self.task is not None or self._started:
return
self._started = True
self.task = asyncio.create_task(self.client.run_forever())
async def stop(self):
if self.task is None:
self._started = False
return
task = self.task
self.task = None
......@@ -34,3 +37,5 @@ class BrokerService:
await task
except asyncio.CancelledError:
pass
finally:
self._started = False
......@@ -18,6 +18,7 @@
import sys
import os
import logging
import threading as _t
# Import configuration from codai modules
from codai.cli import parse_args
......@@ -30,6 +31,83 @@ from codai.broker import BrokerConfigError, build_broker_runtime_config
logger = logging.getLogger(__name__)
def _migrate_hf_gguf_to_gguf_cache() -> None:
"""Move GGUF files stored in the HF cache into the flat GGUF cache directory.
Runs once at startup in a background thread. For repos whose only
non-trivial content is GGUF files, the HF cache entry is removed after
the files are safely copied across.
"""
import shutil
from codai.models.cache import get_hf_hub_cache_dir, get_model_cache_dir
hf_dir = get_hf_hub_cache_dir()
if not os.path.exists(hf_dir):
return
gguf_cache = get_model_cache_dir()
try:
from huggingface_hub import scan_cache_dir
info = scan_cache_dir(hf_dir)
except Exception:
return
_TRIVIAL_EXTS = {'.json', '.txt', '.md', '.py', '.gitattributes', '.model', '.tiktoken', '.vocab'}
migrated_count = 0
repos_to_purge = [] # HF cache repo dirs safe to delete after migration
for repo in info.repos:
if not repo.revisions:
continue
latest_rev = sorted(repo.revisions, key=lambda r: r.commit_hash)[-1]
gguf_files = [f for f in latest_rev.files if f.file_name.endswith('.gguf')]
if not gguf_files:
continue
# Determine whether this repo contains ONLY gguf + trivial metadata
non_trivial = [
f for f in latest_rev.files
if os.path.splitext(f.file_name)[1].lower() not in _TRIVIAL_EXTS
]
gguf_only_repo = non_trivial and all(f.file_name.endswith('.gguf') for f in non_trivial)
all_migrated = True
for f in gguf_files:
dest = os.path.join(gguf_cache, os.path.basename(f.file_name))
if os.path.exists(dest):
continue # already present in GGUF cache
try:
src = os.path.realpath(str(f.file_path)) # resolve symlink → blob
shutil.copy2(src, dest)
migrated_count += 1
logger.info("Migrated GGUF: %s → %s", f.file_name, dest)
except Exception as exc:
logger.warning("Could not migrate %s: %s", f.file_name, exc)
all_migrated = False
if gguf_only_repo and all_migrated:
repo_dir = os.path.join(hf_dir, f"models--{repo.repo_id.replace('/', '--')}")
if os.path.isdir(repo_dir):
repos_to_purge.append((repo.repo_id, repo_dir))
for repo_id, repo_dir in repos_to_purge:
try:
shutil.rmtree(repo_dir)
logger.info("Removed migrated HF cache entry: %s", repo_id)
except Exception as exc:
logger.warning("Could not remove HF cache entry %s: %s", repo_id, exc)
if migrated_count or repos_to_purge:
logger.info(
"GGUF cache migration: %d file(s) moved to %s, %d HF cache entr%s cleaned up.",
migrated_count, gguf_cache,
len(repos_to_purge), "ies" if len(repos_to_purge) != 1 else "y",
)
def main():
"""Main entry point for the codai server."""
# Suppress unraisable exceptions from LlamaModel.__del__
......@@ -54,10 +132,22 @@ def main():
config_mgr = ConfigManager(config_dir)
config = config_mgr.load()
# Apply cache directory overrides from config before any cache module is used
# Apply cache directory overrides from config before any cache module is used.
# We set env vars AND patch huggingface_hub.constants in case the library was
# already imported (constants are computed once at import time from env vars).
if config.models.hf_cache_dir:
hf_hub_cache = os.path.join(config.models.hf_cache_dir, 'hub')
os.environ['HF_HOME'] = config.models.hf_cache_dir
os.environ['HUGGINGFACE_HUB_CACHE'] = config.models.hf_cache_dir
os.environ['HUGGINGFACE_HUB_CACHE'] = hf_hub_cache
try:
import sys as _sys
if 'huggingface_hub.constants' in _sys.modules:
import huggingface_hub.constants as _hfc
_hfc.HF_HUB_CACHE = hf_hub_cache
if hasattr(_hfc, 'HF_HOME'):
_hfc.HF_HOME = config.models.hf_cache_dir
except Exception:
pass
if config.models.gguf_cache_dir:
os.environ['CODERAI_CACHE_DIR'] = config.models.gguf_cache_dir
......@@ -75,7 +165,7 @@ def main():
# Initialize admin session manager and expose config to admin routes
from pathlib import Path
init_session_manager(Path(config_dir))
init_session_manager(Path(config_dir), port=config.server.port)
set_config_manager(config_mgr)
# Handle early exit options (before heavy imports)
......@@ -145,7 +235,8 @@ def main():
print("No GGUF files found, downloading full HuggingFace repo...")
try:
from huggingface_hub import snapshot_download
cached_path = snapshot_download(model_id)
from codai.models.cache import get_hf_hub_cache_dir
cached_path = snapshot_download(model_id, cache_dir=get_hf_hub_cache_dir())
except Exception as e:
print(f"Error downloading full repo: {e}")
cached_path = None
......@@ -175,6 +266,9 @@ def main():
print(f"Error listing devices: {e}")
sys.exit(0)
# Migrate any GGUF files that ended up in the HF cache to the GGUF cache
_t.Thread(target=_migrate_hf_gguf_to_gguf_cache, daemon=True).start()
# Import core modules (only after early exits)
from codai.api import app
from codai.api.state import (
......@@ -724,6 +818,22 @@ def main():
queue_manager.max_size = config.server.queue_max_size
queue_manager.max_parallel_requests = config.server.max_parallel_requests
# Configure Python logging so broker/API log calls reach the terminal.
# uvicorn is started with log_config=None to keep our config in place.
_log_level = logging.DEBUG if global_debug else logging.INFO
logging.basicConfig(
level=_log_level,
format="%(asctime)s [%(levelname)-8s] %(name)s: %(message)s",
stream=sys.stdout,
force=True,
)
# Suppress noisy third-party libraries at WARNING unless in debug mode.
for _noisy in ("httpx", "httpcore", "urllib3", "multipart", "PIL"):
logging.getLogger(_noisy).setLevel(logging.WARNING)
if not global_debug:
logging.getLogger("websockets").setLevel(logging.WARNING)
logging.getLogger("asyncio").setLevel(logging.WARNING)
# Start the server
import uvicorn
print(f"\nStarting server on http://{config.server.host}:{config.server.port}")
......@@ -736,6 +846,8 @@ def main():
actual_backend = "cuda (via llama-cpp-python)"
print(f"Using backend: {actual_backend}")
_uvi_log_level = "debug" if global_debug else "info"
if config.server.https:
import ssl
ssl_keyfile = config.server.https_key_path
......@@ -758,14 +870,17 @@ def main():
except Exception as e:
print(f"Warning: Could not generate certificate: {e}")
print("Falling back to HTTP...")
uvicorn.run(fastapi_app, host=config.server.host, port=config.server.port)
uvicorn.run(fastapi_app, host=config.server.host, port=config.server.port,
log_level=_uvi_log_level, log_config=None)
return
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
ssl_context.load_cert_chain(ssl_certfile, ssl_keyfile)
uvicorn.run(fastapi_app, host=config.server.host, port=config.server.port, ssl_context=ssl_context)
uvicorn.run(fastapi_app, host=config.server.host, port=config.server.port,
ssl_context=ssl_context, log_level=_uvi_log_level, log_config=None)
else:
uvicorn.run(fastapi_app, host=config.server.host, port=config.server.port)
uvicorn.run(fastapi_app, host=config.server.host, port=config.server.port,
log_level=_uvi_log_level, log_config=None)
if __name__ == "__main__":
......
......@@ -56,6 +56,26 @@ def get_model_cache_dir() -> str:
return cache_dir
def get_hf_hub_cache_dir() -> str:
"""Return the HuggingFace Hub cache directory CoderAI is configured to use.
Mirrors huggingface_hub's own env-var priority so that passing this value
as ``cache_dir`` to snapshot_download / hf_hub_download always targets the
same location the library would choose on its own — even if the directory
does not yet exist (first download).
"""
# Priority mirrors huggingface_hub.constants:
# HUGGINGFACE_HUB_CACHE (explicit cache path)
# HF_HOME/hub (parent-home style)
# ~/.cache/huggingface/hub (built-in default)
hf_hub_cache = (
os.environ.get('HUGGINGFACE_HUB_CACHE')
or (os.path.join(os.environ['HF_HOME'], 'hub') if 'HF_HOME' in os.environ else None)
or os.path.join(os.path.expanduser('~'), '.cache', 'huggingface', 'hub')
)
return hf_hub_cache
def get_all_cache_dirs() -> dict:
"""Get all model cache directories."""
caches = {}
......@@ -540,6 +560,7 @@ def remove_all_cached_models() -> int:
# Export all public functions
__all__ = [
'get_model_cache_dir',
'get_hf_hub_cache_dir',
'get_all_cache_dirs',
'get_cached_model_path',
'is_huggingface_model_id',
......
......@@ -110,12 +110,10 @@ def detect_model_capabilities(model_name: str) -> ModelCapabilities:
'animatediff', 'text2video', 'modelscope-t2v',
'zeroscope', 'lavie']):
caps.video_generation = True
caps.text_generation = True # T2V models also do text
return caps
if any(x in n for x in ['wan2.1-t2v', 'wan-t2v']):
caps.video_generation = True
caps.text_generation = True
return caps
# Image-to-video
......@@ -124,17 +122,14 @@ def detect_model_capabilities(model_name: str) -> ModelCapabilities:
'wan2.1-i2v', 'wan-i2v', 'img2vid',
'image2video', 'motionctrl']):
caps.image_to_video = True
caps.image_to_text = True # I2V models process images
return caps
# Wan generic (detect sub-variant)
if 'wan' in n and ('video' in n or 'diffuser' in n):
if 'i2v' in n:
caps.image_to_video = True
caps.image_to_text = True
else:
caps.video_generation = True
caps.text_generation = True
return caps
# Video interpolation
......@@ -158,7 +153,6 @@ def detect_model_capabilities(model_name: str) -> ModelCapabilities:
if any(x in n for x in ['musicgen', 'audiogen', 'audioldm', 'stable-audio',
'mustango', 'noise2music', 'jukebox', 'audiocraft']):
caps.audio_generation = True
caps.text_generation = True # T2A models process text
return caps
if any(x in n for x in ['demucs', 'spleeter', 'asteroid', 'open-unmix']):
......@@ -174,7 +168,6 @@ def detect_model_capabilities(model_name: str) -> ModelCapabilities:
if any(x in n for x in ['kokoro', 'xtts', 'bark', 'tortoise',
'speecht5', 'matcha-tts', 'voicebox']):
caps.text_to_speech = True
caps.text_generation = True # TTS models process text
return caps
# Lip sync / dubbing
......@@ -199,13 +192,11 @@ def detect_model_capabilities(model_name: str) -> ModelCapabilities:
caps.inpainting = True
caps.image_generation = True
caps.image_to_image = True
caps.text_generation = True
return caps
if 'controlnet' in n:
caps.controlnet = True
caps.image_generation = True
caps.text_generation = True
return caps
if any(x in n for x in [
......@@ -235,7 +226,6 @@ def detect_model_capabilities(model_name: str) -> ModelCapabilities:
caps.image_generation = True
caps.image_to_image = True
caps.inpainting = True # most SD/SDXL/Flux checkpoints support inpainting via mask
caps.text_generation = True
return caps
# ── Image: analysis / processing ─────────────────────────────────────────
......@@ -295,12 +285,6 @@ def detect_model_capabilities(model_name: str) -> ModelCapabilities:
'text-embedding', 'voyage-',
]):
caps.embeddings = True
caps.text_generation = True
return caps
# ── GGUF quantised text models ───────────────────────────────────────────
if '.gguf' in n or 'gguf' in n:
caps.text_generation = True
return caps
# Default: text generation
......@@ -315,17 +299,17 @@ _PIPELINE_TAG_CAPS: dict = {
'image-to-text': ['image_to_text', 'text_generation'],
'visual-question-answering': ['image_to_text', 'text_generation'],
'image-text-to-text': ['image_to_text', 'text_generation'],
'text-to-image': ['image_generation', 'image_to_image', 'text_generation'],
'text-to-image': ['image_generation', 'image_to_image'],
'unconditional-image-generation': ['image_generation'],
'image-to-image': ['image_to_image'], # sub-typed below
'automatic-speech-recognition': ['speech_to_text'],
'audio-to-audio': ['audio_to_audio'],
'text-to-speech': ['text_to_speech'],
'text-to-audio': ['audio_generation'],
'text-to-video': ['video_generation', 'text_generation'],
'text-to-video': ['video_generation'],
'image-to-video': ['image_to_video'],
'feature-extraction': ['embeddings', 'text_generation'],
'sentence-similarity': ['embeddings', 'text_generation'],
'feature-extraction': ['embeddings'],
'sentence-similarity': ['embeddings'],
'depth-estimation': ['depth_estimation', 'image_to_text'],
'image-segmentation': ['image_segmentation', 'image_to_text'],
'object-detection': ['object_detection', 'image_to_text'],
......
......@@ -1541,6 +1541,10 @@ class MultiModelManager:
2. Local HuggingFace hub cache scan.
3. HuggingFace API (network, one call per model per process lifetime).
For LoRA adapters (repos containing adapter_config.json) the size of
the base model is added so that VRAM requirements are not
underestimated.
Returns 0 on any failure.
"""
if model_id in MultiModelManager._hf_size_cache:
......@@ -1548,10 +1552,19 @@ class MultiModelManager:
weight_exts = {'.safetensors', '.bin', '.gguf', '.ggml', '.pt'}
def _resolve_base_model(base_model_id: str) -> int:
from codai.models.cache import is_huggingface_model_id
if not base_model_id or base_model_id == model_id:
return 0
if not is_huggingface_model_id(base_model_id):
return 0
return MultiModelManager._hf_cached_model_size_bytes(base_model_id)
# --- Try local HF hub cache first (no network) ---
try:
from huggingface_hub import scan_cache_dir
from codai.models.cache import get_all_cache_dirs
import json as _json
hf_dir = get_all_cache_dirs().get("huggingface")
if hf_dir:
info = scan_cache_dir(hf_dir)
......@@ -1560,11 +1573,26 @@ class MultiModelManager:
continue
revs = sorted(repo.revisions, key=lambda r: r.last_modified, reverse=True)
if revs:
rev = revs[0]
total = sum(
f.size_on_disk
for f in revs[0].files
for f in rev.files
if os.path.splitext(f.file_name)[1].lower() in weight_exts
)
# LoRA adapter: add base model size
for f in rev.files:
if f.file_name == "adapter_config.json":
try:
with open(f.file_path) as fp:
adapter_cfg = _json.load(fp)
base_id = (
adapter_cfg.get("base_model_name_or_path")
or adapter_cfg.get("base_model")
)
total += _resolve_base_model(base_id)
except Exception:
pass
break
if total > 0:
MultiModelManager._hf_size_cache[model_id] = total
return total
......@@ -1580,13 +1608,31 @@ class MultiModelManager:
with urllib.request.urlopen(req, timeout=10) as resp:
data = _json.loads(resp.read())
total = 0
has_adapter_config = False
for sib in data.get("siblings", []):
name = sib.get("rfilename", "")
if name == "adapter_config.json":
has_adapter_config = True
continue
if os.path.splitext(name)[1].lower() not in weight_exts:
continue
lfs = sib.get("lfs") or {}
size = lfs.get("size") or sib.get("size") or 0
total += size
# LoRA adapter: fetch adapter_config.json to get the base model
if has_adapter_config:
try:
cfg_url = f"https://huggingface.co/{model_id}/resolve/main/adapter_config.json"
cfg_req = urllib.request.Request(cfg_url, headers={"User-Agent": "coderai/1.0"})
with urllib.request.urlopen(cfg_req, timeout=10) as resp:
adapter_cfg = _json.loads(resp.read())
base_id = (
adapter_cfg.get("base_model_name_or_path")
or adapter_cfg.get("base_model")
)
total += _resolve_base_model(base_id)
except Exception:
pass
if total > 0:
MultiModelManager._hf_size_cache[model_id] = total
return total
......@@ -2101,9 +2147,12 @@ class MultiModelManager:
needed_gb = self._get_model_used_vram_gb(model_key, resolved_name)
free_gb = self._get_free_vram_gb()
if needed_gb > 0 and free_gb >= needed_gb:
# Require headroom beyond raw weight size for activation buffers
# and generation scratch (30% of model size + 1 GB base).
headroom_gb = max(1.0, needed_gb * 0.30)
if needed_gb > 0 and free_gb >= needed_gb + headroom_gb:
print(f"Ondemand mode - keeping '{loaded_canonical}' in VRAM alongside new model "
f"(need {needed_gb:.1f} GB, have {free_gb:.1f} GB free)")
f"(need {needed_gb:.1f} GB + {headroom_gb:.1f} GB headroom, have {free_gb:.1f} GB free)")
else:
print(f"Ondemand mode - model switch detected:")
print(f" Requested: '{model_key}' (resolved: '{resolved_name}')")
......
......@@ -89,6 +89,22 @@ The outbound WebSocket connection must include:
- `username`: either `global` or the AISBF username for user-owned providers
- `registration_token`: provider-scoped secret from AISBF provider configuration
### Current server-side resolution order
AISBF resolves broker identity in this exact order when the WebSocket handshake arrives:
- `provider_id`: query param `provider_id`, then header `x-coderai-provider-id`, then default `coderai`
- `client_id`: query param `client_id`, then header `x-coderai-client-id`, then generated fallback `anon-<unix_timestamp>`
- `username`: query param `username`, then header `x-coderai-username`, then the path scope name (`global` or the `/api/u/{username}` path segment)
- `registration_token`: query param `registration_token`, then header `x-coderai-registration-token`
Important constraints:
- the `registration_token` is required for admission
- `Authorization: Bearer ...` is currently not used by the broker WebSocket admission check
- if you omit `client_id`, AISBF generates an `anon-*` client id and future broker routing will only work if AISBF also targets that exact generated value
- the `client_id` used by the CoderAI client must match the `coderai_config.client_id` used by the AISBF provider, or the broker can show the session as connected while requests still fail to route
## Optional Headers
AISBF also accepts or may expect these headers:
......@@ -109,6 +125,35 @@ Recommended behavior:
Open the outbound WebSocket to the correct scoped AISBF endpoint.
The handshake is a normal WebSocket upgrade request, which starts as an HTTP `GET` carrying query parameters. This is expected.
Recommended connect template:
```text
wss://<aisbf-host>/<optional-prefix>/api/coderai/wss?provider_id=<provider_id>&client_id=<stable_client_id>&username=global&registration_token=<provider_registration_token>
```
User-scoped template:
```text
wss://<aisbf-host>/<optional-prefix>/api/u/<username>/coderai/wss?provider_id=<provider_id>&client_id=<stable_client_id>&username=<username>&registration_token=<provider_registration_token>
```
Recommended handshake headers:
```text
x-coderai-provider-id: <provider_id>
x-coderai-client-id: <stable_client_id>
x-coderai-username: <username>
x-coderai-registration-token: <provider_registration_token>
```
Best practice:
- send the same identity in both query parameters and headers
- keep `client_id` stable across reconnects
- always reconnect with the same provider scope and owner scope
### 2. Wait for `registered` event
AISBF immediately sends a registration acknowledgment event on successful admission.
......@@ -135,11 +180,21 @@ Store:
- `client_id`
- `username`
- `scope_name`
- `owner_user_id`
- `expires_at`
Notes:
- this event means the socket is admitted and the session row exists
- it does not yet mean hardware/capabilities metadata has been uploaded
- the client should send the explicit `register` operation immediately after this event
### 3. Send explicit `register` operation
After the `registered` event, CoderAI must send a `register` message describing its capabilities, hardware inventory, and advertised endpoints.
AISBF currently processes `register` as a normal inbound WebSocket message and responds with `status=ok` using the same `request_id`.
### 4. Enter long-lived receive loop
Then keep listening for incoming broker requests from AISBF.
......@@ -233,6 +288,60 @@ CoderAI should send this after receiving the initial AISBF `registered` event.
AISBF replies with a success envelope.
### Fields AISBF currently reads from the `register` message
Top-level:
- `v`
- `op` with value `register`
- `request_id`
- optional top-level `registration_token`
- optional top-level `capabilities`
From `payload`:
- `endpoint`
- `transport`
- `registration_token`
- `studio_endpoints`
- `hardware`
- `gpus`
- `gpu_count`
- `total_vram_mb`
- `available_vram_mb`
- `capabilities`
AISBF behavior:
- if `payload.registration_token` or top-level `registration_token` is present and does not match the handshake token, AISBF replies with an error envelope
- if token matches, AISBF persists the metadata onto the broker session
- `payload.capabilities` takes precedence over missing top-level capability data
- if `gpus`, `gpu_count`, `total_vram_mb`, or `available_vram_mb` are omitted at the top level, AISBF falls back to the values inside `payload.hardware`
Minimal acceptable `register` message:
```json
{
"v": 1,
"op": "register",
"request_id": "reg-1",
"payload": {
"transport": "websocket",
"registration_token": "<same_registration_token>",
"capabilities": {}
}
}
```
Recommended full `register` message:
- include `endpoint`
- include `transport`
- include `registration_token`
- include `hardware.gpus`, `hardware.gpu_count`, `hardware.total_vram_mb`, `hardware.available_vram_mb`
- include `studio_endpoints`
- include `capabilities`
### Hardware Reporting Requirements
The `register` payload should include the best hardware view available to the running CoderAI process.
......@@ -326,6 +435,37 @@ Heartbeat payloads may also refresh dynamic hardware state such as changing free
}
```
Current AISBF note:
- AISBF acknowledges heartbeat messages and merges the heartbeat `payload` into session metadata
- keep heartbeat payloads small and non-blocking
- use heartbeats for lightweight dynamic updates only; do not block the main receive loop on expensive hardware rescans
## Async Client Requirements
The broker WebSocket integration must be fully asynchronous.
CoderAI client requirements:
- the main receive loop must never block on model loading, inference, GPU inspection, or disk/network I/O
- expensive work should run in background tasks or worker executors while the socket remains responsive to incoming frames and ping/pong traffic
- the client should be able to receive broker requests while also sending progress or result frames for earlier requests
- the client must not serialize all work behind registration or heartbeat handling
AISBF broker behavior:
- AISBF now drains queued outbound broker requests in a background async task while independently reading inbound websocket messages
- this means the CoderAI client should expect inbound requests to arrive even while it is still sending heartbeat or response messages for unrelated work
- operations are correlated strictly by `request_id`; client implementations must not rely on message ordering alone
Recommended client architecture:
1. one async reader task for inbound WebSocket frames
2. one async writer path or send queue for outbound replies/events
3. per-request async tasks for local execution
4. a lightweight periodic heartbeat task
5. explicit request correlation by `request_id`
AISBF merges those updates into the broker session metadata.
## Local HTTP Endpoints CoderAI Should Expose
......
# CoderAI Broker Implementation Reference
## Purpose
This document is the single source of truth for implementing the CoderAI side of the AISBF broker and bridge integration.
The target audience is another LLM or engineer implementing CoderAI, not AISBF.
This document is mirrored in `docs/coderai-broker-implementation-reference.md` and should be kept identical in purpose and protocol coverage.
## AISBF broker mode
AISBF now includes a public broker-side WebSocket endpoint for outbound-only NAT traversal.
- Broker WebSocket endpoint: `/api/coderai/broker/ws`
- Broker WebSocket endpoints:
- global scope: `/api/coderai/wss`
- user scope: `/api/u/{username}/coderai/wss`
- Broker session status endpoint: `/api/coderai/broker/providers/{provider_id}/status`
- Broker session listing endpoint: `/api/coderai/broker/sessions`
Each CoderAI provider is owned either by:
- the global config admin (`user_id = null`), or
- a specific AISBF user (`user_id = <id>`)
Registration tokens are resolved from the owning provider configuration. This means:
- the global admin configures the token for globally configured `coderai` providers
- each user configures the token for their own user-scoped `coderai` providers
- a broker session is only usable by requests belonging to the same owner principal
Broker registration is now scope-aware:
- global providers register with `username=global`
- user-owned providers register with `username=<aisbf_username>`
- the same scoped path must be used by the CoderAI client when connecting over WebSocket
- deployments behind TLS termination or reverse proxies must connect with the externally visible `wss://...` URL and preserve proxy headers so AISBF can remain scheme-aware
The AISBF dashboard now exposes this token directly inside each `coderai` provider configuration:
- token input is stored in `coderai_config.registration_token`
- global admins edit global provider tokens in the admin providers page
- users edit their own provider tokens in the user providers page
- token rotation is available inline and returns a newly generated provider-scoped secret
- broker session status is shown directly in the provider editor, including owner, client id, transport, last seen, and advertised Studio endpoints
CoderAI can keep a persistent outbound connection open to AISBF, register itself, and then receive routed provider operations over that same socket.
## What AISBF now expects
### Provider type
Use provider type:
```json
{
"type": "coderai"
}
```
### Provider config shape
```json
{
"id": "coderai",
"name": "CoderAI Local Bridge",
"endpoint": "http://127.0.0.1:11437",
"type": "coderai",
"api_key_required": false,
"coderai_config": {
"transport": "http",
"http_enabled": true,
"websocket_enabled": true,
"broker_enabled": true,
"broker_mode": false,
"broker_preferred": true,
"discovery_enabled": true,
"client_id": "aisbf-default",
"bridge_path": "/coderai/ws",
"registration_path": "/coderai/register",
"registration_token": "optional-shared-secret",
"bridge_token": "optional-bridge-secret",
"request_timeout": 300,
"model_timeout": 30
}
}
```
### AISBF behaviors
- For `transport=http`, AISBF uses the OpenAI Python client against `endpoint + /v1`.
- For `transport=websocket`, AISBF uses a WebSocket bridge and sends framed JSON envelopes.
- AISBF uses `models.list`, `chat.completions`, `capabilities`, `register`, and `proxy` bridge operations.
- `proxy` now supports arbitrary forwarded request headers, query params, multipart form payloads, binary/base64 bodies, progress polling endpoints, and non-chat streaming event envelopes for long-running jobs.
- AISBF treats `coderai` like an OpenAI-style Studio adapter family.
- AISBF can also forward arbitrary Studio-native endpoints through `proxy` when the provider transport is WebSocket.
- AISBF validates that broker-enabled `coderai` providers have a non-empty `registration_token`.
- AISBF persists broker session metadata to `~/.aisbf/coderai_broker_sessions.json` so the dashboard can still show the last known broker session after restart, even while disconnected.
## Required CoderAI HTTP endpoints
### 1. OpenAI-compatible endpoints
CoderAI should already expose these when HTTP mode is enabled:
- `GET /v1/models`
- `POST /v1/chat/completions`
- optional additional OpenAI-compatible endpoints that Studio may use directly via generic proxy
The `/v1/models` response should preferably include as much metadata as possible:
```json
{
"data": [
{
"id": "llama3.1:8b",
"name": "llama3.1:8b",
"description": "Local general-purpose chat model",
"context_length": 131072,
"architecture": {
"input_modalities": ["text"],
"output_modalities": ["text"]
},
"supported_parameters": ["temperature", "top_p", "max_tokens"],
"default_parameters": {
"temperature": 0.7
},
"pricing": null,
"studio_capabilities": ["chat", "tool_use", "code_generation"]
}
]
}
```
### 2. Capabilities endpoint
Expose:
- `GET /coderai/capabilities`
Recommended response:
```json
{
"server": {
"name": "coderai",
"version": "0.1.0"
},
"transports": {
"http": true,
"websocket": true
},
"openai_compat": {
"chat_completions": true,
"models": true,
"responses": false,
"embeddings": true,
"images": true,
"audio": true
},
"studio": {
"enabled": true,
"endpoints": [
"v1/images/generate",
"v1/audio/tts",
"v1/audio/transcriptions",
"v1/video/dub"
]
},
"models": [
{
"id": "llama3.1:8b",
"studio_capabilities": ["chat", "tool_use", "code_generation"]
}
]
}
```
## Required WebSocket bridge
### Connection URL
CoderAI should accept WebSocket clients on:
- `/coderai/ws`
or another configured path mirrored in `coderai_config.bridge_path`.
### Headers AISBF sends
- `Authorization: Bearer <bridge_token_or_registration_token_or_api_key>` if available
- `x-coderai-client-id: <client_id>`
- `x-coderai-provider-id: <provider_id>`
### Broker connection query params
When CoderAI dials AISBF broker directly, it should connect using:
- `provider_id=<provider_id>`
- `client_id=<client_id>`
- `username=<username-or-global>`
- `registration_token=<owner-configured-token>`
AISBF broker admission currently resolves connection data in this order:
- `provider_id`: query param, then `x-coderai-provider-id`, then default `coderai`
- `client_id`: query param, then `x-coderai-client-id`, then generated fallback `anon-<timestamp>`
- `username`: query param, then `x-coderai-username`, then the scoped path value
- `registration_token`: query param, then `x-coderai-registration-token`
Important:
- the WebSocket broker flow starts as an HTTP `GET` upgrade request; this is expected
- the broker currently validates `registration_token`, not `Authorization: Bearer ...`, during WebSocket admission
- `client_id` must stay stable and match the `coderai_config.client_id` AISBF uses for that provider, or requests like `models.list` may fail even if the dashboard shows the session as connected
Example:
```text
wss://your-aisbf.example/api/coderai/wss?provider_id=coderai&client_id=workstation-01&username=global&registration_token=<owner-configured-token>
```
User-scoped example:
```text
wss://your-aisbf.example/api/u/alice/coderai/wss?provider_id=my-coderai&client_id=workstation-01&username=alice&registration_token=<owner-configured-token>
```
Recommended duplicate headers for robustness:
```text
x-coderai-provider-id: <provider_id>
x-coderai-client-id: <client_id>
x-coderai-username: <username>
x-coderai-registration-token: <registration_token>
```
### Broker registration flow
1. Open the WebSocket to the correct scoped AISBF broker URL.
2. Wait for AISBF to send an initial `registered` event.
3. Immediately send `op="register"` with hardware, capability, and endpoint metadata.
4. Wait for the `status="ok"` registration ack with the same `request_id`.
5. Enter the long-lived async broker loop for heartbeats, inbound requests, and outbound responses.
Initial server event example:
```json
{
"v": 1,
"event": "registered",
"session_id": "coderai_abc123",
"provider_id": "coderai",
"client_id": "workstation-01",
"username": "global",
"scope_name": "global",
"accepted": true,
"owner_user_id": null,
"expires_at": 1747179540
}
```
Required client follow-up:
```json
{
"v": 1,
"op": "register",
"request_id": "reg-1",
"payload": {
"endpoint": "wss://client-or-descriptive-local-endpoint",
"transport": "websocket",
"registration_token": "<same_registration_token>",
"hardware": {
"gpus": [],
"gpu_count": 0,
"total_vram_mb": 0,
"available_vram_mb": 0
},
"studio_endpoints": [],
"capabilities": {}
}
}
```
AISBF reads these `register` fields today:
- top-level: `v`, `op`, `request_id`, optional `registration_token`, optional `capabilities`
- from `payload`: `endpoint`, `transport`, `registration_token`, `studio_endpoints`, `hardware`, `gpus`, `gpu_count`, `total_vram_mb`, `available_vram_mb`, `capabilities`
If `payload.registration_token` or top-level `registration_token` is present and does not match the handshake token, AISBF replies with an error envelope.
### Envelope format
AISBF sends one JSON request envelope per operation:
```json
{
"v": 1,
"op": "chat.completions",
"request_id": "coderai-1746960000000",
"provider_id": "coderai",
"client_id": "aisbf-default",
"registration_token": "optional-shared-secret",
"payload": {
"model": "llama3.1:8b",
"messages": [
{"role": "user", "content": "hello"}
],
"stream": false
}
}
```
### Non-streaming response envelope
```json
{
"v": 1,
"request_id": "coderai-1746960000000",
"status": "ok",
"payload": {
"id": "chatcmpl-123",
"object": "chat.completion",
"created": 1746960000,
"model": "llama3.1:8b",
"choices": [
{
"index": 0,
"message": {"role": "assistant", "content": "hello"},
"finish_reason": "stop"
}
],
"usage": {
"prompt_tokens": 10,
"completion_tokens": 5,
"total_tokens": 15
}
}
}
```
### Error response envelope
```json
{
"v": 1,
"request_id": "coderai-1746960000000",
"status": "error",
"error": "Model not available",
"code": "model_not_found",
"details": {
"model": "missing-model"
}
}
```
### Streaming response envelopes
For `chat.completions` with `stream=true`, send multiple envelopes.
Each chunk envelope:
```json
{
"v": 1,
"request_id": "coderai-1746960000000",
"status": "ok",
"event": "chunk",
"payload": {
"chunk": "data: {\"id\":\"chatcmpl-123\",\"object\":\"chat.completion.chunk\",\"choices\":[{\"delta\":{\"content\":\"hel\"},\"index\":0,\"finish_reason\":null}]}\n\n"
}
}
```
Final envelope:
```json
{
"v": 1,
"request_id": "coderai-1746960000000",
"status": "ok",
"event": "done",
"payload": {}
}
```
Important:
- `payload.chunk` should be a full SSE fragment already formatted exactly as AISBF should relay it.
- This keeps AISBF transport-simple and lets CoderAI own protocol correctness.
- Include `data: [DONE]\n\n` as one of the streamed chunks when the upstream semantics require it.
## Async broker behavior requirements
The broker client must be fully asynchronous.
- do not block the WebSocket receive loop on hardware probing, model loading, inference, file I/O, or reconnect bookkeeping
- handle inbound requests concurrently and correlate replies by `request_id`
- keep heartbeat handling lightweight
- preserve responsiveness to ping/pong traffic while local work is running
AISBF now independently drains queued outbound broker requests while also reading inbound WebSocket messages, so client implementations should not assume a strict request/response lockstep over the socket.
## Broker session visibility, persistence, and multi-node routing
AISBF now tracks two broker states:
- live connected sessions held in memory for active request routing
- persisted session metadata snapshots stored in `~/.aisbf/coderai_broker_sessions.json`
Persisted metadata is dashboard-facing only. It is used to show the last known session details after restart, but it is not treated as an active transport path until CoderAI reconnects.
For multi-node AISBF deployments behind a reverse proxy / load balancer:
- session status and ownership metadata are stored in the configured AISBF cache backend
- requests are enqueued into cache-backed broker queues keyed by broker session id
- the AISBF node holding the live WebSocket consumes queued requests and forwards them to CoderAI
- replies are written back through cache-backed reply keys so the AISBF node that originated the request can receive the result
Redis is the preferred backend for this distributed mode. SQLite/MySQL can operate as polling-based fallbacks. Memory/file cache backends are not suitable for cross-node broker routing.
Expected behavior:
- after reconnect, the persisted snapshot is refreshed with the new live session details
- after disconnect or AISBF restart, the dashboard may still show the last known client id / endpoint / last seen, but `connected` remains false until a new WebSocket is established
## Bridge operations CoderAI must implement
### `op = "models.list"`
Request:
```json
{
"v": 1,
"op": "models.list",
"request_id": "...",
"provider_id": "coderai",
"client_id": "aisbf-default",
"payload": {}
}
```
Response payload should be equivalent to `GET /v1/models`.
### `op = "chat.completions"`
Payload is equivalent to OpenAI `POST /v1/chat/completions` request body.
### `op = "capabilities"`
Response payload should be equivalent to `GET /coderai/capabilities`.
### `op = "register"`
Purpose:
- allow an outbound-only CoderAI agent to announce itself
- report its reachable transports
- report enabled Studio-native endpoints
- report model inventory
- attach metadata to the live AISBF broker session
Request payload from AISBF:
```json
{
"provider_id": "coderai",
"client_id": "aisbf-default",
"transport": "websocket",
"endpoint": "wss://broker.example/coderai/ws"
}
```
Recommended response payload:
```json
{
"accepted": true,
"client_id": "aisbf-default",
"session_id": "sess_123",
"expires_at": 1746963600,
"transports": {
"http": false,
"websocket": true
},
"models": [
{"id": "llama3.1:8b", "studio_capabilities": ["chat", "tool_use"]}
],
"studio_endpoints": [
"v1/video/dub",
"v1/audio/tts"
]
}
```
### `op = "proxy"`
Purpose:
- tunnel arbitrary Studio-native endpoint requests over WebSocket when AISBF cannot directly reach CoderAI over HTTP.
Request payload:
```json
{
"endpoint_path": "v1/video/dub",
"method": "POST",
"headers": {
"x-request-id": "studio-job-123",
"accept": "text/event-stream"
},
"query_params": {
"job_id": "dub_123"
},
"body": {
"model": "local-video-model",
"input": "Dub this clip to Italian"
},
"multipart": {
"fields": [{"name": "model", "value": "whisper-large"}],
"files": [{"name": "file", "filename": "sample.wav", "content_type": "audio/wav", "data_base64": "<base64>"}]
},
"stream": true
}
```
Response payload:
```json
{
"status_code": 200,
"headers": {
"content-type": "application/json"
},
"body": {
"job_id": "dub_123",
"status": "queued"
}
}
```
Binary response payloads may instead use:
```json
{
"status_code": 200,
"content_type": "audio/mpeg",
"body_base64": "<base64>",
"headers": {
"content-disposition": "attachment; filename=preview.mp3"
}
}
```
Streaming and progress responses may emit multiple envelopes with `event` values like `progress`, `output`, `log`, `data`, `chunk`, and finally `done` or `completed`.
Recommended progress chunk payload:
```json
{
"v": 1,
"request_id": "coderai-1746960000000",
"status": "ok",
"event": "progress",
"payload": {
"chunk": "event: progress\ndata: {\"active\":true,\"current\":5,\"total\":20,\"pct\":25,\"elapsed\":12}\n\n"
}
}
```
Capability advertisements should include endpoint metadata for custom pipelines, including supported methods, streaming mode, expected input/output modalities, and whether multipart or binary transport is required.
## Recommended CoderAI architecture
### Server components
1. **OpenAI compatibility router**
- exposes `/v1/models`, `/v1/chat/completions`, and any other supported OpenAI endpoints
2. **Studio-native router**
- exposes endpoints such as `v1/video/dub`, `v1/audio/tts`, `v1/images/generate`, etc.
3. **Capabilities registry**
- enumerates enabled endpoints
- enumerates loaded models
- computes normalized `studio_capabilities`
4. **WebSocket bridge server**
- accepts AISBF envelopes
- dispatches by `op`
- for `proxy`, internally calls the same handler used by HTTP routes
- for `chat.completions`, either:
- returns a full JSON result, or
- emits `chunk` envelopes carrying ready-made SSE fragments
5. **Optional outbound broker client**
- when behind NAT, CoderAI can establish and maintain an outbound WebSocket connection to an AISBF-reachable broker endpoint
- that broker can multiplex messages by `client_id`
## NAT-friendly model
There are two viable patterns.
### Pattern A: AISBF directly opens WebSocket to CoderAI
- simplest
- works when CoderAI is reachable by `ws://` or `wss://`
- no NAT punching support
### Pattern B: CoderAI dials outward and stays connected
- best for NAT/private LAN
- CoderAI opens a persistent outbound WebSocket to a public AISBF-side broker
- broker stores the live session keyed by `client_id`
- AISBF routes provider operations to that session
If you implement Pattern B, keep the same envelope contract. Only the connection initiator changes.
### Recommended outbound broker flow
1. CoderAI opens persistent WebSocket to AISBF broker endpoint.
2. AISBF immediately acknowledges with `event=registered` and a `session_id`.
3. CoderAI sends `op=register` with endpoint, transports, capabilities, models, and Studio endpoints.
4. AISBF stores that live session under `provider_id + client_id`.
5. All AISBF provider operations can now be delivered to that live outbound socket.
6. If the socket drops, AISBF marks the session offline and fails in-flight requests.
## Strong recommendations for metadata
For every model, provide:
- `id`
- `name`
- `description`
- `context_length`
- `architecture.input_modalities`
- `architecture.output_modalities`
- `supported_parameters`
- `default_parameters`
- `studio_capabilities`
For server capabilities, provide:
- transport availability
- OpenAI-compatible endpoint availability
- Studio-native endpoint availability
- current server version
- optional hardware metadata (`gpu`, `memory_gb`, `quantization`, `throughput_hint`)
## Minimal Python implementation sketch for CoderAI
```python
from fastapi import FastAPI, WebSocket
from fastapi.responses import JSONResponse
import json
app = FastAPI()
@app.get("/v1/models")
async def list_models():
return {
"data": [
{
"id": "llama3.1:8b",
"name": "llama3.1:8b",
"context_length": 131072,
"studio_capabilities": ["chat", "tool_use", "code_generation"],
}
]
}
@app.get("/coderai/capabilities")
async def capabilities():
return {
"server": {"name": "coderai", "version": "0.1.0"},
"transports": {"http": True, "websocket": True},
"openai_compat": {"chat_completions": True, "models": True},
"studio": {"enabled": True, "endpoints": ["v1/video/dub"]},
}
@app.websocket("/coderai/ws")
async def coderai_ws(ws: WebSocket):
await ws.accept()
while True:
message = await ws.receive_text()
envelope = json.loads(message)
op = envelope["op"]
request_id = envelope["request_id"]
if op == "models.list":
await ws.send_text(json.dumps({
"v": 1,
"request_id": request_id,
"status": "ok",
"payload": await list_models(),
}))
elif op == "capabilities":
await ws.send_text(json.dumps({
"v": 1,
"request_id": request_id,
"status": "ok",
"payload": await capabilities(),
}))
else:
await ws.send_text(json.dumps({
"v": 1,
"request_id": request_id,
"status": "error",
"error": f"Unsupported op: {op}",
}))
```
## Implementation checklist for the CoderAI-side LLM session
- add `/coderai/capabilities`
- add `/coderai/register`
- add `/coderai/ws`
- expose model metadata with `studio_capabilities`
- support `models.list`
- support `chat.completions`
- support streaming `chunk` and `done` events
- support `proxy` for Studio-native endpoints
- optionally support persistent outbound broker mode for NAT traversal
- protect bridge/register endpoints with a shared secret or signed token
## Compatibility notes
- AISBF currently assumes WebSocket streamed chunks arrive already formatted as SSE fragments.
- AISBF currently expects WebSocket non-streaming responses to carry the raw OpenAI-compatible response under `payload`.
- AISBF can consume either direct HTTP OpenAI compatibility or the WebSocket bridge for chat/model listing.
- AISBF generic Studio proxy now uses the provider bridge for `coderai`, making NAT traversal possible for non-chat endpoints too.
......@@ -84,8 +84,13 @@ def test_build_broker_runtime_config_global_scope_builds_url_and_headers():
"x-coderai-provider-id": "provider-1",
"x-coderai-client-id": "client-1",
"x-coderai-username": "global",
"x-coderai-registration-token": "token-123",
"x-coderai-advertised-endpoint": "https://server.example.com",
}
assert runtime.provider_id == "provider-1"
assert runtime.client_id == "client-1"
assert runtime.username == "global"
assert runtime.registration_token == "token-123"
assert runtime.transport == "websocket"
assert runtime.heartbeat_interval_seconds == 30
assert runtime.connect_timeout_seconds == 10
......@@ -94,6 +99,38 @@ def test_build_broker_runtime_config_global_scope_builds_url_and_headers():
assert runtime.reconnect_max_delay_seconds == 60
def test_build_broker_runtime_config_global_scope_uses_global_service_paths():
runtime = build_broker_runtime_config(
BrokerConfig(
enabled=True,
base_url="https://broker.example.com/base",
scope="global",
username="global",
provider_id="provider-1",
client_id="client-1",
registration_token="token-123",
)
)
assert runtime.websocket_url.startswith("wss://broker.example.com/base/api/coderai/wss")
def test_build_broker_runtime_config_user_scope_uses_user_service_paths():
runtime = build_broker_runtime_config(
BrokerConfig(
enabled=True,
base_url="https://broker.example.com/base",
scope="user",
username="alice",
provider_id="provider-1",
client_id="client-1",
registration_token="token-123",
)
)
assert runtime.websocket_url.startswith("wss://broker.example.com/base/api/u/alice/coderai/wss")
def test_build_broker_runtime_config_rejects_invalid_global_username():
try:
build_broker_runtime_config(
......@@ -139,6 +176,7 @@ def test_build_broker_runtime_config_user_scope_uses_user_path():
"x-coderai-provider-id": "provider-1",
"x-coderai-client-id": "client-1",
"x-coderai-username": "alice",
"x-coderai-registration-token": "token-123",
"x-coderai-advertised-endpoint": "https://server.example.com/alice",
}
......@@ -165,6 +203,28 @@ def test_build_broker_runtime_config_preserves_base_url_prefix_in_websocket_url(
)
def test_build_broker_runtime_config_does_not_duplicate_api_prefix_when_base_url_already_ends_with_api():
runtime = build_broker_runtime_config(
BrokerConfig(
enabled=True,
base_url="https://aisbf.cloud/api",
scope="global",
username="global",
provider_id="provider-1",
client_id="client-1",
registration_token="token-123",
)
)
assert runtime.websocket_url == (
"wss://aisbf.cloud/api/coderai/wss"
"?provider_id=provider-1"
"&client_id=client-1"
"&username=global"
"&registration_token=token-123"
)
def test_build_broker_runtime_config_encodes_reserved_username_path_characters():
runtime = build_broker_runtime_config(
BrokerConfig(
......@@ -187,6 +247,48 @@ def test_build_broker_runtime_config_encodes_reserved_username_path_characters()
)
def test_build_broker_runtime_config_uses_manual_websocket_path_override():
runtime = build_broker_runtime_config(
BrokerConfig(
enabled=True,
base_url="https://broker.example.com/prefix",
scope="global",
username="global",
provider_id="provider-1",
client_id="client-1",
registration_token="token-123",
websocket_path="/custom/broker/socket",
)
)
assert runtime.websocket_path == "/custom/broker/socket"
assert runtime.websocket_url == (
"wss://broker.example.com/prefix/custom/broker/socket"
"?provider_id=provider-1"
"&client_id=client-1"
"&username=global"
"&registration_token=token-123"
)
def test_build_broker_runtime_config_normalizes_manual_websocket_path_override_without_leading_slash():
runtime = build_broker_runtime_config(
BrokerConfig(
enabled=True,
base_url="https://broker.example.com",
scope="user",
username="alice",
provider_id="provider-1",
client_id="client-1",
registration_token="token-123",
websocket_path="broker/ws",
)
)
assert runtime.websocket_path == "broker/ws"
assert runtime.websocket_url.startswith("wss://broker.example.com/broker/ws")
def test_build_broker_runtime_config_rejects_invalid_user_scope_username():
try:
build_broker_runtime_config(
......@@ -294,11 +396,17 @@ def test_build_register_message_includes_capabilities_and_hardware():
"v": 1,
"op": "register",
"request_id": "req-1",
"registration_token": "token-123",
"capabilities": capabilities,
"payload": {
"endpoint": "https://server.example.com/alice",
"transport": "websocket",
"registration_token": "token-123",
"hardware": {"gpu": True, "memory_gb": 24},
"gpus": [],
"gpu_count": 0,
"total_vram_mb": 0,
"available_vram_mb": 0,
"studio_endpoints": EXPECTED_STUDIO_ENDPOINTS,
"capabilities": capabilities,
},
......@@ -318,17 +426,65 @@ def test_build_register_message_defaults_token_and_studio_endpoints_for_empty_ru
"v": 1,
"op": "register",
"request_id": "req-1",
"registration_token": "",
"capabilities": {"server": "codai"},
"payload": {
"endpoint": "",
"transport": "websocket",
"registration_token": "",
"hardware": None,
"gpus": [],
"gpu_count": 0,
"total_vram_mb": 0,
"available_vram_mb": 0,
"studio_endpoints": DEFAULT_STUDIO_ENDPOINTS,
"capabilities": {"server": "codai"},
},
}
def test_build_hardware_summary_reports_torch_cuda_vram(monkeypatch):
class FakeProps:
total_memory = 24 * 1024 * 1024 * 1024
class FakeCuda:
@staticmethod
def is_available():
return True
@staticmethod
def device_count():
return 1
@staticmethod
def current_device():
return 0
@staticmethod
def get_device_properties(index):
return FakeProps()
@staticmethod
def get_device_name(index):
return "RTX Test"
@staticmethod
def mem_get_info():
return (10 * 1024 * 1024 * 1024, 24 * 1024 * 1024 * 1024)
class FakeTorch:
cuda = FakeCuda()
monkeypatch.setitem(sys.modules, "torch", FakeTorch())
hardware = build_hardware_summary()
assert hardware["gpu_count"] == 1
assert hardware["total_vram_mb"] == 24576
assert hardware["available_vram_mb"] == 10240
assert hardware["gpus"] == [{"index": 0, "name": "RTX Test", "total_vram_mb": 24576}]
def test_build_capabilities_document_lists_openai_and_studio_support():
document = build_capabilities_document(hardware={"gpu": True})
......
......@@ -6,6 +6,7 @@ sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
import pytest
from fastapi import FastAPI
from fastapi import File, Form, UploadFile
from fastapi.responses import JSONResponse
from fastapi.testclient import TestClient
from starlette.responses import Response
......@@ -87,6 +88,7 @@ async def test_execute_broker_request_returns_success_envelope_for_json_route():
envelope = BrokerRequestEnvelope(
request_id="req-123",
op="chat.completions",
method="POST",
path="/v1/chat/completions",
headers={"accept": "application/json"},
......@@ -96,7 +98,7 @@ async def test_execute_broker_request_returns_success_envelope_for_json_route():
response = await execute_broker_request(app, envelope)
assert response["request_id"] == "req-123"
assert response["ok"] is True
assert response["status"] == "ok"
assert response["payload"] == {
"status_code": 201,
"headers": {
......@@ -125,6 +127,7 @@ async def test_execute_broker_request_preserves_binary_payload_metadata():
envelope = BrokerRequestEnvelope(
request_id="req-binary",
op="proxy",
method="GET",
path="/v1/images/render",
)
......@@ -132,7 +135,7 @@ async def test_execute_broker_request_preserves_binary_payload_metadata():
response = await execute_broker_request(app, envelope)
assert response["request_id"] == "req-binary"
assert response["ok"] is True
assert response["status"] == "ok"
assert response["payload"] == {
"status_code": 200,
"headers": {
......@@ -148,12 +151,75 @@ async def test_execute_broker_request_preserves_binary_payload_metadata():
assert response["metrics"]["elapsed_ms"] >= 0
@pytest.mark.anyio("asyncio")
async def test_execute_broker_request_maps_proxy_operation_to_internal_route():
app = FastAPI()
@app.post("/v1/video/dub")
async def dub_route(payload: dict):
return {"received": payload, "route": "dub"}
envelope = BrokerRequestEnvelope(
request_id="req-proxy-op",
op="proxy",
payload={
"endpoint_path": "v1/video/dub",
"method": "POST",
"headers": {"content-type": "application/json"},
"body": {"prompt": "hello"},
},
)
response = await execute_broker_request(app, envelope)
assert response["status"] == "ok"
assert response["payload"]["status_code"] == 200
assert response["payload"]["body"] == '{"received":{"prompt":"hello"},"route":"dub"}'
@pytest.mark.anyio("asyncio")
async def test_execute_broker_request_supports_proxy_multipart_payloads():
app = FastAPI()
@app.post("/v1/audio/transcriptions")
async def transcription_route(model: str = Form(...), file: UploadFile = File(...)):
data = await file.read()
return {"model": model, "filename": file.filename, "size": len(data)}
envelope = BrokerRequestEnvelope(
request_id="req-multipart",
op="proxy",
payload={
"endpoint_path": "v1/audio/transcriptions",
"method": "POST",
"multipart": {
"fields": [{"name": "model", "value": "whisper-large"}],
"files": [
{
"name": "file",
"filename": "sample.wav",
"content_type": "audio/wav",
"data_base64": "aGVsbG8=",
}
],
},
},
)
response = await execute_broker_request(app, envelope)
assert response["status"] == "ok"
assert response["payload"]["status_code"] == 200
assert response["payload"]["body"] == '{"model":"whisper-large","filename":"sample.wav","size":5}'
@pytest.mark.anyio("asyncio")
async def test_brokered_models_match_direct_http_response_shape():
direct_response = TestClient(real_app).get("/v1/models")
envelope = BrokerRequestEnvelope(
request_id="req-models-shape",
op="models.list",
method="GET",
path="/v1/models",
headers={"accept": "application/json"},
......@@ -163,7 +229,7 @@ async def test_brokered_models_match_direct_http_response_shape():
brokered_body = json.loads(brokered_response["payload"]["body"])
direct_body = direct_response.json()
assert brokered_response["ok"] is True
assert brokered_response["status"] == "ok"
assert brokered_response["payload"]["status_code"] == direct_response.status_code
assert brokered_response["payload"]["content_type"] == direct_response.headers["content-type"]
assert brokered_response["payload"]["headers"]["content-type"] == direct_response.headers["content-type"]
......@@ -192,6 +258,7 @@ async def test_execute_broker_request_rejects_unsupported_endpoint():
app = FastAPI()
envelope = BrokerRequestEnvelope(
request_id="req-unsupported",
op="proxy",
method="GET",
path="/internal",
)
......@@ -199,8 +266,9 @@ async def test_execute_broker_request_rejects_unsupported_endpoint():
response = await execute_broker_request(app, envelope)
assert response == {
"v": 1,
"request_id": "req-unsupported",
"ok": False,
"status": "error",
"error": {
"code": "unsupported_endpoint",
"message": "Unsupported endpoint: /internal",
......
......@@ -28,6 +28,7 @@ class FakeWebSocket:
def __init__(self, messages):
self._messages = list(messages)
self.sent_messages = []
self.closed = False
async def recv(self):
return self._messages.pop(0)
......@@ -35,6 +36,10 @@ class FakeWebSocket:
async def send(self, message):
self.sent_messages.append(message)
async def close(self):
self.closed = True
return None
@pytest.mark.anyio("asyncio")
async def test_broker_client_waits_for_registered_before_register():
......@@ -46,18 +51,31 @@ async def test_broker_client_waits_for_registered_before_register():
"accepted": True,
"session_id": "session-123",
}
)
),
json.dumps({"request_id": "session-123", "status": "ok"}),
]
)
runtime = BrokerRuntimeConfig(
enabled=True,
websocket_url="wss://broker.example/ws",
headers={"Authorization": "Bearer token"},
registration_token="token",
advertised_endpoint="http://localhost:8000",
)
with patch("codai.broker.client.websockets.connect", new=AsyncMock(return_value=websocket)) as connect_mock:
client = BrokerClient(runtime)
with patch(
"codai.broker.client.build_hardware_summary",
return_value={
"hostname": "test-host",
"platform": "linux",
"gpus": [],
"gpu_count": 0,
"total_vram_mb": 0,
"available_vram_mb": 0,
},
):
await client.connect_and_register()
connect_mock.assert_awaited_once_with(
......@@ -71,7 +89,11 @@ async def test_broker_client_waits_for_registered_before_register():
register_message = json.loads(websocket.sent_messages[0])
assert register_message["op"] == "register"
assert register_message["request_id"] == "session-123"
assert register_message["capabilities"] == register_message["payload"]["capabilities"]
assert register_message["payload"]["registration_token"] == "token"
assert register_message["payload"]["hardware"]["gpu_count"] == 0
assert register_message["payload"]["gpus"] == []
assert register_message["payload"]["capabilities"]["transports"] == ["websocket"]
assert register_message["payload"]["studio_endpoints"] == DEFAULT_STUDIO_ENDPOINTS
......@@ -98,7 +120,10 @@ async def test_broker_client_rejects_registered_ack_without_session_id():
@pytest.mark.anyio("asyncio")
async def test_broker_client_passes_connect_timeout_to_websocket_connection():
websocket = FakeWebSocket(
[json.dumps({"event": "registered", "accepted": True, "session_id": "session-123"})]
[
json.dumps({"event": "registered", "accepted": True, "session_id": "session-123"}),
json.dumps({"request_id": "session-123", "status": "ok"}),
]
)
runtime = BrokerRuntimeConfig(
enabled=True,
......@@ -121,7 +146,10 @@ async def test_broker_client_passes_connect_timeout_to_websocket_connection():
@pytest.mark.anyio("asyncio")
async def test_broker_client_applies_request_timeout_to_socket_reads():
websocket = FakeWebSocket(
[json.dumps({"event": "registered", "accepted": True, "session_id": "session-123"})]
[
json.dumps({"event": "registered", "accepted": True, "session_id": "session-123"}),
json.dumps({"request_id": "session-123", "status": "ok"}),
]
)
runtime = BrokerRuntimeConfig(enabled=True, request_timeout_seconds=23)
timeout_calls = []
......@@ -136,7 +164,107 @@ async def test_broker_client_applies_request_timeout_to_socket_reads():
client = BrokerClient(runtime)
await client.connect_and_register()
assert timeout_calls == [runtime.request_timeout_seconds]
assert timeout_calls == [runtime.request_timeout_seconds, runtime.request_timeout_seconds]
@pytest.mark.anyio("asyncio")
async def test_broker_client_rejects_failed_register_ack():
websocket = FakeWebSocket(
[
json.dumps({"event": "registered", "accepted": True, "session_id": "session-123"}),
json.dumps({"request_id": "session-123", "status": "error"}),
]
)
runtime = BrokerRuntimeConfig(enabled=True, headers={"Authorization": "Bearer token"})
with patch("codai.broker.client.websockets.connect", new=AsyncMock(return_value=websocket)):
client = BrokerClient(runtime)
with pytest.raises(ValueError, match="broker did not acknowledge register message"):
await client.connect_and_register()
@pytest.mark.anyio("asyncio")
async def test_broker_client_stores_registered_session_metadata():
websocket = FakeWebSocket(
[
json.dumps(
{
"event": "registered",
"accepted": True,
"session_id": "session-123",
"provider_id": "provider-1",
"client_id": "client-1",
"username": "alice",
"scope_name": "alice",
"owner_user_id": 42,
"expires_at": "2026-05-13T00:00:00Z",
}
),
json.dumps({"request_id": "session-123", "status": "ok"}),
]
)
runtime = BrokerRuntimeConfig(enabled=True, headers={"Authorization": "Bearer token"})
with patch("codai.broker.client.websockets.connect", new=AsyncMock(return_value=websocket)):
client = BrokerClient(runtime)
await client.connect_and_register()
assert client.session_metadata == {
"session_id": "session-123",
"provider_id": "provider-1",
"client_id": "client-1",
"username": "alice",
"scope_name": "alice",
"owner_user_id": 42,
"expires_at": "2026-05-13T00:00:00Z",
}
@pytest.mark.anyio("asyncio")
async def test_broker_client_run_forever_closes_registered_socket_before_reconnect():
runtime = BrokerRuntimeConfig(
enabled=True,
reconnect_initial_delay_seconds=1,
reconnect_max_delay_seconds=4,
)
client = BrokerClient(runtime)
attempts = []
sockets = []
reconnected = asyncio.Event()
class DisconnectingWebSocket(FakeWebSocket):
def __init__(self):
super().__init__([])
async def recv(self):
raise RuntimeError("disconnect")
async def fake_connect_and_register():
attempts.append("connect")
websocket = DisconnectingWebSocket()
sockets.append(websocket)
client.websocket = websocket
if len(attempts) >= 2:
reconnected.set()
real_sleep = asyncio.sleep
async def fake_sleep(delay):
await real_sleep(0)
client.connect_and_register = AsyncMock(side_effect=fake_connect_and_register)
with patch("codai.broker.client.asyncio.sleep", new=AsyncMock(side_effect=fake_sleep)):
task = asyncio.create_task(client.run_forever())
await asyncio.wait_for(reconnected.wait(), timeout=0.2)
task.cancel()
with pytest.raises(asyncio.CancelledError):
await task
assert attempts[:2] == ["connect", "connect"]
assert sockets[0].closed is True
@pytest.mark.anyio("asyncio")
......@@ -158,8 +286,9 @@ async def test_broker_client_replies_to_heartbeat():
)
assert response == {
"v": 1,
"request_id": "req-heartbeat",
"ok": True,
"status": "ok",
"event": "heartbeat",
"payload": {"ts": 1715443200},
}
......@@ -174,7 +303,7 @@ async def test_broker_client_dispatches_non_heartbeat_messages():
dispatcher = AsyncMock(
return_value={
"request_id": "req-dispatch",
"ok": True,
"status": "ok",
"payload": {"status": "handled"},
}
)
......@@ -193,7 +322,7 @@ async def test_broker_client_dispatches_non_heartbeat_messages():
dispatcher.assert_awaited_once_with(json.loads(raw_message))
assert response == {
"request_id": "req-dispatch",
"ok": True,
"status": "ok",
"payload": {"status": "handled"},
}
assert websocket.sent_messages == [json.dumps(response)]
......@@ -214,11 +343,8 @@ async def test_broker_client_dispatches_request_messages_through_fastapi_app():
service = BrokerService(client, app)
client.websocket = FakeWebSocket([])
message = {
"op": "request",
"op": "chat.completions",
"request_id": "req-app",
"method": "POST",
"path": "/v1/chat/completions",
"headers": {"accept": "application/json"},
"payload": {"message": "hello"},
}
......@@ -227,9 +353,10 @@ async def test_broker_client_dispatches_request_messages_through_fastapi_app():
assert envelope == BrokerRequestEnvelope(
request_id="req-app",
op="chat.completions",
method="POST",
path="/v1/chat/completions",
headers={"accept": "application/json"},
headers={},
query={},
payload={"message": "hello"},
stream=False,
......@@ -237,7 +364,7 @@ async def test_broker_client_dispatches_request_messages_through_fastapi_app():
)
expected_response = await execute_broker_request(app, envelope)
assert response["request_id"] == "req-app"
assert response["ok"] is True
assert response["status"] == "ok"
assert response["payload"] == expected_response["payload"]
assert response["payload"]["status_code"] == 200
assert response["payload"]["content_type"] == "application/json"
......@@ -245,6 +372,62 @@ async def test_broker_client_dispatches_request_messages_through_fastapi_app():
assert client.websocket.sent_messages == [json.dumps(response)]
@pytest.mark.anyio("asyncio")
async def test_broker_client_maps_proxy_op_to_endpoint_envelope():
client = BrokerClient(BrokerRuntimeConfig(enabled=True))
message = {
"op": "proxy",
"request_id": "req-proxy",
"payload": {
"endpoint_path": "v1/video/dub",
"method": "POST",
"headers": {"accept": "application/json"},
"query_params": {"job_id": "123"},
"body": {"prompt": "hello"},
"stream": True,
},
}
envelope = client.message_to_envelope(message)
assert envelope == BrokerRequestEnvelope(
request_id="req-proxy",
op="proxy",
method="POST",
path="v1/video/dub",
headers={"accept": "application/json"},
query={"job_id": "123"},
payload=message["payload"],
stream=True,
content_type="application/json",
)
@pytest.mark.anyio("asyncio")
async def test_broker_client_maps_models_list_operation_to_internal_route():
client = BrokerClient(BrokerRuntimeConfig(enabled=True))
envelope = client.message_to_envelope(
{
"op": "models.list",
"request_id": "req-models",
"payload": {},
}
)
assert envelope == BrokerRequestEnvelope(
request_id="req-models",
op="models.list",
method="GET",
path="/v1/models",
headers={},
query={},
payload={},
stream=False,
content_type="application/json",
)
@pytest.mark.anyio("asyncio")
async def test_broker_client_next_reconnect_delay_caps_at_max():
runtime = BrokerRuntimeConfig(
......@@ -335,6 +518,65 @@ async def test_broker_client_run_forever_reconnects_after_disconnect():
assert sleep_calls[0] == 1
@pytest.mark.anyio("asyncio")
async def test_broker_client_run_forever_processes_requests_concurrently():
runtime = BrokerRuntimeConfig(enabled=True, request_timeout_seconds=1)
client = BrokerClient(runtime)
first_started = asyncio.Event()
release_first = asyncio.Event()
second_finished = asyncio.Event()
class MessageWebSocket:
def __init__(self):
self.sent_messages = []
self._messages = [
json.dumps({"op": "chat.completions", "request_id": "req-1", "payload": {"message": "one"}}),
json.dumps({"op": "capabilities", "request_id": "req-2", "payload": {}}),
]
async def recv(self):
if self._messages:
return self._messages.pop(0)
await asyncio.sleep(0)
raise RuntimeError("disconnect")
async def send(self, message):
self.sent_messages.append(json.loads(message))
async def close(self):
return None
websocket = MessageWebSocket()
async def fake_connect_and_register():
client.websocket = websocket
async def dispatcher(message):
if message["request_id"] == "req-1":
first_started.set()
await release_first.wait()
return {"v": 1, "request_id": "req-1", "status": "ok", "payload": {"done": 1}}
second_finished.set()
return {"v": 1, "request_id": "req-2", "status": "ok", "payload": {"done": 2}}
client.connect_and_register = AsyncMock(side_effect=fake_connect_and_register)
client.dispatcher = dispatcher
task = asyncio.create_task(client.run_forever())
await asyncio.wait_for(first_started.wait(), timeout=0.2)
await asyncio.wait_for(second_finished.wait(), timeout=0.2)
release_first.set()
await asyncio.sleep(0.05)
task.cancel()
with pytest.raises(asyncio.CancelledError):
await task
sent_ids = [message["request_id"] for message in websocket.sent_messages]
assert "req-2" in sent_ids
@pytest.mark.anyio("asyncio")
async def test_broker_client_run_forever_handles_heartbeat_before_reconnect():
runtime = BrokerRuntimeConfig(
......@@ -394,8 +636,9 @@ async def test_broker_client_run_forever_handles_heartbeat_before_reconnect():
assert attempts[:2] == ["connect", "connect"]
assert sleep_calls[0] == 1
assert json.loads(sockets[0].sent_messages[0]) == {
"v": 1,
"request_id": "req-heartbeat",
"ok": True,
"status": "ok",
"event": "heartbeat",
"payload": {"ts": 1715443200},
}
......@@ -429,6 +672,32 @@ async def test_broker_service_start_and_stop_manage_background_task():
assert service.task is None
@pytest.mark.anyio("asyncio")
async def test_broker_service_start_is_idempotent_while_running():
runtime = BrokerRuntimeConfig(enabled=True)
client = BrokerClient(runtime)
started = asyncio.Event()
async def fake_run_forever():
started.set()
await asyncio.Future()
client.run_forever = AsyncMock(side_effect=fake_run_forever)
service = BrokerService(client)
service.start()
service.start()
await started.wait()
task = service.task
assert task is not None
await service.stop()
client.run_forever.assert_awaited_once_with()
@pytest.mark.anyio("asyncio")
async def test_fastapi_lifespan_stops_broker_service_before_model_cleanup():
from fastapi import FastAPI
......
......@@ -28,8 +28,9 @@ async def test_stream_chunk_envelope_preserves_request_id_and_order():
)
assert chunk_one == {
"v": 1,
"request_id": "req-stream",
"ok": True,
"status": "ok",
"event": "stream",
"payload": {
"sequence": 0,
......@@ -37,8 +38,9 @@ async def test_stream_chunk_envelope_preserves_request_id_and_order():
},
}
assert chunk_two == {
"v": 1,
"request_id": "req-stream",
"ok": True,
"status": "ok",
"event": "stream",
"payload": {
"sequence": 1,
......@@ -56,8 +58,9 @@ async def test_finalize_stream_attaches_metrics():
)
assert response == {
"v": 1,
"request_id": "req-stream",
"ok": True,
"status": "ok",
"event": "stream_end",
"payload": {
"total_chunks": 2,
......@@ -82,6 +85,7 @@ async def test_execute_broker_request_wraps_streaming_response_metadata():
envelope = BrokerRequestEnvelope(
request_id="req-stream-route",
op="proxy",
method="GET",
path="/v1/chat/completions",
stream=True,
......@@ -89,9 +93,9 @@ async def test_execute_broker_request_wraps_streaming_response_metadata():
response = await execute_broker_request(app, envelope)
assert set(response) == {"request_id", "ok", "payload", "metrics"}
assert set(response) == {"v", "request_id", "status", "payload", "metrics"}
assert response["request_id"] == "req-stream-route"
assert response["ok"] is True
assert response["status"] == "ok"
assert response["payload"] == {
"status_code": 200,
"headers": {
......@@ -127,6 +131,7 @@ async def test_brokered_streaming_route_preserves_event_stream_body():
envelope = BrokerRequestEnvelope(
request_id="req-stream-equivalence",
op="proxy",
method="GET",
path="/v1/chat/completions",
headers={"accept": "text/event-stream"},
......@@ -135,7 +140,7 @@ async def test_brokered_streaming_route_preserves_event_stream_body():
brokered_response = await execute_broker_request(app, envelope)
assert brokered_response["ok"] is True
assert brokered_response["status"] == "ok"
assert brokered_response["payload"]["status_code"] == direct_status_code
assert brokered_response["payload"]["content_type"] == direct_content_type
assert brokered_response["payload"]["headers"]["content-type"] == direct_content_type
......
......@@ -849,7 +849,10 @@ def test_settings_template_includes_broker_controls():
assert "AISBF Broker" in template
assert "s-broker-enabled" in template
assert "s-broker-base-url" in template
assert 'id="s-broker-scope" class="form-input" onchange="toggleBrokerFields()"' in template
assert "s-broker-provider-id" in template
assert "s-broker-client-id" in template
assert "s-broker-registration-token" in template
assert "s-broker-websocket-path" in template
assert "toggleBrokerFields()" in template
assert "forced to `global` for global scope" in template
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment