Backends, API, and tooling updates; gitignore township_output

- cuda/vulkan backend improvements and config plumbing
- API updates across characters, text, environments, audio, embeddings, tts
- admin chat/settings template updates
- add hf_loading helper, video request fields, platform paths
- new docs (CODERAI_API_DOCUMENTATION.md) and tools (review_outputs, video_dubber)
- ignore generated township_output/
Co-Authored-By: 's avatarClaude Opus 4.8 <noreply@anthropic.com>
parent 7dc60f66
......@@ -26,3 +26,6 @@ test_*.py
# Local git worktrees
.worktrees/
# Generated township fighter outputs
township_output/
This diff is collapsed.
This diff is collapsed.
......@@ -173,6 +173,16 @@ if [ "$BACKEND" = "nvidia" ]; then
echo -e "${YELLOW}Note: audiocraft not installed (audio generation with MusicGen optional)${NC}"
}
# Optional quantization backends for diffusers image/video pipelines:
# optimum-quanto -> enables 2-bit (int2) per-component quantization
# gguf -> enables loading GGUF-quantized components (Q5_K/Q6_K, etc.)
# bitsandbytes (4-bit/8-bit) comes via requirements-nvidia.txt; these add the
# extra widths that bitsandbytes cannot do.
echo -e "${YELLOW}Installing optional quantization backends (2-bit / GGUF)...${NC}"
pip install optimum-quanto gguf || {
echo -e "${YELLOW}Note: optimum-quanto/gguf not installed (2-bit and GGUF 5/6-bit quantization optional)${NC}"
}
# Install Flash Attention 2 if requested
if [ "$FLASH" = true ]; then
echo ""
......
......@@ -14,6 +14,63 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
# Configure the CUDA caching allocator BEFORE torch is imported anywhere.
# expandable_segments lets the allocator return freed pages to the driver even
# from partially-used segments. Without it, a single small live tensor (e.g. a
# tied embedding weight) pins an entire large segment, so torch.cuda.empty_cache()
# cannot release the GBs of already-freed weights around it after a model is
# evicted — VRAM stays occupied and the next model can't load. Honour any value
# the user already set.
import os as _os
_alloc_conf = _os.environ.get("PYTORCH_CUDA_ALLOC_CONF", "")
if "expandable_segments" not in _alloc_conf:
_os.environ["PYTORCH_CUDA_ALLOC_CONF"] = (
(_alloc_conf + ",") if _alloc_conf else ""
) + "expandable_segments:True"
# Cap CPU threads BEFORE torch / OpenMP / MKL initialise. Loading and 4-bit
# dequantising large models is CPU-heavy; left uncapped, torch/OpenMP grab every
# core and the machine's load average spikes and it becomes sluggish. On boxes
# with >= 8 cores, limit to HALF the cores so model loads never saturate the
# machine. Smaller machines keep the default (don't cripple them). Honour any
# value the user already set.
try:
_ncpu = _os.cpu_count() or 0
if _ncpu >= 8:
_cap = str(max(1, _ncpu // 2))
for _var in ("OMP_NUM_THREADS", "MKL_NUM_THREADS", "OPENBLAS_NUM_THREADS",
"NUMEXPR_NUM_THREADS", "VECLIB_MAXIMUM_THREADS"):
_os.environ.setdefault(_var, _cap)
except Exception:
pass
# Silence ONE specific upstream FutureWarning from bitsandbytes' quant kernels:
# bitsandbytes/backends/cuda/ops.py: torch._check_is_size(blocksize)
# bitsandbytes (latest, 0.49.2) still calls the deprecated torch._check_is_size
# on bleeding-edge torch. We don't call it ourselves and can't fix their source,
# so suppress just this message (not warnings in general) to keep logs readable.
import warnings as _warnings
_warnings.filterwarnings(
"ignore",
message=r".*_check_is_size will be removed.*",
category=FutureWarning,
)
# More upstream / diagnostic-only noise we can't fix from here:
# - huggingface_hub: diffusers/transformers pass the deprecated
# `local_dir_use_symlinks` kwarg to hf_hub_download (not our code).
# - torch.distributed.reduce_op: emitted while the debug leak-scanner walks
# gc.get_objects(); unavoidable without dropping the scan.
_warnings.filterwarnings(
"ignore",
message=r".*local_dir_use_symlinks.*",
category=UserWarning,
)
_warnings.filterwarnings(
"ignore",
message=r".*reduce_op.*is deprecated.*",
category=FutureWarning,
)
# codai module - AI model parsing utilities
from .models.parser import (
ModelParserDispatcher,
......
This diff is collapsed.
......@@ -102,6 +102,57 @@
</div>
</div>
<!-- Thermal protection -->
<div class="card mb-0" style="margin-top:1rem">
<div class="card-title">Thermal Protection</div>
<span class="form-hint" style="display:block;margin-bottom:.75rem">
Before serving a request against a loaded model, wait until temperatures are
safe so a long sequence of heavy generations can't overheat the machine and
trip its power-off protection. The wait is non-blocking (other requests keep
being accepted) and takes effect immediately on save. Temperatures in °C.
</span>
<div class="form-row">
<label style="display:flex;align-items:center;gap:.5rem;cursor:pointer">
<input type="checkbox" id="s-therm-gpu-enabled" onchange="toggleThermalFields()">
<span style="font-size:13px;font-weight:500">Enable GPU temperature protection</span>
</label>
</div>
<div id="therm-gpu-fields" class="form-row" style="display:grid;grid-template-columns:1fr 1fr;gap:1rem">
<div>
<label class="form-label">Pause when GPU reaches (°C)</label>
<input type="number" id="s-therm-gpu-high" class="form-input" min="40" max="120" step="1" placeholder="90">
</div>
<div>
<label class="form-label">Resume when GPU drops to (°C)</label>
<input type="number" id="s-therm-gpu-resume" class="form-input" min="30" max="120" step="1" placeholder="87">
</div>
</div>
<div class="form-row" style="margin-top:.5rem">
<label style="display:flex;align-items:center;gap:.5rem;cursor:pointer">
<input type="checkbox" id="s-therm-cpu-enabled" onchange="toggleThermalFields()">
<span style="font-size:13px;font-weight:500">Enable CPU temperature protection</span>
</label>
</div>
<div id="therm-cpu-fields" class="form-row" style="display:grid;grid-template-columns:1fr 1fr;gap:1rem">
<div>
<label class="form-label">Pause when CPU reaches (°C)</label>
<input type="number" id="s-therm-cpu-high" class="form-input" min="40" max="120" step="1" placeholder="90">
</div>
<div>
<label class="form-label">Resume when CPU drops to (°C)</label>
<input type="number" id="s-therm-cpu-resume" class="form-input" min="30" max="120" step="1" placeholder="87">
</div>
</div>
<div class="form-row" style="margin:0">
<label class="form-label">Re-check interval while cooling down (seconds)</label>
<input type="number" id="s-therm-poll" class="form-input" style="max-width:200px" min="1" max="120" step="1" placeholder="5">
<span class="form-hint">How often to re-read temperatures while waiting for cooldown.</span>
</div>
</div>
<div class="card mb-0" style="margin-top:1rem">
<div class="card-title">AISBF Broker</div>
<div class="form-row">
......@@ -210,6 +261,13 @@ function toggleBrokerFields(){
}
}
function toggleThermalFields(){
document.getElementById('therm-gpu-fields').style.display =
document.getElementById('s-therm-gpu-enabled').checked ? 'grid' : 'none';
document.getElementById('therm-cpu-fields').style.display =
document.getElementById('s-therm-cpu-enabled').checked ? 'grid' : 'none';
}
function showAlert(type, msg){
const el = document.getElementById('settings-alert');
el.className = 'alert alert-' + (type === 'error' ? 'error' : 'info');
......@@ -260,6 +318,16 @@ async function loadSettings(){
document.getElementById('s-broker-reconnect-max').value = broker.reconnect_max_delay_seconds ?? 60;
document.getElementById('s-broker-ws-ping').value = broker.websocket_ping_interval ?? 20;
toggleBrokerFields();
// Thermal protection
const therm = d.thermal || {};
document.getElementById('s-therm-gpu-enabled').checked = therm.gpu_enabled !== false;
document.getElementById('s-therm-cpu-enabled').checked = therm.cpu_enabled !== false;
document.getElementById('s-therm-gpu-high').value = therm.gpu_high ?? 90;
document.getElementById('s-therm-gpu-resume').value = therm.gpu_resume ?? 87;
document.getElementById('s-therm-cpu-high').value = therm.cpu_high ?? 90;
document.getElementById('s-therm-cpu-resume').value = therm.cpu_resume ?? 87;
document.getElementById('s-therm-poll').value = therm.poll_seconds ?? 5;
toggleThermalFields();
}catch(e){ showAlert('error','Failed to load settings: '+e.message); }
}
......@@ -286,6 +354,15 @@ async function saveSettings(){
directory: document.getElementById('s-arc-dir').value.trim(),
retention: document.getElementById('s-arc-retention').value,
},
thermal:{
gpu_enabled: document.getElementById('s-therm-gpu-enabled').checked,
cpu_enabled: document.getElementById('s-therm-cpu-enabled').checked,
gpu_high: parseFloat(document.getElementById('s-therm-gpu-high').value) || 90,
gpu_resume: parseFloat(document.getElementById('s-therm-gpu-resume').value) || 87,
cpu_high: parseFloat(document.getElementById('s-therm-cpu-high').value) || 90,
cpu_resume: parseFloat(document.getElementById('s-therm-cpu-resume').value) || 87,
poll_seconds: parseFloat(document.getElementById('s-therm-poll').value) || 5,
},
broker:{
enabled: document.getElementById('s-broker-enabled').checked,
base_url: document.getElementById('s-broker-base-url').value.trim(),
......@@ -310,7 +387,7 @@ async function saveSettings(){
method:'POST', headers:{'Content-Type':'application/json'},
body: JSON.stringify(data)
});
if(r.ok) showAlert('info','Settings saved. Archive changes take effect immediately; restart CoderAI for other changes.');
if(r.ok) showAlert('info','Settings saved. Archive and thermal-protection changes take effect immediately; restart CoderAI for other changes.');
else{ const e=await r.json(); showAlert('error', e.detail||'Save failed'); }
}catch(e){ showAlert('error','Error: '+e.message); }
}
......
......@@ -139,6 +139,7 @@ from codai.api.voice_clone import router as voice_clone_router
from codai.api.voice_convert import router as voice_convert_router
from codai.api.faceswap import router as faceswap_router
from codai.api.characters import router as characters_router
from codai.api.loras import router as loras_router
from codai.api.spatial import router as spatial_router
from codai.api.environments import router as environments_router
from codai.admin.routes import router as admin_router
......@@ -203,6 +204,7 @@ app.include_router(voice_clone_router)
app.include_router(voice_convert_router)
app.include_router(faceswap_router)
app.include_router(characters_router)
app.include_router(loras_router)
app.include_router(environments_router)
app.include_router(spatial_router)
app.include_router(admin_router)
......
......@@ -119,11 +119,35 @@ def _load_musicgen(model_name: str, device: str):
return model
def _load_audioldm(model_name: str, device: str):
def _load_audioldm(model_name: str, device: str, model_config: dict = None):
import torch
from diffusers import AudioLDM2Pipeline
pipe = AudioLDM2Pipeline.from_pretrained(model_name, torch_dtype=torch.float16)
pipe = pipe.to(device)
from codai.models.hf_loading import resolve_dtype
dtype = resolve_dtype(model_config, default='f16')
_xtra = {}
# Apply 4-bit/8-bit quantization to the diffusion backbone when configured.
_mc = model_config or {}
if _mc.get('load_in_4bit') or _mc.get('load_in_8bit'):
_bits = 4 if _mc.get('load_in_4bit') else 8
try:
from diffusers.quantizers import PipelineQuantizationConfig
_qk = ({'load_in_4bit': True, 'bnb_4bit_compute_dtype': dtype}
if _mc.get('load_in_4bit') else {'load_in_8bit': True})
_xtra['quantization_config'] = PipelineQuantizationConfig(
quant_backend=f"bitsandbytes_{_bits}bit",
quant_kwargs=_qk,
components_to_quantize=["transformer", "unet"],
)
print(f"AudioLDM quantization: {_bits}-bit (bitsandbytes)")
except Exception as e:
print(f"AudioLDM quantization unavailable: {e}")
pipe = AudioLDM2Pipeline.from_pretrained(model_name, torch_dtype=dtype, **_xtra)
# CPU offload when configured; otherwise place on device (skip for quantized).
_off = _mc.get('offload_strategy')
if _off in ('cpu', 'sequential', 'model', 'disk') and hasattr(pipe, 'enable_model_cpu_offload'):
pipe.enable_model_cpu_offload()
elif 'quantization_config' not in _xtra:
pipe = pipe.to(device)
return pipe
......@@ -224,7 +248,8 @@ async def audio_generate(request: AudioGenerationRequest, http_request: Request
Compatible models: MusicGen, AudioGen, AudioLDM2, StableAudio.
"""
_aud_progress_loading(request.model or "audio")
model_info = multi_model_manager.request_model(request.model, model_type="audio_gen")
model_info = await asyncio.to_thread(
multi_model_manager.request_model, request.model, model_type="audio_gen")
model_name = model_info.get('model_name')
if not model_name:
err = model_info.get('error', f"Model '{request.model}' not found")
......@@ -236,13 +261,14 @@ async def audio_generate(request: AudioGenerationRequest, http_request: Request
if pipe is None:
device = _derive_device()
model_type = _detect_audio_gen_type(model_name)
_ag_cfg = model_info.get('config') or {}
try:
if model_type in ('musicgen', 'audiogen'):
pipe = await asyncio.get_event_loop().run_in_executor(
None, _load_musicgen, model_name, device)
else:
pipe = await asyncio.get_event_loop().run_in_executor(
None, _load_audioldm, model_name, device)
None, _load_audioldm, model_name, device, _ag_cfg)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to load audio gen model: {e}")
multi_model_manager.models[model_key] = pipe
......
......@@ -37,9 +37,38 @@ import tempfile
import time
from typing import List, Optional
from fastapi import APIRouter, HTTPException, Request
from fastapi import APIRouter, Depends, HTTPException, Request
from pydantic import BaseModel, ConfigDict
def _require_api_auth(request: Request) -> None:
"""Raise 401 if auth is enabled and the request carries no valid credential."""
try:
from codai.admin import routes as _admin_routes
sm = _admin_routes.session_manager
except Exception:
return # auth subsystem unavailable — allow through
if sm is None:
return # auth not configured on this instance
auth = request.headers.get("authorization", "")
if auth.lower().startswith("bearer "):
token = auth[7:].strip()
if sm.verify_token(token):
return
cookie = request.cookies.get("session", "")
if cookie.endswith(".MUST_CHANGE"):
cookie = cookie[:-12]
if cookie and sm.validate_session(cookie):
return
raise HTTPException(
status_code=401,
detail={"message": "Invalid API key. Provide a valid Bearer token.",
"type": "invalid_request_error", "code": "invalid_api_key"},
)
from codai.platform_paths import default_characters_dir, legacy_style_config_dir
router = APIRouter()
......@@ -211,7 +240,12 @@ def _decode_source(data: str) -> bytes:
def _detect_faces_cv2(img_bytes: bytes):
"""Return list of (x,y,w,h) face rects using Haar cascade, or [] if cv2 unavailable."""
"""
Return list of (x,y,w,h) face rects, largest first.
Tries MediaPipe (most accurate), then OpenCV DNN, then Haar cascade as fallback.
Detections smaller than 2% of image area are discarded as false positives.
Returns [] if no library is available or no plausible face is found.
"""
try:
import cv2
import numpy as np
......@@ -219,19 +253,56 @@ def _detect_faces_cv2(img_bytes: bytes):
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
if img is None:
return []
ih, iw = img.shape[:2]
img_area = ih * iw
min_face_area = img_area * 0.02 # reject anything < 2% of image
# ── Try MediaPipe first (most accurate, no model download needed) ──
try:
import mediapipe as mp
mp_face = mp.solutions.face_detection
with mp_face.FaceDetection(model_selection=1, min_detection_confidence=0.5) as det:
rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
results = det.process(rgb)
if results.detections:
rects = []
for d in results.detections:
bb = d.location_data.relative_bounding_box
x = int(bb.xmin * iw)
y = int(bb.ymin * ih)
w = int(bb.width * iw)
h = int(bb.height * ih)
if w * h >= min_face_area:
rects.append((x, y, w, h))
if rects:
rects.sort(key=lambda r: r[2]*r[3], reverse=True)
return rects
except ImportError:
pass
# ── Haar cascade fallback (stricter parameters to reduce false positives) ──
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
gray = cv2.equalizeHist(gray)
cascade_path = cv2.data.haarcascades + 'haarcascade_frontalface_default.xml'
cascade = cv2.CascadeClassifier(cascade_path)
faces = cascade.detectMultiScale(gray, scaleFactor=1.1, minNeighbors=5, minSize=(40, 40))
# minSize scaled to image: at least 8% of the shorter dimension
min_dim = int(min(iw, ih) * 0.08)
faces = cascade.detectMultiScale(
gray, scaleFactor=1.05, minNeighbors=8,
minSize=(max(40, min_dim), max(40, min_dim)),
)
if len(faces) == 0:
return []
return [(int(x), int(y), int(w), int(h)) for x, y, w, h in faces]
rects = [(int(x), int(y), int(w), int(h)) for x, y, w, h in faces
if int(w) * int(h) >= min_face_area]
rects.sort(key=lambda r: r[2]*r[3], reverse=True)
return rects
except Exception:
return []
def _crop_face(img_bytes: bytes, rect) -> Optional[bytes]:
"""Crop a face rect (with padding) from an image, return PNG bytes."""
"""Crop a face rect with generous padding (head-and-shoulders), return PNG bytes."""
try:
import cv2
import numpy as np
......@@ -241,11 +312,15 @@ def _crop_face(img_bytes: bytes, rect) -> Optional[bytes]:
if img is None:
return None
ih, iw = img.shape[:2]
pad = int(max(w, h) * 0.4)
x1 = max(0, x - pad)
y1 = max(0, y - pad)
x2 = min(iw, x + w + pad)
y2 = min(ih, y + h + pad)
side = max(w, h)
# More padding on top to include hair/forehead, less at bottom
pad_sides = int(side * 0.5)
pad_top = int(side * 0.7)
pad_bot = int(side * 0.4)
x1 = max(0, x - pad_sides)
y1 = max(0, y - pad_top)
x2 = min(iw, x + w + pad_sides)
y2 = min(ih, y + h + pad_bot)
crop = img[y1:y2, x1:x2]
ok, buf = cv2.imencode('.png', crop)
return bytes(buf) if ok else None
......@@ -274,7 +349,7 @@ def _extract_from_image(img_bytes: bytes) -> List[bytes]:
crops = [c for f in faces for c in [_crop_face(img_bytes, f)] if c]
if crops:
return crops
# No face detected — use whole image as reference
# No face detected (or all detections filtered as false positives) — use whole image
try:
from PIL import Image as PILImage
img = PILImage.open(io.BytesIO(img_bytes)).convert('RGB')
......@@ -345,7 +420,7 @@ def resolve_character_profiles(profile_names: List[str]) -> List[str]:
# ── Endpoints ─────────────────────────────────────────────────────────────────
@router.post("/v1/characters")
async def save_character(req: CharacterSaveRequest):
async def save_character(req: CharacterSaveRequest, _auth=Depends(_require_api_auth)):
"""Save or update a named character profile."""
if not req.name or '/' in req.name or '..' in req.name:
raise HTTPException(status_code=400, detail="Invalid character name")
......@@ -356,13 +431,13 @@ async def save_character(req: CharacterSaveRequest):
@router.get("/v1/characters")
async def list_characters():
async def list_characters(_auth=Depends(_require_api_auth)):
"""List all saved character profiles (metadata only, no images)."""
return {"characters": _list_characters()}
@router.get("/v1/characters/{name}")
async def get_character(name: str):
async def get_character(name: str, _auth=Depends(_require_api_auth)):
"""Get a character profile including its reference images as base64."""
meta = _load_character_meta(name)
if not meta:
......@@ -378,7 +453,7 @@ async def get_character(name: str):
@router.delete("/v1/characters/{name}")
async def delete_character(name: str):
async def delete_character(name: str, _auth=Depends(_require_api_auth)):
"""Delete a character profile."""
cdir = _char_dir(name)
if not os.path.isdir(cdir):
......@@ -389,7 +464,7 @@ async def delete_character(name: str):
@router.patch("/v1/characters/{name}")
async def patch_character(name: str, req: CharacterPatchRequest):
async def patch_character(name: str, req: CharacterPatchRequest, _auth=Depends(_require_api_auth)):
"""Update a character profile: description, add images, or remove images by index."""
meta = _load_character_meta(name)
if not meta:
......@@ -462,29 +537,24 @@ async def generate_character(req: CharacterGenerateRequest, request: Request):
if req.steps:
payload["steps"] = req.steps
# Forward the caller's auth token so rate-limit / auth middleware passes
auth_header = request.headers.get("authorization", "")
headers = {"Content-Type": "application/json"}
if auth_header:
headers["Authorization"] = auth_header
try:
from httpx import AsyncClient, ASGITransport
async with AsyncClient(
transport=ASGITransport(app=request.app),
base_url="http://internal",
timeout=300,
) as client:
r = await client.post("/v1/images/generations", json=payload, headers=headers)
if not r.is_success:
import json as _json
from codai.broker.asgi_bridge import execute_internal_request
resp = await execute_internal_request(
request.app,
method="POST",
path="/v1/images/generations",
headers={"Content-Type": "application/json"},
body=_json.dumps(payload).encode(),
)
if resp["status_code"] >= 400:
try:
detail = r.json().get("detail", r.text)
detail = _json.loads(resp["body"]).get("detail", resp["body"].decode())
except Exception:
detail = r.text
raise HTTPException(status_code=r.status_code, detail=f"Image generation failed: {detail}")
detail = resp["body"].decode()
raise HTTPException(status_code=resp["status_code"], detail=f"Image generation failed: {detail}")
images_data = r.json().get("data", [])
images_data = _json.loads(resp["body"]).get("data", [])
except HTTPException:
raise
except Exception as e:
......
......@@ -48,10 +48,16 @@ def _derive_device() -> str:
return "cuda:0"
def _load_embedding_model(model_name: str, device: str):
def _load_embedding_model(model_name: str, device: str, model_config: dict = None):
from codai.models.hf_loading import build_from_pretrained_kwargs
try:
from sentence_transformers import SentenceTransformer
model = SentenceTransformer(model_name, device=device)
# sentence-transformers honours quantization via model_kwargs.
fp = build_from_pretrained_kwargs(model_config)
st_kwargs = {}
if 'quantization_config' in fp:
st_kwargs['model_kwargs'] = {'quantization_config': fp['quantization_config']}
model = SentenceTransformer(model_name, device=device, **st_kwargs)
return ('sentence_transformers', model)
except ImportError:
pass
......@@ -59,8 +65,11 @@ def _load_embedding_model(model_name: str, device: str):
try:
from transformers import AutoTokenizer, AutoModel
import torch
fp = build_from_pretrained_kwargs(model_config)
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name).to(device)
model = AutoModel.from_pretrained(model_name, **fp)
if 'quantization_config' not in fp and 'device_map' not in fp:
model = model.to(device)
return ('transformers', (tokenizer, model, device))
except Exception as e:
raise RuntimeError(f"Cannot load embedding model '{model_name}': {e}")
......@@ -97,7 +106,8 @@ async def create_embeddings(request: EmbeddingsRequest, http_request: Request =
"""
OpenAI-compatible embeddings endpoint.
"""
model_info = multi_model_manager.request_model(request.model, model_type="embedding")
model_info = await asyncio.to_thread(
multi_model_manager.request_model, request.model, model_type="embedding")
model_name = model_info.get('model_name')
if not model_name:
err = model_info.get('error', f"Model '{request.model}' not found")
......@@ -108,9 +118,11 @@ async def create_embeddings(request: EmbeddingsRequest, http_request: Request =
if model_obj is None:
device = _derive_device()
_emb_cfg = (multi_model_manager.config.get(f"embedding:{model_name}")
or multi_model_manager.config.get(model_name) or {})
try:
model_obj = await asyncio.get_event_loop().run_in_executor(
None, _load_embedding_model, model_name, device)
None, _load_embedding_model, model_name, device, _emb_cfg)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to load embedding model: {e}")
multi_model_manager.models[model_key] = model_obj
......
......@@ -39,7 +39,32 @@ import tempfile
import time
from typing import List, Optional
from fastapi import APIRouter, HTTPException, Request
from fastapi import APIRouter, Depends, HTTPException, Request
def _require_api_auth(request: Request) -> None:
"""Raise 401 if auth is enabled and the request carries no valid credential."""
try:
from codai.admin import routes as _admin_routes
sm = _admin_routes.session_manager
except Exception:
return
if sm is None:
return
auth = request.headers.get("authorization", "")
if auth.lower().startswith("bearer "):
if sm.verify_token(auth[7:].strip()):
return
cookie = request.cookies.get("session", "")
if cookie.endswith(".MUST_CHANGE"):
cookie = cookie[:-12]
if cookie and sm.validate_session(cookie):
return
raise HTTPException(
status_code=401,
detail={"message": "Invalid API key. Provide a valid Bearer token.",
"type": "invalid_request_error", "code": "invalid_api_key"},
)
from pydantic import BaseModel, ConfigDict
from codai.platform_paths import default_environments_dir, legacy_style_config_dir
......@@ -283,7 +308,7 @@ def resolve_environment_profiles(profile_names: List[str]) -> List[str]:
# ── Endpoints ─────────────────────────────────────────────────────────────────
@router.post("/v1/environments")
async def save_environment(req: EnvironmentSaveRequest):
async def save_environment(req: EnvironmentSaveRequest, _auth=Depends(_require_api_auth)):
"""Save or update a named environment profile."""
if not req.name or '/' in req.name or '..' in req.name:
raise HTTPException(status_code=400, detail="Invalid environment name")
......@@ -294,13 +319,13 @@ async def save_environment(req: EnvironmentSaveRequest):
@router.get("/v1/environments")
async def list_environments():
async def list_environments(_auth=Depends(_require_api_auth)):
"""List all saved environment profiles (metadata only)."""
return {"environments": _list_environments()}
@router.get("/v1/environments/{name}")
async def get_environment(name: str):
async def get_environment(name: str, _auth=Depends(_require_api_auth)):
"""Get an environment profile including its reference images as base64."""
meta = _load_environment_meta(name)
if not meta:
......@@ -316,7 +341,7 @@ async def get_environment(name: str):
@router.delete("/v1/environments/{name}")
async def delete_environment(name: str):
async def delete_environment(name: str, _auth=Depends(_require_api_auth)):
"""Delete an environment profile."""
edir = _env_dir(name)
if not os.path.isdir(edir):
......@@ -327,7 +352,7 @@ async def delete_environment(name: str):
@router.patch("/v1/environments/{name}")
async def patch_environment(name: str, req: EnvironmentPatchRequest):
async def patch_environment(name: str, req: EnvironmentPatchRequest, _auth=Depends(_require_api_auth)):
"""Update an environment profile: description, add images, or remove images by index."""
meta = _load_environment_meta(name)
if not meta:
......@@ -398,28 +423,24 @@ async def generate_environment(req: EnvironmentGenerateRequest, request: Request
if req.steps:
payload["steps"] = req.steps
auth_header = request.headers.get("authorization", "")
headers = {"Content-Type": "application/json"}
if auth_header:
headers["Authorization"] = auth_header
try:
from httpx import AsyncClient, ASGITransport
async with AsyncClient(
transport=ASGITransport(app=request.app),
base_url="http://internal",
timeout=300,
) as client:
r = await client.post("/v1/images/generations", json=payload, headers=headers)
if not r.is_success:
import json as _json
from codai.broker.asgi_bridge import execute_internal_request
resp = await execute_internal_request(
request.app,
method="POST",
path="/v1/images/generations",
headers={"Content-Type": "application/json"},
body=_json.dumps(payload).encode(),
)
if resp["status_code"] >= 400:
try:
detail = r.json().get("detail", r.text)
detail = _json.loads(resp["body"]).get("detail", resp["body"].decode())
except Exception:
detail = r.text
raise HTTPException(status_code=r.status_code, detail=f"Image generation failed: {detail}")
detail = resp["body"].decode()
raise HTTPException(status_code=resp["status_code"], detail=f"Image generation failed: {detail}")
images_data = r.json().get("data", [])
images_data = _json.loads(resp["body"]).get("data", [])
except HTTPException:
raise
except Exception as e:
......
......@@ -283,25 +283,69 @@ async def chat_completions(request: ChatCompletionRequest, http_request: Request
# Continue with original implementation for 'auto' parser
# Get the model for this request
requested_model = request.model
# Use the manager to resolve the model and manage VRAM (handles ondemand unloading)
model_info = multi_model_manager.request_model(
requested_model=requested_model,
model_type="text"
)
# Check if the model was rejected as not allowed
if model_info.get('error'):
raise HTTPException(status_code=404, detail=model_info['error'])
# Acquire the least-busy instance (increments ref-count; released on response completion)
_model_key = model_info.get('model_key')
# Resolve and load the model, waiting if another model is currently loading.
# Retries up to ~5 minutes (60 × 5s) so requests queue behind long video loads
# rather than failing immediately with "No model loaded".
_MAX_WAIT_TRIES = 60
_model_key = None
_instance_idx = None
_acq = multi_model_manager.acquire_model_instance(_model_key) if _model_key else None
if _acq:
_instance_idx, mm = _acq
else:
mm = multi_model_manager.get_model_for_request(requested_model)
mm = None
model_info = {}
for _attempt in range(_MAX_WAIT_TRIES):
# Fail fast on a corrupted CUDA context — retrying 60× is pointless.
if getattr(multi_model_manager, 'cuda_context_poisoned', False):
raise HTTPException(status_code=503, detail=(
"CUDA context corrupted by an earlier device-side assert "
f"({multi_model_manager.cuda_poison_reason}). Restart coderai to recover."))
# If another model is loading, yield the event loop and wait for it to finish.
if not multi_model_manager._model_ready_event.is_set():
print(f"Text model '{requested_model}': waiting for model load to complete "
f"(attempt {_attempt + 1}/{_MAX_WAIT_TRIES})…")
await asyncio.to_thread(
multi_model_manager._model_ready_event.wait, 30.0
)
await asyncio.sleep(0)
# In a thread: request_model may block waiting for a busy model to go
# idle before evicting it; blocking the event loop here would deadlock.
model_info = await asyncio.to_thread(
multi_model_manager.request_model,
requested_model,
"text",
)
if model_info.get('error'):
# CUDA-poison errors are unrecoverable → 503; others (unknown model) → 404.
_status = 503 if 'CUDA context corrupted' in str(model_info['error']) else 404
raise HTTPException(status_code=_status, detail=model_info['error'])
_model_key = model_info.get('model_key')
_candidate = None
_acq = multi_model_manager.acquire_model_instance(_model_key) if _model_key else None
if _acq:
_instance_idx, _candidate = _acq
# Guard against stale pool entries (model evicted but pool not cleared)
if hasattr(_candidate, 'backend') and _candidate.backend is None:
multi_model_manager.release_model_instance(_model_key, _instance_idx)
_instance_idx = None
_candidate = None
if _candidate is None:
_candidate = multi_model_manager.get_model_for_request(requested_model)
if _candidate is None and model_manager.backend is not None:
_candidate = model_manager
# Validate the candidate has a working backend before accepting it
if _candidate is not None:
if hasattr(_candidate, 'backend') and _candidate.backend is None:
_candidate = None
if _candidate is not None:
mm = _candidate
break
print(f"Text model '{requested_model}' not ready, retrying in 5s "
f"(attempt {_attempt + 1}/{_MAX_WAIT_TRIES})…")
await asyncio.sleep(5)
def _release_instance():
if _instance_idx is not None and _model_key:
......@@ -309,12 +353,10 @@ async def chat_completions(request: ChatCompletionRequest, http_request: Request
if mm is None:
_release_instance()
if model_manager.backend is not None:
current_manager = model_manager
else:
raise HTTPException(status_code=503, detail="Model not loaded")
else:
current_manager = mm
raise HTTPException(status_code=503,
detail=f"Model '{requested_model}' could not be loaded after waiting. "
"Another model may be using all available VRAM.")
current_manager = mm
# Inject system prompt if --system-prompt flag was provided
messages = request.messages
......@@ -1161,6 +1203,7 @@ async def chat_completions(request: ChatCompletionRequest, http_request: Request
tool_parser,
request.response_format,
_prefix_key,
enable_thinking=reasoning_enabled,
):
yield chunk
finally:
......@@ -1182,6 +1225,7 @@ async def chat_completions(request: ChatCompletionRequest, http_request: Request
tool_parser,
request.response_format,
force_reasoning_args,
enable_thinking=reasoning_enabled,
)
finally:
_release_instance()
......@@ -1198,6 +1242,7 @@ async def stream_chat_response(
tool_parser: ToolCallParser,
response_format: Optional[Dict] = None,
prefix_key: str = "",
enable_thinking: bool = False,
) -> AsyncGenerator[str, None]:
"""Stream chat completion response with queue notifications."""
completion_id = f"chatcmpl-{uuid.uuid4().hex}"
......@@ -1327,6 +1372,7 @@ async def stream_chat_response(
stop=stop,
tools=tools,
response_format=response_format,
enable_thinking=enable_thinking,
):
chunk_count += 1
# Always filter malformed content (regex-based, works per-chunk)
......@@ -1547,6 +1593,7 @@ async def generate_chat_response(
tool_parser: ToolCallParser,
response_format: Optional[Dict] = None,
force_reasoning_args: Optional[List[str]] = None,
enable_thinking: bool = False,
) -> Dict:
"""Generate non-streaming chat completion response."""
completion_id = f"chatcmpl-{uuid.uuid4().hex}"
......@@ -1583,6 +1630,7 @@ async def generate_chat_response(
stop=stop,
tools=tools,
response_format=response_format,
enable_thinking=enable_thinking,
)
# Always filter out malformed content
......@@ -1748,9 +1796,12 @@ async def completions(request: CompletionRequest):
requested_model = request.model
# Use the manager to resolve the model and manage VRAM (handles ondemand unloading)
model_info = multi_model_manager.request_model(
# In a thread: request_model may block (thermal cooldown / waiting for a busy
# model) and we must not stall the event loop.
model_info = await asyncio.to_thread(
multi_model_manager.request_model,
requested_model=requested_model,
model_type="text"
model_type="text",
)
# Check if the model was rejected as not allowed
......
......@@ -18,6 +18,7 @@
Audio transcription endpoint for the codai API.
"""
import asyncio
import io
import os
import tempfile
......@@ -143,7 +144,8 @@ async def create_transcription(
else multi_model_manager.whisper_servers.get(model)
)
if whisper_server is not None:
multi_model_manager.request_model(requested_model=model, model_type="audio")
await asyncio.to_thread(
multi_model_manager.request_model, requested_model=model, model_type="audio")
if not whisper_server.is_running():
whisper_server.start(
getattr(whisper_server, "_model_path", None),
......@@ -166,7 +168,8 @@ async def create_transcription(
return _format_response(response_format, result.get("text", ""), [])
# Use the manager to resolve the model and manage VRAM
model_info = multi_model_manager.request_model(
model_info = await asyncio.to_thread(
multi_model_manager.request_model,
requested_model=model,
model_type="audio"
)
......
......@@ -99,9 +99,10 @@ async def create_speech(request: TTSRequest, http_request: Request = None):
return {"audio": audio_base64}
# Use the manager to resolve the model and manage VRAM
model_info = multi_model_manager.request_model(
model_info = await asyncio.to_thread(
multi_model_manager.request_model,
requested_model=request.model,
model_type="tts"
model_type="tts",
)
# Check if the model was rejected as not allowed
......
This diff is collapsed.
......@@ -38,6 +38,34 @@ try:
except (ImportError, AttributeError):
_grammar_guided_gen = False
def _make_llama_thermal_criteria():
"""A llama.cpp StoppingCriteriaList that pauses generation while too hot.
llama-cpp-python evaluates stopping criteria synchronously per token inside
create_(chat_)completion, so blocking here pauses the GPU forward pass —
mid-generation thermal protection for the GGUF/Vulkan/llama.cpp backend.
The criterion never stops generation (returns False) and is throttled so it
doesn't read sensors on every token. Returns None if unavailable.
"""
try:
from llama_cpp import StoppingCriteriaList
except Exception:
return None
def _pause(input_ids, logits):
try:
from codai.models.thermal import checkpoint
checkpoint(context="text-gen", throttle_seconds=2.0)
except Exception:
pass
return False
try:
return StoppingCriteriaList([_pause])
except Exception:
return None
try:
from llama_cpp import Llama
from llama_cpp.llama_chat_format import ChatFormatterResponse
......@@ -699,6 +727,7 @@ class VulkanBackend(ModelBackend):
try:
result = self.model.create_completion(
stopping_criteria=_make_llama_thermal_criteria(),
prompt=prompt,
max_tokens=max_tokens,
temperature=temperature,
......@@ -717,6 +746,7 @@ class VulkanBackend(ModelBackend):
print(f"Warning: Grammar-guided generation failed: {e}, falling back to normal generation")
try:
result = self.model.create_completion(
stopping_criteria=_make_llama_thermal_criteria(),
prompt=prompt,
max_tokens=max_tokens,
temperature=temperature,
......@@ -803,6 +833,7 @@ class VulkanBackend(ModelBackend):
prompt_len = len(prompt) if isinstance(prompt, str) else 0
for chunk in self.model.create_completion(
stopping_criteria=_make_llama_thermal_criteria(),
prompt=prompt,
max_tokens=max_tokens,
temperature=temperature,
......@@ -842,6 +873,7 @@ class VulkanBackend(ModelBackend):
prompt_len = len(prompt) if isinstance(prompt, str) else 0
for chunk in self.model.create_completion(
stopping_criteria=_make_llama_thermal_criteria(),
prompt=prompt,
max_tokens=max_tokens,
temperature=temperature,
......@@ -911,6 +943,7 @@ class VulkanBackend(ModelBackend):
prompt_len = len(prompt)
for chunk in self.model.create_completion(
stopping_criteria=_make_llama_thermal_criteria(),
prompt=prompt,
max_tokens=max_tokens,
temperature=temperature,
......@@ -937,6 +970,7 @@ class VulkanBackend(ModelBackend):
return {"stream": generate_stream(), "content": ""}
else:
result = self.model.create_completion(
stopping_criteria=_make_llama_thermal_criteria(),
prompt=prompt,
max_tokens=max_tokens,
temperature=temperature,
......@@ -1052,6 +1086,9 @@ class VulkanBackend(ModelBackend):
kwargs['stop'] = stop
if response_format and response_format.get('type') == 'json_object':
kwargs['response_format'] = {'type': 'json_object'}
_tc = _make_llama_thermal_criteria()
if _tc is not None:
kwargs['stopping_criteria'] = _tc
result = self.model.create_chat_completion(**kwargs)
usage = result.get('usage', {})
......@@ -1077,6 +1114,9 @@ class VulkanBackend(ModelBackend):
)
if stop:
kwargs['stop'] = stop
_tc = _make_llama_thermal_criteria()
if _tc is not None:
kwargs['stopping_criteria'] = _tc
prompt_tokens = 0
completion_tokens = 0
......
......@@ -108,6 +108,24 @@ class ArchiveConfig:
retention: str = "never" # one of: 1h 1d 2d 1w 1m 3m 6m 1y never
@dataclass
class ThermalConfig:
"""Thermal-protection configuration.
Before running a request against a loaded model, wait until CPU/GPU
temperatures are within safe limits so a long sequence of heavy
generations can't overheat the machine and trip its power-off protection.
Thresholds are in degrees Celsius. CPU and GPU can be toggled separately.
"""
cpu_enabled: bool = True
gpu_enabled: bool = True
cpu_high: float = 90.0 # pause when CPU reaches this temperature
cpu_resume: float = 87.0 # resume once CPU drops back to/below this
gpu_high: float = 90.0 # pause when GPU reaches this temperature
gpu_resume: float = 87.0 # resume once GPU drops back to/below this
poll_seconds: float = 5.0 # how often to re-check while cooling down
@dataclass
class Config:
"""Main configuration class."""
......@@ -120,6 +138,7 @@ class Config:
image: ImageConfig = field(default_factory=ImageConfig)
whisper: WhisperConfig = field(default_factory=WhisperConfig)
archive: ArchiveConfig = field(default_factory=ArchiveConfig)
thermal: ThermalConfig = field(default_factory=ThermalConfig)
broker: BrokerConfig = field(default_factory=BrokerConfig)
system_prompt: Optional[str] = None
tools_closer_prompt: bool = False
......@@ -273,6 +292,7 @@ class ConfigManager:
image=ImageConfig(**config_data.get("image", {})),
whisper=WhisperConfig(**config_data.get("whisper", {})),
archive=ArchiveConfig(**config_data.get("archive", {})),
thermal=ThermalConfig(**config_data.get("thermal", {})),
broker=BrokerConfig(**config_data.get("broker", {})),
system_prompt=config_data.get("system_prompt"),
tools_closer_prompt=config_data.get("tools_closer_prompt", False),
......@@ -382,6 +402,15 @@ class ConfigManager:
"directory": self.config.archive.directory,
"retention": self.config.archive.retention,
},
"thermal": {
"cpu_enabled": self.config.thermal.cpu_enabled,
"gpu_enabled": self.config.thermal.gpu_enabled,
"cpu_high": self.config.thermal.cpu_high,
"cpu_resume": self.config.thermal.cpu_resume,
"gpu_high": self.config.thermal.gpu_high,
"gpu_resume": self.config.thermal.gpu_resume,
"poll_seconds": self.config.thermal.poll_seconds,
},
"broker": {
"enabled": self.config.broker.enabled,
"base_url": self.config.broker.base_url,
......
This diff is collapsed.
......@@ -112,6 +112,10 @@ def default_environments_dir() -> Path:
return ensure_dir(legacy_style_config_dir() / "environments")
def default_loras_dir() -> Path:
return ensure_dir(legacy_style_config_dir() / "loras")
def default_whisper_server_path() -> str:
if os.name == "nt":
local = _windows_dir("LOCALAPPDATA", _home_dir() / "AppData" / "Local")
......
......@@ -20,6 +20,14 @@ from typing import Dict, List, Optional
from pydantic import BaseModel, ConfigDict
class VideoLoraConfig(BaseModel):
"""A LoRA adapter to apply to the video diffusion pipeline for one request."""
model: str # path or HF id of the LoRA weights
weight: float = 1.0
name: Optional[str] = None
model_config = ConfigDict(extra="allow")
class CharacterDialogLine(BaseModel):
"""One spoken line in a multi-character dialog sequence."""
character: Optional[str] = None # character profile name (used for lip-sync face)
......@@ -78,6 +86,10 @@ class VideoGenerationRequest(BaseModel):
# Named saved profiles to load (resolved server-side)
character_profiles: Optional[List[str]] = None
# Per-request LoRA adapters (e.g. trained per-character identity LoRAs).
# Applied to diffusers video pipelines that support load_lora_weights.
loras: Optional[List[VideoLoraConfig]] = None
# ── Audio generation / manipulation ──────────────────────────────────
add_audio: Optional[bool] = False
audio_type: Optional[str] = None # music | speech | sfx | ambient
......
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment