Commit f55f6578 authored by Stefy Lanza (nextime / spora )'s avatar Stefy Lanza (nextime / spora )

Merge feature/tasks-quant-thermal: task mgmt, quantization, Wan2.2 video...

Merge feature/tasks-quant-thermal: task mgmt, quantization, Wan2.2 video fixes, pipeline cache, smarter offload
Co-Authored-By: 's avatarClaude Opus 4.8 <noreply@anthropic.com>
parents 9494d1bd dbc09b75
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
<a href="{{ root_path }}/docs" class="nav-link" target="_blank">API Docs</a> <a href="{{ root_path }}/docs" class="nav-link" target="_blank">API Docs</a>
{% if is_admin|default(false) %} {% if is_admin|default(false) %}
<a href="{{ root_path }}/admin/models" class="nav-link {% if '/models' in request.url.path %}active{% endif %}">Models</a> <a href="{{ root_path }}/admin/models" class="nav-link {% if '/models' in request.url.path %}active{% endif %}">Models</a>
<a href="{{ root_path }}/admin/tasks" class="nav-link {% if '/tasks' in request.url.path %}active{% endif %}">Tasks</a>
<a href="{{ root_path }}/admin/tokens" class="nav-link {% if '/tokens' in request.url.path %}active{% endif %}">Tokens</a> <a href="{{ root_path }}/admin/tokens" class="nav-link {% if '/tokens' in request.url.path %}active{% endif %}">Tokens</a>
<a href="{{ root_path }}/admin/users" class="nav-link {% if '/users' in request.url.path %}active{% endif %}">Users</a> <a href="{{ root_path }}/admin/users" class="nav-link {% if '/users' in request.url.path %}active{% endif %}">Users</a>
<a href="{{ root_path }}/admin/archive" class="nav-link {% if '/archive' in request.url.path %}active{% endif %}">Archive</a> <a href="{{ root_path }}/admin/archive" class="nav-link {% if '/archive' in request.url.path %}active{% endif %}">Archive</a>
......
This diff is collapsed.
...@@ -153,6 +153,24 @@ ...@@ -153,6 +153,24 @@
</div> </div>
</div> </div>
<!-- Background jobs -->
<div class="card mb-0" style="margin-top:1rem">
<div class="card-title">Background Jobs</div>
<span class="form-hint" style="display:block;margin-bottom:.75rem">
Controls how interrupted LoRA training is handled when CoderAI restarts.
Equivalent to the <code>--no-resume-jobs</code> launch flag.
</span>
<div class="form-row" style="margin:0">
<label style="display:flex;align-items:center;gap:.5rem;cursor:pointer">
<input type="checkbox" id="s-jobs-resume">
<span style="font-size:13px;font-weight:500">Resume interrupted training on restart</span>
</label>
<span class="form-hint">When off, a training job that was running at restart is marked
<em>cancelled</em> instead of resuming. Its checkpoint is kept, so you can still
restart it manually from the Tasks page.</span>
</div>
</div>
<div class="card mb-0" style="margin-top:1rem"> <div class="card mb-0" style="margin-top:1rem">
<div class="card-title">AISBF Broker</div> <div class="card-title">AISBF Broker</div>
<div class="form-row"> <div class="form-row">
...@@ -328,6 +346,9 @@ async function loadSettings(){ ...@@ -328,6 +346,9 @@ async function loadSettings(){
document.getElementById('s-therm-cpu-resume').value = therm.cpu_resume ?? 87; document.getElementById('s-therm-cpu-resume').value = therm.cpu_resume ?? 87;
document.getElementById('s-therm-poll').value = therm.poll_seconds ?? 5; document.getElementById('s-therm-poll').value = therm.poll_seconds ?? 5;
toggleThermalFields(); toggleThermalFields();
// Background jobs
const jobs = d.jobs || {};
document.getElementById('s-jobs-resume').checked = jobs.resume_on_restart !== false;
}catch(e){ showAlert('error','Failed to load settings: '+e.message); } }catch(e){ showAlert('error','Failed to load settings: '+e.message); }
} }
...@@ -363,6 +384,9 @@ async function saveSettings(){ ...@@ -363,6 +384,9 @@ async function saveSettings(){
cpu_resume: parseFloat(document.getElementById('s-therm-cpu-resume').value) || 87, cpu_resume: parseFloat(document.getElementById('s-therm-cpu-resume').value) || 87,
poll_seconds: parseFloat(document.getElementById('s-therm-poll').value) || 5, poll_seconds: parseFloat(document.getElementById('s-therm-poll').value) || 5,
}, },
jobs:{
resume_on_restart: document.getElementById('s-jobs-resume').checked,
},
broker:{ broker:{
enabled: document.getElementById('s-broker-enabled').checked, enabled: document.getElementById('s-broker-enabled').checked,
base_url: document.getElementById('s-broker-base-url').value.trim(), base_url: document.getElementById('s-broker-base-url').value.trim(),
......
{% extends "base.html" %}
{% block title %}Tasks — CoderAI{% endblock %}
{% block content %}
<div class="page-header">
<div>
<h1>Tasks</h1>
<p>Live view of generations and LoRA training. Cancel, interrupt, or restart a job.</p>
</div>
<div class="header-actions">
<span id="queue-summary" class="dim small"></span>
</div>
</div>
<div id="thermal-banner" style="display:none;margin:0 0 1rem;padding:.6rem .85rem;border-radius:8px;
background:rgba(245,158,11,.12);border:1px solid rgba(245,158,11,.4);color:#f59e0b;font-size:13px">
<span style="font-weight:600">❄ Thermal cooldown</span>
<span id="thermal-banner-msg" class="mono"></span>
— running work is paused until the hardware cools.
</div>
<!-- Live hardware telemetry -->
<div id="sys-stats" style="display:grid;grid-template-columns:repeat(auto-fit,minmax(220px,1fr));
gap:.75rem;margin:0 0 1.25rem">
<div class="sys-tile" id="tile-cpu"></div>
<div class="sys-tile" id="tile-gpu"></div>
<div class="sys-tile" id="tile-ram"></div>
<div class="sys-tile" id="tile-vram"></div>
</div>
<style>
.sys-tile{border:1px solid var(--border,#2a2a2a);border-radius:10px;padding:.7rem .85rem;
background:var(--card-bg,rgba(255,255,255,.02))}
.sys-tile .sys-head{display:flex;justify-content:space-between;align-items:baseline;margin-bottom:.45rem}
.sys-tile .sys-name{font-size:12px;font-weight:600;letter-spacing:.03em;text-transform:uppercase;color:var(--text-muted,#9aa0a6)}
.sys-tile .sys-val{font-size:13px;font-weight:600}
.sys-tile .sys-sub{font-size:11px;color:var(--text-muted,#9aa0a6);margin-top:.35rem;display:flex;justify-content:space-between}
.sys-bar{height:8px;border-radius:5px;background:rgba(127,127,127,.18);overflow:hidden}
.sys-bar > span{display:block;height:100%;border-radius:5px;transition:width .4s ease,background .4s ease}
.sys-ok > span{background:#22c55e}.sys-warn > span{background:#f59e0b}.sys-hot > span{background:#ef4444}
.sys-temp-ok{color:#22c55e}.sys-temp-warn{color:#f59e0b}.sys-temp-hot{color:#ef4444}
</style>
<div class="table-wrap">
<table>
<thead>
<tr>
<th>Type</th><th>Name / Model</th><th>Status</th>
<th style="width:220px">Progress</th><th>Started</th><th style="text-align:right">Actions</th>
</tr>
</thead>
<tbody id="tasks-body">
<tr class="empty-row"><td colspan="6">No tasks yet</td></tr>
</tbody>
</table>
</div>
{% endblock %}
{% block scripts %}
<script>
function esc(s) { return String(s == null ? '' : s).replace(/&/g,'&amp;').replace(/</g,'&lt;').replace(/>/g,'&gt;'); }
function fmtTime(s) {
if (!s) return '';
try {
// started_at is unix seconds (float) from the server.
const d = new Date(s * 1000);
return d.toLocaleTimeString(undefined, {hour:'2-digit', minute:'2-digit', second:'2-digit'});
} catch { return ''; }
}
const KIND_LABEL = {training:'Training', image:'Image', video:'Video', audio:'Audio', text:'Text', pipeline:'Pipeline', request:'Request', loading:'Loading'};
const STATUS_BADGE = {
running:'badge-admin', queued:'badge-user', done:'badge-ok', error:'badge-err',
cancelled:'badge-user', interrupted:'badge-warn'
};
function progressBar(t) {
const total = t.total || 0, step = t.step || 0;
if (!total) {
return t.status === 'running' ? '<span class="dim small">working…</span>' : '<span class="dim small">—</span>';
}
const pct = Math.max(0, Math.min(100, Math.round(step / total * 100)));
return `<div class="progress"><div class="progress-fill" style="width:${pct}%"></div></div>
<span class="dim small">${step}/${total} (${pct}%)</span>`;
}
function actions(t) {
const btns = [];
if (t.paused) {
btns.push(`<button class="btn btn-primary btn-sm" onclick="taskAction('${esc(t.id)}','resume')">Resume</button>`);
} else if (t.pausable) {
btns.push(`<button class="btn btn-ghost btn-sm" onclick="taskAction('${esc(t.id)}','pause')">Pause</button>`);
}
if (t.cancellable) {
const label = t.status === 'running' ? 'Interrupt' : 'Cancel';
const act = t.status === 'running' ? 'interrupt' : 'cancel';
btns.push(`<button class="btn btn-danger btn-sm" onclick="taskAction('${esc(t.id)}','${act}')">${label}</button>`);
}
if (t.restartable) {
btns.push(`<button class="btn btn-ghost btn-sm" onclick="taskAction('${esc(t.id)}','restart')">Restart</button>`);
}
if (!t.active) {
btns.push(`<button class="btn btn-ghost btn-sm" onclick="removeTask('${esc(t.id)}')">Remove</button>`);
}
return btns.join(' ') || '<span class="dim small">—</span>';
}
// ---- Live hardware telemetry ----
function _utilClass(pct){ return pct == null ? 'sys-ok' : (pct >= 90 ? 'sys-hot' : pct >= 70 ? 'sys-warn' : 'sys-ok'); }
function _tempClass(t){ return t == null ? '' : (t >= 90 ? 'sys-temp-hot' : t >= 80 ? 'sys-temp-warn' : 'sys-temp-ok'); }
function _bar(pct){
const p = pct == null ? 0 : Math.max(0, Math.min(100, pct));
return `<div class="sys-bar ${_utilClass(pct)}"><span style="width:${p}%"></span></div>`;
}
function _utilTile(name, pct, temp){
const valTxt = pct == null ? 'n/a' : `${Math.round(pct)}%`;
const tempTxt = temp == null ? '<span class="dim">temp n/a</span>'
: `<span class="${_tempClass(temp)}">${Math.round(temp)}°C</span>`;
return `<div class="sys-head"><span class="sys-name">${name}</span><span class="sys-val">${valTxt}</span></div>`
+ _bar(pct) + `<div class="sys-sub"><span>utilization</span>${tempTxt}</div>`;
}
function _memTile(name, used, total, pct){
const valTxt = (used == null || total == null) ? 'n/a' : `${used.toFixed(1)} / ${total.toFixed(1)} GB`;
const p = pct != null ? pct : (used != null && total ? used/total*100 : null);
return `<div class="sys-head"><span class="sys-name">${name}</span><span class="sys-val">${valTxt}</span></div>`
+ _bar(p) + `<div class="sys-sub"><span>${p == null ? '' : Math.round(p)+'% used'}</span><span></span></div>`;
}
async function loadSystemStats(){
try {
const s = await fetch(ROOT_PATH + '/admin/api/system-stats').then(r => r.json());
const cpu = s.cpu || {}, gpu = s.gpu || {}, ram = s.ram || {}, vram = s.vram || {};
document.getElementById('tile-cpu').innerHTML = _utilTile('CPU', cpu.util, cpu.temp);
document.getElementById('tile-gpu').innerHTML = _utilTile('GPU', gpu.util, gpu.temp);
document.getElementById('tile-ram').innerHTML = _memTile('RAM', ram.used, ram.total, ram.percent);
document.getElementById('tile-vram').innerHTML =
_memTile('VRAM', vram.used, vram.total, vram.percent);
} catch(e){ /* keep last render on transient errors */ }
}
let _refreshing = false;
async function loadTasks() {
if (_refreshing) return;
_refreshing = true;
try {
const data = await fetch(ROOT_PATH + '/admin/api/tasks').then(r => r.json());
const tasks = data.tasks || [];
const q = data.queue || {};
document.getElementById('queue-summary').textContent =
`${q.active || 0} active · ${q.waiting || 0} waiting · max ${q.max_parallel_requests || 0} parallel`;
const therm = data.thermal || {};
const banner = document.getElementById('thermal-banner');
if (therm.active) {
document.getElementById('thermal-banner-msg').textContent = ' ' + (therm.message || '');
banner.style.display = '';
} else {
banner.style.display = 'none';
}
const tbody = document.getElementById('tasks-body');
if (!tasks.length) {
tbody.innerHTML = '<tr class="empty-row"><td colspan="6">No tasks yet</td></tr>';
return;
}
tbody.innerHTML = tasks.map(t => {
const badge = STATUS_BADGE[t.status] || 'badge-dim';
const title = t.title || '(untitled)';
let statusCell;
if (t.cooling) {
statusCell = `<span class="badge badge-warn">❄ Cooling down</span>`
+ `<div class="dim small">${esc(t.cooling_message || 'paused for thermal cooldown')}</div>`;
} else if (t.paused) {
statusCell = `<span class="badge badge-warn">⏸ Paused</span>`
+ `<div class="dim small">suspended — click Resume to continue</div>`;
} else {
statusCell = `<span class="badge ${badge}">${esc(t.status)}</span>`
+ (t.message ? `<div class="dim small">${esc(t.message)}</div>` : '');
}
return `<tr>
<td><span class="badge badge-user">${esc(KIND_LABEL[t.kind] || t.kind)}</span></td>
<td><div class="td-name">${esc(title)}</div><div class="dim small mono">${esc(t.model || '')}</div></td>
<td>${statusCell}</td>
<td>${progressBar(t)}</td>
<td class="dim small">${fmtTime(t.started_at)}</td>
<td style="text-align:right">${actions(t)}</td>
</tr>`;
}).join('');
} catch (e) {
// transient fetch errors during a model swap are fine; keep last render.
} finally {
_refreshing = false;
}
}
async function taskAction(id, action) {
const verb = {cancel:'Cancel', interrupt:'Interrupt', restart:'Restart', pause:'Pause', resume:'Resume'}[action] || action;
// Only confirm destructive actions; pause/resume/restart act immediately.
if ((action === 'cancel' || action === 'interrupt') && !confirm(`${verb} this task?`)) return;
try {
const r = await fetch(ROOT_PATH + '/admin/api/tasks/' + encodeURIComponent(id) + '/' + action, {method:'POST'});
if (!r.ok) {
const e = await r.json().catch(() => ({}));
alert(e.detail || (verb + ' failed'));
}
} catch (e) { alert(e.message); }
loadTasks();
}
async function removeTask(id) {
try {
const r = await fetch(ROOT_PATH + '/admin/api/tasks/' + encodeURIComponent(id), {method:'DELETE'});
if (!r.ok) {
const e = await r.json().catch(() => ({}));
alert(e.detail || 'Remove failed');
}
} catch (e) { alert(e.message); }
loadTasks();
}
loadTasks();
loadSystemStats();
setInterval(loadTasks, 2000);
setInterval(loadSystemStats, 2000);
</script>
{% endblock %}
...@@ -189,25 +189,25 @@ if admin_static_dir.exists(): ...@@ -189,25 +189,25 @@ if admin_static_dir.exists():
app.mount("/static/admin", StaticFiles(directory=str(admin_static_dir)), name="admin_static") app.mount("/static/admin", StaticFiles(directory=str(admin_static_dir)), name="admin_static")
# Include routers from submodules # Include routers from submodules
app.include_router(transcriptions_router) app.include_router(transcriptions_router, tags=["Audio"])
app.include_router(images_router) app.include_router(images_router, tags=["Images"])
app.include_router(tts_router) app.include_router(tts_router, tags=["Audio"])
app.include_router(text_router) app.include_router(text_router, tags=["Text"])
app.include_router(video_router) app.include_router(video_router, tags=["Video"])
app.include_router(audio_gen_router) app.include_router(audio_gen_router, tags=["Audio"])
app.include_router(audio_stems_router) app.include_router(audio_stems_router, tags=["Audio"])
app.include_router(audio_clean_router) app.include_router(audio_clean_router, tags=["Audio"])
app.include_router(embeddings_router) app.include_router(embeddings_router, tags=["Embeddings"])
app.include_router(pipelines_router) app.include_router(pipelines_router, tags=["Pipelines"])
app.include_router(custom_pipelines_router) app.include_router(custom_pipelines_router, tags=["Pipelines"])
app.include_router(voice_clone_router) app.include_router(voice_clone_router, tags=["Audio"])
app.include_router(voice_convert_router) app.include_router(voice_convert_router, tags=["Audio"])
app.include_router(faceswap_router) app.include_router(faceswap_router, tags=["Images"])
app.include_router(characters_router) app.include_router(characters_router, tags=["Characters"])
app.include_router(loras_router) app.include_router(loras_router, tags=["LoRAs"])
app.include_router(environments_router) app.include_router(environments_router, tags=["Environments"])
app.include_router(spatial_router) app.include_router(spatial_router, tags=["Spatial / 3D"])
app.include_router(admin_router) app.include_router(admin_router, tags=["Admin"])
@app.exception_handler(401) @app.exception_handler(401)
...@@ -222,20 +222,35 @@ async def unauthorized_redirect(request: Request, exc: HTTPException): ...@@ -222,20 +222,35 @@ async def unauthorized_redirect(request: Request, exc: HTTPException):
return JSONResponse(status_code=401, content={"detail": exc.detail}) return JSONResponse(status_code=401, content={"detail": exc.detail})
@app.get("/v1/models", response_model=ModelList) from codai.tasks import TaskCancelled, task_registry
@app.exception_handler(TaskCancelled)
async def task_cancelled_handler(request: Request, exc: TaskCancelled):
"""A worker observed its task was cancelled and unwound. Finish the task
(cancelled) and return 499 (client-closed-request style). The task id is
carried on the exception so any generation/training worker can simply
`raise` without bookkeeping."""
tid = exc.args[0] if exc.args else None
if tid:
task_registry.finish(tid, "cancelled", "cancelled by user")
return JSONResponse(status_code=499, content={"detail": "Task cancelled", "task_id": tid})
@app.get("/v1/models", response_model=ModelList, summary="List available models", tags=["Core"])
async def list_models(): async def list_models():
"""List available models.""" """List available models."""
models = multi_model_manager.list_models() models = multi_model_manager.list_models()
return ModelList(data=models) return ModelList(data=models)
@app.get("/coderai/capabilities") @app.get("/coderai/capabilities", summary="Server capability document", tags=["Core"])
async def get_broker_capabilities(): async def get_broker_capabilities():
"""Return broker capability metadata.""" """Return broker capability metadata."""
return build_capabilities_document(hardware=build_hardware_summary()) return build_capabilities_document(hardware=build_hardware_summary())
@app.get("/v1/files/{filename}") @app.get("/v1/files/{filename}", summary="Download a generated file", tags=["Files"])
async def get_file(filename: str): async def get_file(filename: str):
"""Serve uploaded/generated files.""" """Serve uploaded/generated files."""
if not global_file_path: if not global_file_path:
...@@ -256,7 +271,7 @@ _VIDEO_EXTS = {'.mp4', '.webm', '.avi', '.mov'} ...@@ -256,7 +271,7 @@ _VIDEO_EXTS = {'.mp4', '.webm', '.avi', '.mov'}
_AUDIO_EXTS = {'.wav', '.mp3', '.ogg', '.flac', '.aac', '.m4a'} _AUDIO_EXTS = {'.wav', '.mp3', '.ogg', '.flac', '.aac', '.m4a'}
@app.get("/v1/archive") @app.get("/v1/archive", summary="List archived generations", tags=["Files"])
async def list_archive(request: Request): async def list_archive(request: Request):
"""List all generated files in the output directory.""" """List all generated files in the output directory."""
if not global_file_path or not os.path.isdir(global_file_path): if not global_file_path or not os.path.isdir(global_file_path):
...@@ -292,7 +307,7 @@ async def list_archive(request: Request): ...@@ -292,7 +307,7 @@ async def list_archive(request: Request):
return {"files": files} return {"files": files}
@app.delete("/v1/archive/{filename}") @app.delete("/v1/archive/{filename}", summary="Delete an archived file", tags=["Files"])
async def delete_archive_file(filename: str): async def delete_archive_file(filename: str):
"""Delete a generated file from the output directory.""" """Delete a generated file from the output directory."""
if not global_file_path: if not global_file_path:
......
...@@ -116,8 +116,15 @@ class AudioCleanupRequest(BaseModel): ...@@ -116,8 +116,15 @@ class AudioCleanupRequest(BaseModel):
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow")
@router.post("/v1/audio/cleanup") @router.post("/v1/audio/cleanup", summary="Clean / restore audio")
async def cleanup_audio(request: AudioCleanupRequest, http_request: Request = None): async def cleanup_audio(request: AudioCleanupRequest, http_request: Request = None):
"""Restore/clean a noisy audio clip.
Applies any combination of noise reduction, loudness normalization, mains-hum
removal and click/crackle repair. Uses an ML restoration backend when available,
falling back to an ffmpeg-based best-effort path when `fallback_mode` is set.
Returns the cleaned audio plus the backend and quality tier that were used.
"""
try: try:
audio_bytes = _decode_audio(request.audio) audio_bytes = _decode_audio(request.audio)
except Exception as exc: except Exception as exc:
......
...@@ -31,6 +31,7 @@ from fastapi import APIRouter, HTTPException, Request ...@@ -31,6 +31,7 @@ from fastapi import APIRouter, HTTPException, Request
from codai.models.manager import multi_model_manager from codai.models.manager import multi_model_manager
from codai.pydantic.audiogenrequest import AudioGenerationRequest, AudioGenerationResponse from codai.pydantic.audiogenrequest import AudioGenerationRequest, AudioGenerationResponse
from codai.tasks import task_registry, TaskCancelled
router = APIRouter() router = APIRouter()
...@@ -160,7 +161,7 @@ def _detect_audio_gen_type(model_name: str) -> str: ...@@ -160,7 +161,7 @@ def _detect_audio_gen_type(model_name: str) -> str:
return 'musicgen' return 'musicgen'
def _generate_audio(pipe, model_name: str, request: AudioGenerationRequest): def _generate_audio(pipe, model_name: str, request: AudioGenerationRequest, task_id=None):
"""Run generation and return (audio_bytes, ext).""" """Run generation and return (audio_bytes, ext)."""
import numpy as np, io as _io import numpy as np, io as _io
...@@ -191,6 +192,9 @@ def _generate_audio(pipe, model_name: str, request: AudioGenerationRequest): ...@@ -191,6 +192,9 @@ def _generate_audio(pipe, model_name: str, request: AudioGenerationRequest):
_aud_progress_reset(num_steps, unit="it") _aud_progress_reset(num_steps, unit="it")
def _aud_step_cb(pipe, step_index, timestep, callback_kwargs): def _aud_step_cb(pipe, step_index, timestep, callback_kwargs):
task_registry.raise_if_cancelled(task_id)
task_registry.wait_if_paused(task_id)
task_registry.step(task_id, step_index + 1)
_aud_progress_step(step_index + 1) _aud_progress_step(step_index + 1)
return callback_kwargs return callback_kwargs
...@@ -222,7 +226,7 @@ def _decode_b64_or_url(data: str) -> bytes: ...@@ -222,7 +226,7 @@ def _decode_b64_or_url(data: str) -> bytes:
return base64.b64decode(data) return base64.b64decode(data)
@router.get("/v1/audio/progress") @router.get("/v1/audio/progress", summary="Audio generation progress")
async def get_audio_progress(): async def get_audio_progress():
"""Return current audio generation progress including speed.""" """Return current audio generation progress including speed."""
elapsed = time.monotonic() - _aud_progress["started_at"] if _aud_progress["active"] else 0.0 elapsed = time.monotonic() - _aud_progress["started_at"] if _aud_progress["active"] else 0.0
...@@ -241,7 +245,7 @@ async def get_audio_progress(): ...@@ -241,7 +245,7 @@ async def get_audio_progress():
} }
@router.post("/v1/audio/generate", response_model=AudioGenerationResponse) @router.post("/v1/audio/generate", response_model=AudioGenerationResponse, summary="Generate audio, music or SFX")
async def audio_generate(request: AudioGenerationRequest, http_request: Request = None): async def audio_generate(request: AudioGenerationRequest, http_request: Request = None):
""" """
Generate music, sound effects, or ambient audio. Generate music, sound effects, or ambient audio.
...@@ -262,26 +266,36 @@ async def audio_generate(request: AudioGenerationRequest, http_request: Request ...@@ -262,26 +266,36 @@ async def audio_generate(request: AudioGenerationRequest, http_request: Request
device = _derive_device() device = _derive_device()
model_type = _detect_audio_gen_type(model_name) model_type = _detect_audio_gen_type(model_name)
_ag_cfg = model_info.get('config') or {} _ag_cfg = model_info.get('config') or {}
from codai.tasks import loading_task
try: try:
if model_type in ('musicgen', 'audiogen'): with loading_task(model_name, model_type="audio"):
pipe = await asyncio.get_event_loop().run_in_executor( if model_type in ('musicgen', 'audiogen'):
None, _load_musicgen, model_name, device) pipe = await asyncio.get_event_loop().run_in_executor(
else: None, _load_musicgen, model_name, device)
pipe = await asyncio.get_event_loop().run_in_executor( else:
None, _load_audioldm, model_name, device, _ag_cfg) pipe = await asyncio.get_event_loop().run_in_executor(
None, _load_audioldm, model_name, device, _ag_cfg)
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to load audio gen model: {e}") raise HTTPException(status_code=500, detail=f"Failed to load audio gen model: {e}")
multi_model_manager.models[model_key] = pipe multi_model_manager.models[model_key] = pipe
multi_model_manager.current_model_key = model_key multi_model_manager.current_model_key = model_key
_tid = task_registry.register(
"audio", title=(request.prompt or "")[:80], model=model_name or "")
task_registry.start(_tid)
try: try:
audio_bytes, ext = await asyncio.get_event_loop().run_in_executor( audio_bytes, ext = await asyncio.get_event_loop().run_in_executor(
None, _generate_audio, pipe, model_name, request) None, _generate_audio, pipe, model_name, request, _tid)
except TaskCancelled:
_aud_progress_done()
raise # global handler finishes the task (cancelled) + returns HTTP 499
except Exception as e: except Exception as e:
task_registry.finish(_tid, "error", str(e)[:200])
_aud_progress_done() _aud_progress_done()
raise HTTPException(status_code=500, detail=f"Audio generation failed: {e}") raise HTTPException(status_code=500, detail=f"Audio generation failed: {e}")
finally: finally:
_aud_progress_done() _aud_progress_done()
task_registry.finish(_tid, "done")
result = _save_audio_response(audio_bytes, ext, http_request) result = _save_audio_response(audio_bytes, ext, http_request)
......
...@@ -166,8 +166,15 @@ class AudioStemRequest(BaseModel): ...@@ -166,8 +166,15 @@ class AudioStemRequest(BaseModel):
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow")
@router.post("/v1/audio/stems") @router.post("/v1/audio/stems", summary="Separate audio into stems")
async def separate_stems(request: AudioStemRequest, http_request: Request = None): async def separate_stems(request: AudioStemRequest, http_request: Request = None):
"""Split a track into its component stems (source separation).
Separates an input clip according to `stem_mode` (e.g. vocals/instrumental, or a
full 4-stem split). Uses an ML separation provider when available, falling back to
an ffmpeg-based best-effort split when `fallback_mode` is set. Returns one audio
output per stem along with the backend and quality tier used.
"""
try: try:
audio_bytes = _decode_audio(request.audio) audio_bytes = _decode_audio(request.audio)
except Exception as exc: except Exception as exc:
......
...@@ -419,7 +419,7 @@ def resolve_character_profiles(profile_names: List[str]) -> List[str]: ...@@ -419,7 +419,7 @@ def resolve_character_profiles(profile_names: List[str]) -> List[str]:
# ── Endpoints ───────────────────────────────────────────────────────────────── # ── Endpoints ─────────────────────────────────────────────────────────────────
@router.post("/v1/characters") @router.post("/v1/characters", summary="Create or replace a character profile")
async def save_character(req: CharacterSaveRequest, _auth=Depends(_require_api_auth)): async def save_character(req: CharacterSaveRequest, _auth=Depends(_require_api_auth)):
"""Save or update a named character profile.""" """Save or update a named character profile."""
if not req.name or '/' in req.name or '..' in req.name: if not req.name or '/' in req.name or '..' in req.name:
...@@ -430,13 +430,13 @@ async def save_character(req: CharacterSaveRequest, _auth=Depends(_require_api_a ...@@ -430,13 +430,13 @@ async def save_character(req: CharacterSaveRequest, _auth=Depends(_require_api_a
return {"ok": True, "name": meta['name'], "image_count": meta['image_count']} return {"ok": True, "name": meta['name'], "image_count": meta['image_count']}
@router.get("/v1/characters") @router.get("/v1/characters", summary="List character profiles")
async def list_characters(_auth=Depends(_require_api_auth)): async def list_characters(_auth=Depends(_require_api_auth)):
"""List all saved character profiles (metadata only, no images).""" """List all saved character profiles (metadata only, no images)."""
return {"characters": _list_characters()} return {"characters": _list_characters()}
@router.get("/v1/characters/{name}") @router.get("/v1/characters/{name}", summary="Get a character profile")
async def get_character(name: str, _auth=Depends(_require_api_auth)): async def get_character(name: str, _auth=Depends(_require_api_auth)):
"""Get a character profile including its reference images as base64.""" """Get a character profile including its reference images as base64."""
meta = _load_character_meta(name) meta = _load_character_meta(name)
...@@ -452,7 +452,7 @@ async def get_character(name: str, _auth=Depends(_require_api_auth)): ...@@ -452,7 +452,7 @@ async def get_character(name: str, _auth=Depends(_require_api_auth)):
} }
@router.delete("/v1/characters/{name}") @router.delete("/v1/characters/{name}", summary="Delete a character profile")
async def delete_character(name: str, _auth=Depends(_require_api_auth)): async def delete_character(name: str, _auth=Depends(_require_api_auth)):
"""Delete a character profile.""" """Delete a character profile."""
cdir = _char_dir(name) cdir = _char_dir(name)
...@@ -463,7 +463,7 @@ async def delete_character(name: str, _auth=Depends(_require_api_auth)): ...@@ -463,7 +463,7 @@ async def delete_character(name: str, _auth=Depends(_require_api_auth)):
return {"ok": True, "name": name} return {"ok": True, "name": name}
@router.patch("/v1/characters/{name}") @router.patch("/v1/characters/{name}", summary="Update a character profile")
async def patch_character(name: str, req: CharacterPatchRequest, _auth=Depends(_require_api_auth)): async def patch_character(name: str, req: CharacterPatchRequest, _auth=Depends(_require_api_auth)):
"""Update a character profile: description, add images, or remove images by index.""" """Update a character profile: description, add images, or remove images by index."""
meta = _load_character_meta(name) meta = _load_character_meta(name)
...@@ -512,7 +512,7 @@ async def patch_character(name: str, req: CharacterPatchRequest, _auth=Depends(_ ...@@ -512,7 +512,7 @@ async def patch_character(name: str, req: CharacterPatchRequest, _auth=Depends(_
return {"ok": True, "name": name, "image_count": meta['image_count']} return {"ok": True, "name": name, "image_count": meta['image_count']}
@router.post("/v1/characters/generate") @router.post("/v1/characters/generate", summary="Generate character reference images")
async def generate_character(req: CharacterGenerateRequest, request: Request): async def generate_character(req: CharacterGenerateRequest, request: Request):
""" """
Generate a character profile from a text prompt. Generate a character profile from a text prompt.
...@@ -585,7 +585,7 @@ async def generate_character(req: CharacterGenerateRequest, request: Request): ...@@ -585,7 +585,7 @@ async def generate_character(req: CharacterGenerateRequest, request: Request):
return {"ok": True, "name": meta["name"], "image_count": meta["image_count"]} return {"ok": True, "name": meta["name"], "image_count": meta["image_count"]}
@router.post("/v1/characters/extract") @router.post("/v1/characters/extract", summary="Extract a character from media")
async def extract_character(req: CharacterExtractRequest): async def extract_character(req: CharacterExtractRequest):
""" """
Extract a character profile from source images and/or videos. Extract a character profile from source images and/or videos.
......
...@@ -380,7 +380,7 @@ class AudioMusicDubRequest(BaseModel): ...@@ -380,7 +380,7 @@ class AudioMusicDubRequest(BaseModel):
model_config = ConfigDict(extra='allow') model_config = ConfigDict(extra='allow')
@router.get('/v1/pipelines/custom') @router.get('/v1/pipelines/custom', summary="List saved custom pipelines")
async def list_custom_pipelines(): async def list_custom_pipelines():
"""List all saved custom pipeline definitions.""" """List all saved custom pipeline definitions."""
from codai.admin.routes import config_manager from codai.admin.routes import config_manager
...@@ -389,7 +389,7 @@ async def list_custom_pipelines(): ...@@ -389,7 +389,7 @@ async def list_custom_pipelines():
return {'pipelines': config_manager.pipelines_data} return {'pipelines': config_manager.pipelines_data}
@router.get('/v1/pipelines/step-types') @router.get('/v1/pipelines/step-types', summary="List available pipeline step types")
async def list_step_types(): async def list_step_types():
"""List available step types with their parameter schemas.""" """List available step types with their parameter schemas."""
return { return {
...@@ -400,7 +400,7 @@ async def list_step_types(): ...@@ -400,7 +400,7 @@ async def list_step_types():
} }
@router.post('/v1/pipelines/custom') @router.post('/v1/pipelines/custom', summary="Create a custom pipeline")
async def create_custom_pipeline(pipeline: PipelineDefinition): async def create_custom_pipeline(pipeline: PipelineDefinition):
"""Save a new custom pipeline definition.""" """Save a new custom pipeline definition."""
from codai.admin.routes import config_manager from codai.admin.routes import config_manager
...@@ -416,7 +416,7 @@ async def create_custom_pipeline(pipeline: PipelineDefinition): ...@@ -416,7 +416,7 @@ async def create_custom_pipeline(pipeline: PipelineDefinition):
return {'created': True, 'pipeline': data} return {'created': True, 'pipeline': data}
@router.put('/v1/pipelines/custom/{pipeline_id}') @router.put('/v1/pipelines/custom/{pipeline_id}', summary="Update a custom pipeline")
async def update_custom_pipeline(pipeline_id: str, pipeline: PipelineDefinition): async def update_custom_pipeline(pipeline_id: str, pipeline: PipelineDefinition):
"""Update an existing custom pipeline.""" """Update an existing custom pipeline."""
from codai.admin.routes import config_manager from codai.admin.routes import config_manager
...@@ -433,7 +433,7 @@ async def update_custom_pipeline(pipeline_id: str, pipeline: PipelineDefinition) ...@@ -433,7 +433,7 @@ async def update_custom_pipeline(pipeline_id: str, pipeline: PipelineDefinition)
return {'updated': True, 'pipeline': data} return {'updated': True, 'pipeline': data}
@router.delete('/v1/pipelines/custom/{pipeline_id}') @router.delete('/v1/pipelines/custom/{pipeline_id}', summary="Delete a custom pipeline")
async def delete_custom_pipeline(pipeline_id: str): async def delete_custom_pipeline(pipeline_id: str):
"""Delete a custom pipeline.""" """Delete a custom pipeline."""
from codai.admin.routes import config_manager from codai.admin.routes import config_manager
...@@ -447,7 +447,7 @@ async def delete_custom_pipeline(pipeline_id: str): ...@@ -447,7 +447,7 @@ async def delete_custom_pipeline(pipeline_id: str):
return {'deleted': True, 'id': pipeline_id} return {'deleted': True, 'id': pipeline_id}
@router.post('/v1/pipelines/custom/{pipeline_id}/run') @router.post('/v1/pipelines/custom/{pipeline_id}/run', summary="Run a saved custom pipeline")
async def run_custom_pipeline(pipeline_id: str, body: PipelineRunRequest, http_request: Request = None): async def run_custom_pipeline(pipeline_id: str, body: PipelineRunRequest, http_request: Request = None):
"""Execute a saved custom pipeline.""" """Execute a saved custom pipeline."""
from codai.admin.routes import config_manager from codai.admin.routes import config_manager
...@@ -459,14 +459,20 @@ async def run_custom_pipeline(pipeline_id: str, body: PipelineRunRequest, http_r ...@@ -459,14 +459,20 @@ async def run_custom_pipeline(pipeline_id: str, body: PipelineRunRequest, http_r
return await _execute_pipeline(pipeline_def, body.input or '', http_request) return await _execute_pipeline(pipeline_def, body.input or '', http_request)
@router.post('/v1/pipelines/run') @router.post('/v1/pipelines/run', summary="Run an inline pipeline definition")
async def run_inline_pipeline(pipeline: PipelineDefinition, http_request: Request = None): async def run_inline_pipeline(pipeline: PipelineDefinition, http_request: Request = None):
"""Execute an inline pipeline definition without saving it.""" """Execute an inline pipeline definition without saving it."""
return await _execute_pipeline(pipeline.model_dump(), '', http_request) return await _execute_pipeline(pipeline.model_dump(), '', http_request)
@router.post('/v1/pipelines/audio-understand') @router.post('/v1/pipelines/audio-understand', summary="Transcribe and analyze audio")
async def run_audio_understanding(request: AudioUnderstandRequest, http_request: Request = None): async def run_audio_understanding(request: AudioUnderstandRequest, http_request: Request = None):
"""Transcribe and analyze an audio clip in one pass.
Convenience pipeline that transcribes the input audio and then reasons over the
transcript (summary/understanding) using the configured text model. Returns the
transcript together with the model's analysis.
"""
if not request.audio: if not request.audio:
raise HTTPException(status_code=400, detail='Provide audio input') raise HTTPException(status_code=400, detail='Provide audio input')
...@@ -543,8 +549,14 @@ async def run_full_music_dub(request: AudioMusicDubRequest, http_request: Reques ...@@ -543,8 +549,14 @@ async def run_full_music_dub(request: AudioMusicDubRequest, http_request: Reques
} }
@router.post('/v1/pipelines/audio-music-dub') @router.post('/v1/pipelines/audio-music-dub', summary="Dub a song into another language")
async def run_audio_music_dub(request: AudioMusicDubRequest, http_request: Request = None): async def run_audio_music_dub(request: AudioMusicDubRequest, http_request: Request = None):
"""Dub a song into another language while preserving the backing music.
Splits the track into vocals and instrumental, transcribes and translates the
lyrics, re-sings/voice-converts the translated vocals, then remixes them over the
original instrumental. Returns every intermediate stem plus the final mixed result.
"""
if not request.audio: if not request.audio:
raise HTTPException(status_code=400, detail='Provide audio input') raise HTTPException(status_code=400, detail='Provide audio input')
......
...@@ -101,7 +101,7 @@ def _embed_texts(model_obj, texts: List[str], dimensions=None) -> List[List[floa ...@@ -101,7 +101,7 @@ def _embed_texts(model_obj, texts: List[str], dimensions=None) -> List[List[floa
return results return results
@router.post("/v1/embeddings", response_model=EmbeddingsResponse) @router.post("/v1/embeddings", response_model=EmbeddingsResponse, summary="Create embeddings")
async def create_embeddings(request: EmbeddingsRequest, http_request: Request = None): async def create_embeddings(request: EmbeddingsRequest, http_request: Request = None):
""" """
OpenAI-compatible embeddings endpoint. OpenAI-compatible embeddings endpoint.
...@@ -116,13 +116,16 @@ async def create_embeddings(request: EmbeddingsRequest, http_request: Request = ...@@ -116,13 +116,16 @@ async def create_embeddings(request: EmbeddingsRequest, http_request: Request =
model_key = model_info['model_key'] model_key = model_info['model_key']
model_obj = model_info.get('model_object') model_obj = model_info.get('model_object')
_emb_cfg = (multi_model_manager.config.get(f"embedding:{model_name}")
or multi_model_manager.config.get(model_name) or {})
if model_obj is None: if model_obj is None:
device = _derive_device() device = _derive_device()
_emb_cfg = (multi_model_manager.config.get(f"embedding:{model_name}") from codai.tasks import loading_task
or multi_model_manager.config.get(model_name) or {})
try: try:
model_obj = await asyncio.get_event_loop().run_in_executor( with loading_task(model_name, model_type="embedding"):
None, _load_embedding_model, model_name, device, _emb_cfg) model_obj = await asyncio.get_event_loop().run_in_executor(
None, _load_embedding_model, model_name, device, _emb_cfg)
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to load embedding model: {e}") raise HTTPException(status_code=500, detail=f"Failed to load embedding model: {e}")
multi_model_manager.models[model_key] = model_obj multi_model_manager.models[model_key] = model_obj
...@@ -136,7 +139,59 @@ async def create_embeddings(request: EmbeddingsRequest, http_request: Request = ...@@ -136,7 +139,59 @@ async def create_embeddings(request: EmbeddingsRequest, http_request: Request =
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail=f"Embedding failed: {e}") raise HTTPException(status_code=500, detail=f"Embedding failed: {e}")
if request.encoding_format == 'base64': # Optional TurboQuant vector quantization (data-free, inner-product preserving).
# The per-model config block (turboquant: {enabled, backend, bits}) is the
# source of truth for enable/disable + which implementation to use; the
# per-request `quantization` field triggers it and can override the bit width.
from codai.models import turboquant as _tq
_raw = _emb_cfg.get('_raw_cfg') if isinstance(_emb_cfg.get('_raw_cfg'), dict) else {}
tq_cfg = _emb_cfg.get('turboquant') or _raw.get('turboquant') or {}
tq_enabled = tq_cfg.get('enabled', None) # None = no explicit model setting
tq_backend = (tq_cfg.get('backend') or 'builtin')
quant_meta = None
quant_bits = None
req_spec = getattr(request, 'quantization', None)
if not req_spec and tq_enabled and tq_cfg.get('bits'):
req_spec = f"turbo{tq_cfg.get('bits')}" # model-configured default
if req_spec:
if tq_enabled is False:
raise HTTPException(
status_code=400,
detail="TurboQuant is disabled for this model (enable it in the "
"model configuration).")
quant_bits = _tq._parse_quant_spec(req_spec)
if quant_bits is None:
raise HTTPException(
status_code=400,
detail=f"Unsupported quantization '{req_spec}' "
"(use 'turbo'/'turbo8'/'turbo6'/'turbo4'/'turbo2')")
if quant_bits is not None and request.encoding_format == 'base64':
# Compact wire form: each embedding is base64 of [f16 norm][packed codes].
# The compact packing is the built-in wire format regardless of backend
# (the upstream library exposes its own opaque store, not per-vector blobs).
blobs, meta = await asyncio.get_event_loop().run_in_executor(
None, _tq.quantize_base64, vectors, quant_bits)
data = [EmbeddingObject(index=i, embedding=b) for i, b in enumerate(blobs)]
quant_meta = {
"method": meta.method, "bits": meta.bits, "seed": meta.seed,
"dim": meta.dim, "dim_padded": meta.dim_padded, "radius": meta.radius,
"bytes_per_vector": meta.bytes_per_vector, "backend": "builtin",
"layout": "base64([float16 norm][packbits(rotated b-bit codes, MSB-first per numpy.packbits)])",
}
elif quant_bits is not None:
# Lossy reconstruction returned as plain floats (quantized-store fidelity).
try:
vectors = await asyncio.get_event_loop().run_in_executor(
None, lambda: _tq.reconstruct(vectors, quant_bits, backend=tq_backend))
except RuntimeError as e:
raise HTTPException(status_code=400, detail=str(e))
data = [EmbeddingObject(index=i, embedding=v) for i, v in enumerate(vectors)]
eff_backend = tq_backend if tq_backend != 'auto' else _tq.backend_name()
quant_meta = {"method": "turboquant", "bits": quant_bits,
"encoding": "float-reconstruction", "backend": eff_backend}
elif request.encoding_format == 'base64':
import struct import struct
data = [EmbeddingObject( data = [EmbeddingObject(
index=i, index=i,
...@@ -146,8 +201,11 @@ async def create_embeddings(request: EmbeddingsRequest, http_request: Request = ...@@ -146,8 +201,11 @@ async def create_embeddings(request: EmbeddingsRequest, http_request: Request =
data = [EmbeddingObject(index=i, embedding=v) for i, v in enumerate(vectors)] data = [EmbeddingObject(index=i, embedding=v) for i, v in enumerate(vectors)]
total_tokens = sum(len(t.split()) for t in texts) total_tokens = sum(len(t.split()) for t in texts)
return EmbeddingsResponse( resp = EmbeddingsResponse(
data=data, data=data,
model=request.model, model=request.model,
usage={"prompt_tokens": total_tokens, "total_tokens": total_tokens}, usage={"prompt_tokens": total_tokens, "total_tokens": total_tokens},
) )
\ No newline at end of file if quant_meta is not None:
resp.quantization = quant_meta
return resp
\ No newline at end of file
...@@ -307,7 +307,7 @@ def resolve_environment_profiles(profile_names: List[str]) -> List[str]: ...@@ -307,7 +307,7 @@ def resolve_environment_profiles(profile_names: List[str]) -> List[str]:
# ── Endpoints ───────────────────────────────────────────────────────────────── # ── Endpoints ─────────────────────────────────────────────────────────────────
@router.post("/v1/environments") @router.post("/v1/environments", summary="Create or replace an environment profile")
async def save_environment(req: EnvironmentSaveRequest, _auth=Depends(_require_api_auth)): async def save_environment(req: EnvironmentSaveRequest, _auth=Depends(_require_api_auth)):
"""Save or update a named environment profile.""" """Save or update a named environment profile."""
if not req.name or '/' in req.name or '..' in req.name: if not req.name or '/' in req.name or '..' in req.name:
...@@ -318,13 +318,13 @@ async def save_environment(req: EnvironmentSaveRequest, _auth=Depends(_require_a ...@@ -318,13 +318,13 @@ async def save_environment(req: EnvironmentSaveRequest, _auth=Depends(_require_a
return {"ok": True, "name": meta['name'], "image_count": meta['image_count']} return {"ok": True, "name": meta['name'], "image_count": meta['image_count']}
@router.get("/v1/environments") @router.get("/v1/environments", summary="List environment profiles")
async def list_environments(_auth=Depends(_require_api_auth)): async def list_environments(_auth=Depends(_require_api_auth)):
"""List all saved environment profiles (metadata only).""" """List all saved environment profiles (metadata only)."""
return {"environments": _list_environments()} return {"environments": _list_environments()}
@router.get("/v1/environments/{name}") @router.get("/v1/environments/{name}", summary="Get an environment profile")
async def get_environment(name: str, _auth=Depends(_require_api_auth)): async def get_environment(name: str, _auth=Depends(_require_api_auth)):
"""Get an environment profile including its reference images as base64.""" """Get an environment profile including its reference images as base64."""
meta = _load_environment_meta(name) meta = _load_environment_meta(name)
...@@ -340,7 +340,7 @@ async def get_environment(name: str, _auth=Depends(_require_api_auth)): ...@@ -340,7 +340,7 @@ async def get_environment(name: str, _auth=Depends(_require_api_auth)):
} }
@router.delete("/v1/environments/{name}") @router.delete("/v1/environments/{name}", summary="Delete an environment profile")
async def delete_environment(name: str, _auth=Depends(_require_api_auth)): async def delete_environment(name: str, _auth=Depends(_require_api_auth)):
"""Delete an environment profile.""" """Delete an environment profile."""
edir = _env_dir(name) edir = _env_dir(name)
...@@ -351,7 +351,7 @@ async def delete_environment(name: str, _auth=Depends(_require_api_auth)): ...@@ -351,7 +351,7 @@ async def delete_environment(name: str, _auth=Depends(_require_api_auth)):
return {"ok": True, "name": name} return {"ok": True, "name": name}
@router.patch("/v1/environments/{name}") @router.patch("/v1/environments/{name}", summary="Update an environment profile")
async def patch_environment(name: str, req: EnvironmentPatchRequest, _auth=Depends(_require_api_auth)): async def patch_environment(name: str, req: EnvironmentPatchRequest, _auth=Depends(_require_api_auth)):
"""Update an environment profile: description, add images, or remove images by index.""" """Update an environment profile: description, add images, or remove images by index."""
meta = _load_environment_meta(name) meta = _load_environment_meta(name)
...@@ -398,7 +398,7 @@ async def patch_environment(name: str, req: EnvironmentPatchRequest, _auth=Depen ...@@ -398,7 +398,7 @@ async def patch_environment(name: str, req: EnvironmentPatchRequest, _auth=Depen
return {"ok": True, "name": name, "image_count": meta['image_count']} return {"ok": True, "name": name, "image_count": meta['image_count']}
@router.post("/v1/environments/generate") @router.post("/v1/environments/generate", summary="Generate environment reference images")
async def generate_environment(req: EnvironmentGenerateRequest, request: Request): async def generate_environment(req: EnvironmentGenerateRequest, request: Request):
""" """
Generate an environment profile from a text prompt. Generate an environment profile from a text prompt.
...@@ -471,7 +471,7 @@ async def generate_environment(req: EnvironmentGenerateRequest, request: Request ...@@ -471,7 +471,7 @@ async def generate_environment(req: EnvironmentGenerateRequest, request: Request
return {"ok": True, "name": meta["name"], "image_count": meta["image_count"]} return {"ok": True, "name": meta["name"], "image_count": meta["image_count"]}
@router.post("/v1/environments/extract") @router.post("/v1/environments/extract", summary="Extract an environment from media")
async def extract_environment(req: EnvironmentExtractRequest): async def extract_environment(req: EnvironmentExtractRequest):
""" """
Extract an environment profile from source images and/or videos. Extract an environment profile from source images and/or videos.
......
...@@ -144,7 +144,7 @@ class FaceSwapRequest(BaseModel): ...@@ -144,7 +144,7 @@ class FaceSwapRequest(BaseModel):
# Endpoint # Endpoint
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@router.post('/v1/images/faceswap') @router.post('/v1/images/faceswap', summary="Swap faces between images")
async def faceswap(request: FaceSwapRequest, http_request: Request = None): async def faceswap(request: FaceSwapRequest, http_request: Request = None):
""" """
Swap the face from source_face into every face found in target. Swap the face from source_face into every face found in target.
......
This diff is collapsed.
This diff is collapsed.
...@@ -117,7 +117,7 @@ class ImageToVideoPipelineRequest(BaseModel): ...@@ -117,7 +117,7 @@ class ImageToVideoPipelineRequest(BaseModel):
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow")
@router.post("/v1/pipelines/image-to-video") @router.post("/v1/pipelines/image-to-video", summary="Image-to-video pipeline")
async def pipeline_image_to_video(request: ImageToVideoPipelineRequest, http_request: Request = None): async def pipeline_image_to_video(request: ImageToVideoPipelineRequest, http_request: Request = None):
"""Generate an image then animate it into a video.""" """Generate an image then animate it into a video."""
steps = [] steps = []
...@@ -197,7 +197,7 @@ class VideoDubPipelineRequest(BaseModel): ...@@ -197,7 +197,7 @@ class VideoDubPipelineRequest(BaseModel):
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow")
@router.post("/v1/pipelines/video-dub") @router.post("/v1/pipelines/video-dub", summary="Video dubbing pipeline")
async def pipeline_video_dub(request: VideoDubPipelineRequest, http_request: Request = None): async def pipeline_video_dub(request: VideoDubPipelineRequest, http_request: Request = None):
"""Transcribe → translate → TTS dub → burn subtitles.""" """Transcribe → translate → TTS dub → burn subtitles."""
body = { body = {
...@@ -240,7 +240,7 @@ class StoryPipelineRequest(BaseModel): ...@@ -240,7 +240,7 @@ class StoryPipelineRequest(BaseModel):
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow")
@router.post("/v1/pipelines/story") @router.post("/v1/pipelines/story", summary="Story pipeline (multi-scene)")
async def pipeline_story(request: StoryPipelineRequest, http_request: Request = None): async def pipeline_story(request: StoryPipelineRequest, http_request: Request = None):
"""LLM generates script → image per scene → animate first scene → optional TTS narration.""" """LLM generates script → image per scene → animate first scene → optional TTS narration."""
n = min(request.num_scenes or 3, 6) n = min(request.num_scenes or 3, 6)
...@@ -377,7 +377,7 @@ class AudioDubPipelineRequest(BaseModel): ...@@ -377,7 +377,7 @@ class AudioDubPipelineRequest(BaseModel):
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow")
@router.post("/v1/pipelines/audio-dub") @router.post("/v1/pipelines/audio-dub", summary="Audio dubbing pipeline")
async def pipeline_audio_dub(request: AudioDubPipelineRequest, http_request: Request = None): async def pipeline_audio_dub(request: AudioDubPipelineRequest, http_request: Request = None):
"""Transcribe → (translate) → clone voice → replace audio track.""" """Transcribe → (translate) → clone voice → replace audio track."""
import os, tempfile, subprocess, base64 import os, tempfile, subprocess, base64
......
...@@ -499,7 +499,7 @@ class ImageTo3DRequest(BaseModel): ...@@ -499,7 +499,7 @@ class ImageTo3DRequest(BaseModel):
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow")
@router.post("/v1/images/to3d") @router.post("/v1/images/to3d", summary="Image to 3D model")
async def image_to_3d(request: ImageTo3DRequest, http_request: Request = None): async def image_to_3d(request: ImageTo3DRequest, http_request: Request = None):
"""Convert a 2D image to a 3D representation. """Convert a 2D image to a 3D representation.
...@@ -567,7 +567,7 @@ class ImageFrom3DRequest(BaseModel): ...@@ -567,7 +567,7 @@ class ImageFrom3DRequest(BaseModel):
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow")
@router.post("/v1/images/from3d") @router.post("/v1/images/from3d", summary="Render a 3D model to an image")
async def image_from_3d(request: ImageFrom3DRequest, http_request: Request = None): async def image_from_3d(request: ImageFrom3DRequest, http_request: Request = None):
"""Render a 3D model (GLB/OBJ) to a 2D PNG image from a specified camera angle.""" """Render a 3D model (GLB/OBJ) to a 2D PNG image from a specified camera angle."""
raw = _decode_b64(request.model_data) raw = _decode_b64(request.model_data)
...@@ -600,7 +600,7 @@ class VideoTo3DRequest(BaseModel): ...@@ -600,7 +600,7 @@ class VideoTo3DRequest(BaseModel):
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow")
@router.post("/v1/video/to3d") @router.post("/v1/video/to3d", summary="Video to 3D model")
async def video_to_3d(request: VideoTo3DRequest, http_request: Request = None): async def video_to_3d(request: VideoTo3DRequest, http_request: Request = None):
"""Convert a 2D video to a 3D video frame-by-frame. """Convert a 2D video to a 3D video frame-by-frame.
...@@ -641,7 +641,7 @@ class VideoFrom3DRequest(BaseModel): ...@@ -641,7 +641,7 @@ class VideoFrom3DRequest(BaseModel):
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow")
@router.post("/v1/video/from3d") @router.post("/v1/video/from3d", summary="Render a 3D model to a video")
async def video_from_3d(request: VideoFrom3DRequest, http_request: Request = None): async def video_from_3d(request: VideoFrom3DRequest, http_request: Request = None):
"""Render a 3D model as a 360° turntable video.""" """Render a 3D model as a 360° turntable video."""
raw = _decode_b64(request.model_data) raw = _decode_b64(request.model_data)
...@@ -674,7 +674,7 @@ class Generate3DRequest(BaseModel): ...@@ -674,7 +674,7 @@ class Generate3DRequest(BaseModel):
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow")
@router.post("/v1/3d/generate") @router.post("/v1/3d/generate", summary="Generate a 3D model from a prompt")
async def generate_3d(request: Generate3DRequest, http_request: Request = None): async def generate_3d(request: Generate3DRequest, http_request: Request = None):
"""Generate a 3D model (GLB) from a text prompt and/or an image. """Generate a 3D model (GLB) from a text prompt and/or an image.
......
...@@ -32,6 +32,7 @@ logger = logging.getLogger(__name__) ...@@ -32,6 +32,7 @@ logger = logging.getLogger(__name__)
# Import from codai modules # Import from codai modules
from codai.models.manager import ModelManager, WhisperServerManager, MultiModelManager, model_manager, multi_model_manager from codai.models.manager import ModelManager, WhisperServerManager, MultiModelManager, model_manager, multi_model_manager
from codai.queue.manager import QueueManager, queue_manager from codai.queue.manager import QueueManager, queue_manager
from codai.tasks import task_registry
from codai.api.prompt_cache import prompt_cache_manager from codai.api.prompt_cache import prompt_cache_manager
from codai.pydantic.textrequest import ChatCompletionRequest, ToolFunction, Tool from codai.pydantic.textrequest import ChatCompletionRequest, ToolFunction, Tool
from codai.models.parser import filter_malformed_content, filter_repetition, format_tools_for_prompt, cleanup_control_tokens, OpenAIFormatter, ModelParserAdapter, ToolCallParser from codai.models.parser import filter_malformed_content, filter_repetition, format_tools_for_prompt, cleanup_control_tokens, OpenAIFormatter, ModelParserAdapter, ToolCallParser
...@@ -92,7 +93,7 @@ def set_grammar_guided_gen(enabled: bool): ...@@ -92,7 +93,7 @@ def set_grammar_guided_gen(enabled: bool):
router = APIRouter() router = APIRouter()
@router.post("/v1/chat/completions") @router.post("/v1/chat/completions", summary="Chat completions")
async def chat_completions(request: ChatCompletionRequest, http_request: Request = None): async def chat_completions(request: ChatCompletionRequest, http_request: Request = None):
"""Chat completions endpoint with streaming and tool support.""" """Chat completions endpoint with streaming and tool support."""
...@@ -1248,7 +1249,8 @@ async def stream_chat_response( ...@@ -1248,7 +1249,8 @@ async def stream_chat_response(
completion_id = f"chatcmpl-{uuid.uuid4().hex}" completion_id = f"chatcmpl-{uuid.uuid4().hex}"
created = int(time.time()) created = int(time.time())
request_id = f"req-{uuid.uuid4().hex[:8]}" request_id = f"req-{uuid.uuid4().hex[:8]}"
_tid = None
generated_text = "" generated_text = ""
# Check if model is loaded - if not, notify waiting clients # Check if model is loaded - if not, notify waiting clients
...@@ -1320,6 +1322,9 @@ async def stream_chat_response( ...@@ -1320,6 +1322,9 @@ async def stream_chat_response(
# Mark as starting processing # Mark as starting processing
await queue_manager.start_processing(request_id, model_name) await queue_manager.start_processing(request_id, model_name)
_tid = task_registry.register("text", title=(model_name or "chat"),
model=model_name or "", task_id=request_id)
task_registry.start(_tid)
# Send "Model starting" message # Send "Model starting" message
data = { data = {
...@@ -1374,6 +1379,9 @@ async def stream_chat_response( ...@@ -1374,6 +1379,9 @@ async def stream_chat_response(
response_format=response_format, response_format=response_format,
enable_thinking=enable_thinking, enable_thinking=enable_thinking,
): ):
# Cooperative cancellation: stop streaming if the task was cancelled.
if task_registry.is_cancelled(_tid):
break
chunk_count += 1 chunk_count += 1
# Always filter malformed content (regex-based, works per-chunk) # Always filter malformed content (regex-based, works per-chunk)
filtered_chunk = filter_malformed_content(chunk) filtered_chunk = filter_malformed_content(chunk)
...@@ -1580,6 +1588,9 @@ async def stream_chat_response( ...@@ -1580,6 +1588,9 @@ async def stream_chat_response(
finally: finally:
# Always clean up queue state # Always clean up queue state
await queue_manager.finish_processing() await queue_manager.finish_processing()
if _tid:
task_registry.finish(
_tid, "cancelled" if task_registry.is_cancelled(_tid) else "done")
async def generate_chat_response( async def generate_chat_response(
messages: List[Dict], messages: List[Dict],
...@@ -1789,7 +1800,7 @@ async def generate_chat_response( ...@@ -1789,7 +1800,7 @@ async def generate_chat_response(
from codai.pydantic.textrequest import CompletionRequest from codai.pydantic.textrequest import CompletionRequest
@router.post("/v1/completions") @router.post("/v1/completions", summary="Legacy text completions")
async def completions(request: CompletionRequest): async def completions(request: CompletionRequest):
"""Legacy text completions endpoint (for backward compatibility).""" """Legacy text completions endpoint (for backward compatibility)."""
# Get the model for this request # Get the model for this request
......
...@@ -119,7 +119,7 @@ def _format_response(fmt: str, text: str, segments: list): ...@@ -119,7 +119,7 @@ def _format_response(fmt: str, text: str, segments: list):
router = APIRouter() router = APIRouter()
@router.post("/v1/audio/transcriptions") @router.post("/v1/audio/transcriptions", summary="Transcribe audio to text")
async def create_transcription( async def create_transcription(
model: str = Form(...), model: str = Form(...),
file: UploadFile = File(...), file: UploadFile = File(...),
......
...@@ -64,7 +64,7 @@ class TTSResponse(BaseModel): ...@@ -64,7 +64,7 @@ class TTSResponse(BaseModel):
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow")
@router.post("/v1/audio/speech") @router.post("/v1/audio/speech", summary="Text-to-speech synthesis")
async def create_speech(request: TTSRequest, http_request: Request = None): async def create_speech(request: TTSRequest, http_request: Request = None):
""" """
Text-to-speech endpoint (OpenAI-compatible). Text-to-speech endpoint (OpenAI-compatible).
......
This diff is collapsed.
...@@ -185,13 +185,13 @@ class VoicePatchRequest(BaseModel): ...@@ -185,13 +185,13 @@ class VoicePatchRequest(BaseModel):
# Voice profile management # Voice profile management
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@router.get("/v1/audio/voices") @router.get("/v1/audio/voices", summary="List voice profiles")
async def list_voices(): async def list_voices():
"""List all saved voice profiles.""" """List all saved voice profiles."""
return {"voices": _list_voices()} return {"voices": _list_voices()}
@router.post("/v1/audio/voices") @router.post("/v1/audio/voices", summary="Create a voice profile")
async def create_voice( async def create_voice(
name: str = Form(...), name: str = Form(...),
transcript: str = Form(...), transcript: str = Form(...),
...@@ -216,7 +216,7 @@ async def create_voice( ...@@ -216,7 +216,7 @@ async def create_voice(
return {"created": True, "voice": meta} return {"created": True, "voice": meta}
@router.delete("/v1/audio/voices/{name}") @router.delete("/v1/audio/voices/{name}", summary="Delete a voice profile")
async def delete_voice(name: str): async def delete_voice(name: str):
"""Delete a saved voice profile.""" """Delete a saved voice profile."""
import shutil import shutil
...@@ -227,7 +227,7 @@ async def delete_voice(name: str): ...@@ -227,7 +227,7 @@ async def delete_voice(name: str):
return {"deleted": True, "name": name} return {"deleted": True, "name": name}
@router.patch("/v1/audio/voices/{name}") @router.patch("/v1/audio/voices/{name}", summary="Update a voice profile")
async def patch_voice(name: str, req: VoicePatchRequest): async def patch_voice(name: str, req: VoicePatchRequest):
"""Update description, transcript, or reference audio of a saved voice profile.""" """Update description, transcript, or reference audio of a saved voice profile."""
meta = _load_voice(name) meta = _load_voice(name)
...@@ -259,7 +259,7 @@ async def patch_voice(name: str, req: VoicePatchRequest): ...@@ -259,7 +259,7 @@ async def patch_voice(name: str, req: VoicePatchRequest):
return {"updated": True, "voice": meta} return {"updated": True, "voice": meta}
@router.get("/v1/audio/voices/{name}") @router.get("/v1/audio/voices/{name}", summary="Get a voice profile")
async def get_voice(name: str): async def get_voice(name: str):
"""Get a single voice profile metadata.""" """Get a single voice profile metadata."""
meta = _load_voice(name) meta = _load_voice(name)
...@@ -268,7 +268,7 @@ async def get_voice(name: str): ...@@ -268,7 +268,7 @@ async def get_voice(name: str):
return {"voice": meta} return {"voice": meta}
@router.post("/v1/audio/voices/extract") @router.post("/v1/audio/voices/extract", summary="Extract a voice profile from a sample")
async def extract_voice(req: VoiceExtractRequest): async def extract_voice(req: VoiceExtractRequest):
""" """
Extract a voice profile from a source audio or video file. Extract a voice profile from a source audio or video file.
...@@ -358,7 +358,7 @@ class VoiceCloneRequest(BaseModel): ...@@ -358,7 +358,7 @@ class VoiceCloneRequest(BaseModel):
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow")
@router.post("/v1/audio/clone") @router.post("/v1/audio/clone", summary="Clone a voice / synthesize cloned speech")
async def clone_voice(request: VoiceCloneRequest, http_request: Request = None): async def clone_voice(request: VoiceCloneRequest, http_request: Request = None):
""" """
Synthesize speech in a cloned voice using F5-TTS. Synthesize speech in a cloned voice using F5-TTS.
......
...@@ -94,7 +94,7 @@ class VoiceConvertRequest(BaseModel): ...@@ -94,7 +94,7 @@ class VoiceConvertRequest(BaseModel):
model_config = ConfigDict(extra='allow') model_config = ConfigDict(extra='allow')
@router.post('/v1/audio/convert') @router.post('/v1/audio/convert', summary="Voice conversion (speech-to-speech)")
async def convert_voice(request: VoiceConvertRequest, http_request: Request = None): async def convert_voice(request: VoiceConvertRequest, http_request: Request = None):
""" """
Voice conversion: preserves pitch/melody/expression, changes only timbre. Voice conversion: preserves pitch/melody/expression, changes only timbre.
......
...@@ -78,6 +78,40 @@ except ImportError: ...@@ -78,6 +78,40 @@ except ImportError:
_llama_cpp = None _llama_cpp = None
# Friendly KV-cache quant names → llama.cpp GGML type. q8_0 is near-lossless and
# the safe default; the q5/q4 types trade a little accuracy for ~2x less KV VRAM.
_KV_TYPE_ALIASES = {
'f16': 'GGML_TYPE_F16', 'fp16': 'GGML_TYPE_F16', 'f32': 'GGML_TYPE_F32',
'q8_0': 'GGML_TYPE_Q8_0', 'q8': 'GGML_TYPE_Q8_0', 'q8_1': 'GGML_TYPE_Q8_1',
'q5_0': 'GGML_TYPE_Q5_0', 'q5_1': 'GGML_TYPE_Q5_1', 'q5': 'GGML_TYPE_Q5_1',
'q4_0': 'GGML_TYPE_Q4_0', 'q4_1': 'GGML_TYPE_Q4_1', 'q4': 'GGML_TYPE_Q4_1',
'iq4_nl': 'GGML_TYPE_IQ4_NL',
}
# Sub-8-bit KV types that llama.cpp can only use with flash attention enabled.
_KV_NEEDS_FLASH = {'q5_0', 'q5_1', 'q5', 'q4_0', 'q4_1', 'q4', 'iq4_nl'}
def _ggml_kv_type(name):
"""Map a KV-cache quant name to the llama.cpp GGML type int, or None.
Returns None for falsy / unknown / 'none' / 'auto' values (→ keep the
llama.cpp default, f16). Unknown names log a warning instead of failing."""
if not name or _llama_cpp is None:
return None
key = str(name).strip().lower().replace('-', '_').replace(' ', '')
if key in ('', 'none', 'auto', 'default', 'f16default'):
return None
const = _KV_TYPE_ALIASES.get(key)
if const is None:
print(f" KV cache type '{name}' not recognized — using default (f16)")
return None
val = getattr(_llama_cpp, const, None)
if val is None:
print(f" KV cache type '{name}' unsupported by this llama.cpp build — using f16")
return val
def _install_layer_log_callback(): def _install_layer_log_callback():
"""Replace llama.cpp's log callback with one that prints load-time layer/buffer """Replace llama.cpp's log callback with one that prints load-time layer/buffer
messages directly to stdout. Returns the callback object — keep a reference messages directly to stdout. Returns the callback object — keep a reference
...@@ -613,7 +647,33 @@ class VulkanBackend(ModelBackend): ...@@ -613,7 +647,33 @@ class VulkanBackend(ModelBackend):
llama_kwargs['rope_freq_base'] = kwargs['rope_freq_base'] llama_kwargs['rope_freq_base'] = kwargs['rope_freq_base']
if 'rope_freq_scale' in kwargs: if 'rope_freq_scale' in kwargs:
llama_kwargs['rope_freq_scale'] = kwargs['rope_freq_scale'] llama_kwargs['rope_freq_scale'] = kwargs['rope_freq_scale']
# KV-cache quantization (llama.cpp type_k / type_v). Shrinks the KV cache
# so long contexts fit in less VRAM. Read from the per-model config, with
# the raw models.json entry as a fallback (carried in _raw_cfg).
_raw_cfg = kwargs.get('_raw_cfg') or {}
_ck = kwargs.get('cache_type_k', _raw_cfg.get('cache_type_k'))
_cv = kwargs.get('cache_type_v', _raw_cfg.get('cache_type_v'))
_flash = bool(kwargs.get('flash_attn', _raw_cfg.get('flash_attn',
_raw_cfg.get('flash_attention', False))))
_tk = _ggml_kv_type(_ck)
_tv = _ggml_kv_type(_cv)
if _tk is not None:
llama_kwargs['type_k'] = _tk
if _tv is not None:
llama_kwargs['type_v'] = _tv
# A quantized V cache below 8 bits requires flash attention in llama.cpp;
# auto-enable it (with a note) so the config "just works".
_v_needs_flash = str(_cv or '').strip().lower().replace('-', '_') in _KV_NEEDS_FLASH
if (_tk is not None or _tv is not None):
if _v_needs_flash and not _flash:
_flash = True
print(" KV cache: sub-8-bit V cache needs flash attention — enabling it")
if _flash:
llama_kwargs['flash_attn'] = True
print(f" KV cache: type_k={_ck or 'f16'} type_v={_cv or 'f16'}"
f"{' (flash_attn on)' if _flash else ''}")
# Force CUDA if requested # Force CUDA if requested
if self.force_cuda: if self.force_cuda:
# Set environment variable to force CUDA # Set environment variable to force CUDA
......
...@@ -247,4 +247,25 @@ configuration directory (--config DIR, default: OS-specific CoderAI directory). ...@@ -247,4 +247,25 @@ configuration directory (--config DIR, default: OS-specific CoderAI directory).
action="store_true", action="store_true",
help="List available Vulkan GPU devices and exit", help="List available Vulkan GPU devices and exit",
) )
parser.add_argument(
"--no-resume-jobs",
action="store_true",
help="Do not resume/recover interrupted LoRA training jobs on restart. "
"Mid-flight jobs are marked 'cancelled' (checkpoints are kept, so they "
"can still be restarted manually from the Tasks page).",
)
parser.add_argument(
"--pipeline-cache",
action="store_true",
help="Cache quantized diffusers pipelines to disk after the first build "
"and reload them from that cache on later starts — skipping the "
"expensive re-download/re-quantization (e.g. the Wan2.2 A14B). The "
"fast acceleration LoRA fuse is re-applied per load. Uses extra disk.",
)
parser.add_argument(
"--rebuild-pipeline-cache",
action="store_true",
help="Ignore any existing pipeline cache and rebuild it from scratch this "
"run (use after changing a model's quantization/precision config).",
)
return parser.parse_args() return parser.parse_args()
...@@ -126,6 +126,17 @@ class ThermalConfig: ...@@ -126,6 +126,17 @@ class ThermalConfig:
poll_seconds: float = 5.0 # how often to re-check while cooling down poll_seconds: float = 5.0 # how often to re-check while cooling down
@dataclass
class JobsConfig:
"""Background-job (LoRA training) configuration."""
# When True, an interrupted training job (process restart) is left
# 'interrupted' so it can resume from its on-disk checkpoint. When False,
# such jobs are marked 'cancelled' on startup and not auto-resumed (their
# checkpoints are kept, so they can be restarted manually from the Tasks
# page). The --no-resume-jobs CLI flag forces this off for one run.
resume_on_restart: bool = True
@dataclass @dataclass
class Config: class Config:
"""Main configuration class.""" """Main configuration class."""
...@@ -139,6 +150,7 @@ class Config: ...@@ -139,6 +150,7 @@ class Config:
whisper: WhisperConfig = field(default_factory=WhisperConfig) whisper: WhisperConfig = field(default_factory=WhisperConfig)
archive: ArchiveConfig = field(default_factory=ArchiveConfig) archive: ArchiveConfig = field(default_factory=ArchiveConfig)
thermal: ThermalConfig = field(default_factory=ThermalConfig) thermal: ThermalConfig = field(default_factory=ThermalConfig)
jobs: JobsConfig = field(default_factory=JobsConfig)
broker: BrokerConfig = field(default_factory=BrokerConfig) broker: BrokerConfig = field(default_factory=BrokerConfig)
system_prompt: Optional[str] = None system_prompt: Optional[str] = None
tools_closer_prompt: bool = False tools_closer_prompt: bool = False
...@@ -293,6 +305,7 @@ class ConfigManager: ...@@ -293,6 +305,7 @@ class ConfigManager:
whisper=WhisperConfig(**config_data.get("whisper", {})), whisper=WhisperConfig(**config_data.get("whisper", {})),
archive=ArchiveConfig(**config_data.get("archive", {})), archive=ArchiveConfig(**config_data.get("archive", {})),
thermal=ThermalConfig(**config_data.get("thermal", {})), thermal=ThermalConfig(**config_data.get("thermal", {})),
jobs=JobsConfig(**config_data.get("jobs", {})),
broker=BrokerConfig(**config_data.get("broker", {})), broker=BrokerConfig(**config_data.get("broker", {})),
system_prompt=config_data.get("system_prompt"), system_prompt=config_data.get("system_prompt"),
tools_closer_prompt=config_data.get("tools_closer_prompt", False), tools_closer_prompt=config_data.get("tools_closer_prompt", False),
...@@ -411,6 +424,9 @@ class ConfigManager: ...@@ -411,6 +424,9 @@ class ConfigManager:
"gpu_resume": self.config.thermal.gpu_resume, "gpu_resume": self.config.thermal.gpu_resume,
"poll_seconds": self.config.thermal.poll_seconds, "poll_seconds": self.config.thermal.poll_seconds,
}, },
"jobs": {
"resume_on_restart": self.config.jobs.resume_on_restart,
},
"broker": { "broker": {
"enabled": self.config.broker.enabled, "enabled": self.config.broker.enabled,
"base_url": self.config.broker.base_url, "base_url": self.config.broker.base_url,
......
...@@ -147,6 +147,8 @@ def build_runtime_kwargs(model_cfg, model_type): ...@@ -147,6 +147,8 @@ def build_runtime_kwargs(model_cfg, model_type):
} }
if model_type == "text": if model_type == "text":
kwargs['ctx'] = model_cfg.get('n_ctx', model_cfg.get('context_size')) kwargs['ctx'] = model_cfg.get('n_ctx', model_cfg.get('context_size'))
kwargs['cache_type_k'] = model_cfg.get('cache_type_k')
kwargs['cache_type_v'] = model_cfg.get('cache_type_v')
elif model_type == "image": elif model_type == "image":
kwargs['llm_path'] = model_cfg.get('llm_path') kwargs['llm_path'] = model_cfg.get('llm_path')
kwargs['vae_path'] = model_cfg.get('vae_path') kwargs['vae_path'] = model_cfg.get('vae_path')
...@@ -865,9 +867,24 @@ def main(): ...@@ -865,9 +867,24 @@ def main():
from codai.api.characters import set_global_args as set_chars_global_args from codai.api.characters import set_global_args as set_chars_global_args
set_chars_global_args(global_args) set_chars_global_args(global_args)
# Set LoRA training module global args # Set LoRA training module global args. Resolve job-recovery first (the
from codai.api.loras import set_global_args as set_loras_global_args # --no-resume-jobs flag overrides the persisted config setting), then call
# set_global_args, which runs _load_jobs_on_start and honours the flag.
from codai.api.loras import (set_global_args as set_loras_global_args,
set_resume_enabled as set_loras_resume_enabled)
_resume_jobs = bool(getattr(config.jobs, "resume_on_restart", True)) and not getattr(args, "no_resume_jobs", False)
set_loras_resume_enabled(_resume_jobs)
set_loras_global_args(global_args) set_loras_global_args(global_args)
if not _resume_jobs:
print("LoRA job recovery: DISABLED (interrupted training will be cancelled on restart)")
if getattr(args, "pipeline_cache", False):
try:
from codai.models.pipeline_cache import cache_root
_pc_extra = " (rebuilding this run)" if getattr(args, "rebuild_pipeline_cache", False) else ""
print(f"Pipeline cache: ENABLED{_pc_extra} — quantized pipelines cached at {cache_root()}")
except Exception:
print("Pipeline cache: ENABLED")
# Set environment profiles module global args # Set environment profiles module global args
from codai.api.environments import set_global_args as set_envs_global_args from codai.api.environments import set_global_args as set_envs_global_args
...@@ -955,13 +972,18 @@ def main(): ...@@ -955,13 +972,18 @@ def main():
if not _debug_web: if not _debug_web:
class _AccessNoiseFilter(logging.Filter): class _AccessNoiseFilter(logging.Filter):
# uvicorn.access record args: (client_addr, method, full_path, http_ver, status) # uvicorn.access record args: (client_addr, method, full_path, http_ver, status)
_NOISY = ("/v1/loras/progress",) _NOISY_PREFIX = ("/v1/loras/progress",)
# Exact-match only, so the live Tasks-page pollers are dropped but the
# user-initiated action endpoints (/admin/api/tasks/{id}/pause, …) still log.
_NOISY_EXACT = ("/admin/api/tasks", "/admin/api/system-stats")
def filter(self, record): def filter(self, record):
try: try:
args = record.args args = record.args
if isinstance(args, (tuple, list)) and len(args) >= 3: if isinstance(args, (tuple, list)) and len(args) >= 3:
path = str(args[2]).split("?", 1)[0] path = str(args[2]).split("?", 1)[0]
if any(path == p or path.startswith(p) for p in self._NOISY): if path in self._NOISY_EXACT:
return False
if any(path == p or path.startswith(p) for p in self._NOISY_PREFIX):
return False return False
except Exception: except Exception:
pass pass
......
...@@ -52,7 +52,15 @@ ACCEL_PRESETS: dict = { ...@@ -52,7 +52,15 @@ ACCEL_PRESETS: dict = {
"label": "Wan2.2 Lightning (4-step DMD)", "label": "Wan2.2 Lightning (4-step DMD)",
"family": "wan", "family": "wan",
"applies_to": ["video"], "applies_to": ["video"],
"lora": "lightx2v/Wan2.2-Lightning", # Wan2.2 A14B is a two-expert MoE: the distill LoRA must be fused into BOTH
# the high-noise (transformer) and low-noise (transformer_2) experts, or the
# clip collapses to a solid colour at 4 steps. These default to the locally
# installed lightx2v/Wan2.2-Lightning weights (resolved from cache — not a
# download); override per model in the Acceleration config for T2V or a
# different rank/version. `lora` stays None because the two experts differ.
"lora": None,
"lora_high": "lightx2v/Wan2.2-Lightning:Wan2.2-I2V-A14B-4steps-lora-rank64-Seko-V1/high_noise_model.safetensors",
"lora_low": "lightx2v/Wan2.2-Lightning:Wan2.2-I2V-A14B-4steps-lora-rank64-Seko-V1/low_noise_model.safetensors",
"lora_weight": 1.0, "lora_weight": 1.0,
"steps": 4, "steps": 4,
"guidance_scale": 1.0, "guidance_scale": 1.0,
...@@ -175,6 +183,12 @@ def resolve_acceleration(model_cfg: Optional[dict]) -> Optional[dict]: ...@@ -175,6 +183,12 @@ def resolve_acceleration(model_cfg: Optional[dict]) -> Optional[dict]:
out = { out = {
"preset": preset_key or "custom", "preset": preset_key or "custom",
"lora": _pick("lora"), "lora": _pick("lora"),
# Wan2.2 A14B is a two-expert MoE: the distill LoRA differs for the
# high-noise (transformer) and low-noise (transformer_2) experts. When
# these are set they take precedence over the single `lora` per expert;
# otherwise the single `lora` is applied to BOTH experts.
"lora_high": _pick("lora_high"),
"lora_low": _pick("lora_low"),
"lora_weight": _pick("lora_weight", 1.0), "lora_weight": _pick("lora_weight", 1.0),
"steps": _pick("steps"), "steps": _pick("steps"),
"guidance_scale": _pick("guidance_scale"), "guidance_scale": _pick("guidance_scale"),
...@@ -255,35 +269,94 @@ def apply_accel_to_pipeline(pipe, accel: Optional[dict]) -> None: ...@@ -255,35 +269,94 @@ def apply_accel_to_pipeline(pipe, accel: Optional[dict]) -> None:
log.warning("[accel] flow_shift apply failed: %s", e) log.warning("[accel] flow_shift apply failed: %s", e)
# 3. Fuse the distill LoRA (when one is configured — turbo has none). # 3. Fuse the distill LoRA (when one is configured — turbo has none).
lora_ref = accel.get("lora") # `_coderai_accel_fused` records whether a distill LoRA actually baked in,
if not lora_ref: # so the generator only drops to the preset's low step count when the model
# is genuinely distilled (running 4 steps un-distilled collapses the video
# to a solid colour — exactly the Wan2.2 dual-expert failure mode).
try:
pipe._coderai_accel_fused = False
except Exception:
pass
has_t2 = getattr(pipe, "transformer_2", None) is not None
lora_high = accel.get("lora_high") or accel.get("lora")
lora_low = accel.get("lora_low") or accel.get("lora")
if not lora_high and not lora_low:
# No LoRA (e.g. a full distilled model like SDXL-Turbo) — treat as distilled.
try:
pipe._coderai_accel_fused = True
except Exception:
pass
return return
if not hasattr(pipe, "load_lora_weights"): if not hasattr(pipe, "load_lora_weights"):
log.warning("[accel] pipeline %s has no load_lora_weights — cannot fuse " log.warning("[accel] pipeline %s has no load_lora_weights — cannot fuse "
"acceleration LoRA", type(pipe).__name__) "acceleration LoRA", type(pipe).__name__)
return return
repo, weight_name = _split_lora_ref(lora_ref)
weight = float(accel.get("lora_weight") or 1.0) weight = float(accel.get("lora_weight") or 1.0)
try:
load_kwargs = {"adapter_name": "__accel__"} def _load_one(ref, into_t2: bool, adapter: str) -> bool:
repo, weight_name = _split_lora_ref(ref)
kw = {"adapter_name": adapter}
if weight_name: if weight_name:
load_kwargs["weight_name"] = weight_name kw["weight_name"] = weight_name
pipe.load_lora_weights(repo, **load_kwargs) if into_t2:
kw["load_into_transformer_2"] = True
pipe.load_lora_weights(repo, **kw)
return True
loaded_adapters = []
try:
# High-noise expert (transformer) — always.
if lora_high and _load_one(lora_high, False, "__accel__"):
loaded_adapters.append("__accel__")
# Low-noise expert (transformer_2) — only on dual-expert Wan2.2 models.
if has_t2 and lora_low:
try:
if _load_one(lora_low, True, "__accel_2__"):
loaded_adapters.append("__accel_2__")
except Exception as e2:
log.warning("[accel] could not load distill LoRA into transformer_2 "
"(%s) — low-noise expert stays un-distilled", e2)
elif has_t2 and not lora_low:
log.warning("[accel] model has a second expert (transformer_2) but no "
"low-noise distill LoRA — set acceleration.lora_low")
if not loaded_adapters:
raise RuntimeError("no distill adapter registered on the pipeline")
try: try:
pipe.set_adapters(["__accel__"], [weight]) pipe.set_adapters(loaded_adapters, [weight] * len(loaded_adapters))
except Exception: except Exception:
pass pass
# Bake it in, then drop the adapter handle so per-request LoRAs are clean. # Bake them in, then drop the adapter handles so per-request LoRAs are clean.
pipe.fuse_lora(lora_scale=weight) # CRITICAL: diffusers' Wan fuse_lora defaults to components=["transformer"],
# so without naming transformer_2 the low-noise expert's distill adapter is
# never fused — and the subsequent unload strips it off, leaving that expert
# undistilled. At 4 steps that collapses the clip to a solid colour. Fuse
# BOTH experts explicitly.
_fuse_components = ["transformer"]
if has_t2:
_fuse_components.append("transformer_2")
try:
pipe.fuse_lora(components=_fuse_components, lora_scale=weight)
except TypeError:
# Older diffusers without the `components` kwarg — best effort.
pipe.fuse_lora(lora_scale=weight)
try: try:
pipe.unload_lora_weights() pipe.unload_lora_weights()
except Exception: except Exception:
pass pass
log.info("[accel] fused distillation LoRA %s (weight=%s) into %s", try:
repo, weight, type(pipe).__name__) pipe._coderai_accel_fused = True
except Exception:
pass
log.info("[accel] fused distillation LoRA(s) %s (weight=%s) into %s%s",
loaded_adapters, weight, type(pipe).__name__,
" (both experts)" if len(loaded_adapters) > 1 else "")
except Exception as e: except Exception as e:
log.warning("[accel] failed to fuse acceleration LoRA %s: %s — generating " log.warning("[accel] failed to fuse acceleration LoRA (high=%s low=%s): %s "
"without acceleration", lora_ref, e) "— generating WITHOUT acceleration (step count will fall back to "
"a safe default, not the preset's distilled count)",
lora_high, lora_low, e)
def accel_call_defaults(accel: Optional[dict]) -> dict: def accel_call_defaults(accel: Optional[dict]) -> dict:
......
...@@ -182,6 +182,15 @@ def build_pipeline_quant_config(model_name: str, cfg: Optional[Dict[str, Any]], ...@@ -182,6 +182,15 @@ def build_pipeline_quant_config(model_name: str, cfg: Optional[Dict[str, Any]],
bnb_4bit_compute_dtype=dtype, bnb_4bit_use_double_quant=True) bnb_4bit_compute_dtype=dtype, bnb_4bit_use_double_quant=True)
return BnB(load_in_8bit=True) return BnB(load_in_8bit=True)
def _bnb_incompatible(name: str) -> bool:
# bitsandbytes (4/8-bit) and optimum-quanto (2-bit) only quantize
# nn.Linear. A fully-convolutional component (the VAE) has no Linear
# layers, so applying them triggers a hard "no linear modules were found"
# error. Such components must stay full precision (a smaller VAE comes
# from a GGUF VAE instead, handled separately).
n = (name or '').lower()
return n == 'vae' or n.endswith('_vae') or n.startswith('vae')
quant_mapping: Dict[str, Any] = {} quant_mapping: Dict[str, Any] = {}
descs = [] descs = []
if comp_q: if comp_q:
...@@ -189,6 +198,11 @@ def build_pipeline_quant_config(model_name: str, cfg: Optional[Dict[str, Any]], ...@@ -189,6 +198,11 @@ def build_pipeline_quant_config(model_name: str, cfg: Optional[Dict[str, Any]],
mode = _normalize_quant_mode(raw_mode) # GGUF/none → None here mode = _normalize_quant_mode(raw_mode) # GGUF/none → None here
if mode is None: if mode is None:
continue continue
if _bnb_incompatible(name):
print(f" Skipping {mode} for '{name}': it has no Linear layers "
f"(conv-only) — bitsandbytes/quanto cannot quantize it; "
f"leaving full precision (use a GGUF VAE to shrink the VAE).")
continue
cfg_obj = _mk(comp_lib.get(name, 'diffusers'), mode) cfg_obj = _mk(comp_lib.get(name, 'diffusers'), mode)
if cfg_obj is not None: if cfg_obj is not None:
quant_mapping[name] = cfg_obj quant_mapping[name] = cfg_obj
...@@ -198,6 +212,8 @@ def build_pipeline_quant_config(model_name: str, cfg: Optional[Dict[str, Any]], ...@@ -198,6 +212,8 @@ def build_pipeline_quant_config(model_name: str, cfg: Optional[Dict[str, Any]],
targets = [n for n in comp_lib if _is_heavy(n)] or \ targets = [n for n in comp_lib if _is_heavy(n)] or \
['transformer', 'transformer_2', 'text_encoder', 'unet'] ['transformer', 'transformer_2', 'text_encoder', 'unet']
for name in targets: for name in targets:
if _bnb_incompatible(name):
continue
cfg_obj = _mk(comp_lib.get(name, 'diffusers'), mode) cfg_obj = _mk(comp_lib.get(name, 'diffusers'), mode)
if cfg_obj is not None: if cfg_obj is not None:
quant_mapping[name] = cfg_obj quant_mapping[name] = cfg_obj
......
...@@ -790,6 +790,17 @@ class MultiModelManager: ...@@ -790,6 +790,17 @@ class MultiModelManager:
# build_kwargs_from_config populates it from the model's # build_kwargs_from_config populates it from the model's
# 'flash_attention' setting; CLI/global is NOT consulted here. # 'flash_attention' setting; CLI/global is NOT consulted here.
kwargs['flash_attn'] = bool(config.get('flash_attn', False)) kwargs['flash_attn'] = bool(config.get('flash_attn', False))
# KV-cache quantization (llama.cpp type_k/type_v) — pass through
# to the backend, with the raw models.json entry as a fallback.
_raw = config.get('_raw_cfg') if isinstance(config.get('_raw_cfg'), dict) else {}
for _kvk in ('cache_type_k', 'cache_type_v'):
_kvv = config.get(_kvk)
if _kvv is None:
_kvv = _raw.get(_kvk)
if _kvv:
kwargs[_kvk] = _kvv
if _raw and '_raw_cfg' not in kwargs:
kwargs['_raw_cfg'] = _raw
no_ram = _cfg_or_global('no_ram', 'no_ram', False) no_ram = _cfg_or_global('no_ram', 'no_ram', False)
kwargs['no_ram'] = bool(no_ram) kwargs['no_ram'] = bool(no_ram)
offload_strategy = _cfg_or_global('offload_strategy', 'offload_strategy', 'auto') offload_strategy = _cfg_or_global('offload_strategy', 'offload_strategy', 'auto')
...@@ -806,7 +817,9 @@ class MultiModelManager: ...@@ -806,7 +817,9 @@ class MultiModelManager:
print(f"Loading default model on demand: {self.default_model}") print(f"Loading default model on demand: {self.default_model}")
_snap = self.vram_before_load() _snap = self.vram_before_load()
kwargs['expected_vram_gb'] = self._get_model_used_vram_gb(self.default_model) kwargs['expected_vram_gb'] = self._get_model_used_vram_gb(self.default_model)
model_manager.load_model(self.default_model, backend_type=backend_type, **kwargs) from codai.tasks import loading_task
with loading_task(self.default_model, model_type="text"):
model_manager.load_model(self.default_model, backend_type=backend_type, **kwargs)
self.add_model(self.default_model, model_manager) self.add_model(self.default_model, model_manager)
self.record_vram_delta(self.default_model, _snap) self.record_vram_delta(self.default_model, _snap)
self.current_model_key = self.default_model self.current_model_key = self.default_model
...@@ -872,6 +885,17 @@ class MultiModelManager: ...@@ -872,6 +885,17 @@ class MultiModelManager:
# build_kwargs_from_config populates it from the model's # build_kwargs_from_config populates it from the model's
# 'flash_attention' setting; CLI/global is NOT consulted here. # 'flash_attention' setting; CLI/global is NOT consulted here.
kwargs['flash_attn'] = bool(config.get('flash_attn', False)) kwargs['flash_attn'] = bool(config.get('flash_attn', False))
# KV-cache quantization (llama.cpp type_k/type_v) — pass through
# to the backend, with the raw models.json entry as a fallback.
_raw = config.get('_raw_cfg') if isinstance(config.get('_raw_cfg'), dict) else {}
for _kvk in ('cache_type_k', 'cache_type_v'):
_kvv = config.get(_kvk)
if _kvv is None:
_kvv = _raw.get(_kvk)
if _kvv:
kwargs[_kvk] = _kvv
if _raw and '_raw_cfg' not in kwargs:
kwargs['_raw_cfg'] = _raw
no_ram = _cfg_or_global('no_ram', 'no_ram', False) no_ram = _cfg_or_global('no_ram', 'no_ram', False)
kwargs['no_ram'] = bool(no_ram) kwargs['no_ram'] = bool(no_ram)
offload_strategy = _cfg_or_global('offload_strategy', 'offload_strategy', 'auto') offload_strategy = _cfg_or_global('offload_strategy', 'offload_strategy', 'auto')
...@@ -894,7 +918,9 @@ class MultiModelManager: ...@@ -894,7 +918,9 @@ class MultiModelManager:
# it can decide whether Flash-Attention-2 is safe (FA2 requires the # it can decide whether Flash-Attention-2 is safe (FA2 requires the
# whole model on GPU; it device-side-asserts when layers offload). # whole model on GPU; it device-side-asserts when layers offload).
kwargs['expected_vram_gb'] = self._get_model_used_vram_gb(model_name) kwargs['expected_vram_gb'] = self._get_model_used_vram_gb(model_name)
model_manager.load_model(model_name, backend_type=backend_type, **kwargs) from codai.tasks import loading_task
with loading_task(model_name, model_type="text"):
model_manager.load_model(model_name, backend_type=backend_type, **kwargs)
self.add_model(model_name, model_manager) self.add_model(model_name, model_manager)
self.record_vram_delta(model_name, _snap) self.record_vram_delta(model_name, _snap)
self.current_model_key = model_name self.current_model_key = model_name
......
# CoderAI - OpenAI-compatible API server
# Copyright (C) 2026 Stefy Lanza <stefy@nexlab.net>
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
"""On-disk cache of *built* diffusers pipelines.
Building a large quantized video pipeline (e.g. Wan2.2 A14B at 4-bit) is slow:
download + bitsandbytes quantization of ~28B parameters. The weights don't change
between restarts, so once built we ``save_pretrained`` the pipeline to a local
cache keyed by ``(model, quantization, precision)``. A later start with
``--pipeline-cache`` reloads from there with a plain ``from_pretrained`` of the
already-quantized weights — no re-download, no re-quantization.
Scope: only the *base* pipeline is cached. The acceleration/distillation LoRA is
NOT baked into the cache — it is re-fused on every load (a fast operation), so the
cache stays independent of the (cheap to re-apply) ``acceleration`` config and we
avoid the fragile round-trip of serialising a fused + quantized model.
Everything here is best-effort: any failure (save or load) is swallowed and the
caller falls back to a normal build, so the cache can never break generation.
"""
import hashlib
import json
import os
import shutil
import time
from typing import Optional
# Bump when the cache layout / marker format changes so stale caches are ignored.
_CACHE_VERSION = 1
def _global_args():
try:
from codai.api.state import get_global_args
return get_global_args()
except Exception:
return None
def enabled() -> bool:
"""True when --pipeline-cache was passed."""
ga = _global_args()
return bool(ga is not None and getattr(ga, "pipeline_cache", False))
def _force_rebuild() -> bool:
ga = _global_args()
return bool(ga is not None and getattr(ga, "rebuild_pipeline_cache", False))
def cache_root() -> str:
"""Root dir for cached pipelines. Sits next to the offload dir by default."""
ga = _global_args()
offload_dir = getattr(ga, "offload_dir", None) if ga else None
if offload_dir:
root = os.path.join(os.path.dirname(os.path.abspath(os.path.expanduser(offload_dir))),
"pipeline_cache")
else:
root = os.path.join(os.path.expanduser("~"), ".cache", "coderai", "pipeline_cache")
return root
def _signature(model_name: str, model_cfg: Optional[dict]) -> str:
"""Stable hash of everything that changes the *built* (quantized) weights:
the model id, the quantization choices, and the precision. NOT acceleration
(re-applied per load) and NOT offload (a runtime placement decision)."""
c = model_cfg or {}
payload = {
"v": _CACHE_VERSION,
"model": model_name,
"precision": c.get("precision") or "bf16",
"load_in_4bit": bool(c.get("load_in_4bit", False)),
"load_in_8bit": bool(c.get("load_in_8bit", False)),
"component_quantization": c.get("component_quantization") or {},
}
blob = json.dumps(payload, sort_keys=True, default=str)
return hashlib.sha256(blob.encode()).hexdigest()[:16]
def _safe_name(model_name: str) -> str:
return "".join(ch if ch.isalnum() or ch in "-._" else "_" for ch in model_name)[:80]
def path(model_name: str, model_cfg: Optional[dict]) -> str:
"""Absolute cache directory for this model + quant/precision signature."""
return os.path.join(cache_root(),
f"{_safe_name(model_name)}__{_signature(model_name, model_cfg)}")
def _marker(p: str) -> str:
return os.path.join(p, ".coderai_pipeline_cache.json")
def valid(p: str) -> bool:
"""True if a complete, current cache exists at ``p`` and rebuild wasn't forced."""
if not p or _force_rebuild():
return False
try:
if not os.path.isfile(os.path.join(p, "model_index.json")):
return False
with open(_marker(p)) as f:
meta = json.load(f)
return meta.get("version") == _CACHE_VERSION and meta.get("complete") is True
except Exception:
return False
def invalidate(model_name: str, model_cfg: Optional[dict]) -> None:
"""Delete a model's cache dir (e.g. after a failed cache load) so the next
build rewrites it. Best-effort."""
try:
p = path(model_name, model_cfg)
if p and os.path.isdir(p):
shutil.rmtree(p, ignore_errors=True)
print(f" [pipeline-cache] invalidated {p}")
except Exception:
pass
def save(pipe, p: str, *, model_name: str = "", model_cfg: Optional[dict] = None) -> bool:
"""Serialize ``pipe`` to the cache dir ``p`` (atomic via a temp dir).
Returns True on success. Any failure is logged and returns False — the caller
keeps the freshly built in-memory pipeline regardless."""
if not p:
return False
tmp = p + ".building"
try:
os.makedirs(cache_root(), exist_ok=True)
if os.path.exists(tmp):
shutil.rmtree(tmp, ignore_errors=True)
print(f" [pipeline-cache] saving quantized pipeline → {p}")
t0 = time.time()
pipe.save_pretrained(tmp)
with open(_marker(tmp), "w") as f:
json.dump({
"version": _CACHE_VERSION, "complete": True,
"model": model_name, "saved_at": time.time(),
"signature": _signature(model_name, model_cfg),
}, f)
if os.path.exists(p):
shutil.rmtree(p, ignore_errors=True)
os.replace(tmp, p)
print(f" [pipeline-cache] saved in {time.time() - t0:.0f}s")
return True
except Exception as e:
print(f" [pipeline-cache] save failed ({e}) — continuing without a cache")
try:
shutil.rmtree(tmp, ignore_errors=True)
except Exception:
pass
return False
...@@ -35,10 +35,62 @@ Semantics (per sensor, when enabled): ...@@ -35,10 +35,62 @@ Semantics (per sensor, when enabled):
import os import os
import shutil import shutil
import subprocess import subprocess
import threading
import time import time
from typing import Optional, Tuple from typing import Optional, Tuple
# ---------------------------------------------------------------------------
# Cooldown state (published for the admin Tasks view)
# ---------------------------------------------------------------------------
# A thermal pause is a *global* hardware event: every worker that reaches a
# checkpoint blocks until temps recover. We publish a single process-wide state
# so the Tasks page can show that running work is paused for cooldown. A waiter
# counter (not a bool) keeps the state correct when several workers pause at
# once — the state is "active" while any worker is still cooling.
_cooldown_lock = threading.Lock()
_cooldown_waiters = 0
_cooldown_state: dict = {
"active": False, "since": 0.0, "waited": 0.0,
"gpu": None, "cpu": None, "message": "",
}
def get_cooldown_state() -> dict:
"""Snapshot of the current thermal cooldown (see module note). ``active`` is
True while at least one worker is paused waiting for the hardware to cool."""
with _cooldown_lock:
return dict(_cooldown_state)
def _cooldown_enter() -> None:
global _cooldown_waiters
with _cooldown_lock:
_cooldown_waiters += 1
_cooldown_state["active"] = True
if not _cooldown_state.get("since"):
_cooldown_state["since"] = time.time()
def _cooldown_update(gpu, cpu, waited, message) -> None:
with _cooldown_lock:
_cooldown_state["gpu"] = gpu
_cooldown_state["cpu"] = cpu
_cooldown_state["waited"] = waited
_cooldown_state["message"] = message
def _cooldown_exit() -> None:
global _cooldown_waiters
with _cooldown_lock:
_cooldown_waiters = max(0, _cooldown_waiters - 1)
if _cooldown_waiters == 0:
_cooldown_state.update({
"active": False, "since": 0.0, "waited": 0.0,
"gpu": None, "cpu": None, "message": "",
})
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Temperature readers # Temperature readers
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
...@@ -199,6 +251,57 @@ def read_cpu_temp() -> Optional[float]: ...@@ -199,6 +251,57 @@ def read_cpu_temp() -> Optional[float]:
return val return val
_gpu_util_cache: Tuple[float, Optional[float]] = (0.0, None)
def _read_gpu_util_uncached() -> Optional[float]:
"""Hottest GPU utilization in %, or None if unreadable."""
if _NVIDIA_SMI:
out = _run([
_NVIDIA_SMI,
"--query-gpu=utilization.gpu",
"--format=csv,noheader,nounits",
])
if out:
vals = []
for line in out.splitlines():
line = line.strip()
if line:
try:
vals.append(float(line))
except ValueError:
pass
if vals:
return max(vals)
if _ROCM_SMI:
out = _run([_ROCM_SMI, "--showuse"])
if out:
vals = []
for line in out.splitlines():
low = line.lower()
if "gpu use" in low and "%" in line:
for tok in line.replace("%", " ").split():
try:
vals.append(float(tok))
except ValueError:
continue
if vals:
return max(vals)
return None
def read_gpu_util() -> Optional[float]:
"""GPU utilization % (cached ~2s), or None if unreadable."""
global _gpu_util_cache
now = time.monotonic()
ts, val = _gpu_util_cache
if now - ts < _CACHE_TTL:
return val
val = _read_gpu_util_uncached()
_gpu_util_cache = (now, val)
return val
def read_cpu_temp_avg(samples: int = 3, max_seconds: float = 3.0) -> Optional[float]: def read_cpu_temp_avg(samples: int = 3, max_seconds: float = 3.0) -> Optional[float]:
"""Averaged CPU temperature for stable resume/cooldown decisions. """Averaged CPU temperature for stable resume/cooldown decisions.
...@@ -372,25 +475,30 @@ def wait_until_safe(settings: Optional[ThermalSettings] = None, ...@@ -372,25 +475,30 @@ def wait_until_safe(settings: Optional[ThermalSettings] = None,
f"until cooldown (GPU<={settings.gpu_resume:.0f}°C / " f"until cooldown (GPU<={settings.gpu_resume:.0f}°C / "
f"CPU<={settings.cpu_resume:.0f}°C)") f"CPU<={settings.cpu_resume:.0f}°C)")
waited = 0.0 waited = 0.0
while True: _cooldown_enter()
# Re-evaluate against resume thresholds (lower than trigger → hysteresis). try:
# CPU temps are noisy, so average a few samples for the resume decision while True:
# (the pause check above stays single-read to react fast to spikes). # Re-evaluate against resume thresholds (lower than trigger → hysteresis).
gt = read_gpu_temp() if settings.gpu_enabled else None # CPU temps are noisy, so average a few samples for the resume decision
ct = read_cpu_temp_avg() if settings.cpu_enabled else None # (the pause check above stays single-read to react fast to spikes).
still = [] gt = read_gpu_temp() if settings.gpu_enabled else None
if gt is not None and gt > settings.gpu_resume: ct = read_cpu_temp_avg() if settings.cpu_enabled else None
still.append(("GPU", gt, settings.gpu_resume)) still = []
if ct is not None and ct > settings.cpu_resume: if gt is not None and gt > settings.gpu_resume:
still.append(("CPU", ct, settings.cpu_resume)) still.append(("GPU", gt, settings.gpu_resume))
_dbg(f"cooldown{desc} {int(waited)}s: GPU {_fmt(gt)} CPU {_fmt(ct)} (avg-3) " if ct is not None and ct > settings.cpu_resume:
f"(still hot: {[s[0] for s in still] or 'none'})") still.append(("CPU", ct, settings.cpu_resume))
if not still: _dbg(f"cooldown{desc} {int(waited)}s: GPU {_fmt(gt)} CPU {_fmt(ct)} (avg-3) "
break f"(still hot: {[s[0] for s in still] or 'none'})")
msg = ", ".join(f"{lbl} {t:.0f}°C>{r:.0f}°C" for lbl, t, r in still) if not still:
print(f"[thermal] Cooling{desc}: {msg} — waiting " break
f"({int(waited)}s elapsed)") msg = ", ".join(f"{lbl} {t:.0f}°C>{r:.0f}°C" for lbl, t, r in still)
time.sleep(settings.poll_seconds) _cooldown_update(gt, ct, waited, msg)
waited += settings.poll_seconds print(f"[thermal] Cooling{desc}: {msg} — waiting "
f"({int(waited)}s elapsed)")
time.sleep(settings.poll_seconds)
waited += settings.poll_seconds
finally:
_cooldown_exit()
print(f"[thermal] Temperatures back within safe limits{desc} — resuming " print(f"[thermal] Temperatures back within safe limits{desc} — resuming "
f"after {int(waited)}s") f"after {int(waited)}s")
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment