Add task management, quantization, and hardware telemetry

Tasks / queue management:
- Central in-memory task registry with cooperative cancel, pause/resume,
  and step progress across image/video/audio/text generation + LoRA training
- Tasks admin page (live 2s poll): cancel, interrupt, pause/resume, restart,
  remove; done jobs auto-drop from the list; bounded persisted job history
- Disable interrupted-training recovery via --no-resume-jobs + settings toggle

Quantization / acceleration:
- TurboQuant embedding vector quantization (data-free, inner-product
  preserving): built-in NumPy backend + optional turboquant-py library,
  selectable per embedding model; /v1/embeddings `quantization` param
- llama.cpp KV-cache quantization (cache_type_k/v) for GGUF text models,
  configurable in the Models UI

Hardware telemetry:
- Thermal cooldown state surfaced on the Tasks page (banner + per-task badge)
- Live CPU/GPU/RAM/VRAM usage + temperature panel via /admin/api/system-stats

Docs: API documentation gaps/accuracy pass + Swagger overhaul; DISTRIBUTION.md
implementation spec. Plus I2V LoRA training channel-mismatch fix.
Co-Authored-By: 's avatarClaude Opus 4.8 <noreply@anthropic.com>
parent 9494d1bd
...@@ -4,6 +4,11 @@ This document describes the full HTTP API exposed by CoderAI, including OpenAI-c ...@@ -4,6 +4,11 @@ This document describes the full HTTP API exposed by CoderAI, including OpenAI-c
The API is implemented with FastAPI in `codai/api/app.py` and routers under `codai/api/`, with admin routes under `codai/admin/routes.py`. The API is implemented with FastAPI in `codai/api/app.py` and routers under `codai/api/`, with admin routes under `codai/admin/routes.py`.
Interactive, always-up-to-date OpenAPI docs are served by the running server at
`/docs` (Swagger UI, linked as **API Docs** in the admin nav) and `/redoc`, with the
raw schema at `/openapi.json`. Every endpoint there carries a tag, summary, and
per-field descriptions generated from the code.
## Base URL ## Base URL
Default local server: Default local server:
...@@ -377,9 +382,9 @@ Request fields: ...@@ -377,9 +382,9 @@ Request fields:
| `response_format` | string | `url` | `url` or `b64_json` | | `response_format` | string | `url` | `url` or `b64_json` |
| `seed` | integer/null | random | Deterministic seed | | `seed` | integer/null | random | Deterministic seed |
| `negative_prompt` | string/null | `null` | Negative prompt | | `negative_prompt` | string/null | `null` | Negative prompt |
| `disable_safety_checker` | boolean | `false` | Disable safety checker where supported | | `disable_safety_checker` | boolean | `false` | Null the diffusers safety checker (only affects SD 1.x/2.x; SDXL/Flux ship none) |
| `vae_model` | string/null | `null` | Per-request VAE override | | `vae_model` | string/null | `null` | Per-request VAE override |
| `loras` | array/null | `null` | LoRA adapters `{model, weight, name}` | | `loras` | array/null | `null` | LoRA adapters — see [LoRA references](#lora-references-in-requests) for all supported fields |
| `character_profiles` | string[]/null | `null` | Saved character profile names | | `character_profiles` | string[]/null | `null` | Saved character profile names |
| `character_references` | string[]/null | `null` | Inline reference images | | `character_references` | string[]/null | `null` | Inline reference images |
| `character_strength` | number | `0.6` | IP-Adapter/reference strength | | `character_strength` | number | `0.6` | IP-Adapter/reference strength |
...@@ -403,14 +408,15 @@ curl -s "$CODERAI_URL/v1/images/generations" \ ...@@ -403,14 +408,15 @@ curl -s "$CODERAI_URL/v1/images/generations" \
}' | jq }' | jq
``` ```
LoRA example: LoRA example (the weights can also be sent inline or by registry id — see
[LoRA references](#lora-references-in-requests)):
```json ```json
{ {
"model": "image-model", "model": "image-model",
"prompt": "portrait of <character-token> as a space pilot", "prompt": "portrait of <character-token> as a space pilot",
"loras": [ "loras": [
{"model": "/home/me/loras/space_uniform.safetensors", "weight": 0.8, "name": "uniform"} {"id": "name:space_uniform", "weight": 0.8, "name": "uniform"}
] ]
} }
``` ```
...@@ -603,14 +609,15 @@ Primary fields: ...@@ -603,14 +609,15 @@ Primary fields:
| `num_inference_steps` | integer/null | model default | Diffusion steps | | `num_inference_steps` | integer/null | model default | Diffusion steps |
| `guidance_scale` | number/null | model default | CFG/guidance | | `guidance_scale` | number/null | model default | CFG/guidance |
| `seed` | integer/null | random | Seed | | `seed` | integer/null | random | Seed |
| `mode` | string | `t2v` | `t2v`, `i2v`, `v2v`, `ti2v`, `interp` | | `mode` | string | `t2v` | `t2v`, `i2v` (init_image, prompt dropped), `ti2v` (init_image + prompt), `v2v`, `interp`. The server gracefully falls back between Wan t2v/i2v pipelines when a model supports only one. |
| `image` / `init_image` | string/null | `null` | Initial/reference frame | | `image` / `init_image` | string/null | `null` | Initial/reference frame |
| `end_image` | string/null | `null` | End frame for interpolation | | `end_image` | string/null | `null` | End frame for interpolation |
| `video` | string/null | `null` | Input video for v2v/post-processing | | `video` | string/null | `null` | Input video for v2v/post-processing |
| `strength` | number/null | `null` | Denoising strength | | `strength` | number/null | `null` | Denoising strength |
| `camera_motion` | string/null | `null` | `zoom-in`, `pan-left`, etc. | | `camera_motion` | string/null | `null` | `zoom-in`, `pan-left`, etc. |
| `character_profiles` | string[]/null | `null` | Saved character profiles | | `character_profiles` | string[]/null | `null` | Saved character profiles |
| `loras` | array/null | `null` | Video LoRA adapters | | `loras` | array/null | `null` | Video LoRA adapters — see [LoRA references](#lora-references-in-requests) |
| `disable_safety_checker` | boolean | `false` | Null the diffusers safety checker (no effect on models without one, e.g. Wan) |
| `response_format` | string | `url` | `url` or `b64_mp4` | | `response_format` | string | `url` | `url` or `b64_mp4` |
Text-to-video example: Text-to-video example:
...@@ -1046,6 +1053,7 @@ Request fields: ...@@ -1046,6 +1053,7 @@ Request fields:
| `image` | string/string[]/null | `null` | Optional image input(s) for multimodal embeddings | | `image` | string/string[]/null | `null` | Optional image input(s) for multimodal embeddings |
| `encoding_format` | string | `float` | `float` or `base64` | | `encoding_format` | string | `float` | `float` or `base64` |
| `dimensions` | integer/null | `null` | Optional truncation size | | `dimensions` | integer/null | `null` | Optional truncation size |
| `quantization` | string/null | `null` | TurboQuant vector quantization: `turbo`/`turbo8`/`turbo6`/`turbo4`/`turbo2` |
Example: Example:
...@@ -1085,6 +1093,41 @@ Multimodal embedding example: ...@@ -1085,6 +1093,41 @@ Multimodal embedding example:
} }
``` ```
### TurboQuant embedding quantization
TurboQuant ([arXiv:2504.19874](https://arxiv.org/abs/2504.19874)) is a data-free,
inner-product-preserving vector quantizer: it randomly rotates each embedding so its
coordinates concentrate, then applies a per-coordinate scalar quantizer. Quantized
vectors keep their dot products / cosine similarity, so they can be stored 3–12×
smaller in a vector DB. Set `quantization` to enable it:
- `turbo`/`turbo8` = 8-bit (near-lossless, ~3×), `turbo6`, `turbo4` (~6×), `turbo2` (~12×).
- With `encoding_format: "float"` (default) the response returns the **lossy
reconstructed** float vectors (same shape) — drop-in, behaves like a quantized store.
- With `encoding_format: "base64"` each `embedding` is the **compact packed bytes**
(`[float16 norm][packbits(b-bit rotated codes)]`), and the response carries a
top-level `quantization` block (`bits`, `seed`, `dim`, `dim_padded`, `radius`,
`bytes_per_vector`, `layout`) describing how to decode them.
The implementation backend is chosen per embedding model in the admin **Models**
config (TurboQuant section): `builtin` (NumPy, always available) or `library`
(the optional `turboquant-py[torch]` package, which adds the paper's QJL stage).
TurboQuant must be enabled for the model, or a request `quantization` is rejected
with HTTP 400. Selecting the `library` backend when the package is not installed
also returns HTTP 400 rather than silently degrading.
```bash
curl -s "$CODERAI_URL/v1/embeddings" \
-H "Authorization: Bearer $CODERAI_TOKEN" \
-H "Content-Type: application/json" \
-d '{
"model": "BAAI/bge-small-en-v1.5",
"input": ["first document", "second document"],
"quantization": "turbo4",
"encoding_format": "base64"
}' | jq
```
## Character Profiles ## Character Profiles
Character profiles are named collections of reference images used for visual identity consistency in image/video generation. Character profiles are named collections of reference images used for visual identity consistency in image/video generation.
...@@ -1282,20 +1325,71 @@ curl -s "$CODERAI_URL/v1/loras/progress" \ ...@@ -1282,20 +1325,71 @@ curl -s "$CODERAI_URL/v1/loras/progress" \
### LoRA Registry ### LoRA Registry
- `GET /v1/loras` - `GET /v1/loras` — list registered LoRAs (name, weight path, metadata)
- `GET /v1/loras/{name}` - `GET /v1/loras/{name}` — fetch one registered LoRA
- `DELETE /v1/loras/{name}` - `DELETE /v1/loras/{name}` — delete a registered LoRA
### Upload LoRA Weights
`POST /v1/loras/upload`
Upload a LoRA file into a **content-addressed (sha256) blob store** so a client on a
different machine can use it without sharing the server's filesystem. Accepts the file
in three ways:
- **multipart/form-data** with a `file` field,
- **JSON** `{"file": "<base64>"}` (a `data:` URI is also accepted; `data` is an alias),
- a **raw** request body (the bytes of the `.safetensors`).
Returns `{"id": "sha256:<hex>", "bytes": <n>, "existed": <bool>}`. Reference the returned
`id` in any image/video request via `"loras": [{"id": "sha256:<hex>", "weight": ...}]`.
```bash
curl -s "$CODERAI_URL/v1/loras/upload" \
-H "Authorization: Bearer $CODERAI_TOKEN" \
-F "file=@./alice_identity.safetensors" | jq
# → {"id":"sha256:1f3b…","bytes":18874368,"existed":false}
```
### Check Uploaded Blob
`GET /v1/loras/blob/{hash}`
Use a trained LoRA in image/video requests: Existence check for an uploaded blob — `200` with `{id, bytes, exists}` when present,
`404` when absent — so a client can skip re-uploading a file the server already has.
`hash` may be a bare hex sha256 or `sha256:<hex>`.
### LoRA references in requests
The `loras` array in image (`/v1/images/generations`) and video
(`/v1/video/generations`) requests accepts LoRA weights supplied in several ways. The
server resolves each entry in this **priority order**:
| Field | Example | Meaning |
|---|---|---|
| `id` | `"name:alice_identity"` | A registered/trained LoRA by name |
| `id` | `"sha256:1f3b…"` | An uploaded blob (from `/v1/loras/upload`) |
| `file` / `data` | `"<base64>"` or `data:` URI | Inline weights, sent with the request |
| `url` | `"https://…/lora.safetensors"` | Server downloads and caches it |
| `model` / `path` | `"/path/to/lora.safetensors"` or HF id | Legacy local path / HF id (shared filesystem only) |
Common fields: `weight` (float, scale; default `1.0`) and `name` (optional adapter name).
```json ```json
{ {
"model": "image-model", "model": "image-model",
"prompt": "alice_person in a cyberpunk alley", "prompt": "alice_person in a cyberpunk alley",
"loras": [{"model": "alice_identity", "weight": 0.85}] "loras": [
{"id": "name:alice_identity", "weight": 0.85},
{"id": "sha256:1f3b9c…", "weight": 0.6, "name": "jacket"}
]
} }
``` ```
> The previous `{"model": "alice_identity"}` form still works, but prefer `id`
> (`"name:<registered>"`) or an uploaded `sha256:` blob so requests don't depend on the
> client and server sharing a filesystem.
## 2D / 3D / Spatial APIs ## 2D / 3D / Spatial APIs
### Image to 3D ### Image to 3D
...@@ -1706,7 +1800,8 @@ curl -s "$CODERAI_URL/admin/api/tokens" \ ...@@ -1706,7 +1800,8 @@ curl -s "$CODERAI_URL/admin/api/tokens" \
| `GET` | `/admin/api/model-loaded-status` | none | Loaded model / pool info | | `GET` | `/admin/api/model-loaded-status` | none | Loaded model / pool info |
| `POST` | `/admin/api/model-load` | `{path}` | Load model now | | `POST` | `/admin/api/model-load` | `{path}` | Load model now |
| `POST` | `/admin/api/model-unload` | `{path}` | Unload model | | `POST` | `/admin/api/model-unload` | `{path}` | Unload model |
| `POST` | `/admin/api/model-configure` | model config JSON | Configure model | | `POST` | `/admin/api/model-configure` | model config JSON | Configure model (incl. the `acceleration` block — see [Acceleration and Distillation](#acceleration-and-distillation)) |
| `GET` | `/admin/api/accel-presets` | none | Catalog of acceleration/distillation presets (Lightning, Lightx2v, Turbo, LCM, Hyper-SD) |
Download with SSE progress: Download with SSE progress:
...@@ -1768,6 +1863,76 @@ Logged-in users can access profile metadata through admin routes: ...@@ -1768,6 +1863,76 @@ Logged-in users can access profile metadata through admin routes:
| `GET` | `/admin/api/voices/{name}` | Voice detail | | `GET` | `/admin/api/voices/{name}` | Voice detail |
| `DELETE` | `/admin/api/voices/{name}` | Delete voice | | `DELETE` | `/admin/api/voices/{name}` | Delete voice |
## Acceleration and Distillation
Image and video models can be configured to use a **distillation adapter** (Lightning,
Lightx2v / phased DMD, SDXL-Turbo, LCM-LoRA, Hyper-SD). When enabled, the distill LoRA
is **fused into the pipeline at load time** and the correct low step-count / low-guidance
/ scheduler defaults are applied at generation time — cutting inference from ~25–50 steps
to **1–8 steps at guidance ≈ 1.0** (a 5–10× speedup). It is orthogonal to per-request
character LoRAs, which still apply on top.
This is **per-model configuration** (set via `POST /admin/api/model-configure` or the
admin Models page), not a per-request field. The catalog of presets is served by
`GET /admin/api/accel-presets`.
The `acceleration` block in a model's config:
```json
"acceleration": {
"enabled": true,
"preset": "wan22_lightning_4step",
"lora": "lightx2v/Wan2.2-Lightning",
"lora_weight": 1.0,
"steps": 4,
"guidance_scale": 1.0,
"flow_shift": 5.0,
"scheduler": ""
}
```
- `enabled``false` or an absent block means no change to current behaviour.
- `preset` — a catalog key (below) or `"custom"`. When not `"custom"`, unset fields are
filled from the preset; any explicit field overrides it.
- `lora` — distill LoRA path or HF repo (`repo` or `repo:weight_name.safetensors`).
`null` for full-model presets such as SDXL-Turbo.
- `steps` / `guidance_scale` — defaults applied when the request omits them.
- `flow_shift` — optional Wan flow-match scheduler shift.
- `scheduler` — optional scheduler class override (e.g. `LCMScheduler`).
Preset catalog (`GET /admin/api/accel-presets`):
| Preset key | Applies to | Steps | Guidance | Notes |
|---|---|---:|---:|---|
| `wan22_lightning_4step` | video | 4 | 1.0 | Wan2.2 Lightning (4-step DMD) |
| `wan21_lightx2v_4step` | video | 4 | 1.0 | Wan2.1 Lightx2v (4-step) |
| `sdxl_lightning_4step` | image | 4 | 1.0 | SDXL-Lightning (4-step) |
| `sdxl_lightning_8step` | image | 8 | 1.0 | SDXL-Lightning (8-step) |
| `sdxl_turbo` | image | 4 | 1.0 | SDXL-Turbo (full model, 1–4 step) |
| `sdxl_lcm` | image | 6 | 1.5 | SDXL LCM-LoRA (`LCMScheduler`) |
| `hyper_sdxl_8step` | image | 8 | 1.0 | Hyper-SD SDXL (8-step) |
| `sd15_lcm` | image | 6 | 1.5 | SD1.5 LCM-LoRA (`LCMScheduler`) |
> The preset LoRA repo ids are best-effort defaults; override `lora` (and any numeric
> field) per model. A LoRA-fuse failure is logged and generation proceeds un-accelerated.
> sd.cpp models get the step/guidance defaults and optional `<lora:…>` prompt injection
> (more limited than diffusers).
### KV-cache quantization (GGUF text models)
GGUF/llama.cpp text models can quantize the **KV cache** to fit longer contexts in
less VRAM. This is **per-model configuration** (set via `POST /admin/api/model-configure`
or the admin Models UI), independent of the weight-quantization flags:
```json
"cache_type_k": "q8_0",
"cache_type_v": "q8_0"
```
Accepted values: `q8_0` (near-lossless, ~2× smaller KV), `q5_1`, `q5_0`, `q4_1`,
`q4_0` (smallest), or omit/blank for the default `f16`. A sub-8-bit **value** cache
(`q5_*`/`q4_*`) requires flash attention; CoderAI auto-enables it for that model.
## AISBF / Broker Integration ## AISBF / Broker Integration
CoderAI exposes: CoderAI exposes:
......
# CoderAI Distribution Plan
> **Purpose of this document.** This is an implementation-ready specification for
> distributing CoderAI as **prebuilt, no-compile, no-venv-recreation artifacts** that run
> on a freshly installed machine after a simple download/extract/install. It is detailed
> enough that a future request of *"implement DISTRIBUTION.md"* can be executed end to end
> without re-deriving the design. When implementing, treat the **file layouts, commands,
> version pins, and CI skeletons below as the contract**; deviate only where a TODO or an
> "Open parameters" item explicitly leaves a choice.
---
## 1. Goals and non-goals
### Goals
- **No compilation on the user's machine.** All native modules are prebuilt in CI.
- **No venv creation / dependency resolution on install.** Ship a self-contained Python
interpreter with every package already installed.
- **Runs on "any" reasonably modern Linux** (≈ 2021+) and on **Windows 10/11**, after a
download + extract (Linux tarball), `docker run` (OCI image), or a double-click installer
(Windows).
- **Three artifacts from one CI pipeline:**
1. Linux **relocatable tarball** (`.tar.zst`) — extract and run, no Docker.
2. Linux **OCI image** (Docker/Podman) — most robust GPU path.
3. Windows **installer** (`.exe`, Inno Setup) — easy install + launcher.
- **Backend coverage: NVIDIA (CUDA) + Vulkan + CPU** (matches today's `all` venv).
### Non-goals
- **Bundling models.** Models are large and `township_output/`-style generated media are
gitignored. Models are downloaded at runtime via the admin UI / Hugging Face. The
artifacts ship with an **empty `models/` dir**.
- **Bundling GPU drivers.** The host must provide the GPU **driver** (NVIDIA driver, or a
Vulkan ICD via mesa/AMDVLK/Intel). See the GPU contract (§4).
- **macOS** (tracked separately; `osxbuild.sh` exists). This doc may be extended later.
- **ROCm/HIP** as a first-class bundled backend (AMD users go through Vulkan).
---
## 2. Why the current artifacts are not portable (constraints)
Established by inspecting the live build environment:
| Fact | Value | Consequence |
|---|---|---|
| Build-host glibc | **2.42** (Debian testing) | Anything compiled/linked here needs glibc ≥ 2.42. Ubuntu 24.04 = 2.39, Debian 12 = 2.36, RHEL 9 = 2.34 → **current binaries won't run there.** Fix: build in an **old-glibc container**. |
| `venv_all` size | **~11 GB** | nvidia wheels 2.7G + torch 1.2G + flash-attn 0.9G + triton 0.6G + onnxruntime 0.43G + llama_cpp 0.3G. Download size is inherent. |
| Python | `/usr/bin/python3.13`, `pyvenv.cfg home=/usr/bin` | A venv has **no stdlib of its own**; it points back at system Python. Copying a venv to a machine lacking that exact 3.13 fails. Fix: ship a **standalone interpreter** (its own stdlib). |
| CUDA runtime | bundled in `site-packages/nvidia/*` (2.7G) | Host needs **only the NVIDIA driver**, not the CUDA toolkit. ✓ |
| Compiled native modules | `llama_cpp`, `stable_diffusion_cpp`, `whispercpp`, `flash_attn`, `causal_conv1d` | These are glibc/ABI-bound and link CUDA/Vulkan. They are what CI must prebuild. |
**The core principle:** portability comes from *where and how you build* (old glibc, standalone
Python), not from a packing trick applied to the current env. `conda-pack`/`venv-pack` of the
current 2.42 env would only run on equally-new distros and is therefore rejected as the primary
mechanism.
---
## 3. Architecture: one pipeline, three artifacts
```
┌─────────────────────────────────────────────┐
│ Native-module build (per-OS, old toolchain) │
│ → wheels: llama_cpp, stable_diffusion_cpp, │
│ whispercpp (CUDA + Vulkan) │
└───────────────┬──────────────┬──────────────┘
│ wheels │ wheels
┌────────────────────────┴───┐ ┌────┴───────────────────────┐
LINUX → │ assemble standalone Python │ │ WINDOWS embeddable Python │ ← WINDOWS
│ (python-build-standalone) │ │ + all wheels (uv) │
│ + all wheels (uv) │ └────┬────────────────────────┘
└───────┬───────────┬─────────┘ │
│ │ │
① tar.zst │ ② OCI image (buildx) │ ③ Inno Setup .exe
(extract │ (docker/podman) │ (installer + launcher)
& run) │ │
```
Shared design across all three:
- **Standalone Python** (no system Python, no venv recreation).
- **CUDA via pip `nvidia-*` wheels** → host needs only the driver.
- **Vulkan via the host loader** (`libvulkan.so.1` / `vulkan-1.dll`) → host needs a GPU ICD.
- **Models never bundled.**
- The existing `codai/platform_paths.py` already abstracts per-OS paths — rely on it; do not
hardcode paths in launchers beyond the install root.
---
## 4. GPU / runtime contract (document this for users verbatim)
| Backend | What ships in the artifact | What the host must provide |
|---|---|---|
| **CUDA (NVIDIA)** | `torch` cu-wheels + `nvidia-*` runtime wheels + `llama_cpp`/`sd_cpp`/`whispercpp` compiled with CUDA | **NVIDIA driver only** (no CUDA toolkit). OCI: NVIDIA Container Toolkit. |
| **Vulkan (any GPU)** | native modules compiled with `GGML_VULKAN=ON` / `SD_VULKAN` | Host **Vulkan loader + ICD**: Linux `libvulkan1` + `mesa-vulkan-drivers`/AMDVLK/NVIDIA; Windows `vulkan-1.dll` (driver-provided, always present). |
| **CPU** | CPU paths of every module | Nothing beyond glibc / VC++ runtime. Always the fallback. |
Never bundle the GPU driver or the host Vulkan ICD. Bundling the Vulkan **loader** is allowed
but discouraged (prefer dlopen of the host's, so the host's ICD is found).
---
## 5. Shared building blocks
### 5.1 Standalone Python (interpreter source of truth)
- Use **`python-build-standalone`** (astral / `indygreg`) release matching the repo's Python
minor (currently **3.13**). Use the **`install_only`** variant.
- Linux x86_64: `cpython-3.13.<patch>+<date>-x86_64-unknown-linux-gnu-install_only.tar.gz`
- Windows x86_64: `cpython-3.13.<patch>+<date>-x86_64-pc-windows-msvc-install_only.tar.gz`
- These are **relocatable** and contain their own stdlib. Pin the exact release in
`packaging/versions.env` (see §9).
- Install packages into it with **`uv pip install`** (fast, deterministic, offline-capable
from a wheel cache). `uv` is fetched as a standalone binary in CI.
### 5.2 Native-module wheels (the only things CI compiles)
Compile each into a wheel (`uv pip wheel` / `pip wheel`) using the same CMAKE flags the current
`build.sh` uses:
| Module | Linux CMAKE_ARGS | Windows CMAKE_ARGS | Notes |
|---|---|---|---|
| `llama-cpp-python` | `-DGGML_VULKAN=ON -DGGML_CUDA=ON` | `-DGGML_VULKAN=ON -DGGML_CUDA=ON` | one wheel covers CUDA+Vulkan+CPU |
| `stable-diffusion-cpp-python` | `-DSD_VULKAN=ON -DSD_CUDA=ON -DSD_WEBM=OFF` (or `-DSD_USE_SYSTEM_WEBM=ON`) | `-DSD_CUDA=ON -DSD_VULKAN=ON -DSD_WEBM=OFF` | the libwebm submodule fix is mandatory (see build.sh:94-101) |
| `whispercpp` | `-DWHISPER_VULKAN=ON -DGGML_VULKAN=ON` | `-DWHISPER_VULKAN=ON -DGGML_VULKAN=ON` | built from source in build.sh:288 |
**Optional/extras (NOT in base bundle):** `flash-attn` (~0.9G, CUDA-arch-specific, slow/fragile
on old glibc), `bitsandbytes`, `causal_conv1d`, `turboquant-py[torch]` (optional TurboQuant
embedding-quantization backend; the built-in NumPy backend works without it). Ship these as a
separate **"cuda-extras"** download/layer the user can opt into; the app must already degrade
gracefully without them.
### 5.3 Pure/prebuilt wheels (no compilation)
`torch`, `torchvision`, `torchaudio`, `nvidia-*`, `triton`, `transformers`, `diffusers`,
`accelerate`, `onnxruntime`(-gpu), `insightface`, `sentence-transformers`, plus everything in
`requirements.txt` / `requirements-nvidia.txt` / `requirements-vulkan.txt`. These resolve to
manylinux/win wheels from PyPI + the PyTorch index. No build step.
### 5.4 Build base images (old glibc = forward compatibility)
- **Primary build base: glibc 2.31** (Debian 11 "bullseye" or `manylinux_2_31` equivalent).
Runs on Ubuntu 20.04+, Debian 11+, RHEL 9+ — ≈ everything from 2021. Easier to install the
Vulkan SDK + CUDA toolkit than older bases.
- **Fallback for wider reach: `manylinux_2_28`** (glibc 2.28). Older toolchain; use only if a
user needs pre-2021 distros.
- Pin the chosen base digest in `packaging/versions.env`.
---
## 6. Repository layout to create
```
packaging/
versions.env # pinned versions (python-build-standalone, uv, cuda, vulkan sdk, base image digests)
common/
requirements.lock # fully pinned, hash-locked resolution (uv pip compile output)
assemble_env.sh # shared: lay down standalone python + uv pip install wheels + app (Linux/macOS)
assemble_env.ps1 # shared: Windows equivalent
app_payload.txt # include/exclude globs for the codai/ app copy (excludes models, __pycache__, township_output, venv_*)
linux/
Dockerfile.build # old-glibc + CUDA + Vulkan SDK → compiles native wheels into /wheels
build_native_wheels.sh # invoked inside Dockerfile.build
make_tarball.sh # assemble standalone-python bundle → coderai-linux-x64.tar.zst
launcher/coderai # runtime launcher (sets PYTHONHOME/LD_LIBRARY_PATH, execs python coderai)
Dockerfile.runtime # slim runtime image (artifact ②)
windows/
build_native_wheels.ps1 # MSVC + CUDA + Vulkan SDK → native wheels
make_bundle.ps1 # assemble embeddable python + wheels + app → staging dir
installer.iss # Inno Setup script (artifact ③)
launcher/coderai-launcher.ps1 # starts server, opens browser to admin UI
ci/
.github/workflows/release.yml # orchestrates all three (copy into .github/workflows/ when implementing)
DISTRIBUTION.md # this file
```
> When implementing, **create `packaging/` and move the new scripts there**; do not bloat the
> existing top-level `build.sh`/`build.ps1` (those remain the *developer* build path). The
> distribution scripts are a separate, CI-oriented track that consumes prebuilt wheels.
---
## 7. Artifact ① — Linux relocatable tarball
### 7.1 Final layout (what the user extracts)
```
coderai/
python/ # python-build-standalone tree: bin/python3, lib/python3.13/, incl. all site-packages
app/ # copy of repo (codai/, coderai launcher, templates, static) minus excludes
models/ # empty; CODERAI_MODELS_DIR points here by default
bin/coderai # launcher (chmod +x)
VERSION
README-RUN.txt
```
### 7.2 Launcher `bin/coderai` (exact behavior)
```sh
#!/bin/sh
# Resolve the install root from this script's location (handles symlinks).
HERE="$(cd "$(dirname "$(readlink -f "$0")")/.." && pwd)"
export PYTHONHOME="$HERE/python"
# Bundled CUDA runtime libs ship inside the nvidia/* wheels:
NV="$HERE/python/lib/python3.13/site-packages/nvidia"
LIBS="$HERE/python/lib"
if [ -d "$NV" ]; then
for d in "$NV"/*/lib; do LIBS="$LIBS:$d"; done
fi
export LD_LIBRARY_PATH="$LIBS${LD_LIBRARY_PATH:+:$LD_LIBRARY_PATH}"
# Keep all state inside the bundle by default. codai/platform_paths.py honors the XDG
# vars (user_config_dir/user_data_dir/user_cache_dir), so point them at the install root
# unless the user already set them. (If a unified CODERAI_HOME is added app-side later,
# set that instead — see §12.)
export XDG_CONFIG_HOME="${XDG_CONFIG_HOME:-$HERE/config}"
export XDG_DATA_HOME="${XDG_DATA_HOME:-$HERE/data}"
export XDG_CACHE_HOME="${XDG_CACHE_HOME:-$HERE/cache}"
exec "$HERE/python/bin/python3" "$HERE/app/coderai" "$@"
```
- Do **not** rely on `activate`; call the interpreter directly with `PYTHONHOME` set.
- RPATHs of the compiled native `.so`s should be `$ORIGIN`-relative (patchelf in CI if needed)
so they find their sibling ggml/CUDA libs without `LD_LIBRARY_PATH` gymnastics; the
`LD_LIBRARY_PATH` above is the belt-and-suspenders for the nvidia wheels.
### 7.3 Build steps (`packaging/linux/make_tarball.sh`)
1. `docker build -f packaging/linux/Dockerfile.build` → produces `/wheels/*.whl` (native) +
exports them to the host (`docker create` + `docker cp`, or a buildx `--output`).
2. Download + extract `python-build-standalone` into `coderai/python/`.
3. Fetch `uv` static binary.
4. `uv pip install --python coderai/python/bin/python3 -r packaging/common/requirements.lock
--find-links /wheels` (native wheels resolved locally; the rest from PyPI + torch index).
5. Copy app via `packaging/common/app_payload.txt` includes into `coderai/app/`.
6. Drop in `bin/coderai`, `VERSION`, `README-RUN.txt`; `mkdir coderai/models`.
7. **Prune**: `__pycache__`, `*.pyi` test dirs, `pip`, `*.dist-info/RECORD` optional, `.a`
static libs, duplicate `nvidia/*` headers — to shrink. (Document expected size ≈ 4–5 GB
zstd.)
8. `tar --zstd -cf dist/coderai-linux-x64.tar.zst coderai/` + `sha256sum`.
### 7.4 User experience
```sh
tar --zstd -xf coderai-linux-x64.tar.zst
./coderai/bin/coderai # starts server; prints http://127.0.0.1:<port>/admin
```
README-RUN.txt documents: NVIDIA → install driver; AMD/Intel Vulkan →
`sudo apt install libvulkan1 mesa-vulkan-drivers`; no GPU → works on CPU.
---
## 8. Artifact ② — Linux OCI image
### 8.1 Dockerfile.runtime (multi-stage; reuse the wheels from §7 step 1)
```dockerfile
# ---- stage 1: wheels come from Dockerfile.build (or COPY --from a wheels image) ----
FROM debian:bookworm-slim AS runtime
RUN apt-get update && apt-get install -y --no-install-recommends \
libvulkan1 mesa-vulkan-drivers libgomp1 ffmpeg ca-certificates \
&& rm -rf /var/lib/apt/lists/*
# standalone python + installed packages assembled exactly like the tarball's coderai/python:
COPY --from=assembler /opt/coderai /opt/coderai
ENV PYTHONHOME=/opt/coderai/python \
PATH=/opt/coderai/python/bin:$PATH \
XDG_CONFIG_HOME=/config \
XDG_DATA_HOME=/models \
XDG_CACHE_HOME=/cache
# (If a unified CODERAI_HOME override is added app-side per §12, set it here instead.)
EXPOSE 8776
VOLUME ["/models", "/config", "/cache"]
ENTRYPOINT ["/opt/coderai/python/bin/python3", "/opt/coderai/app/coderai"]
```
- Base is plain Debian slim (CUDA runtime arrives via pip nvidia wheels). NVIDIA Container
Toolkit injects the driver at `--gpus all`.
- Provide CPU/Vulkan-only and CUDA variants if size matters (tag suffixes `:cpu`, `:cuda`).
### 8.2 Run
```sh
# NVIDIA
docker run --gpus all -p 8776:8776 -v $PWD/models:/models -v $PWD/config:/config ghcr.io/<org>/coderai:latest
# AMD/Intel via Vulkan
docker run --device /dev/dri -p 8776:8776 -v $PWD/models:/models ghcr.io/<org>/coderai:latest
# CPU
docker run -p 8776:8776 -v $PWD/models:/models ghcr.io/<org>/coderai:latest
```
Publish to **GHCR**; also `docker save | zstd` a loadable tarball attached to the Release for
offline/Podman use.
---
## 9. Artifact ③ — Windows installer
### 9.1 Strategy
Prebuild native wheels on a `windows-latest` CI runner (MSVC Build Tools + CUDA toolkit +
Vulkan SDK), assemble a standalone-Python bundle, wrap in an **Inno Setup** installer. No
compiler on the user's machine.
### 9.2 Bundle assembly (`packaging/windows/make_bundle.ps1`)
1. `build_native_wheels.ps1` compiles `llama-cpp-python`, `stable-diffusion-cpp-python`,
`whispercpp` with the CMAKE flags from §5.2 → `wheels\`.
2. Extract `python-build-standalone` Windows `install_only` into `staging\python\`.
3. `uv pip install --python staging\python\python.exe -r packaging\common\requirements.lock
--find-links wheels`.
4. Copy app (per `app_payload.txt`) into `staging\app\`.
5. Add `staging\models\` (empty), `staging\coderai-launcher.ps1`, icon, VERSION.
### 9.3 Launcher (`coderai-launcher.ps1` → wrapped as `coderai.exe` via a tiny shim or `ps2exe`)
- Sets `PYTHONHOME=<install>\python`, `PATH` to include `python` + bundled `nvidia\*\bin`.
- Starts `python.exe app\coderai` minimized (or as a background process / tray icon).
- Waits for the port, then `Start-Process http://127.0.0.1:<port>/admin`.
- A second shortcut "Stop CoderAI" kills the process.
### 9.4 Inno Setup (`installer.iss`) requirements
- Install to `{localappdata}\Programs\CoderAI` (no admin) **or** `{autopf}\CoderAI` (admin) —
default to per-user to avoid UAC.
- Bundle + silently install the **VC++ 2015-2022 redistributable** if absent (check the
registry key; run `vc_redist.x64.exe /quiet /norestart`).
- Create Start-Menu group + desktop shortcut → the launcher.
- Register an uninstaller (Inno does this automatically); clean removal of the whole tree.
- Two installer SKUs:
- **Offline** `CoderAI-Setup-x64.exe` (~5–8 GB) — everything embedded.
- **Web/stub** `CoderAI-WebSetup-x64.exe` (small) — downloads the GPU payload (the big
`nvidia-*`/torch wheels) on first run from the Release assets; good for users who don't
need GPU or want a small download. Implemented via Inno's `[Code]` + `idpDownload` or a
first-run step in the launcher.
- GPU notes shown on the finish page: NVIDIA driver for CUDA; Vulkan works via any modern GPU
driver (vulkan-1.dll is driver-provided).
### 9.5 Windows GPU coverage (confirmed feasible)
- **CUDA**: identical to Linux — CUDA DLLs ship in the pip `nvidia-*`/torch cu-wheels; host
needs only the NVIDIA driver.
- **Vulkan**: `vulkan-1.dll` is provided by every GPU driver (NVIDIA/AMD/Intel) in System32 →
the Vulkan-compiled `.pyd` works across vendors.
- **CPU**: always.
### 9.6 Rejected Windows options (record so we don't revisit)
- **PyInstaller `--onefile`** (current `build.ps1 --package`): fragile with torch/diffusers,
~11 GB self-extract per launch, AV false positives. Keep only as a fallback build.
- **MSIX**: sandbox fights the large ML stack + GPU access.
- **WSL2 + reuse Linux artifacts**: easiest to produce, GPU works (CUDA in WSL2), but "enable
WSL2" is not double-click-easy → document as a power-user path, not the default.
---
## 10. CI pipeline (`packaging/ci/.github/workflows/release.yml`)
Trigger: `on: push: tags: ['v*']` + manual `workflow_dispatch`.
Jobs:
1. **lock**`uv pip compile` the merged requirements (`requirements.txt` +
`requirements-nvidia.txt` + `requirements-vulkan.txt` + the native modules as `--find-links`
placeholders) into `packaging/common/requirements.lock` (hash-pinned). Cache it.
2. **linux-native-wheels** — build `packaging/linux/Dockerfile.build` (glibc 2.31 + CUDA +
Vulkan SDK); export `/wheels`. Cache by hash of the three native module versions.
3. **linux-tarball** — needs (1)(2); run `make_tarball.sh`; upload `*.tar.zst` + `.sha256`.
4. **linux-image** — needs (1)(2); `docker buildx build --push` to GHCR (`:latest`, `:vX.Y.Z`,
`:cpu`, `:cuda`); also `docker save | zstd` artifact.
5. **windows-native-wheels**`windows-latest`; install MSVC + CUDA + Vulkan SDK (chocolatey /
official installers, pinned); run `build_native_wheels.ps1`; upload `wheels\`.
6. **windows-installer** — needs (5); `make_bundle.ps1` then compile `installer.iss` with the
Inno Setup CLI (`iscc`); upload both installer SKUs.
7. **release** — needs (3)(4)(6); create/attach all assets to the GitHub Release; publish
checksums + this doc's "GPU contract" as release notes.
Caching: key the native-wheel jobs on `(module versions, base image digest, python version)` so
they only rebuild when those change — the expensive compile happens rarely.
---
## 11. Version pins (`packaging/versions.env`)
Fill these at implementation time; keep them the single source of truth referenced by every
script.
```
PYTHON_VERSION=3.13 # must match repo (currently 3.13.12)
PBS_RELEASE= # python-build-standalone tag, e.g. 20250xxx
UV_VERSION= # astral uv pinned version
LINUX_BUILD_BASE= # e.g. debian:11-slim@sha256:... (glibc 2.31)
CUDA_VERSION= # toolkit used to compile native wheels (match torch cuXX)
VULKAN_SDK_VERSION= # LunarG SDK pinned
LLAMA_CPP_PYTHON_VERSION=
SD_CPP_PYTHON_VERSION=
WHISPERCPP_REF= # git ref for whisper.cpp python bindings
INNO_SETUP_VERSION=
```
Keep `CUDA_VERSION` consistent with the `torch` cuXX wheels (e.g. torch cu124 ↔ CUDA 12.4) so
the bundled `nvidia-*` runtime matches what the native modules were compiled against.
---
## 12. App-side changes likely needed (verify during implementation)
- **`coderai` entry** already works as `python coderai` (it imports `codai.main.main`). Confirm
it runs headless with no extra args and binds `config.server.host:port`.
- **Config/models dir.** Today `codai/platform_paths.py` derives locations from the **XDG**
vars on Linux (`XDG_CONFIG_HOME`/`XDG_DATA_HOME`/`XDG_CACHE_HOME`, with
`user_config_dir()`/`user_data_dir()`/`user_cache_dir()`) and from Windows dir vars via
`_windows_dir(env_var, fallback)`. `codai/config.py` builds everything under a single
`config_dir` (`config.json`, `models.json`, `auth.json`, `pipelines.json`). The launchers
in §7/§9 therefore point the **XDG vars** (Linux) / the Windows dir vars at the install
root — this works **without any app change**. *Optional improvement:* add a single
`CODERAI_HOME` env var that overrides all of them at once (cleaner for the container's
`/config` + `/models` volumes); if added, update the launchers to set it instead. Verify
how `main.py` chooses `config_dir` (CLI arg vs default) so the env override actually wins.
- **First-run UX**: with empty `models/`, the server must start and the admin UI must let the
user download a model (it already has `/admin/api/model-download`). Verify no model is
required at boot.
- **Graceful absence of extras** (`flash-attn`, `bitsandbytes`): confirm import guards so the
base bundle runs without them.
- **Whisper server binary**: build.sh also compiles a standalone `whisper-server`; decide
whether the bundle ships it or uses the in-process `whispercpp` wheel. Prefer the wheel to
avoid shipping a second native binary.
---
## 13. Verification matrix (acceptance criteria)
Run after building each artifact. "Pass" = server boots, `/v1/models` responds, a tiny CPU
text generation and a small image generation succeed.
| Target | Host | GPU path to verify |
|---|---|---|
| tarball | Ubuntu 22.04 clean | CPU; NVIDIA (driver only); AMD Vulkan (mesa) |
| tarball | Debian 12 clean | CPU; NVIDIA |
| tarball | Fedora/RHEL 9 clean | CPU (glibc 2.34 ≥ build glibc) |
| OCI image | Docker + NVIDIA Container Toolkit | `--gpus all` CUDA |
| OCI image | Podman | CPU + `--device /dev/dri` Vulkan |
| Windows installer | Win 11 clean, NVIDIA | CUDA + browser auto-open |
| Windows installer | Win 10, AMD | Vulkan + CPU |
Each must require **no compiler, no pip, no manual venv** on the target. Capture install size
and cold-start time.
---
## 14. Open parameters (decide at implementation, defaults in **bold**)
1. Build glibc floor: **2.31 (Debian 11)** vs 2.28 (manylinux_2_28, wider/older).
2. `flash-attn`/`bitsandbytes`: **separate cuda-extras download** vs in-base (bigger).
3. Windows installer SKUs: **ship both** offline + web-stub vs offline only.
4. OCI variants: **single `all` image** first; add `:cpu`/`:cuda` slims if size complaints.
5. Compression: **zstd** (fast, good ratio) for the Linux tarball.
6. Registry: **GHCR**; mirror to Docker Hub optional.
7. Models dir default: **inside install root** (`coderai/models`) overridable by env/volume.
---
## 15. Out-of-scope follow-ups (note for later)
- macOS `.dmg` / notarized app (Metal backend; `osxbuild.sh` is the starting point).
- ARM64 builds (Linux aarch64 tarball + image) — same design, different base + wheels.
- Auto-update channel for the tarball/installer.
- Signed Windows installer (code-signing cert) to avoid SmartScreen warnings.
...@@ -167,7 +167,7 @@ def require_admin(request: Request) -> str: ...@@ -167,7 +167,7 @@ def require_admin(request: Request) -> str:
return username return username
@router.get("/login", response_class=HTMLResponse) @router.get("/login", response_class=HTMLResponse, summary="Admin login page")
async def login_page(request: Request): async def login_page(request: Request):
"""Display login page.""" """Display login page."""
# If already logged in, redirect to dashboard # If already logged in, redirect to dashboard
...@@ -178,7 +178,7 @@ async def login_page(request: Request): ...@@ -178,7 +178,7 @@ async def login_page(request: Request):
return _tmpl(request, "login.html", {"error": None}) return _tmpl(request, "login.html", {"error": None})
@router.post("/login") @router.post("/login", summary="Authenticate admin login")
async def login( async def login(
request: Request, request: Request,
username: str = Form(...), username: str = Form(...),
...@@ -211,7 +211,7 @@ async def login( ...@@ -211,7 +211,7 @@ async def login(
return response return response
@router.get("/logout") @router.get("/logout", summary="Log out")
async def logout(request: Request): async def logout(request: Request):
"""Handle logout.""" """Handle logout."""
if session_manager: if session_manager:
...@@ -223,7 +223,7 @@ async def logout(request: Request): ...@@ -223,7 +223,7 @@ async def logout(request: Request):
return response return response
@router.get("/admin/change-password", response_class=HTMLResponse) @router.get("/admin/change-password", response_class=HTMLResponse, summary="Change-password page")
async def change_password_page(request: Request, username: str = Depends(require_auth)): async def change_password_page(request: Request, username: str = Depends(require_auth)):
user = session_manager.get_user(username) user = session_manager.get_user(username)
must_change = user.get("must_change_password", False) if user else False must_change = user.get("must_change_password", False) if user else False
...@@ -235,7 +235,7 @@ async def change_password_page(request: Request, username: str = Depends(require ...@@ -235,7 +235,7 @@ async def change_password_page(request: Request, username: str = Depends(require
}) })
@router.post("/admin/change-password") @router.post("/admin/change-password", summary="Change admin password")
async def change_password( async def change_password(
request: Request, request: Request,
old_password: Optional[str] = Form(None), old_password: Optional[str] = Form(None),
...@@ -271,7 +271,7 @@ async def change_password( ...@@ -271,7 +271,7 @@ async def change_password(
return RedirectResponse(url=_url(request, "/admin"), status_code=302) return RedirectResponse(url=_url(request, "/admin"), status_code=302)
@router.get("/admin", response_class=HTMLResponse) @router.get("/admin", response_class=HTMLResponse, summary="Admin dashboard")
async def admin_dashboard(request: Request, username: str = Depends(require_auth)): async def admin_dashboard(request: Request, username: str = Depends(require_auth)):
is_admin = session_manager.is_admin(username) is_admin = session_manager.is_admin(username)
return _tmpl(request, "dashboard.html", { return _tmpl(request, "dashboard.html", {
...@@ -279,7 +279,7 @@ async def admin_dashboard(request: Request, username: str = Depends(require_auth ...@@ -279,7 +279,7 @@ async def admin_dashboard(request: Request, username: str = Depends(require_auth
}) })
@router.get("/admin/models", response_class=HTMLResponse) @router.get("/admin/models", response_class=HTMLResponse, summary="Models admin page")
async def models_page(request: Request, username: str = Depends(require_admin)): async def models_page(request: Request, username: str = Depends(require_admin)):
return _tmpl(request, "models.html", { return _tmpl(request, "models.html", {
"username": username, "username": username,
...@@ -288,12 +288,12 @@ async def models_page(request: Request, username: str = Depends(require_admin)): ...@@ -288,12 +288,12 @@ async def models_page(request: Request, username: str = Depends(require_admin)):
}) })
@router.get("/admin/tokens", response_class=HTMLResponse) @router.get("/admin/tokens", response_class=HTMLResponse, summary="API tokens admin page")
async def tokens_page(request: Request, username: str = Depends(require_admin)): async def tokens_page(request: Request, username: str = Depends(require_admin)):
return _tmpl(request, "tokens.html", {"username": username, "is_admin": True}) return _tmpl(request, "tokens.html", {"username": username, "is_admin": True})
@router.get("/admin/users", response_class=HTMLResponse) @router.get("/admin/users", response_class=HTMLResponse, summary="Users admin page")
async def users_page(request: Request, username: str = Depends(require_admin)): async def users_page(request: Request, username: str = Depends(require_admin)):
users = session_manager.list_users() users = session_manager.list_users()
return _tmpl(request, "users.html", { return _tmpl(request, "users.html", {
...@@ -301,7 +301,12 @@ async def users_page(request: Request, username: str = Depends(require_admin)): ...@@ -301,7 +301,12 @@ async def users_page(request: Request, username: str = Depends(require_admin)):
}) })
@router.get("/chat", response_class=HTMLResponse) @router.get("/admin/tasks", response_class=HTMLResponse, summary="Tasks admin page")
async def tasks_page(request: Request, username: str = Depends(require_admin)):
return _tmpl(request, "tasks.html", {"username": username, "is_admin": True})
@router.get("/chat", response_class=HTMLResponse, summary="Studio (chat) page")
async def chat_page(request: Request, username: str = Depends(require_auth)): async def chat_page(request: Request, username: str = Depends(require_auth)):
return _tmpl(request, "chat.html", { return _tmpl(request, "chat.html", {
"username": username, "is_admin": session_manager.is_admin(username), "username": username, "is_admin": session_manager.is_admin(username),
...@@ -309,7 +314,7 @@ async def chat_page(request: Request, username: str = Depends(require_auth)): ...@@ -309,7 +314,7 @@ async def chat_page(request: Request, username: str = Depends(require_auth)):
# API endpoints for admin operations # API endpoints for admin operations
@router.get("/admin/api/status") @router.get("/admin/api/status", summary="Server and model status")
async def api_status(username: str = Depends(require_auth)): async def api_status(username: str = Depends(require_auth)):
"""Get system status.""" """Get system status."""
from codai.models.manager import multi_model_manager from codai.models.manager import multi_model_manager
...@@ -486,7 +491,7 @@ async def api_status(username: str = Depends(require_auth)): ...@@ -486,7 +491,7 @@ async def api_status(username: str = Depends(require_auth)):
} }
@router.post("/admin/api/users") @router.post("/admin/api/users", summary="Create a user")
async def api_create_user( async def api_create_user(
request: Request, request: Request,
username: str = Depends(require_admin) username: str = Depends(require_admin)
...@@ -507,7 +512,7 @@ async def api_create_user( ...@@ -507,7 +512,7 @@ async def api_create_user(
return {"success": True} return {"success": True}
@router.delete("/admin/api/users/{user_id}") @router.delete("/admin/api/users/{user_id}", summary="Delete a user")
async def api_delete_user( async def api_delete_user(
user_id: int, user_id: int,
username: str = Depends(require_admin) username: str = Depends(require_admin)
...@@ -528,7 +533,7 @@ async def api_delete_user( ...@@ -528,7 +533,7 @@ async def api_delete_user(
# --- Token management endpoints --- # --- Token management endpoints ---
@router.get("/admin/api/tokens", response_model=list) @router.get("/admin/api/tokens", response_model=list, summary="List API tokens")
async def api_list_tokens(username: str = Depends(require_admin)): async def api_list_tokens(username: str = Depends(require_admin)):
"""List all API tokens.""" """List all API tokens."""
auth_data = session_manager._load_auth_data() auth_data = session_manager._load_auth_data()
...@@ -545,7 +550,7 @@ async def api_list_tokens(username: str = Depends(require_admin)): ...@@ -545,7 +550,7 @@ async def api_list_tokens(username: str = Depends(require_admin)):
return tokens return tokens
@router.post("/admin/api/tokens") @router.post("/admin/api/tokens", summary="Create an API token")
async def api_create_token(request: Request, username: str = Depends(require_admin)): async def api_create_token(request: Request, username: str = Depends(require_admin)):
"""Create a new API token.""" """Create a new API token."""
data = await request.json() data = await request.json()
...@@ -580,7 +585,7 @@ async def api_create_token(request: Request, username: str = Depends(require_adm ...@@ -580,7 +585,7 @@ async def api_create_token(request: Request, username: str = Depends(require_adm
} }
@router.delete("/admin/api/tokens/{token_id}") @router.delete("/admin/api/tokens/{token_id}", summary="Revoke an API token")
async def api_delete_token(token_id: int, username: str = Depends(require_admin)): async def api_delete_token(token_id: int, username: str = Depends(require_admin)):
"""Delete an API token.""" """Delete an API token."""
auth_data = session_manager._load_auth_data() auth_data = session_manager._load_auth_data()
...@@ -598,7 +603,7 @@ async def api_delete_token(token_id: int, username: str = Depends(require_admin) ...@@ -598,7 +603,7 @@ async def api_delete_token(token_id: int, username: str = Depends(require_admin)
# --- Models management endpoints --- # --- Models management endpoints ---
@router.get("/admin/api/models") @router.get("/admin/api/models", summary="List configured models")
async def api_list_models(username: str = Depends(require_admin)): async def api_list_models(username: str = Depends(require_admin)):
"""List all configured models with details.""" """List all configured models with details."""
models_data = session_manager._load_auth_data() # TODO: move to ModelManager models_data = session_manager._load_auth_data() # TODO: move to ModelManager
...@@ -921,7 +926,7 @@ def _run_download_thread(session_id: str, model_id: str, file_pattern: str, pq): ...@@ -921,7 +926,7 @@ def _run_download_thread(session_id: str, model_id: str, file_pattern: str, pq):
_t.Thread(target=_gc, daemon=True).start() _t.Thread(target=_gc, daemon=True).start()
@router.post("/admin/api/model-download") @router.post("/admin/api/model-download", summary="Download a model")
async def api_download_model( async def api_download_model(
request: Request, request: Request,
username: str = Depends(require_admin) username: str = Depends(require_admin)
...@@ -947,7 +952,7 @@ async def api_download_model( ...@@ -947,7 +952,7 @@ async def api_download_model(
return {"session_id": session_id} return {"session_id": session_id}
@router.get("/admin/api/download-stream/{session_id}") @router.get("/admin/api/download-stream/{session_id}", summary="Stream model download progress")
async def api_download_stream( async def api_download_stream(
session_id: str, session_id: str,
request: Request, request: Request,
...@@ -982,7 +987,7 @@ async def api_download_stream( ...@@ -982,7 +987,7 @@ async def api_download_stream(
) )
@router.delete("/admin/api/models/{model_identifier}") @router.delete("/admin/api/models/{model_identifier}", summary="Remove a configured model")
async def api_delete_model( async def api_delete_model(
model_identifier: str, model_identifier: str,
username: str = Depends(require_admin) username: str = Depends(require_admin)
...@@ -1001,7 +1006,7 @@ async def api_delete_model( ...@@ -1001,7 +1006,7 @@ async def api_delete_model(
# --- Download status / cache management --- # --- Download status / cache management ---
@router.get("/admin/api/hf-files") @router.get("/admin/api/hf-files", summary="List files in a Hugging Face repo")
async def api_hf_repo_files(repo_id: str, username: str = Depends(require_admin)): async def api_hf_repo_files(repo_id: str, username: str = Depends(require_admin)):
"""Return the file list for a HuggingFace repo with name and size metadata.""" """Return the file list for a HuggingFace repo with name and size metadata."""
import asyncio import asyncio
...@@ -1023,13 +1028,13 @@ async def api_hf_repo_files(repo_id: str, username: str = Depends(require_admin) ...@@ -1023,13 +1028,13 @@ async def api_hf_repo_files(repo_id: str, username: str = Depends(require_admin)
return await asyncio.to_thread(_fetch) return await asyncio.to_thread(_fetch)
@router.get("/admin/api/downloads") @router.get("/admin/api/downloads", summary="List active downloads")
async def api_list_downloads(username: str = Depends(require_admin)): async def api_list_downloads(username: str = Depends(require_admin)):
"""Return status of all active and recently completed download sessions.""" """Return status of all active and recently completed download sessions."""
return list(_download_status.values()) return list(_download_status.values())
@router.post("/admin/api/download-cancel/{session_id}") @router.post("/admin/api/download-cancel/{session_id}", summary="Cancel a download")
async def api_cancel_download(session_id: str, username: str = Depends(require_admin)): async def api_cancel_download(session_id: str, username: str = Depends(require_admin)):
"""Request cancellation of an active download session.""" """Request cancellation of an active download session."""
if session_id not in _download_sessions and session_id not in _download_status: if session_id not in _download_sessions and session_id not in _download_status:
...@@ -1038,7 +1043,7 @@ async def api_cancel_download(session_id: str, username: str = Depends(require_a ...@@ -1038,7 +1043,7 @@ async def api_cancel_download(session_id: str, username: str = Depends(require_a
return {"success": True} return {"success": True}
@router.post("/admin/api/model-upload") @router.post("/admin/api/model-upload", summary="Upload a model file")
async def api_model_upload(request: Request, username: str = Depends(require_admin)): async def api_model_upload(request: Request, username: str = Depends(require_admin)):
"""Upload a GGUF model file in chunks.""" """Upload a GGUF model file in chunks."""
from codai.models.cache import get_model_cache_dir from codai.models.cache import get_model_cache_dir
...@@ -1480,28 +1485,28 @@ def _do_delete_model(model_id: str, cache_type: str) -> dict: ...@@ -1480,28 +1485,28 @@ def _do_delete_model(model_id: str, cache_type: str) -> dict:
return {"success": False, "detail": "Unknown cache_type"} return {"success": False, "detail": "Unknown cache_type"}
@router.get("/admin/api/cached-models") @router.get("/admin/api/cached-models", summary="List cached models")
async def api_cached_models(username: str = Depends(require_admin)): async def api_cached_models(username: str = Depends(require_admin)):
"""Scan both caches and return all locally stored models.""" """Scan both caches and return all locally stored models."""
import asyncio import asyncio
return await asyncio.to_thread(_scan_caches) return await asyncio.to_thread(_scan_caches)
@router.get("/admin/api/cache-stats") @router.get("/admin/api/cache-stats", summary="Model cache statistics")
async def api_cache_stats(username: str = Depends(require_admin)): async def api_cache_stats(username: str = Depends(require_admin)):
"""Return disk-usage statistics for each cache.""" """Return disk-usage statistics for each cache."""
import asyncio import asyncio
return await asyncio.to_thread(_get_cache_stats) return await asyncio.to_thread(_get_cache_stats)
@router.delete("/admin/api/cache") @router.delete("/admin/api/cache", summary="Clear the model cache")
async def api_clear_cache(cache_type: str = "all", username: str = Depends(require_admin)): async def api_clear_cache(cache_type: str = "all", username: str = Depends(require_admin)):
"""Bulk-delete cache. cache_type: all | hf | gguf""" """Bulk-delete cache. cache_type: all | hf | gguf"""
import asyncio import asyncio
return await asyncio.to_thread(_do_clear_cache, cache_type) return await asyncio.to_thread(_do_clear_cache, cache_type)
@router.delete("/admin/api/cached-models/{model_id:path}") @router.delete("/admin/api/cached-models/{model_id:path}", summary="Evict a cached model")
async def api_delete_cached_model( async def api_delete_cached_model(
model_id: str, model_id: str,
cache_type: str = "hf", cache_type: str = "hf",
...@@ -1512,7 +1517,7 @@ async def api_delete_cached_model( ...@@ -1512,7 +1517,7 @@ async def api_delete_cached_model(
return await asyncio.to_thread(_do_delete_model, model_id, cache_type) return await asyncio.to_thread(_do_delete_model, model_id, cache_type)
@router.post("/admin/api/model-enable") @router.post("/admin/api/model-enable", summary="Enable a model")
async def api_model_enable(request: Request, username: str = Depends(require_admin)): async def api_model_enable(request: Request, username: str = Depends(require_admin)):
"""Register a cached model in models.json so CoderAI can use it.""" """Register a cached model in models.json so CoderAI can use it."""
if config_manager is None: if config_manager is None:
...@@ -1532,7 +1537,7 @@ async def api_model_enable(request: Request, username: str = Depends(require_adm ...@@ -1532,7 +1537,7 @@ async def api_model_enable(request: Request, username: str = Depends(require_adm
return {"success": True} return {"success": True}
@router.post("/admin/api/model-disable") @router.post("/admin/api/model-disable", summary="Disable a model")
async def api_model_disable(request: Request, username: str = Depends(require_admin)): async def api_model_disable(request: Request, username: str = Depends(require_admin)):
"""Remove a model from models.json (keeps it cached locally).""" """Remove a model from models.json (keeps it cached locally)."""
if config_manager is None: if config_manager is None:
...@@ -1568,7 +1573,7 @@ async def api_model_disable(request: Request, username: str = Depends(require_ad ...@@ -1568,7 +1573,7 @@ async def api_model_disable(request: Request, username: str = Depends(require_ad
return {"success": True} return {"success": True}
@router.get("/admin/api/model-loaded-status") @router.get("/admin/api/model-loaded-status", summary="Model load status")
async def api_model_loaded_status(username: str = Depends(require_admin)): async def api_model_loaded_status(username: str = Depends(require_admin)):
"""Return loaded model keys with per-model instance pool info.""" """Return loaded model keys with per-model instance pool info."""
from codai.models.manager import multi_model_manager from codai.models.manager import multi_model_manager
...@@ -1593,7 +1598,7 @@ async def api_model_loaded_status(username: str = Depends(require_admin)): ...@@ -1593,7 +1598,7 @@ async def api_model_loaded_status(username: str = Depends(require_admin)):
return {"loaded": loaded, "instances": instance_pools, "configured_max": configured_max} return {"loaded": loaded, "instances": instance_pools, "configured_max": configured_max}
@router.post("/admin/api/model-load") @router.post("/admin/api/model-load", summary="Load a model into memory")
async def api_model_load(request: Request, username: str = Depends(require_admin)): async def api_model_load(request: Request, username: str = Depends(require_admin)):
"""Load a configured model into VRAM (same VRAM checks as a real request).""" """Load a configured model into VRAM (same VRAM checks as a real request)."""
from codai.models.manager import multi_model_manager from codai.models.manager import multi_model_manager
...@@ -1755,7 +1760,7 @@ async def api_model_load(request: Request, username: str = Depends(require_admin ...@@ -1755,7 +1760,7 @@ async def api_model_load(request: Request, username: str = Depends(require_admin
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@router.post("/admin/api/model-unload") @router.post("/admin/api/model-unload", summary="Unload a model")
async def api_model_unload(request: Request, username: str = Depends(require_admin)): async def api_model_unload(request: Request, username: str = Depends(require_admin)):
"""Unload a model from VRAM (keeps it available for on-request reload).""" """Unload a model from VRAM (keeps it available for on-request reload)."""
import gc import gc
...@@ -1799,7 +1804,7 @@ async def api_model_unload(request: Request, username: str = Depends(require_adm ...@@ -1799,7 +1804,7 @@ async def api_model_unload(request: Request, username: str = Depends(require_adm
return {"success": True, "was_loaded": True} return {"success": True, "was_loaded": True}
@router.post("/admin/api/model-configure") @router.post("/admin/api/model-configure", summary="Update a model's configuration")
async def api_model_configure(request: Request, username: str = Depends(require_admin)): async def api_model_configure(request: Request, username: str = Depends(require_admin)):
"""Save per-model configuration and register/update in models.json.""" """Save per-model configuration and register/update in models.json."""
if config_manager is None: if config_manager is None:
...@@ -1939,7 +1944,8 @@ async def api_model_configure(request: Request, username: str = Depends(require_ ...@@ -1939,7 +1944,8 @@ async def api_model_configure(request: Request, username: str = Depends(require_
"lora_train_base_model", "lora_train_base_model",
"max_vram", "sdcpp_flash_attn", "sdcpp_diffusion_flash_attn", "vae_tiling", "max_vram", "sdcpp_flash_attn", "sdcpp_diffusion_flash_attn", "vae_tiling",
"component_quantization", "output_crf", "force_vram_update", "component_quantization", "output_crf", "force_vram_update",
"balanced_gpu_percent", "acceleration"): "balanced_gpu_percent", "acceleration",
"cache_type_k", "cache_type_v", "turboquant"):
if key in data: if key in data:
entry[key] = data[key] entry[key] = data[key]
...@@ -1970,7 +1976,7 @@ async def api_model_configure(request: Request, username: str = Depends(require_ ...@@ -1970,7 +1976,7 @@ async def api_model_configure(request: Request, username: str = Depends(require_
return {"success": True, "applied_live": applied} return {"success": True, "applied_live": applied}
@router.get("/admin/api/accel-presets") @router.get("/admin/api/accel-presets", summary="List acceleration / distillation presets")
async def api_accel_presets(username: str = Depends(require_admin)): async def api_accel_presets(username: str = Depends(require_admin)):
"""Return the acceleration/distillation preset catalog (Lightning / Turbo / """Return the acceleration/distillation preset catalog (Lightning / Turbo /
LCM / Hyper-SD) so the model-config UI dropdown stays in sync with the Python LCM / Hyper-SD) so the model-config UI dropdown stays in sync with the Python
...@@ -1982,9 +1988,264 @@ async def api_accel_presets(username: str = Depends(require_admin)): ...@@ -1982,9 +1988,264 @@ async def api_accel_presets(username: str = Depends(require_admin)):
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@router.get("/admin/api/turboquant-info", summary="TurboQuant backend availability")
async def api_turboquant_info(username: str = Depends(require_admin)):
"""Report which TurboQuant embedding-quantization backends are available so
the model-config UI can offer 'builtin' (always) and 'library' (turboquant-py
when installed)."""
try:
from codai.models import turboquant as _tq
return {
"builtin": True,
"library": _tq.have_library(),
"library_package": "turboquant-py",
"bit_widths": [8, 6, 4, 2],
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# --- Task / queue management ---
@router.get("/admin/api/tasks", summary="List active and recent tasks")
async def api_tasks(username: str = Depends(require_admin)):
"""Unified live view of long-running work: in-flight / recent generations
(image, video, audio, text) from the task registry, durable LoRA training
jobs, and queued requests waiting for a slot. The Tasks page polls this."""
from codai.tasks import task_registry
from codai.api.loras import list_jobs
from codai.queue.manager import queue_manager
tasks = []
seen = set()
# Training jobs are authoritative (persisted, survive restarts).
for j in list_jobs():
jid = j.get("job_id")
if not jid:
continue
seen.add(jid)
status = j.get("status") or "unknown"
norm = "running" if status in ("preparing", "training", "saving") else status
active = norm in ("queued", "running")
tasks.append({
"id": jid,
"kind": "training",
"title": j.get("name") or "",
"model": j.get("base_model") or "",
"status": norm,
"step": j.get("step") or 0,
"total": j.get("total") or 0,
"message": j.get("message") or "",
"started_at": j.get("started_at"),
"active": active,
"cancellable": active,
"pausable": norm == "running",
"paused": bool(task_registry.is_paused(jid)),
"restartable": status in ("cancelled", "error", "interrupted", "done"),
})
# Generations + anything else from the live registry (skip training dupes).
for t in task_registry.list():
if t["id"] in seen or t.get("kind") == "training":
continue
seen.add(t["id"])
t = dict(t)
t["cancellable"] = bool(t.get("cancellable", True) and t.get("active", False))
t["pausable"] = (t.get("status") == "running")
t["restartable"] = False
tasks.append(t)
# Queued requests waiting for a free slot/model (e.g. text) not shown yet.
for w in queue_manager.list_waiting():
rid = w.get("request_id")
if not rid or rid in seen or rid.startswith("lora-train-"):
continue
seen.add(rid)
tasks.append({
"id": rid,
"kind": "text" if rid.startswith("req-") else "request",
"title": "",
"model": w.get("model_key") or "",
"status": "queued",
"step": 0, "total": 0,
"message": "waiting for a free slot",
"started_at": w.get("enqueued_at"),
"active": True,
"cancellable": False,
"restartable": False,
})
# Successfully-finished work is dropped from the live list — a "done" job is
# no longer actionable, so it shouldn't clutter the view. Terminal-but-notable
# states (cancelled / error / interrupted) stay, so they can be inspected,
# restarted, or removed manually.
tasks = [t for t in tasks if t.get("status") != "done"]
# A thermal pause is global hardware state: while cooling, every running
# worker is blocked at its next checkpoint. Surface it on the running tasks
# and as a top-level banner so the Tasks page shows "cooling down".
cooling = {"active": False}
try:
from codai.models import thermal
cs = thermal.get_cooldown_state()
if cs.get("active"):
parts = []
if cs.get("gpu") is not None:
parts.append(f"GPU {cs['gpu']:.0f}°C")
if cs.get("cpu") is not None:
parts.append(f"CPU {cs['cpu']:.0f}°C")
waited = int(cs.get("waited") or 0)
detail = ", ".join(parts)
label = "Cooling down" + (f" — {detail}" if detail else "")
if waited:
label += f" ({waited}s)"
cooling = {"active": True, "message": label,
"gpu": cs.get("gpu"), "cpu": cs.get("cpu"), "waited": waited}
for t in tasks:
if t.get("active") and t.get("status") == "running":
t["cooling"] = True
t["cooling_message"] = label
except Exception:
pass
return {"tasks": tasks, "queue": queue_manager.get_metrics(), "thermal": cooling}
def _read_vram_info() -> Optional[dict]:
"""Best-effort {used, total, gpu} in GB. CUDA via torch, else AMD/Intel sysfs."""
try:
import torch
if torch.cuda.is_available():
free, total = torch.cuda.mem_get_info()
return {"used": (total - free) / 1e9, "total": total / 1e9,
"gpu": torch.cuda.get_device_name(0)}
except Exception:
pass
try:
import glob as _glob
for card in sorted(_glob.glob("/sys/class/drm/card[0-9]")):
dev = card + "/device"
tot = dev + "/mem_info_vram_total"
if not os.path.exists(tot):
continue
total_b = int(open(tot).read())
used_b = int(open(dev + "/mem_info_vram_used").read())
return {"used": used_b / 1e9, "total": total_b / 1e9, "gpu": ""}
except Exception:
pass
return None
@router.get("/admin/api/system-stats", summary="Live CPU / GPU / RAM / VRAM usage and temperatures")
async def api_system_stats(username: str = Depends(require_admin)):
"""Lightweight hardware telemetry for the Tasks page header: CPU & GPU
utilization and temperature, plus RAM and VRAM usage. All fields are
best-effort and may be null when a sensor/metric is unavailable."""
from codai.models import thermal
cpu = {"util": None, "temp": thermal.read_cpu_temp()}
ram = None
try:
import psutil
cpu["util"] = psutil.cpu_percent(interval=None)
vm = psutil.virtual_memory()
ram = {"used": vm.used / 1e9, "total": vm.total / 1e9, "percent": vm.percent}
except Exception:
pass
gpu = {"util": thermal.read_gpu_util(), "temp": thermal.read_gpu_temp()}
vram = _read_vram_info()
if vram and vram.get("total"):
vram["percent"] = round(vram["used"] / vram["total"] * 100, 1)
if gpu.get("gpu") is None:
gpu["name"] = vram.get("gpu") or ""
return {"cpu": cpu, "gpu": gpu, "ram": ram, "vram": vram}
def _do_task_cancel(task_id: str) -> bool:
"""Cancel a task by id. Training ids route through loras.cancel_job (handles
queued vs running + the durable job record); everything else goes through the
in-memory task registry."""
from codai.tasks import task_registry
from codai.api.loras import cancel_job
if cancel_job(task_id):
return True
return task_registry.cancel(task_id)
@router.post("/admin/api/tasks/{task_id}/cancel", summary="Cancel a task")
async def api_task_cancel(task_id: str, username: str = Depends(require_admin)):
"""Cancel a queued or running task. Running generations/training stop at the
next step boundary; queued items are dropped before they start."""
if not _do_task_cancel(task_id):
raise HTTPException(status_code=404, detail="Task not found")
return {"ok": True, "task_id": task_id, "action": "cancel"}
@router.post("/admin/api/tasks/{task_id}/interrupt", summary="Interrupt a running task")
async def api_task_interrupt(task_id: str, username: str = Depends(require_admin)):
"""Alias of cancel for a running task."""
if not _do_task_cancel(task_id):
raise HTTPException(status_code=404, detail="Task not found")
return {"ok": True, "task_id": task_id, "action": "interrupt"}
@router.delete("/admin/api/tasks/{task_id}", summary="Remove a finished task")
async def api_task_remove(task_id: str, username: str = Depends(require_admin)):
"""Dismiss a finished/cancelled/errored task from the Tasks view. Refuses to
remove a task that is still active — cancel it first."""
from codai.tasks import task_registry
from codai.api.loras import remove_job
# Training job (durable record) first.
if remove_job(task_id):
return {"ok": True, "task_id": task_id, "removed": True}
# Otherwise a live registry (generation) task — only when it's not active.
t = task_registry.get(task_id)
if t is None:
raise HTTPException(status_code=404, detail="Task not found")
if t.get("active"):
raise HTTPException(status_code=409, detail="Task is still active — cancel it first")
task_registry.remove(task_id)
return {"ok": True, "task_id": task_id, "removed": True}
@router.post("/admin/api/tasks/{task_id}/pause", summary="Pause a running task")
async def api_task_pause(task_id: str, username: str = Depends(require_admin)):
"""Pause a running task. It suspends at the next step boundary (holding the
model/GPU) until resumed. Works for generations and LoRA training."""
from codai.tasks import task_registry
if not task_registry.pause(task_id):
raise HTTPException(status_code=404,
detail="Task not found or not running")
return {"ok": True, "task_id": task_id, "action": "pause"}
@router.post("/admin/api/tasks/{task_id}/resume", summary="Resume a paused task")
async def api_task_resume(task_id: str, username: str = Depends(require_admin)):
"""Resume a task previously paused from the Tasks page."""
from codai.tasks import task_registry
if not task_registry.resume(task_id):
raise HTTPException(status_code=404, detail="Task not found")
return {"ok": True, "task_id": task_id, "action": "resume"}
@router.post("/admin/api/tasks/{task_id}/restart", summary="Restart a training task")
async def api_task_restart(task_id: str, username: str = Depends(require_admin)):
"""Restart a finished/cancelled/interrupted LoRA training job, resuming from
its last on-disk checkpoint (only training tasks are restartable)."""
from codai.api.loras import restart_job
jid = restart_job(task_id)
if not jid:
raise HTTPException(status_code=400,
detail="Task is not restartable (training jobs only, and the saved request must exist)")
return {"ok": True, "job_id": jid, "status": "queued"}
# --- System endpoints --- # --- System endpoints ---
@router.post("/admin/api/system/reload") @router.post("/admin/api/system/reload", summary="Reload server configuration")
async def api_reload_config(username: str = Depends(require_admin)): async def api_reload_config(username: str = Depends(require_admin)):
"""Reload configuration from disk.""" """Reload configuration from disk."""
try: try:
...@@ -2010,12 +2271,12 @@ from datetime import datetime ...@@ -2010,12 +2271,12 @@ from datetime import datetime
# --- Settings page --- # --- Settings page ---
@router.get("/admin/settings", response_class=HTMLResponse) @router.get("/admin/settings", response_class=HTMLResponse, summary="Settings page")
async def settings_page(request: Request, username: str = Depends(require_admin)): async def settings_page(request: Request, username: str = Depends(require_admin)):
return _tmpl(request, "settings.html", {"username": username, "is_admin": True}) return _tmpl(request, "settings.html", {"username": username, "is_admin": True})
@router.get("/admin/api/settings") @router.get("/admin/api/settings", summary="Get server settings")
async def api_get_settings(username: str = Depends(require_admin)): async def api_get_settings(username: str = Depends(require_admin)):
"""Return current config.json as JSON.""" """Return current config.json as JSON."""
if config_manager is None or config_manager.config is None: if config_manager is None or config_manager.config is None:
...@@ -2072,6 +2333,9 @@ async def api_get_settings(username: str = Depends(require_admin)): ...@@ -2072,6 +2333,9 @@ async def api_get_settings(username: str = Depends(require_admin)):
"gpu_resume": c.thermal.gpu_resume, "gpu_resume": c.thermal.gpu_resume,
"poll_seconds": c.thermal.poll_seconds, "poll_seconds": c.thermal.poll_seconds,
}, },
"jobs": {
"resume_on_restart": c.jobs.resume_on_restart,
},
"broker": { "broker": {
"enabled": c.broker.enabled, "enabled": c.broker.enabled,
"base_url": c.broker.base_url, "base_url": c.broker.base_url,
...@@ -2097,7 +2361,7 @@ async def api_get_settings(username: str = Depends(require_admin)): ...@@ -2097,7 +2361,7 @@ async def api_get_settings(username: str = Depends(require_admin)):
} }
@router.post("/admin/api/settings") @router.post("/admin/api/settings", summary="Update server settings")
async def api_save_settings(request: Request, username: str = Depends(require_admin)): async def api_save_settings(request: Request, username: str = Depends(require_admin)):
"""Update and persist config.json from submitted JSON. Only sections present in the payload are updated.""" """Update and persist config.json from submitted JSON. Only sections present in the payload are updated."""
if config_manager is None or config_manager.config is None: if config_manager is None or config_manager.config is None:
...@@ -2204,6 +2468,17 @@ async def api_save_settings(request: Request, username: str = Depends(require_ad ...@@ -2204,6 +2468,17 @@ async def api_save_settings(request: Request, username: str = Depends(require_ad
except Exception: except Exception:
pass pass
if "jobs" in data:
jb = data["jobs"]
c.jobs.resume_on_restart = bool(jb.get("resume_on_restart", c.jobs.resume_on_restart))
# Apply live so the change takes effect on the next restart-recovery pass
# without needing a server restart to re-read config.
try:
from codai.api.loras import set_resume_enabled
set_resume_enabled(c.jobs.resume_on_restart)
except Exception:
pass
if "broker" in data: if "broker" in data:
bro = data["broker"] bro = data["broker"]
c.broker.enabled = bool(bro.get("enabled", c.broker.enabled)) c.broker.enabled = bool(bro.get("enabled", c.broker.enabled))
...@@ -2243,12 +2518,12 @@ async def api_save_settings(request: Request, username: str = Depends(require_ad ...@@ -2243,12 +2518,12 @@ async def api_save_settings(request: Request, username: str = Depends(require_ad
# Archive management # Archive management
# ============================================================================= # =============================================================================
@router.get("/admin/archive", response_class=HTMLResponse) @router.get("/admin/archive", response_class=HTMLResponse, summary="Archive page")
async def archive_page(request: Request, username: str = Depends(require_admin)): async def archive_page(request: Request, username: str = Depends(require_admin)):
return _tmpl(request, "archive.html", {"username": username, "is_admin": True}) return _tmpl(request, "archive.html", {"username": username, "is_admin": True})
@router.get("/admin/api/archive") @router.get("/admin/api/archive", summary="List archived generations")
async def api_archive_list( async def api_archive_list(
limit: int = 50, limit: int = 50,
offset: int = 0, offset: int = 0,
...@@ -2260,7 +2535,7 @@ async def api_archive_list( ...@@ -2260,7 +2535,7 @@ async def api_archive_list(
return {"entries": entries, "total": total} return {"entries": entries, "total": total}
@router.get("/admin/api/archive/{gen_id}") @router.get("/admin/api/archive/{gen_id}", summary="Get an archived generation")
async def api_archive_get(gen_id: str, username: str = Depends(require_admin)): async def api_archive_get(gen_id: str, username: str = Depends(require_admin)):
from codai.api.archive import archive_manager from codai.api.archive import archive_manager
entry = archive_manager.get_entry(gen_id) entry = archive_manager.get_entry(gen_id)
...@@ -2269,7 +2544,7 @@ async def api_archive_get(gen_id: str, username: str = Depends(require_admin)): ...@@ -2269,7 +2544,7 @@ async def api_archive_get(gen_id: str, username: str = Depends(require_admin)):
return entry return entry
@router.delete("/admin/api/archive/{gen_id}") @router.delete("/admin/api/archive/{gen_id}", summary="Delete an archived generation")
async def api_archive_delete(gen_id: str, username: str = Depends(require_admin)): async def api_archive_delete(gen_id: str, username: str = Depends(require_admin)):
from codai.api.archive import archive_manager from codai.api.archive import archive_manager
if not archive_manager.delete_entry(gen_id): if not archive_manager.delete_entry(gen_id):
...@@ -2277,7 +2552,7 @@ async def api_archive_delete(gen_id: str, username: str = Depends(require_admin) ...@@ -2277,7 +2552,7 @@ async def api_archive_delete(gen_id: str, username: str = Depends(require_admin)
return {"success": True} return {"success": True}
@router.get("/admin/api/archive/{gen_id}/files/{filename}") @router.get("/admin/api/archive/{gen_id}/files/{filename}", summary="Download an archived file")
async def api_archive_file( async def api_archive_file(
gen_id: str, gen_id: str,
filename: str, filename: str,
...@@ -2293,7 +2568,7 @@ async def api_archive_file( ...@@ -2293,7 +2568,7 @@ async def api_archive_file(
return FileResponse(path, media_type=media_type) return FileResponse(path, media_type=media_type)
@router.get("/admin/api/archive-settings") @router.get("/admin/api/archive-settings", summary="Get archive settings")
async def api_archive_settings_get(username: str = Depends(require_admin)): async def api_archive_settings_get(username: str = Depends(require_admin)):
if config_manager is None or config_manager.config is None: if config_manager is None or config_manager.config is None:
raise HTTPException(status_code=503, detail="Config not ready") raise HTTPException(status_code=503, detail="Config not ready")
...@@ -2325,7 +2600,7 @@ def _hf_file_size(sibling: dict) -> int: ...@@ -2325,7 +2600,7 @@ def _hf_file_size(sibling: dict) -> int:
return lfs.get("size") or sibling.get("size") or 0 return lfs.get("size") or sibling.get("size") or 0
@router.get("/admin/api/hf-search") @router.get("/admin/api/hf-search", summary="Search Hugging Face models")
async def api_hf_search( async def api_hf_search(
q: str = "", q: str = "",
gguf_mode: str = "gguf", # "gguf" | "all" | "no-gguf" gguf_mode: str = "gguf", # "gguf" | "all" | "no-gguf"
...@@ -2482,7 +2757,7 @@ async def api_hf_search( ...@@ -2482,7 +2757,7 @@ async def api_hf_search(
raise HTTPException(status_code=502, detail=f"HuggingFace API error: {e}") raise HTTPException(status_code=502, detail=f"HuggingFace API error: {e}")
@router.get("/admin/api/hf-model-files") @router.get("/admin/api/hf-model-files", summary="List Hugging Face model files")
async def api_hf_model_files(model_id: str, username: str = Depends(require_admin)): async def api_hf_model_files(model_id: str, username: str = Depends(require_admin)):
"""Return GGUF files (name, size, VRAM estimate, quant type) for an HF model repo.""" """Return GGUF files (name, size, VRAM estimate, quant type) for an HF model repo."""
import urllib.request import urllib.request
...@@ -2523,13 +2798,13 @@ async def api_hf_model_files(model_id: str, username: str = Depends(require_admi ...@@ -2523,13 +2798,13 @@ async def api_hf_model_files(model_id: str, username: str = Depends(require_admi
# Character profile management proxy (admin UI) # Character profile management proxy (admin UI)
# ============================================================================= # =============================================================================
@router.get("/admin/api/characters") @router.get("/admin/api/characters", summary="List characters")
async def api_list_characters(username: str = Depends(require_auth)): async def api_list_characters(username: str = Depends(require_auth)):
from codai.api.characters import _list_characters from codai.api.characters import _list_characters
return {"characters": _list_characters()} return {"characters": _list_characters()}
@router.get("/admin/api/characters/{name}") @router.get("/admin/api/characters/{name}", summary="Get a character")
async def api_get_character(name: str, username: str = Depends(require_auth)): async def api_get_character(name: str, username: str = Depends(require_auth)):
from codai.api.characters import _load_character_meta, _load_character_images from codai.api.characters import _load_character_meta, _load_character_images
meta = _load_character_meta(name) meta = _load_character_meta(name)
...@@ -2545,7 +2820,7 @@ async def api_get_character(name: str, username: str = Depends(require_auth)): ...@@ -2545,7 +2820,7 @@ async def api_get_character(name: str, username: str = Depends(require_auth)):
} }
@router.get("/admin/api/characters/{name}/thumbnail") @router.get("/admin/api/characters/{name}/thumbnail", summary="Character thumbnail")
async def api_character_thumbnail(name: str, username: str = Depends(require_auth)): async def api_character_thumbnail(name: str, username: str = Depends(require_auth)):
import os as _os import os as _os
from codai.api.characters import _char_dir, _load_character_meta from codai.api.characters import _char_dir, _load_character_meta
...@@ -2563,7 +2838,7 @@ async def api_character_thumbnail(name: str, username: str = Depends(require_aut ...@@ -2563,7 +2838,7 @@ async def api_character_thumbnail(name: str, username: str = Depends(require_aut
raise HTTPException(status_code=404) raise HTTPException(status_code=404)
@router.delete("/admin/api/characters/{name}") @router.delete("/admin/api/characters/{name}", summary="Delete a character")
async def api_delete_character(name: str, username: str = Depends(require_auth)): async def api_delete_character(name: str, username: str = Depends(require_auth)):
import os as _os, shutil import os as _os, shutil
from codai.api.characters import _char_dir from codai.api.characters import _char_dir
...@@ -2578,13 +2853,13 @@ async def api_delete_character(name: str, username: str = Depends(require_auth)) ...@@ -2578,13 +2853,13 @@ async def api_delete_character(name: str, username: str = Depends(require_auth))
# Environment profile management proxy (admin UI) # Environment profile management proxy (admin UI)
# ============================================================================= # =============================================================================
@router.get("/admin/api/environments") @router.get("/admin/api/environments", summary="List environments")
async def api_list_environments(username: str = Depends(require_auth)): async def api_list_environments(username: str = Depends(require_auth)):
from codai.api.environments import _list_environments from codai.api.environments import _list_environments
return {"environments": _list_environments()} return {"environments": _list_environments()}
@router.get("/admin/api/environments/{name}") @router.get("/admin/api/environments/{name}", summary="Get an environment")
async def api_get_environment(name: str, username: str = Depends(require_auth)): async def api_get_environment(name: str, username: str = Depends(require_auth)):
from codai.api.environments import _load_environment_meta, _load_environment_images from codai.api.environments import _load_environment_meta, _load_environment_images
meta = _load_environment_meta(name) meta = _load_environment_meta(name)
...@@ -2600,7 +2875,7 @@ async def api_get_environment(name: str, username: str = Depends(require_auth)): ...@@ -2600,7 +2875,7 @@ async def api_get_environment(name: str, username: str = Depends(require_auth)):
} }
@router.get("/admin/api/environments/{name}/thumbnail") @router.get("/admin/api/environments/{name}/thumbnail", summary="Environment thumbnail")
async def api_environment_thumbnail(name: str, username: str = Depends(require_auth)): async def api_environment_thumbnail(name: str, username: str = Depends(require_auth)):
import os as _os import os as _os
from codai.api.environments import _env_dir, _load_environment_meta from codai.api.environments import _env_dir, _load_environment_meta
...@@ -2618,7 +2893,7 @@ async def api_environment_thumbnail(name: str, username: str = Depends(require_a ...@@ -2618,7 +2893,7 @@ async def api_environment_thumbnail(name: str, username: str = Depends(require_a
raise HTTPException(status_code=404) raise HTTPException(status_code=404)
@router.delete("/admin/api/environments/{name}") @router.delete("/admin/api/environments/{name}", summary="Delete an environment")
async def api_delete_environment(name: str, username: str = Depends(require_auth)): async def api_delete_environment(name: str, username: str = Depends(require_auth)):
import os as _os, shutil import os as _os, shutil
from codai.api.environments import _env_dir from codai.api.environments import _env_dir
...@@ -2633,13 +2908,13 @@ async def api_delete_environment(name: str, username: str = Depends(require_auth ...@@ -2633,13 +2908,13 @@ async def api_delete_environment(name: str, username: str = Depends(require_auth
# Voice profile management proxy (admin UI) # Voice profile management proxy (admin UI)
# ============================================================================= # =============================================================================
@router.get("/admin/api/voices") @router.get("/admin/api/voices", summary="List voices")
async def api_list_voices(username: str = Depends(require_auth)): async def api_list_voices(username: str = Depends(require_auth)):
from codai.api.voice_clone import _list_voices from codai.api.voice_clone import _list_voices
return {"voices": _list_voices()} return {"voices": _list_voices()}
@router.get("/admin/api/voices/{name}") @router.get("/admin/api/voices/{name}", summary="Get a voice")
async def api_get_voice(name: str, username: str = Depends(require_auth)): async def api_get_voice(name: str, username: str = Depends(require_auth)):
from codai.api.voice_clone import _load_voice from codai.api.voice_clone import _load_voice
meta = _load_voice(name) meta = _load_voice(name)
...@@ -2648,7 +2923,7 @@ async def api_get_voice(name: str, username: str = Depends(require_auth)): ...@@ -2648,7 +2923,7 @@ async def api_get_voice(name: str, username: str = Depends(require_auth)):
return {"voice": meta} return {"voice": meta}
@router.delete("/admin/api/voices/{name}") @router.delete("/admin/api/voices/{name}", summary="Delete a voice")
async def api_delete_voice(name: str, username: str = Depends(require_auth)): async def api_delete_voice(name: str, username: str = Depends(require_auth)):
import os as _os, shutil import os as _os, shutil
from codai.api.voice_clone import _voice_path from codai.api.voice_clone import _voice_path
...@@ -2659,7 +2934,7 @@ async def api_delete_voice(name: str, username: str = Depends(require_auth)): ...@@ -2659,7 +2934,7 @@ async def api_delete_voice(name: str, username: str = Depends(require_auth)):
return {"deleted": True, "name": name} return {"deleted": True, "name": name}
@router.get("/admin/api/hf-model-info") @router.get("/admin/api/hf-model-info", summary="Get Hugging Face model info")
async def api_hf_model_info(model_id: str, username: str = Depends(require_admin)): async def api_hf_model_info(model_id: str, username: str = Depends(require_admin)):
"""Full metadata for a single HuggingFace model repo.""" """Full metadata for a single HuggingFace model repo."""
import urllib.request import urllib.request
......
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
<a href="{{ root_path }}/docs" class="nav-link" target="_blank">API Docs</a> <a href="{{ root_path }}/docs" class="nav-link" target="_blank">API Docs</a>
{% if is_admin|default(false) %} {% if is_admin|default(false) %}
<a href="{{ root_path }}/admin/models" class="nav-link {% if '/models' in request.url.path %}active{% endif %}">Models</a> <a href="{{ root_path }}/admin/models" class="nav-link {% if '/models' in request.url.path %}active{% endif %}">Models</a>
<a href="{{ root_path }}/admin/tasks" class="nav-link {% if '/tasks' in request.url.path %}active{% endif %}">Tasks</a>
<a href="{{ root_path }}/admin/tokens" class="nav-link {% if '/tokens' in request.url.path %}active{% endif %}">Tokens</a> <a href="{{ root_path }}/admin/tokens" class="nav-link {% if '/tokens' in request.url.path %}active{% endif %}">Tokens</a>
<a href="{{ root_path }}/admin/users" class="nav-link {% if '/users' in request.url.path %}active{% endif %}">Users</a> <a href="{{ root_path }}/admin/users" class="nav-link {% if '/users' in request.url.path %}active{% endif %}">Users</a>
<a href="{{ root_path }}/admin/archive" class="nav-link {% if '/archive' in request.url.path %}active{% endif %}">Archive</a> <a href="{{ root_path }}/admin/archive" class="nav-link {% if '/archive' in request.url.path %}active{% endif %}">Archive</a>
......
...@@ -616,6 +616,28 @@ window.__DEFAULT_WHISPER_SERVER_PATH__ = {{ default_whisper_server_path|tojson } ...@@ -616,6 +616,28 @@ window.__DEFAULT_WHISPER_SERVER_PATH__ = {{ default_whisper_server_path|tojson }
<label class="form-label">Context size</label> <label class="form-label">Context size</label>
<input type="number" id="cfg-n-ctx" class="form-input" min="128" step="128" value="2048"> <input type="number" id="cfg-n-ctx" class="form-input" min="128" step="128" value="2048">
</div> </div>
<div class="form-row" style="margin:0" id="cfg-kv-k-row">
<label class="form-label">KV cache — Keys <span class="muted">(GGUF text; shrinks KV VRAM)</span></label>
<select id="cfg-cache-type-k" class="form-input">
<option value="">Default (f16)</option>
<option value="q8_0">q8_0 (near-lossless, ~2×)</option>
<option value="q5_1">q5_1</option>
<option value="q5_0">q5_0</option>
<option value="q4_1">q4_1 (smallest)</option>
<option value="q4_0">q4_0 (smallest)</option>
</select>
</div>
<div class="form-row" style="margin:0" id="cfg-kv-v-row">
<label class="form-label">KV cache — Values <span class="muted">(sub-8-bit needs Flash Attn)</span></label>
<select id="cfg-cache-type-v" class="form-input">
<option value="">Default (f16)</option>
<option value="q8_0">q8_0 (near-lossless, ~2×)</option>
<option value="q5_1">q5_1</option>
<option value="q5_0">q5_0</option>
<option value="q4_1">q4_1 (smallest)</option>
<option value="q4_0">q4_0 (smallest)</option>
</select>
</div>
<div class="form-row" style="margin:0"> <div class="form-row" style="margin:0">
<label class="form-label">Max GPU % <span class="muted">(optional)</span></label> <label class="form-label">Max GPU % <span class="muted">(optional)</span></label>
<input type="number" id="cfg-max-gpu" class="form-input" min="1" max="100" placeholder="e.g. 90"> <input type="number" id="cfg-max-gpu" class="form-input" min="1" max="100" placeholder="e.g. 90">
...@@ -732,6 +754,36 @@ window.__DEFAULT_WHISPER_SERVER_PATH__ = {{ default_whisper_server_path|tojson } ...@@ -732,6 +754,36 @@ window.__DEFAULT_WHISPER_SERVER_PATH__ = {{ default_whisper_server_path|tojson }
</div> </div>
</div> </div>
<!-- TurboQuant embedding vector quantization (embedding models) -->
<div id="cfg-turboquant-section" style="display:none">
<div class="card-title" style="margin-top:1.25rem">TurboQuant <span class="muted" style="font-weight:normal">(embedding vector quantization — data-free, inner-product preserving)</span></div>
<label style="display:flex;align-items:center;gap:.5rem;cursor:pointer;font-size:13px;margin:.4rem 0">
<input type="checkbox" id="cfg-tq-enabled" onchange="onTurboQuantToggle()"> Enable TurboQuant
<span class="muted">compress embeddings to 2–8 bits/coord for smaller vector stores</span></label>
<div id="cfg-tq-fields" style="display:none">
<div style="display:flex;gap:1rem;flex-wrap:wrap">
<div class="form-row" style="max-width:260px">
<label class="form-label">Backend</label>
<select id="cfg-tq-backend" class="form-input">
<option value="builtin">Built-in (NumPy, always available)</option>
<option value="library">turboquant-py library (QJL)</option>
</select>
<span class="form-hint" id="cfg-tq-backend-hint"></span>
</div>
<div class="form-row" style="max-width:200px">
<label class="form-label">Default bits</label>
<select id="cfg-tq-bits" class="form-input">
<option value="8">8-bit (near-lossless, 3×)</option>
<option value="6">6-bit</option>
<option value="4">4-bit (6×)</option>
<option value="2">2-bit (12×)</option>
</select>
</div>
</div>
<span class="form-hint">A request's <code>quantization</code> field (e.g. <code>turbo4</code>) overrides the default bits. With <code>encoding_format:"base64"</code> the response is compact packed bytes; otherwise it returns the lossy reconstruction as floats.</span>
</div>
</div>
<!-- components --> <!-- components -->
<div class="card-title" style="margin-top:1.25rem">Components</div> <div class="card-title" style="margin-top:1.25rem">Components</div>
<div class="form-row"> <div class="form-row">
...@@ -2268,9 +2320,9 @@ async function refreshLocal(){ ...@@ -2268,9 +2320,9 @@ async function refreshLocal(){
loadGlobalSettings(); loadGlobalSettings();
refreshLocal(); refreshLocal();
// Toggle the acceleration section as image/video model types are checked/unchecked. // Toggle the acceleration / TurboQuant sections as model types are checked/unchecked.
document.querySelectorAll('.cfg-type-cb').forEach(cb => document.querySelectorAll('.cfg-type-cb').forEach(cb =>
cb.addEventListener('change', () => _refreshAccelVisibility())); cb.addEventListener('change', () => { _refreshAccelVisibility(); _refreshTurboQuantVisibility(); }));
// ── Deep-link from Studio: /admin/models?tab=search&q=...&pipeline=...&gguf=... // ── Deep-link from Studio: /admin/models?tab=search&q=...&pipeline=...&gguf=...
// ── or: /admin/models?local_cap=CAPABILITY — highlight local models with that capability // ── or: /admin/models?local_cap=CAPABILITY — highlight local models with that capability
...@@ -2633,6 +2685,8 @@ function openCfgModal(idx, cfgIdx){ ...@@ -2633,6 +2685,8 @@ function openCfgModal(idx, cfgIdx){
document.getElementById('cfg-force-vram-update').checked = !!s.force_vram_update; document.getElementById('cfg-force-vram-update').checked = !!s.force_vram_update;
document.getElementById('cfg-gpu-layers').value = s.n_gpu_layers !== undefined ? s.n_gpu_layers : -1; document.getElementById('cfg-gpu-layers').value = s.n_gpu_layers !== undefined ? s.n_gpu_layers : -1;
document.getElementById('cfg-n-ctx').value = nCtxForEst; document.getElementById('cfg-n-ctx').value = nCtxForEst;
document.getElementById('cfg-cache-type-k').value = s.cache_type_k || '';
document.getElementById('cfg-cache-type-v').value = s.cache_type_v || '';
document.getElementById('cfg-max-instances').value = s.max_instances != null ? s.max_instances : 1; document.getElementById('cfg-max-instances').value = s.max_instances != null ? s.max_instances : 1;
document.getElementById('cfg-preload-all-instances').checked = !!s.preload_all_instances; document.getElementById('cfg-preload-all-instances').checked = !!s.preload_all_instances;
_updatePreloadAllVisibility(); _updatePreloadAllVisibility();
...@@ -2678,6 +2732,7 @@ function openCfgModal(idx, cfgIdx){ ...@@ -2678,6 +2732,7 @@ function openCfgModal(idx, cfgIdx){
document.getElementById('cfg-lora-dir').value = s.lora_model_dir || ''; document.getElementById('cfg-lora-dir').value = s.lora_model_dir || '';
document.getElementById('cfg-lora-train-base').value = s.lora_train_base_model || ''; document.getElementById('cfg-lora-train-base').value = s.lora_train_base_model || '';
_populateAccel(s.acceleration); _populateAccel(s.acceleration);
_populateTurboQuant(s.turboquant);
openModal('cfg-modal'); openModal('cfg-modal');
} }
...@@ -2768,6 +2823,56 @@ function _collectAccel(){ ...@@ -2768,6 +2823,56 @@ function _collectAccel(){
}; };
} }
// ---- TurboQuant (embedding vector quantization) ----
let _tqInfo = null;
async function _loadTurboQuantInfo(){
if (_tqInfo) return _tqInfo;
try {
const r = await fetch(ROOT_PATH + '/admin/api/turboquant-info');
_tqInfo = await r.json();
} catch(e){ _tqInfo = {builtin:true, library:false}; }
return _tqInfo;
}
function _turboQuantApplies(){
return [...document.querySelectorAll('.cfg-type-cb:checked')]
.some(cb => cb.value === 'embedding_models');
}
function _refreshTurboQuantVisibility(){
const section = document.getElementById('cfg-turboquant-section');
if (section) section.style.display = _turboQuantApplies() ? '' : 'none';
}
function onTurboQuantToggle(){
const on = document.getElementById('cfg-tq-enabled').checked;
document.getElementById('cfg-tq-fields').style.display = on ? '' : 'none';
}
async function _populateTurboQuant(t){
await _loadTurboQuantInfo();
_refreshTurboQuantVisibility();
// Reflect library availability in the backend dropdown + hint.
const libOpt = document.querySelector('#cfg-tq-backend option[value="library"]');
const hint = document.getElementById('cfg-tq-backend-hint');
const libOk = !!(_tqInfo && _tqInfo.library);
if (libOpt){ libOpt.disabled = !libOk; libOpt.textContent =
'turboquant-py library (QJL)' + (libOk ? '' : ' — not installed'); }
if (hint) hint.textContent = libOk
? 'turboquant-py detected.'
: 'Install "turboquant-py[torch]" to enable the library backend.';
t = t || {};
document.getElementById('cfg-tq-enabled').checked = !!t.enabled;
document.getElementById('cfg-tq-backend').value =
(t.backend === 'library' && libOk) ? 'library' : 'builtin';
document.getElementById('cfg-tq-bits').value = String(t.bits || 8);
onTurboQuantToggle();
}
function _collectTurboQuant(){
if (!document.getElementById('cfg-tq-enabled').checked) return null;
return {
enabled: true,
backend: document.getElementById('cfg-tq-backend').value || 'builtin',
bits: parseInt(document.getElementById('cfg-tq-bits').value) || 8,
};
}
function _updatePreloadAllVisibility() { function _updatePreloadAllVisibility() {
const loadMode = document.getElementById('cfg-load-mode').value; const loadMode = document.getElementById('cfg-load-mode').value;
const maxInst = parseInt(document.getElementById('cfg-max-instances').value) || 1; const maxInst = parseInt(document.getElementById('cfg-max-instances').value) || 1;
...@@ -2834,6 +2939,8 @@ async function saveModelConfig(){ ...@@ -2834,6 +2939,8 @@ async function saveModelConfig(){
preload_all_instances: document.getElementById('cfg-preload-all-instances').checked, preload_all_instances: document.getElementById('cfg-preload-all-instances').checked,
n_gpu_layers: parseInt(document.getElementById('cfg-gpu-layers').value) || -1, n_gpu_layers: parseInt(document.getElementById('cfg-gpu-layers').value) || -1,
n_ctx: parseInt(document.getElementById('cfg-n-ctx').value) || 2048, n_ctx: parseInt(document.getElementById('cfg-n-ctx').value) || 2048,
cache_type_k: document.getElementById('cfg-cache-type-k').value || null,
cache_type_v: document.getElementById('cfg-cache-type-v').value || null,
max_gpu_percent: isNaN(maxGpu) ? null : maxGpu, max_gpu_percent: isNaN(maxGpu) ? null : maxGpu,
manual_ram_gb: isNaN(ramGb) ? null : ramGb, manual_ram_gb: isNaN(ramGb) ? null : ramGb,
load_in_4bit: document.getElementById('cfg-4bit').checked, load_in_4bit: document.getElementById('cfg-4bit').checked,
...@@ -2866,6 +2973,7 @@ async function saveModelConfig(){ ...@@ -2866,6 +2973,7 @@ async function saveModelConfig(){
balanced_gpu_percent: (document.getElementById('cfg-balanced-gpu-pct').value.trim() === '' balanced_gpu_percent: (document.getElementById('cfg-balanced-gpu-pct').value.trim() === ''
? null : parseFloat(document.getElementById('cfg-balanced-gpu-pct').value)), ? null : parseFloat(document.getElementById('cfg-balanced-gpu-pct').value)),
acceleration: _collectAccel(), acceleration: _collectAccel(),
turboquant: _collectTurboQuant(),
}; };
try{ try{
const r = await fetch(ROOT_PATH + '/admin/api/model-configure',{ const r = await fetch(ROOT_PATH + '/admin/api/model-configure',{
......
...@@ -153,6 +153,24 @@ ...@@ -153,6 +153,24 @@
</div> </div>
</div> </div>
<!-- Background jobs -->
<div class="card mb-0" style="margin-top:1rem">
<div class="card-title">Background Jobs</div>
<span class="form-hint" style="display:block;margin-bottom:.75rem">
Controls how interrupted LoRA training is handled when CoderAI restarts.
Equivalent to the <code>--no-resume-jobs</code> launch flag.
</span>
<div class="form-row" style="margin:0">
<label style="display:flex;align-items:center;gap:.5rem;cursor:pointer">
<input type="checkbox" id="s-jobs-resume">
<span style="font-size:13px;font-weight:500">Resume interrupted training on restart</span>
</label>
<span class="form-hint">When off, a training job that was running at restart is marked
<em>cancelled</em> instead of resuming. Its checkpoint is kept, so you can still
restart it manually from the Tasks page.</span>
</div>
</div>
<div class="card mb-0" style="margin-top:1rem"> <div class="card mb-0" style="margin-top:1rem">
<div class="card-title">AISBF Broker</div> <div class="card-title">AISBF Broker</div>
<div class="form-row"> <div class="form-row">
...@@ -328,6 +346,9 @@ async function loadSettings(){ ...@@ -328,6 +346,9 @@ async function loadSettings(){
document.getElementById('s-therm-cpu-resume').value = therm.cpu_resume ?? 87; document.getElementById('s-therm-cpu-resume').value = therm.cpu_resume ?? 87;
document.getElementById('s-therm-poll').value = therm.poll_seconds ?? 5; document.getElementById('s-therm-poll').value = therm.poll_seconds ?? 5;
toggleThermalFields(); toggleThermalFields();
// Background jobs
const jobs = d.jobs || {};
document.getElementById('s-jobs-resume').checked = jobs.resume_on_restart !== false;
}catch(e){ showAlert('error','Failed to load settings: '+e.message); } }catch(e){ showAlert('error','Failed to load settings: '+e.message); }
} }
...@@ -363,6 +384,9 @@ async function saveSettings(){ ...@@ -363,6 +384,9 @@ async function saveSettings(){
cpu_resume: parseFloat(document.getElementById('s-therm-cpu-resume').value) || 87, cpu_resume: parseFloat(document.getElementById('s-therm-cpu-resume').value) || 87,
poll_seconds: parseFloat(document.getElementById('s-therm-poll').value) || 5, poll_seconds: parseFloat(document.getElementById('s-therm-poll').value) || 5,
}, },
jobs:{
resume_on_restart: document.getElementById('s-jobs-resume').checked,
},
broker:{ broker:{
enabled: document.getElementById('s-broker-enabled').checked, enabled: document.getElementById('s-broker-enabled').checked,
base_url: document.getElementById('s-broker-base-url').value.trim(), base_url: document.getElementById('s-broker-base-url').value.trim(),
......
{% extends "base.html" %}
{% block title %}Tasks — CoderAI{% endblock %}
{% block content %}
<div class="page-header">
<div>
<h1>Tasks</h1>
<p>Live view of generations and LoRA training. Cancel, interrupt, or restart a job.</p>
</div>
<div class="header-actions">
<span id="queue-summary" class="dim small"></span>
</div>
</div>
<div id="thermal-banner" style="display:none;margin:0 0 1rem;padding:.6rem .85rem;border-radius:8px;
background:rgba(245,158,11,.12);border:1px solid rgba(245,158,11,.4);color:#f59e0b;font-size:13px">
<span style="font-weight:600">❄ Thermal cooldown</span>
<span id="thermal-banner-msg" class="mono"></span>
— running work is paused until the hardware cools.
</div>
<!-- Live hardware telemetry -->
<div id="sys-stats" style="display:grid;grid-template-columns:repeat(auto-fit,minmax(220px,1fr));
gap:.75rem;margin:0 0 1.25rem">
<div class="sys-tile" id="tile-cpu"></div>
<div class="sys-tile" id="tile-gpu"></div>
<div class="sys-tile" id="tile-ram"></div>
<div class="sys-tile" id="tile-vram"></div>
</div>
<style>
.sys-tile{border:1px solid var(--border,#2a2a2a);border-radius:10px;padding:.7rem .85rem;
background:var(--card-bg,rgba(255,255,255,.02))}
.sys-tile .sys-head{display:flex;justify-content:space-between;align-items:baseline;margin-bottom:.45rem}
.sys-tile .sys-name{font-size:12px;font-weight:600;letter-spacing:.03em;text-transform:uppercase;color:var(--text-muted,#9aa0a6)}
.sys-tile .sys-val{font-size:13px;font-weight:600}
.sys-tile .sys-sub{font-size:11px;color:var(--text-muted,#9aa0a6);margin-top:.35rem;display:flex;justify-content:space-between}
.sys-bar{height:8px;border-radius:5px;background:rgba(127,127,127,.18);overflow:hidden}
.sys-bar > span{display:block;height:100%;border-radius:5px;transition:width .4s ease,background .4s ease}
.sys-ok > span{background:#22c55e}.sys-warn > span{background:#f59e0b}.sys-hot > span{background:#ef4444}
.sys-temp-ok{color:#22c55e}.sys-temp-warn{color:#f59e0b}.sys-temp-hot{color:#ef4444}
</style>
<div class="table-wrap">
<table>
<thead>
<tr>
<th>Type</th><th>Name / Model</th><th>Status</th>
<th style="width:220px">Progress</th><th>Started</th><th style="text-align:right">Actions</th>
</tr>
</thead>
<tbody id="tasks-body">
<tr class="empty-row"><td colspan="6">No tasks yet</td></tr>
</tbody>
</table>
</div>
{% endblock %}
{% block scripts %}
<script>
function esc(s) { return String(s == null ? '' : s).replace(/&/g,'&amp;').replace(/</g,'&lt;').replace(/>/g,'&gt;'); }
function fmtTime(s) {
if (!s) return '';
try {
// started_at is unix seconds (float) from the server.
const d = new Date(s * 1000);
return d.toLocaleTimeString(undefined, {hour:'2-digit', minute:'2-digit', second:'2-digit'});
} catch { return ''; }
}
const KIND_LABEL = {training:'Training', image:'Image', video:'Video', audio:'Audio', text:'Text', pipeline:'Pipeline', request:'Request'};
const STATUS_BADGE = {
running:'badge-admin', queued:'badge-user', done:'badge-ok', error:'badge-err',
cancelled:'badge-user', interrupted:'badge-warn'
};
function progressBar(t) {
const total = t.total || 0, step = t.step || 0;
if (!total) {
return t.status === 'running' ? '<span class="dim small">working…</span>' : '<span class="dim small">—</span>';
}
const pct = Math.max(0, Math.min(100, Math.round(step / total * 100)));
return `<div class="progress"><div class="progress-fill" style="width:${pct}%"></div></div>
<span class="dim small">${step}/${total} (${pct}%)</span>`;
}
function actions(t) {
const btns = [];
if (t.paused) {
btns.push(`<button class="btn btn-primary btn-sm" onclick="taskAction('${esc(t.id)}','resume')">Resume</button>`);
} else if (t.pausable) {
btns.push(`<button class="btn btn-ghost btn-sm" onclick="taskAction('${esc(t.id)}','pause')">Pause</button>`);
}
if (t.cancellable) {
const label = t.status === 'running' ? 'Interrupt' : 'Cancel';
const act = t.status === 'running' ? 'interrupt' : 'cancel';
btns.push(`<button class="btn btn-danger btn-sm" onclick="taskAction('${esc(t.id)}','${act}')">${label}</button>`);
}
if (t.restartable) {
btns.push(`<button class="btn btn-ghost btn-sm" onclick="taskAction('${esc(t.id)}','restart')">Restart</button>`);
}
if (!t.active) {
btns.push(`<button class="btn btn-ghost btn-sm" onclick="removeTask('${esc(t.id)}')">Remove</button>`);
}
return btns.join(' ') || '<span class="dim small">—</span>';
}
// ---- Live hardware telemetry ----
function _utilClass(pct){ return pct == null ? 'sys-ok' : (pct >= 90 ? 'sys-hot' : pct >= 70 ? 'sys-warn' : 'sys-ok'); }
function _tempClass(t){ return t == null ? '' : (t >= 90 ? 'sys-temp-hot' : t >= 80 ? 'sys-temp-warn' : 'sys-temp-ok'); }
function _bar(pct){
const p = pct == null ? 0 : Math.max(0, Math.min(100, pct));
return `<div class="sys-bar ${_utilClass(pct)}"><span style="width:${p}%"></span></div>`;
}
function _utilTile(name, pct, temp){
const valTxt = pct == null ? 'n/a' : `${Math.round(pct)}%`;
const tempTxt = temp == null ? '<span class="dim">temp n/a</span>'
: `<span class="${_tempClass(temp)}">${Math.round(temp)}°C</span>`;
return `<div class="sys-head"><span class="sys-name">${name}</span><span class="sys-val">${valTxt}</span></div>`
+ _bar(pct) + `<div class="sys-sub"><span>utilization</span>${tempTxt}</div>`;
}
function _memTile(name, used, total, pct){
const valTxt = (used == null || total == null) ? 'n/a' : `${used.toFixed(1)} / ${total.toFixed(1)} GB`;
const p = pct != null ? pct : (used != null && total ? used/total*100 : null);
return `<div class="sys-head"><span class="sys-name">${name}</span><span class="sys-val">${valTxt}</span></div>`
+ _bar(p) + `<div class="sys-sub"><span>${p == null ? '' : Math.round(p)+'% used'}</span><span></span></div>`;
}
async function loadSystemStats(){
try {
const s = await fetch(ROOT_PATH + '/admin/api/system-stats').then(r => r.json());
const cpu = s.cpu || {}, gpu = s.gpu || {}, ram = s.ram || {}, vram = s.vram || {};
document.getElementById('tile-cpu').innerHTML = _utilTile('CPU', cpu.util, cpu.temp);
document.getElementById('tile-gpu').innerHTML = _utilTile('GPU', gpu.util, gpu.temp);
document.getElementById('tile-ram').innerHTML = _memTile('RAM', ram.used, ram.total, ram.percent);
document.getElementById('tile-vram').innerHTML =
_memTile('VRAM', vram.used, vram.total, vram.percent);
} catch(e){ /* keep last render on transient errors */ }
}
let _refreshing = false;
async function loadTasks() {
if (_refreshing) return;
_refreshing = true;
try {
const data = await fetch(ROOT_PATH + '/admin/api/tasks').then(r => r.json());
const tasks = data.tasks || [];
const q = data.queue || {};
document.getElementById('queue-summary').textContent =
`${q.active || 0} active · ${q.waiting || 0} waiting · max ${q.max_parallel_requests || 0} parallel`;
const therm = data.thermal || {};
const banner = document.getElementById('thermal-banner');
if (therm.active) {
document.getElementById('thermal-banner-msg').textContent = ' ' + (therm.message || '');
banner.style.display = '';
} else {
banner.style.display = 'none';
}
const tbody = document.getElementById('tasks-body');
if (!tasks.length) {
tbody.innerHTML = '<tr class="empty-row"><td colspan="6">No tasks yet</td></tr>';
return;
}
tbody.innerHTML = tasks.map(t => {
const badge = STATUS_BADGE[t.status] || 'badge-dim';
const title = t.title || '(untitled)';
let statusCell;
if (t.cooling) {
statusCell = `<span class="badge badge-warn">❄ Cooling down</span>`
+ `<div class="dim small">${esc(t.cooling_message || 'paused for thermal cooldown')}</div>`;
} else if (t.paused) {
statusCell = `<span class="badge badge-warn">⏸ Paused</span>`
+ `<div class="dim small">suspended — click Resume to continue</div>`;
} else {
statusCell = `<span class="badge ${badge}">${esc(t.status)}</span>`
+ (t.message ? `<div class="dim small">${esc(t.message)}</div>` : '');
}
return `<tr>
<td><span class="badge badge-user">${esc(KIND_LABEL[t.kind] || t.kind)}</span></td>
<td><div class="td-name">${esc(title)}</div><div class="dim small mono">${esc(t.model || '')}</div></td>
<td>${statusCell}</td>
<td>${progressBar(t)}</td>
<td class="dim small">${fmtTime(t.started_at)}</td>
<td style="text-align:right">${actions(t)}</td>
</tr>`;
}).join('');
} catch (e) {
// transient fetch errors during a model swap are fine; keep last render.
} finally {
_refreshing = false;
}
}
async function taskAction(id, action) {
const verb = {cancel:'Cancel', interrupt:'Interrupt', restart:'Restart', pause:'Pause', resume:'Resume'}[action] || action;
// Only confirm destructive actions; pause/resume/restart act immediately.
if ((action === 'cancel' || action === 'interrupt') && !confirm(`${verb} this task?`)) return;
try {
const r = await fetch(ROOT_PATH + '/admin/api/tasks/' + encodeURIComponent(id) + '/' + action, {method:'POST'});
if (!r.ok) {
const e = await r.json().catch(() => ({}));
alert(e.detail || (verb + ' failed'));
}
} catch (e) { alert(e.message); }
loadTasks();
}
async function removeTask(id) {
try {
const r = await fetch(ROOT_PATH + '/admin/api/tasks/' + encodeURIComponent(id), {method:'DELETE'});
if (!r.ok) {
const e = await r.json().catch(() => ({}));
alert(e.detail || 'Remove failed');
}
} catch (e) { alert(e.message); }
loadTasks();
}
loadTasks();
loadSystemStats();
setInterval(loadTasks, 2000);
setInterval(loadSystemStats, 2000);
</script>
{% endblock %}
...@@ -189,25 +189,25 @@ if admin_static_dir.exists(): ...@@ -189,25 +189,25 @@ if admin_static_dir.exists():
app.mount("/static/admin", StaticFiles(directory=str(admin_static_dir)), name="admin_static") app.mount("/static/admin", StaticFiles(directory=str(admin_static_dir)), name="admin_static")
# Include routers from submodules # Include routers from submodules
app.include_router(transcriptions_router) app.include_router(transcriptions_router, tags=["Audio"])
app.include_router(images_router) app.include_router(images_router, tags=["Images"])
app.include_router(tts_router) app.include_router(tts_router, tags=["Audio"])
app.include_router(text_router) app.include_router(text_router, tags=["Text"])
app.include_router(video_router) app.include_router(video_router, tags=["Video"])
app.include_router(audio_gen_router) app.include_router(audio_gen_router, tags=["Audio"])
app.include_router(audio_stems_router) app.include_router(audio_stems_router, tags=["Audio"])
app.include_router(audio_clean_router) app.include_router(audio_clean_router, tags=["Audio"])
app.include_router(embeddings_router) app.include_router(embeddings_router, tags=["Embeddings"])
app.include_router(pipelines_router) app.include_router(pipelines_router, tags=["Pipelines"])
app.include_router(custom_pipelines_router) app.include_router(custom_pipelines_router, tags=["Pipelines"])
app.include_router(voice_clone_router) app.include_router(voice_clone_router, tags=["Audio"])
app.include_router(voice_convert_router) app.include_router(voice_convert_router, tags=["Audio"])
app.include_router(faceswap_router) app.include_router(faceswap_router, tags=["Images"])
app.include_router(characters_router) app.include_router(characters_router, tags=["Characters"])
app.include_router(loras_router) app.include_router(loras_router, tags=["LoRAs"])
app.include_router(environments_router) app.include_router(environments_router, tags=["Environments"])
app.include_router(spatial_router) app.include_router(spatial_router, tags=["Spatial / 3D"])
app.include_router(admin_router) app.include_router(admin_router, tags=["Admin"])
@app.exception_handler(401) @app.exception_handler(401)
...@@ -222,20 +222,35 @@ async def unauthorized_redirect(request: Request, exc: HTTPException): ...@@ -222,20 +222,35 @@ async def unauthorized_redirect(request: Request, exc: HTTPException):
return JSONResponse(status_code=401, content={"detail": exc.detail}) return JSONResponse(status_code=401, content={"detail": exc.detail})
@app.get("/v1/models", response_model=ModelList) from codai.tasks import TaskCancelled, task_registry
@app.exception_handler(TaskCancelled)
async def task_cancelled_handler(request: Request, exc: TaskCancelled):
"""A worker observed its task was cancelled and unwound. Finish the task
(cancelled) and return 499 (client-closed-request style). The task id is
carried on the exception so any generation/training worker can simply
`raise` without bookkeeping."""
tid = exc.args[0] if exc.args else None
if tid:
task_registry.finish(tid, "cancelled", "cancelled by user")
return JSONResponse(status_code=499, content={"detail": "Task cancelled", "task_id": tid})
@app.get("/v1/models", response_model=ModelList, summary="List available models", tags=["Core"])
async def list_models(): async def list_models():
"""List available models.""" """List available models."""
models = multi_model_manager.list_models() models = multi_model_manager.list_models()
return ModelList(data=models) return ModelList(data=models)
@app.get("/coderai/capabilities") @app.get("/coderai/capabilities", summary="Server capability document", tags=["Core"])
async def get_broker_capabilities(): async def get_broker_capabilities():
"""Return broker capability metadata.""" """Return broker capability metadata."""
return build_capabilities_document(hardware=build_hardware_summary()) return build_capabilities_document(hardware=build_hardware_summary())
@app.get("/v1/files/{filename}") @app.get("/v1/files/{filename}", summary="Download a generated file", tags=["Files"])
async def get_file(filename: str): async def get_file(filename: str):
"""Serve uploaded/generated files.""" """Serve uploaded/generated files."""
if not global_file_path: if not global_file_path:
...@@ -256,7 +271,7 @@ _VIDEO_EXTS = {'.mp4', '.webm', '.avi', '.mov'} ...@@ -256,7 +271,7 @@ _VIDEO_EXTS = {'.mp4', '.webm', '.avi', '.mov'}
_AUDIO_EXTS = {'.wav', '.mp3', '.ogg', '.flac', '.aac', '.m4a'} _AUDIO_EXTS = {'.wav', '.mp3', '.ogg', '.flac', '.aac', '.m4a'}
@app.get("/v1/archive") @app.get("/v1/archive", summary="List archived generations", tags=["Files"])
async def list_archive(request: Request): async def list_archive(request: Request):
"""List all generated files in the output directory.""" """List all generated files in the output directory."""
if not global_file_path or not os.path.isdir(global_file_path): if not global_file_path or not os.path.isdir(global_file_path):
...@@ -292,7 +307,7 @@ async def list_archive(request: Request): ...@@ -292,7 +307,7 @@ async def list_archive(request: Request):
return {"files": files} return {"files": files}
@app.delete("/v1/archive/{filename}") @app.delete("/v1/archive/{filename}", summary="Delete an archived file", tags=["Files"])
async def delete_archive_file(filename: str): async def delete_archive_file(filename: str):
"""Delete a generated file from the output directory.""" """Delete a generated file from the output directory."""
if not global_file_path: if not global_file_path:
......
...@@ -116,8 +116,15 @@ class AudioCleanupRequest(BaseModel): ...@@ -116,8 +116,15 @@ class AudioCleanupRequest(BaseModel):
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow")
@router.post("/v1/audio/cleanup") @router.post("/v1/audio/cleanup", summary="Clean / restore audio")
async def cleanup_audio(request: AudioCleanupRequest, http_request: Request = None): async def cleanup_audio(request: AudioCleanupRequest, http_request: Request = None):
"""Restore/clean a noisy audio clip.
Applies any combination of noise reduction, loudness normalization, mains-hum
removal and click/crackle repair. Uses an ML restoration backend when available,
falling back to an ffmpeg-based best-effort path when `fallback_mode` is set.
Returns the cleaned audio plus the backend and quality tier that were used.
"""
try: try:
audio_bytes = _decode_audio(request.audio) audio_bytes = _decode_audio(request.audio)
except Exception as exc: except Exception as exc:
......
...@@ -31,6 +31,7 @@ from fastapi import APIRouter, HTTPException, Request ...@@ -31,6 +31,7 @@ from fastapi import APIRouter, HTTPException, Request
from codai.models.manager import multi_model_manager from codai.models.manager import multi_model_manager
from codai.pydantic.audiogenrequest import AudioGenerationRequest, AudioGenerationResponse from codai.pydantic.audiogenrequest import AudioGenerationRequest, AudioGenerationResponse
from codai.tasks import task_registry, TaskCancelled
router = APIRouter() router = APIRouter()
...@@ -160,7 +161,7 @@ def _detect_audio_gen_type(model_name: str) -> str: ...@@ -160,7 +161,7 @@ def _detect_audio_gen_type(model_name: str) -> str:
return 'musicgen' return 'musicgen'
def _generate_audio(pipe, model_name: str, request: AudioGenerationRequest): def _generate_audio(pipe, model_name: str, request: AudioGenerationRequest, task_id=None):
"""Run generation and return (audio_bytes, ext).""" """Run generation and return (audio_bytes, ext)."""
import numpy as np, io as _io import numpy as np, io as _io
...@@ -191,6 +192,9 @@ def _generate_audio(pipe, model_name: str, request: AudioGenerationRequest): ...@@ -191,6 +192,9 @@ def _generate_audio(pipe, model_name: str, request: AudioGenerationRequest):
_aud_progress_reset(num_steps, unit="it") _aud_progress_reset(num_steps, unit="it")
def _aud_step_cb(pipe, step_index, timestep, callback_kwargs): def _aud_step_cb(pipe, step_index, timestep, callback_kwargs):
task_registry.raise_if_cancelled(task_id)
task_registry.wait_if_paused(task_id)
task_registry.step(task_id, step_index + 1)
_aud_progress_step(step_index + 1) _aud_progress_step(step_index + 1)
return callback_kwargs return callback_kwargs
...@@ -222,7 +226,7 @@ def _decode_b64_or_url(data: str) -> bytes: ...@@ -222,7 +226,7 @@ def _decode_b64_or_url(data: str) -> bytes:
return base64.b64decode(data) return base64.b64decode(data)
@router.get("/v1/audio/progress") @router.get("/v1/audio/progress", summary="Audio generation progress")
async def get_audio_progress(): async def get_audio_progress():
"""Return current audio generation progress including speed.""" """Return current audio generation progress including speed."""
elapsed = time.monotonic() - _aud_progress["started_at"] if _aud_progress["active"] else 0.0 elapsed = time.monotonic() - _aud_progress["started_at"] if _aud_progress["active"] else 0.0
...@@ -241,7 +245,7 @@ async def get_audio_progress(): ...@@ -241,7 +245,7 @@ async def get_audio_progress():
} }
@router.post("/v1/audio/generate", response_model=AudioGenerationResponse) @router.post("/v1/audio/generate", response_model=AudioGenerationResponse, summary="Generate audio, music or SFX")
async def audio_generate(request: AudioGenerationRequest, http_request: Request = None): async def audio_generate(request: AudioGenerationRequest, http_request: Request = None):
""" """
Generate music, sound effects, or ambient audio. Generate music, sound effects, or ambient audio.
...@@ -274,14 +278,22 @@ async def audio_generate(request: AudioGenerationRequest, http_request: Request ...@@ -274,14 +278,22 @@ async def audio_generate(request: AudioGenerationRequest, http_request: Request
multi_model_manager.models[model_key] = pipe multi_model_manager.models[model_key] = pipe
multi_model_manager.current_model_key = model_key multi_model_manager.current_model_key = model_key
_tid = task_registry.register(
"audio", title=(request.prompt or "")[:80], model=model_name or "")
task_registry.start(_tid)
try: try:
audio_bytes, ext = await asyncio.get_event_loop().run_in_executor( audio_bytes, ext = await asyncio.get_event_loop().run_in_executor(
None, _generate_audio, pipe, model_name, request) None, _generate_audio, pipe, model_name, request, _tid)
except TaskCancelled:
_aud_progress_done()
raise # global handler finishes the task (cancelled) + returns HTTP 499
except Exception as e: except Exception as e:
task_registry.finish(_tid, "error", str(e)[:200])
_aud_progress_done() _aud_progress_done()
raise HTTPException(status_code=500, detail=f"Audio generation failed: {e}") raise HTTPException(status_code=500, detail=f"Audio generation failed: {e}")
finally: finally:
_aud_progress_done() _aud_progress_done()
task_registry.finish(_tid, "done")
result = _save_audio_response(audio_bytes, ext, http_request) result = _save_audio_response(audio_bytes, ext, http_request)
......
...@@ -166,8 +166,15 @@ class AudioStemRequest(BaseModel): ...@@ -166,8 +166,15 @@ class AudioStemRequest(BaseModel):
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow")
@router.post("/v1/audio/stems") @router.post("/v1/audio/stems", summary="Separate audio into stems")
async def separate_stems(request: AudioStemRequest, http_request: Request = None): async def separate_stems(request: AudioStemRequest, http_request: Request = None):
"""Split a track into its component stems (source separation).
Separates an input clip according to `stem_mode` (e.g. vocals/instrumental, or a
full 4-stem split). Uses an ML separation provider when available, falling back to
an ffmpeg-based best-effort split when `fallback_mode` is set. Returns one audio
output per stem along with the backend and quality tier used.
"""
try: try:
audio_bytes = _decode_audio(request.audio) audio_bytes = _decode_audio(request.audio)
except Exception as exc: except Exception as exc:
......
...@@ -419,7 +419,7 @@ def resolve_character_profiles(profile_names: List[str]) -> List[str]: ...@@ -419,7 +419,7 @@ def resolve_character_profiles(profile_names: List[str]) -> List[str]:
# ── Endpoints ───────────────────────────────────────────────────────────────── # ── Endpoints ─────────────────────────────────────────────────────────────────
@router.post("/v1/characters") @router.post("/v1/characters", summary="Create or replace a character profile")
async def save_character(req: CharacterSaveRequest, _auth=Depends(_require_api_auth)): async def save_character(req: CharacterSaveRequest, _auth=Depends(_require_api_auth)):
"""Save or update a named character profile.""" """Save or update a named character profile."""
if not req.name or '/' in req.name or '..' in req.name: if not req.name or '/' in req.name or '..' in req.name:
...@@ -430,13 +430,13 @@ async def save_character(req: CharacterSaveRequest, _auth=Depends(_require_api_a ...@@ -430,13 +430,13 @@ async def save_character(req: CharacterSaveRequest, _auth=Depends(_require_api_a
return {"ok": True, "name": meta['name'], "image_count": meta['image_count']} return {"ok": True, "name": meta['name'], "image_count": meta['image_count']}
@router.get("/v1/characters") @router.get("/v1/characters", summary="List character profiles")
async def list_characters(_auth=Depends(_require_api_auth)): async def list_characters(_auth=Depends(_require_api_auth)):
"""List all saved character profiles (metadata only, no images).""" """List all saved character profiles (metadata only, no images)."""
return {"characters": _list_characters()} return {"characters": _list_characters()}
@router.get("/v1/characters/{name}") @router.get("/v1/characters/{name}", summary="Get a character profile")
async def get_character(name: str, _auth=Depends(_require_api_auth)): async def get_character(name: str, _auth=Depends(_require_api_auth)):
"""Get a character profile including its reference images as base64.""" """Get a character profile including its reference images as base64."""
meta = _load_character_meta(name) meta = _load_character_meta(name)
...@@ -452,7 +452,7 @@ async def get_character(name: str, _auth=Depends(_require_api_auth)): ...@@ -452,7 +452,7 @@ async def get_character(name: str, _auth=Depends(_require_api_auth)):
} }
@router.delete("/v1/characters/{name}") @router.delete("/v1/characters/{name}", summary="Delete a character profile")
async def delete_character(name: str, _auth=Depends(_require_api_auth)): async def delete_character(name: str, _auth=Depends(_require_api_auth)):
"""Delete a character profile.""" """Delete a character profile."""
cdir = _char_dir(name) cdir = _char_dir(name)
...@@ -463,7 +463,7 @@ async def delete_character(name: str, _auth=Depends(_require_api_auth)): ...@@ -463,7 +463,7 @@ async def delete_character(name: str, _auth=Depends(_require_api_auth)):
return {"ok": True, "name": name} return {"ok": True, "name": name}
@router.patch("/v1/characters/{name}") @router.patch("/v1/characters/{name}", summary="Update a character profile")
async def patch_character(name: str, req: CharacterPatchRequest, _auth=Depends(_require_api_auth)): async def patch_character(name: str, req: CharacterPatchRequest, _auth=Depends(_require_api_auth)):
"""Update a character profile: description, add images, or remove images by index.""" """Update a character profile: description, add images, or remove images by index."""
meta = _load_character_meta(name) meta = _load_character_meta(name)
...@@ -512,7 +512,7 @@ async def patch_character(name: str, req: CharacterPatchRequest, _auth=Depends(_ ...@@ -512,7 +512,7 @@ async def patch_character(name: str, req: CharacterPatchRequest, _auth=Depends(_
return {"ok": True, "name": name, "image_count": meta['image_count']} return {"ok": True, "name": name, "image_count": meta['image_count']}
@router.post("/v1/characters/generate") @router.post("/v1/characters/generate", summary="Generate character reference images")
async def generate_character(req: CharacterGenerateRequest, request: Request): async def generate_character(req: CharacterGenerateRequest, request: Request):
""" """
Generate a character profile from a text prompt. Generate a character profile from a text prompt.
...@@ -585,7 +585,7 @@ async def generate_character(req: CharacterGenerateRequest, request: Request): ...@@ -585,7 +585,7 @@ async def generate_character(req: CharacterGenerateRequest, request: Request):
return {"ok": True, "name": meta["name"], "image_count": meta["image_count"]} return {"ok": True, "name": meta["name"], "image_count": meta["image_count"]}
@router.post("/v1/characters/extract") @router.post("/v1/characters/extract", summary="Extract a character from media")
async def extract_character(req: CharacterExtractRequest): async def extract_character(req: CharacterExtractRequest):
""" """
Extract a character profile from source images and/or videos. Extract a character profile from source images and/or videos.
......
...@@ -380,7 +380,7 @@ class AudioMusicDubRequest(BaseModel): ...@@ -380,7 +380,7 @@ class AudioMusicDubRequest(BaseModel):
model_config = ConfigDict(extra='allow') model_config = ConfigDict(extra='allow')
@router.get('/v1/pipelines/custom') @router.get('/v1/pipelines/custom', summary="List saved custom pipelines")
async def list_custom_pipelines(): async def list_custom_pipelines():
"""List all saved custom pipeline definitions.""" """List all saved custom pipeline definitions."""
from codai.admin.routes import config_manager from codai.admin.routes import config_manager
...@@ -389,7 +389,7 @@ async def list_custom_pipelines(): ...@@ -389,7 +389,7 @@ async def list_custom_pipelines():
return {'pipelines': config_manager.pipelines_data} return {'pipelines': config_manager.pipelines_data}
@router.get('/v1/pipelines/step-types') @router.get('/v1/pipelines/step-types', summary="List available pipeline step types")
async def list_step_types(): async def list_step_types():
"""List available step types with their parameter schemas.""" """List available step types with their parameter schemas."""
return { return {
...@@ -400,7 +400,7 @@ async def list_step_types(): ...@@ -400,7 +400,7 @@ async def list_step_types():
} }
@router.post('/v1/pipelines/custom') @router.post('/v1/pipelines/custom', summary="Create a custom pipeline")
async def create_custom_pipeline(pipeline: PipelineDefinition): async def create_custom_pipeline(pipeline: PipelineDefinition):
"""Save a new custom pipeline definition.""" """Save a new custom pipeline definition."""
from codai.admin.routes import config_manager from codai.admin.routes import config_manager
...@@ -416,7 +416,7 @@ async def create_custom_pipeline(pipeline: PipelineDefinition): ...@@ -416,7 +416,7 @@ async def create_custom_pipeline(pipeline: PipelineDefinition):
return {'created': True, 'pipeline': data} return {'created': True, 'pipeline': data}
@router.put('/v1/pipelines/custom/{pipeline_id}') @router.put('/v1/pipelines/custom/{pipeline_id}', summary="Update a custom pipeline")
async def update_custom_pipeline(pipeline_id: str, pipeline: PipelineDefinition): async def update_custom_pipeline(pipeline_id: str, pipeline: PipelineDefinition):
"""Update an existing custom pipeline.""" """Update an existing custom pipeline."""
from codai.admin.routes import config_manager from codai.admin.routes import config_manager
...@@ -433,7 +433,7 @@ async def update_custom_pipeline(pipeline_id: str, pipeline: PipelineDefinition) ...@@ -433,7 +433,7 @@ async def update_custom_pipeline(pipeline_id: str, pipeline: PipelineDefinition)
return {'updated': True, 'pipeline': data} return {'updated': True, 'pipeline': data}
@router.delete('/v1/pipelines/custom/{pipeline_id}') @router.delete('/v1/pipelines/custom/{pipeline_id}', summary="Delete a custom pipeline")
async def delete_custom_pipeline(pipeline_id: str): async def delete_custom_pipeline(pipeline_id: str):
"""Delete a custom pipeline.""" """Delete a custom pipeline."""
from codai.admin.routes import config_manager from codai.admin.routes import config_manager
...@@ -447,7 +447,7 @@ async def delete_custom_pipeline(pipeline_id: str): ...@@ -447,7 +447,7 @@ async def delete_custom_pipeline(pipeline_id: str):
return {'deleted': True, 'id': pipeline_id} return {'deleted': True, 'id': pipeline_id}
@router.post('/v1/pipelines/custom/{pipeline_id}/run') @router.post('/v1/pipelines/custom/{pipeline_id}/run', summary="Run a saved custom pipeline")
async def run_custom_pipeline(pipeline_id: str, body: PipelineRunRequest, http_request: Request = None): async def run_custom_pipeline(pipeline_id: str, body: PipelineRunRequest, http_request: Request = None):
"""Execute a saved custom pipeline.""" """Execute a saved custom pipeline."""
from codai.admin.routes import config_manager from codai.admin.routes import config_manager
...@@ -459,14 +459,20 @@ async def run_custom_pipeline(pipeline_id: str, body: PipelineRunRequest, http_r ...@@ -459,14 +459,20 @@ async def run_custom_pipeline(pipeline_id: str, body: PipelineRunRequest, http_r
return await _execute_pipeline(pipeline_def, body.input or '', http_request) return await _execute_pipeline(pipeline_def, body.input or '', http_request)
@router.post('/v1/pipelines/run') @router.post('/v1/pipelines/run', summary="Run an inline pipeline definition")
async def run_inline_pipeline(pipeline: PipelineDefinition, http_request: Request = None): async def run_inline_pipeline(pipeline: PipelineDefinition, http_request: Request = None):
"""Execute an inline pipeline definition without saving it.""" """Execute an inline pipeline definition without saving it."""
return await _execute_pipeline(pipeline.model_dump(), '', http_request) return await _execute_pipeline(pipeline.model_dump(), '', http_request)
@router.post('/v1/pipelines/audio-understand') @router.post('/v1/pipelines/audio-understand', summary="Transcribe and analyze audio")
async def run_audio_understanding(request: AudioUnderstandRequest, http_request: Request = None): async def run_audio_understanding(request: AudioUnderstandRequest, http_request: Request = None):
"""Transcribe and analyze an audio clip in one pass.
Convenience pipeline that transcribes the input audio and then reasons over the
transcript (summary/understanding) using the configured text model. Returns the
transcript together with the model's analysis.
"""
if not request.audio: if not request.audio:
raise HTTPException(status_code=400, detail='Provide audio input') raise HTTPException(status_code=400, detail='Provide audio input')
...@@ -543,8 +549,14 @@ async def run_full_music_dub(request: AudioMusicDubRequest, http_request: Reques ...@@ -543,8 +549,14 @@ async def run_full_music_dub(request: AudioMusicDubRequest, http_request: Reques
} }
@router.post('/v1/pipelines/audio-music-dub') @router.post('/v1/pipelines/audio-music-dub', summary="Dub a song into another language")
async def run_audio_music_dub(request: AudioMusicDubRequest, http_request: Request = None): async def run_audio_music_dub(request: AudioMusicDubRequest, http_request: Request = None):
"""Dub a song into another language while preserving the backing music.
Splits the track into vocals and instrumental, transcribes and translates the
lyrics, re-sings/voice-converts the translated vocals, then remixes them over the
original instrumental. Returns every intermediate stem plus the final mixed result.
"""
if not request.audio: if not request.audio:
raise HTTPException(status_code=400, detail='Provide audio input') raise HTTPException(status_code=400, detail='Provide audio input')
......
...@@ -101,7 +101,7 @@ def _embed_texts(model_obj, texts: List[str], dimensions=None) -> List[List[floa ...@@ -101,7 +101,7 @@ def _embed_texts(model_obj, texts: List[str], dimensions=None) -> List[List[floa
return results return results
@router.post("/v1/embeddings", response_model=EmbeddingsResponse) @router.post("/v1/embeddings", response_model=EmbeddingsResponse, summary="Create embeddings")
async def create_embeddings(request: EmbeddingsRequest, http_request: Request = None): async def create_embeddings(request: EmbeddingsRequest, http_request: Request = None):
""" """
OpenAI-compatible embeddings endpoint. OpenAI-compatible embeddings endpoint.
...@@ -116,10 +116,11 @@ async def create_embeddings(request: EmbeddingsRequest, http_request: Request = ...@@ -116,10 +116,11 @@ async def create_embeddings(request: EmbeddingsRequest, http_request: Request =
model_key = model_info['model_key'] model_key = model_info['model_key']
model_obj = model_info.get('model_object') model_obj = model_info.get('model_object')
if model_obj is None:
device = _derive_device()
_emb_cfg = (multi_model_manager.config.get(f"embedding:{model_name}") _emb_cfg = (multi_model_manager.config.get(f"embedding:{model_name}")
or multi_model_manager.config.get(model_name) or {}) or multi_model_manager.config.get(model_name) or {})
if model_obj is None:
device = _derive_device()
try: try:
model_obj = await asyncio.get_event_loop().run_in_executor( model_obj = await asyncio.get_event_loop().run_in_executor(
None, _load_embedding_model, model_name, device, _emb_cfg) None, _load_embedding_model, model_name, device, _emb_cfg)
...@@ -136,7 +137,59 @@ async def create_embeddings(request: EmbeddingsRequest, http_request: Request = ...@@ -136,7 +137,59 @@ async def create_embeddings(request: EmbeddingsRequest, http_request: Request =
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail=f"Embedding failed: {e}") raise HTTPException(status_code=500, detail=f"Embedding failed: {e}")
if request.encoding_format == 'base64': # Optional TurboQuant vector quantization (data-free, inner-product preserving).
# The per-model config block (turboquant: {enabled, backend, bits}) is the
# source of truth for enable/disable + which implementation to use; the
# per-request `quantization` field triggers it and can override the bit width.
from codai.models import turboquant as _tq
_raw = _emb_cfg.get('_raw_cfg') if isinstance(_emb_cfg.get('_raw_cfg'), dict) else {}
tq_cfg = _emb_cfg.get('turboquant') or _raw.get('turboquant') or {}
tq_enabled = tq_cfg.get('enabled', None) # None = no explicit model setting
tq_backend = (tq_cfg.get('backend') or 'builtin')
quant_meta = None
quant_bits = None
req_spec = getattr(request, 'quantization', None)
if not req_spec and tq_enabled and tq_cfg.get('bits'):
req_spec = f"turbo{tq_cfg.get('bits')}" # model-configured default
if req_spec:
if tq_enabled is False:
raise HTTPException(
status_code=400,
detail="TurboQuant is disabled for this model (enable it in the "
"model configuration).")
quant_bits = _tq._parse_quant_spec(req_spec)
if quant_bits is None:
raise HTTPException(
status_code=400,
detail=f"Unsupported quantization '{req_spec}' "
"(use 'turbo'/'turbo8'/'turbo6'/'turbo4'/'turbo2')")
if quant_bits is not None and request.encoding_format == 'base64':
# Compact wire form: each embedding is base64 of [f16 norm][packed codes].
# The compact packing is the built-in wire format regardless of backend
# (the upstream library exposes its own opaque store, not per-vector blobs).
blobs, meta = await asyncio.get_event_loop().run_in_executor(
None, _tq.quantize_base64, vectors, quant_bits)
data = [EmbeddingObject(index=i, embedding=b) for i, b in enumerate(blobs)]
quant_meta = {
"method": meta.method, "bits": meta.bits, "seed": meta.seed,
"dim": meta.dim, "dim_padded": meta.dim_padded, "radius": meta.radius,
"bytes_per_vector": meta.bytes_per_vector, "backend": "builtin",
"layout": "base64([float16 norm][packbits(rotated b-bit codes, MSB-first per numpy.packbits)])",
}
elif quant_bits is not None:
# Lossy reconstruction returned as plain floats (quantized-store fidelity).
try:
vectors = await asyncio.get_event_loop().run_in_executor(
None, lambda: _tq.reconstruct(vectors, quant_bits, backend=tq_backend))
except RuntimeError as e:
raise HTTPException(status_code=400, detail=str(e))
data = [EmbeddingObject(index=i, embedding=v) for i, v in enumerate(vectors)]
eff_backend = tq_backend if tq_backend != 'auto' else _tq.backend_name()
quant_meta = {"method": "turboquant", "bits": quant_bits,
"encoding": "float-reconstruction", "backend": eff_backend}
elif request.encoding_format == 'base64':
import struct import struct
data = [EmbeddingObject( data = [EmbeddingObject(
index=i, index=i,
...@@ -146,8 +199,11 @@ async def create_embeddings(request: EmbeddingsRequest, http_request: Request = ...@@ -146,8 +199,11 @@ async def create_embeddings(request: EmbeddingsRequest, http_request: Request =
data = [EmbeddingObject(index=i, embedding=v) for i, v in enumerate(vectors)] data = [EmbeddingObject(index=i, embedding=v) for i, v in enumerate(vectors)]
total_tokens = sum(len(t.split()) for t in texts) total_tokens = sum(len(t.split()) for t in texts)
return EmbeddingsResponse( resp = EmbeddingsResponse(
data=data, data=data,
model=request.model, model=request.model,
usage={"prompt_tokens": total_tokens, "total_tokens": total_tokens}, usage={"prompt_tokens": total_tokens, "total_tokens": total_tokens},
) )
if quant_meta is not None:
resp.quantization = quant_meta
return resp
\ No newline at end of file
...@@ -307,7 +307,7 @@ def resolve_environment_profiles(profile_names: List[str]) -> List[str]: ...@@ -307,7 +307,7 @@ def resolve_environment_profiles(profile_names: List[str]) -> List[str]:
# ── Endpoints ───────────────────────────────────────────────────────────────── # ── Endpoints ─────────────────────────────────────────────────────────────────
@router.post("/v1/environments") @router.post("/v1/environments", summary="Create or replace an environment profile")
async def save_environment(req: EnvironmentSaveRequest, _auth=Depends(_require_api_auth)): async def save_environment(req: EnvironmentSaveRequest, _auth=Depends(_require_api_auth)):
"""Save or update a named environment profile.""" """Save or update a named environment profile."""
if not req.name or '/' in req.name or '..' in req.name: if not req.name or '/' in req.name or '..' in req.name:
...@@ -318,13 +318,13 @@ async def save_environment(req: EnvironmentSaveRequest, _auth=Depends(_require_a ...@@ -318,13 +318,13 @@ async def save_environment(req: EnvironmentSaveRequest, _auth=Depends(_require_a
return {"ok": True, "name": meta['name'], "image_count": meta['image_count']} return {"ok": True, "name": meta['name'], "image_count": meta['image_count']}
@router.get("/v1/environments") @router.get("/v1/environments", summary="List environment profiles")
async def list_environments(_auth=Depends(_require_api_auth)): async def list_environments(_auth=Depends(_require_api_auth)):
"""List all saved environment profiles (metadata only).""" """List all saved environment profiles (metadata only)."""
return {"environments": _list_environments()} return {"environments": _list_environments()}
@router.get("/v1/environments/{name}") @router.get("/v1/environments/{name}", summary="Get an environment profile")
async def get_environment(name: str, _auth=Depends(_require_api_auth)): async def get_environment(name: str, _auth=Depends(_require_api_auth)):
"""Get an environment profile including its reference images as base64.""" """Get an environment profile including its reference images as base64."""
meta = _load_environment_meta(name) meta = _load_environment_meta(name)
...@@ -340,7 +340,7 @@ async def get_environment(name: str, _auth=Depends(_require_api_auth)): ...@@ -340,7 +340,7 @@ async def get_environment(name: str, _auth=Depends(_require_api_auth)):
} }
@router.delete("/v1/environments/{name}") @router.delete("/v1/environments/{name}", summary="Delete an environment profile")
async def delete_environment(name: str, _auth=Depends(_require_api_auth)): async def delete_environment(name: str, _auth=Depends(_require_api_auth)):
"""Delete an environment profile.""" """Delete an environment profile."""
edir = _env_dir(name) edir = _env_dir(name)
...@@ -351,7 +351,7 @@ async def delete_environment(name: str, _auth=Depends(_require_api_auth)): ...@@ -351,7 +351,7 @@ async def delete_environment(name: str, _auth=Depends(_require_api_auth)):
return {"ok": True, "name": name} return {"ok": True, "name": name}
@router.patch("/v1/environments/{name}") @router.patch("/v1/environments/{name}", summary="Update an environment profile")
async def patch_environment(name: str, req: EnvironmentPatchRequest, _auth=Depends(_require_api_auth)): async def patch_environment(name: str, req: EnvironmentPatchRequest, _auth=Depends(_require_api_auth)):
"""Update an environment profile: description, add images, or remove images by index.""" """Update an environment profile: description, add images, or remove images by index."""
meta = _load_environment_meta(name) meta = _load_environment_meta(name)
...@@ -398,7 +398,7 @@ async def patch_environment(name: str, req: EnvironmentPatchRequest, _auth=Depen ...@@ -398,7 +398,7 @@ async def patch_environment(name: str, req: EnvironmentPatchRequest, _auth=Depen
return {"ok": True, "name": name, "image_count": meta['image_count']} return {"ok": True, "name": name, "image_count": meta['image_count']}
@router.post("/v1/environments/generate") @router.post("/v1/environments/generate", summary="Generate environment reference images")
async def generate_environment(req: EnvironmentGenerateRequest, request: Request): async def generate_environment(req: EnvironmentGenerateRequest, request: Request):
""" """
Generate an environment profile from a text prompt. Generate an environment profile from a text prompt.
...@@ -471,7 +471,7 @@ async def generate_environment(req: EnvironmentGenerateRequest, request: Request ...@@ -471,7 +471,7 @@ async def generate_environment(req: EnvironmentGenerateRequest, request: Request
return {"ok": True, "name": meta["name"], "image_count": meta["image_count"]} return {"ok": True, "name": meta["name"], "image_count": meta["image_count"]}
@router.post("/v1/environments/extract") @router.post("/v1/environments/extract", summary="Extract an environment from media")
async def extract_environment(req: EnvironmentExtractRequest): async def extract_environment(req: EnvironmentExtractRequest):
""" """
Extract an environment profile from source images and/or videos. Extract an environment profile from source images and/or videos.
......
...@@ -144,7 +144,7 @@ class FaceSwapRequest(BaseModel): ...@@ -144,7 +144,7 @@ class FaceSwapRequest(BaseModel):
# Endpoint # Endpoint
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@router.post('/v1/images/faceswap') @router.post('/v1/images/faceswap', summary="Swap faces between images")
async def faceswap(request: FaceSwapRequest, http_request: Request = None): async def faceswap(request: FaceSwapRequest, http_request: Request = None):
""" """
Swap the face from source_face into every face found in target. Swap the face from source_face into every face found in target.
......
...@@ -37,6 +37,7 @@ from pydantic import BaseModel, ConfigDict ...@@ -37,6 +37,7 @@ from pydantic import BaseModel, ConfigDict
from codai.models.manager import multi_model_manager from codai.models.manager import multi_model_manager
from codai.pydantic.imagerequest import ImageGenerationRequest from codai.pydantic.imagerequest import ImageGenerationRequest
from codai.api.state import get_load_mode from codai.api.state import get_load_mode
from codai.tasks import task_registry, TaskCancelled
# ============================================================================= # =============================================================================
...@@ -756,6 +757,13 @@ async def _generate_with_diffusers(pipeline, request, global_args, http_request= ...@@ -756,6 +757,13 @@ async def _generate_with_diffusers(pipeline, request, global_args, http_request=
_progress_reset(num_steps) _progress_reset(num_steps)
# Register this generation as a cancellable task (live view + cooperative
# cancel via the step callback below).
_tid = task_registry.register(
"image", title=(request.prompt or "")[:80],
model=getattr(request, 'model', '') or '', total=num_steps)
task_registry.start(_tid)
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# Prompt embedding cache # Prompt embedding cache
# Try to encode the prompt once and reuse the embeddings. # Try to encode the prompt once and reuse the embeddings.
...@@ -830,6 +838,11 @@ async def _generate_with_diffusers(pipeline, request, global_args, http_request= ...@@ -830,6 +838,11 @@ async def _generate_with_diffusers(pipeline, request, global_args, http_request=
embed_kwargs = {} embed_kwargs = {}
def _step_cb(pipe, step_index, timestep, callback_kwargs): def _step_cb(pipe, step_index, timestep, callback_kwargs):
# Cooperative cancellation: abort at the next step boundary if cancelled.
task_registry.raise_if_cancelled(_tid)
# Cooperative pause: block here while the user has paused this task.
task_registry.wait_if_paused(_tid)
task_registry.step(_tid, step_index + 1)
_progress_step(step_index + 1) _progress_step(step_index + 1)
# Mid-generation thermal checkpoint: pause between denoise steps if too hot. # Mid-generation thermal checkpoint: pause between denoise steps if too hot.
try: try:
...@@ -912,10 +925,21 @@ async def _generate_with_diffusers(pipeline, request, global_args, http_request= ...@@ -912,10 +925,21 @@ async def _generate_with_diffusers(pipeline, request, global_args, http_request=
try: try:
result = await asyncio.to_thread(pipeline, **call_kwargs) result = await asyncio.to_thread(pipeline, **call_kwargs)
except TaskCancelled:
_progress_done()
raise # global handler finishes the task (cancelled) + returns HTTP 499
except TypeError: except TypeError:
# Older pipeline that doesn't support callback_on_step_end # Older pipeline that doesn't support callback_on_step_end
call_kwargs.pop('callback_on_step_end', None) call_kwargs.pop('callback_on_step_end', None)
try:
result = await asyncio.to_thread(pipeline, **call_kwargs) result = await asyncio.to_thread(pipeline, **call_kwargs)
except TaskCancelled:
_progress_done()
raise
except Exception as e:
task_registry.finish(_tid, "error", str(e)[:200])
_progress_done()
raise
finally: finally:
_progress_done() _progress_done()
...@@ -967,6 +991,7 @@ async def _generate_with_diffusers(pipeline, request, global_args, http_request= ...@@ -967,6 +991,7 @@ async def _generate_with_diffusers(pipeline, request, global_args, http_request=
except Exception: except Exception:
pass pass
task_registry.finish(_tid, "done")
return { return {
"created": timestamp, "created": timestamp,
"data": images, "data": images,
...@@ -1014,7 +1039,17 @@ async def _generate_with_sdcpp(sd_model, request, global_args, http_request=None ...@@ -1014,7 +1039,17 @@ async def _generate_with_sdcpp(sd_model, request, global_args, http_request=None
_progress_reset(steps) _progress_reset(steps)
# sd.cpp runs the whole diffusion inside one C call, so it can't be aborted
# mid-step (raising from its progress callback won't reliably unwind the C
# extension). We still register the task for visibility + step progress; a
# cancel takes effect when control returns to Python.
_tid = task_registry.register(
"image", title=(request.prompt or "")[:80],
model=getattr(request, 'model', '') or '', total=steps)
task_registry.start(_tid)
def _sdcpp_progress(step: int, total: int, elapsed: float): def _sdcpp_progress(step: int, total: int, elapsed: float):
task_registry.step(_tid, step)
_progress_step(step) _progress_step(step)
# Use request seed if provided, otherwise use CLI default seed # Use request seed if provided, otherwise use CLI default seed
...@@ -1045,6 +1080,10 @@ async def _generate_with_sdcpp(sd_model, request, global_args, http_request=None ...@@ -1045,6 +1080,10 @@ async def _generate_with_sdcpp(sd_model, request, global_args, http_request=None
seed=seed if seed is not None else 42, seed=seed if seed is not None else 42,
batch_count=request.n if request.n else 1, batch_count=request.n if request.n else 1,
) )
except Exception as e:
task_registry.finish(_tid, "error", str(e)[:200])
_progress_done()
raise
finally: finally:
_progress_done() _progress_done()
...@@ -1087,6 +1126,7 @@ async def _generate_with_sdcpp(sd_model, request, global_args, http_request=None ...@@ -1087,6 +1126,7 @@ async def _generate_with_sdcpp(sd_model, request, global_args, http_request=None
except Exception: except Exception:
pass pass
task_registry.finish(_tid, "done")
return { return {
"created": int(time.time()), "created": int(time.time()),
"data": images "data": images
...@@ -1185,7 +1225,7 @@ def _load_sdcpp_model(model_path: str, global_args, model_config: dict = None): ...@@ -1185,7 +1225,7 @@ def _load_sdcpp_model(model_path: str, global_args, model_config: dict = None):
router = APIRouter() router = APIRouter()
@router.get("/v1/images/progress") @router.get("/v1/images/progress", summary="Image generation progress")
async def get_image_progress(): async def get_image_progress():
"""Return current image generation step progress including speed.""" """Return current image generation step progress including speed."""
elapsed = _time.monotonic() - _gen_progress["started_at"] if _gen_progress["active"] else 0.0 elapsed = _time.monotonic() - _gen_progress["started_at"] if _gen_progress["active"] else 0.0
...@@ -1202,7 +1242,7 @@ async def get_image_progress(): ...@@ -1202,7 +1242,7 @@ async def get_image_progress():
} }
@router.post("/v1/images/generations") @router.post("/v1/images/generations", summary="Generate images (text-to-image)")
async def create_image_generation(request: ImageGenerationRequest, http_request: Request = None): async def create_image_generation(request: ImageGenerationRequest, http_request: Request = None):
""" """
Image generation endpoint (OpenAI-compatible). Image generation endpoint (OpenAI-compatible).
...@@ -1497,7 +1537,7 @@ def _load_img2img_pipeline(model_name: str, global_args): ...@@ -1497,7 +1537,7 @@ def _load_img2img_pipeline(model_name: str, global_args):
raise raise
@router.post("/v1/images/edits") @router.post("/v1/images/edits", summary="Edit an image (instruction / img2img)")
async def create_image_edit(request: ImageEditRequest, http_request: Request = None): async def create_image_edit(request: ImageEditRequest, http_request: Request = None):
""" """
Image-to-image editing endpoint (OpenAI-compatible). Image-to-image editing endpoint (OpenAI-compatible).
...@@ -1638,7 +1678,7 @@ def _load_inpaint_pipeline(model_name: str, global_args): ...@@ -1638,7 +1678,7 @@ def _load_inpaint_pipeline(model_name: str, global_args):
raise raise
@router.post("/v1/images/inpaint") @router.post("/v1/images/inpaint", summary="Inpaint a masked region")
async def create_image_inpaint(request: ImageInpaintRequest, http_request: Request = None): async def create_image_inpaint(request: ImageInpaintRequest, http_request: Request = None):
"""Inpaint a masked region of an image (OpenAI-compatible extension).""" """Inpaint a masked region of an image (OpenAI-compatible extension)."""
global global_args global global_args
...@@ -1750,7 +1790,7 @@ def _run_upscale(upscaler, image_bytes: bytes, scale: int): ...@@ -1750,7 +1790,7 @@ def _run_upscale(upscaler, image_bytes: bytes, scale: int):
return img.resize((w * scale, h * scale), PILImage.LANCZOS) return img.resize((w * scale, h * scale), PILImage.LANCZOS)
@router.post("/v1/images/upscale") @router.post("/v1/images/upscale", summary="Upscale an image")
async def create_image_upscale(request: ImageUpscaleRequest, http_request: Request = None): async def create_image_upscale(request: ImageUpscaleRequest, http_request: Request = None):
"""Upscale an image using Real-ESRGAN or PIL LANCZOS fallback.""" """Upscale an image using Real-ESRGAN or PIL LANCZOS fallback."""
global global_args global global_args
...@@ -1862,7 +1902,7 @@ def _resolve_spatial_model(requested: Optional[str], capability: str) -> Optiona ...@@ -1862,7 +1902,7 @@ def _resolve_spatial_model(requested: Optional[str], capability: str) -> Optiona
return None return None
@router.post("/v1/images/depth") @router.post("/v1/images/depth", summary="Estimate a depth map")
async def create_image_depth(request: ImageDepthRequest, http_request: Request = None): async def create_image_depth(request: ImageDepthRequest, http_request: Request = None):
"""Estimate depth map from an image.""" """Estimate depth map from an image."""
global global_args global global_args
...@@ -1968,7 +2008,7 @@ def _run_segmentation(seg_model, image_bytes: bytes, points, boxes): ...@@ -1968,7 +2008,7 @@ def _run_segmentation(seg_model, image_bytes: bytes, points, boxes):
return PILImage.fromarray(out) return PILImage.fromarray(out)
@router.post("/v1/images/segment") @router.post("/v1/images/segment", summary="Segment an image")
async def create_image_segment(request: ImageSegmentRequest, http_request: Request = None): async def create_image_segment(request: ImageSegmentRequest, http_request: Request = None):
"""Segment objects in an image using SAM or similar models.""" """Segment objects in an image using SAM or similar models."""
global global_args global global_args
...@@ -2035,7 +2075,7 @@ def _run_deblur(image_bytes: bytes, strength: float) -> "PILImage.Image": ...@@ -2035,7 +2075,7 @@ def _run_deblur(image_bytes: bytes, strength: float) -> "PILImage.Image":
return PILImage.fromarray((sharpened * 255).astype(np.uint8)) return PILImage.fromarray((sharpened * 255).astype(np.uint8))
@router.post("/v1/images/deblur") @router.post("/v1/images/deblur", summary="Deblur an image")
async def create_image_deblur(request: ImageDeblurRequest, http_request: Request = None): async def create_image_deblur(request: ImageDeblurRequest, http_request: Request = None):
"""Remove blur from an image using Wiener deconvolution and unsharp masking.""" """Remove blur from an image using Wiener deconvolution and unsharp masking."""
raw = base64.b64decode(request.image.split(',', 1)[-1] if ',' in request.image else request.image) raw = base64.b64decode(request.image.split(',', 1)[-1] if ',' in request.image else request.image)
...@@ -2093,7 +2133,7 @@ def _run_unpixelate(image_bytes: bytes, scale: int, model_path: Optional[str]) - ...@@ -2093,7 +2133,7 @@ def _run_unpixelate(image_bytes: bytes, scale: int, model_path: Optional[str]) -
return PILImage.fromarray(out_arr) return PILImage.fromarray(out_arr)
@router.post("/v1/images/unpixelate") @router.post("/v1/images/unpixelate", summary="Restore a pixelated image")
async def create_image_unpixelate(request: ImageUnpixelateRequest, http_request: Request = None): async def create_image_unpixelate(request: ImageUnpixelateRequest, http_request: Request = None):
"""Remove pixelation / upscale with detail recovery using Real-ESRGAN.""" """Remove pixelation / upscale with detail recovery using Real-ESRGAN."""
raw = base64.b64decode(request.image.split(',', 1)[-1] if ',' in request.image else request.image) raw = base64.b64decode(request.image.split(',', 1)[-1] if ',' in request.image else request.image)
...@@ -2155,7 +2195,7 @@ def _generate_clothing_mask(img_arr) -> "np.ndarray": ...@@ -2155,7 +2195,7 @@ def _generate_clothing_mask(img_arr) -> "np.ndarray":
return fg_mask return fg_mask
@router.post("/v1/images/outfit") @router.post("/v1/images/outfit", summary="Change outfit / clothing")
async def create_image_outfit(request: ImageOutfitRequest, http_request: Request = None): async def create_image_outfit(request: ImageOutfitRequest, http_request: Request = None):
"""Change the outfit/clothing in an image or video using inpainting.""" """Change the outfit/clothing in an image or video using inpainting."""
global global_args global global_args
......
...@@ -45,6 +45,7 @@ from pydantic import BaseModel, ConfigDict ...@@ -45,6 +45,7 @@ from pydantic import BaseModel, ConfigDict
from codai.platform_paths import default_loras_dir from codai.platform_paths import default_loras_dir
from codai.queue.manager import queue_manager from codai.queue.manager import queue_manager
from codai.tasks import task_registry, TaskCancelled
router = APIRouter() router = APIRouter()
...@@ -75,7 +76,19 @@ _jobs_lock = threading.Lock() ...@@ -75,7 +76,19 @@ _jobs_lock = threading.Lock()
_jobs: dict = {} # job_id -> record _jobs: dict = {} # job_id -> record
_active_job_id: Optional[str] = None _active_job_id: Optional[str] = None
_bg_tasks: set = set() # strong refs to detached train tasks _bg_tasks: set = set() # strong refs to detached train tasks
# Job ids with a pending cancel. Survives the window before a queued job's task
# is registered (the worker checks this set right after acquiring the GPU lock).
_cancel_requested: set = set()
# Job ids that must resume from their checkpoint even when global recovery is off
# (an explicit manual Restart from the Tasks page).
_force_resume_jobs: set = set()
_JOB_ACTIVE_STATES = ("queued", "preparing", "training", "saving") _JOB_ACTIVE_STATES = ("queued", "preparing", "training", "saving")
# Cap on retained *terminal* (done/cancelled/error/interrupted) job records so the
# persisted registry can't grow without bound. Active jobs are always kept; the
# oldest terminal records beyond this many are pruned (with their sidecar request
# files). A removed job is deleted outright (remove_job); this just bounds the
# slow accumulation of finished ones nobody removed by hand.
_JOB_HISTORY_MAX = 200
_MIRROR_FIELDS = ("active", "name", "step", "total", "status", "message", _MIRROR_FIELDS = ("active", "name", "step", "total", "status", "message",
"started_at", "path") "started_at", "path")
_last_job_persist = 0.0 _last_job_persist = 0.0
...@@ -85,6 +98,25 @@ def _jobs_file() -> str: ...@@ -85,6 +98,25 @@ def _jobs_file() -> str:
return os.path.join(_loras_dir(), "_train_jobs.json") return os.path.join(_loras_dir(), "_train_jobs.json")
def _prune_jobs_locked() -> None:
"""Drop the oldest terminal job records beyond _JOB_HISTORY_MAX. Caller holds
_jobs_lock. Also removes each pruned job's persisted request sidecar."""
terminal = [(jid, r) for jid, r in _jobs.items()
if r.get("status") not in _JOB_ACTIVE_STATES]
excess = len(terminal) - _JOB_HISTORY_MAX
if excess <= 0:
return
terminal.sort(key=lambda kv: kv[1].get("updated_at") or kv[1].get("started_at") or 0)
for jid, _ in terminal[:excess]:
_jobs.pop(jid, None)
try:
p = _job_req_path(jid)
if os.path.isfile(p):
os.remove(p)
except Exception:
pass
def _persist_jobs_locked(force: bool = False) -> None: def _persist_jobs_locked(force: bool = False) -> None:
"""Write the registry to disk. Throttled to ~5s unless forced (status change) """Write the registry to disk. Throttled to ~5s unless forced (status change)
so per-step progress updates don't hammer the disk.""" so per-step progress updates don't hammer the disk."""
...@@ -93,6 +125,7 @@ def _persist_jobs_locked(force: bool = False) -> None: ...@@ -93,6 +125,7 @@ def _persist_jobs_locked(force: bool = False) -> None:
if not force and (now - _last_job_persist) < 5.0: if not force and (now - _last_job_persist) < 5.0:
return return
_last_job_persist = now _last_job_persist = now
_prune_jobs_locked()
try: try:
p = _jobs_file() p = _jobs_file()
tmp = p + ".tmp" tmp = p + ".tmp"
...@@ -127,11 +160,29 @@ def _update_job(job_id: Optional[str], force: bool = False, **fields) -> None: ...@@ -127,11 +160,29 @@ def _update_job(job_id: Optional[str], force: bool = False, **fields) -> None:
_persist_jobs_locked(force=force) _persist_jobs_locked(force=force)
# When False (set via --no-resume-jobs or config.jobs.resume_on_restart=false),
# interrupted training is NOT recovered on restart: mid-flight jobs are marked
# 'cancelled' (not 'interrupted') and the per-job checkpoint auto-resume is
# disabled. Checkpoints are kept on disk, so a job can still be restarted
# manually (Tasks page) — that path passes resume=True explicitly.
_RESUME_ENABLED: bool = True
def set_resume_enabled(enabled: bool) -> None:
"""Enable/disable interrupted-training recovery on restart. Call before
set_global_args() (which runs _load_jobs_on_start)."""
global _RESUME_ENABLED
_RESUME_ENABLED = bool(enabled)
def _load_jobs_on_start() -> None: def _load_jobs_on_start() -> None:
"""Load the persisted registry and mark any job that was mid-flight when the """Load the persisted registry and reconcile jobs that were mid-flight when
process died as 'interrupted' (its in-memory GPU training is gone). Clients the process died (their in-memory GPU training is gone).
polling such a job learn it must be resubmitted (server resumes from the last
on-disk checkpoint).""" With recovery enabled they become 'interrupted' — a polling client resubmits
and the server resumes from the last on-disk checkpoint. With recovery
disabled (--no-resume-jobs / config) they become 'cancelled' and are not
auto-resumed; checkpoints are kept so they can be restarted manually."""
global _jobs global _jobs
try: try:
with open(_jobs_file()) as f: with open(_jobs_file()) as f:
...@@ -143,9 +194,13 @@ def _load_jobs_on_start() -> None: ...@@ -143,9 +194,13 @@ def _load_jobs_on_start() -> None:
changed = False changed = False
for rec in data.values(): for rec in data.values():
if rec.get("status") in _JOB_ACTIVE_STATES: if rec.get("status") in _JOB_ACTIVE_STATES:
if _RESUME_ENABLED:
rec["status"] = "interrupted" rec["status"] = "interrupted"
rec["active"] = False
rec["message"] = "interrupted by server restart — resubmit to resume" rec["message"] = "interrupted by server restart — resubmit to resume"
else:
rec["status"] = "cancelled"
rec["message"] = "cancelled — recovery disabled on restart (checkpoint kept)"
rec["active"] = False
changed = True changed = True
with _jobs_lock: with _jobs_lock:
_jobs = data _jobs = data
...@@ -494,6 +549,129 @@ def _set_progress(**kw): ...@@ -494,6 +549,129 @@ def _set_progress(**kw):
mirror = {k: kw[k] for k in kw if k in _MIRROR_FIELDS} mirror = {k: kw[k] for k in kw if k in _MIRROR_FIELDS}
if mirror: if mirror:
_update_job(job_id, **mirror) _update_job(job_id, **mirror)
# Mirror step/total into the live task registry (task id == job id).
if "step" in kw or "total" in kw:
task_registry.step(job_id, kw.get("step", 0), kw.get("total"))
if "message" in kw:
task_registry.update(job_id, message=str(kw["message"])[:200])
def _check_train_cancel() -> None:
"""Raise TaskCancelled if the running training job was cancelled, or block
while it is paused. Called each training step; the task id is the active job
id. The pause wait stays responsive to cancellation."""
jid = _active_job_id
if jid and (jid in _cancel_requested or task_registry.is_cancelled(jid)):
raise TaskCancelled(jid)
# Cooperative pause: block here (in the training thread) while paused.
if jid:
was_paused = task_registry.is_paused(jid)
if was_paused:
_update_job(jid, force=True, message="paused")
task_registry.wait_if_paused(jid)
if was_paused:
_update_job(jid, force=True, message="training")
def _job_req_path(job_id: str) -> str:
"""Per-job sidecar file holding the full training request, so a job can be
restarted (resumed from its checkpoint) after the original client is gone.
Kept separate from _train_jobs.json so inline base64 images don't bloat the
registry that's rewritten on every progress tick."""
return os.path.join(_loras_dir(), "_jobs", f"{job_id}.json")
def _save_job_request(job_id: str, req) -> None:
try:
os.makedirs(os.path.dirname(_job_req_path(job_id)), exist_ok=True)
data = req.model_dump() if hasattr(req, "model_dump") else req.dict()
with open(_job_req_path(job_id), "w") as f:
json.dump(data, f)
except Exception as e:
print(f" [lora] could not persist request for {job_id}: {e}")
def _load_job_request(job_id: str):
try:
with open(_job_req_path(job_id)) as f:
return LoraTrainRequest(**json.load(f))
except Exception:
return None
def list_jobs() -> list:
"""Snapshot of all training jobs (newest first) for the Tasks view."""
with _jobs_lock:
recs = [dict(r) for r in _jobs.values()]
recs.sort(key=lambda r: r.get("started_at") or r.get("updated_at") or 0, reverse=True)
return recs
def remove_job(job_id: str) -> bool:
"""Dismiss a finished/cancelled/errored training job from the view. Refuses a
job that is still active (cancel it first → returns False). Drops the job
record + its saved request sidecar; the trained LoRA weights and any training
checkpoint on disk are kept."""
with _jobs_lock:
rec = _jobs.get(job_id)
if rec is None:
return False
if rec.get("status") in _JOB_ACTIVE_STATES:
return False
_jobs.pop(job_id, None)
_persist_jobs_locked(force=True)
try:
p = _job_req_path(job_id)
if os.path.isfile(p):
os.remove(p)
except Exception:
pass
_cancel_requested.discard(job_id)
_force_resume_jobs.discard(job_id)
task_registry.remove(job_id)
return True
def cancel_job(job_id: str) -> bool:
"""Cancel a queued or running training job. Running jobs stop at the next
step; queued jobs are marked cancelled immediately (and aborted before any
GPU work if they later acquire the lock). Checkpoints are kept."""
with _jobs_lock:
rec = _jobs.get(job_id)
status = rec.get("status") if rec else None
if rec is None:
return False
_cancel_requested.add(job_id)
task_registry.cancel(job_id) # no-op if the task isn't registered yet
if status in ("queued", "interrupted", "preparing"):
_update_job(job_id, status="cancelled", active=False, force=True,
message="cancelled")
return True
def restart_job(job_id: str) -> Optional[str]:
"""Restart a finished/cancelled/interrupted training job, resuming from its
last on-disk checkpoint. Returns the (same) job_id on success, None if the
request payload could not be recovered."""
import asyncio
req = _load_job_request(job_id)
if req is None:
return None
_force_resume_jobs.add(job_id)
_cancel_requested.discard(job_id)
_update_job(job_id, status="queued", active=True, force=True,
message="restarting (resume from checkpoint)", step=0)
task = asyncio.create_task(_detached_train(req, job_id))
_bg_tasks.add(task)
task.add_done_callback(_bg_tasks.discard)
return job_id
def _resume_allowed(req) -> bool:
"""Whether to auto-resume this training from its on-disk checkpoint. True when
recovery is enabled globally, or when this is an explicit manual restart
(restart_job adds the job id to _force_resume_jobs)."""
return _RESUME_ENABLED or (_active_job_id in _force_resume_jobs)
def _lora_debug_enabled() -> bool: def _lora_debug_enabled() -> bool:
...@@ -961,7 +1139,7 @@ def _train_sd15(req, base_path, images, instance_prompt, ...@@ -961,7 +1139,7 @@ def _train_sd15(req, base_path, images, instance_prompt,
# Resume from a mid-training checkpoint if one survives a prior restart. # Resume from a mid-training checkpoint if one survives a prior restart.
start_step = 0 start_step = 0
_ck = _load_train_state(name, base_path=base_path, target="image", rank=rank, _ck = _load_train_state(name, base_path=base_path, target="image", rank=rank,
session=getattr(req, "session", None)) session=getattr(req, "session", None)) if _resume_allowed(req) else None
if _ck: if _ck:
try: try:
_apply_peft_checkpoint(name, "default", unet) _apply_peft_checkpoint(name, "default", unet)
...@@ -995,6 +1173,7 @@ def _train_sd15(req, base_path, images, instance_prompt, ...@@ -995,6 +1173,7 @@ def _train_sd15(req, base_path, images, instance_prompt,
unet.train() unet.train()
n = len(latents_list) n = len(latents_list)
for step in range(start_step, steps): for step in range(start_step, steps):
_check_train_cancel()
latents = latents_list[step % n].to(device) latents = latents_list[step % n].to(device)
noise = torch.randn_like(latents) noise = torch.randn_like(latents)
bsz = latents.shape[0] bsz = latents.shape[0]
...@@ -1104,7 +1283,7 @@ def _train_sdxl(req, base_path, images, instance_prompt, ...@@ -1104,7 +1283,7 @@ def _train_sdxl(req, base_path, images, instance_prompt,
# Resume from a mid-training checkpoint if one survives a prior restart. # Resume from a mid-training checkpoint if one survives a prior restart.
start_step = 0 start_step = 0
_ck = _load_train_state(name, base_path=base_path, target="image", rank=rank, _ck = _load_train_state(name, base_path=base_path, target="image", rank=rank,
session=getattr(req, "session", None)) session=getattr(req, "session", None)) if _resume_allowed(req) else None
if _ck: if _ck:
try: try:
_apply_peft_checkpoint(name, "default", unet) _apply_peft_checkpoint(name, "default", unet)
...@@ -1165,6 +1344,7 @@ def _train_sdxl(req, base_path, images, instance_prompt, ...@@ -1165,6 +1344,7 @@ def _train_sdxl(req, base_path, images, instance_prompt,
unet.train() unet.train()
n = len(latents_list) n = len(latents_list)
for step in range(start_step, steps): for step in range(start_step, steps):
_check_train_cancel()
latents = latents_list[step % n].to(device) latents = latents_list[step % n].to(device)
noise = torch.randn_like(latents) noise = torch.randn_like(latents)
bsz = latents.shape[0] bsz = latents.shape[0]
...@@ -1383,7 +1563,7 @@ def _train_wan(req, base_path, images, instance_prompt, ...@@ -1383,7 +1563,7 @@ def _train_wan(req, base_path, images, instance_prompt,
# runs for hours — a reboot otherwise throws all of it away. # runs for hours — a reboot otherwise throws all of it away.
start_step = 0 start_step = 0
_ck = _load_train_state(name, base_path=base_path, target="video", rank=rank, _ck = _load_train_state(name, base_path=base_path, target="video", rank=rank,
session=getattr(req, "session", None)) session=getattr(req, "session", None)) if _resume_allowed(req) else None
if _ck: if _ck:
try: try:
_apply_peft_checkpoint(name, "t", experts[0][1]) _apply_peft_checkpoint(name, "t", experts[0][1])
...@@ -1410,9 +1590,34 @@ def _train_wan(req, base_path, images, instance_prompt, ...@@ -1410,9 +1590,34 @@ def _train_wan(req, base_path, images, instance_prompt,
return experts[0][1] if t_val >= float(boundary) else experts[1][1] return experts[0][1] if t_val >= float(boundary) else experts[1][1]
return experts[0][1] return experts[0][1]
def _tr_in_channels(tr):
"""Channels the transformer's patch_embedding expects."""
try:
c = getattr(getattr(tr, "config", None), "in_channels", None)
if c:
return int(c)
except Exception:
pass
try:
return int(tr.patch_embedding.weight.shape[1])
except Exception:
return None
# I2V transformers (in_channels=36 = 16 noise + 16 image-cond + 4 mask) expect a
# wider input than the bare VAE latent (z_dim, typically 16). We have no init-image
# conditioning during identity-LoRA training, so feed the well-defined "no condition"
# case: zero-pad the noise latent up to in_channels (zero image latents + zero mask).
# The transformer still outputs z_dim channels, so the flow-matching target/loss is
# unchanged, and the attention-projection LoRA it trains applies during real i2v too.
_want_in_ch = _tr_in_channels(experts[0][1])
if _want_in_ch and _want_in_ch > z_dim:
_dbg_lora(f"Wan I2V transformer wants {_want_in_ch} in-channels; zero-padding "
f"the {z_dim}-channel latent (no-init-image training regime)")
_set_progress(status="training", message="training (Wan video LoRA)") _set_progress(status="training", message="training (Wan video LoRA)")
n = len(latents_list) n = len(latents_list)
for step in range(start_step, steps): for step in range(start_step, steps):
_check_train_cancel()
x0 = latents_list[step % n].to(device, dtype=compute_dtype) x0 = latents_list[step % n].to(device, dtype=compute_dtype)
noise = torch.randn_like(x0) noise = torch.randn_like(x0)
# Rectified-flow timestep with Wan resolution shift applied to sigma. # Rectified-flow timestep with Wan resolution shift applied to sigma.
...@@ -1427,7 +1632,14 @@ def _train_wan(req, base_path, images, instance_prompt, ...@@ -1427,7 +1632,14 @@ def _train_wan(req, base_path, images, instance_prompt,
target = (noise.float() - x0f).to(compute_dtype) # flow-matching velocity target = (noise.float() - x0f).to(compute_dtype) # flow-matching velocity
timestep = (sigma * num_train_t).to(torch.float32) # cast internally by Wan timestep = (sigma * num_train_t).to(torch.float32) # cast internally by Wan
tr = _pick_expert(float(sigma.item())) tr = _pick_expert(float(sigma.item()))
pred = tr(hidden_states=x_t, timestep=timestep, # Pad to the transformer's expected in_channels for I2V models (see note above).
x_in = x_t
want_ch = _tr_in_channels(tr) or x_t.shape[1]
if want_ch > x_t.shape[1]:
pad = torch.zeros((x_t.shape[0], want_ch - x_t.shape[1], *x_t.shape[2:]),
device=x_t.device, dtype=x_t.dtype)
x_in = torch.cat([x_t, pad], dim=1)
pred = tr(hidden_states=x_in, timestep=timestep,
encoder_hidden_states=encoder_hidden_states.to(device), encoder_hidden_states=encoder_hidden_states.to(device),
return_dict=False)[0] return_dict=False)[0]
loss = F.mse_loss(pred.float(), target.float(), reduction="mean") loss = F.mse_loss(pred.float(), target.float(), reduction="mean")
...@@ -1552,12 +1764,45 @@ def _train_lora_blocking(req: LoraTrainRequest, job_id: Optional[str] = None) -> ...@@ -1552,12 +1764,45 @@ def _train_lora_blocking(req: LoraTrainRequest, job_id: Optional[str] = None) ->
global _active_job_id global _active_job_id
_train_lock.acquire() _train_lock.acquire()
_active_job_id = job_id _active_job_id = job_id
# Live cancellable task (id == job id). Registered here, when the job actually
# starts on the GPU, so its progress mirrors via _set_progress.
if job_id:
task_registry.register("training", title=getattr(req, "name", "") or "",
model=getattr(req, "base_model", "") or "",
total=getattr(req, "steps", 0) or 0,
job_id=job_id, task_id=job_id, restartable=True,
status="running")
task_registry.start(job_id)
# A job cancelled while still queued aborts before any GPU work.
try:
_check_train_cancel()
except TaskCancelled:
_update_job(job_id, status="cancelled", active=False, force=True,
message="cancelled before start")
_active_job_id = None
_train_lock.release()
raise
try: try:
result = _train_lora_sync(req) result = _train_lora_sync(req)
if job_id: if job_id:
_update_job(job_id, status="done", active=False, force=True, _update_job(job_id, status="done", active=False, force=True,
message="done", path=result.get("path")) message="done", path=result.get("path"))
task_registry.finish(job_id, "done")
return result return result
except TaskCancelled:
# Cooperative cancel from the training loop — not an error.
try:
_set_progress(active=False, status="cancelled", message="cancelled")
except Exception:
pass
if job_id:
_update_job(job_id, status="cancelled", active=False, force=True,
message="cancelled by user")
task_registry.finish(job_id, "cancelled")
# The base/wan caches may hold a half-applied adapter — drop them.
_drop_base_cache()
_drop_wan_cache()
raise
except Exception as e: except Exception as e:
import traceback import traceback
traceback.print_exc() traceback.print_exc()
...@@ -1570,11 +1815,15 @@ def _train_lora_blocking(req: LoraTrainRequest, job_id: Optional[str] = None) -> ...@@ -1570,11 +1815,15 @@ def _train_lora_blocking(req: LoraTrainRequest, job_id: Optional[str] = None) ->
if job_id: if job_id:
_update_job(job_id, status="error", active=False, force=True, _update_job(job_id, status="error", active=False, force=True,
message=f"training failed: {e}"[:300]) message=f"training failed: {e}"[:300])
task_registry.finish(job_id, "error", f"{e}"[:200])
_drop_base_cache() _drop_base_cache()
_drop_wan_cache() _drop_wan_cache()
raise raise
finally: finally:
_active_job_id = None _active_job_id = None
if job_id:
_cancel_requested.discard(job_id)
_force_resume_jobs.discard(job_id)
_train_lock.release() _train_lock.release()
...@@ -1591,7 +1840,7 @@ async def _run_train_job(req: LoraTrainRequest, job_id: str) -> dict: ...@@ -1591,7 +1840,7 @@ async def _run_train_job(req: LoraTrainRequest, job_id: str) -> dict:
await queue_manager.release(lease) await queue_manager.release(lease)
@router.post("/v1/loras/train") @router.post("/v1/loras/train", summary="Train a new LoRA")
async def train_lora(req: LoraTrainRequest, _auth=Depends(_require_api_auth)): async def train_lora(req: LoraTrainRequest, _auth=Depends(_require_api_auth)):
"""Train a LoRA. Admitted through the central request scheduler so concurrent """Train a LoRA. Admitted through the central request scheduler so concurrent
trainings queue and run one-at-a-time alongside all other model requests. trainings queue and run one-at-a-time alongside all other model requests.
...@@ -1615,6 +1864,9 @@ async def train_lora(req: LoraTrainRequest, _auth=Depends(_require_api_auth)): ...@@ -1615,6 +1864,9 @@ async def train_lora(req: LoraTrainRequest, _auth=Depends(_require_api_auth)):
target=(req.target or "image"), status="queued", active=True, target=(req.target or "image"), status="queued", active=True,
step=0, total=req.steps or 0, message="queued", step=0, total=req.steps or 0, message="queued",
started_at=time.time(), path=None) started_at=time.time(), path=None)
# Persist the full request so the job can be restarted later (Tasks page)
# without the original client.
_save_job_request(job_id, req)
if not req.wait: if not req.wait:
# Detach: run independently of this request's lifetime. Keep a strong # Detach: run independently of this request's lifetime. Keep a strong
...@@ -1639,12 +1891,15 @@ async def _detached_train(req: LoraTrainRequest, job_id: str) -> None: ...@@ -1639,12 +1891,15 @@ async def _detached_train(req: LoraTrainRequest, job_id: str) -> None:
record (the client polls for them); nothing is raised to a waiting request.""" record (the client polls for them); nothing is raised to a waiting request."""
try: try:
await _run_train_job(req, job_id) await _run_train_job(req, job_id)
except TaskCancelled:
# Already recorded as 'cancelled' by _train_lora_blocking — don't clobber.
pass
except Exception as e: except Exception as e:
_update_job(job_id, status="error", active=False, force=True, _update_job(job_id, status="error", active=False, force=True,
message=f"training failed: {e}"[:300]) message=f"training failed: {e}"[:300])
@router.get("/v1/loras/progress") @router.get("/v1/loras/progress", summary="LoRA training progress")
async def lora_progress(job: Optional[str] = None, session: Optional[str] = None): async def lora_progress(job: Optional[str] = None, session: Optional[str] = None):
"""Training progress. """Training progress.
...@@ -1678,7 +1933,7 @@ async def lora_progress(job: Optional[str] = None, session: Optional[str] = None ...@@ -1678,7 +1933,7 @@ async def lora_progress(job: Optional[str] = None, session: Optional[str] = None
return dict(_progress) return dict(_progress)
@router.post("/v1/loras/upload") @router.post("/v1/loras/upload", summary="Upload LoRA weights (content-addressed)")
async def upload_lora(request: Request, _auth=Depends(_require_api_auth)): async def upload_lora(request: Request, _auth=Depends(_require_api_auth)):
"""Upload a LoRA file into the content-addressed store so a remote client can """Upload a LoRA file into the content-addressed store so a remote client can
use it without sharing a filesystem. use it without sharing a filesystem.
...@@ -1713,7 +1968,7 @@ async def upload_lora(request: Request, _auth=Depends(_require_api_auth)): ...@@ -1713,7 +1968,7 @@ async def upload_lora(request: Request, _auth=Depends(_require_api_auth)):
return {"id": f"sha256:{h}", "bytes": len(data), "existed": existed} return {"id": f"sha256:{h}", "bytes": len(data), "existed": existed}
@router.get("/v1/loras/blob/{hash}") @router.get("/v1/loras/blob/{hash}", summary="Check an uploaded LoRA blob exists")
async def lora_blob_info(hash: str, _auth=Depends(_require_api_auth)): async def lora_blob_info(hash: str, _auth=Depends(_require_api_auth)):
"""Existence check for an uploaded LoRA blob. 200 with metadata when present, """Existence check for an uploaded LoRA blob. 200 with metadata when present,
404 when absent — lets a client skip re-uploading a file the server already 404 when absent — lets a client skip re-uploading a file the server already
...@@ -1725,8 +1980,14 @@ async def lora_blob_info(hash: str, _auth=Depends(_require_api_auth)): ...@@ -1725,8 +1980,14 @@ async def lora_blob_info(hash: str, _auth=Depends(_require_api_auth)):
"bytes": os.path.getsize(p), "exists": True} "bytes": os.path.getsize(p), "exists": True}
@router.get("/v1/loras") @router.get("/v1/loras", summary="List registered LoRAs")
async def list_loras(_auth=Depends(_require_api_auth)): async def list_loras(_auth=Depends(_require_api_auth)):
"""List every trained/registered LoRA in the registry.
Returns one entry per LoRA with its `name`, on-disk weight `path` and any saved
metadata (base model, target, training params). These names can be referenced from
image/video requests via `loras: [{"id": "name:<name>"}]`.
"""
out = [] out = []
d = _loras_dir() d = _loras_dir()
if os.path.isdir(d): if os.path.isdir(d):
...@@ -1746,8 +2007,9 @@ async def list_loras(_auth=Depends(_require_api_auth)): ...@@ -1746,8 +2007,9 @@ async def list_loras(_auth=Depends(_require_api_auth)):
return {"loras": out} return {"loras": out}
@router.get("/v1/loras/{name}") @router.get("/v1/loras/{name}", summary="Get a registered LoRA")
async def get_lora(name: str, _auth=Depends(_require_api_auth)): async def get_lora(name: str, _auth=Depends(_require_api_auth)):
"""Fetch one registered LoRA by name, including its weight path and metadata."""
wf = _lora_weight_file(name) wf = _lora_weight_file(name)
if not wf: if not wf:
raise HTTPException(status_code=404, detail=f"LoRA '{name}' not found") raise HTTPException(status_code=404, detail=f"LoRA '{name}' not found")
...@@ -1762,8 +2024,9 @@ async def get_lora(name: str, _auth=Depends(_require_api_auth)): ...@@ -1762,8 +2024,9 @@ async def get_lora(name: str, _auth=Depends(_require_api_auth)):
return {"name": name, "path": wf, **meta} return {"name": name, "path": wf, **meta}
@router.delete("/v1/loras/{name}") @router.delete("/v1/loras/{name}", summary="Delete a registered LoRA")
async def delete_lora(name: str, _auth=Depends(_require_api_auth)): async def delete_lora(name: str, _auth=Depends(_require_api_auth)):
"""Delete a registered LoRA and all its files from the registry."""
d = _lora_dir(name) d = _lora_dir(name)
if not os.path.isdir(d): if not os.path.isdir(d):
raise HTTPException(status_code=404, detail=f"LoRA '{name}' not found") raise HTTPException(status_code=404, detail=f"LoRA '{name}' not found")
......
...@@ -117,7 +117,7 @@ class ImageToVideoPipelineRequest(BaseModel): ...@@ -117,7 +117,7 @@ class ImageToVideoPipelineRequest(BaseModel):
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow")
@router.post("/v1/pipelines/image-to-video") @router.post("/v1/pipelines/image-to-video", summary="Image-to-video pipeline")
async def pipeline_image_to_video(request: ImageToVideoPipelineRequest, http_request: Request = None): async def pipeline_image_to_video(request: ImageToVideoPipelineRequest, http_request: Request = None):
"""Generate an image then animate it into a video.""" """Generate an image then animate it into a video."""
steps = [] steps = []
...@@ -197,7 +197,7 @@ class VideoDubPipelineRequest(BaseModel): ...@@ -197,7 +197,7 @@ class VideoDubPipelineRequest(BaseModel):
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow")
@router.post("/v1/pipelines/video-dub") @router.post("/v1/pipelines/video-dub", summary="Video dubbing pipeline")
async def pipeline_video_dub(request: VideoDubPipelineRequest, http_request: Request = None): async def pipeline_video_dub(request: VideoDubPipelineRequest, http_request: Request = None):
"""Transcribe → translate → TTS dub → burn subtitles.""" """Transcribe → translate → TTS dub → burn subtitles."""
body = { body = {
...@@ -240,7 +240,7 @@ class StoryPipelineRequest(BaseModel): ...@@ -240,7 +240,7 @@ class StoryPipelineRequest(BaseModel):
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow")
@router.post("/v1/pipelines/story") @router.post("/v1/pipelines/story", summary="Story pipeline (multi-scene)")
async def pipeline_story(request: StoryPipelineRequest, http_request: Request = None): async def pipeline_story(request: StoryPipelineRequest, http_request: Request = None):
"""LLM generates script → image per scene → animate first scene → optional TTS narration.""" """LLM generates script → image per scene → animate first scene → optional TTS narration."""
n = min(request.num_scenes or 3, 6) n = min(request.num_scenes or 3, 6)
...@@ -377,7 +377,7 @@ class AudioDubPipelineRequest(BaseModel): ...@@ -377,7 +377,7 @@ class AudioDubPipelineRequest(BaseModel):
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow")
@router.post("/v1/pipelines/audio-dub") @router.post("/v1/pipelines/audio-dub", summary="Audio dubbing pipeline")
async def pipeline_audio_dub(request: AudioDubPipelineRequest, http_request: Request = None): async def pipeline_audio_dub(request: AudioDubPipelineRequest, http_request: Request = None):
"""Transcribe → (translate) → clone voice → replace audio track.""" """Transcribe → (translate) → clone voice → replace audio track."""
import os, tempfile, subprocess, base64 import os, tempfile, subprocess, base64
......
...@@ -499,7 +499,7 @@ class ImageTo3DRequest(BaseModel): ...@@ -499,7 +499,7 @@ class ImageTo3DRequest(BaseModel):
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow")
@router.post("/v1/images/to3d") @router.post("/v1/images/to3d", summary="Image to 3D model")
async def image_to_3d(request: ImageTo3DRequest, http_request: Request = None): async def image_to_3d(request: ImageTo3DRequest, http_request: Request = None):
"""Convert a 2D image to a 3D representation. """Convert a 2D image to a 3D representation.
...@@ -567,7 +567,7 @@ class ImageFrom3DRequest(BaseModel): ...@@ -567,7 +567,7 @@ class ImageFrom3DRequest(BaseModel):
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow")
@router.post("/v1/images/from3d") @router.post("/v1/images/from3d", summary="Render a 3D model to an image")
async def image_from_3d(request: ImageFrom3DRequest, http_request: Request = None): async def image_from_3d(request: ImageFrom3DRequest, http_request: Request = None):
"""Render a 3D model (GLB/OBJ) to a 2D PNG image from a specified camera angle.""" """Render a 3D model (GLB/OBJ) to a 2D PNG image from a specified camera angle."""
raw = _decode_b64(request.model_data) raw = _decode_b64(request.model_data)
...@@ -600,7 +600,7 @@ class VideoTo3DRequest(BaseModel): ...@@ -600,7 +600,7 @@ class VideoTo3DRequest(BaseModel):
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow")
@router.post("/v1/video/to3d") @router.post("/v1/video/to3d", summary="Video to 3D model")
async def video_to_3d(request: VideoTo3DRequest, http_request: Request = None): async def video_to_3d(request: VideoTo3DRequest, http_request: Request = None):
"""Convert a 2D video to a 3D video frame-by-frame. """Convert a 2D video to a 3D video frame-by-frame.
...@@ -641,7 +641,7 @@ class VideoFrom3DRequest(BaseModel): ...@@ -641,7 +641,7 @@ class VideoFrom3DRequest(BaseModel):
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow")
@router.post("/v1/video/from3d") @router.post("/v1/video/from3d", summary="Render a 3D model to a video")
async def video_from_3d(request: VideoFrom3DRequest, http_request: Request = None): async def video_from_3d(request: VideoFrom3DRequest, http_request: Request = None):
"""Render a 3D model as a 360° turntable video.""" """Render a 3D model as a 360° turntable video."""
raw = _decode_b64(request.model_data) raw = _decode_b64(request.model_data)
...@@ -674,7 +674,7 @@ class Generate3DRequest(BaseModel): ...@@ -674,7 +674,7 @@ class Generate3DRequest(BaseModel):
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow")
@router.post("/v1/3d/generate") @router.post("/v1/3d/generate", summary="Generate a 3D model from a prompt")
async def generate_3d(request: Generate3DRequest, http_request: Request = None): async def generate_3d(request: Generate3DRequest, http_request: Request = None):
"""Generate a 3D model (GLB) from a text prompt and/or an image. """Generate a 3D model (GLB) from a text prompt and/or an image.
......
...@@ -32,6 +32,7 @@ logger = logging.getLogger(__name__) ...@@ -32,6 +32,7 @@ logger = logging.getLogger(__name__)
# Import from codai modules # Import from codai modules
from codai.models.manager import ModelManager, WhisperServerManager, MultiModelManager, model_manager, multi_model_manager from codai.models.manager import ModelManager, WhisperServerManager, MultiModelManager, model_manager, multi_model_manager
from codai.queue.manager import QueueManager, queue_manager from codai.queue.manager import QueueManager, queue_manager
from codai.tasks import task_registry
from codai.api.prompt_cache import prompt_cache_manager from codai.api.prompt_cache import prompt_cache_manager
from codai.pydantic.textrequest import ChatCompletionRequest, ToolFunction, Tool from codai.pydantic.textrequest import ChatCompletionRequest, ToolFunction, Tool
from codai.models.parser import filter_malformed_content, filter_repetition, format_tools_for_prompt, cleanup_control_tokens, OpenAIFormatter, ModelParserAdapter, ToolCallParser from codai.models.parser import filter_malformed_content, filter_repetition, format_tools_for_prompt, cleanup_control_tokens, OpenAIFormatter, ModelParserAdapter, ToolCallParser
...@@ -92,7 +93,7 @@ def set_grammar_guided_gen(enabled: bool): ...@@ -92,7 +93,7 @@ def set_grammar_guided_gen(enabled: bool):
router = APIRouter() router = APIRouter()
@router.post("/v1/chat/completions") @router.post("/v1/chat/completions", summary="Chat completions")
async def chat_completions(request: ChatCompletionRequest, http_request: Request = None): async def chat_completions(request: ChatCompletionRequest, http_request: Request = None):
"""Chat completions endpoint with streaming and tool support.""" """Chat completions endpoint with streaming and tool support."""
...@@ -1248,6 +1249,7 @@ async def stream_chat_response( ...@@ -1248,6 +1249,7 @@ async def stream_chat_response(
completion_id = f"chatcmpl-{uuid.uuid4().hex}" completion_id = f"chatcmpl-{uuid.uuid4().hex}"
created = int(time.time()) created = int(time.time())
request_id = f"req-{uuid.uuid4().hex[:8]}" request_id = f"req-{uuid.uuid4().hex[:8]}"
_tid = None
generated_text = "" generated_text = ""
...@@ -1320,6 +1322,9 @@ async def stream_chat_response( ...@@ -1320,6 +1322,9 @@ async def stream_chat_response(
# Mark as starting processing # Mark as starting processing
await queue_manager.start_processing(request_id, model_name) await queue_manager.start_processing(request_id, model_name)
_tid = task_registry.register("text", title=(model_name or "chat"),
model=model_name or "", task_id=request_id)
task_registry.start(_tid)
# Send "Model starting" message # Send "Model starting" message
data = { data = {
...@@ -1374,6 +1379,9 @@ async def stream_chat_response( ...@@ -1374,6 +1379,9 @@ async def stream_chat_response(
response_format=response_format, response_format=response_format,
enable_thinking=enable_thinking, enable_thinking=enable_thinking,
): ):
# Cooperative cancellation: stop streaming if the task was cancelled.
if task_registry.is_cancelled(_tid):
break
chunk_count += 1 chunk_count += 1
# Always filter malformed content (regex-based, works per-chunk) # Always filter malformed content (regex-based, works per-chunk)
filtered_chunk = filter_malformed_content(chunk) filtered_chunk = filter_malformed_content(chunk)
...@@ -1580,6 +1588,9 @@ async def stream_chat_response( ...@@ -1580,6 +1588,9 @@ async def stream_chat_response(
finally: finally:
# Always clean up queue state # Always clean up queue state
await queue_manager.finish_processing() await queue_manager.finish_processing()
if _tid:
task_registry.finish(
_tid, "cancelled" if task_registry.is_cancelled(_tid) else "done")
async def generate_chat_response( async def generate_chat_response(
messages: List[Dict], messages: List[Dict],
...@@ -1789,7 +1800,7 @@ async def generate_chat_response( ...@@ -1789,7 +1800,7 @@ async def generate_chat_response(
from codai.pydantic.textrequest import CompletionRequest from codai.pydantic.textrequest import CompletionRequest
@router.post("/v1/completions") @router.post("/v1/completions", summary="Legacy text completions")
async def completions(request: CompletionRequest): async def completions(request: CompletionRequest):
"""Legacy text completions endpoint (for backward compatibility).""" """Legacy text completions endpoint (for backward compatibility)."""
# Get the model for this request # Get the model for this request
......
...@@ -119,7 +119,7 @@ def _format_response(fmt: str, text: str, segments: list): ...@@ -119,7 +119,7 @@ def _format_response(fmt: str, text: str, segments: list):
router = APIRouter() router = APIRouter()
@router.post("/v1/audio/transcriptions") @router.post("/v1/audio/transcriptions", summary="Transcribe audio to text")
async def create_transcription( async def create_transcription(
model: str = Form(...), model: str = Form(...),
file: UploadFile = File(...), file: UploadFile = File(...),
......
...@@ -64,7 +64,7 @@ class TTSResponse(BaseModel): ...@@ -64,7 +64,7 @@ class TTSResponse(BaseModel):
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow")
@router.post("/v1/audio/speech") @router.post("/v1/audio/speech", summary="Text-to-speech synthesis")
async def create_speech(request: TTSRequest, http_request: Request = None): async def create_speech(request: TTSRequest, http_request: Request = None):
""" """
Text-to-speech endpoint (OpenAI-compatible). Text-to-speech endpoint (OpenAI-compatible).
......
...@@ -45,6 +45,7 @@ from codai.pydantic.videorequest import ( ...@@ -45,6 +45,7 @@ from codai.pydantic.videorequest import (
CharacterDialogLine, CharacterDialogLine,
) )
from codai.api.images import _disable_safety_checker from codai.api.images import _disable_safety_checker
from codai.tasks import task_registry, TaskCancelled
router = APIRouter() router = APIRouter()
...@@ -627,7 +628,15 @@ def _generate_sdcpp_video(sd_model, request, model_cfg=None): ...@@ -627,7 +628,15 @@ def _generate_sdcpp_video(sd_model, request, model_cfg=None):
_vid_progress_reset(steps) _vid_progress_reset(steps)
# sd.cpp runs the whole diffusion in one C call → not interruptible mid-step;
# register for visibility + step progress (cancel applies once back in Python).
_tid = task_registry.register(
"video", title=(prompt or mode or "")[:80],
model=getattr(request, 'model', '') or '', total=steps)
task_registry.start(_tid)
def _progress_cb(step: int, total: int, elapsed: float): def _progress_cb(step: int, total: int, elapsed: float):
task_registry.step(_tid, step)
_vid_progress_step(step) _vid_progress_step(step)
kw = { kw = {
...@@ -654,8 +663,14 @@ def _generate_sdcpp_video(sd_model, request, model_cfg=None): ...@@ -654,8 +663,14 @@ def _generate_sdcpp_video(sd_model, request, model_cfg=None):
kw['init_image'] = _pil_from_b64(init_src) kw['init_image'] = _pil_from_b64(init_src)
kw['end_image'] = _pil_from_b64(request.end_image) kw['end_image'] = _pil_from_b64(request.end_image)
try:
frames = sd_model.generate_video(**kw) frames = sd_model.generate_video(**kw)
except Exception as e:
task_registry.finish(_tid, "error", str(e)[:200])
_vid_progress_done() _vid_progress_done()
raise
_vid_progress_done()
task_registry.finish(_tid, "done")
return list(frames), fps return list(frames), fps
...@@ -1483,7 +1498,17 @@ def _generate_video(pipe, request: VideoGenerationRequest): ...@@ -1483,7 +1498,17 @@ def _generate_video(pipe, request: VideoGenerationRequest):
_vid_progress_reset(kw['num_inference_steps']) _vid_progress_reset(kw['num_inference_steps'])
_tid = task_registry.register(
"video", title=(request.prompt or mode or "")[:80],
model=getattr(request, 'model', '') or '', total=kw['num_inference_steps'])
task_registry.start(_tid)
def _vid_step_cb(pipe, step_index, timestep, callback_kwargs): def _vid_step_cb(pipe, step_index, timestep, callback_kwargs):
# Cooperative cancellation: abort at the next step boundary if cancelled.
task_registry.raise_if_cancelled(_tid)
# Cooperative pause: block here while the user has paused this task.
task_registry.wait_if_paused(_tid)
task_registry.step(_tid, step_index + 1)
_vid_progress_step(step_index + 1) _vid_progress_step(step_index + 1)
# Mid-generation thermal checkpoint: pause between denoise steps if the # Mid-generation thermal checkpoint: pause between denoise steps if the
# CPU/GPU went over the limit during this (multi-minute) generation. # CPU/GPU went over the limit during this (multi-minute) generation.
...@@ -1547,8 +1572,17 @@ def _generate_video(pipe, request: VideoGenerationRequest): ...@@ -1547,8 +1572,17 @@ def _generate_video(pipe, request: VideoGenerationRequest):
# previous clip's (common within a match) and only swapping when they differ. # previous clip's (common within a match) and only swapping when they differ.
# Left loaded after the run so the next clip with the same set pays nothing. # Left loaded after the run so the next clip with the same set pays nothing.
_sync_video_loras(pipe, getattr(request, 'loras', None)) _sync_video_loras(pipe, getattr(request, 'loras', None))
try:
frames = _run_pipeline(pipe, kw) frames = _run_pipeline(pipe, kw)
except TaskCancelled:
_vid_progress_done()
raise # global handler finishes the task (cancelled) + returns HTTP 499
except Exception as e:
task_registry.finish(_tid, "error", str(e)[:200])
_vid_progress_done()
raise
_vid_progress_done() _vid_progress_done()
task_registry.finish(_tid, "done")
return frames, fps return frames, fps
...@@ -1979,7 +2013,7 @@ def _translate_srt(srt_path: str, target_lang: str, temps: list) -> str: ...@@ -1979,7 +2013,7 @@ def _translate_srt(srt_path: str, target_lang: str, temps: list) -> str:
# Progress endpoint # Progress endpoint
# ============================================================================= # =============================================================================
@router.get("/v1/video/progress") @router.get("/v1/video/progress", summary="Video generation progress")
async def get_video_progress(): async def get_video_progress():
"""Return current video generation step progress including speed.""" """Return current video generation step progress including speed."""
elapsed = time.monotonic() - _vid_progress["started_at"] if _vid_progress["active"] else 0.0 elapsed = time.monotonic() - _vid_progress["started_at"] if _vid_progress["active"] else 0.0
...@@ -2000,7 +2034,7 @@ async def get_video_progress(): ...@@ -2000,7 +2034,7 @@ async def get_video_progress():
# Main generation endpoint # Main generation endpoint
# ============================================================================= # =============================================================================
@router.post("/v1/video/generations", response_model=VideoGenerationResponse) @router.post("/v1/video/generations", response_model=VideoGenerationResponse, summary="Generate video")
async def video_generations(request: VideoGenerationRequest, async def video_generations(request: VideoGenerationRequest,
http_request: Request = None): http_request: Request = None):
""" """
...@@ -2269,7 +2303,7 @@ async def video_generations(request: VideoGenerationRequest, ...@@ -2269,7 +2303,7 @@ async def video_generations(request: VideoGenerationRequest,
# Video upscale endpoint # Video upscale endpoint
# ============================================================================= # =============================================================================
@router.post("/v1/video/upscale") @router.post("/v1/video/upscale", summary="Upscale a video")
async def video_upscale(request: VideoUpscaleRequest, http_request: Request = None): async def video_upscale(request: VideoUpscaleRequest, http_request: Request = None):
""" """
Upscale a video using ffmpeg lanczos or Real-ESRGAN. Upscale a video using ffmpeg lanczos or Real-ESRGAN.
...@@ -2299,7 +2333,7 @@ async def video_upscale(request: VideoUpscaleRequest, http_request: Request = No ...@@ -2299,7 +2333,7 @@ async def video_upscale(request: VideoUpscaleRequest, http_request: Request = No
# Subtitle generation endpoint # Subtitle generation endpoint
# ============================================================================= # =============================================================================
@router.post("/v1/video/subtitle") @router.post("/v1/video/subtitle", summary="Subtitle / caption a video")
async def video_subtitle(request: VideoSubtitleRequest, http_request: Request = None): async def video_subtitle(request: VideoSubtitleRequest, http_request: Request = None):
""" """
Generate subtitles for a video. Generate subtitles for a video.
...@@ -2353,7 +2387,7 @@ async def video_subtitle(request: VideoSubtitleRequest, http_request: Request = ...@@ -2353,7 +2387,7 @@ async def video_subtitle(request: VideoSubtitleRequest, http_request: Request =
# Frame interpolation endpoint # Frame interpolation endpoint
# ============================================================================= # =============================================================================
@router.post("/v1/video/interpolate") @router.post("/v1/video/interpolate", summary="Interpolate video frames")
async def video_interpolate(request: VideoInterpolateRequest, http_request: Request = None): async def video_interpolate(request: VideoInterpolateRequest, http_request: Request = None):
""" """
Increase video FPS via frame interpolation. Increase video FPS via frame interpolation.
...@@ -2400,7 +2434,7 @@ async def video_interpolate(request: VideoInterpolateRequest, http_request: Requ ...@@ -2400,7 +2434,7 @@ async def video_interpolate(request: VideoInterpolateRequest, http_request: Requ
# Video dubbing endpoint # Video dubbing endpoint
# ============================================================================= # =============================================================================
@router.post("/v1/video/dub") @router.post("/v1/video/dub", summary="Dub a video")
async def video_dub(request: VideoDubRequest, http_request: Request = None): async def video_dub(request: VideoDubRequest, http_request: Request = None):
""" """
Translate and re-dub a video. Translate and re-dub a video.
......
...@@ -185,13 +185,13 @@ class VoicePatchRequest(BaseModel): ...@@ -185,13 +185,13 @@ class VoicePatchRequest(BaseModel):
# Voice profile management # Voice profile management
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@router.get("/v1/audio/voices") @router.get("/v1/audio/voices", summary="List voice profiles")
async def list_voices(): async def list_voices():
"""List all saved voice profiles.""" """List all saved voice profiles."""
return {"voices": _list_voices()} return {"voices": _list_voices()}
@router.post("/v1/audio/voices") @router.post("/v1/audio/voices", summary="Create a voice profile")
async def create_voice( async def create_voice(
name: str = Form(...), name: str = Form(...),
transcript: str = Form(...), transcript: str = Form(...),
...@@ -216,7 +216,7 @@ async def create_voice( ...@@ -216,7 +216,7 @@ async def create_voice(
return {"created": True, "voice": meta} return {"created": True, "voice": meta}
@router.delete("/v1/audio/voices/{name}") @router.delete("/v1/audio/voices/{name}", summary="Delete a voice profile")
async def delete_voice(name: str): async def delete_voice(name: str):
"""Delete a saved voice profile.""" """Delete a saved voice profile."""
import shutil import shutil
...@@ -227,7 +227,7 @@ async def delete_voice(name: str): ...@@ -227,7 +227,7 @@ async def delete_voice(name: str):
return {"deleted": True, "name": name} return {"deleted": True, "name": name}
@router.patch("/v1/audio/voices/{name}") @router.patch("/v1/audio/voices/{name}", summary="Update a voice profile")
async def patch_voice(name: str, req: VoicePatchRequest): async def patch_voice(name: str, req: VoicePatchRequest):
"""Update description, transcript, or reference audio of a saved voice profile.""" """Update description, transcript, or reference audio of a saved voice profile."""
meta = _load_voice(name) meta = _load_voice(name)
...@@ -259,7 +259,7 @@ async def patch_voice(name: str, req: VoicePatchRequest): ...@@ -259,7 +259,7 @@ async def patch_voice(name: str, req: VoicePatchRequest):
return {"updated": True, "voice": meta} return {"updated": True, "voice": meta}
@router.get("/v1/audio/voices/{name}") @router.get("/v1/audio/voices/{name}", summary="Get a voice profile")
async def get_voice(name: str): async def get_voice(name: str):
"""Get a single voice profile metadata.""" """Get a single voice profile metadata."""
meta = _load_voice(name) meta = _load_voice(name)
...@@ -268,7 +268,7 @@ async def get_voice(name: str): ...@@ -268,7 +268,7 @@ async def get_voice(name: str):
return {"voice": meta} return {"voice": meta}
@router.post("/v1/audio/voices/extract") @router.post("/v1/audio/voices/extract", summary="Extract a voice profile from a sample")
async def extract_voice(req: VoiceExtractRequest): async def extract_voice(req: VoiceExtractRequest):
""" """
Extract a voice profile from a source audio or video file. Extract a voice profile from a source audio or video file.
...@@ -358,7 +358,7 @@ class VoiceCloneRequest(BaseModel): ...@@ -358,7 +358,7 @@ class VoiceCloneRequest(BaseModel):
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow")
@router.post("/v1/audio/clone") @router.post("/v1/audio/clone", summary="Clone a voice / synthesize cloned speech")
async def clone_voice(request: VoiceCloneRequest, http_request: Request = None): async def clone_voice(request: VoiceCloneRequest, http_request: Request = None):
""" """
Synthesize speech in a cloned voice using F5-TTS. Synthesize speech in a cloned voice using F5-TTS.
......
...@@ -94,7 +94,7 @@ class VoiceConvertRequest(BaseModel): ...@@ -94,7 +94,7 @@ class VoiceConvertRequest(BaseModel):
model_config = ConfigDict(extra='allow') model_config = ConfigDict(extra='allow')
@router.post('/v1/audio/convert') @router.post('/v1/audio/convert', summary="Voice conversion (speech-to-speech)")
async def convert_voice(request: VoiceConvertRequest, http_request: Request = None): async def convert_voice(request: VoiceConvertRequest, http_request: Request = None):
""" """
Voice conversion: preserves pitch/melody/expression, changes only timbre. Voice conversion: preserves pitch/melody/expression, changes only timbre.
......
...@@ -78,6 +78,40 @@ except ImportError: ...@@ -78,6 +78,40 @@ except ImportError:
_llama_cpp = None _llama_cpp = None
# Friendly KV-cache quant names → llama.cpp GGML type. q8_0 is near-lossless and
# the safe default; the q5/q4 types trade a little accuracy for ~2x less KV VRAM.
_KV_TYPE_ALIASES = {
'f16': 'GGML_TYPE_F16', 'fp16': 'GGML_TYPE_F16', 'f32': 'GGML_TYPE_F32',
'q8_0': 'GGML_TYPE_Q8_0', 'q8': 'GGML_TYPE_Q8_0', 'q8_1': 'GGML_TYPE_Q8_1',
'q5_0': 'GGML_TYPE_Q5_0', 'q5_1': 'GGML_TYPE_Q5_1', 'q5': 'GGML_TYPE_Q5_1',
'q4_0': 'GGML_TYPE_Q4_0', 'q4_1': 'GGML_TYPE_Q4_1', 'q4': 'GGML_TYPE_Q4_1',
'iq4_nl': 'GGML_TYPE_IQ4_NL',
}
# Sub-8-bit KV types that llama.cpp can only use with flash attention enabled.
_KV_NEEDS_FLASH = {'q5_0', 'q5_1', 'q5', 'q4_0', 'q4_1', 'q4', 'iq4_nl'}
def _ggml_kv_type(name):
"""Map a KV-cache quant name to the llama.cpp GGML type int, or None.
Returns None for falsy / unknown / 'none' / 'auto' values (→ keep the
llama.cpp default, f16). Unknown names log a warning instead of failing."""
if not name or _llama_cpp is None:
return None
key = str(name).strip().lower().replace('-', '_').replace(' ', '')
if key in ('', 'none', 'auto', 'default', 'f16default'):
return None
const = _KV_TYPE_ALIASES.get(key)
if const is None:
print(f" KV cache type '{name}' not recognized — using default (f16)")
return None
val = getattr(_llama_cpp, const, None)
if val is None:
print(f" KV cache type '{name}' unsupported by this llama.cpp build — using f16")
return val
def _install_layer_log_callback(): def _install_layer_log_callback():
"""Replace llama.cpp's log callback with one that prints load-time layer/buffer """Replace llama.cpp's log callback with one that prints load-time layer/buffer
messages directly to stdout. Returns the callback object — keep a reference messages directly to stdout. Returns the callback object — keep a reference
...@@ -614,6 +648,32 @@ class VulkanBackend(ModelBackend): ...@@ -614,6 +648,32 @@ class VulkanBackend(ModelBackend):
if 'rope_freq_scale' in kwargs: if 'rope_freq_scale' in kwargs:
llama_kwargs['rope_freq_scale'] = kwargs['rope_freq_scale'] llama_kwargs['rope_freq_scale'] = kwargs['rope_freq_scale']
# KV-cache quantization (llama.cpp type_k / type_v). Shrinks the KV cache
# so long contexts fit in less VRAM. Read from the per-model config, with
# the raw models.json entry as a fallback (carried in _raw_cfg).
_raw_cfg = kwargs.get('_raw_cfg') or {}
_ck = kwargs.get('cache_type_k', _raw_cfg.get('cache_type_k'))
_cv = kwargs.get('cache_type_v', _raw_cfg.get('cache_type_v'))
_flash = bool(kwargs.get('flash_attn', _raw_cfg.get('flash_attn',
_raw_cfg.get('flash_attention', False))))
_tk = _ggml_kv_type(_ck)
_tv = _ggml_kv_type(_cv)
if _tk is not None:
llama_kwargs['type_k'] = _tk
if _tv is not None:
llama_kwargs['type_v'] = _tv
# A quantized V cache below 8 bits requires flash attention in llama.cpp;
# auto-enable it (with a note) so the config "just works".
_v_needs_flash = str(_cv or '').strip().lower().replace('-', '_') in _KV_NEEDS_FLASH
if (_tk is not None or _tv is not None):
if _v_needs_flash and not _flash:
_flash = True
print(" KV cache: sub-8-bit V cache needs flash attention — enabling it")
if _flash:
llama_kwargs['flash_attn'] = True
print(f" KV cache: type_k={_ck or 'f16'} type_v={_cv or 'f16'}"
f"{' (flash_attn on)' if _flash else ''}")
# Force CUDA if requested # Force CUDA if requested
if self.force_cuda: if self.force_cuda:
# Set environment variable to force CUDA # Set environment variable to force CUDA
......
...@@ -247,4 +247,11 @@ configuration directory (--config DIR, default: OS-specific CoderAI directory). ...@@ -247,4 +247,11 @@ configuration directory (--config DIR, default: OS-specific CoderAI directory).
action="store_true", action="store_true",
help="List available Vulkan GPU devices and exit", help="List available Vulkan GPU devices and exit",
) )
parser.add_argument(
"--no-resume-jobs",
action="store_true",
help="Do not resume/recover interrupted LoRA training jobs on restart. "
"Mid-flight jobs are marked 'cancelled' (checkpoints are kept, so they "
"can still be restarted manually from the Tasks page).",
)
return parser.parse_args() return parser.parse_args()
...@@ -126,6 +126,17 @@ class ThermalConfig: ...@@ -126,6 +126,17 @@ class ThermalConfig:
poll_seconds: float = 5.0 # how often to re-check while cooling down poll_seconds: float = 5.0 # how often to re-check while cooling down
@dataclass
class JobsConfig:
"""Background-job (LoRA training) configuration."""
# When True, an interrupted training job (process restart) is left
# 'interrupted' so it can resume from its on-disk checkpoint. When False,
# such jobs are marked 'cancelled' on startup and not auto-resumed (their
# checkpoints are kept, so they can be restarted manually from the Tasks
# page). The --no-resume-jobs CLI flag forces this off for one run.
resume_on_restart: bool = True
@dataclass @dataclass
class Config: class Config:
"""Main configuration class.""" """Main configuration class."""
...@@ -139,6 +150,7 @@ class Config: ...@@ -139,6 +150,7 @@ class Config:
whisper: WhisperConfig = field(default_factory=WhisperConfig) whisper: WhisperConfig = field(default_factory=WhisperConfig)
archive: ArchiveConfig = field(default_factory=ArchiveConfig) archive: ArchiveConfig = field(default_factory=ArchiveConfig)
thermal: ThermalConfig = field(default_factory=ThermalConfig) thermal: ThermalConfig = field(default_factory=ThermalConfig)
jobs: JobsConfig = field(default_factory=JobsConfig)
broker: BrokerConfig = field(default_factory=BrokerConfig) broker: BrokerConfig = field(default_factory=BrokerConfig)
system_prompt: Optional[str] = None system_prompt: Optional[str] = None
tools_closer_prompt: bool = False tools_closer_prompt: bool = False
...@@ -293,6 +305,7 @@ class ConfigManager: ...@@ -293,6 +305,7 @@ class ConfigManager:
whisper=WhisperConfig(**config_data.get("whisper", {})), whisper=WhisperConfig(**config_data.get("whisper", {})),
archive=ArchiveConfig(**config_data.get("archive", {})), archive=ArchiveConfig(**config_data.get("archive", {})),
thermal=ThermalConfig(**config_data.get("thermal", {})), thermal=ThermalConfig(**config_data.get("thermal", {})),
jobs=JobsConfig(**config_data.get("jobs", {})),
broker=BrokerConfig(**config_data.get("broker", {})), broker=BrokerConfig(**config_data.get("broker", {})),
system_prompt=config_data.get("system_prompt"), system_prompt=config_data.get("system_prompt"),
tools_closer_prompt=config_data.get("tools_closer_prompt", False), tools_closer_prompt=config_data.get("tools_closer_prompt", False),
...@@ -411,6 +424,9 @@ class ConfigManager: ...@@ -411,6 +424,9 @@ class ConfigManager:
"gpu_resume": self.config.thermal.gpu_resume, "gpu_resume": self.config.thermal.gpu_resume,
"poll_seconds": self.config.thermal.poll_seconds, "poll_seconds": self.config.thermal.poll_seconds,
}, },
"jobs": {
"resume_on_restart": self.config.jobs.resume_on_restart,
},
"broker": { "broker": {
"enabled": self.config.broker.enabled, "enabled": self.config.broker.enabled,
"base_url": self.config.broker.base_url, "base_url": self.config.broker.base_url,
......
...@@ -147,6 +147,8 @@ def build_runtime_kwargs(model_cfg, model_type): ...@@ -147,6 +147,8 @@ def build_runtime_kwargs(model_cfg, model_type):
} }
if model_type == "text": if model_type == "text":
kwargs['ctx'] = model_cfg.get('n_ctx', model_cfg.get('context_size')) kwargs['ctx'] = model_cfg.get('n_ctx', model_cfg.get('context_size'))
kwargs['cache_type_k'] = model_cfg.get('cache_type_k')
kwargs['cache_type_v'] = model_cfg.get('cache_type_v')
elif model_type == "image": elif model_type == "image":
kwargs['llm_path'] = model_cfg.get('llm_path') kwargs['llm_path'] = model_cfg.get('llm_path')
kwargs['vae_path'] = model_cfg.get('vae_path') kwargs['vae_path'] = model_cfg.get('vae_path')
...@@ -865,9 +867,16 @@ def main(): ...@@ -865,9 +867,16 @@ def main():
from codai.api.characters import set_global_args as set_chars_global_args from codai.api.characters import set_global_args as set_chars_global_args
set_chars_global_args(global_args) set_chars_global_args(global_args)
# Set LoRA training module global args # Set LoRA training module global args. Resolve job-recovery first (the
from codai.api.loras import set_global_args as set_loras_global_args # --no-resume-jobs flag overrides the persisted config setting), then call
# set_global_args, which runs _load_jobs_on_start and honours the flag.
from codai.api.loras import (set_global_args as set_loras_global_args,
set_resume_enabled as set_loras_resume_enabled)
_resume_jobs = bool(getattr(config.jobs, "resume_on_restart", True)) and not getattr(args, "no_resume_jobs", False)
set_loras_resume_enabled(_resume_jobs)
set_loras_global_args(global_args) set_loras_global_args(global_args)
if not _resume_jobs:
print("LoRA job recovery: DISABLED (interrupted training will be cancelled on restart)")
# Set environment profiles module global args # Set environment profiles module global args
from codai.api.environments import set_global_args as set_envs_global_args from codai.api.environments import set_global_args as set_envs_global_args
......
...@@ -790,6 +790,17 @@ class MultiModelManager: ...@@ -790,6 +790,17 @@ class MultiModelManager:
# build_kwargs_from_config populates it from the model's # build_kwargs_from_config populates it from the model's
# 'flash_attention' setting; CLI/global is NOT consulted here. # 'flash_attention' setting; CLI/global is NOT consulted here.
kwargs['flash_attn'] = bool(config.get('flash_attn', False)) kwargs['flash_attn'] = bool(config.get('flash_attn', False))
# KV-cache quantization (llama.cpp type_k/type_v) — pass through
# to the backend, with the raw models.json entry as a fallback.
_raw = config.get('_raw_cfg') if isinstance(config.get('_raw_cfg'), dict) else {}
for _kvk in ('cache_type_k', 'cache_type_v'):
_kvv = config.get(_kvk)
if _kvv is None:
_kvv = _raw.get(_kvk)
if _kvv:
kwargs[_kvk] = _kvv
if _raw and '_raw_cfg' not in kwargs:
kwargs['_raw_cfg'] = _raw
no_ram = _cfg_or_global('no_ram', 'no_ram', False) no_ram = _cfg_or_global('no_ram', 'no_ram', False)
kwargs['no_ram'] = bool(no_ram) kwargs['no_ram'] = bool(no_ram)
offload_strategy = _cfg_or_global('offload_strategy', 'offload_strategy', 'auto') offload_strategy = _cfg_or_global('offload_strategy', 'offload_strategy', 'auto')
...@@ -872,6 +883,17 @@ class MultiModelManager: ...@@ -872,6 +883,17 @@ class MultiModelManager:
# build_kwargs_from_config populates it from the model's # build_kwargs_from_config populates it from the model's
# 'flash_attention' setting; CLI/global is NOT consulted here. # 'flash_attention' setting; CLI/global is NOT consulted here.
kwargs['flash_attn'] = bool(config.get('flash_attn', False)) kwargs['flash_attn'] = bool(config.get('flash_attn', False))
# KV-cache quantization (llama.cpp type_k/type_v) — pass through
# to the backend, with the raw models.json entry as a fallback.
_raw = config.get('_raw_cfg') if isinstance(config.get('_raw_cfg'), dict) else {}
for _kvk in ('cache_type_k', 'cache_type_v'):
_kvv = config.get(_kvk)
if _kvv is None:
_kvv = _raw.get(_kvk)
if _kvv:
kwargs[_kvk] = _kvv
if _raw and '_raw_cfg' not in kwargs:
kwargs['_raw_cfg'] = _raw
no_ram = _cfg_or_global('no_ram', 'no_ram', False) no_ram = _cfg_or_global('no_ram', 'no_ram', False)
kwargs['no_ram'] = bool(no_ram) kwargs['no_ram'] = bool(no_ram)
offload_strategy = _cfg_or_global('offload_strategy', 'offload_strategy', 'auto') offload_strategy = _cfg_or_global('offload_strategy', 'offload_strategy', 'auto')
......
...@@ -35,10 +35,62 @@ Semantics (per sensor, when enabled): ...@@ -35,10 +35,62 @@ Semantics (per sensor, when enabled):
import os import os
import shutil import shutil
import subprocess import subprocess
import threading
import time import time
from typing import Optional, Tuple from typing import Optional, Tuple
# ---------------------------------------------------------------------------
# Cooldown state (published for the admin Tasks view)
# ---------------------------------------------------------------------------
# A thermal pause is a *global* hardware event: every worker that reaches a
# checkpoint blocks until temps recover. We publish a single process-wide state
# so the Tasks page can show that running work is paused for cooldown. A waiter
# counter (not a bool) keeps the state correct when several workers pause at
# once — the state is "active" while any worker is still cooling.
_cooldown_lock = threading.Lock()
_cooldown_waiters = 0
_cooldown_state: dict = {
"active": False, "since": 0.0, "waited": 0.0,
"gpu": None, "cpu": None, "message": "",
}
def get_cooldown_state() -> dict:
"""Snapshot of the current thermal cooldown (see module note). ``active`` is
True while at least one worker is paused waiting for the hardware to cool."""
with _cooldown_lock:
return dict(_cooldown_state)
def _cooldown_enter() -> None:
global _cooldown_waiters
with _cooldown_lock:
_cooldown_waiters += 1
_cooldown_state["active"] = True
if not _cooldown_state.get("since"):
_cooldown_state["since"] = time.time()
def _cooldown_update(gpu, cpu, waited, message) -> None:
with _cooldown_lock:
_cooldown_state["gpu"] = gpu
_cooldown_state["cpu"] = cpu
_cooldown_state["waited"] = waited
_cooldown_state["message"] = message
def _cooldown_exit() -> None:
global _cooldown_waiters
with _cooldown_lock:
_cooldown_waiters = max(0, _cooldown_waiters - 1)
if _cooldown_waiters == 0:
_cooldown_state.update({
"active": False, "since": 0.0, "waited": 0.0,
"gpu": None, "cpu": None, "message": "",
})
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Temperature readers # Temperature readers
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
...@@ -199,6 +251,57 @@ def read_cpu_temp() -> Optional[float]: ...@@ -199,6 +251,57 @@ def read_cpu_temp() -> Optional[float]:
return val return val
_gpu_util_cache: Tuple[float, Optional[float]] = (0.0, None)
def _read_gpu_util_uncached() -> Optional[float]:
"""Hottest GPU utilization in %, or None if unreadable."""
if _NVIDIA_SMI:
out = _run([
_NVIDIA_SMI,
"--query-gpu=utilization.gpu",
"--format=csv,noheader,nounits",
])
if out:
vals = []
for line in out.splitlines():
line = line.strip()
if line:
try:
vals.append(float(line))
except ValueError:
pass
if vals:
return max(vals)
if _ROCM_SMI:
out = _run([_ROCM_SMI, "--showuse"])
if out:
vals = []
for line in out.splitlines():
low = line.lower()
if "gpu use" in low and "%" in line:
for tok in line.replace("%", " ").split():
try:
vals.append(float(tok))
except ValueError:
continue
if vals:
return max(vals)
return None
def read_gpu_util() -> Optional[float]:
"""GPU utilization % (cached ~2s), or None if unreadable."""
global _gpu_util_cache
now = time.monotonic()
ts, val = _gpu_util_cache
if now - ts < _CACHE_TTL:
return val
val = _read_gpu_util_uncached()
_gpu_util_cache = (now, val)
return val
def read_cpu_temp_avg(samples: int = 3, max_seconds: float = 3.0) -> Optional[float]: def read_cpu_temp_avg(samples: int = 3, max_seconds: float = 3.0) -> Optional[float]:
"""Averaged CPU temperature for stable resume/cooldown decisions. """Averaged CPU temperature for stable resume/cooldown decisions.
...@@ -372,6 +475,8 @@ def wait_until_safe(settings: Optional[ThermalSettings] = None, ...@@ -372,6 +475,8 @@ def wait_until_safe(settings: Optional[ThermalSettings] = None,
f"until cooldown (GPU<={settings.gpu_resume:.0f}°C / " f"until cooldown (GPU<={settings.gpu_resume:.0f}°C / "
f"CPU<={settings.cpu_resume:.0f}°C)") f"CPU<={settings.cpu_resume:.0f}°C)")
waited = 0.0 waited = 0.0
_cooldown_enter()
try:
while True: while True:
# Re-evaluate against resume thresholds (lower than trigger → hysteresis). # Re-evaluate against resume thresholds (lower than trigger → hysteresis).
# CPU temps are noisy, so average a few samples for the resume decision # CPU temps are noisy, so average a few samples for the resume decision
...@@ -388,9 +493,12 @@ def wait_until_safe(settings: Optional[ThermalSettings] = None, ...@@ -388,9 +493,12 @@ def wait_until_safe(settings: Optional[ThermalSettings] = None,
if not still: if not still:
break break
msg = ", ".join(f"{lbl} {t:.0f}°C>{r:.0f}°C" for lbl, t, r in still) msg = ", ".join(f"{lbl} {t:.0f}°C>{r:.0f}°C" for lbl, t, r in still)
_cooldown_update(gt, ct, waited, msg)
print(f"[thermal] Cooling{desc}: {msg} — waiting " print(f"[thermal] Cooling{desc}: {msg} — waiting "
f"({int(waited)}s elapsed)") f"({int(waited)}s elapsed)")
time.sleep(settings.poll_seconds) time.sleep(settings.poll_seconds)
waited += settings.poll_seconds waited += settings.poll_seconds
finally:
_cooldown_exit()
print(f"[thermal] Temperatures back within safe limits{desc} — resuming " print(f"[thermal] Temperatures back within safe limits{desc} — resuming "
f"after {int(waited)}s") f"after {int(waited)}s")
# CoderAI - OpenAI-compatible API server
# Copyright (C) 2026 Stefy Lanza <stefy@nexlab.net>
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
"""TurboQuant-style data-free vector quantization.
A faithful, dependency-light implementation of the core idea behind TurboQuant
(Zandieh et al., *TurboQuant: Online Vector Quantization with Near-optimal
Distortion Rate*, arXiv:2504.19874, ICLR 2026): randomly rotate each vector so
its coordinates become near-Gaussian and well concentrated, then apply a simple
per-coordinate uniform scalar quantizer. The rotation makes the cheap uniform
quantizer near rate-distortion optimal, and — crucially for retrieval — the
reconstruction is **unbiased**, so inner products / cosine similarities between
quantized vectors are preserved in expectation.
What this gives coderai: an optional compact representation for ``/v1/embeddings``
output (4–8× smaller than float32) whose dot products match the full-precision
embeddings, suitable for storing in a vector DB.
Scope / honesty: this is the rotation + scalar-quantization core (a *data-free,
calibration-free* quantizer). It does **not** implement the paper's extra 1-bit
QJL residual stage, which buys a little more accuracy at the same bit budget;
the structure here is deliberately simple, deterministic and fast (an O(d log d)
randomized Hadamard transform, no stored rotation matrix, no torch dependency).
The rotation is keyed by ``(dim, seed)`` only — never by the data — so every
vector from the same model lands in the *same* rotated space and quantized
vectors remain mutually comparable. Use a fixed ``seed`` per deployment.
"""
from __future__ import annotations
import base64
import os
from dataclasses import dataclass
from typing import List, Optional, Tuple
import numpy as np
# Default global seed for the keyed rotation. Stable across processes so the
# same embedding model always quantizes into the same space.
DEFAULT_SEED = 0x7B_C0_DE
# --- Optional upstream library backend ------------------------------------
# When the `turboquant-py` package (pip install "turboquant-py[torch]") is
# installed, the float-reconstruction path can delegate to its QJL/inner-product
# quantizer, which adds the paper's 1-bit residual stage this built-in core
# omits. It is used opportunistically and every call is guarded — any import or
# API mismatch transparently falls back to the built-in NumPy implementation, so
# there is never a hard dependency. Control via CODERAI_TURBOQUANT_LIB:
# auto (default) = use the library if importable; off/0 = always built-in.
_LIB_MODE = os.environ.get("CODERAI_TURBOQUANT_LIB", "auto").strip().lower()
def _lib():
if _LIB_MODE in ("off", "0", "false", "no", "none"):
return None
try:
import turboquant as _tq # turboquant-py
return _tq
except Exception:
return None
def have_library() -> bool:
"""True if the optional turboquant-py backend is importable and enabled."""
return _lib() is not None
def backend_name() -> str:
"""Name of the active reconstruction backend ('turboquant-py' or 'builtin')."""
return "turboquant-py" if have_library() else "builtin"
def _next_pow2(n: int) -> int:
p = 1
while p < n:
p <<= 1
return p
def _signs(dim_padded: int, seed: int) -> np.ndarray:
"""Deterministic ±1 sign vector keyed by (dim_padded, seed)."""
rng = np.random.default_rng(np.uint64(seed) ^ np.uint64(dim_padded))
return rng.integers(0, 2, size=dim_padded, dtype=np.int8).astype(np.float32) * 2.0 - 1.0
def _fwht(a: np.ndarray) -> np.ndarray:
"""In-place fast Walsh-Hadamard transform along the last axis.
``a`` must have a power-of-two last dimension. Returns the *unnormalized*
transform (H with H@H == n*I); callers scale by 1/sqrt(n) for orthonormality.
"""
a = a.astype(np.float32, copy=True)
n = a.shape[-1]
h = 1
while h < n:
# vectorized butterfly over the last axis
a = a.reshape(*a.shape[:-1], n // (2 * h), 2, h)
x = a[..., 0, :]
y = a[..., 1, :]
a = np.concatenate([x + y, x - y], axis=-1)
a = a.reshape(*a.shape[:-2], n)
h *= 2
return a
def _rotate(x: np.ndarray, signs: np.ndarray) -> np.ndarray:
"""Orthonormal randomized Hadamard rotation R(x) = (1/sqrt(P)) H (s ⊙ x)."""
p = signs.shape[0]
return _fwht(x * signs) / np.sqrt(p, dtype=np.float32)
def _irotate(y: np.ndarray, signs: np.ndarray) -> np.ndarray:
"""Inverse rotation R^-1(y) = s ⊙ ((1/sqrt(P)) H y) (R is orthonormal)."""
p = signs.shape[0]
return (_fwht(y) / np.sqrt(p, dtype=np.float32)) * signs
def _clip_radius(dim_padded: int) -> float:
"""Clip range for the rotated *unit* vector's coordinates.
After rotating a unit vector, each coordinate is ~N(0, 1/P); ~4 sigma covers
the distribution with negligible clipping while keeping the quantizer step
small. Returns r so coordinates are quantized over [-r, r].
"""
return 4.0 / np.sqrt(float(dim_padded))
@dataclass
class TurboQuantMeta:
method: str
bits: int
seed: int
dim: int # original embedding dimension
dim_padded: int # power-of-two rotation size
radius: float
bytes_per_vector: int
def _parse_quant_spec(spec: Optional[str]) -> Optional[int]:
"""Map a request quantization string to a bit width, or None.
Accepts ``turbo`` (=8 bit), ``turbo8``/``turbo6``/``turbo4``/``turbo2``,
or a bare integer string. Returns None for falsy / unrecognized values so
callers can treat it as "no quantization".
"""
if not spec:
return None
s = str(spec).strip().lower().replace("-", "").replace("_", "")
if s in ("turbo", "turboquant"):
return 8
if s.startswith("turbo"):
s = s[5:]
if s.isdigit():
b = int(s)
return b if b in (2, 4, 6, 8) else None
return None
def _as_2d(vectors) -> Tuple[np.ndarray, bool]:
arr = np.asarray(vectors, dtype=np.float32)
single = arr.ndim == 1
if single:
arr = arr[None, :]
return arr, single
def _prepare(vectors, bits: int, seed: int):
"""Rotate unit-normalized vectors and return (codes, norms, signs, meta)."""
arr, single = _as_2d(vectors)
n, dim = arr.shape
p = _next_pow2(dim)
signs = _signs(p, seed)
norms = np.linalg.norm(arr, axis=1, keepdims=True).astype(np.float32)
safe = np.where(norms == 0.0, 1.0, norms)
unit = arr / safe
padded = np.zeros((n, p), dtype=np.float32)
padded[:, :dim] = unit
rot = _rotate(padded, signs) # ~N(0, 1/P) coordinates
r = _clip_radius(p)
levels = (1 << bits) - 1
q = np.clip((rot + r) / (2.0 * r), 0.0, 1.0)
codes = np.rint(q * levels).astype(np.uint16) # in [0, levels]
meta = TurboQuantMeta(
method="turboquant", bits=bits, seed=seed, dim=dim, dim_padded=p,
radius=float(r), bytes_per_vector=(p * bits + 7) // 8 + 2,
)
return codes, norms.reshape(-1), signs, meta, single
def _decode_codes(codes: np.ndarray, norms: np.ndarray, signs: np.ndarray,
meta: TurboQuantMeta) -> np.ndarray:
"""Inverse of :func:`_prepare` — unbiased reconstruction back to dim ``meta.dim``."""
r = meta.radius
levels = (1 << meta.bits) - 1
rot = codes.astype(np.float32) / levels * (2.0 * r) - r
padded = _irotate(rot, signs)
out = padded[:, :meta.dim] * norms.reshape(-1, 1)
return out
# ---------------------------------------------------------------------------
# Bit packing (generic 2/4/6/8-bit, vectorized via numpy bit-planes)
# ---------------------------------------------------------------------------
def _pack_bits(codes: np.ndarray, bits: int) -> np.ndarray:
"""Pack per-row uint codes (each ``bits`` wide) into bytes. Returns (n, ceil(P*bits/8))."""
n, p = codes.shape
planes = ((codes[:, :, None] >> np.arange(bits, dtype=np.uint16)) & 1).astype(np.uint8)
flat = planes.reshape(n, p * bits)
return np.packbits(flat, axis=1)
def _unpack_bits(packed: np.ndarray, p: int, bits: int) -> np.ndarray:
"""Inverse of :func:`_pack_bits`."""
n = packed.shape[0]
flat = np.unpackbits(packed, axis=1)[:, : p * bits]
planes = flat.reshape(n, p, bits).astype(np.uint16)
weights = (np.uint16(1) << np.arange(bits, dtype=np.uint16))
return (planes * weights).sum(axis=2).astype(np.uint16)
# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------
def _reconstruct_library(arr2d: np.ndarray, bits: int, seed: int) -> Optional[np.ndarray]:
"""Reconstruct via turboquant-py; None if the library is unavailable / errors."""
tq = _lib()
if tq is None:
return None
try:
TQ = tq.TurboQuant(dim=int(arr2d.shape[1]), bit_width=int(bits),
mode="mse", seed=int(seed) & 0x7FFFFFFF)
rec = TQ.dequantize(TQ.quantize(arr2d))
rec = np.asarray(rec, dtype=np.float32)
if rec.shape != arr2d.shape:
return None
return rec
except Exception:
return None
def _reconstruct_builtin(arr2d: np.ndarray, bits: int, seed: int) -> np.ndarray:
codes, norms, signs, meta, _ = _prepare(arr2d, bits, seed)
return _decode_codes(codes, norms, signs, meta)
def reconstruct(vectors, bits: int, seed: int = DEFAULT_SEED,
backend: str = "builtin") -> List[List[float]]:
"""Quantize then dequantize — the lossy float reconstruction.
The returned vectors are the same shape as the input and behave (in inner
product / cosine) like ``bits``-bit TurboQuant-stored embeddings.
``backend`` selects the implementation explicitly:
* ``builtin`` — the built-in NumPy quantizer (always available).
* ``library`` — the upstream ``turboquant-py`` (QJL inner-product mode);
raises if it is not installed/enabled, rather than silently degrading.
* ``auto`` — library if available, else built-in.
"""
arr, single = _as_2d(vectors)
b = (backend or "builtin").strip().lower()
out = None
if b in ("library", "external", "turboquant-py", "turboquantpy", "auto"):
out = _reconstruct_library(arr, bits, seed)
if out is None and b != "auto":
raise RuntimeError(
"TurboQuant 'library' backend selected but turboquant-py is "
"unavailable or failed — install \"turboquant-py[torch]\" or "
"switch the model's TurboQuant backend to 'builtin'.")
if out is None:
out = _reconstruct_builtin(arr, bits, seed)
lst = out.tolist()
return lst[0] if single else lst
def quantize_packed(vectors, bits: int, seed: int = DEFAULT_SEED
) -> Tuple[List[bytes], TurboQuantMeta]:
"""Quantize to the compact wire form: one ``bytes`` blob per vector.
Each blob is ``[float16 norm][packed b-bit rotated codes]``. Decode with
:func:`unpack_blob` using the returned :class:`TurboQuantMeta`.
"""
codes, norms, _signs, meta, _single = _prepare(vectors, bits, seed)
packed = _pack_bits(codes, bits)
norm16 = norms.astype(np.float16)
blobs = [norm16[i].tobytes() + packed[i].tobytes() for i in range(packed.shape[0])]
return blobs, meta
def quantize_base64(vectors, bits: int, seed: int = DEFAULT_SEED
) -> Tuple[List[str], TurboQuantMeta]:
"""Like :func:`quantize_packed` but each blob base64-encoded (JSON-friendly)."""
blobs, meta = quantize_packed(vectors, bits, seed)
return [base64.b64encode(b).decode("ascii") for b in blobs], meta
def unpack_blob(blob: bytes, meta: TurboQuantMeta) -> List[float]:
"""Decode a single packed blob (or base64 str) back to a float vector."""
if isinstance(blob, str):
blob = base64.b64decode(blob)
norm = np.frombuffer(blob[:2], dtype=np.float16).astype(np.float32)
packed = np.frombuffer(blob[2:], dtype=np.uint8)[None, :]
codes = _unpack_bits(packed, meta.dim_padded, meta.bits)
signs = _signs(meta.dim_padded, meta.seed)
out = _decode_codes(codes, norm, signs, meta)
return out[0].tolist()
if __name__ == "__main__":
# Self-test: rotation round-trips, and quantization preserves inner products.
rng = np.random.default_rng(1)
d = 384
X = rng.standard_normal((64, d)).astype(np.float32)
X /= np.linalg.norm(X, axis=1, keepdims=True)
p = _next_pow2(d)
s = _signs(p, DEFAULT_SEED)
pad = np.zeros((X.shape[0], p), dtype=np.float32); pad[:, :d] = X
assert np.allclose(_irotate(_rotate(pad, s), s), pad, atol=1e-4), "rotation not invertible"
for bits in (8, 4, 2):
R = np.asarray(reconstruct(X, bits))
# cosine between original and reconstruction
cos = (X * R).sum(1) / (np.linalg.norm(R, axis=1) + 1e-9)
# preservation of pairwise inner products
G0 = X @ X.T
G1 = R @ R.T
err = np.abs(G0 - G1).mean()
blobs, meta = quantize_packed(X, bits)
rt = np.asarray(unpack_blob(blobs[0], meta))
assert np.allclose(rt, R[0], atol=1e-5), "packed blob != reconstruct"
print(f"bits={bits}: mean|Δip|={err:.4f} meanCos={cos.mean():.4f} "
f"bytes/vec={meta.bytes_per_vector} (float32={d*4})")
print("turboquant self-test OK")
...@@ -17,16 +17,17 @@ ...@@ -17,16 +17,17 @@
"""Pydantic models for embeddings API.""" """Pydantic models for embeddings API."""
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict, Field
class EmbeddingsRequest(BaseModel): class EmbeddingsRequest(BaseModel):
model: str model: str = Field(..., description="Embedding model id to use.")
input: Union[str, List[str]] # text(s) to embed input: Union[str, List[str]] = Field(..., description="Text or list of texts to embed.")
image: Optional[Union[str, List[str]]] = None # base64/URL image(s) for multimodal embed image: Optional[Union[str, List[str]]] = Field(None, description="Base64/URL image(s) for multimodal embedding models.")
encoding_format: Optional[str] = "float" # float | base64 encoding_format: Optional[str] = Field("float", description="Return embeddings as 'float' arrays or 'base64'.")
dimensions: Optional[int] = None # truncate to N dims if supported dimensions: Optional[int] = Field(None, description="Truncate embeddings to N dimensions (if the model supports it).")
user: Optional[str] = None quantization: Optional[str] = Field(None, description="Optional TurboQuant vector quantization: 'turbo' (8-bit), 'turbo8', 'turbo6', 'turbo4' or 'turbo2'. With encoding_format='float' the (lossy) reconstructed vectors are returned; with 'base64' the compact packed bytes are returned plus a 'quantization' metadata block describing how to decode them.")
user: Optional[str] = Field(None, description="Opaque end-user identifier (passthrough).")
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow")
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
from typing import Dict, List, Optional from typing import Dict, List, Optional
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict, Field
class LoraConfig(BaseModel): class LoraConfig(BaseModel):
...@@ -26,41 +26,43 @@ class LoraConfig(BaseModel): ...@@ -26,41 +26,43 @@ class LoraConfig(BaseModel):
server-side, in priority) via `id` ("name:<registered>" or "sha256:<hex>"), server-side, in priority) via `id` ("name:<registered>" or "sha256:<hex>"),
inline `file`/`data` base64, a `url`, or the legacy `model`/`path` (local path inline `file`/`data` base64, a `url`, or the legacy `model`/`path` (local path
/ HF id) — so a remote client needn't share the server's filesystem.""" / HF id) — so a remote client needn't share the server's filesystem."""
model: Optional[str] = None model: Optional[str] = Field(None, description="Legacy: local path or HF id of the weights (shared-filesystem only).")
path: Optional[str] = None path: Optional[str] = Field(None, description="Alias of `model` — local path to the .safetensors weights.")
id: Optional[str] = None id: Optional[str] = Field(None, description='Registry/blob reference: "name:<registered-lora>" or "sha256:<hex>" (from /v1/loras/upload).')
url: Optional[str] = None url: Optional[str] = Field(None, description="HTTP(S) URL the server downloads and caches.")
file: Optional[str] = None file: Optional[str] = Field(None, description="Base64 of the .safetensors file (or a data: URI). Sent inline so no shared filesystem is needed.")
data: Optional[str] = None data: Optional[str] = Field(None, description="Alias of `file` — inline base64 weights.")
weight: float = 1.0 weight: float = Field(1.0, description="Adapter strength / scale.")
name: Optional[str] = None name: Optional[str] = Field(None, description="Optional adapter name.")
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow")
class ImageGenerationRequest(BaseModel): class ImageGenerationRequest(BaseModel):
model: str model: str = Field(..., description="Model id to generate with (must be a configured image model).")
prompt: str prompt: str = Field(..., description="Text prompt describing the image.")
n: int = 1 n: int = Field(1, description="Number of images to generate.")
size: Optional[str] = "1024x1024" size: Optional[str] = Field("1024x1024", description="Output size as 'WIDTHxHEIGHT'.")
steps: Optional[int] = None steps: Optional[int] = Field(None, description="Denoising steps (model/acceleration default if omitted).")
guidance_scale: Optional[float] = None guidance_scale: Optional[float] = Field(None, description="Classifier-free guidance scale (model/acceleration default if omitted).")
quality: Optional[str] = "standard" quality: Optional[str] = Field("standard", description="Quality hint: 'standard' or 'hd'.")
style: Optional[str] = None style: Optional[str] = Field(None, description="Optional style hint passed through to the model.")
response_format: Optional[str] = "url" response_format: Optional[str] = Field("url", description="How to return the result: 'url' or 'b64_json'.")
seed: Optional[int] = None seed: Optional[int] = Field(None, description="Random seed for reproducibility.")
user: Optional[str] = None user: Optional[str] = Field(None, description="Opaque end-user identifier (passthrough).")
disable_safety_checker: Optional[bool] = False disable_safety_checker: Optional[bool] = Field(False, description=(
negative_prompt: Optional[str] = None "Null out the diffusers safety_checker so uncensored fine-tunes are not blocked. "
"Only affects SD 1.x/2.x (SDXL/Flux ship no checker)."))
negative_prompt: Optional[str] = Field(None, description="What to avoid in the output.")
# Per-request component overrides # Per-request component overrides
vae_model: Optional[str] = None # Override the VAE for this request vae_model: Optional[str] = Field(None, description="Override the VAE for this request.")
loras: Optional[List[LoraConfig]] = None # Additional LoRA weights for this request loras: Optional[List[LoraConfig]] = Field(None, description="Additional LoRA adapters to apply for this request.")
# Character consistency # Character consistency
character_profiles: Optional[List[str]] = None # saved profile names character_profiles: Optional[List[str]] = Field(None, description="Saved character profile names to apply (IP-Adapter).")
character_references: Optional[List[str]] = None # inline base64 images character_references: Optional[List[str]] = Field(None, description="Inline base64 reference images for character consistency.")
character_strength: Optional[float] = 0.6 # IP-Adapter scale character_strength: Optional[float] = Field(0.6, description="IP-Adapter scale for character references.")
environment_profiles: Optional[List[str]] = None # saved environment profile names (IP-Adapter) environment_profiles: Optional[List[str]] = Field(None, description="Saved environment profile names to apply (IP-Adapter).")
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow")
......
...@@ -67,34 +67,32 @@ class ChatMessage(BaseModel): ...@@ -67,34 +67,32 @@ class ChatMessage(BaseModel):
class ChatCompletionRequest(BaseModel): class ChatCompletionRequest(BaseModel):
model: str model: str = Field(..., description="Text/chat model id to use.")
messages: List[ChatMessage] messages: List[ChatMessage] = Field(..., description="Conversation messages (roles: system/user/assistant/tool). Content may include text and image parts for vision models.")
temperature: float = 0.7 temperature: float = Field(0.7, description="Sampling temperature; higher = more random.")
top_p: float = 1.0 top_p: float = Field(1.0, description="Nucleus sampling probability mass.")
n: int = 1 n: int = Field(1, description="Number of completions to generate.")
max_tokens: Optional[int] = None max_tokens: Optional[int] = Field(None, description="Max tokens to generate (model default if omitted).")
stream: bool = False stream: bool = Field(False, description="Stream the response as Server-Sent Events.")
stop: Optional[Union[str, List[str]]] = None stop: Optional[Union[str, List[str]]] = Field(None, description="Stop sequence(s) that end generation.")
presence_penalty: float = 0.0 presence_penalty: float = Field(0.0, description="Penalize tokens already present (encourages new topics).")
frequency_penalty: float = 0.0 frequency_penalty: float = Field(0.0, description="Penalize frequent tokens (reduces repetition).")
repeat_penalty: float = 1.0 repeat_penalty: float = Field(1.0, description="llama.cpp repetition penalty.")
tools: Optional[List[Tool]] = None tools: Optional[List[Tool]] = Field(None, description="Tool/function definitions the model may call.")
tool_choice: Optional[Union[str, Dict]] = "auto" tool_choice: Optional[Union[str, Dict]] = Field("auto", description="Tool selection: 'auto', 'none', or a specific tool.")
# Extra fields that clients may send but we ignore seed: Optional[int] = Field(None, description="Random seed for reproducibility.")
seed: Optional[int] = None logprobs: Optional[bool] = Field(None, description="Return token log-probabilities (if supported).")
logprobs: Optional[bool] = None top_logprobs: Optional[int] = Field(None, description="Number of top log-probs to return per token.")
top_logprobs: Optional[int] = None response_format: Optional[Dict] = Field(None, description="Structured-output format, e.g. {'type': 'json_object'}.")
response_format: Optional[Dict] = None user: Optional[str] = Field(None, description="Opaque end-user identifier (passthrough).")
user: Optional[str] = None enable_thinking: Optional[bool] = Field(False, description="Enable thinking/reasoning mode for models that support it.")
# Enable thinking/reasoning mode for supported models
enable_thinking: Optional[bool] = False
model_config = ConfigDict(extra="allow") # Allow extra fields to prevent 422 errors model_config = ConfigDict(extra="allow") # Allow extra fields to prevent 422 errors
class CompletionRequest(BaseModel): class CompletionRequest(BaseModel):
model: str model: str = Field(..., description="Text model id to use.")
prompt: Union[str, List[str]] prompt: Union[str, List[str]] = Field(..., description="Prompt text (or list of prompts) to complete.")
temperature: float = 0.7 temperature: float = 0.7
top_p: float = 1.0 top_p: float = 1.0
n: int = 1 n: int = 1
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
"""Pydantic models for video generation API.""" """Pydantic models for video generation API."""
from typing import Dict, List, Optional from typing import Dict, List, Optional
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict, Field
class VideoLoraConfig(BaseModel): class VideoLoraConfig(BaseModel):
...@@ -28,14 +28,14 @@ class VideoLoraConfig(BaseModel): ...@@ -28,14 +28,14 @@ class VideoLoraConfig(BaseModel):
blob), inline `file`/`data` base64, a `url` to download, or the legacy blob), inline `file`/`data` base64, a `url` to download, or the legacy
`model`/`path` local path / HF id. This lets a client on a different machine `model`/`path` local path / HF id. This lets a client on a different machine
use a LoRA without sharing a filesystem with the server.""" use a LoRA without sharing a filesystem with the server."""
model: Optional[str] = None # legacy: local path or HF id of the weights model: Optional[str] = Field(None, description="Legacy: local path or HF id of the weights (shared-filesystem only).")
path: Optional[str] = None # alias of model path: Optional[str] = Field(None, description="Alias of `model` — local path to the .safetensors weights.")
id: Optional[str] = None # "name:<registered>" or "sha256:<hex>" id: Optional[str] = Field(None, description='Registry/blob reference: "name:<registered-lora>" or "sha256:<hex>" (from /v1/loras/upload).')
url: Optional[str] = None # http(s) URL the server downloads url: Optional[str] = Field(None, description="HTTP(S) URL the server downloads and caches.")
file: Optional[str] = None # base64 of the .safetensors (or data: URI) file: Optional[str] = Field(None, description="Base64 of the .safetensors file (or a data: URI). Sent inline so no shared filesystem is needed.")
data: Optional[str] = None # alias of file data: Optional[str] = Field(None, description="Alias of `file` — inline base64 weights.")
weight: float = 1.0 weight: float = Field(1.0, description="Adapter strength / scale applied when fusing the LoRA.")
name: Optional[str] = None # adapter name name: Optional[str] = Field(None, description="Optional adapter name (defaults derived from the reference).")
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow")
...@@ -52,104 +52,98 @@ class CharacterDialogLine(BaseModel): ...@@ -52,104 +52,98 @@ class CharacterDialogLine(BaseModel):
class VideoGenerationRequest(BaseModel): class VideoGenerationRequest(BaseModel):
model: str model: str = Field(..., description="Model id to generate with (must be a configured video model).")
prompt: str = "" prompt: str = Field("", description="Text prompt describing the video. Optional for pure i2v.")
negative_prompt: Optional[str] = None negative_prompt: Optional[str] = Field(None, description="What to avoid in the output.")
# Dimensions # Dimensions
width: Optional[int] = 512 width: Optional[int] = Field(512, description="Output width in pixels.")
height: Optional[int] = 512 height: Optional[int] = Field(512, description="Output height in pixels.")
# Temporal # Temporal
num_frames: Optional[int] = None # model default if None num_frames: Optional[int] = Field(None, description="Number of frames to generate (model default if omitted).")
fps: Optional[int] = None # output FPS fps: Optional[int] = Field(None, description="Output frame rate (model default if omitted).")
# Diffusion # Diffusion
num_inference_steps: Optional[int] = None num_inference_steps: Optional[int] = Field(None, description="Denoising steps (model/acceleration default if omitted).")
guidance_scale: Optional[float] = None guidance_scale: Optional[float] = Field(None, description="Classifier-free guidance scale (model/acceleration default if omitted).")
seed: Optional[int] = None seed: Optional[int] = Field(None, description="Random seed for reproducibility.")
# Mode mode: Optional[str] = Field("t2v", description=(
# t2v – text-to-video "Generation mode: 't2v' text-to-video; 'i2v' image-to-video (init_image required, "
# i2v – image-to-video (init_image required) "prompt dropped); 'ti2v' text+init image (prompt is primary driver); 'v2v' "
# v2v – video-to-video (video required) "video-to-video (video required); 'interp' frame interpolation (init_image+end_image). "
# ti2v – text + init image → video (like i2v but prompt is primary driver) "The server gracefully falls back between Wan t2v/i2v pipelines when a model only "
# interp – frame interpolation (init_image + end_image) "supports one."))
mode: Optional[str] = "t2v"
# Input media (base64 or URL) # Input media (base64 or URL)
image: Optional[str] = None # alias for init_image image: Optional[str] = Field(None, description="Alias for init_image (base64 or URL).")
init_image: Optional[str] = None # first/reference frame init_image: Optional[str] = Field(None, description="First/reference frame for i2v/ti2v (base64 or URL).")
end_image: Optional[str] = None # last frame (for interp mode) end_image: Optional[str] = Field(None, description="Last frame, for 'interp' mode (base64 or URL).")
video: Optional[str] = None # input video (v2v / audio manipulation) video: Optional[str] = Field(None, description="Input video for v2v / audio manipulation (base64 or URL).")
strength: Optional[float] = None # denoising strength for v2v strength: Optional[float] = Field(None, description="Denoising strength for v2v (0–1).")
# Camera motion hint # Camera motion hint
camera_motion: Optional[str] = None # zoom-in | zoom-out | pan-left | pan-right | tilt-up | tilt-down | rotate camera_motion: Optional[str] = None # zoom-in | zoom-out | pan-left | pan-right | tilt-up | tilt-down | rotate
# ── Character consistency ───────────────────────────────────────────── # ── Character consistency ─────────────────────────────────────────────
# Each entry: {"name": "Alice", "images": ["b64...", ...]} characters: Optional[List[dict]] = Field(None, description='Per-character references, each {"name": "Alice", "images": ["b64...", ...]}.')
characters: Optional[List[dict]] = None character_references: Optional[List[str]] = Field(None, description="Legacy flat list of base64/URL reference images.")
# Legacy flat list of base64/URL reference images (still accepted) character_strength: Optional[float] = Field(0.8, description="IP-Adapter scale for character references.")
character_references: Optional[List[str]] = None character_names: Optional[List[str]] = Field(None, description="Optional names aligned with character_references.")
character_strength: Optional[float] = 0.8 character_profiles: Optional[List[str]] = Field(None, description="Saved character profile names to load (resolved server-side).")
character_names: Optional[List[str]] = None # optional names per reference
# Named saved profiles to load (resolved server-side) loras: Optional[List[VideoLoraConfig]] = Field(None, description="Per-request LoRA adapters (e.g. trained per-character identity LoRAs) fused into the pipeline.")
character_profiles: Optional[List[str]] = None
# Per-request LoRA adapters (e.g. trained per-character identity LoRAs).
# Applied to diffusers video pipelines that support load_lora_weights.
loras: Optional[List[VideoLoraConfig]] = None
# ── Audio generation / manipulation ────────────────────────────────── # ── Audio generation / manipulation ──────────────────────────────────
add_audio: Optional[bool] = False add_audio: Optional[bool] = Field(False, description="Add an audio track to the generated video.")
audio_type: Optional[str] = None # music | speech | sfx | ambient audio_type: Optional[str] = Field(None, description="Audio to generate/add: 'music', 'speech', 'sfx' or 'ambient'.")
audio_prompt: Optional[str] = None # prompt for music/sfx generation audio_prompt: Optional[str] = Field(None, description="Prompt for generated music/sfx.")
audio_file: Optional[str] = None # existing audio to add (base64/URL) audio_file: Optional[str] = Field(None, description="Existing audio to add (base64/URL).")
tts_text: Optional[str] = None # text for speech synthesis tts_text: Optional[str] = Field(None, description="Text to synthesize as speech.")
tts_voice: Optional[str] = None # TTS voice id tts_voice: Optional[str] = Field(None, description="TTS voice id for synthesized speech.")
tts_speed: Optional[float] = 1.0 tts_speed: Optional[float] = Field(1.0, description="TTS speaking speed multiplier.")
sync_audio: Optional[bool] = False # sync audio timing to video sync_audio: Optional[bool] = Field(False, description="Sync audio timing to the video.")
lip_sync: Optional[bool] = False # warp mouth to match audio lip_sync: Optional[bool] = Field(False, description="Warp the mouth to match the audio.")
lip_sync_method: Optional[str] = "wav2lip" # wav2lip | sadtalker lip_sync_method: Optional[str] = Field("wav2lip", description="Lip-sync engine: 'wav2lip' or 'sadtalker'.")
# ── Multi-character dialog ──────────────────────────────────────────── # ── Multi-character dialog ────────────────────────────────────────────
dialogs: Optional[List[CharacterDialogLine]] = None dialogs: Optional[List[CharacterDialogLine]] = Field(None, description="Ordered spoken lines for multi-character dialog with per-line voice/lip-sync.")
# ── Subtitles ──────────────────────────────────────────────────────── # ── Subtitles ────────────────────────────────────────────────────────
generate_subtitles: Optional[bool] = False generate_subtitles: Optional[bool] = Field(False, description="Generate subtitles (SRT) for the video.")
burn_subtitles: Optional[bool] = False burn_subtitles: Optional[bool] = Field(False, description="Burn the subtitles into the video frames.")
subtitle_language: Optional[str] = None # source language hint subtitle_language: Optional[str] = Field(None, description="Source language hint for subtitle transcription.")
translate_subtitles: Optional[bool] = False translate_subtitles: Optional[bool] = Field(False, description="Translate subtitles to subtitle_target_lang.")
subtitle_target_lang: Optional[str] = None subtitle_target_lang: Optional[str] = Field(None, description="Target language for translated subtitles.")
subtitle_style: Optional[str] = "default" # default | karaoke | minimal subtitle_style: Optional[str] = Field("default", description="Subtitle style: 'default', 'karaoke' or 'minimal'.")
whisper_model: Optional[str] = None # which whisper variant to use whisper_model: Optional[str] = Field(None, description="Whisper variant used for transcription.")
# ── Video dubbing ───────────────────────────────────────────────────── # ── Video dubbing ─────────────────────────────────────────────────────
dub_video: Optional[bool] = False dub_video: Optional[bool] = Field(False, description="Dub the video's speech into another language.")
dub_target_lang: Optional[str] = None dub_target_lang: Optional[str] = Field(None, description="Target language for dubbing.")
dub_source_lang: Optional[str] = None dub_source_lang: Optional[str] = Field(None, description="Source language of the original speech.")
voice_clone: Optional[bool] = False # clone original speaker voice voice_clone: Optional[bool] = Field(False, description="Clone the original speaker's voice when dubbing.")
# ── Post-processing ─────────────────────────────────────────────────── # ── Post-processing ───────────────────────────────────────────────────
upscale_output: Optional[bool] = False upscale_output: Optional[bool] = Field(False, description="Upscale the generated video.")
upscale_factor: Optional[int] = 2 upscale_factor: Optional[int] = Field(2, description="Upscaling factor.")
interpolate_output: Optional[bool] = False # increase FPS after generation interpolate_output: Optional[bool] = Field(False, description="Increase FPS via frame interpolation after generation.")
fps_multiplier: Optional[int] = 2 # e.g. 2 → 2× FPS via frame interp fps_multiplier: Optional[int] = Field(2, description="FPS multiplier for interpolation (e.g. 2 → 2× FPS).")
convert_to_3d: Optional[bool] = False convert_to_3d: Optional[bool] = Field(False, description="Produce a depth-based 3D/stereo version.")
depth_method: Optional[str] = "midas" # midas | zoe | depth-anything depth_method: Optional[str] = Field("midas", description="Depth estimator: 'midas', 'zoe' or 'depth-anything'.")
# ── Memory / offload ───────────────────────────────────────────────── # ── Memory / offload ─────────────────────────────────────────────────
offload_strategy: Optional[str] = None # sequential | model | none offload_strategy: Optional[str] = Field(None, description="VRAM offload strategy: 'sequential', 'model' or 'none'.")
# Nulls pipeline safety_checker / safety_concept so uncensored fine-tunes disable_safety_checker: Optional[bool] = Field(False, description=(
# are not blocked. Has no effect on models without a safety checker. "Null out the diffusers safety_checker/safety_concept so uncensored fine-tunes "
disable_safety_checker: Optional[bool] = False "are not blocked. No effect on models that ship no safety checker (SDXL/Flux/Wan)."))
# ── Output ─────────────────────────────────────────────────────────── # ── Output ───────────────────────────────────────────────────────────
response_format: Optional[str] = "url" # url | b64_mp4 response_format: Optional[str] = Field("url", description="How to return the result: 'url' or 'b64_mp4'.")
n: int = 1 n: int = Field(1, description="Number of videos to generate.")
user: Optional[str] = None user: Optional[str] = Field(None, description="Opaque end-user identifier (passthrough).")
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow")
......
...@@ -169,6 +169,23 @@ class QueueManager: ...@@ -169,6 +169,23 @@ class QueueManager:
return index return index
return 0 return 0
def list_waiting(self) -> list:
"""Best-effort snapshot of queued (waiting) requests for the Tasks view.
Read without the async lock — fine for a read-only UI snapshot."""
out = []
for w in list(self.waiting):
out.append({
"request_id": w.request_id,
"model_key": w.model_key,
"enqueued_at": w.enqueued_at,
})
return out
def list_active(self) -> list:
"""Best-effort snapshot of in-flight leases for the Tasks view."""
return [{"request_id": rid, "model_key": lease.model_key}
for rid, lease in list(self.active_leases.items())]
def get_metrics(self) -> Dict[str, object]: def get_metrics(self) -> Dict[str, object]:
return { return {
"active": len(self.active_leases), "active": len(self.active_leases),
......
# CoderAI - OpenAI-compatible API server
# Copyright (C) 2026 Stefy Lanza <stefy@nexlab.net>
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
"""Central task registry for long-running operations."""
from codai.tasks.registry import (
Task,
TaskCancelled,
TaskRegistry,
task_registry,
raise_if_cancelled,
wait_if_paused,
)
__all__ = [
"Task",
"TaskCancelled",
"TaskRegistry",
"task_registry",
"raise_if_cancelled",
"wait_if_paused",
]
# CoderAI - OpenAI-compatible API server
# Copyright (C) 2026 Stefy Lanza <stefy@nexlab.net>
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
"""Central registry of long-running tasks (generation + training).
Lets the admin UI list every in-flight / recent task and cooperatively cancel
one. Cancellation is *cooperative*: the per-step diffusers callbacks
(``_step_cb`` / ``_vid_step_cb`` / ``_aud_step_cb`` / sd.cpp ``progress_callback``)
and the LoRA training loops call ``raise_if_cancelled(task_id)`` between steps,
which raises :class:`TaskCancelled` once the task is cancelled. This mirrors the
``codai.models.thermal.checkpoint()`` pause already invoked in those same hooks,
so the running model aborts at the next step boundary.
The registry is process-local and in-memory: it is the live view. Durable LoRA
training state lives separately in ``codai/api/loras.py`` (``_train_jobs.json``);
a task with a ``job_id`` links the two.
"""
import threading
import time
import uuid
from dataclasses import asdict, dataclass, field
from typing import Dict, List, Optional
class TaskCancelled(Exception):
"""Raised inside a worker when its task has been cancelled by the user."""
ACTIVE_STATES = ("queued", "running")
TERMINAL_STATES = ("done", "error", "cancelled")
@dataclass
class Task:
id: str
kind: str # training | image | video | audio | text | pipeline
title: str = ""
model: str = ""
status: str = "queued" # queued | running | done | error | cancelled
step: int = 0
total: int = 0
message: str = ""
job_id: Optional[str] = None # link to a durable loras training job, if any
created_at: float = field(default_factory=time.time)
started_at: Optional[float] = None
ended_at: Optional[float] = None
cancellable: bool = True
restartable: bool = False
paused: bool = False
def to_dict(self) -> dict:
d = asdict(self)
d["active"] = self.status in ACTIVE_STATES
return d
class TaskRegistry:
"""Thread-safe registry. All public methods take the registry lock briefly;
the per-task cancel flag is a ``threading.Event`` so ``is_cancelled`` is a
cheap lock-free read suitable for a hot step loop."""
def __init__(self, history: int = 50):
self._lock = threading.Lock()
self._tasks: Dict[str, Task] = {}
self._events: Dict[str, threading.Event] = {}
self._pause_events: Dict[str, threading.Event] = {}
self._history = history
def register(self, kind: str, *, title: str = "", model: str = "",
total: int = 0, job_id: Optional[str] = None,
status: str = "queued", cancellable: bool = True,
restartable: bool = False, task_id: Optional[str] = None) -> str:
tid = task_id or f"task-{uuid.uuid4().hex[:12]}"
with self._lock:
self._tasks[tid] = Task(
id=tid, kind=kind, title=title, model=model, total=total,
job_id=job_id, status=status, cancellable=cancellable,
restartable=restartable,
)
self._events[tid] = threading.Event()
self._pause_events[tid] = threading.Event()
self._prune_locked()
return tid
def start(self, tid: str) -> None:
with self._lock:
t = self._tasks.get(tid)
if t and t.status in ACTIVE_STATES:
t.status = "running"
t.started_at = t.started_at or time.time()
def update(self, tid: str, **fields) -> None:
with self._lock:
t = self._tasks.get(tid)
if not t:
return
for k, v in fields.items():
if hasattr(t, k):
setattr(t, k, v)
def step(self, tid: str, step: int, total: Optional[int] = None) -> None:
with self._lock:
t = self._tasks.get(tid)
if not t:
return
t.step = int(step)
if total is not None:
t.total = int(total)
def finish(self, tid: str, status: str = "done", message: str = "") -> None:
with self._lock:
t = self._tasks.get(tid)
if not t:
return
# A user cancel wins over a late 'done' from the worker unwinding.
if not (t.status == "cancelled" and status == "done"):
t.status = status
if message:
t.message = message
t.ended_at = time.time()
self._prune_locked()
def cancel(self, tid: str) -> bool:
with self._lock:
t = self._tasks.get(tid)
if not t:
return False
ev = self._events.get(tid)
if ev:
ev.set()
# Release any pause so a paused→cancelled task unblocks immediately.
pev = self._pause_events.get(tid)
if pev:
pev.set() # wakes wait_if_paused; its is_cancelled check then raises
t.paused = False
was = t.status
if was in ACTIVE_STATES:
t.status = "cancelled"
t.message = "cancelled"
# A queued task never entered a worker, so it's terminal now;
# a running task finalises when its worker observes the flag.
if was == "queued":
t.ended_at = time.time()
return True
def is_cancelled(self, tid: Optional[str]) -> bool:
if not tid:
return False
ev = self._events.get(tid)
return bool(ev and ev.is_set())
def raise_if_cancelled(self, tid: Optional[str]) -> None:
if self.is_cancelled(tid):
raise TaskCancelled(tid)
# --- Pause / resume (cooperative, at the next step boundary) -------------
def pause(self, tid: str) -> bool:
with self._lock:
t = self._tasks.get(tid)
ev = self._pause_events.get(tid)
if not t or ev is None or t.status not in ACTIVE_STATES:
return False
ev.set()
t.paused = True
return True
def resume(self, tid: str) -> bool:
with self._lock:
t = self._tasks.get(tid)
ev = self._pause_events.get(tid)
if not t or ev is None:
return False
ev.clear()
t.paused = False
return True
def is_paused(self, tid: Optional[str]) -> bool:
if not tid:
return False
ev = self._pause_events.get(tid)
return bool(ev and ev.is_set())
def wait_if_paused(self, tid: Optional[str], poll: float = 0.2) -> None:
"""Block while the task is paused, returning when it is resumed.
Stays responsive to cancellation: a paused task that is then cancelled
raises :class:`TaskCancelled` instead of hanging. Safe to call from a hot
step loop — a no-op unless the task is actually paused."""
ev = self._pause_events.get(tid) if tid else None
if ev is None or not ev.is_set():
return
while ev.is_set():
if self.is_cancelled(tid):
raise TaskCancelled(tid)
ev.wait(timeout=poll)
def remove(self, tid: str) -> bool:
"""Drop a task from the registry entirely (used to dismiss a finished/
cancelled task from the view). No-op on a missing id."""
with self._lock:
self._events.pop(tid, None)
self._pause_events.pop(tid, None)
return self._tasks.pop(tid, None) is not None
def get(self, tid: str) -> Optional[dict]:
with self._lock:
t = self._tasks.get(tid)
return t.to_dict() if t else None
def list(self) -> List[dict]:
with self._lock:
return [t.to_dict() for t in sorted(
self._tasks.values(), key=lambda x: x.created_at, reverse=True)]
def _prune_locked(self) -> None:
"""Keep all active tasks + the most recent ``history`` terminal ones."""
terminal = [t for t in self._tasks.values() if t.status in TERMINAL_STATES]
if len(terminal) <= self._history:
return
terminal.sort(key=lambda x: x.ended_at or x.created_at)
for t in terminal[:-self._history]:
self._tasks.pop(t.id, None)
self._events.pop(t.id, None)
self._pause_events.pop(t.id, None)
# Process-wide singleton.
task_registry = TaskRegistry()
def raise_if_cancelled(task_id: Optional[str]) -> None:
"""Free helper mirroring ``thermal.checkpoint()`` — call from hot step loops.
Raises :class:`TaskCancelled` if ``task_id`` has been cancelled; a falsy
``task_id`` is a no-op (so callers needn't guard)."""
task_registry.raise_if_cancelled(task_id)
def wait_if_paused(task_id: Optional[str]) -> None:
"""Free helper — block at a step boundary while ``task_id`` is paused.
Returns immediately when not paused; raises :class:`TaskCancelled` if the
task is cancelled while paused. A falsy ``task_id`` is a no-op."""
task_registry.wait_if_paused(task_id)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment