Add task management, quantization, and hardware telemetry

Tasks / queue management:
- Central in-memory task registry with cooperative cancel, pause/resume,
  and step progress across image/video/audio/text generation + LoRA training
- Tasks admin page (live 2s poll): cancel, interrupt, pause/resume, restart,
  remove; done jobs auto-drop from the list; bounded persisted job history
- Disable interrupted-training recovery via --no-resume-jobs + settings toggle

Quantization / acceleration:
- TurboQuant embedding vector quantization (data-free, inner-product
  preserving): built-in NumPy backend + optional turboquant-py library,
  selectable per embedding model; /v1/embeddings `quantization` param
- llama.cpp KV-cache quantization (cache_type_k/v) for GGUF text models,
  configurable in the Models UI

Hardware telemetry:
- Thermal cooldown state surfaced on the Tasks page (banner + per-task badge)
- Live CPU/GPU/RAM/VRAM usage + temperature panel via /admin/api/system-stats

Docs: API documentation gaps/accuracy pass + Swagger overhaul; DISTRIBUTION.md
implementation spec. Plus I2V LoRA training channel-mismatch fix.
Co-Authored-By: 's avatarClaude Opus 4.8 <noreply@anthropic.com>
parent 9494d1bd
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
......@@ -27,6 +27,7 @@
<a href="{{ root_path }}/docs" class="nav-link" target="_blank">API Docs</a>
{% if is_admin|default(false) %}
<a href="{{ root_path }}/admin/models" class="nav-link {% if '/models' in request.url.path %}active{% endif %}">Models</a>
<a href="{{ root_path }}/admin/tasks" class="nav-link {% if '/tasks' in request.url.path %}active{% endif %}">Tasks</a>
<a href="{{ root_path }}/admin/tokens" class="nav-link {% if '/tokens' in request.url.path %}active{% endif %}">Tokens</a>
<a href="{{ root_path }}/admin/users" class="nav-link {% if '/users' in request.url.path %}active{% endif %}">Users</a>
<a href="{{ root_path }}/admin/archive" class="nav-link {% if '/archive' in request.url.path %}active{% endif %}">Archive</a>
......
......@@ -616,6 +616,28 @@ window.__DEFAULT_WHISPER_SERVER_PATH__ = {{ default_whisper_server_path|tojson }
<label class="form-label">Context size</label>
<input type="number" id="cfg-n-ctx" class="form-input" min="128" step="128" value="2048">
</div>
<div class="form-row" style="margin:0" id="cfg-kv-k-row">
<label class="form-label">KV cache — Keys <span class="muted">(GGUF text; shrinks KV VRAM)</span></label>
<select id="cfg-cache-type-k" class="form-input">
<option value="">Default (f16)</option>
<option value="q8_0">q8_0 (near-lossless, ~2×)</option>
<option value="q5_1">q5_1</option>
<option value="q5_0">q5_0</option>
<option value="q4_1">q4_1 (smallest)</option>
<option value="q4_0">q4_0 (smallest)</option>
</select>
</div>
<div class="form-row" style="margin:0" id="cfg-kv-v-row">
<label class="form-label">KV cache — Values <span class="muted">(sub-8-bit needs Flash Attn)</span></label>
<select id="cfg-cache-type-v" class="form-input">
<option value="">Default (f16)</option>
<option value="q8_0">q8_0 (near-lossless, ~2×)</option>
<option value="q5_1">q5_1</option>
<option value="q5_0">q5_0</option>
<option value="q4_1">q4_1 (smallest)</option>
<option value="q4_0">q4_0 (smallest)</option>
</select>
</div>
<div class="form-row" style="margin:0">
<label class="form-label">Max GPU % <span class="muted">(optional)</span></label>
<input type="number" id="cfg-max-gpu" class="form-input" min="1" max="100" placeholder="e.g. 90">
......@@ -732,6 +754,36 @@ window.__DEFAULT_WHISPER_SERVER_PATH__ = {{ default_whisper_server_path|tojson }
</div>
</div>
<!-- TurboQuant embedding vector quantization (embedding models) -->
<div id="cfg-turboquant-section" style="display:none">
<div class="card-title" style="margin-top:1.25rem">TurboQuant <span class="muted" style="font-weight:normal">(embedding vector quantization — data-free, inner-product preserving)</span></div>
<label style="display:flex;align-items:center;gap:.5rem;cursor:pointer;font-size:13px;margin:.4rem 0">
<input type="checkbox" id="cfg-tq-enabled" onchange="onTurboQuantToggle()"> Enable TurboQuant
<span class="muted">compress embeddings to 2–8 bits/coord for smaller vector stores</span></label>
<div id="cfg-tq-fields" style="display:none">
<div style="display:flex;gap:1rem;flex-wrap:wrap">
<div class="form-row" style="max-width:260px">
<label class="form-label">Backend</label>
<select id="cfg-tq-backend" class="form-input">
<option value="builtin">Built-in (NumPy, always available)</option>
<option value="library">turboquant-py library (QJL)</option>
</select>
<span class="form-hint" id="cfg-tq-backend-hint"></span>
</div>
<div class="form-row" style="max-width:200px">
<label class="form-label">Default bits</label>
<select id="cfg-tq-bits" class="form-input">
<option value="8">8-bit (near-lossless, 3×)</option>
<option value="6">6-bit</option>
<option value="4">4-bit (6×)</option>
<option value="2">2-bit (12×)</option>
</select>
</div>
</div>
<span class="form-hint">A request's <code>quantization</code> field (e.g. <code>turbo4</code>) overrides the default bits. With <code>encoding_format:"base64"</code> the response is compact packed bytes; otherwise it returns the lossy reconstruction as floats.</span>
</div>
</div>
<!-- components -->
<div class="card-title" style="margin-top:1.25rem">Components</div>
<div class="form-row">
......@@ -2268,9 +2320,9 @@ async function refreshLocal(){
loadGlobalSettings();
refreshLocal();
// Toggle the acceleration section as image/video model types are checked/unchecked.
// Toggle the acceleration / TurboQuant sections as model types are checked/unchecked.
document.querySelectorAll('.cfg-type-cb').forEach(cb =>
cb.addEventListener('change', () => _refreshAccelVisibility()));
cb.addEventListener('change', () => { _refreshAccelVisibility(); _refreshTurboQuantVisibility(); }));
// ── Deep-link from Studio: /admin/models?tab=search&q=...&pipeline=...&gguf=...
// ── or: /admin/models?local_cap=CAPABILITY — highlight local models with that capability
......@@ -2633,6 +2685,8 @@ function openCfgModal(idx, cfgIdx){
document.getElementById('cfg-force-vram-update').checked = !!s.force_vram_update;
document.getElementById('cfg-gpu-layers').value = s.n_gpu_layers !== undefined ? s.n_gpu_layers : -1;
document.getElementById('cfg-n-ctx').value = nCtxForEst;
document.getElementById('cfg-cache-type-k').value = s.cache_type_k || '';
document.getElementById('cfg-cache-type-v').value = s.cache_type_v || '';
document.getElementById('cfg-max-instances').value = s.max_instances != null ? s.max_instances : 1;
document.getElementById('cfg-preload-all-instances').checked = !!s.preload_all_instances;
_updatePreloadAllVisibility();
......@@ -2678,6 +2732,7 @@ function openCfgModal(idx, cfgIdx){
document.getElementById('cfg-lora-dir').value = s.lora_model_dir || '';
document.getElementById('cfg-lora-train-base').value = s.lora_train_base_model || '';
_populateAccel(s.acceleration);
_populateTurboQuant(s.turboquant);
openModal('cfg-modal');
}
......@@ -2768,6 +2823,56 @@ function _collectAccel(){
};
}
// ---- TurboQuant (embedding vector quantization) ----
let _tqInfo = null;
async function _loadTurboQuantInfo(){
if (_tqInfo) return _tqInfo;
try {
const r = await fetch(ROOT_PATH + '/admin/api/turboquant-info');
_tqInfo = await r.json();
} catch(e){ _tqInfo = {builtin:true, library:false}; }
return _tqInfo;
}
function _turboQuantApplies(){
return [...document.querySelectorAll('.cfg-type-cb:checked')]
.some(cb => cb.value === 'embedding_models');
}
function _refreshTurboQuantVisibility(){
const section = document.getElementById('cfg-turboquant-section');
if (section) section.style.display = _turboQuantApplies() ? '' : 'none';
}
function onTurboQuantToggle(){
const on = document.getElementById('cfg-tq-enabled').checked;
document.getElementById('cfg-tq-fields').style.display = on ? '' : 'none';
}
async function _populateTurboQuant(t){
await _loadTurboQuantInfo();
_refreshTurboQuantVisibility();
// Reflect library availability in the backend dropdown + hint.
const libOpt = document.querySelector('#cfg-tq-backend option[value="library"]');
const hint = document.getElementById('cfg-tq-backend-hint');
const libOk = !!(_tqInfo && _tqInfo.library);
if (libOpt){ libOpt.disabled = !libOk; libOpt.textContent =
'turboquant-py library (QJL)' + (libOk ? '' : ' — not installed'); }
if (hint) hint.textContent = libOk
? 'turboquant-py detected.'
: 'Install "turboquant-py[torch]" to enable the library backend.';
t = t || {};
document.getElementById('cfg-tq-enabled').checked = !!t.enabled;
document.getElementById('cfg-tq-backend').value =
(t.backend === 'library' && libOk) ? 'library' : 'builtin';
document.getElementById('cfg-tq-bits').value = String(t.bits || 8);
onTurboQuantToggle();
}
function _collectTurboQuant(){
if (!document.getElementById('cfg-tq-enabled').checked) return null;
return {
enabled: true,
backend: document.getElementById('cfg-tq-backend').value || 'builtin',
bits: parseInt(document.getElementById('cfg-tq-bits').value) || 8,
};
}
function _updatePreloadAllVisibility() {
const loadMode = document.getElementById('cfg-load-mode').value;
const maxInst = parseInt(document.getElementById('cfg-max-instances').value) || 1;
......@@ -2834,6 +2939,8 @@ async function saveModelConfig(){
preload_all_instances: document.getElementById('cfg-preload-all-instances').checked,
n_gpu_layers: parseInt(document.getElementById('cfg-gpu-layers').value) || -1,
n_ctx: parseInt(document.getElementById('cfg-n-ctx').value) || 2048,
cache_type_k: document.getElementById('cfg-cache-type-k').value || null,
cache_type_v: document.getElementById('cfg-cache-type-v').value || null,
max_gpu_percent: isNaN(maxGpu) ? null : maxGpu,
manual_ram_gb: isNaN(ramGb) ? null : ramGb,
load_in_4bit: document.getElementById('cfg-4bit').checked,
......@@ -2866,6 +2973,7 @@ async function saveModelConfig(){
balanced_gpu_percent: (document.getElementById('cfg-balanced-gpu-pct').value.trim() === ''
? null : parseFloat(document.getElementById('cfg-balanced-gpu-pct').value)),
acceleration: _collectAccel(),
turboquant: _collectTurboQuant(),
};
try{
const r = await fetch(ROOT_PATH + '/admin/api/model-configure',{
......
......@@ -153,6 +153,24 @@
</div>
</div>
<!-- Background jobs -->
<div class="card mb-0" style="margin-top:1rem">
<div class="card-title">Background Jobs</div>
<span class="form-hint" style="display:block;margin-bottom:.75rem">
Controls how interrupted LoRA training is handled when CoderAI restarts.
Equivalent to the <code>--no-resume-jobs</code> launch flag.
</span>
<div class="form-row" style="margin:0">
<label style="display:flex;align-items:center;gap:.5rem;cursor:pointer">
<input type="checkbox" id="s-jobs-resume">
<span style="font-size:13px;font-weight:500">Resume interrupted training on restart</span>
</label>
<span class="form-hint">When off, a training job that was running at restart is marked
<em>cancelled</em> instead of resuming. Its checkpoint is kept, so you can still
restart it manually from the Tasks page.</span>
</div>
</div>
<div class="card mb-0" style="margin-top:1rem">
<div class="card-title">AISBF Broker</div>
<div class="form-row">
......@@ -328,6 +346,9 @@ async function loadSettings(){
document.getElementById('s-therm-cpu-resume').value = therm.cpu_resume ?? 87;
document.getElementById('s-therm-poll').value = therm.poll_seconds ?? 5;
toggleThermalFields();
// Background jobs
const jobs = d.jobs || {};
document.getElementById('s-jobs-resume').checked = jobs.resume_on_restart !== false;
}catch(e){ showAlert('error','Failed to load settings: '+e.message); }
}
......@@ -363,6 +384,9 @@ async function saveSettings(){
cpu_resume: parseFloat(document.getElementById('s-therm-cpu-resume').value) || 87,
poll_seconds: parseFloat(document.getElementById('s-therm-poll').value) || 5,
},
jobs:{
resume_on_restart: document.getElementById('s-jobs-resume').checked,
},
broker:{
enabled: document.getElementById('s-broker-enabled').checked,
base_url: document.getElementById('s-broker-base-url').value.trim(),
......
{% extends "base.html" %}
{% block title %}Tasks — CoderAI{% endblock %}
{% block content %}
<div class="page-header">
<div>
<h1>Tasks</h1>
<p>Live view of generations and LoRA training. Cancel, interrupt, or restart a job.</p>
</div>
<div class="header-actions">
<span id="queue-summary" class="dim small"></span>
</div>
</div>
<div id="thermal-banner" style="display:none;margin:0 0 1rem;padding:.6rem .85rem;border-radius:8px;
background:rgba(245,158,11,.12);border:1px solid rgba(245,158,11,.4);color:#f59e0b;font-size:13px">
<span style="font-weight:600">❄ Thermal cooldown</span>
<span id="thermal-banner-msg" class="mono"></span>
— running work is paused until the hardware cools.
</div>
<!-- Live hardware telemetry -->
<div id="sys-stats" style="display:grid;grid-template-columns:repeat(auto-fit,minmax(220px,1fr));
gap:.75rem;margin:0 0 1.25rem">
<div class="sys-tile" id="tile-cpu"></div>
<div class="sys-tile" id="tile-gpu"></div>
<div class="sys-tile" id="tile-ram"></div>
<div class="sys-tile" id="tile-vram"></div>
</div>
<style>
.sys-tile{border:1px solid var(--border,#2a2a2a);border-radius:10px;padding:.7rem .85rem;
background:var(--card-bg,rgba(255,255,255,.02))}
.sys-tile .sys-head{display:flex;justify-content:space-between;align-items:baseline;margin-bottom:.45rem}
.sys-tile .sys-name{font-size:12px;font-weight:600;letter-spacing:.03em;text-transform:uppercase;color:var(--text-muted,#9aa0a6)}
.sys-tile .sys-val{font-size:13px;font-weight:600}
.sys-tile .sys-sub{font-size:11px;color:var(--text-muted,#9aa0a6);margin-top:.35rem;display:flex;justify-content:space-between}
.sys-bar{height:8px;border-radius:5px;background:rgba(127,127,127,.18);overflow:hidden}
.sys-bar > span{display:block;height:100%;border-radius:5px;transition:width .4s ease,background .4s ease}
.sys-ok > span{background:#22c55e}.sys-warn > span{background:#f59e0b}.sys-hot > span{background:#ef4444}
.sys-temp-ok{color:#22c55e}.sys-temp-warn{color:#f59e0b}.sys-temp-hot{color:#ef4444}
</style>
<div class="table-wrap">
<table>
<thead>
<tr>
<th>Type</th><th>Name / Model</th><th>Status</th>
<th style="width:220px">Progress</th><th>Started</th><th style="text-align:right">Actions</th>
</tr>
</thead>
<tbody id="tasks-body">
<tr class="empty-row"><td colspan="6">No tasks yet</td></tr>
</tbody>
</table>
</div>
{% endblock %}
{% block scripts %}
<script>
function esc(s) { return String(s == null ? '' : s).replace(/&/g,'&amp;').replace(/</g,'&lt;').replace(/>/g,'&gt;'); }
function fmtTime(s) {
if (!s) return '';
try {
// started_at is unix seconds (float) from the server.
const d = new Date(s * 1000);
return d.toLocaleTimeString(undefined, {hour:'2-digit', minute:'2-digit', second:'2-digit'});
} catch { return ''; }
}
const KIND_LABEL = {training:'Training', image:'Image', video:'Video', audio:'Audio', text:'Text', pipeline:'Pipeline', request:'Request'};
const STATUS_BADGE = {
running:'badge-admin', queued:'badge-user', done:'badge-ok', error:'badge-err',
cancelled:'badge-user', interrupted:'badge-warn'
};
function progressBar(t) {
const total = t.total || 0, step = t.step || 0;
if (!total) {
return t.status === 'running' ? '<span class="dim small">working…</span>' : '<span class="dim small">—</span>';
}
const pct = Math.max(0, Math.min(100, Math.round(step / total * 100)));
return `<div class="progress"><div class="progress-fill" style="width:${pct}%"></div></div>
<span class="dim small">${step}/${total} (${pct}%)</span>`;
}
function actions(t) {
const btns = [];
if (t.paused) {
btns.push(`<button class="btn btn-primary btn-sm" onclick="taskAction('${esc(t.id)}','resume')">Resume</button>`);
} else if (t.pausable) {
btns.push(`<button class="btn btn-ghost btn-sm" onclick="taskAction('${esc(t.id)}','pause')">Pause</button>`);
}
if (t.cancellable) {
const label = t.status === 'running' ? 'Interrupt' : 'Cancel';
const act = t.status === 'running' ? 'interrupt' : 'cancel';
btns.push(`<button class="btn btn-danger btn-sm" onclick="taskAction('${esc(t.id)}','${act}')">${label}</button>`);
}
if (t.restartable) {
btns.push(`<button class="btn btn-ghost btn-sm" onclick="taskAction('${esc(t.id)}','restart')">Restart</button>`);
}
if (!t.active) {
btns.push(`<button class="btn btn-ghost btn-sm" onclick="removeTask('${esc(t.id)}')">Remove</button>`);
}
return btns.join(' ') || '<span class="dim small">—</span>';
}
// ---- Live hardware telemetry ----
function _utilClass(pct){ return pct == null ? 'sys-ok' : (pct >= 90 ? 'sys-hot' : pct >= 70 ? 'sys-warn' : 'sys-ok'); }
function _tempClass(t){ return t == null ? '' : (t >= 90 ? 'sys-temp-hot' : t >= 80 ? 'sys-temp-warn' : 'sys-temp-ok'); }
function _bar(pct){
const p = pct == null ? 0 : Math.max(0, Math.min(100, pct));
return `<div class="sys-bar ${_utilClass(pct)}"><span style="width:${p}%"></span></div>`;
}
function _utilTile(name, pct, temp){
const valTxt = pct == null ? 'n/a' : `${Math.round(pct)}%`;
const tempTxt = temp == null ? '<span class="dim">temp n/a</span>'
: `<span class="${_tempClass(temp)}">${Math.round(temp)}°C</span>`;
return `<div class="sys-head"><span class="sys-name">${name}</span><span class="sys-val">${valTxt}</span></div>`
+ _bar(pct) + `<div class="sys-sub"><span>utilization</span>${tempTxt}</div>`;
}
function _memTile(name, used, total, pct){
const valTxt = (used == null || total == null) ? 'n/a' : `${used.toFixed(1)} / ${total.toFixed(1)} GB`;
const p = pct != null ? pct : (used != null && total ? used/total*100 : null);
return `<div class="sys-head"><span class="sys-name">${name}</span><span class="sys-val">${valTxt}</span></div>`
+ _bar(p) + `<div class="sys-sub"><span>${p == null ? '' : Math.round(p)+'% used'}</span><span></span></div>`;
}
async function loadSystemStats(){
try {
const s = await fetch(ROOT_PATH + '/admin/api/system-stats').then(r => r.json());
const cpu = s.cpu || {}, gpu = s.gpu || {}, ram = s.ram || {}, vram = s.vram || {};
document.getElementById('tile-cpu').innerHTML = _utilTile('CPU', cpu.util, cpu.temp);
document.getElementById('tile-gpu').innerHTML = _utilTile('GPU', gpu.util, gpu.temp);
document.getElementById('tile-ram').innerHTML = _memTile('RAM', ram.used, ram.total, ram.percent);
document.getElementById('tile-vram').innerHTML =
_memTile('VRAM', vram.used, vram.total, vram.percent);
} catch(e){ /* keep last render on transient errors */ }
}
let _refreshing = false;
async function loadTasks() {
if (_refreshing) return;
_refreshing = true;
try {
const data = await fetch(ROOT_PATH + '/admin/api/tasks').then(r => r.json());
const tasks = data.tasks || [];
const q = data.queue || {};
document.getElementById('queue-summary').textContent =
`${q.active || 0} active · ${q.waiting || 0} waiting · max ${q.max_parallel_requests || 0} parallel`;
const therm = data.thermal || {};
const banner = document.getElementById('thermal-banner');
if (therm.active) {
document.getElementById('thermal-banner-msg').textContent = ' ' + (therm.message || '');
banner.style.display = '';
} else {
banner.style.display = 'none';
}
const tbody = document.getElementById('tasks-body');
if (!tasks.length) {
tbody.innerHTML = '<tr class="empty-row"><td colspan="6">No tasks yet</td></tr>';
return;
}
tbody.innerHTML = tasks.map(t => {
const badge = STATUS_BADGE[t.status] || 'badge-dim';
const title = t.title || '(untitled)';
let statusCell;
if (t.cooling) {
statusCell = `<span class="badge badge-warn">❄ Cooling down</span>`
+ `<div class="dim small">${esc(t.cooling_message || 'paused for thermal cooldown')}</div>`;
} else if (t.paused) {
statusCell = `<span class="badge badge-warn">⏸ Paused</span>`
+ `<div class="dim small">suspended — click Resume to continue</div>`;
} else {
statusCell = `<span class="badge ${badge}">${esc(t.status)}</span>`
+ (t.message ? `<div class="dim small">${esc(t.message)}</div>` : '');
}
return `<tr>
<td><span class="badge badge-user">${esc(KIND_LABEL[t.kind] || t.kind)}</span></td>
<td><div class="td-name">${esc(title)}</div><div class="dim small mono">${esc(t.model || '')}</div></td>
<td>${statusCell}</td>
<td>${progressBar(t)}</td>
<td class="dim small">${fmtTime(t.started_at)}</td>
<td style="text-align:right">${actions(t)}</td>
</tr>`;
}).join('');
} catch (e) {
// transient fetch errors during a model swap are fine; keep last render.
} finally {
_refreshing = false;
}
}
async function taskAction(id, action) {
const verb = {cancel:'Cancel', interrupt:'Interrupt', restart:'Restart', pause:'Pause', resume:'Resume'}[action] || action;
// Only confirm destructive actions; pause/resume/restart act immediately.
if ((action === 'cancel' || action === 'interrupt') && !confirm(`${verb} this task?`)) return;
try {
const r = await fetch(ROOT_PATH + '/admin/api/tasks/' + encodeURIComponent(id) + '/' + action, {method:'POST'});
if (!r.ok) {
const e = await r.json().catch(() => ({}));
alert(e.detail || (verb + ' failed'));
}
} catch (e) { alert(e.message); }
loadTasks();
}
async function removeTask(id) {
try {
const r = await fetch(ROOT_PATH + '/admin/api/tasks/' + encodeURIComponent(id), {method:'DELETE'});
if (!r.ok) {
const e = await r.json().catch(() => ({}));
alert(e.detail || 'Remove failed');
}
} catch (e) { alert(e.message); }
loadTasks();
}
loadTasks();
loadSystemStats();
setInterval(loadTasks, 2000);
setInterval(loadSystemStats, 2000);
</script>
{% endblock %}
......@@ -189,25 +189,25 @@ if admin_static_dir.exists():
app.mount("/static/admin", StaticFiles(directory=str(admin_static_dir)), name="admin_static")
# Include routers from submodules
app.include_router(transcriptions_router)
app.include_router(images_router)
app.include_router(tts_router)
app.include_router(text_router)
app.include_router(video_router)
app.include_router(audio_gen_router)
app.include_router(audio_stems_router)
app.include_router(audio_clean_router)
app.include_router(embeddings_router)
app.include_router(pipelines_router)
app.include_router(custom_pipelines_router)
app.include_router(voice_clone_router)
app.include_router(voice_convert_router)
app.include_router(faceswap_router)
app.include_router(characters_router)
app.include_router(loras_router)
app.include_router(environments_router)
app.include_router(spatial_router)
app.include_router(admin_router)
app.include_router(transcriptions_router, tags=["Audio"])
app.include_router(images_router, tags=["Images"])
app.include_router(tts_router, tags=["Audio"])
app.include_router(text_router, tags=["Text"])
app.include_router(video_router, tags=["Video"])
app.include_router(audio_gen_router, tags=["Audio"])
app.include_router(audio_stems_router, tags=["Audio"])
app.include_router(audio_clean_router, tags=["Audio"])
app.include_router(embeddings_router, tags=["Embeddings"])
app.include_router(pipelines_router, tags=["Pipelines"])
app.include_router(custom_pipelines_router, tags=["Pipelines"])
app.include_router(voice_clone_router, tags=["Audio"])
app.include_router(voice_convert_router, tags=["Audio"])
app.include_router(faceswap_router, tags=["Images"])
app.include_router(characters_router, tags=["Characters"])
app.include_router(loras_router, tags=["LoRAs"])
app.include_router(environments_router, tags=["Environments"])
app.include_router(spatial_router, tags=["Spatial / 3D"])
app.include_router(admin_router, tags=["Admin"])
@app.exception_handler(401)
......@@ -222,20 +222,35 @@ async def unauthorized_redirect(request: Request, exc: HTTPException):
return JSONResponse(status_code=401, content={"detail": exc.detail})
@app.get("/v1/models", response_model=ModelList)
from codai.tasks import TaskCancelled, task_registry
@app.exception_handler(TaskCancelled)
async def task_cancelled_handler(request: Request, exc: TaskCancelled):
"""A worker observed its task was cancelled and unwound. Finish the task
(cancelled) and return 499 (client-closed-request style). The task id is
carried on the exception so any generation/training worker can simply
`raise` without bookkeeping."""
tid = exc.args[0] if exc.args else None
if tid:
task_registry.finish(tid, "cancelled", "cancelled by user")
return JSONResponse(status_code=499, content={"detail": "Task cancelled", "task_id": tid})
@app.get("/v1/models", response_model=ModelList, summary="List available models", tags=["Core"])
async def list_models():
"""List available models."""
models = multi_model_manager.list_models()
return ModelList(data=models)
@app.get("/coderai/capabilities")
@app.get("/coderai/capabilities", summary="Server capability document", tags=["Core"])
async def get_broker_capabilities():
"""Return broker capability metadata."""
return build_capabilities_document(hardware=build_hardware_summary())
@app.get("/v1/files/{filename}")
@app.get("/v1/files/{filename}", summary="Download a generated file", tags=["Files"])
async def get_file(filename: str):
"""Serve uploaded/generated files."""
if not global_file_path:
......@@ -256,7 +271,7 @@ _VIDEO_EXTS = {'.mp4', '.webm', '.avi', '.mov'}
_AUDIO_EXTS = {'.wav', '.mp3', '.ogg', '.flac', '.aac', '.m4a'}
@app.get("/v1/archive")
@app.get("/v1/archive", summary="List archived generations", tags=["Files"])
async def list_archive(request: Request):
"""List all generated files in the output directory."""
if not global_file_path or not os.path.isdir(global_file_path):
......@@ -292,7 +307,7 @@ async def list_archive(request: Request):
return {"files": files}
@app.delete("/v1/archive/{filename}")
@app.delete("/v1/archive/{filename}", summary="Delete an archived file", tags=["Files"])
async def delete_archive_file(filename: str):
"""Delete a generated file from the output directory."""
if not global_file_path:
......
......@@ -116,8 +116,15 @@ class AudioCleanupRequest(BaseModel):
model_config = ConfigDict(extra="allow")
@router.post("/v1/audio/cleanup")
@router.post("/v1/audio/cleanup", summary="Clean / restore audio")
async def cleanup_audio(request: AudioCleanupRequest, http_request: Request = None):
"""Restore/clean a noisy audio clip.
Applies any combination of noise reduction, loudness normalization, mains-hum
removal and click/crackle repair. Uses an ML restoration backend when available,
falling back to an ffmpeg-based best-effort path when `fallback_mode` is set.
Returns the cleaned audio plus the backend and quality tier that were used.
"""
try:
audio_bytes = _decode_audio(request.audio)
except Exception as exc:
......
......@@ -31,6 +31,7 @@ from fastapi import APIRouter, HTTPException, Request
from codai.models.manager import multi_model_manager
from codai.pydantic.audiogenrequest import AudioGenerationRequest, AudioGenerationResponse
from codai.tasks import task_registry, TaskCancelled
router = APIRouter()
......@@ -160,7 +161,7 @@ def _detect_audio_gen_type(model_name: str) -> str:
return 'musicgen'
def _generate_audio(pipe, model_name: str, request: AudioGenerationRequest):
def _generate_audio(pipe, model_name: str, request: AudioGenerationRequest, task_id=None):
"""Run generation and return (audio_bytes, ext)."""
import numpy as np, io as _io
......@@ -191,6 +192,9 @@ def _generate_audio(pipe, model_name: str, request: AudioGenerationRequest):
_aud_progress_reset(num_steps, unit="it")
def _aud_step_cb(pipe, step_index, timestep, callback_kwargs):
task_registry.raise_if_cancelled(task_id)
task_registry.wait_if_paused(task_id)
task_registry.step(task_id, step_index + 1)
_aud_progress_step(step_index + 1)
return callback_kwargs
......@@ -222,7 +226,7 @@ def _decode_b64_or_url(data: str) -> bytes:
return base64.b64decode(data)
@router.get("/v1/audio/progress")
@router.get("/v1/audio/progress", summary="Audio generation progress")
async def get_audio_progress():
"""Return current audio generation progress including speed."""
elapsed = time.monotonic() - _aud_progress["started_at"] if _aud_progress["active"] else 0.0
......@@ -241,7 +245,7 @@ async def get_audio_progress():
}
@router.post("/v1/audio/generate", response_model=AudioGenerationResponse)
@router.post("/v1/audio/generate", response_model=AudioGenerationResponse, summary="Generate audio, music or SFX")
async def audio_generate(request: AudioGenerationRequest, http_request: Request = None):
"""
Generate music, sound effects, or ambient audio.
......@@ -274,14 +278,22 @@ async def audio_generate(request: AudioGenerationRequest, http_request: Request
multi_model_manager.models[model_key] = pipe
multi_model_manager.current_model_key = model_key
_tid = task_registry.register(
"audio", title=(request.prompt or "")[:80], model=model_name or "")
task_registry.start(_tid)
try:
audio_bytes, ext = await asyncio.get_event_loop().run_in_executor(
None, _generate_audio, pipe, model_name, request)
None, _generate_audio, pipe, model_name, request, _tid)
except TaskCancelled:
_aud_progress_done()
raise # global handler finishes the task (cancelled) + returns HTTP 499
except Exception as e:
task_registry.finish(_tid, "error", str(e)[:200])
_aud_progress_done()
raise HTTPException(status_code=500, detail=f"Audio generation failed: {e}")
finally:
_aud_progress_done()
task_registry.finish(_tid, "done")
result = _save_audio_response(audio_bytes, ext, http_request)
......
......@@ -166,8 +166,15 @@ class AudioStemRequest(BaseModel):
model_config = ConfigDict(extra="allow")
@router.post("/v1/audio/stems")
@router.post("/v1/audio/stems", summary="Separate audio into stems")
async def separate_stems(request: AudioStemRequest, http_request: Request = None):
"""Split a track into its component stems (source separation).
Separates an input clip according to `stem_mode` (e.g. vocals/instrumental, or a
full 4-stem split). Uses an ML separation provider when available, falling back to
an ffmpeg-based best-effort split when `fallback_mode` is set. Returns one audio
output per stem along with the backend and quality tier used.
"""
try:
audio_bytes = _decode_audio(request.audio)
except Exception as exc:
......
......@@ -419,7 +419,7 @@ def resolve_character_profiles(profile_names: List[str]) -> List[str]:
# ── Endpoints ─────────────────────────────────────────────────────────────────
@router.post("/v1/characters")
@router.post("/v1/characters", summary="Create or replace a character profile")
async def save_character(req: CharacterSaveRequest, _auth=Depends(_require_api_auth)):
"""Save or update a named character profile."""
if not req.name or '/' in req.name or '..' in req.name:
......@@ -430,13 +430,13 @@ async def save_character(req: CharacterSaveRequest, _auth=Depends(_require_api_a
return {"ok": True, "name": meta['name'], "image_count": meta['image_count']}
@router.get("/v1/characters")
@router.get("/v1/characters", summary="List character profiles")
async def list_characters(_auth=Depends(_require_api_auth)):
"""List all saved character profiles (metadata only, no images)."""
return {"characters": _list_characters()}
@router.get("/v1/characters/{name}")
@router.get("/v1/characters/{name}", summary="Get a character profile")
async def get_character(name: str, _auth=Depends(_require_api_auth)):
"""Get a character profile including its reference images as base64."""
meta = _load_character_meta(name)
......@@ -452,7 +452,7 @@ async def get_character(name: str, _auth=Depends(_require_api_auth)):
}
@router.delete("/v1/characters/{name}")
@router.delete("/v1/characters/{name}", summary="Delete a character profile")
async def delete_character(name: str, _auth=Depends(_require_api_auth)):
"""Delete a character profile."""
cdir = _char_dir(name)
......@@ -463,7 +463,7 @@ async def delete_character(name: str, _auth=Depends(_require_api_auth)):
return {"ok": True, "name": name}
@router.patch("/v1/characters/{name}")
@router.patch("/v1/characters/{name}", summary="Update a character profile")
async def patch_character(name: str, req: CharacterPatchRequest, _auth=Depends(_require_api_auth)):
"""Update a character profile: description, add images, or remove images by index."""
meta = _load_character_meta(name)
......@@ -512,7 +512,7 @@ async def patch_character(name: str, req: CharacterPatchRequest, _auth=Depends(_
return {"ok": True, "name": name, "image_count": meta['image_count']}
@router.post("/v1/characters/generate")
@router.post("/v1/characters/generate", summary="Generate character reference images")
async def generate_character(req: CharacterGenerateRequest, request: Request):
"""
Generate a character profile from a text prompt.
......@@ -585,7 +585,7 @@ async def generate_character(req: CharacterGenerateRequest, request: Request):
return {"ok": True, "name": meta["name"], "image_count": meta["image_count"]}
@router.post("/v1/characters/extract")
@router.post("/v1/characters/extract", summary="Extract a character from media")
async def extract_character(req: CharacterExtractRequest):
"""
Extract a character profile from source images and/or videos.
......
......@@ -380,7 +380,7 @@ class AudioMusicDubRequest(BaseModel):
model_config = ConfigDict(extra='allow')
@router.get('/v1/pipelines/custom')
@router.get('/v1/pipelines/custom', summary="List saved custom pipelines")
async def list_custom_pipelines():
"""List all saved custom pipeline definitions."""
from codai.admin.routes import config_manager
......@@ -389,7 +389,7 @@ async def list_custom_pipelines():
return {'pipelines': config_manager.pipelines_data}
@router.get('/v1/pipelines/step-types')
@router.get('/v1/pipelines/step-types', summary="List available pipeline step types")
async def list_step_types():
"""List available step types with their parameter schemas."""
return {
......@@ -400,7 +400,7 @@ async def list_step_types():
}
@router.post('/v1/pipelines/custom')
@router.post('/v1/pipelines/custom', summary="Create a custom pipeline")
async def create_custom_pipeline(pipeline: PipelineDefinition):
"""Save a new custom pipeline definition."""
from codai.admin.routes import config_manager
......@@ -416,7 +416,7 @@ async def create_custom_pipeline(pipeline: PipelineDefinition):
return {'created': True, 'pipeline': data}
@router.put('/v1/pipelines/custom/{pipeline_id}')
@router.put('/v1/pipelines/custom/{pipeline_id}', summary="Update a custom pipeline")
async def update_custom_pipeline(pipeline_id: str, pipeline: PipelineDefinition):
"""Update an existing custom pipeline."""
from codai.admin.routes import config_manager
......@@ -433,7 +433,7 @@ async def update_custom_pipeline(pipeline_id: str, pipeline: PipelineDefinition)
return {'updated': True, 'pipeline': data}
@router.delete('/v1/pipelines/custom/{pipeline_id}')
@router.delete('/v1/pipelines/custom/{pipeline_id}', summary="Delete a custom pipeline")
async def delete_custom_pipeline(pipeline_id: str):
"""Delete a custom pipeline."""
from codai.admin.routes import config_manager
......@@ -447,7 +447,7 @@ async def delete_custom_pipeline(pipeline_id: str):
return {'deleted': True, 'id': pipeline_id}
@router.post('/v1/pipelines/custom/{pipeline_id}/run')
@router.post('/v1/pipelines/custom/{pipeline_id}/run', summary="Run a saved custom pipeline")
async def run_custom_pipeline(pipeline_id: str, body: PipelineRunRequest, http_request: Request = None):
"""Execute a saved custom pipeline."""
from codai.admin.routes import config_manager
......@@ -459,14 +459,20 @@ async def run_custom_pipeline(pipeline_id: str, body: PipelineRunRequest, http_r
return await _execute_pipeline(pipeline_def, body.input or '', http_request)
@router.post('/v1/pipelines/run')
@router.post('/v1/pipelines/run', summary="Run an inline pipeline definition")
async def run_inline_pipeline(pipeline: PipelineDefinition, http_request: Request = None):
"""Execute an inline pipeline definition without saving it."""
return await _execute_pipeline(pipeline.model_dump(), '', http_request)
@router.post('/v1/pipelines/audio-understand')
@router.post('/v1/pipelines/audio-understand', summary="Transcribe and analyze audio")
async def run_audio_understanding(request: AudioUnderstandRequest, http_request: Request = None):
"""Transcribe and analyze an audio clip in one pass.
Convenience pipeline that transcribes the input audio and then reasons over the
transcript (summary/understanding) using the configured text model. Returns the
transcript together with the model's analysis.
"""
if not request.audio:
raise HTTPException(status_code=400, detail='Provide audio input')
......@@ -543,8 +549,14 @@ async def run_full_music_dub(request: AudioMusicDubRequest, http_request: Reques
}
@router.post('/v1/pipelines/audio-music-dub')
@router.post('/v1/pipelines/audio-music-dub', summary="Dub a song into another language")
async def run_audio_music_dub(request: AudioMusicDubRequest, http_request: Request = None):
"""Dub a song into another language while preserving the backing music.
Splits the track into vocals and instrumental, transcribes and translates the
lyrics, re-sings/voice-converts the translated vocals, then remixes them over the
original instrumental. Returns every intermediate stem plus the final mixed result.
"""
if not request.audio:
raise HTTPException(status_code=400, detail='Provide audio input')
......
......@@ -101,7 +101,7 @@ def _embed_texts(model_obj, texts: List[str], dimensions=None) -> List[List[floa
return results
@router.post("/v1/embeddings", response_model=EmbeddingsResponse)
@router.post("/v1/embeddings", response_model=EmbeddingsResponse, summary="Create embeddings")
async def create_embeddings(request: EmbeddingsRequest, http_request: Request = None):
"""
OpenAI-compatible embeddings endpoint.
......@@ -116,10 +116,11 @@ async def create_embeddings(request: EmbeddingsRequest, http_request: Request =
model_key = model_info['model_key']
model_obj = model_info.get('model_object')
_emb_cfg = (multi_model_manager.config.get(f"embedding:{model_name}")
or multi_model_manager.config.get(model_name) or {})
if model_obj is None:
device = _derive_device()
_emb_cfg = (multi_model_manager.config.get(f"embedding:{model_name}")
or multi_model_manager.config.get(model_name) or {})
try:
model_obj = await asyncio.get_event_loop().run_in_executor(
None, _load_embedding_model, model_name, device, _emb_cfg)
......@@ -136,7 +137,59 @@ async def create_embeddings(request: EmbeddingsRequest, http_request: Request =
except Exception as e:
raise HTTPException(status_code=500, detail=f"Embedding failed: {e}")
if request.encoding_format == 'base64':
# Optional TurboQuant vector quantization (data-free, inner-product preserving).
# The per-model config block (turboquant: {enabled, backend, bits}) is the
# source of truth for enable/disable + which implementation to use; the
# per-request `quantization` field triggers it and can override the bit width.
from codai.models import turboquant as _tq
_raw = _emb_cfg.get('_raw_cfg') if isinstance(_emb_cfg.get('_raw_cfg'), dict) else {}
tq_cfg = _emb_cfg.get('turboquant') or _raw.get('turboquant') or {}
tq_enabled = tq_cfg.get('enabled', None) # None = no explicit model setting
tq_backend = (tq_cfg.get('backend') or 'builtin')
quant_meta = None
quant_bits = None
req_spec = getattr(request, 'quantization', None)
if not req_spec and tq_enabled and tq_cfg.get('bits'):
req_spec = f"turbo{tq_cfg.get('bits')}" # model-configured default
if req_spec:
if tq_enabled is False:
raise HTTPException(
status_code=400,
detail="TurboQuant is disabled for this model (enable it in the "
"model configuration).")
quant_bits = _tq._parse_quant_spec(req_spec)
if quant_bits is None:
raise HTTPException(
status_code=400,
detail=f"Unsupported quantization '{req_spec}' "
"(use 'turbo'/'turbo8'/'turbo6'/'turbo4'/'turbo2')")
if quant_bits is not None and request.encoding_format == 'base64':
# Compact wire form: each embedding is base64 of [f16 norm][packed codes].
# The compact packing is the built-in wire format regardless of backend
# (the upstream library exposes its own opaque store, not per-vector blobs).
blobs, meta = await asyncio.get_event_loop().run_in_executor(
None, _tq.quantize_base64, vectors, quant_bits)
data = [EmbeddingObject(index=i, embedding=b) for i, b in enumerate(blobs)]
quant_meta = {
"method": meta.method, "bits": meta.bits, "seed": meta.seed,
"dim": meta.dim, "dim_padded": meta.dim_padded, "radius": meta.radius,
"bytes_per_vector": meta.bytes_per_vector, "backend": "builtin",
"layout": "base64([float16 norm][packbits(rotated b-bit codes, MSB-first per numpy.packbits)])",
}
elif quant_bits is not None:
# Lossy reconstruction returned as plain floats (quantized-store fidelity).
try:
vectors = await asyncio.get_event_loop().run_in_executor(
None, lambda: _tq.reconstruct(vectors, quant_bits, backend=tq_backend))
except RuntimeError as e:
raise HTTPException(status_code=400, detail=str(e))
data = [EmbeddingObject(index=i, embedding=v) for i, v in enumerate(vectors)]
eff_backend = tq_backend if tq_backend != 'auto' else _tq.backend_name()
quant_meta = {"method": "turboquant", "bits": quant_bits,
"encoding": "float-reconstruction", "backend": eff_backend}
elif request.encoding_format == 'base64':
import struct
data = [EmbeddingObject(
index=i,
......@@ -146,8 +199,11 @@ async def create_embeddings(request: EmbeddingsRequest, http_request: Request =
data = [EmbeddingObject(index=i, embedding=v) for i, v in enumerate(vectors)]
total_tokens = sum(len(t.split()) for t in texts)
return EmbeddingsResponse(
resp = EmbeddingsResponse(
data=data,
model=request.model,
usage={"prompt_tokens": total_tokens, "total_tokens": total_tokens},
)
\ No newline at end of file
)
if quant_meta is not None:
resp.quantization = quant_meta
return resp
\ No newline at end of file
......@@ -307,7 +307,7 @@ def resolve_environment_profiles(profile_names: List[str]) -> List[str]:
# ── Endpoints ─────────────────────────────────────────────────────────────────
@router.post("/v1/environments")
@router.post("/v1/environments", summary="Create or replace an environment profile")
async def save_environment(req: EnvironmentSaveRequest, _auth=Depends(_require_api_auth)):
"""Save or update a named environment profile."""
if not req.name or '/' in req.name or '..' in req.name:
......@@ -318,13 +318,13 @@ async def save_environment(req: EnvironmentSaveRequest, _auth=Depends(_require_a
return {"ok": True, "name": meta['name'], "image_count": meta['image_count']}
@router.get("/v1/environments")
@router.get("/v1/environments", summary="List environment profiles")
async def list_environments(_auth=Depends(_require_api_auth)):
"""List all saved environment profiles (metadata only)."""
return {"environments": _list_environments()}
@router.get("/v1/environments/{name}")
@router.get("/v1/environments/{name}", summary="Get an environment profile")
async def get_environment(name: str, _auth=Depends(_require_api_auth)):
"""Get an environment profile including its reference images as base64."""
meta = _load_environment_meta(name)
......@@ -340,7 +340,7 @@ async def get_environment(name: str, _auth=Depends(_require_api_auth)):
}
@router.delete("/v1/environments/{name}")
@router.delete("/v1/environments/{name}", summary="Delete an environment profile")
async def delete_environment(name: str, _auth=Depends(_require_api_auth)):
"""Delete an environment profile."""
edir = _env_dir(name)
......@@ -351,7 +351,7 @@ async def delete_environment(name: str, _auth=Depends(_require_api_auth)):
return {"ok": True, "name": name}
@router.patch("/v1/environments/{name}")
@router.patch("/v1/environments/{name}", summary="Update an environment profile")
async def patch_environment(name: str, req: EnvironmentPatchRequest, _auth=Depends(_require_api_auth)):
"""Update an environment profile: description, add images, or remove images by index."""
meta = _load_environment_meta(name)
......@@ -398,7 +398,7 @@ async def patch_environment(name: str, req: EnvironmentPatchRequest, _auth=Depen
return {"ok": True, "name": name, "image_count": meta['image_count']}
@router.post("/v1/environments/generate")
@router.post("/v1/environments/generate", summary="Generate environment reference images")
async def generate_environment(req: EnvironmentGenerateRequest, request: Request):
"""
Generate an environment profile from a text prompt.
......@@ -471,7 +471,7 @@ async def generate_environment(req: EnvironmentGenerateRequest, request: Request
return {"ok": True, "name": meta["name"], "image_count": meta["image_count"]}
@router.post("/v1/environments/extract")
@router.post("/v1/environments/extract", summary="Extract an environment from media")
async def extract_environment(req: EnvironmentExtractRequest):
"""
Extract an environment profile from source images and/or videos.
......
......@@ -144,7 +144,7 @@ class FaceSwapRequest(BaseModel):
# Endpoint
# ---------------------------------------------------------------------------
@router.post('/v1/images/faceswap')
@router.post('/v1/images/faceswap', summary="Swap faces between images")
async def faceswap(request: FaceSwapRequest, http_request: Request = None):
"""
Swap the face from source_face into every face found in target.
......
......@@ -37,6 +37,7 @@ from pydantic import BaseModel, ConfigDict
from codai.models.manager import multi_model_manager
from codai.pydantic.imagerequest import ImageGenerationRequest
from codai.api.state import get_load_mode
from codai.tasks import task_registry, TaskCancelled
# =============================================================================
......@@ -756,6 +757,13 @@ async def _generate_with_diffusers(pipeline, request, global_args, http_request=
_progress_reset(num_steps)
# Register this generation as a cancellable task (live view + cooperative
# cancel via the step callback below).
_tid = task_registry.register(
"image", title=(request.prompt or "")[:80],
model=getattr(request, 'model', '') or '', total=num_steps)
task_registry.start(_tid)
# ------------------------------------------------------------------
# Prompt embedding cache
# Try to encode the prompt once and reuse the embeddings.
......@@ -830,6 +838,11 @@ async def _generate_with_diffusers(pipeline, request, global_args, http_request=
embed_kwargs = {}
def _step_cb(pipe, step_index, timestep, callback_kwargs):
# Cooperative cancellation: abort at the next step boundary if cancelled.
task_registry.raise_if_cancelled(_tid)
# Cooperative pause: block here while the user has paused this task.
task_registry.wait_if_paused(_tid)
task_registry.step(_tid, step_index + 1)
_progress_step(step_index + 1)
# Mid-generation thermal checkpoint: pause between denoise steps if too hot.
try:
......@@ -912,10 +925,21 @@ async def _generate_with_diffusers(pipeline, request, global_args, http_request=
try:
result = await asyncio.to_thread(pipeline, **call_kwargs)
except TaskCancelled:
_progress_done()
raise # global handler finishes the task (cancelled) + returns HTTP 499
except TypeError:
# Older pipeline that doesn't support callback_on_step_end
call_kwargs.pop('callback_on_step_end', None)
result = await asyncio.to_thread(pipeline, **call_kwargs)
try:
result = await asyncio.to_thread(pipeline, **call_kwargs)
except TaskCancelled:
_progress_done()
raise
except Exception as e:
task_registry.finish(_tid, "error", str(e)[:200])
_progress_done()
raise
finally:
_progress_done()
......@@ -967,6 +991,7 @@ async def _generate_with_diffusers(pipeline, request, global_args, http_request=
except Exception:
pass
task_registry.finish(_tid, "done")
return {
"created": timestamp,
"data": images,
......@@ -1014,7 +1039,17 @@ async def _generate_with_sdcpp(sd_model, request, global_args, http_request=None
_progress_reset(steps)
# sd.cpp runs the whole diffusion inside one C call, so it can't be aborted
# mid-step (raising from its progress callback won't reliably unwind the C
# extension). We still register the task for visibility + step progress; a
# cancel takes effect when control returns to Python.
_tid = task_registry.register(
"image", title=(request.prompt or "")[:80],
model=getattr(request, 'model', '') or '', total=steps)
task_registry.start(_tid)
def _sdcpp_progress(step: int, total: int, elapsed: float):
task_registry.step(_tid, step)
_progress_step(step)
# Use request seed if provided, otherwise use CLI default seed
......@@ -1045,9 +1080,13 @@ async def _generate_with_sdcpp(sd_model, request, global_args, http_request=None
seed=seed if seed is not None else 42,
batch_count=request.n if request.n else 1,
)
except Exception as e:
task_registry.finish(_tid, "error", str(e)[:200])
_progress_done()
raise
finally:
_progress_done()
# Small delay to let Vulkan driver settle after generation
time.sleep(0.1)
......@@ -1087,6 +1126,7 @@ async def _generate_with_sdcpp(sd_model, request, global_args, http_request=None
except Exception:
pass
task_registry.finish(_tid, "done")
return {
"created": int(time.time()),
"data": images
......@@ -1185,7 +1225,7 @@ def _load_sdcpp_model(model_path: str, global_args, model_config: dict = None):
router = APIRouter()
@router.get("/v1/images/progress")
@router.get("/v1/images/progress", summary="Image generation progress")
async def get_image_progress():
"""Return current image generation step progress including speed."""
elapsed = _time.monotonic() - _gen_progress["started_at"] if _gen_progress["active"] else 0.0
......@@ -1202,7 +1242,7 @@ async def get_image_progress():
}
@router.post("/v1/images/generations")
@router.post("/v1/images/generations", summary="Generate images (text-to-image)")
async def create_image_generation(request: ImageGenerationRequest, http_request: Request = None):
"""
Image generation endpoint (OpenAI-compatible).
......@@ -1497,7 +1537,7 @@ def _load_img2img_pipeline(model_name: str, global_args):
raise
@router.post("/v1/images/edits")
@router.post("/v1/images/edits", summary="Edit an image (instruction / img2img)")
async def create_image_edit(request: ImageEditRequest, http_request: Request = None):
"""
Image-to-image editing endpoint (OpenAI-compatible).
......@@ -1638,7 +1678,7 @@ def _load_inpaint_pipeline(model_name: str, global_args):
raise
@router.post("/v1/images/inpaint")
@router.post("/v1/images/inpaint", summary="Inpaint a masked region")
async def create_image_inpaint(request: ImageInpaintRequest, http_request: Request = None):
"""Inpaint a masked region of an image (OpenAI-compatible extension)."""
global global_args
......@@ -1750,7 +1790,7 @@ def _run_upscale(upscaler, image_bytes: bytes, scale: int):
return img.resize((w * scale, h * scale), PILImage.LANCZOS)
@router.post("/v1/images/upscale")
@router.post("/v1/images/upscale", summary="Upscale an image")
async def create_image_upscale(request: ImageUpscaleRequest, http_request: Request = None):
"""Upscale an image using Real-ESRGAN or PIL LANCZOS fallback."""
global global_args
......@@ -1862,7 +1902,7 @@ def _resolve_spatial_model(requested: Optional[str], capability: str) -> Optiona
return None
@router.post("/v1/images/depth")
@router.post("/v1/images/depth", summary="Estimate a depth map")
async def create_image_depth(request: ImageDepthRequest, http_request: Request = None):
"""Estimate depth map from an image."""
global global_args
......@@ -1968,7 +2008,7 @@ def _run_segmentation(seg_model, image_bytes: bytes, points, boxes):
return PILImage.fromarray(out)
@router.post("/v1/images/segment")
@router.post("/v1/images/segment", summary="Segment an image")
async def create_image_segment(request: ImageSegmentRequest, http_request: Request = None):
"""Segment objects in an image using SAM or similar models."""
global global_args
......@@ -2035,7 +2075,7 @@ def _run_deblur(image_bytes: bytes, strength: float) -> "PILImage.Image":
return PILImage.fromarray((sharpened * 255).astype(np.uint8))
@router.post("/v1/images/deblur")
@router.post("/v1/images/deblur", summary="Deblur an image")
async def create_image_deblur(request: ImageDeblurRequest, http_request: Request = None):
"""Remove blur from an image using Wiener deconvolution and unsharp masking."""
raw = base64.b64decode(request.image.split(',', 1)[-1] if ',' in request.image else request.image)
......@@ -2093,7 +2133,7 @@ def _run_unpixelate(image_bytes: bytes, scale: int, model_path: Optional[str]) -
return PILImage.fromarray(out_arr)
@router.post("/v1/images/unpixelate")
@router.post("/v1/images/unpixelate", summary="Restore a pixelated image")
async def create_image_unpixelate(request: ImageUnpixelateRequest, http_request: Request = None):
"""Remove pixelation / upscale with detail recovery using Real-ESRGAN."""
raw = base64.b64decode(request.image.split(',', 1)[-1] if ',' in request.image else request.image)
......@@ -2155,7 +2195,7 @@ def _generate_clothing_mask(img_arr) -> "np.ndarray":
return fg_mask
@router.post("/v1/images/outfit")
@router.post("/v1/images/outfit", summary="Change outfit / clothing")
async def create_image_outfit(request: ImageOutfitRequest, http_request: Request = None):
"""Change the outfit/clothing in an image or video using inpainting."""
global global_args
......
This diff is collapsed.
......@@ -117,7 +117,7 @@ class ImageToVideoPipelineRequest(BaseModel):
model_config = ConfigDict(extra="allow")
@router.post("/v1/pipelines/image-to-video")
@router.post("/v1/pipelines/image-to-video", summary="Image-to-video pipeline")
async def pipeline_image_to_video(request: ImageToVideoPipelineRequest, http_request: Request = None):
"""Generate an image then animate it into a video."""
steps = []
......@@ -197,7 +197,7 @@ class VideoDubPipelineRequest(BaseModel):
model_config = ConfigDict(extra="allow")
@router.post("/v1/pipelines/video-dub")
@router.post("/v1/pipelines/video-dub", summary="Video dubbing pipeline")
async def pipeline_video_dub(request: VideoDubPipelineRequest, http_request: Request = None):
"""Transcribe → translate → TTS dub → burn subtitles."""
body = {
......@@ -240,7 +240,7 @@ class StoryPipelineRequest(BaseModel):
model_config = ConfigDict(extra="allow")
@router.post("/v1/pipelines/story")
@router.post("/v1/pipelines/story", summary="Story pipeline (multi-scene)")
async def pipeline_story(request: StoryPipelineRequest, http_request: Request = None):
"""LLM generates script → image per scene → animate first scene → optional TTS narration."""
n = min(request.num_scenes or 3, 6)
......@@ -377,7 +377,7 @@ class AudioDubPipelineRequest(BaseModel):
model_config = ConfigDict(extra="allow")
@router.post("/v1/pipelines/audio-dub")
@router.post("/v1/pipelines/audio-dub", summary="Audio dubbing pipeline")
async def pipeline_audio_dub(request: AudioDubPipelineRequest, http_request: Request = None):
"""Transcribe → (translate) → clone voice → replace audio track."""
import os, tempfile, subprocess, base64
......
......@@ -499,7 +499,7 @@ class ImageTo3DRequest(BaseModel):
model_config = ConfigDict(extra="allow")
@router.post("/v1/images/to3d")
@router.post("/v1/images/to3d", summary="Image to 3D model")
async def image_to_3d(request: ImageTo3DRequest, http_request: Request = None):
"""Convert a 2D image to a 3D representation.
......@@ -567,7 +567,7 @@ class ImageFrom3DRequest(BaseModel):
model_config = ConfigDict(extra="allow")
@router.post("/v1/images/from3d")
@router.post("/v1/images/from3d", summary="Render a 3D model to an image")
async def image_from_3d(request: ImageFrom3DRequest, http_request: Request = None):
"""Render a 3D model (GLB/OBJ) to a 2D PNG image from a specified camera angle."""
raw = _decode_b64(request.model_data)
......@@ -600,7 +600,7 @@ class VideoTo3DRequest(BaseModel):
model_config = ConfigDict(extra="allow")
@router.post("/v1/video/to3d")
@router.post("/v1/video/to3d", summary="Video to 3D model")
async def video_to_3d(request: VideoTo3DRequest, http_request: Request = None):
"""Convert a 2D video to a 3D video frame-by-frame.
......@@ -641,7 +641,7 @@ class VideoFrom3DRequest(BaseModel):
model_config = ConfigDict(extra="allow")
@router.post("/v1/video/from3d")
@router.post("/v1/video/from3d", summary="Render a 3D model to a video")
async def video_from_3d(request: VideoFrom3DRequest, http_request: Request = None):
"""Render a 3D model as a 360° turntable video."""
raw = _decode_b64(request.model_data)
......@@ -674,7 +674,7 @@ class Generate3DRequest(BaseModel):
model_config = ConfigDict(extra="allow")
@router.post("/v1/3d/generate")
@router.post("/v1/3d/generate", summary="Generate a 3D model from a prompt")
async def generate_3d(request: Generate3DRequest, http_request: Request = None):
"""Generate a 3D model (GLB) from a text prompt and/or an image.
......
......@@ -32,6 +32,7 @@ logger = logging.getLogger(__name__)
# Import from codai modules
from codai.models.manager import ModelManager, WhisperServerManager, MultiModelManager, model_manager, multi_model_manager
from codai.queue.manager import QueueManager, queue_manager
from codai.tasks import task_registry
from codai.api.prompt_cache import prompt_cache_manager
from codai.pydantic.textrequest import ChatCompletionRequest, ToolFunction, Tool
from codai.models.parser import filter_malformed_content, filter_repetition, format_tools_for_prompt, cleanup_control_tokens, OpenAIFormatter, ModelParserAdapter, ToolCallParser
......@@ -92,7 +93,7 @@ def set_grammar_guided_gen(enabled: bool):
router = APIRouter()
@router.post("/v1/chat/completions")
@router.post("/v1/chat/completions", summary="Chat completions")
async def chat_completions(request: ChatCompletionRequest, http_request: Request = None):
"""Chat completions endpoint with streaming and tool support."""
......@@ -1248,7 +1249,8 @@ async def stream_chat_response(
completion_id = f"chatcmpl-{uuid.uuid4().hex}"
created = int(time.time())
request_id = f"req-{uuid.uuid4().hex[:8]}"
_tid = None
generated_text = ""
# Check if model is loaded - if not, notify waiting clients
......@@ -1320,6 +1322,9 @@ async def stream_chat_response(
# Mark as starting processing
await queue_manager.start_processing(request_id, model_name)
_tid = task_registry.register("text", title=(model_name or "chat"),
model=model_name or "", task_id=request_id)
task_registry.start(_tid)
# Send "Model starting" message
data = {
......@@ -1374,6 +1379,9 @@ async def stream_chat_response(
response_format=response_format,
enable_thinking=enable_thinking,
):
# Cooperative cancellation: stop streaming if the task was cancelled.
if task_registry.is_cancelled(_tid):
break
chunk_count += 1
# Always filter malformed content (regex-based, works per-chunk)
filtered_chunk = filter_malformed_content(chunk)
......@@ -1580,6 +1588,9 @@ async def stream_chat_response(
finally:
# Always clean up queue state
await queue_manager.finish_processing()
if _tid:
task_registry.finish(
_tid, "cancelled" if task_registry.is_cancelled(_tid) else "done")
async def generate_chat_response(
messages: List[Dict],
......@@ -1789,7 +1800,7 @@ async def generate_chat_response(
from codai.pydantic.textrequest import CompletionRequest
@router.post("/v1/completions")
@router.post("/v1/completions", summary="Legacy text completions")
async def completions(request: CompletionRequest):
"""Legacy text completions endpoint (for backward compatibility)."""
# Get the model for this request
......
......@@ -119,7 +119,7 @@ def _format_response(fmt: str, text: str, segments: list):
router = APIRouter()
@router.post("/v1/audio/transcriptions")
@router.post("/v1/audio/transcriptions", summary="Transcribe audio to text")
async def create_transcription(
model: str = Form(...),
file: UploadFile = File(...),
......
......@@ -64,7 +64,7 @@ class TTSResponse(BaseModel):
model_config = ConfigDict(extra="allow")
@router.post("/v1/audio/speech")
@router.post("/v1/audio/speech", summary="Text-to-speech synthesis")
async def create_speech(request: TTSRequest, http_request: Request = None):
"""
Text-to-speech endpoint (OpenAI-compatible).
......
......@@ -45,6 +45,7 @@ from codai.pydantic.videorequest import (
CharacterDialogLine,
)
from codai.api.images import _disable_safety_checker
from codai.tasks import task_registry, TaskCancelled
router = APIRouter()
......@@ -627,7 +628,15 @@ def _generate_sdcpp_video(sd_model, request, model_cfg=None):
_vid_progress_reset(steps)
# sd.cpp runs the whole diffusion in one C call → not interruptible mid-step;
# register for visibility + step progress (cancel applies once back in Python).
_tid = task_registry.register(
"video", title=(prompt or mode or "")[:80],
model=getattr(request, 'model', '') or '', total=steps)
task_registry.start(_tid)
def _progress_cb(step: int, total: int, elapsed: float):
task_registry.step(_tid, step)
_vid_progress_step(step)
kw = {
......@@ -654,8 +663,14 @@ def _generate_sdcpp_video(sd_model, request, model_cfg=None):
kw['init_image'] = _pil_from_b64(init_src)
kw['end_image'] = _pil_from_b64(request.end_image)
frames = sd_model.generate_video(**kw)
try:
frames = sd_model.generate_video(**kw)
except Exception as e:
task_registry.finish(_tid, "error", str(e)[:200])
_vid_progress_done()
raise
_vid_progress_done()
task_registry.finish(_tid, "done")
return list(frames), fps
......@@ -1483,7 +1498,17 @@ def _generate_video(pipe, request: VideoGenerationRequest):
_vid_progress_reset(kw['num_inference_steps'])
_tid = task_registry.register(
"video", title=(request.prompt or mode or "")[:80],
model=getattr(request, 'model', '') or '', total=kw['num_inference_steps'])
task_registry.start(_tid)
def _vid_step_cb(pipe, step_index, timestep, callback_kwargs):
# Cooperative cancellation: abort at the next step boundary if cancelled.
task_registry.raise_if_cancelled(_tid)
# Cooperative pause: block here while the user has paused this task.
task_registry.wait_if_paused(_tid)
task_registry.step(_tid, step_index + 1)
_vid_progress_step(step_index + 1)
# Mid-generation thermal checkpoint: pause between denoise steps if the
# CPU/GPU went over the limit during this (multi-minute) generation.
......@@ -1547,8 +1572,17 @@ def _generate_video(pipe, request: VideoGenerationRequest):
# previous clip's (common within a match) and only swapping when they differ.
# Left loaded after the run so the next clip with the same set pays nothing.
_sync_video_loras(pipe, getattr(request, 'loras', None))
frames = _run_pipeline(pipe, kw)
try:
frames = _run_pipeline(pipe, kw)
except TaskCancelled:
_vid_progress_done()
raise # global handler finishes the task (cancelled) + returns HTTP 499
except Exception as e:
task_registry.finish(_tid, "error", str(e)[:200])
_vid_progress_done()
raise
_vid_progress_done()
task_registry.finish(_tid, "done")
return frames, fps
......@@ -1979,7 +2013,7 @@ def _translate_srt(srt_path: str, target_lang: str, temps: list) -> str:
# Progress endpoint
# =============================================================================
@router.get("/v1/video/progress")
@router.get("/v1/video/progress", summary="Video generation progress")
async def get_video_progress():
"""Return current video generation step progress including speed."""
elapsed = time.monotonic() - _vid_progress["started_at"] if _vid_progress["active"] else 0.0
......@@ -2000,7 +2034,7 @@ async def get_video_progress():
# Main generation endpoint
# =============================================================================
@router.post("/v1/video/generations", response_model=VideoGenerationResponse)
@router.post("/v1/video/generations", response_model=VideoGenerationResponse, summary="Generate video")
async def video_generations(request: VideoGenerationRequest,
http_request: Request = None):
"""
......@@ -2269,7 +2303,7 @@ async def video_generations(request: VideoGenerationRequest,
# Video upscale endpoint
# =============================================================================
@router.post("/v1/video/upscale")
@router.post("/v1/video/upscale", summary="Upscale a video")
async def video_upscale(request: VideoUpscaleRequest, http_request: Request = None):
"""
Upscale a video using ffmpeg lanczos or Real-ESRGAN.
......@@ -2299,7 +2333,7 @@ async def video_upscale(request: VideoUpscaleRequest, http_request: Request = No
# Subtitle generation endpoint
# =============================================================================
@router.post("/v1/video/subtitle")
@router.post("/v1/video/subtitle", summary="Subtitle / caption a video")
async def video_subtitle(request: VideoSubtitleRequest, http_request: Request = None):
"""
Generate subtitles for a video.
......@@ -2353,7 +2387,7 @@ async def video_subtitle(request: VideoSubtitleRequest, http_request: Request =
# Frame interpolation endpoint
# =============================================================================
@router.post("/v1/video/interpolate")
@router.post("/v1/video/interpolate", summary="Interpolate video frames")
async def video_interpolate(request: VideoInterpolateRequest, http_request: Request = None):
"""
Increase video FPS via frame interpolation.
......@@ -2400,7 +2434,7 @@ async def video_interpolate(request: VideoInterpolateRequest, http_request: Requ
# Video dubbing endpoint
# =============================================================================
@router.post("/v1/video/dub")
@router.post("/v1/video/dub", summary="Dub a video")
async def video_dub(request: VideoDubRequest, http_request: Request = None):
"""
Translate and re-dub a video.
......
......@@ -185,13 +185,13 @@ class VoicePatchRequest(BaseModel):
# Voice profile management
# ---------------------------------------------------------------------------
@router.get("/v1/audio/voices")
@router.get("/v1/audio/voices", summary="List voice profiles")
async def list_voices():
"""List all saved voice profiles."""
return {"voices": _list_voices()}
@router.post("/v1/audio/voices")
@router.post("/v1/audio/voices", summary="Create a voice profile")
async def create_voice(
name: str = Form(...),
transcript: str = Form(...),
......@@ -216,7 +216,7 @@ async def create_voice(
return {"created": True, "voice": meta}
@router.delete("/v1/audio/voices/{name}")
@router.delete("/v1/audio/voices/{name}", summary="Delete a voice profile")
async def delete_voice(name: str):
"""Delete a saved voice profile."""
import shutil
......@@ -227,7 +227,7 @@ async def delete_voice(name: str):
return {"deleted": True, "name": name}
@router.patch("/v1/audio/voices/{name}")
@router.patch("/v1/audio/voices/{name}", summary="Update a voice profile")
async def patch_voice(name: str, req: VoicePatchRequest):
"""Update description, transcript, or reference audio of a saved voice profile."""
meta = _load_voice(name)
......@@ -259,7 +259,7 @@ async def patch_voice(name: str, req: VoicePatchRequest):
return {"updated": True, "voice": meta}
@router.get("/v1/audio/voices/{name}")
@router.get("/v1/audio/voices/{name}", summary="Get a voice profile")
async def get_voice(name: str):
"""Get a single voice profile metadata."""
meta = _load_voice(name)
......@@ -268,7 +268,7 @@ async def get_voice(name: str):
return {"voice": meta}
@router.post("/v1/audio/voices/extract")
@router.post("/v1/audio/voices/extract", summary="Extract a voice profile from a sample")
async def extract_voice(req: VoiceExtractRequest):
"""
Extract a voice profile from a source audio or video file.
......@@ -358,7 +358,7 @@ class VoiceCloneRequest(BaseModel):
model_config = ConfigDict(extra="allow")
@router.post("/v1/audio/clone")
@router.post("/v1/audio/clone", summary="Clone a voice / synthesize cloned speech")
async def clone_voice(request: VoiceCloneRequest, http_request: Request = None):
"""
Synthesize speech in a cloned voice using F5-TTS.
......
......@@ -94,7 +94,7 @@ class VoiceConvertRequest(BaseModel):
model_config = ConfigDict(extra='allow')
@router.post('/v1/audio/convert')
@router.post('/v1/audio/convert', summary="Voice conversion (speech-to-speech)")
async def convert_voice(request: VoiceConvertRequest, http_request: Request = None):
"""
Voice conversion: preserves pitch/melody/expression, changes only timbre.
......
......@@ -78,6 +78,40 @@ except ImportError:
_llama_cpp = None
# Friendly KV-cache quant names → llama.cpp GGML type. q8_0 is near-lossless and
# the safe default; the q5/q4 types trade a little accuracy for ~2x less KV VRAM.
_KV_TYPE_ALIASES = {
'f16': 'GGML_TYPE_F16', 'fp16': 'GGML_TYPE_F16', 'f32': 'GGML_TYPE_F32',
'q8_0': 'GGML_TYPE_Q8_0', 'q8': 'GGML_TYPE_Q8_0', 'q8_1': 'GGML_TYPE_Q8_1',
'q5_0': 'GGML_TYPE_Q5_0', 'q5_1': 'GGML_TYPE_Q5_1', 'q5': 'GGML_TYPE_Q5_1',
'q4_0': 'GGML_TYPE_Q4_0', 'q4_1': 'GGML_TYPE_Q4_1', 'q4': 'GGML_TYPE_Q4_1',
'iq4_nl': 'GGML_TYPE_IQ4_NL',
}
# Sub-8-bit KV types that llama.cpp can only use with flash attention enabled.
_KV_NEEDS_FLASH = {'q5_0', 'q5_1', 'q5', 'q4_0', 'q4_1', 'q4', 'iq4_nl'}
def _ggml_kv_type(name):
"""Map a KV-cache quant name to the llama.cpp GGML type int, or None.
Returns None for falsy / unknown / 'none' / 'auto' values (→ keep the
llama.cpp default, f16). Unknown names log a warning instead of failing."""
if not name or _llama_cpp is None:
return None
key = str(name).strip().lower().replace('-', '_').replace(' ', '')
if key in ('', 'none', 'auto', 'default', 'f16default'):
return None
const = _KV_TYPE_ALIASES.get(key)
if const is None:
print(f" KV cache type '{name}' not recognized — using default (f16)")
return None
val = getattr(_llama_cpp, const, None)
if val is None:
print(f" KV cache type '{name}' unsupported by this llama.cpp build — using f16")
return val
def _install_layer_log_callback():
"""Replace llama.cpp's log callback with one that prints load-time layer/buffer
messages directly to stdout. Returns the callback object — keep a reference
......@@ -613,7 +647,33 @@ class VulkanBackend(ModelBackend):
llama_kwargs['rope_freq_base'] = kwargs['rope_freq_base']
if 'rope_freq_scale' in kwargs:
llama_kwargs['rope_freq_scale'] = kwargs['rope_freq_scale']
# KV-cache quantization (llama.cpp type_k / type_v). Shrinks the KV cache
# so long contexts fit in less VRAM. Read from the per-model config, with
# the raw models.json entry as a fallback (carried in _raw_cfg).
_raw_cfg = kwargs.get('_raw_cfg') or {}
_ck = kwargs.get('cache_type_k', _raw_cfg.get('cache_type_k'))
_cv = kwargs.get('cache_type_v', _raw_cfg.get('cache_type_v'))
_flash = bool(kwargs.get('flash_attn', _raw_cfg.get('flash_attn',
_raw_cfg.get('flash_attention', False))))
_tk = _ggml_kv_type(_ck)
_tv = _ggml_kv_type(_cv)
if _tk is not None:
llama_kwargs['type_k'] = _tk
if _tv is not None:
llama_kwargs['type_v'] = _tv
# A quantized V cache below 8 bits requires flash attention in llama.cpp;
# auto-enable it (with a note) so the config "just works".
_v_needs_flash = str(_cv or '').strip().lower().replace('-', '_') in _KV_NEEDS_FLASH
if (_tk is not None or _tv is not None):
if _v_needs_flash and not _flash:
_flash = True
print(" KV cache: sub-8-bit V cache needs flash attention — enabling it")
if _flash:
llama_kwargs['flash_attn'] = True
print(f" KV cache: type_k={_ck or 'f16'} type_v={_cv or 'f16'}"
f"{' (flash_attn on)' if _flash else ''}")
# Force CUDA if requested
if self.force_cuda:
# Set environment variable to force CUDA
......
......@@ -247,4 +247,11 @@ configuration directory (--config DIR, default: OS-specific CoderAI directory).
action="store_true",
help="List available Vulkan GPU devices and exit",
)
parser.add_argument(
"--no-resume-jobs",
action="store_true",
help="Do not resume/recover interrupted LoRA training jobs on restart. "
"Mid-flight jobs are marked 'cancelled' (checkpoints are kept, so they "
"can still be restarted manually from the Tasks page).",
)
return parser.parse_args()
......@@ -126,6 +126,17 @@ class ThermalConfig:
poll_seconds: float = 5.0 # how often to re-check while cooling down
@dataclass
class JobsConfig:
"""Background-job (LoRA training) configuration."""
# When True, an interrupted training job (process restart) is left
# 'interrupted' so it can resume from its on-disk checkpoint. When False,
# such jobs are marked 'cancelled' on startup and not auto-resumed (their
# checkpoints are kept, so they can be restarted manually from the Tasks
# page). The --no-resume-jobs CLI flag forces this off for one run.
resume_on_restart: bool = True
@dataclass
class Config:
"""Main configuration class."""
......@@ -139,6 +150,7 @@ class Config:
whisper: WhisperConfig = field(default_factory=WhisperConfig)
archive: ArchiveConfig = field(default_factory=ArchiveConfig)
thermal: ThermalConfig = field(default_factory=ThermalConfig)
jobs: JobsConfig = field(default_factory=JobsConfig)
broker: BrokerConfig = field(default_factory=BrokerConfig)
system_prompt: Optional[str] = None
tools_closer_prompt: bool = False
......@@ -293,6 +305,7 @@ class ConfigManager:
whisper=WhisperConfig(**config_data.get("whisper", {})),
archive=ArchiveConfig(**config_data.get("archive", {})),
thermal=ThermalConfig(**config_data.get("thermal", {})),
jobs=JobsConfig(**config_data.get("jobs", {})),
broker=BrokerConfig(**config_data.get("broker", {})),
system_prompt=config_data.get("system_prompt"),
tools_closer_prompt=config_data.get("tools_closer_prompt", False),
......@@ -411,6 +424,9 @@ class ConfigManager:
"gpu_resume": self.config.thermal.gpu_resume,
"poll_seconds": self.config.thermal.poll_seconds,
},
"jobs": {
"resume_on_restart": self.config.jobs.resume_on_restart,
},
"broker": {
"enabled": self.config.broker.enabled,
"base_url": self.config.broker.base_url,
......
......@@ -147,6 +147,8 @@ def build_runtime_kwargs(model_cfg, model_type):
}
if model_type == "text":
kwargs['ctx'] = model_cfg.get('n_ctx', model_cfg.get('context_size'))
kwargs['cache_type_k'] = model_cfg.get('cache_type_k')
kwargs['cache_type_v'] = model_cfg.get('cache_type_v')
elif model_type == "image":
kwargs['llm_path'] = model_cfg.get('llm_path')
kwargs['vae_path'] = model_cfg.get('vae_path')
......@@ -865,9 +867,16 @@ def main():
from codai.api.characters import set_global_args as set_chars_global_args
set_chars_global_args(global_args)
# Set LoRA training module global args
from codai.api.loras import set_global_args as set_loras_global_args
# Set LoRA training module global args. Resolve job-recovery first (the
# --no-resume-jobs flag overrides the persisted config setting), then call
# set_global_args, which runs _load_jobs_on_start and honours the flag.
from codai.api.loras import (set_global_args as set_loras_global_args,
set_resume_enabled as set_loras_resume_enabled)
_resume_jobs = bool(getattr(config.jobs, "resume_on_restart", True)) and not getattr(args, "no_resume_jobs", False)
set_loras_resume_enabled(_resume_jobs)
set_loras_global_args(global_args)
if not _resume_jobs:
print("LoRA job recovery: DISABLED (interrupted training will be cancelled on restart)")
# Set environment profiles module global args
from codai.api.environments import set_global_args as set_envs_global_args
......
......@@ -790,6 +790,17 @@ class MultiModelManager:
# build_kwargs_from_config populates it from the model's
# 'flash_attention' setting; CLI/global is NOT consulted here.
kwargs['flash_attn'] = bool(config.get('flash_attn', False))
# KV-cache quantization (llama.cpp type_k/type_v) — pass through
# to the backend, with the raw models.json entry as a fallback.
_raw = config.get('_raw_cfg') if isinstance(config.get('_raw_cfg'), dict) else {}
for _kvk in ('cache_type_k', 'cache_type_v'):
_kvv = config.get(_kvk)
if _kvv is None:
_kvv = _raw.get(_kvk)
if _kvv:
kwargs[_kvk] = _kvv
if _raw and '_raw_cfg' not in kwargs:
kwargs['_raw_cfg'] = _raw
no_ram = _cfg_or_global('no_ram', 'no_ram', False)
kwargs['no_ram'] = bool(no_ram)
offload_strategy = _cfg_or_global('offload_strategy', 'offload_strategy', 'auto')
......@@ -872,6 +883,17 @@ class MultiModelManager:
# build_kwargs_from_config populates it from the model's
# 'flash_attention' setting; CLI/global is NOT consulted here.
kwargs['flash_attn'] = bool(config.get('flash_attn', False))
# KV-cache quantization (llama.cpp type_k/type_v) — pass through
# to the backend, with the raw models.json entry as a fallback.
_raw = config.get('_raw_cfg') if isinstance(config.get('_raw_cfg'), dict) else {}
for _kvk in ('cache_type_k', 'cache_type_v'):
_kvv = config.get(_kvk)
if _kvv is None:
_kvv = _raw.get(_kvk)
if _kvv:
kwargs[_kvk] = _kvv
if _raw and '_raw_cfg' not in kwargs:
kwargs['_raw_cfg'] = _raw
no_ram = _cfg_or_global('no_ram', 'no_ram', False)
kwargs['no_ram'] = bool(no_ram)
offload_strategy = _cfg_or_global('offload_strategy', 'offload_strategy', 'auto')
......
......@@ -35,10 +35,62 @@ Semantics (per sensor, when enabled):
import os
import shutil
import subprocess
import threading
import time
from typing import Optional, Tuple
# ---------------------------------------------------------------------------
# Cooldown state (published for the admin Tasks view)
# ---------------------------------------------------------------------------
# A thermal pause is a *global* hardware event: every worker that reaches a
# checkpoint blocks until temps recover. We publish a single process-wide state
# so the Tasks page can show that running work is paused for cooldown. A waiter
# counter (not a bool) keeps the state correct when several workers pause at
# once — the state is "active" while any worker is still cooling.
_cooldown_lock = threading.Lock()
_cooldown_waiters = 0
_cooldown_state: dict = {
"active": False, "since": 0.0, "waited": 0.0,
"gpu": None, "cpu": None, "message": "",
}
def get_cooldown_state() -> dict:
"""Snapshot of the current thermal cooldown (see module note). ``active`` is
True while at least one worker is paused waiting for the hardware to cool."""
with _cooldown_lock:
return dict(_cooldown_state)
def _cooldown_enter() -> None:
global _cooldown_waiters
with _cooldown_lock:
_cooldown_waiters += 1
_cooldown_state["active"] = True
if not _cooldown_state.get("since"):
_cooldown_state["since"] = time.time()
def _cooldown_update(gpu, cpu, waited, message) -> None:
with _cooldown_lock:
_cooldown_state["gpu"] = gpu
_cooldown_state["cpu"] = cpu
_cooldown_state["waited"] = waited
_cooldown_state["message"] = message
def _cooldown_exit() -> None:
global _cooldown_waiters
with _cooldown_lock:
_cooldown_waiters = max(0, _cooldown_waiters - 1)
if _cooldown_waiters == 0:
_cooldown_state.update({
"active": False, "since": 0.0, "waited": 0.0,
"gpu": None, "cpu": None, "message": "",
})
# ---------------------------------------------------------------------------
# Temperature readers
# ---------------------------------------------------------------------------
......@@ -199,6 +251,57 @@ def read_cpu_temp() -> Optional[float]:
return val
_gpu_util_cache: Tuple[float, Optional[float]] = (0.0, None)
def _read_gpu_util_uncached() -> Optional[float]:
"""Hottest GPU utilization in %, or None if unreadable."""
if _NVIDIA_SMI:
out = _run([
_NVIDIA_SMI,
"--query-gpu=utilization.gpu",
"--format=csv,noheader,nounits",
])
if out:
vals = []
for line in out.splitlines():
line = line.strip()
if line:
try:
vals.append(float(line))
except ValueError:
pass
if vals:
return max(vals)
if _ROCM_SMI:
out = _run([_ROCM_SMI, "--showuse"])
if out:
vals = []
for line in out.splitlines():
low = line.lower()
if "gpu use" in low and "%" in line:
for tok in line.replace("%", " ").split():
try:
vals.append(float(tok))
except ValueError:
continue
if vals:
return max(vals)
return None
def read_gpu_util() -> Optional[float]:
"""GPU utilization % (cached ~2s), or None if unreadable."""
global _gpu_util_cache
now = time.monotonic()
ts, val = _gpu_util_cache
if now - ts < _CACHE_TTL:
return val
val = _read_gpu_util_uncached()
_gpu_util_cache = (now, val)
return val
def read_cpu_temp_avg(samples: int = 3, max_seconds: float = 3.0) -> Optional[float]:
"""Averaged CPU temperature for stable resume/cooldown decisions.
......@@ -372,25 +475,30 @@ def wait_until_safe(settings: Optional[ThermalSettings] = None,
f"until cooldown (GPU<={settings.gpu_resume:.0f}°C / "
f"CPU<={settings.cpu_resume:.0f}°C)")
waited = 0.0
while True:
# Re-evaluate against resume thresholds (lower than trigger → hysteresis).
# CPU temps are noisy, so average a few samples for the resume decision
# (the pause check above stays single-read to react fast to spikes).
gt = read_gpu_temp() if settings.gpu_enabled else None
ct = read_cpu_temp_avg() if settings.cpu_enabled else None
still = []
if gt is not None and gt > settings.gpu_resume:
still.append(("GPU", gt, settings.gpu_resume))
if ct is not None and ct > settings.cpu_resume:
still.append(("CPU", ct, settings.cpu_resume))
_dbg(f"cooldown{desc} {int(waited)}s: GPU {_fmt(gt)} CPU {_fmt(ct)} (avg-3) "
f"(still hot: {[s[0] for s in still] or 'none'})")
if not still:
break
msg = ", ".join(f"{lbl} {t:.0f}°C>{r:.0f}°C" for lbl, t, r in still)
print(f"[thermal] Cooling{desc}: {msg} — waiting "
f"({int(waited)}s elapsed)")
time.sleep(settings.poll_seconds)
waited += settings.poll_seconds
_cooldown_enter()
try:
while True:
# Re-evaluate against resume thresholds (lower than trigger → hysteresis).
# CPU temps are noisy, so average a few samples for the resume decision
# (the pause check above stays single-read to react fast to spikes).
gt = read_gpu_temp() if settings.gpu_enabled else None
ct = read_cpu_temp_avg() if settings.cpu_enabled else None
still = []
if gt is not None and gt > settings.gpu_resume:
still.append(("GPU", gt, settings.gpu_resume))
if ct is not None and ct > settings.cpu_resume:
still.append(("CPU", ct, settings.cpu_resume))
_dbg(f"cooldown{desc} {int(waited)}s: GPU {_fmt(gt)} CPU {_fmt(ct)} (avg-3) "
f"(still hot: {[s[0] for s in still] or 'none'})")
if not still:
break
msg = ", ".join(f"{lbl} {t:.0f}°C>{r:.0f}°C" for lbl, t, r in still)
_cooldown_update(gt, ct, waited, msg)
print(f"[thermal] Cooling{desc}: {msg} — waiting "
f"({int(waited)}s elapsed)")
time.sleep(settings.poll_seconds)
waited += settings.poll_seconds
finally:
_cooldown_exit()
print(f"[thermal] Temperatures back within safe limits{desc} — resuming "
f"after {int(waited)}s")
This diff is collapsed.
......@@ -17,16 +17,17 @@
"""Pydantic models for embeddings API."""
from typing import Dict, List, Optional, Union
from pydantic import BaseModel, ConfigDict
from pydantic import BaseModel, ConfigDict, Field
class EmbeddingsRequest(BaseModel):
model: str
input: Union[str, List[str]] # text(s) to embed
image: Optional[Union[str, List[str]]] = None # base64/URL image(s) for multimodal embed
encoding_format: Optional[str] = "float" # float | base64
dimensions: Optional[int] = None # truncate to N dims if supported
user: Optional[str] = None
model: str = Field(..., description="Embedding model id to use.")
input: Union[str, List[str]] = Field(..., description="Text or list of texts to embed.")
image: Optional[Union[str, List[str]]] = Field(None, description="Base64/URL image(s) for multimodal embedding models.")
encoding_format: Optional[str] = Field("float", description="Return embeddings as 'float' arrays or 'base64'.")
dimensions: Optional[int] = Field(None, description="Truncate embeddings to N dimensions (if the model supports it).")
quantization: Optional[str] = Field(None, description="Optional TurboQuant vector quantization: 'turbo' (8-bit), 'turbo8', 'turbo6', 'turbo4' or 'turbo2'. With encoding_format='float' the (lossy) reconstructed vectors are returned; with 'base64' the compact packed bytes are returned plus a 'quantization' metadata block describing how to decode them.")
user: Optional[str] = Field(None, description="Opaque end-user identifier (passthrough).")
model_config = ConfigDict(extra="allow")
......
......@@ -18,7 +18,7 @@
from typing import Dict, List, Optional
from pydantic import BaseModel, ConfigDict
from pydantic import BaseModel, ConfigDict, Field
class LoraConfig(BaseModel):
......@@ -26,41 +26,43 @@ class LoraConfig(BaseModel):
server-side, in priority) via `id` ("name:<registered>" or "sha256:<hex>"),
inline `file`/`data` base64, a `url`, or the legacy `model`/`path` (local path
/ HF id) — so a remote client needn't share the server's filesystem."""
model: Optional[str] = None
path: Optional[str] = None
id: Optional[str] = None
url: Optional[str] = None
file: Optional[str] = None
data: Optional[str] = None
weight: float = 1.0
name: Optional[str] = None
model: Optional[str] = Field(None, description="Legacy: local path or HF id of the weights (shared-filesystem only).")
path: Optional[str] = Field(None, description="Alias of `model` — local path to the .safetensors weights.")
id: Optional[str] = Field(None, description='Registry/blob reference: "name:<registered-lora>" or "sha256:<hex>" (from /v1/loras/upload).')
url: Optional[str] = Field(None, description="HTTP(S) URL the server downloads and caches.")
file: Optional[str] = Field(None, description="Base64 of the .safetensors file (or a data: URI). Sent inline so no shared filesystem is needed.")
data: Optional[str] = Field(None, description="Alias of `file` — inline base64 weights.")
weight: float = Field(1.0, description="Adapter strength / scale.")
name: Optional[str] = Field(None, description="Optional adapter name.")
model_config = ConfigDict(extra="allow")
class ImageGenerationRequest(BaseModel):
model: str
prompt: str
n: int = 1
size: Optional[str] = "1024x1024"
steps: Optional[int] = None
guidance_scale: Optional[float] = None
quality: Optional[str] = "standard"
style: Optional[str] = None
response_format: Optional[str] = "url"
seed: Optional[int] = None
user: Optional[str] = None
disable_safety_checker: Optional[bool] = False
negative_prompt: Optional[str] = None
model: str = Field(..., description="Model id to generate with (must be a configured image model).")
prompt: str = Field(..., description="Text prompt describing the image.")
n: int = Field(1, description="Number of images to generate.")
size: Optional[str] = Field("1024x1024", description="Output size as 'WIDTHxHEIGHT'.")
steps: Optional[int] = Field(None, description="Denoising steps (model/acceleration default if omitted).")
guidance_scale: Optional[float] = Field(None, description="Classifier-free guidance scale (model/acceleration default if omitted).")
quality: Optional[str] = Field("standard", description="Quality hint: 'standard' or 'hd'.")
style: Optional[str] = Field(None, description="Optional style hint passed through to the model.")
response_format: Optional[str] = Field("url", description="How to return the result: 'url' or 'b64_json'.")
seed: Optional[int] = Field(None, description="Random seed for reproducibility.")
user: Optional[str] = Field(None, description="Opaque end-user identifier (passthrough).")
disable_safety_checker: Optional[bool] = Field(False, description=(
"Null out the diffusers safety_checker so uncensored fine-tunes are not blocked. "
"Only affects SD 1.x/2.x (SDXL/Flux ship no checker)."))
negative_prompt: Optional[str] = Field(None, description="What to avoid in the output.")
# Per-request component overrides
vae_model: Optional[str] = None # Override the VAE for this request
loras: Optional[List[LoraConfig]] = None # Additional LoRA weights for this request
vae_model: Optional[str] = Field(None, description="Override the VAE for this request.")
loras: Optional[List[LoraConfig]] = Field(None, description="Additional LoRA adapters to apply for this request.")
# Character consistency
character_profiles: Optional[List[str]] = None # saved profile names
character_references: Optional[List[str]] = None # inline base64 images
character_strength: Optional[float] = 0.6 # IP-Adapter scale
environment_profiles: Optional[List[str]] = None # saved environment profile names (IP-Adapter)
character_profiles: Optional[List[str]] = Field(None, description="Saved character profile names to apply (IP-Adapter).")
character_references: Optional[List[str]] = Field(None, description="Inline base64 reference images for character consistency.")
character_strength: Optional[float] = Field(0.6, description="IP-Adapter scale for character references.")
environment_profiles: Optional[List[str]] = Field(None, description="Saved environment profile names to apply (IP-Adapter).")
model_config = ConfigDict(extra="allow")
......
......@@ -67,34 +67,32 @@ class ChatMessage(BaseModel):
class ChatCompletionRequest(BaseModel):
model: str
messages: List[ChatMessage]
temperature: float = 0.7
top_p: float = 1.0
n: int = 1
max_tokens: Optional[int] = None
stream: bool = False
stop: Optional[Union[str, List[str]]] = None
presence_penalty: float = 0.0
frequency_penalty: float = 0.0
repeat_penalty: float = 1.0
tools: Optional[List[Tool]] = None
tool_choice: Optional[Union[str, Dict]] = "auto"
# Extra fields that clients may send but we ignore
seed: Optional[int] = None
logprobs: Optional[bool] = None
top_logprobs: Optional[int] = None
response_format: Optional[Dict] = None
user: Optional[str] = None
# Enable thinking/reasoning mode for supported models
enable_thinking: Optional[bool] = False
model: str = Field(..., description="Text/chat model id to use.")
messages: List[ChatMessage] = Field(..., description="Conversation messages (roles: system/user/assistant/tool). Content may include text and image parts for vision models.")
temperature: float = Field(0.7, description="Sampling temperature; higher = more random.")
top_p: float = Field(1.0, description="Nucleus sampling probability mass.")
n: int = Field(1, description="Number of completions to generate.")
max_tokens: Optional[int] = Field(None, description="Max tokens to generate (model default if omitted).")
stream: bool = Field(False, description="Stream the response as Server-Sent Events.")
stop: Optional[Union[str, List[str]]] = Field(None, description="Stop sequence(s) that end generation.")
presence_penalty: float = Field(0.0, description="Penalize tokens already present (encourages new topics).")
frequency_penalty: float = Field(0.0, description="Penalize frequent tokens (reduces repetition).")
repeat_penalty: float = Field(1.0, description="llama.cpp repetition penalty.")
tools: Optional[List[Tool]] = Field(None, description="Tool/function definitions the model may call.")
tool_choice: Optional[Union[str, Dict]] = Field("auto", description="Tool selection: 'auto', 'none', or a specific tool.")
seed: Optional[int] = Field(None, description="Random seed for reproducibility.")
logprobs: Optional[bool] = Field(None, description="Return token log-probabilities (if supported).")
top_logprobs: Optional[int] = Field(None, description="Number of top log-probs to return per token.")
response_format: Optional[Dict] = Field(None, description="Structured-output format, e.g. {'type': 'json_object'}.")
user: Optional[str] = Field(None, description="Opaque end-user identifier (passthrough).")
enable_thinking: Optional[bool] = Field(False, description="Enable thinking/reasoning mode for models that support it.")
model_config = ConfigDict(extra="allow") # Allow extra fields to prevent 422 errors
class CompletionRequest(BaseModel):
model: str
prompt: Union[str, List[str]]
model: str = Field(..., description="Text model id to use.")
prompt: Union[str, List[str]] = Field(..., description="Prompt text (or list of prompts) to complete.")
temperature: float = 0.7
top_p: float = 1.0
n: int = 1
......
This diff is collapsed.
......@@ -169,6 +169,23 @@ class QueueManager:
return index
return 0
def list_waiting(self) -> list:
"""Best-effort snapshot of queued (waiting) requests for the Tasks view.
Read without the async lock — fine for a read-only UI snapshot."""
out = []
for w in list(self.waiting):
out.append({
"request_id": w.request_id,
"model_key": w.model_key,
"enqueued_at": w.enqueued_at,
})
return out
def list_active(self) -> list:
"""Best-effort snapshot of in-flight leases for the Tasks view."""
return [{"request_id": rid, "model_key": lease.model_key}
for rid, lease in list(self.active_leases.items())]
def get_metrics(self) -> Dict[str, object]:
return {
"active": len(self.active_leases),
......
# CoderAI - OpenAI-compatible API server
# Copyright (C) 2026 Stefy Lanza <stefy@nexlab.net>
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
"""Central task registry for long-running operations."""
from codai.tasks.registry import (
Task,
TaskCancelled,
TaskRegistry,
task_registry,
raise_if_cancelled,
wait_if_paused,
)
__all__ = [
"Task",
"TaskCancelled",
"TaskRegistry",
"task_registry",
"raise_if_cancelled",
"wait_if_paused",
]
# CoderAI - OpenAI-compatible API server
# Copyright (C) 2026 Stefy Lanza <stefy@nexlab.net>
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
"""Central registry of long-running tasks (generation + training).
Lets the admin UI list every in-flight / recent task and cooperatively cancel
one. Cancellation is *cooperative*: the per-step diffusers callbacks
(``_step_cb`` / ``_vid_step_cb`` / ``_aud_step_cb`` / sd.cpp ``progress_callback``)
and the LoRA training loops call ``raise_if_cancelled(task_id)`` between steps,
which raises :class:`TaskCancelled` once the task is cancelled. This mirrors the
``codai.models.thermal.checkpoint()`` pause already invoked in those same hooks,
so the running model aborts at the next step boundary.
The registry is process-local and in-memory: it is the live view. Durable LoRA
training state lives separately in ``codai/api/loras.py`` (``_train_jobs.json``);
a task with a ``job_id`` links the two.
"""
import threading
import time
import uuid
from dataclasses import asdict, dataclass, field
from typing import Dict, List, Optional
class TaskCancelled(Exception):
"""Raised inside a worker when its task has been cancelled by the user."""
ACTIVE_STATES = ("queued", "running")
TERMINAL_STATES = ("done", "error", "cancelled")
@dataclass
class Task:
id: str
kind: str # training | image | video | audio | text | pipeline
title: str = ""
model: str = ""
status: str = "queued" # queued | running | done | error | cancelled
step: int = 0
total: int = 0
message: str = ""
job_id: Optional[str] = None # link to a durable loras training job, if any
created_at: float = field(default_factory=time.time)
started_at: Optional[float] = None
ended_at: Optional[float] = None
cancellable: bool = True
restartable: bool = False
paused: bool = False
def to_dict(self) -> dict:
d = asdict(self)
d["active"] = self.status in ACTIVE_STATES
return d
class TaskRegistry:
"""Thread-safe registry. All public methods take the registry lock briefly;
the per-task cancel flag is a ``threading.Event`` so ``is_cancelled`` is a
cheap lock-free read suitable for a hot step loop."""
def __init__(self, history: int = 50):
self._lock = threading.Lock()
self._tasks: Dict[str, Task] = {}
self._events: Dict[str, threading.Event] = {}
self._pause_events: Dict[str, threading.Event] = {}
self._history = history
def register(self, kind: str, *, title: str = "", model: str = "",
total: int = 0, job_id: Optional[str] = None,
status: str = "queued", cancellable: bool = True,
restartable: bool = False, task_id: Optional[str] = None) -> str:
tid = task_id or f"task-{uuid.uuid4().hex[:12]}"
with self._lock:
self._tasks[tid] = Task(
id=tid, kind=kind, title=title, model=model, total=total,
job_id=job_id, status=status, cancellable=cancellable,
restartable=restartable,
)
self._events[tid] = threading.Event()
self._pause_events[tid] = threading.Event()
self._prune_locked()
return tid
def start(self, tid: str) -> None:
with self._lock:
t = self._tasks.get(tid)
if t and t.status in ACTIVE_STATES:
t.status = "running"
t.started_at = t.started_at or time.time()
def update(self, tid: str, **fields) -> None:
with self._lock:
t = self._tasks.get(tid)
if not t:
return
for k, v in fields.items():
if hasattr(t, k):
setattr(t, k, v)
def step(self, tid: str, step: int, total: Optional[int] = None) -> None:
with self._lock:
t = self._tasks.get(tid)
if not t:
return
t.step = int(step)
if total is not None:
t.total = int(total)
def finish(self, tid: str, status: str = "done", message: str = "") -> None:
with self._lock:
t = self._tasks.get(tid)
if not t:
return
# A user cancel wins over a late 'done' from the worker unwinding.
if not (t.status == "cancelled" and status == "done"):
t.status = status
if message:
t.message = message
t.ended_at = time.time()
self._prune_locked()
def cancel(self, tid: str) -> bool:
with self._lock:
t = self._tasks.get(tid)
if not t:
return False
ev = self._events.get(tid)
if ev:
ev.set()
# Release any pause so a paused→cancelled task unblocks immediately.
pev = self._pause_events.get(tid)
if pev:
pev.set() # wakes wait_if_paused; its is_cancelled check then raises
t.paused = False
was = t.status
if was in ACTIVE_STATES:
t.status = "cancelled"
t.message = "cancelled"
# A queued task never entered a worker, so it's terminal now;
# a running task finalises when its worker observes the flag.
if was == "queued":
t.ended_at = time.time()
return True
def is_cancelled(self, tid: Optional[str]) -> bool:
if not tid:
return False
ev = self._events.get(tid)
return bool(ev and ev.is_set())
def raise_if_cancelled(self, tid: Optional[str]) -> None:
if self.is_cancelled(tid):
raise TaskCancelled(tid)
# --- Pause / resume (cooperative, at the next step boundary) -------------
def pause(self, tid: str) -> bool:
with self._lock:
t = self._tasks.get(tid)
ev = self._pause_events.get(tid)
if not t or ev is None or t.status not in ACTIVE_STATES:
return False
ev.set()
t.paused = True
return True
def resume(self, tid: str) -> bool:
with self._lock:
t = self._tasks.get(tid)
ev = self._pause_events.get(tid)
if not t or ev is None:
return False
ev.clear()
t.paused = False
return True
def is_paused(self, tid: Optional[str]) -> bool:
if not tid:
return False
ev = self._pause_events.get(tid)
return bool(ev and ev.is_set())
def wait_if_paused(self, tid: Optional[str], poll: float = 0.2) -> None:
"""Block while the task is paused, returning when it is resumed.
Stays responsive to cancellation: a paused task that is then cancelled
raises :class:`TaskCancelled` instead of hanging. Safe to call from a hot
step loop — a no-op unless the task is actually paused."""
ev = self._pause_events.get(tid) if tid else None
if ev is None or not ev.is_set():
return
while ev.is_set():
if self.is_cancelled(tid):
raise TaskCancelled(tid)
ev.wait(timeout=poll)
def remove(self, tid: str) -> bool:
"""Drop a task from the registry entirely (used to dismiss a finished/
cancelled task from the view). No-op on a missing id."""
with self._lock:
self._events.pop(tid, None)
self._pause_events.pop(tid, None)
return self._tasks.pop(tid, None) is not None
def get(self, tid: str) -> Optional[dict]:
with self._lock:
t = self._tasks.get(tid)
return t.to_dict() if t else None
def list(self) -> List[dict]:
with self._lock:
return [t.to_dict() for t in sorted(
self._tasks.values(), key=lambda x: x.created_at, reverse=True)]
def _prune_locked(self) -> None:
"""Keep all active tasks + the most recent ``history`` terminal ones."""
terminal = [t for t in self._tasks.values() if t.status in TERMINAL_STATES]
if len(terminal) <= self._history:
return
terminal.sort(key=lambda x: x.ended_at or x.created_at)
for t in terminal[:-self._history]:
self._tasks.pop(t.id, None)
self._events.pop(t.id, None)
self._pause_events.pop(t.id, None)
# Process-wide singleton.
task_registry = TaskRegistry()
def raise_if_cancelled(task_id: Optional[str]) -> None:
"""Free helper mirroring ``thermal.checkpoint()`` — call from hot step loops.
Raises :class:`TaskCancelled` if ``task_id`` has been cancelled; a falsy
``task_id`` is a no-op (so callers needn't guard)."""
task_registry.raise_if_cancelled(task_id)
def wait_if_paused(task_id: Optional[str]) -> None:
"""Free helper — block at a step boundary while ``task_id`` is paused.
Returns immediately when not paused; raises :class:`TaskCancelled` if the
task is cancelled while paused. A falsy ``task_id`` is a no-op."""
task_registry.wait_if_paused(task_id)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment