Add task management, quantization, and hardware telemetry

Tasks / queue management:
- Central in-memory task registry with cooperative cancel, pause/resume,
  and step progress across image/video/audio/text generation + LoRA training
- Tasks admin page (live 2s poll): cancel, interrupt, pause/resume, restart,
  remove; done jobs auto-drop from the list; bounded persisted job history
- Disable interrupted-training recovery via --no-resume-jobs + settings toggle

Quantization / acceleration:
- TurboQuant embedding vector quantization (data-free, inner-product
  preserving): built-in NumPy backend + optional turboquant-py library,
  selectable per embedding model; /v1/embeddings `quantization` param
- llama.cpp KV-cache quantization (cache_type_k/v) for GGUF text models,
  configurable in the Models UI

Hardware telemetry:
- Thermal cooldown state surfaced on the Tasks page (banner + per-task badge)
- Live CPU/GPU/RAM/VRAM usage + temperature panel via /admin/api/system-stats

Docs: API documentation gaps/accuracy pass + Swagger overhaul; DISTRIBUTION.md
implementation spec. Plus I2V LoRA training channel-mismatch fix.
Co-Authored-By: 's avatarClaude Opus 4.8 <noreply@anthropic.com>
parent 9494d1bd
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
<a href="{{ root_path }}/docs" class="nav-link" target="_blank">API Docs</a> <a href="{{ root_path }}/docs" class="nav-link" target="_blank">API Docs</a>
{% if is_admin|default(false) %} {% if is_admin|default(false) %}
<a href="{{ root_path }}/admin/models" class="nav-link {% if '/models' in request.url.path %}active{% endif %}">Models</a> <a href="{{ root_path }}/admin/models" class="nav-link {% if '/models' in request.url.path %}active{% endif %}">Models</a>
<a href="{{ root_path }}/admin/tasks" class="nav-link {% if '/tasks' in request.url.path %}active{% endif %}">Tasks</a>
<a href="{{ root_path }}/admin/tokens" class="nav-link {% if '/tokens' in request.url.path %}active{% endif %}">Tokens</a> <a href="{{ root_path }}/admin/tokens" class="nav-link {% if '/tokens' in request.url.path %}active{% endif %}">Tokens</a>
<a href="{{ root_path }}/admin/users" class="nav-link {% if '/users' in request.url.path %}active{% endif %}">Users</a> <a href="{{ root_path }}/admin/users" class="nav-link {% if '/users' in request.url.path %}active{% endif %}">Users</a>
<a href="{{ root_path }}/admin/archive" class="nav-link {% if '/archive' in request.url.path %}active{% endif %}">Archive</a> <a href="{{ root_path }}/admin/archive" class="nav-link {% if '/archive' in request.url.path %}active{% endif %}">Archive</a>
......
...@@ -616,6 +616,28 @@ window.__DEFAULT_WHISPER_SERVER_PATH__ = {{ default_whisper_server_path|tojson } ...@@ -616,6 +616,28 @@ window.__DEFAULT_WHISPER_SERVER_PATH__ = {{ default_whisper_server_path|tojson }
<label class="form-label">Context size</label> <label class="form-label">Context size</label>
<input type="number" id="cfg-n-ctx" class="form-input" min="128" step="128" value="2048"> <input type="number" id="cfg-n-ctx" class="form-input" min="128" step="128" value="2048">
</div> </div>
<div class="form-row" style="margin:0" id="cfg-kv-k-row">
<label class="form-label">KV cache — Keys <span class="muted">(GGUF text; shrinks KV VRAM)</span></label>
<select id="cfg-cache-type-k" class="form-input">
<option value="">Default (f16)</option>
<option value="q8_0">q8_0 (near-lossless, ~2×)</option>
<option value="q5_1">q5_1</option>
<option value="q5_0">q5_0</option>
<option value="q4_1">q4_1 (smallest)</option>
<option value="q4_0">q4_0 (smallest)</option>
</select>
</div>
<div class="form-row" style="margin:0" id="cfg-kv-v-row">
<label class="form-label">KV cache — Values <span class="muted">(sub-8-bit needs Flash Attn)</span></label>
<select id="cfg-cache-type-v" class="form-input">
<option value="">Default (f16)</option>
<option value="q8_0">q8_0 (near-lossless, ~2×)</option>
<option value="q5_1">q5_1</option>
<option value="q5_0">q5_0</option>
<option value="q4_1">q4_1 (smallest)</option>
<option value="q4_0">q4_0 (smallest)</option>
</select>
</div>
<div class="form-row" style="margin:0"> <div class="form-row" style="margin:0">
<label class="form-label">Max GPU % <span class="muted">(optional)</span></label> <label class="form-label">Max GPU % <span class="muted">(optional)</span></label>
<input type="number" id="cfg-max-gpu" class="form-input" min="1" max="100" placeholder="e.g. 90"> <input type="number" id="cfg-max-gpu" class="form-input" min="1" max="100" placeholder="e.g. 90">
...@@ -732,6 +754,36 @@ window.__DEFAULT_WHISPER_SERVER_PATH__ = {{ default_whisper_server_path|tojson } ...@@ -732,6 +754,36 @@ window.__DEFAULT_WHISPER_SERVER_PATH__ = {{ default_whisper_server_path|tojson }
</div> </div>
</div> </div>
<!-- TurboQuant embedding vector quantization (embedding models) -->
<div id="cfg-turboquant-section" style="display:none">
<div class="card-title" style="margin-top:1.25rem">TurboQuant <span class="muted" style="font-weight:normal">(embedding vector quantization — data-free, inner-product preserving)</span></div>
<label style="display:flex;align-items:center;gap:.5rem;cursor:pointer;font-size:13px;margin:.4rem 0">
<input type="checkbox" id="cfg-tq-enabled" onchange="onTurboQuantToggle()"> Enable TurboQuant
<span class="muted">compress embeddings to 2–8 bits/coord for smaller vector stores</span></label>
<div id="cfg-tq-fields" style="display:none">
<div style="display:flex;gap:1rem;flex-wrap:wrap">
<div class="form-row" style="max-width:260px">
<label class="form-label">Backend</label>
<select id="cfg-tq-backend" class="form-input">
<option value="builtin">Built-in (NumPy, always available)</option>
<option value="library">turboquant-py library (QJL)</option>
</select>
<span class="form-hint" id="cfg-tq-backend-hint"></span>
</div>
<div class="form-row" style="max-width:200px">
<label class="form-label">Default bits</label>
<select id="cfg-tq-bits" class="form-input">
<option value="8">8-bit (near-lossless, 3×)</option>
<option value="6">6-bit</option>
<option value="4">4-bit (6×)</option>
<option value="2">2-bit (12×)</option>
</select>
</div>
</div>
<span class="form-hint">A request's <code>quantization</code> field (e.g. <code>turbo4</code>) overrides the default bits. With <code>encoding_format:"base64"</code> the response is compact packed bytes; otherwise it returns the lossy reconstruction as floats.</span>
</div>
</div>
<!-- components --> <!-- components -->
<div class="card-title" style="margin-top:1.25rem">Components</div> <div class="card-title" style="margin-top:1.25rem">Components</div>
<div class="form-row"> <div class="form-row">
...@@ -2268,9 +2320,9 @@ async function refreshLocal(){ ...@@ -2268,9 +2320,9 @@ async function refreshLocal(){
loadGlobalSettings(); loadGlobalSettings();
refreshLocal(); refreshLocal();
// Toggle the acceleration section as image/video model types are checked/unchecked. // Toggle the acceleration / TurboQuant sections as model types are checked/unchecked.
document.querySelectorAll('.cfg-type-cb').forEach(cb => document.querySelectorAll('.cfg-type-cb').forEach(cb =>
cb.addEventListener('change', () => _refreshAccelVisibility())); cb.addEventListener('change', () => { _refreshAccelVisibility(); _refreshTurboQuantVisibility(); }));
// ── Deep-link from Studio: /admin/models?tab=search&q=...&pipeline=...&gguf=... // ── Deep-link from Studio: /admin/models?tab=search&q=...&pipeline=...&gguf=...
// ── or: /admin/models?local_cap=CAPABILITY — highlight local models with that capability // ── or: /admin/models?local_cap=CAPABILITY — highlight local models with that capability
...@@ -2633,6 +2685,8 @@ function openCfgModal(idx, cfgIdx){ ...@@ -2633,6 +2685,8 @@ function openCfgModal(idx, cfgIdx){
document.getElementById('cfg-force-vram-update').checked = !!s.force_vram_update; document.getElementById('cfg-force-vram-update').checked = !!s.force_vram_update;
document.getElementById('cfg-gpu-layers').value = s.n_gpu_layers !== undefined ? s.n_gpu_layers : -1; document.getElementById('cfg-gpu-layers').value = s.n_gpu_layers !== undefined ? s.n_gpu_layers : -1;
document.getElementById('cfg-n-ctx').value = nCtxForEst; document.getElementById('cfg-n-ctx').value = nCtxForEst;
document.getElementById('cfg-cache-type-k').value = s.cache_type_k || '';
document.getElementById('cfg-cache-type-v').value = s.cache_type_v || '';
document.getElementById('cfg-max-instances').value = s.max_instances != null ? s.max_instances : 1; document.getElementById('cfg-max-instances').value = s.max_instances != null ? s.max_instances : 1;
document.getElementById('cfg-preload-all-instances').checked = !!s.preload_all_instances; document.getElementById('cfg-preload-all-instances').checked = !!s.preload_all_instances;
_updatePreloadAllVisibility(); _updatePreloadAllVisibility();
...@@ -2678,6 +2732,7 @@ function openCfgModal(idx, cfgIdx){ ...@@ -2678,6 +2732,7 @@ function openCfgModal(idx, cfgIdx){
document.getElementById('cfg-lora-dir').value = s.lora_model_dir || ''; document.getElementById('cfg-lora-dir').value = s.lora_model_dir || '';
document.getElementById('cfg-lora-train-base').value = s.lora_train_base_model || ''; document.getElementById('cfg-lora-train-base').value = s.lora_train_base_model || '';
_populateAccel(s.acceleration); _populateAccel(s.acceleration);
_populateTurboQuant(s.turboquant);
openModal('cfg-modal'); openModal('cfg-modal');
} }
...@@ -2768,6 +2823,56 @@ function _collectAccel(){ ...@@ -2768,6 +2823,56 @@ function _collectAccel(){
}; };
} }
// ---- TurboQuant (embedding vector quantization) ----
let _tqInfo = null;
async function _loadTurboQuantInfo(){
if (_tqInfo) return _tqInfo;
try {
const r = await fetch(ROOT_PATH + '/admin/api/turboquant-info');
_tqInfo = await r.json();
} catch(e){ _tqInfo = {builtin:true, library:false}; }
return _tqInfo;
}
function _turboQuantApplies(){
return [...document.querySelectorAll('.cfg-type-cb:checked')]
.some(cb => cb.value === 'embedding_models');
}
function _refreshTurboQuantVisibility(){
const section = document.getElementById('cfg-turboquant-section');
if (section) section.style.display = _turboQuantApplies() ? '' : 'none';
}
function onTurboQuantToggle(){
const on = document.getElementById('cfg-tq-enabled').checked;
document.getElementById('cfg-tq-fields').style.display = on ? '' : 'none';
}
async function _populateTurboQuant(t){
await _loadTurboQuantInfo();
_refreshTurboQuantVisibility();
// Reflect library availability in the backend dropdown + hint.
const libOpt = document.querySelector('#cfg-tq-backend option[value="library"]');
const hint = document.getElementById('cfg-tq-backend-hint');
const libOk = !!(_tqInfo && _tqInfo.library);
if (libOpt){ libOpt.disabled = !libOk; libOpt.textContent =
'turboquant-py library (QJL)' + (libOk ? '' : ' — not installed'); }
if (hint) hint.textContent = libOk
? 'turboquant-py detected.'
: 'Install "turboquant-py[torch]" to enable the library backend.';
t = t || {};
document.getElementById('cfg-tq-enabled').checked = !!t.enabled;
document.getElementById('cfg-tq-backend').value =
(t.backend === 'library' && libOk) ? 'library' : 'builtin';
document.getElementById('cfg-tq-bits').value = String(t.bits || 8);
onTurboQuantToggle();
}
function _collectTurboQuant(){
if (!document.getElementById('cfg-tq-enabled').checked) return null;
return {
enabled: true,
backend: document.getElementById('cfg-tq-backend').value || 'builtin',
bits: parseInt(document.getElementById('cfg-tq-bits').value) || 8,
};
}
function _updatePreloadAllVisibility() { function _updatePreloadAllVisibility() {
const loadMode = document.getElementById('cfg-load-mode').value; const loadMode = document.getElementById('cfg-load-mode').value;
const maxInst = parseInt(document.getElementById('cfg-max-instances').value) || 1; const maxInst = parseInt(document.getElementById('cfg-max-instances').value) || 1;
...@@ -2834,6 +2939,8 @@ async function saveModelConfig(){ ...@@ -2834,6 +2939,8 @@ async function saveModelConfig(){
preload_all_instances: document.getElementById('cfg-preload-all-instances').checked, preload_all_instances: document.getElementById('cfg-preload-all-instances').checked,
n_gpu_layers: parseInt(document.getElementById('cfg-gpu-layers').value) || -1, n_gpu_layers: parseInt(document.getElementById('cfg-gpu-layers').value) || -1,
n_ctx: parseInt(document.getElementById('cfg-n-ctx').value) || 2048, n_ctx: parseInt(document.getElementById('cfg-n-ctx').value) || 2048,
cache_type_k: document.getElementById('cfg-cache-type-k').value || null,
cache_type_v: document.getElementById('cfg-cache-type-v').value || null,
max_gpu_percent: isNaN(maxGpu) ? null : maxGpu, max_gpu_percent: isNaN(maxGpu) ? null : maxGpu,
manual_ram_gb: isNaN(ramGb) ? null : ramGb, manual_ram_gb: isNaN(ramGb) ? null : ramGb,
load_in_4bit: document.getElementById('cfg-4bit').checked, load_in_4bit: document.getElementById('cfg-4bit').checked,
...@@ -2866,6 +2973,7 @@ async function saveModelConfig(){ ...@@ -2866,6 +2973,7 @@ async function saveModelConfig(){
balanced_gpu_percent: (document.getElementById('cfg-balanced-gpu-pct').value.trim() === '' balanced_gpu_percent: (document.getElementById('cfg-balanced-gpu-pct').value.trim() === ''
? null : parseFloat(document.getElementById('cfg-balanced-gpu-pct').value)), ? null : parseFloat(document.getElementById('cfg-balanced-gpu-pct').value)),
acceleration: _collectAccel(), acceleration: _collectAccel(),
turboquant: _collectTurboQuant(),
}; };
try{ try{
const r = await fetch(ROOT_PATH + '/admin/api/model-configure',{ const r = await fetch(ROOT_PATH + '/admin/api/model-configure',{
......
...@@ -153,6 +153,24 @@ ...@@ -153,6 +153,24 @@
</div> </div>
</div> </div>
<!-- Background jobs -->
<div class="card mb-0" style="margin-top:1rem">
<div class="card-title">Background Jobs</div>
<span class="form-hint" style="display:block;margin-bottom:.75rem">
Controls how interrupted LoRA training is handled when CoderAI restarts.
Equivalent to the <code>--no-resume-jobs</code> launch flag.
</span>
<div class="form-row" style="margin:0">
<label style="display:flex;align-items:center;gap:.5rem;cursor:pointer">
<input type="checkbox" id="s-jobs-resume">
<span style="font-size:13px;font-weight:500">Resume interrupted training on restart</span>
</label>
<span class="form-hint">When off, a training job that was running at restart is marked
<em>cancelled</em> instead of resuming. Its checkpoint is kept, so you can still
restart it manually from the Tasks page.</span>
</div>
</div>
<div class="card mb-0" style="margin-top:1rem"> <div class="card mb-0" style="margin-top:1rem">
<div class="card-title">AISBF Broker</div> <div class="card-title">AISBF Broker</div>
<div class="form-row"> <div class="form-row">
...@@ -328,6 +346,9 @@ async function loadSettings(){ ...@@ -328,6 +346,9 @@ async function loadSettings(){
document.getElementById('s-therm-cpu-resume').value = therm.cpu_resume ?? 87; document.getElementById('s-therm-cpu-resume').value = therm.cpu_resume ?? 87;
document.getElementById('s-therm-poll').value = therm.poll_seconds ?? 5; document.getElementById('s-therm-poll').value = therm.poll_seconds ?? 5;
toggleThermalFields(); toggleThermalFields();
// Background jobs
const jobs = d.jobs || {};
document.getElementById('s-jobs-resume').checked = jobs.resume_on_restart !== false;
}catch(e){ showAlert('error','Failed to load settings: '+e.message); } }catch(e){ showAlert('error','Failed to load settings: '+e.message); }
} }
...@@ -363,6 +384,9 @@ async function saveSettings(){ ...@@ -363,6 +384,9 @@ async function saveSettings(){
cpu_resume: parseFloat(document.getElementById('s-therm-cpu-resume').value) || 87, cpu_resume: parseFloat(document.getElementById('s-therm-cpu-resume').value) || 87,
poll_seconds: parseFloat(document.getElementById('s-therm-poll').value) || 5, poll_seconds: parseFloat(document.getElementById('s-therm-poll').value) || 5,
}, },
jobs:{
resume_on_restart: document.getElementById('s-jobs-resume').checked,
},
broker:{ broker:{
enabled: document.getElementById('s-broker-enabled').checked, enabled: document.getElementById('s-broker-enabled').checked,
base_url: document.getElementById('s-broker-base-url').value.trim(), base_url: document.getElementById('s-broker-base-url').value.trim(),
......
{% extends "base.html" %}
{% block title %}Tasks — CoderAI{% endblock %}
{% block content %}
<div class="page-header">
<div>
<h1>Tasks</h1>
<p>Live view of generations and LoRA training. Cancel, interrupt, or restart a job.</p>
</div>
<div class="header-actions">
<span id="queue-summary" class="dim small"></span>
</div>
</div>
<div id="thermal-banner" style="display:none;margin:0 0 1rem;padding:.6rem .85rem;border-radius:8px;
background:rgba(245,158,11,.12);border:1px solid rgba(245,158,11,.4);color:#f59e0b;font-size:13px">
<span style="font-weight:600">❄ Thermal cooldown</span>
<span id="thermal-banner-msg" class="mono"></span>
— running work is paused until the hardware cools.
</div>
<!-- Live hardware telemetry -->
<div id="sys-stats" style="display:grid;grid-template-columns:repeat(auto-fit,minmax(220px,1fr));
gap:.75rem;margin:0 0 1.25rem">
<div class="sys-tile" id="tile-cpu"></div>
<div class="sys-tile" id="tile-gpu"></div>
<div class="sys-tile" id="tile-ram"></div>
<div class="sys-tile" id="tile-vram"></div>
</div>
<style>
.sys-tile{border:1px solid var(--border,#2a2a2a);border-radius:10px;padding:.7rem .85rem;
background:var(--card-bg,rgba(255,255,255,.02))}
.sys-tile .sys-head{display:flex;justify-content:space-between;align-items:baseline;margin-bottom:.45rem}
.sys-tile .sys-name{font-size:12px;font-weight:600;letter-spacing:.03em;text-transform:uppercase;color:var(--text-muted,#9aa0a6)}
.sys-tile .sys-val{font-size:13px;font-weight:600}
.sys-tile .sys-sub{font-size:11px;color:var(--text-muted,#9aa0a6);margin-top:.35rem;display:flex;justify-content:space-between}
.sys-bar{height:8px;border-radius:5px;background:rgba(127,127,127,.18);overflow:hidden}
.sys-bar > span{display:block;height:100%;border-radius:5px;transition:width .4s ease,background .4s ease}
.sys-ok > span{background:#22c55e}.sys-warn > span{background:#f59e0b}.sys-hot > span{background:#ef4444}
.sys-temp-ok{color:#22c55e}.sys-temp-warn{color:#f59e0b}.sys-temp-hot{color:#ef4444}
</style>
<div class="table-wrap">
<table>
<thead>
<tr>
<th>Type</th><th>Name / Model</th><th>Status</th>
<th style="width:220px">Progress</th><th>Started</th><th style="text-align:right">Actions</th>
</tr>
</thead>
<tbody id="tasks-body">
<tr class="empty-row"><td colspan="6">No tasks yet</td></tr>
</tbody>
</table>
</div>
{% endblock %}
{% block scripts %}
<script>
function esc(s) { return String(s == null ? '' : s).replace(/&/g,'&amp;').replace(/</g,'&lt;').replace(/>/g,'&gt;'); }
function fmtTime(s) {
if (!s) return '';
try {
// started_at is unix seconds (float) from the server.
const d = new Date(s * 1000);
return d.toLocaleTimeString(undefined, {hour:'2-digit', minute:'2-digit', second:'2-digit'});
} catch { return ''; }
}
const KIND_LABEL = {training:'Training', image:'Image', video:'Video', audio:'Audio', text:'Text', pipeline:'Pipeline', request:'Request'};
const STATUS_BADGE = {
running:'badge-admin', queued:'badge-user', done:'badge-ok', error:'badge-err',
cancelled:'badge-user', interrupted:'badge-warn'
};
function progressBar(t) {
const total = t.total || 0, step = t.step || 0;
if (!total) {
return t.status === 'running' ? '<span class="dim small">working…</span>' : '<span class="dim small">—</span>';
}
const pct = Math.max(0, Math.min(100, Math.round(step / total * 100)));
return `<div class="progress"><div class="progress-fill" style="width:${pct}%"></div></div>
<span class="dim small">${step}/${total} (${pct}%)</span>`;
}
function actions(t) {
const btns = [];
if (t.paused) {
btns.push(`<button class="btn btn-primary btn-sm" onclick="taskAction('${esc(t.id)}','resume')">Resume</button>`);
} else if (t.pausable) {
btns.push(`<button class="btn btn-ghost btn-sm" onclick="taskAction('${esc(t.id)}','pause')">Pause</button>`);
}
if (t.cancellable) {
const label = t.status === 'running' ? 'Interrupt' : 'Cancel';
const act = t.status === 'running' ? 'interrupt' : 'cancel';
btns.push(`<button class="btn btn-danger btn-sm" onclick="taskAction('${esc(t.id)}','${act}')">${label}</button>`);
}
if (t.restartable) {
btns.push(`<button class="btn btn-ghost btn-sm" onclick="taskAction('${esc(t.id)}','restart')">Restart</button>`);
}
if (!t.active) {
btns.push(`<button class="btn btn-ghost btn-sm" onclick="removeTask('${esc(t.id)}')">Remove</button>`);
}
return btns.join(' ') || '<span class="dim small">—</span>';
}
// ---- Live hardware telemetry ----
function _utilClass(pct){ return pct == null ? 'sys-ok' : (pct >= 90 ? 'sys-hot' : pct >= 70 ? 'sys-warn' : 'sys-ok'); }
function _tempClass(t){ return t == null ? '' : (t >= 90 ? 'sys-temp-hot' : t >= 80 ? 'sys-temp-warn' : 'sys-temp-ok'); }
function _bar(pct){
const p = pct == null ? 0 : Math.max(0, Math.min(100, pct));
return `<div class="sys-bar ${_utilClass(pct)}"><span style="width:${p}%"></span></div>`;
}
function _utilTile(name, pct, temp){
const valTxt = pct == null ? 'n/a' : `${Math.round(pct)}%`;
const tempTxt = temp == null ? '<span class="dim">temp n/a</span>'
: `<span class="${_tempClass(temp)}">${Math.round(temp)}°C</span>`;
return `<div class="sys-head"><span class="sys-name">${name}</span><span class="sys-val">${valTxt}</span></div>`
+ _bar(pct) + `<div class="sys-sub"><span>utilization</span>${tempTxt}</div>`;
}
function _memTile(name, used, total, pct){
const valTxt = (used == null || total == null) ? 'n/a' : `${used.toFixed(1)} / ${total.toFixed(1)} GB`;
const p = pct != null ? pct : (used != null && total ? used/total*100 : null);
return `<div class="sys-head"><span class="sys-name">${name}</span><span class="sys-val">${valTxt}</span></div>`
+ _bar(p) + `<div class="sys-sub"><span>${p == null ? '' : Math.round(p)+'% used'}</span><span></span></div>`;
}
async function loadSystemStats(){
try {
const s = await fetch(ROOT_PATH + '/admin/api/system-stats').then(r => r.json());
const cpu = s.cpu || {}, gpu = s.gpu || {}, ram = s.ram || {}, vram = s.vram || {};
document.getElementById('tile-cpu').innerHTML = _utilTile('CPU', cpu.util, cpu.temp);
document.getElementById('tile-gpu').innerHTML = _utilTile('GPU', gpu.util, gpu.temp);
document.getElementById('tile-ram').innerHTML = _memTile('RAM', ram.used, ram.total, ram.percent);
document.getElementById('tile-vram').innerHTML =
_memTile('VRAM', vram.used, vram.total, vram.percent);
} catch(e){ /* keep last render on transient errors */ }
}
let _refreshing = false;
async function loadTasks() {
if (_refreshing) return;
_refreshing = true;
try {
const data = await fetch(ROOT_PATH + '/admin/api/tasks').then(r => r.json());
const tasks = data.tasks || [];
const q = data.queue || {};
document.getElementById('queue-summary').textContent =
`${q.active || 0} active · ${q.waiting || 0} waiting · max ${q.max_parallel_requests || 0} parallel`;
const therm = data.thermal || {};
const banner = document.getElementById('thermal-banner');
if (therm.active) {
document.getElementById('thermal-banner-msg').textContent = ' ' + (therm.message || '');
banner.style.display = '';
} else {
banner.style.display = 'none';
}
const tbody = document.getElementById('tasks-body');
if (!tasks.length) {
tbody.innerHTML = '<tr class="empty-row"><td colspan="6">No tasks yet</td></tr>';
return;
}
tbody.innerHTML = tasks.map(t => {
const badge = STATUS_BADGE[t.status] || 'badge-dim';
const title = t.title || '(untitled)';
let statusCell;
if (t.cooling) {
statusCell = `<span class="badge badge-warn">❄ Cooling down</span>`
+ `<div class="dim small">${esc(t.cooling_message || 'paused for thermal cooldown')}</div>`;
} else if (t.paused) {
statusCell = `<span class="badge badge-warn">⏸ Paused</span>`
+ `<div class="dim small">suspended — click Resume to continue</div>`;
} else {
statusCell = `<span class="badge ${badge}">${esc(t.status)}</span>`
+ (t.message ? `<div class="dim small">${esc(t.message)}</div>` : '');
}
return `<tr>
<td><span class="badge badge-user">${esc(KIND_LABEL[t.kind] || t.kind)}</span></td>
<td><div class="td-name">${esc(title)}</div><div class="dim small mono">${esc(t.model || '')}</div></td>
<td>${statusCell}</td>
<td>${progressBar(t)}</td>
<td class="dim small">${fmtTime(t.started_at)}</td>
<td style="text-align:right">${actions(t)}</td>
</tr>`;
}).join('');
} catch (e) {
// transient fetch errors during a model swap are fine; keep last render.
} finally {
_refreshing = false;
}
}
async function taskAction(id, action) {
const verb = {cancel:'Cancel', interrupt:'Interrupt', restart:'Restart', pause:'Pause', resume:'Resume'}[action] || action;
// Only confirm destructive actions; pause/resume/restart act immediately.
if ((action === 'cancel' || action === 'interrupt') && !confirm(`${verb} this task?`)) return;
try {
const r = await fetch(ROOT_PATH + '/admin/api/tasks/' + encodeURIComponent(id) + '/' + action, {method:'POST'});
if (!r.ok) {
const e = await r.json().catch(() => ({}));
alert(e.detail || (verb + ' failed'));
}
} catch (e) { alert(e.message); }
loadTasks();
}
async function removeTask(id) {
try {
const r = await fetch(ROOT_PATH + '/admin/api/tasks/' + encodeURIComponent(id), {method:'DELETE'});
if (!r.ok) {
const e = await r.json().catch(() => ({}));
alert(e.detail || 'Remove failed');
}
} catch (e) { alert(e.message); }
loadTasks();
}
loadTasks();
loadSystemStats();
setInterval(loadTasks, 2000);
setInterval(loadSystemStats, 2000);
</script>
{% endblock %}
...@@ -189,25 +189,25 @@ if admin_static_dir.exists(): ...@@ -189,25 +189,25 @@ if admin_static_dir.exists():
app.mount("/static/admin", StaticFiles(directory=str(admin_static_dir)), name="admin_static") app.mount("/static/admin", StaticFiles(directory=str(admin_static_dir)), name="admin_static")
# Include routers from submodules # Include routers from submodules
app.include_router(transcriptions_router) app.include_router(transcriptions_router, tags=["Audio"])
app.include_router(images_router) app.include_router(images_router, tags=["Images"])
app.include_router(tts_router) app.include_router(tts_router, tags=["Audio"])
app.include_router(text_router) app.include_router(text_router, tags=["Text"])
app.include_router(video_router) app.include_router(video_router, tags=["Video"])
app.include_router(audio_gen_router) app.include_router(audio_gen_router, tags=["Audio"])
app.include_router(audio_stems_router) app.include_router(audio_stems_router, tags=["Audio"])
app.include_router(audio_clean_router) app.include_router(audio_clean_router, tags=["Audio"])
app.include_router(embeddings_router) app.include_router(embeddings_router, tags=["Embeddings"])
app.include_router(pipelines_router) app.include_router(pipelines_router, tags=["Pipelines"])
app.include_router(custom_pipelines_router) app.include_router(custom_pipelines_router, tags=["Pipelines"])
app.include_router(voice_clone_router) app.include_router(voice_clone_router, tags=["Audio"])
app.include_router(voice_convert_router) app.include_router(voice_convert_router, tags=["Audio"])
app.include_router(faceswap_router) app.include_router(faceswap_router, tags=["Images"])
app.include_router(characters_router) app.include_router(characters_router, tags=["Characters"])
app.include_router(loras_router) app.include_router(loras_router, tags=["LoRAs"])
app.include_router(environments_router) app.include_router(environments_router, tags=["Environments"])
app.include_router(spatial_router) app.include_router(spatial_router, tags=["Spatial / 3D"])
app.include_router(admin_router) app.include_router(admin_router, tags=["Admin"])
@app.exception_handler(401) @app.exception_handler(401)
...@@ -222,20 +222,35 @@ async def unauthorized_redirect(request: Request, exc: HTTPException): ...@@ -222,20 +222,35 @@ async def unauthorized_redirect(request: Request, exc: HTTPException):
return JSONResponse(status_code=401, content={"detail": exc.detail}) return JSONResponse(status_code=401, content={"detail": exc.detail})
@app.get("/v1/models", response_model=ModelList) from codai.tasks import TaskCancelled, task_registry
@app.exception_handler(TaskCancelled)
async def task_cancelled_handler(request: Request, exc: TaskCancelled):
"""A worker observed its task was cancelled and unwound. Finish the task
(cancelled) and return 499 (client-closed-request style). The task id is
carried on the exception so any generation/training worker can simply
`raise` without bookkeeping."""
tid = exc.args[0] if exc.args else None
if tid:
task_registry.finish(tid, "cancelled", "cancelled by user")
return JSONResponse(status_code=499, content={"detail": "Task cancelled", "task_id": tid})
@app.get("/v1/models", response_model=ModelList, summary="List available models", tags=["Core"])
async def list_models(): async def list_models():
"""List available models.""" """List available models."""
models = multi_model_manager.list_models() models = multi_model_manager.list_models()
return ModelList(data=models) return ModelList(data=models)
@app.get("/coderai/capabilities") @app.get("/coderai/capabilities", summary="Server capability document", tags=["Core"])
async def get_broker_capabilities(): async def get_broker_capabilities():
"""Return broker capability metadata.""" """Return broker capability metadata."""
return build_capabilities_document(hardware=build_hardware_summary()) return build_capabilities_document(hardware=build_hardware_summary())
@app.get("/v1/files/{filename}") @app.get("/v1/files/{filename}", summary="Download a generated file", tags=["Files"])
async def get_file(filename: str): async def get_file(filename: str):
"""Serve uploaded/generated files.""" """Serve uploaded/generated files."""
if not global_file_path: if not global_file_path:
...@@ -256,7 +271,7 @@ _VIDEO_EXTS = {'.mp4', '.webm', '.avi', '.mov'} ...@@ -256,7 +271,7 @@ _VIDEO_EXTS = {'.mp4', '.webm', '.avi', '.mov'}
_AUDIO_EXTS = {'.wav', '.mp3', '.ogg', '.flac', '.aac', '.m4a'} _AUDIO_EXTS = {'.wav', '.mp3', '.ogg', '.flac', '.aac', '.m4a'}
@app.get("/v1/archive") @app.get("/v1/archive", summary="List archived generations", tags=["Files"])
async def list_archive(request: Request): async def list_archive(request: Request):
"""List all generated files in the output directory.""" """List all generated files in the output directory."""
if not global_file_path or not os.path.isdir(global_file_path): if not global_file_path or not os.path.isdir(global_file_path):
...@@ -292,7 +307,7 @@ async def list_archive(request: Request): ...@@ -292,7 +307,7 @@ async def list_archive(request: Request):
return {"files": files} return {"files": files}
@app.delete("/v1/archive/{filename}") @app.delete("/v1/archive/{filename}", summary="Delete an archived file", tags=["Files"])
async def delete_archive_file(filename: str): async def delete_archive_file(filename: str):
"""Delete a generated file from the output directory.""" """Delete a generated file from the output directory."""
if not global_file_path: if not global_file_path:
......
...@@ -116,8 +116,15 @@ class AudioCleanupRequest(BaseModel): ...@@ -116,8 +116,15 @@ class AudioCleanupRequest(BaseModel):
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow")
@router.post("/v1/audio/cleanup") @router.post("/v1/audio/cleanup", summary="Clean / restore audio")
async def cleanup_audio(request: AudioCleanupRequest, http_request: Request = None): async def cleanup_audio(request: AudioCleanupRequest, http_request: Request = None):
"""Restore/clean a noisy audio clip.
Applies any combination of noise reduction, loudness normalization, mains-hum
removal and click/crackle repair. Uses an ML restoration backend when available,
falling back to an ffmpeg-based best-effort path when `fallback_mode` is set.
Returns the cleaned audio plus the backend and quality tier that were used.
"""
try: try:
audio_bytes = _decode_audio(request.audio) audio_bytes = _decode_audio(request.audio)
except Exception as exc: except Exception as exc:
......
...@@ -31,6 +31,7 @@ from fastapi import APIRouter, HTTPException, Request ...@@ -31,6 +31,7 @@ from fastapi import APIRouter, HTTPException, Request
from codai.models.manager import multi_model_manager from codai.models.manager import multi_model_manager
from codai.pydantic.audiogenrequest import AudioGenerationRequest, AudioGenerationResponse from codai.pydantic.audiogenrequest import AudioGenerationRequest, AudioGenerationResponse
from codai.tasks import task_registry, TaskCancelled
router = APIRouter() router = APIRouter()
...@@ -160,7 +161,7 @@ def _detect_audio_gen_type(model_name: str) -> str: ...@@ -160,7 +161,7 @@ def _detect_audio_gen_type(model_name: str) -> str:
return 'musicgen' return 'musicgen'
def _generate_audio(pipe, model_name: str, request: AudioGenerationRequest): def _generate_audio(pipe, model_name: str, request: AudioGenerationRequest, task_id=None):
"""Run generation and return (audio_bytes, ext).""" """Run generation and return (audio_bytes, ext)."""
import numpy as np, io as _io import numpy as np, io as _io
...@@ -191,6 +192,9 @@ def _generate_audio(pipe, model_name: str, request: AudioGenerationRequest): ...@@ -191,6 +192,9 @@ def _generate_audio(pipe, model_name: str, request: AudioGenerationRequest):
_aud_progress_reset(num_steps, unit="it") _aud_progress_reset(num_steps, unit="it")
def _aud_step_cb(pipe, step_index, timestep, callback_kwargs): def _aud_step_cb(pipe, step_index, timestep, callback_kwargs):
task_registry.raise_if_cancelled(task_id)
task_registry.wait_if_paused(task_id)
task_registry.step(task_id, step_index + 1)
_aud_progress_step(step_index + 1) _aud_progress_step(step_index + 1)
return callback_kwargs return callback_kwargs
...@@ -222,7 +226,7 @@ def _decode_b64_or_url(data: str) -> bytes: ...@@ -222,7 +226,7 @@ def _decode_b64_or_url(data: str) -> bytes:
return base64.b64decode(data) return base64.b64decode(data)
@router.get("/v1/audio/progress") @router.get("/v1/audio/progress", summary="Audio generation progress")
async def get_audio_progress(): async def get_audio_progress():
"""Return current audio generation progress including speed.""" """Return current audio generation progress including speed."""
elapsed = time.monotonic() - _aud_progress["started_at"] if _aud_progress["active"] else 0.0 elapsed = time.monotonic() - _aud_progress["started_at"] if _aud_progress["active"] else 0.0
...@@ -241,7 +245,7 @@ async def get_audio_progress(): ...@@ -241,7 +245,7 @@ async def get_audio_progress():
} }
@router.post("/v1/audio/generate", response_model=AudioGenerationResponse) @router.post("/v1/audio/generate", response_model=AudioGenerationResponse, summary="Generate audio, music or SFX")
async def audio_generate(request: AudioGenerationRequest, http_request: Request = None): async def audio_generate(request: AudioGenerationRequest, http_request: Request = None):
""" """
Generate music, sound effects, or ambient audio. Generate music, sound effects, or ambient audio.
...@@ -274,14 +278,22 @@ async def audio_generate(request: AudioGenerationRequest, http_request: Request ...@@ -274,14 +278,22 @@ async def audio_generate(request: AudioGenerationRequest, http_request: Request
multi_model_manager.models[model_key] = pipe multi_model_manager.models[model_key] = pipe
multi_model_manager.current_model_key = model_key multi_model_manager.current_model_key = model_key
_tid = task_registry.register(
"audio", title=(request.prompt or "")[:80], model=model_name or "")
task_registry.start(_tid)
try: try:
audio_bytes, ext = await asyncio.get_event_loop().run_in_executor( audio_bytes, ext = await asyncio.get_event_loop().run_in_executor(
None, _generate_audio, pipe, model_name, request) None, _generate_audio, pipe, model_name, request, _tid)
except TaskCancelled:
_aud_progress_done()
raise # global handler finishes the task (cancelled) + returns HTTP 499
except Exception as e: except Exception as e:
task_registry.finish(_tid, "error", str(e)[:200])
_aud_progress_done() _aud_progress_done()
raise HTTPException(status_code=500, detail=f"Audio generation failed: {e}") raise HTTPException(status_code=500, detail=f"Audio generation failed: {e}")
finally: finally:
_aud_progress_done() _aud_progress_done()
task_registry.finish(_tid, "done")
result = _save_audio_response(audio_bytes, ext, http_request) result = _save_audio_response(audio_bytes, ext, http_request)
......
...@@ -166,8 +166,15 @@ class AudioStemRequest(BaseModel): ...@@ -166,8 +166,15 @@ class AudioStemRequest(BaseModel):
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow")
@router.post("/v1/audio/stems") @router.post("/v1/audio/stems", summary="Separate audio into stems")
async def separate_stems(request: AudioStemRequest, http_request: Request = None): async def separate_stems(request: AudioStemRequest, http_request: Request = None):
"""Split a track into its component stems (source separation).
Separates an input clip according to `stem_mode` (e.g. vocals/instrumental, or a
full 4-stem split). Uses an ML separation provider when available, falling back to
an ffmpeg-based best-effort split when `fallback_mode` is set. Returns one audio
output per stem along with the backend and quality tier used.
"""
try: try:
audio_bytes = _decode_audio(request.audio) audio_bytes = _decode_audio(request.audio)
except Exception as exc: except Exception as exc:
......
...@@ -419,7 +419,7 @@ def resolve_character_profiles(profile_names: List[str]) -> List[str]: ...@@ -419,7 +419,7 @@ def resolve_character_profiles(profile_names: List[str]) -> List[str]:
# ── Endpoints ───────────────────────────────────────────────────────────────── # ── Endpoints ─────────────────────────────────────────────────────────────────
@router.post("/v1/characters") @router.post("/v1/characters", summary="Create or replace a character profile")
async def save_character(req: CharacterSaveRequest, _auth=Depends(_require_api_auth)): async def save_character(req: CharacterSaveRequest, _auth=Depends(_require_api_auth)):
"""Save or update a named character profile.""" """Save or update a named character profile."""
if not req.name or '/' in req.name or '..' in req.name: if not req.name or '/' in req.name or '..' in req.name:
...@@ -430,13 +430,13 @@ async def save_character(req: CharacterSaveRequest, _auth=Depends(_require_api_a ...@@ -430,13 +430,13 @@ async def save_character(req: CharacterSaveRequest, _auth=Depends(_require_api_a
return {"ok": True, "name": meta['name'], "image_count": meta['image_count']} return {"ok": True, "name": meta['name'], "image_count": meta['image_count']}
@router.get("/v1/characters") @router.get("/v1/characters", summary="List character profiles")
async def list_characters(_auth=Depends(_require_api_auth)): async def list_characters(_auth=Depends(_require_api_auth)):
"""List all saved character profiles (metadata only, no images).""" """List all saved character profiles (metadata only, no images)."""
return {"characters": _list_characters()} return {"characters": _list_characters()}
@router.get("/v1/characters/{name}") @router.get("/v1/characters/{name}", summary="Get a character profile")
async def get_character(name: str, _auth=Depends(_require_api_auth)): async def get_character(name: str, _auth=Depends(_require_api_auth)):
"""Get a character profile including its reference images as base64.""" """Get a character profile including its reference images as base64."""
meta = _load_character_meta(name) meta = _load_character_meta(name)
...@@ -452,7 +452,7 @@ async def get_character(name: str, _auth=Depends(_require_api_auth)): ...@@ -452,7 +452,7 @@ async def get_character(name: str, _auth=Depends(_require_api_auth)):
} }
@router.delete("/v1/characters/{name}") @router.delete("/v1/characters/{name}", summary="Delete a character profile")
async def delete_character(name: str, _auth=Depends(_require_api_auth)): async def delete_character(name: str, _auth=Depends(_require_api_auth)):
"""Delete a character profile.""" """Delete a character profile."""
cdir = _char_dir(name) cdir = _char_dir(name)
...@@ -463,7 +463,7 @@ async def delete_character(name: str, _auth=Depends(_require_api_auth)): ...@@ -463,7 +463,7 @@ async def delete_character(name: str, _auth=Depends(_require_api_auth)):
return {"ok": True, "name": name} return {"ok": True, "name": name}
@router.patch("/v1/characters/{name}") @router.patch("/v1/characters/{name}", summary="Update a character profile")
async def patch_character(name: str, req: CharacterPatchRequest, _auth=Depends(_require_api_auth)): async def patch_character(name: str, req: CharacterPatchRequest, _auth=Depends(_require_api_auth)):
"""Update a character profile: description, add images, or remove images by index.""" """Update a character profile: description, add images, or remove images by index."""
meta = _load_character_meta(name) meta = _load_character_meta(name)
...@@ -512,7 +512,7 @@ async def patch_character(name: str, req: CharacterPatchRequest, _auth=Depends(_ ...@@ -512,7 +512,7 @@ async def patch_character(name: str, req: CharacterPatchRequest, _auth=Depends(_
return {"ok": True, "name": name, "image_count": meta['image_count']} return {"ok": True, "name": name, "image_count": meta['image_count']}
@router.post("/v1/characters/generate") @router.post("/v1/characters/generate", summary="Generate character reference images")
async def generate_character(req: CharacterGenerateRequest, request: Request): async def generate_character(req: CharacterGenerateRequest, request: Request):
""" """
Generate a character profile from a text prompt. Generate a character profile from a text prompt.
...@@ -585,7 +585,7 @@ async def generate_character(req: CharacterGenerateRequest, request: Request): ...@@ -585,7 +585,7 @@ async def generate_character(req: CharacterGenerateRequest, request: Request):
return {"ok": True, "name": meta["name"], "image_count": meta["image_count"]} return {"ok": True, "name": meta["name"], "image_count": meta["image_count"]}
@router.post("/v1/characters/extract") @router.post("/v1/characters/extract", summary="Extract a character from media")
async def extract_character(req: CharacterExtractRequest): async def extract_character(req: CharacterExtractRequest):
""" """
Extract a character profile from source images and/or videos. Extract a character profile from source images and/or videos.
......
...@@ -380,7 +380,7 @@ class AudioMusicDubRequest(BaseModel): ...@@ -380,7 +380,7 @@ class AudioMusicDubRequest(BaseModel):
model_config = ConfigDict(extra='allow') model_config = ConfigDict(extra='allow')
@router.get('/v1/pipelines/custom') @router.get('/v1/pipelines/custom', summary="List saved custom pipelines")
async def list_custom_pipelines(): async def list_custom_pipelines():
"""List all saved custom pipeline definitions.""" """List all saved custom pipeline definitions."""
from codai.admin.routes import config_manager from codai.admin.routes import config_manager
...@@ -389,7 +389,7 @@ async def list_custom_pipelines(): ...@@ -389,7 +389,7 @@ async def list_custom_pipelines():
return {'pipelines': config_manager.pipelines_data} return {'pipelines': config_manager.pipelines_data}
@router.get('/v1/pipelines/step-types') @router.get('/v1/pipelines/step-types', summary="List available pipeline step types")
async def list_step_types(): async def list_step_types():
"""List available step types with their parameter schemas.""" """List available step types with their parameter schemas."""
return { return {
...@@ -400,7 +400,7 @@ async def list_step_types(): ...@@ -400,7 +400,7 @@ async def list_step_types():
} }
@router.post('/v1/pipelines/custom') @router.post('/v1/pipelines/custom', summary="Create a custom pipeline")
async def create_custom_pipeline(pipeline: PipelineDefinition): async def create_custom_pipeline(pipeline: PipelineDefinition):
"""Save a new custom pipeline definition.""" """Save a new custom pipeline definition."""
from codai.admin.routes import config_manager from codai.admin.routes import config_manager
...@@ -416,7 +416,7 @@ async def create_custom_pipeline(pipeline: PipelineDefinition): ...@@ -416,7 +416,7 @@ async def create_custom_pipeline(pipeline: PipelineDefinition):
return {'created': True, 'pipeline': data} return {'created': True, 'pipeline': data}
@router.put('/v1/pipelines/custom/{pipeline_id}') @router.put('/v1/pipelines/custom/{pipeline_id}', summary="Update a custom pipeline")
async def update_custom_pipeline(pipeline_id: str, pipeline: PipelineDefinition): async def update_custom_pipeline(pipeline_id: str, pipeline: PipelineDefinition):
"""Update an existing custom pipeline.""" """Update an existing custom pipeline."""
from codai.admin.routes import config_manager from codai.admin.routes import config_manager
...@@ -433,7 +433,7 @@ async def update_custom_pipeline(pipeline_id: str, pipeline: PipelineDefinition) ...@@ -433,7 +433,7 @@ async def update_custom_pipeline(pipeline_id: str, pipeline: PipelineDefinition)
return {'updated': True, 'pipeline': data} return {'updated': True, 'pipeline': data}
@router.delete('/v1/pipelines/custom/{pipeline_id}') @router.delete('/v1/pipelines/custom/{pipeline_id}', summary="Delete a custom pipeline")
async def delete_custom_pipeline(pipeline_id: str): async def delete_custom_pipeline(pipeline_id: str):
"""Delete a custom pipeline.""" """Delete a custom pipeline."""
from codai.admin.routes import config_manager from codai.admin.routes import config_manager
...@@ -447,7 +447,7 @@ async def delete_custom_pipeline(pipeline_id: str): ...@@ -447,7 +447,7 @@ async def delete_custom_pipeline(pipeline_id: str):
return {'deleted': True, 'id': pipeline_id} return {'deleted': True, 'id': pipeline_id}
@router.post('/v1/pipelines/custom/{pipeline_id}/run') @router.post('/v1/pipelines/custom/{pipeline_id}/run', summary="Run a saved custom pipeline")
async def run_custom_pipeline(pipeline_id: str, body: PipelineRunRequest, http_request: Request = None): async def run_custom_pipeline(pipeline_id: str, body: PipelineRunRequest, http_request: Request = None):
"""Execute a saved custom pipeline.""" """Execute a saved custom pipeline."""
from codai.admin.routes import config_manager from codai.admin.routes import config_manager
...@@ -459,14 +459,20 @@ async def run_custom_pipeline(pipeline_id: str, body: PipelineRunRequest, http_r ...@@ -459,14 +459,20 @@ async def run_custom_pipeline(pipeline_id: str, body: PipelineRunRequest, http_r
return await _execute_pipeline(pipeline_def, body.input or '', http_request) return await _execute_pipeline(pipeline_def, body.input or '', http_request)
@router.post('/v1/pipelines/run') @router.post('/v1/pipelines/run', summary="Run an inline pipeline definition")
async def run_inline_pipeline(pipeline: PipelineDefinition, http_request: Request = None): async def run_inline_pipeline(pipeline: PipelineDefinition, http_request: Request = None):
"""Execute an inline pipeline definition without saving it.""" """Execute an inline pipeline definition without saving it."""
return await _execute_pipeline(pipeline.model_dump(), '', http_request) return await _execute_pipeline(pipeline.model_dump(), '', http_request)
@router.post('/v1/pipelines/audio-understand') @router.post('/v1/pipelines/audio-understand', summary="Transcribe and analyze audio")
async def run_audio_understanding(request: AudioUnderstandRequest, http_request: Request = None): async def run_audio_understanding(request: AudioUnderstandRequest, http_request: Request = None):
"""Transcribe and analyze an audio clip in one pass.
Convenience pipeline that transcribes the input audio and then reasons over the
transcript (summary/understanding) using the configured text model. Returns the
transcript together with the model's analysis.
"""
if not request.audio: if not request.audio:
raise HTTPException(status_code=400, detail='Provide audio input') raise HTTPException(status_code=400, detail='Provide audio input')
...@@ -543,8 +549,14 @@ async def run_full_music_dub(request: AudioMusicDubRequest, http_request: Reques ...@@ -543,8 +549,14 @@ async def run_full_music_dub(request: AudioMusicDubRequest, http_request: Reques
} }
@router.post('/v1/pipelines/audio-music-dub') @router.post('/v1/pipelines/audio-music-dub', summary="Dub a song into another language")
async def run_audio_music_dub(request: AudioMusicDubRequest, http_request: Request = None): async def run_audio_music_dub(request: AudioMusicDubRequest, http_request: Request = None):
"""Dub a song into another language while preserving the backing music.
Splits the track into vocals and instrumental, transcribes and translates the
lyrics, re-sings/voice-converts the translated vocals, then remixes them over the
original instrumental. Returns every intermediate stem plus the final mixed result.
"""
if not request.audio: if not request.audio:
raise HTTPException(status_code=400, detail='Provide audio input') raise HTTPException(status_code=400, detail='Provide audio input')
......
...@@ -101,7 +101,7 @@ def _embed_texts(model_obj, texts: List[str], dimensions=None) -> List[List[floa ...@@ -101,7 +101,7 @@ def _embed_texts(model_obj, texts: List[str], dimensions=None) -> List[List[floa
return results return results
@router.post("/v1/embeddings", response_model=EmbeddingsResponse) @router.post("/v1/embeddings", response_model=EmbeddingsResponse, summary="Create embeddings")
async def create_embeddings(request: EmbeddingsRequest, http_request: Request = None): async def create_embeddings(request: EmbeddingsRequest, http_request: Request = None):
""" """
OpenAI-compatible embeddings endpoint. OpenAI-compatible embeddings endpoint.
...@@ -116,10 +116,11 @@ async def create_embeddings(request: EmbeddingsRequest, http_request: Request = ...@@ -116,10 +116,11 @@ async def create_embeddings(request: EmbeddingsRequest, http_request: Request =
model_key = model_info['model_key'] model_key = model_info['model_key']
model_obj = model_info.get('model_object') model_obj = model_info.get('model_object')
_emb_cfg = (multi_model_manager.config.get(f"embedding:{model_name}")
or multi_model_manager.config.get(model_name) or {})
if model_obj is None: if model_obj is None:
device = _derive_device() device = _derive_device()
_emb_cfg = (multi_model_manager.config.get(f"embedding:{model_name}")
or multi_model_manager.config.get(model_name) or {})
try: try:
model_obj = await asyncio.get_event_loop().run_in_executor( model_obj = await asyncio.get_event_loop().run_in_executor(
None, _load_embedding_model, model_name, device, _emb_cfg) None, _load_embedding_model, model_name, device, _emb_cfg)
...@@ -136,7 +137,59 @@ async def create_embeddings(request: EmbeddingsRequest, http_request: Request = ...@@ -136,7 +137,59 @@ async def create_embeddings(request: EmbeddingsRequest, http_request: Request =
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail=f"Embedding failed: {e}") raise HTTPException(status_code=500, detail=f"Embedding failed: {e}")
if request.encoding_format == 'base64': # Optional TurboQuant vector quantization (data-free, inner-product preserving).
# The per-model config block (turboquant: {enabled, backend, bits}) is the
# source of truth for enable/disable + which implementation to use; the
# per-request `quantization` field triggers it and can override the bit width.
from codai.models import turboquant as _tq
_raw = _emb_cfg.get('_raw_cfg') if isinstance(_emb_cfg.get('_raw_cfg'), dict) else {}
tq_cfg = _emb_cfg.get('turboquant') or _raw.get('turboquant') or {}
tq_enabled = tq_cfg.get('enabled', None) # None = no explicit model setting
tq_backend = (tq_cfg.get('backend') or 'builtin')
quant_meta = None
quant_bits = None
req_spec = getattr(request, 'quantization', None)
if not req_spec and tq_enabled and tq_cfg.get('bits'):
req_spec = f"turbo{tq_cfg.get('bits')}" # model-configured default
if req_spec:
if tq_enabled is False:
raise HTTPException(
status_code=400,
detail="TurboQuant is disabled for this model (enable it in the "
"model configuration).")
quant_bits = _tq._parse_quant_spec(req_spec)
if quant_bits is None:
raise HTTPException(
status_code=400,
detail=f"Unsupported quantization '{req_spec}' "
"(use 'turbo'/'turbo8'/'turbo6'/'turbo4'/'turbo2')")
if quant_bits is not None and request.encoding_format == 'base64':
# Compact wire form: each embedding is base64 of [f16 norm][packed codes].
# The compact packing is the built-in wire format regardless of backend
# (the upstream library exposes its own opaque store, not per-vector blobs).
blobs, meta = await asyncio.get_event_loop().run_in_executor(
None, _tq.quantize_base64, vectors, quant_bits)
data = [EmbeddingObject(index=i, embedding=b) for i, b in enumerate(blobs)]
quant_meta = {
"method": meta.method, "bits": meta.bits, "seed": meta.seed,
"dim": meta.dim, "dim_padded": meta.dim_padded, "radius": meta.radius,
"bytes_per_vector": meta.bytes_per_vector, "backend": "builtin",
"layout": "base64([float16 norm][packbits(rotated b-bit codes, MSB-first per numpy.packbits)])",
}
elif quant_bits is not None:
# Lossy reconstruction returned as plain floats (quantized-store fidelity).
try:
vectors = await asyncio.get_event_loop().run_in_executor(
None, lambda: _tq.reconstruct(vectors, quant_bits, backend=tq_backend))
except RuntimeError as e:
raise HTTPException(status_code=400, detail=str(e))
data = [EmbeddingObject(index=i, embedding=v) for i, v in enumerate(vectors)]
eff_backend = tq_backend if tq_backend != 'auto' else _tq.backend_name()
quant_meta = {"method": "turboquant", "bits": quant_bits,
"encoding": "float-reconstruction", "backend": eff_backend}
elif request.encoding_format == 'base64':
import struct import struct
data = [EmbeddingObject( data = [EmbeddingObject(
index=i, index=i,
...@@ -146,8 +199,11 @@ async def create_embeddings(request: EmbeddingsRequest, http_request: Request = ...@@ -146,8 +199,11 @@ async def create_embeddings(request: EmbeddingsRequest, http_request: Request =
data = [EmbeddingObject(index=i, embedding=v) for i, v in enumerate(vectors)] data = [EmbeddingObject(index=i, embedding=v) for i, v in enumerate(vectors)]
total_tokens = sum(len(t.split()) for t in texts) total_tokens = sum(len(t.split()) for t in texts)
return EmbeddingsResponse( resp = EmbeddingsResponse(
data=data, data=data,
model=request.model, model=request.model,
usage={"prompt_tokens": total_tokens, "total_tokens": total_tokens}, usage={"prompt_tokens": total_tokens, "total_tokens": total_tokens},
) )
\ No newline at end of file if quant_meta is not None:
resp.quantization = quant_meta
return resp
\ No newline at end of file
...@@ -307,7 +307,7 @@ def resolve_environment_profiles(profile_names: List[str]) -> List[str]: ...@@ -307,7 +307,7 @@ def resolve_environment_profiles(profile_names: List[str]) -> List[str]:
# ── Endpoints ───────────────────────────────────────────────────────────────── # ── Endpoints ─────────────────────────────────────────────────────────────────
@router.post("/v1/environments") @router.post("/v1/environments", summary="Create or replace an environment profile")
async def save_environment(req: EnvironmentSaveRequest, _auth=Depends(_require_api_auth)): async def save_environment(req: EnvironmentSaveRequest, _auth=Depends(_require_api_auth)):
"""Save or update a named environment profile.""" """Save or update a named environment profile."""
if not req.name or '/' in req.name or '..' in req.name: if not req.name or '/' in req.name or '..' in req.name:
...@@ -318,13 +318,13 @@ async def save_environment(req: EnvironmentSaveRequest, _auth=Depends(_require_a ...@@ -318,13 +318,13 @@ async def save_environment(req: EnvironmentSaveRequest, _auth=Depends(_require_a
return {"ok": True, "name": meta['name'], "image_count": meta['image_count']} return {"ok": True, "name": meta['name'], "image_count": meta['image_count']}
@router.get("/v1/environments") @router.get("/v1/environments", summary="List environment profiles")
async def list_environments(_auth=Depends(_require_api_auth)): async def list_environments(_auth=Depends(_require_api_auth)):
"""List all saved environment profiles (metadata only).""" """List all saved environment profiles (metadata only)."""
return {"environments": _list_environments()} return {"environments": _list_environments()}
@router.get("/v1/environments/{name}") @router.get("/v1/environments/{name}", summary="Get an environment profile")
async def get_environment(name: str, _auth=Depends(_require_api_auth)): async def get_environment(name: str, _auth=Depends(_require_api_auth)):
"""Get an environment profile including its reference images as base64.""" """Get an environment profile including its reference images as base64."""
meta = _load_environment_meta(name) meta = _load_environment_meta(name)
...@@ -340,7 +340,7 @@ async def get_environment(name: str, _auth=Depends(_require_api_auth)): ...@@ -340,7 +340,7 @@ async def get_environment(name: str, _auth=Depends(_require_api_auth)):
} }
@router.delete("/v1/environments/{name}") @router.delete("/v1/environments/{name}", summary="Delete an environment profile")
async def delete_environment(name: str, _auth=Depends(_require_api_auth)): async def delete_environment(name: str, _auth=Depends(_require_api_auth)):
"""Delete an environment profile.""" """Delete an environment profile."""
edir = _env_dir(name) edir = _env_dir(name)
...@@ -351,7 +351,7 @@ async def delete_environment(name: str, _auth=Depends(_require_api_auth)): ...@@ -351,7 +351,7 @@ async def delete_environment(name: str, _auth=Depends(_require_api_auth)):
return {"ok": True, "name": name} return {"ok": True, "name": name}
@router.patch("/v1/environments/{name}") @router.patch("/v1/environments/{name}", summary="Update an environment profile")
async def patch_environment(name: str, req: EnvironmentPatchRequest, _auth=Depends(_require_api_auth)): async def patch_environment(name: str, req: EnvironmentPatchRequest, _auth=Depends(_require_api_auth)):
"""Update an environment profile: description, add images, or remove images by index.""" """Update an environment profile: description, add images, or remove images by index."""
meta = _load_environment_meta(name) meta = _load_environment_meta(name)
...@@ -398,7 +398,7 @@ async def patch_environment(name: str, req: EnvironmentPatchRequest, _auth=Depen ...@@ -398,7 +398,7 @@ async def patch_environment(name: str, req: EnvironmentPatchRequest, _auth=Depen
return {"ok": True, "name": name, "image_count": meta['image_count']} return {"ok": True, "name": name, "image_count": meta['image_count']}
@router.post("/v1/environments/generate") @router.post("/v1/environments/generate", summary="Generate environment reference images")
async def generate_environment(req: EnvironmentGenerateRequest, request: Request): async def generate_environment(req: EnvironmentGenerateRequest, request: Request):
""" """
Generate an environment profile from a text prompt. Generate an environment profile from a text prompt.
...@@ -471,7 +471,7 @@ async def generate_environment(req: EnvironmentGenerateRequest, request: Request ...@@ -471,7 +471,7 @@ async def generate_environment(req: EnvironmentGenerateRequest, request: Request
return {"ok": True, "name": meta["name"], "image_count": meta["image_count"]} return {"ok": True, "name": meta["name"], "image_count": meta["image_count"]}
@router.post("/v1/environments/extract") @router.post("/v1/environments/extract", summary="Extract an environment from media")
async def extract_environment(req: EnvironmentExtractRequest): async def extract_environment(req: EnvironmentExtractRequest):
""" """
Extract an environment profile from source images and/or videos. Extract an environment profile from source images and/or videos.
......
...@@ -144,7 +144,7 @@ class FaceSwapRequest(BaseModel): ...@@ -144,7 +144,7 @@ class FaceSwapRequest(BaseModel):
# Endpoint # Endpoint
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@router.post('/v1/images/faceswap') @router.post('/v1/images/faceswap', summary="Swap faces between images")
async def faceswap(request: FaceSwapRequest, http_request: Request = None): async def faceswap(request: FaceSwapRequest, http_request: Request = None):
""" """
Swap the face from source_face into every face found in target. Swap the face from source_face into every face found in target.
......
...@@ -37,6 +37,7 @@ from pydantic import BaseModel, ConfigDict ...@@ -37,6 +37,7 @@ from pydantic import BaseModel, ConfigDict
from codai.models.manager import multi_model_manager from codai.models.manager import multi_model_manager
from codai.pydantic.imagerequest import ImageGenerationRequest from codai.pydantic.imagerequest import ImageGenerationRequest
from codai.api.state import get_load_mode from codai.api.state import get_load_mode
from codai.tasks import task_registry, TaskCancelled
# ============================================================================= # =============================================================================
...@@ -756,6 +757,13 @@ async def _generate_with_diffusers(pipeline, request, global_args, http_request= ...@@ -756,6 +757,13 @@ async def _generate_with_diffusers(pipeline, request, global_args, http_request=
_progress_reset(num_steps) _progress_reset(num_steps)
# Register this generation as a cancellable task (live view + cooperative
# cancel via the step callback below).
_tid = task_registry.register(
"image", title=(request.prompt or "")[:80],
model=getattr(request, 'model', '') or '', total=num_steps)
task_registry.start(_tid)
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# Prompt embedding cache # Prompt embedding cache
# Try to encode the prompt once and reuse the embeddings. # Try to encode the prompt once and reuse the embeddings.
...@@ -830,6 +838,11 @@ async def _generate_with_diffusers(pipeline, request, global_args, http_request= ...@@ -830,6 +838,11 @@ async def _generate_with_diffusers(pipeline, request, global_args, http_request=
embed_kwargs = {} embed_kwargs = {}
def _step_cb(pipe, step_index, timestep, callback_kwargs): def _step_cb(pipe, step_index, timestep, callback_kwargs):
# Cooperative cancellation: abort at the next step boundary if cancelled.
task_registry.raise_if_cancelled(_tid)
# Cooperative pause: block here while the user has paused this task.
task_registry.wait_if_paused(_tid)
task_registry.step(_tid, step_index + 1)
_progress_step(step_index + 1) _progress_step(step_index + 1)
# Mid-generation thermal checkpoint: pause between denoise steps if too hot. # Mid-generation thermal checkpoint: pause between denoise steps if too hot.
try: try:
...@@ -912,10 +925,21 @@ async def _generate_with_diffusers(pipeline, request, global_args, http_request= ...@@ -912,10 +925,21 @@ async def _generate_with_diffusers(pipeline, request, global_args, http_request=
try: try:
result = await asyncio.to_thread(pipeline, **call_kwargs) result = await asyncio.to_thread(pipeline, **call_kwargs)
except TaskCancelled:
_progress_done()
raise # global handler finishes the task (cancelled) + returns HTTP 499
except TypeError: except TypeError:
# Older pipeline that doesn't support callback_on_step_end # Older pipeline that doesn't support callback_on_step_end
call_kwargs.pop('callback_on_step_end', None) call_kwargs.pop('callback_on_step_end', None)
result = await asyncio.to_thread(pipeline, **call_kwargs) try:
result = await asyncio.to_thread(pipeline, **call_kwargs)
except TaskCancelled:
_progress_done()
raise
except Exception as e:
task_registry.finish(_tid, "error", str(e)[:200])
_progress_done()
raise
finally: finally:
_progress_done() _progress_done()
...@@ -967,6 +991,7 @@ async def _generate_with_diffusers(pipeline, request, global_args, http_request= ...@@ -967,6 +991,7 @@ async def _generate_with_diffusers(pipeline, request, global_args, http_request=
except Exception: except Exception:
pass pass
task_registry.finish(_tid, "done")
return { return {
"created": timestamp, "created": timestamp,
"data": images, "data": images,
...@@ -1014,7 +1039,17 @@ async def _generate_with_sdcpp(sd_model, request, global_args, http_request=None ...@@ -1014,7 +1039,17 @@ async def _generate_with_sdcpp(sd_model, request, global_args, http_request=None
_progress_reset(steps) _progress_reset(steps)
# sd.cpp runs the whole diffusion inside one C call, so it can't be aborted
# mid-step (raising from its progress callback won't reliably unwind the C
# extension). We still register the task for visibility + step progress; a
# cancel takes effect when control returns to Python.
_tid = task_registry.register(
"image", title=(request.prompt or "")[:80],
model=getattr(request, 'model', '') or '', total=steps)
task_registry.start(_tid)
def _sdcpp_progress(step: int, total: int, elapsed: float): def _sdcpp_progress(step: int, total: int, elapsed: float):
task_registry.step(_tid, step)
_progress_step(step) _progress_step(step)
# Use request seed if provided, otherwise use CLI default seed # Use request seed if provided, otherwise use CLI default seed
...@@ -1045,9 +1080,13 @@ async def _generate_with_sdcpp(sd_model, request, global_args, http_request=None ...@@ -1045,9 +1080,13 @@ async def _generate_with_sdcpp(sd_model, request, global_args, http_request=None
seed=seed if seed is not None else 42, seed=seed if seed is not None else 42,
batch_count=request.n if request.n else 1, batch_count=request.n if request.n else 1,
) )
except Exception as e:
task_registry.finish(_tid, "error", str(e)[:200])
_progress_done()
raise
finally: finally:
_progress_done() _progress_done()
# Small delay to let Vulkan driver settle after generation # Small delay to let Vulkan driver settle after generation
time.sleep(0.1) time.sleep(0.1)
...@@ -1087,6 +1126,7 @@ async def _generate_with_sdcpp(sd_model, request, global_args, http_request=None ...@@ -1087,6 +1126,7 @@ async def _generate_with_sdcpp(sd_model, request, global_args, http_request=None
except Exception: except Exception:
pass pass
task_registry.finish(_tid, "done")
return { return {
"created": int(time.time()), "created": int(time.time()),
"data": images "data": images
...@@ -1185,7 +1225,7 @@ def _load_sdcpp_model(model_path: str, global_args, model_config: dict = None): ...@@ -1185,7 +1225,7 @@ def _load_sdcpp_model(model_path: str, global_args, model_config: dict = None):
router = APIRouter() router = APIRouter()
@router.get("/v1/images/progress") @router.get("/v1/images/progress", summary="Image generation progress")
async def get_image_progress(): async def get_image_progress():
"""Return current image generation step progress including speed.""" """Return current image generation step progress including speed."""
elapsed = _time.monotonic() - _gen_progress["started_at"] if _gen_progress["active"] else 0.0 elapsed = _time.monotonic() - _gen_progress["started_at"] if _gen_progress["active"] else 0.0
...@@ -1202,7 +1242,7 @@ async def get_image_progress(): ...@@ -1202,7 +1242,7 @@ async def get_image_progress():
} }
@router.post("/v1/images/generations") @router.post("/v1/images/generations", summary="Generate images (text-to-image)")
async def create_image_generation(request: ImageGenerationRequest, http_request: Request = None): async def create_image_generation(request: ImageGenerationRequest, http_request: Request = None):
""" """
Image generation endpoint (OpenAI-compatible). Image generation endpoint (OpenAI-compatible).
...@@ -1497,7 +1537,7 @@ def _load_img2img_pipeline(model_name: str, global_args): ...@@ -1497,7 +1537,7 @@ def _load_img2img_pipeline(model_name: str, global_args):
raise raise
@router.post("/v1/images/edits") @router.post("/v1/images/edits", summary="Edit an image (instruction / img2img)")
async def create_image_edit(request: ImageEditRequest, http_request: Request = None): async def create_image_edit(request: ImageEditRequest, http_request: Request = None):
""" """
Image-to-image editing endpoint (OpenAI-compatible). Image-to-image editing endpoint (OpenAI-compatible).
...@@ -1638,7 +1678,7 @@ def _load_inpaint_pipeline(model_name: str, global_args): ...@@ -1638,7 +1678,7 @@ def _load_inpaint_pipeline(model_name: str, global_args):
raise raise
@router.post("/v1/images/inpaint") @router.post("/v1/images/inpaint", summary="Inpaint a masked region")
async def create_image_inpaint(request: ImageInpaintRequest, http_request: Request = None): async def create_image_inpaint(request: ImageInpaintRequest, http_request: Request = None):
"""Inpaint a masked region of an image (OpenAI-compatible extension).""" """Inpaint a masked region of an image (OpenAI-compatible extension)."""
global global_args global global_args
...@@ -1750,7 +1790,7 @@ def _run_upscale(upscaler, image_bytes: bytes, scale: int): ...@@ -1750,7 +1790,7 @@ def _run_upscale(upscaler, image_bytes: bytes, scale: int):
return img.resize((w * scale, h * scale), PILImage.LANCZOS) return img.resize((w * scale, h * scale), PILImage.LANCZOS)
@router.post("/v1/images/upscale") @router.post("/v1/images/upscale", summary="Upscale an image")
async def create_image_upscale(request: ImageUpscaleRequest, http_request: Request = None): async def create_image_upscale(request: ImageUpscaleRequest, http_request: Request = None):
"""Upscale an image using Real-ESRGAN or PIL LANCZOS fallback.""" """Upscale an image using Real-ESRGAN or PIL LANCZOS fallback."""
global global_args global global_args
...@@ -1862,7 +1902,7 @@ def _resolve_spatial_model(requested: Optional[str], capability: str) -> Optiona ...@@ -1862,7 +1902,7 @@ def _resolve_spatial_model(requested: Optional[str], capability: str) -> Optiona
return None return None
@router.post("/v1/images/depth") @router.post("/v1/images/depth", summary="Estimate a depth map")
async def create_image_depth(request: ImageDepthRequest, http_request: Request = None): async def create_image_depth(request: ImageDepthRequest, http_request: Request = None):
"""Estimate depth map from an image.""" """Estimate depth map from an image."""
global global_args global global_args
...@@ -1968,7 +2008,7 @@ def _run_segmentation(seg_model, image_bytes: bytes, points, boxes): ...@@ -1968,7 +2008,7 @@ def _run_segmentation(seg_model, image_bytes: bytes, points, boxes):
return PILImage.fromarray(out) return PILImage.fromarray(out)
@router.post("/v1/images/segment") @router.post("/v1/images/segment", summary="Segment an image")
async def create_image_segment(request: ImageSegmentRequest, http_request: Request = None): async def create_image_segment(request: ImageSegmentRequest, http_request: Request = None):
"""Segment objects in an image using SAM or similar models.""" """Segment objects in an image using SAM or similar models."""
global global_args global global_args
...@@ -2035,7 +2075,7 @@ def _run_deblur(image_bytes: bytes, strength: float) -> "PILImage.Image": ...@@ -2035,7 +2075,7 @@ def _run_deblur(image_bytes: bytes, strength: float) -> "PILImage.Image":
return PILImage.fromarray((sharpened * 255).astype(np.uint8)) return PILImage.fromarray((sharpened * 255).astype(np.uint8))
@router.post("/v1/images/deblur") @router.post("/v1/images/deblur", summary="Deblur an image")
async def create_image_deblur(request: ImageDeblurRequest, http_request: Request = None): async def create_image_deblur(request: ImageDeblurRequest, http_request: Request = None):
"""Remove blur from an image using Wiener deconvolution and unsharp masking.""" """Remove blur from an image using Wiener deconvolution and unsharp masking."""
raw = base64.b64decode(request.image.split(',', 1)[-1] if ',' in request.image else request.image) raw = base64.b64decode(request.image.split(',', 1)[-1] if ',' in request.image else request.image)
...@@ -2093,7 +2133,7 @@ def _run_unpixelate(image_bytes: bytes, scale: int, model_path: Optional[str]) - ...@@ -2093,7 +2133,7 @@ def _run_unpixelate(image_bytes: bytes, scale: int, model_path: Optional[str]) -
return PILImage.fromarray(out_arr) return PILImage.fromarray(out_arr)
@router.post("/v1/images/unpixelate") @router.post("/v1/images/unpixelate", summary="Restore a pixelated image")
async def create_image_unpixelate(request: ImageUnpixelateRequest, http_request: Request = None): async def create_image_unpixelate(request: ImageUnpixelateRequest, http_request: Request = None):
"""Remove pixelation / upscale with detail recovery using Real-ESRGAN.""" """Remove pixelation / upscale with detail recovery using Real-ESRGAN."""
raw = base64.b64decode(request.image.split(',', 1)[-1] if ',' in request.image else request.image) raw = base64.b64decode(request.image.split(',', 1)[-1] if ',' in request.image else request.image)
...@@ -2155,7 +2195,7 @@ def _generate_clothing_mask(img_arr) -> "np.ndarray": ...@@ -2155,7 +2195,7 @@ def _generate_clothing_mask(img_arr) -> "np.ndarray":
return fg_mask return fg_mask
@router.post("/v1/images/outfit") @router.post("/v1/images/outfit", summary="Change outfit / clothing")
async def create_image_outfit(request: ImageOutfitRequest, http_request: Request = None): async def create_image_outfit(request: ImageOutfitRequest, http_request: Request = None):
"""Change the outfit/clothing in an image or video using inpainting.""" """Change the outfit/clothing in an image or video using inpainting."""
global global_args global global_args
......
This diff is collapsed.
...@@ -117,7 +117,7 @@ class ImageToVideoPipelineRequest(BaseModel): ...@@ -117,7 +117,7 @@ class ImageToVideoPipelineRequest(BaseModel):
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow")
@router.post("/v1/pipelines/image-to-video") @router.post("/v1/pipelines/image-to-video", summary="Image-to-video pipeline")
async def pipeline_image_to_video(request: ImageToVideoPipelineRequest, http_request: Request = None): async def pipeline_image_to_video(request: ImageToVideoPipelineRequest, http_request: Request = None):
"""Generate an image then animate it into a video.""" """Generate an image then animate it into a video."""
steps = [] steps = []
...@@ -197,7 +197,7 @@ class VideoDubPipelineRequest(BaseModel): ...@@ -197,7 +197,7 @@ class VideoDubPipelineRequest(BaseModel):
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow")
@router.post("/v1/pipelines/video-dub") @router.post("/v1/pipelines/video-dub", summary="Video dubbing pipeline")
async def pipeline_video_dub(request: VideoDubPipelineRequest, http_request: Request = None): async def pipeline_video_dub(request: VideoDubPipelineRequest, http_request: Request = None):
"""Transcribe → translate → TTS dub → burn subtitles.""" """Transcribe → translate → TTS dub → burn subtitles."""
body = { body = {
...@@ -240,7 +240,7 @@ class StoryPipelineRequest(BaseModel): ...@@ -240,7 +240,7 @@ class StoryPipelineRequest(BaseModel):
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow")
@router.post("/v1/pipelines/story") @router.post("/v1/pipelines/story", summary="Story pipeline (multi-scene)")
async def pipeline_story(request: StoryPipelineRequest, http_request: Request = None): async def pipeline_story(request: StoryPipelineRequest, http_request: Request = None):
"""LLM generates script → image per scene → animate first scene → optional TTS narration.""" """LLM generates script → image per scene → animate first scene → optional TTS narration."""
n = min(request.num_scenes or 3, 6) n = min(request.num_scenes or 3, 6)
...@@ -377,7 +377,7 @@ class AudioDubPipelineRequest(BaseModel): ...@@ -377,7 +377,7 @@ class AudioDubPipelineRequest(BaseModel):
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow")
@router.post("/v1/pipelines/audio-dub") @router.post("/v1/pipelines/audio-dub", summary="Audio dubbing pipeline")
async def pipeline_audio_dub(request: AudioDubPipelineRequest, http_request: Request = None): async def pipeline_audio_dub(request: AudioDubPipelineRequest, http_request: Request = None):
"""Transcribe → (translate) → clone voice → replace audio track.""" """Transcribe → (translate) → clone voice → replace audio track."""
import os, tempfile, subprocess, base64 import os, tempfile, subprocess, base64
......
...@@ -499,7 +499,7 @@ class ImageTo3DRequest(BaseModel): ...@@ -499,7 +499,7 @@ class ImageTo3DRequest(BaseModel):
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow")
@router.post("/v1/images/to3d") @router.post("/v1/images/to3d", summary="Image to 3D model")
async def image_to_3d(request: ImageTo3DRequest, http_request: Request = None): async def image_to_3d(request: ImageTo3DRequest, http_request: Request = None):
"""Convert a 2D image to a 3D representation. """Convert a 2D image to a 3D representation.
...@@ -567,7 +567,7 @@ class ImageFrom3DRequest(BaseModel): ...@@ -567,7 +567,7 @@ class ImageFrom3DRequest(BaseModel):
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow")
@router.post("/v1/images/from3d") @router.post("/v1/images/from3d", summary="Render a 3D model to an image")
async def image_from_3d(request: ImageFrom3DRequest, http_request: Request = None): async def image_from_3d(request: ImageFrom3DRequest, http_request: Request = None):
"""Render a 3D model (GLB/OBJ) to a 2D PNG image from a specified camera angle.""" """Render a 3D model (GLB/OBJ) to a 2D PNG image from a specified camera angle."""
raw = _decode_b64(request.model_data) raw = _decode_b64(request.model_data)
...@@ -600,7 +600,7 @@ class VideoTo3DRequest(BaseModel): ...@@ -600,7 +600,7 @@ class VideoTo3DRequest(BaseModel):
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow")
@router.post("/v1/video/to3d") @router.post("/v1/video/to3d", summary="Video to 3D model")
async def video_to_3d(request: VideoTo3DRequest, http_request: Request = None): async def video_to_3d(request: VideoTo3DRequest, http_request: Request = None):
"""Convert a 2D video to a 3D video frame-by-frame. """Convert a 2D video to a 3D video frame-by-frame.
...@@ -641,7 +641,7 @@ class VideoFrom3DRequest(BaseModel): ...@@ -641,7 +641,7 @@ class VideoFrom3DRequest(BaseModel):
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow")
@router.post("/v1/video/from3d") @router.post("/v1/video/from3d", summary="Render a 3D model to a video")
async def video_from_3d(request: VideoFrom3DRequest, http_request: Request = None): async def video_from_3d(request: VideoFrom3DRequest, http_request: Request = None):
"""Render a 3D model as a 360° turntable video.""" """Render a 3D model as a 360° turntable video."""
raw = _decode_b64(request.model_data) raw = _decode_b64(request.model_data)
...@@ -674,7 +674,7 @@ class Generate3DRequest(BaseModel): ...@@ -674,7 +674,7 @@ class Generate3DRequest(BaseModel):
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow")
@router.post("/v1/3d/generate") @router.post("/v1/3d/generate", summary="Generate a 3D model from a prompt")
async def generate_3d(request: Generate3DRequest, http_request: Request = None): async def generate_3d(request: Generate3DRequest, http_request: Request = None):
"""Generate a 3D model (GLB) from a text prompt and/or an image. """Generate a 3D model (GLB) from a text prompt and/or an image.
......
...@@ -32,6 +32,7 @@ logger = logging.getLogger(__name__) ...@@ -32,6 +32,7 @@ logger = logging.getLogger(__name__)
# Import from codai modules # Import from codai modules
from codai.models.manager import ModelManager, WhisperServerManager, MultiModelManager, model_manager, multi_model_manager from codai.models.manager import ModelManager, WhisperServerManager, MultiModelManager, model_manager, multi_model_manager
from codai.queue.manager import QueueManager, queue_manager from codai.queue.manager import QueueManager, queue_manager
from codai.tasks import task_registry
from codai.api.prompt_cache import prompt_cache_manager from codai.api.prompt_cache import prompt_cache_manager
from codai.pydantic.textrequest import ChatCompletionRequest, ToolFunction, Tool from codai.pydantic.textrequest import ChatCompletionRequest, ToolFunction, Tool
from codai.models.parser import filter_malformed_content, filter_repetition, format_tools_for_prompt, cleanup_control_tokens, OpenAIFormatter, ModelParserAdapter, ToolCallParser from codai.models.parser import filter_malformed_content, filter_repetition, format_tools_for_prompt, cleanup_control_tokens, OpenAIFormatter, ModelParserAdapter, ToolCallParser
...@@ -92,7 +93,7 @@ def set_grammar_guided_gen(enabled: bool): ...@@ -92,7 +93,7 @@ def set_grammar_guided_gen(enabled: bool):
router = APIRouter() router = APIRouter()
@router.post("/v1/chat/completions") @router.post("/v1/chat/completions", summary="Chat completions")
async def chat_completions(request: ChatCompletionRequest, http_request: Request = None): async def chat_completions(request: ChatCompletionRequest, http_request: Request = None):
"""Chat completions endpoint with streaming and tool support.""" """Chat completions endpoint with streaming and tool support."""
...@@ -1248,7 +1249,8 @@ async def stream_chat_response( ...@@ -1248,7 +1249,8 @@ async def stream_chat_response(
completion_id = f"chatcmpl-{uuid.uuid4().hex}" completion_id = f"chatcmpl-{uuid.uuid4().hex}"
created = int(time.time()) created = int(time.time())
request_id = f"req-{uuid.uuid4().hex[:8]}" request_id = f"req-{uuid.uuid4().hex[:8]}"
_tid = None
generated_text = "" generated_text = ""
# Check if model is loaded - if not, notify waiting clients # Check if model is loaded - if not, notify waiting clients
...@@ -1320,6 +1322,9 @@ async def stream_chat_response( ...@@ -1320,6 +1322,9 @@ async def stream_chat_response(
# Mark as starting processing # Mark as starting processing
await queue_manager.start_processing(request_id, model_name) await queue_manager.start_processing(request_id, model_name)
_tid = task_registry.register("text", title=(model_name or "chat"),
model=model_name or "", task_id=request_id)
task_registry.start(_tid)
# Send "Model starting" message # Send "Model starting" message
data = { data = {
...@@ -1374,6 +1379,9 @@ async def stream_chat_response( ...@@ -1374,6 +1379,9 @@ async def stream_chat_response(
response_format=response_format, response_format=response_format,
enable_thinking=enable_thinking, enable_thinking=enable_thinking,
): ):
# Cooperative cancellation: stop streaming if the task was cancelled.
if task_registry.is_cancelled(_tid):
break
chunk_count += 1 chunk_count += 1
# Always filter malformed content (regex-based, works per-chunk) # Always filter malformed content (regex-based, works per-chunk)
filtered_chunk = filter_malformed_content(chunk) filtered_chunk = filter_malformed_content(chunk)
...@@ -1580,6 +1588,9 @@ async def stream_chat_response( ...@@ -1580,6 +1588,9 @@ async def stream_chat_response(
finally: finally:
# Always clean up queue state # Always clean up queue state
await queue_manager.finish_processing() await queue_manager.finish_processing()
if _tid:
task_registry.finish(
_tid, "cancelled" if task_registry.is_cancelled(_tid) else "done")
async def generate_chat_response( async def generate_chat_response(
messages: List[Dict], messages: List[Dict],
...@@ -1789,7 +1800,7 @@ async def generate_chat_response( ...@@ -1789,7 +1800,7 @@ async def generate_chat_response(
from codai.pydantic.textrequest import CompletionRequest from codai.pydantic.textrequest import CompletionRequest
@router.post("/v1/completions") @router.post("/v1/completions", summary="Legacy text completions")
async def completions(request: CompletionRequest): async def completions(request: CompletionRequest):
"""Legacy text completions endpoint (for backward compatibility).""" """Legacy text completions endpoint (for backward compatibility)."""
# Get the model for this request # Get the model for this request
......
...@@ -119,7 +119,7 @@ def _format_response(fmt: str, text: str, segments: list): ...@@ -119,7 +119,7 @@ def _format_response(fmt: str, text: str, segments: list):
router = APIRouter() router = APIRouter()
@router.post("/v1/audio/transcriptions") @router.post("/v1/audio/transcriptions", summary="Transcribe audio to text")
async def create_transcription( async def create_transcription(
model: str = Form(...), model: str = Form(...),
file: UploadFile = File(...), file: UploadFile = File(...),
......
...@@ -64,7 +64,7 @@ class TTSResponse(BaseModel): ...@@ -64,7 +64,7 @@ class TTSResponse(BaseModel):
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow")
@router.post("/v1/audio/speech") @router.post("/v1/audio/speech", summary="Text-to-speech synthesis")
async def create_speech(request: TTSRequest, http_request: Request = None): async def create_speech(request: TTSRequest, http_request: Request = None):
""" """
Text-to-speech endpoint (OpenAI-compatible). Text-to-speech endpoint (OpenAI-compatible).
......
...@@ -45,6 +45,7 @@ from codai.pydantic.videorequest import ( ...@@ -45,6 +45,7 @@ from codai.pydantic.videorequest import (
CharacterDialogLine, CharacterDialogLine,
) )
from codai.api.images import _disable_safety_checker from codai.api.images import _disable_safety_checker
from codai.tasks import task_registry, TaskCancelled
router = APIRouter() router = APIRouter()
...@@ -627,7 +628,15 @@ def _generate_sdcpp_video(sd_model, request, model_cfg=None): ...@@ -627,7 +628,15 @@ def _generate_sdcpp_video(sd_model, request, model_cfg=None):
_vid_progress_reset(steps) _vid_progress_reset(steps)
# sd.cpp runs the whole diffusion in one C call → not interruptible mid-step;
# register for visibility + step progress (cancel applies once back in Python).
_tid = task_registry.register(
"video", title=(prompt or mode or "")[:80],
model=getattr(request, 'model', '') or '', total=steps)
task_registry.start(_tid)
def _progress_cb(step: int, total: int, elapsed: float): def _progress_cb(step: int, total: int, elapsed: float):
task_registry.step(_tid, step)
_vid_progress_step(step) _vid_progress_step(step)
kw = { kw = {
...@@ -654,8 +663,14 @@ def _generate_sdcpp_video(sd_model, request, model_cfg=None): ...@@ -654,8 +663,14 @@ def _generate_sdcpp_video(sd_model, request, model_cfg=None):
kw['init_image'] = _pil_from_b64(init_src) kw['init_image'] = _pil_from_b64(init_src)
kw['end_image'] = _pil_from_b64(request.end_image) kw['end_image'] = _pil_from_b64(request.end_image)
frames = sd_model.generate_video(**kw) try:
frames = sd_model.generate_video(**kw)
except Exception as e:
task_registry.finish(_tid, "error", str(e)[:200])
_vid_progress_done()
raise
_vid_progress_done() _vid_progress_done()
task_registry.finish(_tid, "done")
return list(frames), fps return list(frames), fps
...@@ -1483,7 +1498,17 @@ def _generate_video(pipe, request: VideoGenerationRequest): ...@@ -1483,7 +1498,17 @@ def _generate_video(pipe, request: VideoGenerationRequest):
_vid_progress_reset(kw['num_inference_steps']) _vid_progress_reset(kw['num_inference_steps'])
_tid = task_registry.register(
"video", title=(request.prompt or mode or "")[:80],
model=getattr(request, 'model', '') or '', total=kw['num_inference_steps'])
task_registry.start(_tid)
def _vid_step_cb(pipe, step_index, timestep, callback_kwargs): def _vid_step_cb(pipe, step_index, timestep, callback_kwargs):
# Cooperative cancellation: abort at the next step boundary if cancelled.
task_registry.raise_if_cancelled(_tid)
# Cooperative pause: block here while the user has paused this task.
task_registry.wait_if_paused(_tid)
task_registry.step(_tid, step_index + 1)
_vid_progress_step(step_index + 1) _vid_progress_step(step_index + 1)
# Mid-generation thermal checkpoint: pause between denoise steps if the # Mid-generation thermal checkpoint: pause between denoise steps if the
# CPU/GPU went over the limit during this (multi-minute) generation. # CPU/GPU went over the limit during this (multi-minute) generation.
...@@ -1547,8 +1572,17 @@ def _generate_video(pipe, request: VideoGenerationRequest): ...@@ -1547,8 +1572,17 @@ def _generate_video(pipe, request: VideoGenerationRequest):
# previous clip's (common within a match) and only swapping when they differ. # previous clip's (common within a match) and only swapping when they differ.
# Left loaded after the run so the next clip with the same set pays nothing. # Left loaded after the run so the next clip with the same set pays nothing.
_sync_video_loras(pipe, getattr(request, 'loras', None)) _sync_video_loras(pipe, getattr(request, 'loras', None))
frames = _run_pipeline(pipe, kw) try:
frames = _run_pipeline(pipe, kw)
except TaskCancelled:
_vid_progress_done()
raise # global handler finishes the task (cancelled) + returns HTTP 499
except Exception as e:
task_registry.finish(_tid, "error", str(e)[:200])
_vid_progress_done()
raise
_vid_progress_done() _vid_progress_done()
task_registry.finish(_tid, "done")
return frames, fps return frames, fps
...@@ -1979,7 +2013,7 @@ def _translate_srt(srt_path: str, target_lang: str, temps: list) -> str: ...@@ -1979,7 +2013,7 @@ def _translate_srt(srt_path: str, target_lang: str, temps: list) -> str:
# Progress endpoint # Progress endpoint
# ============================================================================= # =============================================================================
@router.get("/v1/video/progress") @router.get("/v1/video/progress", summary="Video generation progress")
async def get_video_progress(): async def get_video_progress():
"""Return current video generation step progress including speed.""" """Return current video generation step progress including speed."""
elapsed = time.monotonic() - _vid_progress["started_at"] if _vid_progress["active"] else 0.0 elapsed = time.monotonic() - _vid_progress["started_at"] if _vid_progress["active"] else 0.0
...@@ -2000,7 +2034,7 @@ async def get_video_progress(): ...@@ -2000,7 +2034,7 @@ async def get_video_progress():
# Main generation endpoint # Main generation endpoint
# ============================================================================= # =============================================================================
@router.post("/v1/video/generations", response_model=VideoGenerationResponse) @router.post("/v1/video/generations", response_model=VideoGenerationResponse, summary="Generate video")
async def video_generations(request: VideoGenerationRequest, async def video_generations(request: VideoGenerationRequest,
http_request: Request = None): http_request: Request = None):
""" """
...@@ -2269,7 +2303,7 @@ async def video_generations(request: VideoGenerationRequest, ...@@ -2269,7 +2303,7 @@ async def video_generations(request: VideoGenerationRequest,
# Video upscale endpoint # Video upscale endpoint
# ============================================================================= # =============================================================================
@router.post("/v1/video/upscale") @router.post("/v1/video/upscale", summary="Upscale a video")
async def video_upscale(request: VideoUpscaleRequest, http_request: Request = None): async def video_upscale(request: VideoUpscaleRequest, http_request: Request = None):
""" """
Upscale a video using ffmpeg lanczos or Real-ESRGAN. Upscale a video using ffmpeg lanczos or Real-ESRGAN.
...@@ -2299,7 +2333,7 @@ async def video_upscale(request: VideoUpscaleRequest, http_request: Request = No ...@@ -2299,7 +2333,7 @@ async def video_upscale(request: VideoUpscaleRequest, http_request: Request = No
# Subtitle generation endpoint # Subtitle generation endpoint
# ============================================================================= # =============================================================================
@router.post("/v1/video/subtitle") @router.post("/v1/video/subtitle", summary="Subtitle / caption a video")
async def video_subtitle(request: VideoSubtitleRequest, http_request: Request = None): async def video_subtitle(request: VideoSubtitleRequest, http_request: Request = None):
""" """
Generate subtitles for a video. Generate subtitles for a video.
...@@ -2353,7 +2387,7 @@ async def video_subtitle(request: VideoSubtitleRequest, http_request: Request = ...@@ -2353,7 +2387,7 @@ async def video_subtitle(request: VideoSubtitleRequest, http_request: Request =
# Frame interpolation endpoint # Frame interpolation endpoint
# ============================================================================= # =============================================================================
@router.post("/v1/video/interpolate") @router.post("/v1/video/interpolate", summary="Interpolate video frames")
async def video_interpolate(request: VideoInterpolateRequest, http_request: Request = None): async def video_interpolate(request: VideoInterpolateRequest, http_request: Request = None):
""" """
Increase video FPS via frame interpolation. Increase video FPS via frame interpolation.
...@@ -2400,7 +2434,7 @@ async def video_interpolate(request: VideoInterpolateRequest, http_request: Requ ...@@ -2400,7 +2434,7 @@ async def video_interpolate(request: VideoInterpolateRequest, http_request: Requ
# Video dubbing endpoint # Video dubbing endpoint
# ============================================================================= # =============================================================================
@router.post("/v1/video/dub") @router.post("/v1/video/dub", summary="Dub a video")
async def video_dub(request: VideoDubRequest, http_request: Request = None): async def video_dub(request: VideoDubRequest, http_request: Request = None):
""" """
Translate and re-dub a video. Translate and re-dub a video.
......
...@@ -185,13 +185,13 @@ class VoicePatchRequest(BaseModel): ...@@ -185,13 +185,13 @@ class VoicePatchRequest(BaseModel):
# Voice profile management # Voice profile management
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@router.get("/v1/audio/voices") @router.get("/v1/audio/voices", summary="List voice profiles")
async def list_voices(): async def list_voices():
"""List all saved voice profiles.""" """List all saved voice profiles."""
return {"voices": _list_voices()} return {"voices": _list_voices()}
@router.post("/v1/audio/voices") @router.post("/v1/audio/voices", summary="Create a voice profile")
async def create_voice( async def create_voice(
name: str = Form(...), name: str = Form(...),
transcript: str = Form(...), transcript: str = Form(...),
...@@ -216,7 +216,7 @@ async def create_voice( ...@@ -216,7 +216,7 @@ async def create_voice(
return {"created": True, "voice": meta} return {"created": True, "voice": meta}
@router.delete("/v1/audio/voices/{name}") @router.delete("/v1/audio/voices/{name}", summary="Delete a voice profile")
async def delete_voice(name: str): async def delete_voice(name: str):
"""Delete a saved voice profile.""" """Delete a saved voice profile."""
import shutil import shutil
...@@ -227,7 +227,7 @@ async def delete_voice(name: str): ...@@ -227,7 +227,7 @@ async def delete_voice(name: str):
return {"deleted": True, "name": name} return {"deleted": True, "name": name}
@router.patch("/v1/audio/voices/{name}") @router.patch("/v1/audio/voices/{name}", summary="Update a voice profile")
async def patch_voice(name: str, req: VoicePatchRequest): async def patch_voice(name: str, req: VoicePatchRequest):
"""Update description, transcript, or reference audio of a saved voice profile.""" """Update description, transcript, or reference audio of a saved voice profile."""
meta = _load_voice(name) meta = _load_voice(name)
...@@ -259,7 +259,7 @@ async def patch_voice(name: str, req: VoicePatchRequest): ...@@ -259,7 +259,7 @@ async def patch_voice(name: str, req: VoicePatchRequest):
return {"updated": True, "voice": meta} return {"updated": True, "voice": meta}
@router.get("/v1/audio/voices/{name}") @router.get("/v1/audio/voices/{name}", summary="Get a voice profile")
async def get_voice(name: str): async def get_voice(name: str):
"""Get a single voice profile metadata.""" """Get a single voice profile metadata."""
meta = _load_voice(name) meta = _load_voice(name)
...@@ -268,7 +268,7 @@ async def get_voice(name: str): ...@@ -268,7 +268,7 @@ async def get_voice(name: str):
return {"voice": meta} return {"voice": meta}
@router.post("/v1/audio/voices/extract") @router.post("/v1/audio/voices/extract", summary="Extract a voice profile from a sample")
async def extract_voice(req: VoiceExtractRequest): async def extract_voice(req: VoiceExtractRequest):
""" """
Extract a voice profile from a source audio or video file. Extract a voice profile from a source audio or video file.
...@@ -358,7 +358,7 @@ class VoiceCloneRequest(BaseModel): ...@@ -358,7 +358,7 @@ class VoiceCloneRequest(BaseModel):
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow")
@router.post("/v1/audio/clone") @router.post("/v1/audio/clone", summary="Clone a voice / synthesize cloned speech")
async def clone_voice(request: VoiceCloneRequest, http_request: Request = None): async def clone_voice(request: VoiceCloneRequest, http_request: Request = None):
""" """
Synthesize speech in a cloned voice using F5-TTS. Synthesize speech in a cloned voice using F5-TTS.
......
...@@ -94,7 +94,7 @@ class VoiceConvertRequest(BaseModel): ...@@ -94,7 +94,7 @@ class VoiceConvertRequest(BaseModel):
model_config = ConfigDict(extra='allow') model_config = ConfigDict(extra='allow')
@router.post('/v1/audio/convert') @router.post('/v1/audio/convert', summary="Voice conversion (speech-to-speech)")
async def convert_voice(request: VoiceConvertRequest, http_request: Request = None): async def convert_voice(request: VoiceConvertRequest, http_request: Request = None):
""" """
Voice conversion: preserves pitch/melody/expression, changes only timbre. Voice conversion: preserves pitch/melody/expression, changes only timbre.
......
...@@ -78,6 +78,40 @@ except ImportError: ...@@ -78,6 +78,40 @@ except ImportError:
_llama_cpp = None _llama_cpp = None
# Friendly KV-cache quant names → llama.cpp GGML type. q8_0 is near-lossless and
# the safe default; the q5/q4 types trade a little accuracy for ~2x less KV VRAM.
_KV_TYPE_ALIASES = {
'f16': 'GGML_TYPE_F16', 'fp16': 'GGML_TYPE_F16', 'f32': 'GGML_TYPE_F32',
'q8_0': 'GGML_TYPE_Q8_0', 'q8': 'GGML_TYPE_Q8_0', 'q8_1': 'GGML_TYPE_Q8_1',
'q5_0': 'GGML_TYPE_Q5_0', 'q5_1': 'GGML_TYPE_Q5_1', 'q5': 'GGML_TYPE_Q5_1',
'q4_0': 'GGML_TYPE_Q4_0', 'q4_1': 'GGML_TYPE_Q4_1', 'q4': 'GGML_TYPE_Q4_1',
'iq4_nl': 'GGML_TYPE_IQ4_NL',
}
# Sub-8-bit KV types that llama.cpp can only use with flash attention enabled.
_KV_NEEDS_FLASH = {'q5_0', 'q5_1', 'q5', 'q4_0', 'q4_1', 'q4', 'iq4_nl'}
def _ggml_kv_type(name):
"""Map a KV-cache quant name to the llama.cpp GGML type int, or None.
Returns None for falsy / unknown / 'none' / 'auto' values (→ keep the
llama.cpp default, f16). Unknown names log a warning instead of failing."""
if not name or _llama_cpp is None:
return None
key = str(name).strip().lower().replace('-', '_').replace(' ', '')
if key in ('', 'none', 'auto', 'default', 'f16default'):
return None
const = _KV_TYPE_ALIASES.get(key)
if const is None:
print(f" KV cache type '{name}' not recognized — using default (f16)")
return None
val = getattr(_llama_cpp, const, None)
if val is None:
print(f" KV cache type '{name}' unsupported by this llama.cpp build — using f16")
return val
def _install_layer_log_callback(): def _install_layer_log_callback():
"""Replace llama.cpp's log callback with one that prints load-time layer/buffer """Replace llama.cpp's log callback with one that prints load-time layer/buffer
messages directly to stdout. Returns the callback object — keep a reference messages directly to stdout. Returns the callback object — keep a reference
...@@ -613,7 +647,33 @@ class VulkanBackend(ModelBackend): ...@@ -613,7 +647,33 @@ class VulkanBackend(ModelBackend):
llama_kwargs['rope_freq_base'] = kwargs['rope_freq_base'] llama_kwargs['rope_freq_base'] = kwargs['rope_freq_base']
if 'rope_freq_scale' in kwargs: if 'rope_freq_scale' in kwargs:
llama_kwargs['rope_freq_scale'] = kwargs['rope_freq_scale'] llama_kwargs['rope_freq_scale'] = kwargs['rope_freq_scale']
# KV-cache quantization (llama.cpp type_k / type_v). Shrinks the KV cache
# so long contexts fit in less VRAM. Read from the per-model config, with
# the raw models.json entry as a fallback (carried in _raw_cfg).
_raw_cfg = kwargs.get('_raw_cfg') or {}
_ck = kwargs.get('cache_type_k', _raw_cfg.get('cache_type_k'))
_cv = kwargs.get('cache_type_v', _raw_cfg.get('cache_type_v'))
_flash = bool(kwargs.get('flash_attn', _raw_cfg.get('flash_attn',
_raw_cfg.get('flash_attention', False))))
_tk = _ggml_kv_type(_ck)
_tv = _ggml_kv_type(_cv)
if _tk is not None:
llama_kwargs['type_k'] = _tk
if _tv is not None:
llama_kwargs['type_v'] = _tv
# A quantized V cache below 8 bits requires flash attention in llama.cpp;
# auto-enable it (with a note) so the config "just works".
_v_needs_flash = str(_cv or '').strip().lower().replace('-', '_') in _KV_NEEDS_FLASH
if (_tk is not None or _tv is not None):
if _v_needs_flash and not _flash:
_flash = True
print(" KV cache: sub-8-bit V cache needs flash attention — enabling it")
if _flash:
llama_kwargs['flash_attn'] = True
print(f" KV cache: type_k={_ck or 'f16'} type_v={_cv or 'f16'}"
f"{' (flash_attn on)' if _flash else ''}")
# Force CUDA if requested # Force CUDA if requested
if self.force_cuda: if self.force_cuda:
# Set environment variable to force CUDA # Set environment variable to force CUDA
......
...@@ -247,4 +247,11 @@ configuration directory (--config DIR, default: OS-specific CoderAI directory). ...@@ -247,4 +247,11 @@ configuration directory (--config DIR, default: OS-specific CoderAI directory).
action="store_true", action="store_true",
help="List available Vulkan GPU devices and exit", help="List available Vulkan GPU devices and exit",
) )
parser.add_argument(
"--no-resume-jobs",
action="store_true",
help="Do not resume/recover interrupted LoRA training jobs on restart. "
"Mid-flight jobs are marked 'cancelled' (checkpoints are kept, so they "
"can still be restarted manually from the Tasks page).",
)
return parser.parse_args() return parser.parse_args()
...@@ -126,6 +126,17 @@ class ThermalConfig: ...@@ -126,6 +126,17 @@ class ThermalConfig:
poll_seconds: float = 5.0 # how often to re-check while cooling down poll_seconds: float = 5.0 # how often to re-check while cooling down
@dataclass
class JobsConfig:
"""Background-job (LoRA training) configuration."""
# When True, an interrupted training job (process restart) is left
# 'interrupted' so it can resume from its on-disk checkpoint. When False,
# such jobs are marked 'cancelled' on startup and not auto-resumed (their
# checkpoints are kept, so they can be restarted manually from the Tasks
# page). The --no-resume-jobs CLI flag forces this off for one run.
resume_on_restart: bool = True
@dataclass @dataclass
class Config: class Config:
"""Main configuration class.""" """Main configuration class."""
...@@ -139,6 +150,7 @@ class Config: ...@@ -139,6 +150,7 @@ class Config:
whisper: WhisperConfig = field(default_factory=WhisperConfig) whisper: WhisperConfig = field(default_factory=WhisperConfig)
archive: ArchiveConfig = field(default_factory=ArchiveConfig) archive: ArchiveConfig = field(default_factory=ArchiveConfig)
thermal: ThermalConfig = field(default_factory=ThermalConfig) thermal: ThermalConfig = field(default_factory=ThermalConfig)
jobs: JobsConfig = field(default_factory=JobsConfig)
broker: BrokerConfig = field(default_factory=BrokerConfig) broker: BrokerConfig = field(default_factory=BrokerConfig)
system_prompt: Optional[str] = None system_prompt: Optional[str] = None
tools_closer_prompt: bool = False tools_closer_prompt: bool = False
...@@ -293,6 +305,7 @@ class ConfigManager: ...@@ -293,6 +305,7 @@ class ConfigManager:
whisper=WhisperConfig(**config_data.get("whisper", {})), whisper=WhisperConfig(**config_data.get("whisper", {})),
archive=ArchiveConfig(**config_data.get("archive", {})), archive=ArchiveConfig(**config_data.get("archive", {})),
thermal=ThermalConfig(**config_data.get("thermal", {})), thermal=ThermalConfig(**config_data.get("thermal", {})),
jobs=JobsConfig(**config_data.get("jobs", {})),
broker=BrokerConfig(**config_data.get("broker", {})), broker=BrokerConfig(**config_data.get("broker", {})),
system_prompt=config_data.get("system_prompt"), system_prompt=config_data.get("system_prompt"),
tools_closer_prompt=config_data.get("tools_closer_prompt", False), tools_closer_prompt=config_data.get("tools_closer_prompt", False),
...@@ -411,6 +424,9 @@ class ConfigManager: ...@@ -411,6 +424,9 @@ class ConfigManager:
"gpu_resume": self.config.thermal.gpu_resume, "gpu_resume": self.config.thermal.gpu_resume,
"poll_seconds": self.config.thermal.poll_seconds, "poll_seconds": self.config.thermal.poll_seconds,
}, },
"jobs": {
"resume_on_restart": self.config.jobs.resume_on_restart,
},
"broker": { "broker": {
"enabled": self.config.broker.enabled, "enabled": self.config.broker.enabled,
"base_url": self.config.broker.base_url, "base_url": self.config.broker.base_url,
......
...@@ -147,6 +147,8 @@ def build_runtime_kwargs(model_cfg, model_type): ...@@ -147,6 +147,8 @@ def build_runtime_kwargs(model_cfg, model_type):
} }
if model_type == "text": if model_type == "text":
kwargs['ctx'] = model_cfg.get('n_ctx', model_cfg.get('context_size')) kwargs['ctx'] = model_cfg.get('n_ctx', model_cfg.get('context_size'))
kwargs['cache_type_k'] = model_cfg.get('cache_type_k')
kwargs['cache_type_v'] = model_cfg.get('cache_type_v')
elif model_type == "image": elif model_type == "image":
kwargs['llm_path'] = model_cfg.get('llm_path') kwargs['llm_path'] = model_cfg.get('llm_path')
kwargs['vae_path'] = model_cfg.get('vae_path') kwargs['vae_path'] = model_cfg.get('vae_path')
...@@ -865,9 +867,16 @@ def main(): ...@@ -865,9 +867,16 @@ def main():
from codai.api.characters import set_global_args as set_chars_global_args from codai.api.characters import set_global_args as set_chars_global_args
set_chars_global_args(global_args) set_chars_global_args(global_args)
# Set LoRA training module global args # Set LoRA training module global args. Resolve job-recovery first (the
from codai.api.loras import set_global_args as set_loras_global_args # --no-resume-jobs flag overrides the persisted config setting), then call
# set_global_args, which runs _load_jobs_on_start and honours the flag.
from codai.api.loras import (set_global_args as set_loras_global_args,
set_resume_enabled as set_loras_resume_enabled)
_resume_jobs = bool(getattr(config.jobs, "resume_on_restart", True)) and not getattr(args, "no_resume_jobs", False)
set_loras_resume_enabled(_resume_jobs)
set_loras_global_args(global_args) set_loras_global_args(global_args)
if not _resume_jobs:
print("LoRA job recovery: DISABLED (interrupted training will be cancelled on restart)")
# Set environment profiles module global args # Set environment profiles module global args
from codai.api.environments import set_global_args as set_envs_global_args from codai.api.environments import set_global_args as set_envs_global_args
......
...@@ -790,6 +790,17 @@ class MultiModelManager: ...@@ -790,6 +790,17 @@ class MultiModelManager:
# build_kwargs_from_config populates it from the model's # build_kwargs_from_config populates it from the model's
# 'flash_attention' setting; CLI/global is NOT consulted here. # 'flash_attention' setting; CLI/global is NOT consulted here.
kwargs['flash_attn'] = bool(config.get('flash_attn', False)) kwargs['flash_attn'] = bool(config.get('flash_attn', False))
# KV-cache quantization (llama.cpp type_k/type_v) — pass through
# to the backend, with the raw models.json entry as a fallback.
_raw = config.get('_raw_cfg') if isinstance(config.get('_raw_cfg'), dict) else {}
for _kvk in ('cache_type_k', 'cache_type_v'):
_kvv = config.get(_kvk)
if _kvv is None:
_kvv = _raw.get(_kvk)
if _kvv:
kwargs[_kvk] = _kvv
if _raw and '_raw_cfg' not in kwargs:
kwargs['_raw_cfg'] = _raw
no_ram = _cfg_or_global('no_ram', 'no_ram', False) no_ram = _cfg_or_global('no_ram', 'no_ram', False)
kwargs['no_ram'] = bool(no_ram) kwargs['no_ram'] = bool(no_ram)
offload_strategy = _cfg_or_global('offload_strategy', 'offload_strategy', 'auto') offload_strategy = _cfg_or_global('offload_strategy', 'offload_strategy', 'auto')
...@@ -872,6 +883,17 @@ class MultiModelManager: ...@@ -872,6 +883,17 @@ class MultiModelManager:
# build_kwargs_from_config populates it from the model's # build_kwargs_from_config populates it from the model's
# 'flash_attention' setting; CLI/global is NOT consulted here. # 'flash_attention' setting; CLI/global is NOT consulted here.
kwargs['flash_attn'] = bool(config.get('flash_attn', False)) kwargs['flash_attn'] = bool(config.get('flash_attn', False))
# KV-cache quantization (llama.cpp type_k/type_v) — pass through
# to the backend, with the raw models.json entry as a fallback.
_raw = config.get('_raw_cfg') if isinstance(config.get('_raw_cfg'), dict) else {}
for _kvk in ('cache_type_k', 'cache_type_v'):
_kvv = config.get(_kvk)
if _kvv is None:
_kvv = _raw.get(_kvk)
if _kvv:
kwargs[_kvk] = _kvv
if _raw and '_raw_cfg' not in kwargs:
kwargs['_raw_cfg'] = _raw
no_ram = _cfg_or_global('no_ram', 'no_ram', False) no_ram = _cfg_or_global('no_ram', 'no_ram', False)
kwargs['no_ram'] = bool(no_ram) kwargs['no_ram'] = bool(no_ram)
offload_strategy = _cfg_or_global('offload_strategy', 'offload_strategy', 'auto') offload_strategy = _cfg_or_global('offload_strategy', 'offload_strategy', 'auto')
......
...@@ -35,10 +35,62 @@ Semantics (per sensor, when enabled): ...@@ -35,10 +35,62 @@ Semantics (per sensor, when enabled):
import os import os
import shutil import shutil
import subprocess import subprocess
import threading
import time import time
from typing import Optional, Tuple from typing import Optional, Tuple
# ---------------------------------------------------------------------------
# Cooldown state (published for the admin Tasks view)
# ---------------------------------------------------------------------------
# A thermal pause is a *global* hardware event: every worker that reaches a
# checkpoint blocks until temps recover. We publish a single process-wide state
# so the Tasks page can show that running work is paused for cooldown. A waiter
# counter (not a bool) keeps the state correct when several workers pause at
# once — the state is "active" while any worker is still cooling.
_cooldown_lock = threading.Lock()
_cooldown_waiters = 0
_cooldown_state: dict = {
"active": False, "since": 0.0, "waited": 0.0,
"gpu": None, "cpu": None, "message": "",
}
def get_cooldown_state() -> dict:
"""Snapshot of the current thermal cooldown (see module note). ``active`` is
True while at least one worker is paused waiting for the hardware to cool."""
with _cooldown_lock:
return dict(_cooldown_state)
def _cooldown_enter() -> None:
global _cooldown_waiters
with _cooldown_lock:
_cooldown_waiters += 1
_cooldown_state["active"] = True
if not _cooldown_state.get("since"):
_cooldown_state["since"] = time.time()
def _cooldown_update(gpu, cpu, waited, message) -> None:
with _cooldown_lock:
_cooldown_state["gpu"] = gpu
_cooldown_state["cpu"] = cpu
_cooldown_state["waited"] = waited
_cooldown_state["message"] = message
def _cooldown_exit() -> None:
global _cooldown_waiters
with _cooldown_lock:
_cooldown_waiters = max(0, _cooldown_waiters - 1)
if _cooldown_waiters == 0:
_cooldown_state.update({
"active": False, "since": 0.0, "waited": 0.0,
"gpu": None, "cpu": None, "message": "",
})
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Temperature readers # Temperature readers
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
...@@ -199,6 +251,57 @@ def read_cpu_temp() -> Optional[float]: ...@@ -199,6 +251,57 @@ def read_cpu_temp() -> Optional[float]:
return val return val
_gpu_util_cache: Tuple[float, Optional[float]] = (0.0, None)
def _read_gpu_util_uncached() -> Optional[float]:
"""Hottest GPU utilization in %, or None if unreadable."""
if _NVIDIA_SMI:
out = _run([
_NVIDIA_SMI,
"--query-gpu=utilization.gpu",
"--format=csv,noheader,nounits",
])
if out:
vals = []
for line in out.splitlines():
line = line.strip()
if line:
try:
vals.append(float(line))
except ValueError:
pass
if vals:
return max(vals)
if _ROCM_SMI:
out = _run([_ROCM_SMI, "--showuse"])
if out:
vals = []
for line in out.splitlines():
low = line.lower()
if "gpu use" in low and "%" in line:
for tok in line.replace("%", " ").split():
try:
vals.append(float(tok))
except ValueError:
continue
if vals:
return max(vals)
return None
def read_gpu_util() -> Optional[float]:
"""GPU utilization % (cached ~2s), or None if unreadable."""
global _gpu_util_cache
now = time.monotonic()
ts, val = _gpu_util_cache
if now - ts < _CACHE_TTL:
return val
val = _read_gpu_util_uncached()
_gpu_util_cache = (now, val)
return val
def read_cpu_temp_avg(samples: int = 3, max_seconds: float = 3.0) -> Optional[float]: def read_cpu_temp_avg(samples: int = 3, max_seconds: float = 3.0) -> Optional[float]:
"""Averaged CPU temperature for stable resume/cooldown decisions. """Averaged CPU temperature for stable resume/cooldown decisions.
...@@ -372,25 +475,30 @@ def wait_until_safe(settings: Optional[ThermalSettings] = None, ...@@ -372,25 +475,30 @@ def wait_until_safe(settings: Optional[ThermalSettings] = None,
f"until cooldown (GPU<={settings.gpu_resume:.0f}°C / " f"until cooldown (GPU<={settings.gpu_resume:.0f}°C / "
f"CPU<={settings.cpu_resume:.0f}°C)") f"CPU<={settings.cpu_resume:.0f}°C)")
waited = 0.0 waited = 0.0
while True: _cooldown_enter()
# Re-evaluate against resume thresholds (lower than trigger → hysteresis). try:
# CPU temps are noisy, so average a few samples for the resume decision while True:
# (the pause check above stays single-read to react fast to spikes). # Re-evaluate against resume thresholds (lower than trigger → hysteresis).
gt = read_gpu_temp() if settings.gpu_enabled else None # CPU temps are noisy, so average a few samples for the resume decision
ct = read_cpu_temp_avg() if settings.cpu_enabled else None # (the pause check above stays single-read to react fast to spikes).
still = [] gt = read_gpu_temp() if settings.gpu_enabled else None
if gt is not None and gt > settings.gpu_resume: ct = read_cpu_temp_avg() if settings.cpu_enabled else None
still.append(("GPU", gt, settings.gpu_resume)) still = []
if ct is not None and ct > settings.cpu_resume: if gt is not None and gt > settings.gpu_resume:
still.append(("CPU", ct, settings.cpu_resume)) still.append(("GPU", gt, settings.gpu_resume))
_dbg(f"cooldown{desc} {int(waited)}s: GPU {_fmt(gt)} CPU {_fmt(ct)} (avg-3) " if ct is not None and ct > settings.cpu_resume:
f"(still hot: {[s[0] for s in still] or 'none'})") still.append(("CPU", ct, settings.cpu_resume))
if not still: _dbg(f"cooldown{desc} {int(waited)}s: GPU {_fmt(gt)} CPU {_fmt(ct)} (avg-3) "
break f"(still hot: {[s[0] for s in still] or 'none'})")
msg = ", ".join(f"{lbl} {t:.0f}°C>{r:.0f}°C" for lbl, t, r in still) if not still:
print(f"[thermal] Cooling{desc}: {msg} — waiting " break
f"({int(waited)}s elapsed)") msg = ", ".join(f"{lbl} {t:.0f}°C>{r:.0f}°C" for lbl, t, r in still)
time.sleep(settings.poll_seconds) _cooldown_update(gt, ct, waited, msg)
waited += settings.poll_seconds print(f"[thermal] Cooling{desc}: {msg} — waiting "
f"({int(waited)}s elapsed)")
time.sleep(settings.poll_seconds)
waited += settings.poll_seconds
finally:
_cooldown_exit()
print(f"[thermal] Temperatures back within safe limits{desc} — resuming " print(f"[thermal] Temperatures back within safe limits{desc} — resuming "
f"after {int(waited)}s") f"after {int(waited)}s")
This diff is collapsed.
...@@ -17,16 +17,17 @@ ...@@ -17,16 +17,17 @@
"""Pydantic models for embeddings API.""" """Pydantic models for embeddings API."""
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict, Field
class EmbeddingsRequest(BaseModel): class EmbeddingsRequest(BaseModel):
model: str model: str = Field(..., description="Embedding model id to use.")
input: Union[str, List[str]] # text(s) to embed input: Union[str, List[str]] = Field(..., description="Text or list of texts to embed.")
image: Optional[Union[str, List[str]]] = None # base64/URL image(s) for multimodal embed image: Optional[Union[str, List[str]]] = Field(None, description="Base64/URL image(s) for multimodal embedding models.")
encoding_format: Optional[str] = "float" # float | base64 encoding_format: Optional[str] = Field("float", description="Return embeddings as 'float' arrays or 'base64'.")
dimensions: Optional[int] = None # truncate to N dims if supported dimensions: Optional[int] = Field(None, description="Truncate embeddings to N dimensions (if the model supports it).")
user: Optional[str] = None quantization: Optional[str] = Field(None, description="Optional TurboQuant vector quantization: 'turbo' (8-bit), 'turbo8', 'turbo6', 'turbo4' or 'turbo2'. With encoding_format='float' the (lossy) reconstructed vectors are returned; with 'base64' the compact packed bytes are returned plus a 'quantization' metadata block describing how to decode them.")
user: Optional[str] = Field(None, description="Opaque end-user identifier (passthrough).")
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow")
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
from typing import Dict, List, Optional from typing import Dict, List, Optional
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict, Field
class LoraConfig(BaseModel): class LoraConfig(BaseModel):
...@@ -26,41 +26,43 @@ class LoraConfig(BaseModel): ...@@ -26,41 +26,43 @@ class LoraConfig(BaseModel):
server-side, in priority) via `id` ("name:<registered>" or "sha256:<hex>"), server-side, in priority) via `id` ("name:<registered>" or "sha256:<hex>"),
inline `file`/`data` base64, a `url`, or the legacy `model`/`path` (local path inline `file`/`data` base64, a `url`, or the legacy `model`/`path` (local path
/ HF id) — so a remote client needn't share the server's filesystem.""" / HF id) — so a remote client needn't share the server's filesystem."""
model: Optional[str] = None model: Optional[str] = Field(None, description="Legacy: local path or HF id of the weights (shared-filesystem only).")
path: Optional[str] = None path: Optional[str] = Field(None, description="Alias of `model` — local path to the .safetensors weights.")
id: Optional[str] = None id: Optional[str] = Field(None, description='Registry/blob reference: "name:<registered-lora>" or "sha256:<hex>" (from /v1/loras/upload).')
url: Optional[str] = None url: Optional[str] = Field(None, description="HTTP(S) URL the server downloads and caches.")
file: Optional[str] = None file: Optional[str] = Field(None, description="Base64 of the .safetensors file (or a data: URI). Sent inline so no shared filesystem is needed.")
data: Optional[str] = None data: Optional[str] = Field(None, description="Alias of `file` — inline base64 weights.")
weight: float = 1.0 weight: float = Field(1.0, description="Adapter strength / scale.")
name: Optional[str] = None name: Optional[str] = Field(None, description="Optional adapter name.")
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow")
class ImageGenerationRequest(BaseModel): class ImageGenerationRequest(BaseModel):
model: str model: str = Field(..., description="Model id to generate with (must be a configured image model).")
prompt: str prompt: str = Field(..., description="Text prompt describing the image.")
n: int = 1 n: int = Field(1, description="Number of images to generate.")
size: Optional[str] = "1024x1024" size: Optional[str] = Field("1024x1024", description="Output size as 'WIDTHxHEIGHT'.")
steps: Optional[int] = None steps: Optional[int] = Field(None, description="Denoising steps (model/acceleration default if omitted).")
guidance_scale: Optional[float] = None guidance_scale: Optional[float] = Field(None, description="Classifier-free guidance scale (model/acceleration default if omitted).")
quality: Optional[str] = "standard" quality: Optional[str] = Field("standard", description="Quality hint: 'standard' or 'hd'.")
style: Optional[str] = None style: Optional[str] = Field(None, description="Optional style hint passed through to the model.")
response_format: Optional[str] = "url" response_format: Optional[str] = Field("url", description="How to return the result: 'url' or 'b64_json'.")
seed: Optional[int] = None seed: Optional[int] = Field(None, description="Random seed for reproducibility.")
user: Optional[str] = None user: Optional[str] = Field(None, description="Opaque end-user identifier (passthrough).")
disable_safety_checker: Optional[bool] = False disable_safety_checker: Optional[bool] = Field(False, description=(
negative_prompt: Optional[str] = None "Null out the diffusers safety_checker so uncensored fine-tunes are not blocked. "
"Only affects SD 1.x/2.x (SDXL/Flux ship no checker)."))
negative_prompt: Optional[str] = Field(None, description="What to avoid in the output.")
# Per-request component overrides # Per-request component overrides
vae_model: Optional[str] = None # Override the VAE for this request vae_model: Optional[str] = Field(None, description="Override the VAE for this request.")
loras: Optional[List[LoraConfig]] = None # Additional LoRA weights for this request loras: Optional[List[LoraConfig]] = Field(None, description="Additional LoRA adapters to apply for this request.")
# Character consistency # Character consistency
character_profiles: Optional[List[str]] = None # saved profile names character_profiles: Optional[List[str]] = Field(None, description="Saved character profile names to apply (IP-Adapter).")
character_references: Optional[List[str]] = None # inline base64 images character_references: Optional[List[str]] = Field(None, description="Inline base64 reference images for character consistency.")
character_strength: Optional[float] = 0.6 # IP-Adapter scale character_strength: Optional[float] = Field(0.6, description="IP-Adapter scale for character references.")
environment_profiles: Optional[List[str]] = None # saved environment profile names (IP-Adapter) environment_profiles: Optional[List[str]] = Field(None, description="Saved environment profile names to apply (IP-Adapter).")
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow")
......
...@@ -67,34 +67,32 @@ class ChatMessage(BaseModel): ...@@ -67,34 +67,32 @@ class ChatMessage(BaseModel):
class ChatCompletionRequest(BaseModel): class ChatCompletionRequest(BaseModel):
model: str model: str = Field(..., description="Text/chat model id to use.")
messages: List[ChatMessage] messages: List[ChatMessage] = Field(..., description="Conversation messages (roles: system/user/assistant/tool). Content may include text and image parts for vision models.")
temperature: float = 0.7 temperature: float = Field(0.7, description="Sampling temperature; higher = more random.")
top_p: float = 1.0 top_p: float = Field(1.0, description="Nucleus sampling probability mass.")
n: int = 1 n: int = Field(1, description="Number of completions to generate.")
max_tokens: Optional[int] = None max_tokens: Optional[int] = Field(None, description="Max tokens to generate (model default if omitted).")
stream: bool = False stream: bool = Field(False, description="Stream the response as Server-Sent Events.")
stop: Optional[Union[str, List[str]]] = None stop: Optional[Union[str, List[str]]] = Field(None, description="Stop sequence(s) that end generation.")
presence_penalty: float = 0.0 presence_penalty: float = Field(0.0, description="Penalize tokens already present (encourages new topics).")
frequency_penalty: float = 0.0 frequency_penalty: float = Field(0.0, description="Penalize frequent tokens (reduces repetition).")
repeat_penalty: float = 1.0 repeat_penalty: float = Field(1.0, description="llama.cpp repetition penalty.")
tools: Optional[List[Tool]] = None tools: Optional[List[Tool]] = Field(None, description="Tool/function definitions the model may call.")
tool_choice: Optional[Union[str, Dict]] = "auto" tool_choice: Optional[Union[str, Dict]] = Field("auto", description="Tool selection: 'auto', 'none', or a specific tool.")
# Extra fields that clients may send but we ignore seed: Optional[int] = Field(None, description="Random seed for reproducibility.")
seed: Optional[int] = None logprobs: Optional[bool] = Field(None, description="Return token log-probabilities (if supported).")
logprobs: Optional[bool] = None top_logprobs: Optional[int] = Field(None, description="Number of top log-probs to return per token.")
top_logprobs: Optional[int] = None response_format: Optional[Dict] = Field(None, description="Structured-output format, e.g. {'type': 'json_object'}.")
response_format: Optional[Dict] = None user: Optional[str] = Field(None, description="Opaque end-user identifier (passthrough).")
user: Optional[str] = None enable_thinking: Optional[bool] = Field(False, description="Enable thinking/reasoning mode for models that support it.")
# Enable thinking/reasoning mode for supported models
enable_thinking: Optional[bool] = False
model_config = ConfigDict(extra="allow") # Allow extra fields to prevent 422 errors model_config = ConfigDict(extra="allow") # Allow extra fields to prevent 422 errors
class CompletionRequest(BaseModel): class CompletionRequest(BaseModel):
model: str model: str = Field(..., description="Text model id to use.")
prompt: Union[str, List[str]] prompt: Union[str, List[str]] = Field(..., description="Prompt text (or list of prompts) to complete.")
temperature: float = 0.7 temperature: float = 0.7
top_p: float = 1.0 top_p: float = 1.0
n: int = 1 n: int = 1
......
This diff is collapsed.
...@@ -169,6 +169,23 @@ class QueueManager: ...@@ -169,6 +169,23 @@ class QueueManager:
return index return index
return 0 return 0
def list_waiting(self) -> list:
"""Best-effort snapshot of queued (waiting) requests for the Tasks view.
Read without the async lock — fine for a read-only UI snapshot."""
out = []
for w in list(self.waiting):
out.append({
"request_id": w.request_id,
"model_key": w.model_key,
"enqueued_at": w.enqueued_at,
})
return out
def list_active(self) -> list:
"""Best-effort snapshot of in-flight leases for the Tasks view."""
return [{"request_id": rid, "model_key": lease.model_key}
for rid, lease in list(self.active_leases.items())]
def get_metrics(self) -> Dict[str, object]: def get_metrics(self) -> Dict[str, object]:
return { return {
"active": len(self.active_leases), "active": len(self.active_leases),
......
# CoderAI - OpenAI-compatible API server
# Copyright (C) 2026 Stefy Lanza <stefy@nexlab.net>
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
"""Central task registry for long-running operations."""
from codai.tasks.registry import (
Task,
TaskCancelled,
TaskRegistry,
task_registry,
raise_if_cancelled,
wait_if_paused,
)
__all__ = [
"Task",
"TaskCancelled",
"TaskRegistry",
"task_registry",
"raise_if_cancelled",
"wait_if_paused",
]
# CoderAI - OpenAI-compatible API server
# Copyright (C) 2026 Stefy Lanza <stefy@nexlab.net>
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
"""Central registry of long-running tasks (generation + training).
Lets the admin UI list every in-flight / recent task and cooperatively cancel
one. Cancellation is *cooperative*: the per-step diffusers callbacks
(``_step_cb`` / ``_vid_step_cb`` / ``_aud_step_cb`` / sd.cpp ``progress_callback``)
and the LoRA training loops call ``raise_if_cancelled(task_id)`` between steps,
which raises :class:`TaskCancelled` once the task is cancelled. This mirrors the
``codai.models.thermal.checkpoint()`` pause already invoked in those same hooks,
so the running model aborts at the next step boundary.
The registry is process-local and in-memory: it is the live view. Durable LoRA
training state lives separately in ``codai/api/loras.py`` (``_train_jobs.json``);
a task with a ``job_id`` links the two.
"""
import threading
import time
import uuid
from dataclasses import asdict, dataclass, field
from typing import Dict, List, Optional
class TaskCancelled(Exception):
"""Raised inside a worker when its task has been cancelled by the user."""
ACTIVE_STATES = ("queued", "running")
TERMINAL_STATES = ("done", "error", "cancelled")
@dataclass
class Task:
id: str
kind: str # training | image | video | audio | text | pipeline
title: str = ""
model: str = ""
status: str = "queued" # queued | running | done | error | cancelled
step: int = 0
total: int = 0
message: str = ""
job_id: Optional[str] = None # link to a durable loras training job, if any
created_at: float = field(default_factory=time.time)
started_at: Optional[float] = None
ended_at: Optional[float] = None
cancellable: bool = True
restartable: bool = False
paused: bool = False
def to_dict(self) -> dict:
d = asdict(self)
d["active"] = self.status in ACTIVE_STATES
return d
class TaskRegistry:
"""Thread-safe registry. All public methods take the registry lock briefly;
the per-task cancel flag is a ``threading.Event`` so ``is_cancelled`` is a
cheap lock-free read suitable for a hot step loop."""
def __init__(self, history: int = 50):
self._lock = threading.Lock()
self._tasks: Dict[str, Task] = {}
self._events: Dict[str, threading.Event] = {}
self._pause_events: Dict[str, threading.Event] = {}
self._history = history
def register(self, kind: str, *, title: str = "", model: str = "",
total: int = 0, job_id: Optional[str] = None,
status: str = "queued", cancellable: bool = True,
restartable: bool = False, task_id: Optional[str] = None) -> str:
tid = task_id or f"task-{uuid.uuid4().hex[:12]}"
with self._lock:
self._tasks[tid] = Task(
id=tid, kind=kind, title=title, model=model, total=total,
job_id=job_id, status=status, cancellable=cancellable,
restartable=restartable,
)
self._events[tid] = threading.Event()
self._pause_events[tid] = threading.Event()
self._prune_locked()
return tid
def start(self, tid: str) -> None:
with self._lock:
t = self._tasks.get(tid)
if t and t.status in ACTIVE_STATES:
t.status = "running"
t.started_at = t.started_at or time.time()
def update(self, tid: str, **fields) -> None:
with self._lock:
t = self._tasks.get(tid)
if not t:
return
for k, v in fields.items():
if hasattr(t, k):
setattr(t, k, v)
def step(self, tid: str, step: int, total: Optional[int] = None) -> None:
with self._lock:
t = self._tasks.get(tid)
if not t:
return
t.step = int(step)
if total is not None:
t.total = int(total)
def finish(self, tid: str, status: str = "done", message: str = "") -> None:
with self._lock:
t = self._tasks.get(tid)
if not t:
return
# A user cancel wins over a late 'done' from the worker unwinding.
if not (t.status == "cancelled" and status == "done"):
t.status = status
if message:
t.message = message
t.ended_at = time.time()
self._prune_locked()
def cancel(self, tid: str) -> bool:
with self._lock:
t = self._tasks.get(tid)
if not t:
return False
ev = self._events.get(tid)
if ev:
ev.set()
# Release any pause so a paused→cancelled task unblocks immediately.
pev = self._pause_events.get(tid)
if pev:
pev.set() # wakes wait_if_paused; its is_cancelled check then raises
t.paused = False
was = t.status
if was in ACTIVE_STATES:
t.status = "cancelled"
t.message = "cancelled"
# A queued task never entered a worker, so it's terminal now;
# a running task finalises when its worker observes the flag.
if was == "queued":
t.ended_at = time.time()
return True
def is_cancelled(self, tid: Optional[str]) -> bool:
if not tid:
return False
ev = self._events.get(tid)
return bool(ev and ev.is_set())
def raise_if_cancelled(self, tid: Optional[str]) -> None:
if self.is_cancelled(tid):
raise TaskCancelled(tid)
# --- Pause / resume (cooperative, at the next step boundary) -------------
def pause(self, tid: str) -> bool:
with self._lock:
t = self._tasks.get(tid)
ev = self._pause_events.get(tid)
if not t or ev is None or t.status not in ACTIVE_STATES:
return False
ev.set()
t.paused = True
return True
def resume(self, tid: str) -> bool:
with self._lock:
t = self._tasks.get(tid)
ev = self._pause_events.get(tid)
if not t or ev is None:
return False
ev.clear()
t.paused = False
return True
def is_paused(self, tid: Optional[str]) -> bool:
if not tid:
return False
ev = self._pause_events.get(tid)
return bool(ev and ev.is_set())
def wait_if_paused(self, tid: Optional[str], poll: float = 0.2) -> None:
"""Block while the task is paused, returning when it is resumed.
Stays responsive to cancellation: a paused task that is then cancelled
raises :class:`TaskCancelled` instead of hanging. Safe to call from a hot
step loop — a no-op unless the task is actually paused."""
ev = self._pause_events.get(tid) if tid else None
if ev is None or not ev.is_set():
return
while ev.is_set():
if self.is_cancelled(tid):
raise TaskCancelled(tid)
ev.wait(timeout=poll)
def remove(self, tid: str) -> bool:
"""Drop a task from the registry entirely (used to dismiss a finished/
cancelled task from the view). No-op on a missing id."""
with self._lock:
self._events.pop(tid, None)
self._pause_events.pop(tid, None)
return self._tasks.pop(tid, None) is not None
def get(self, tid: str) -> Optional[dict]:
with self._lock:
t = self._tasks.get(tid)
return t.to_dict() if t else None
def list(self) -> List[dict]:
with self._lock:
return [t.to_dict() for t in sorted(
self._tasks.values(), key=lambda x: x.created_at, reverse=True)]
def _prune_locked(self) -> None:
"""Keep all active tasks + the most recent ``history`` terminal ones."""
terminal = [t for t in self._tasks.values() if t.status in TERMINAL_STATES]
if len(terminal) <= self._history:
return
terminal.sort(key=lambda x: x.ended_at or x.created_at)
for t in terminal[:-self._history]:
self._tasks.pop(t.id, None)
self._events.pop(t.id, None)
self._pause_events.pop(t.id, None)
# Process-wide singleton.
task_registry = TaskRegistry()
def raise_if_cancelled(task_id: Optional[str]) -> None:
"""Free helper mirroring ``thermal.checkpoint()`` — call from hot step loops.
Raises :class:`TaskCancelled` if ``task_id`` has been cancelled; a falsy
``task_id`` is a no-op (so callers needn't guard)."""
task_registry.raise_if_cancelled(task_id)
def wait_if_paused(task_id: Optional[str]) -> None:
"""Free helper — block at a step boundary while ``task_id`` is paused.
Returns immediately when not paused; raises :class:`TaskCancelled` if the
task is cancelled while paused. A falsy ``task_id`` is a no-op."""
task_registry.wait_if_paused(task_id)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment