Backends, API, and tooling updates; gitignore township_output

- cuda/vulkan backend improvements and config plumbing
- API updates across characters, text, environments, audio, embeddings, tts
- admin chat/settings template updates
- add hf_loading helper, video request fields, platform paths
- new docs (CODERAI_API_DOCUMENTATION.md) and tools (review_outputs, video_dubber)
- ignore generated township_output/
Co-Authored-By: 's avatarClaude Opus 4.8 <noreply@anthropic.com>
parent 7dc60f66
...@@ -26,3 +26,6 @@ test_*.py ...@@ -26,3 +26,6 @@ test_*.py
# Local git worktrees # Local git worktrees
.worktrees/ .worktrees/
# Generated township fighter outputs
township_output/
This diff is collapsed.
This diff is collapsed.
...@@ -173,6 +173,16 @@ if [ "$BACKEND" = "nvidia" ]; then ...@@ -173,6 +173,16 @@ if [ "$BACKEND" = "nvidia" ]; then
echo -e "${YELLOW}Note: audiocraft not installed (audio generation with MusicGen optional)${NC}" echo -e "${YELLOW}Note: audiocraft not installed (audio generation with MusicGen optional)${NC}"
} }
# Optional quantization backends for diffusers image/video pipelines:
# optimum-quanto -> enables 2-bit (int2) per-component quantization
# gguf -> enables loading GGUF-quantized components (Q5_K/Q6_K, etc.)
# bitsandbytes (4-bit/8-bit) comes via requirements-nvidia.txt; these add the
# extra widths that bitsandbytes cannot do.
echo -e "${YELLOW}Installing optional quantization backends (2-bit / GGUF)...${NC}"
pip install optimum-quanto gguf || {
echo -e "${YELLOW}Note: optimum-quanto/gguf not installed (2-bit and GGUF 5/6-bit quantization optional)${NC}"
}
# Install Flash Attention 2 if requested # Install Flash Attention 2 if requested
if [ "$FLASH" = true ]; then if [ "$FLASH" = true ]; then
echo "" echo ""
......
...@@ -14,6 +14,63 @@ ...@@ -14,6 +14,63 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
# Configure the CUDA caching allocator BEFORE torch is imported anywhere.
# expandable_segments lets the allocator return freed pages to the driver even
# from partially-used segments. Without it, a single small live tensor (e.g. a
# tied embedding weight) pins an entire large segment, so torch.cuda.empty_cache()
# cannot release the GBs of already-freed weights around it after a model is
# evicted — VRAM stays occupied and the next model can't load. Honour any value
# the user already set.
import os as _os
_alloc_conf = _os.environ.get("PYTORCH_CUDA_ALLOC_CONF", "")
if "expandable_segments" not in _alloc_conf:
_os.environ["PYTORCH_CUDA_ALLOC_CONF"] = (
(_alloc_conf + ",") if _alloc_conf else ""
) + "expandable_segments:True"
# Cap CPU threads BEFORE torch / OpenMP / MKL initialise. Loading and 4-bit
# dequantising large models is CPU-heavy; left uncapped, torch/OpenMP grab every
# core and the machine's load average spikes and it becomes sluggish. On boxes
# with >= 8 cores, limit to HALF the cores so model loads never saturate the
# machine. Smaller machines keep the default (don't cripple them). Honour any
# value the user already set.
try:
_ncpu = _os.cpu_count() or 0
if _ncpu >= 8:
_cap = str(max(1, _ncpu // 2))
for _var in ("OMP_NUM_THREADS", "MKL_NUM_THREADS", "OPENBLAS_NUM_THREADS",
"NUMEXPR_NUM_THREADS", "VECLIB_MAXIMUM_THREADS"):
_os.environ.setdefault(_var, _cap)
except Exception:
pass
# Silence ONE specific upstream FutureWarning from bitsandbytes' quant kernels:
# bitsandbytes/backends/cuda/ops.py: torch._check_is_size(blocksize)
# bitsandbytes (latest, 0.49.2) still calls the deprecated torch._check_is_size
# on bleeding-edge torch. We don't call it ourselves and can't fix their source,
# so suppress just this message (not warnings in general) to keep logs readable.
import warnings as _warnings
_warnings.filterwarnings(
"ignore",
message=r".*_check_is_size will be removed.*",
category=FutureWarning,
)
# More upstream / diagnostic-only noise we can't fix from here:
# - huggingface_hub: diffusers/transformers pass the deprecated
# `local_dir_use_symlinks` kwarg to hf_hub_download (not our code).
# - torch.distributed.reduce_op: emitted while the debug leak-scanner walks
# gc.get_objects(); unavoidable without dropping the scan.
_warnings.filterwarnings(
"ignore",
message=r".*local_dir_use_symlinks.*",
category=UserWarning,
)
_warnings.filterwarnings(
"ignore",
message=r".*reduce_op.*is deprecated.*",
category=FutureWarning,
)
# codai module - AI model parsing utilities # codai module - AI model parsing utilities
from .models.parser import ( from .models.parser import (
ModelParserDispatcher, ModelParserDispatcher,
......
This diff is collapsed.
...@@ -102,6 +102,57 @@ ...@@ -102,6 +102,57 @@
</div> </div>
</div> </div>
<!-- Thermal protection -->
<div class="card mb-0" style="margin-top:1rem">
<div class="card-title">Thermal Protection</div>
<span class="form-hint" style="display:block;margin-bottom:.75rem">
Before serving a request against a loaded model, wait until temperatures are
safe so a long sequence of heavy generations can't overheat the machine and
trip its power-off protection. The wait is non-blocking (other requests keep
being accepted) and takes effect immediately on save. Temperatures in °C.
</span>
<div class="form-row">
<label style="display:flex;align-items:center;gap:.5rem;cursor:pointer">
<input type="checkbox" id="s-therm-gpu-enabled" onchange="toggleThermalFields()">
<span style="font-size:13px;font-weight:500">Enable GPU temperature protection</span>
</label>
</div>
<div id="therm-gpu-fields" class="form-row" style="display:grid;grid-template-columns:1fr 1fr;gap:1rem">
<div>
<label class="form-label">Pause when GPU reaches (°C)</label>
<input type="number" id="s-therm-gpu-high" class="form-input" min="40" max="120" step="1" placeholder="90">
</div>
<div>
<label class="form-label">Resume when GPU drops to (°C)</label>
<input type="number" id="s-therm-gpu-resume" class="form-input" min="30" max="120" step="1" placeholder="87">
</div>
</div>
<div class="form-row" style="margin-top:.5rem">
<label style="display:flex;align-items:center;gap:.5rem;cursor:pointer">
<input type="checkbox" id="s-therm-cpu-enabled" onchange="toggleThermalFields()">
<span style="font-size:13px;font-weight:500">Enable CPU temperature protection</span>
</label>
</div>
<div id="therm-cpu-fields" class="form-row" style="display:grid;grid-template-columns:1fr 1fr;gap:1rem">
<div>
<label class="form-label">Pause when CPU reaches (°C)</label>
<input type="number" id="s-therm-cpu-high" class="form-input" min="40" max="120" step="1" placeholder="90">
</div>
<div>
<label class="form-label">Resume when CPU drops to (°C)</label>
<input type="number" id="s-therm-cpu-resume" class="form-input" min="30" max="120" step="1" placeholder="87">
</div>
</div>
<div class="form-row" style="margin:0">
<label class="form-label">Re-check interval while cooling down (seconds)</label>
<input type="number" id="s-therm-poll" class="form-input" style="max-width:200px" min="1" max="120" step="1" placeholder="5">
<span class="form-hint">How often to re-read temperatures while waiting for cooldown.</span>
</div>
</div>
<div class="card mb-0" style="margin-top:1rem"> <div class="card mb-0" style="margin-top:1rem">
<div class="card-title">AISBF Broker</div> <div class="card-title">AISBF Broker</div>
<div class="form-row"> <div class="form-row">
...@@ -210,6 +261,13 @@ function toggleBrokerFields(){ ...@@ -210,6 +261,13 @@ function toggleBrokerFields(){
} }
} }
function toggleThermalFields(){
document.getElementById('therm-gpu-fields').style.display =
document.getElementById('s-therm-gpu-enabled').checked ? 'grid' : 'none';
document.getElementById('therm-cpu-fields').style.display =
document.getElementById('s-therm-cpu-enabled').checked ? 'grid' : 'none';
}
function showAlert(type, msg){ function showAlert(type, msg){
const el = document.getElementById('settings-alert'); const el = document.getElementById('settings-alert');
el.className = 'alert alert-' + (type === 'error' ? 'error' : 'info'); el.className = 'alert alert-' + (type === 'error' ? 'error' : 'info');
...@@ -260,6 +318,16 @@ async function loadSettings(){ ...@@ -260,6 +318,16 @@ async function loadSettings(){
document.getElementById('s-broker-reconnect-max').value = broker.reconnect_max_delay_seconds ?? 60; document.getElementById('s-broker-reconnect-max').value = broker.reconnect_max_delay_seconds ?? 60;
document.getElementById('s-broker-ws-ping').value = broker.websocket_ping_interval ?? 20; document.getElementById('s-broker-ws-ping').value = broker.websocket_ping_interval ?? 20;
toggleBrokerFields(); toggleBrokerFields();
// Thermal protection
const therm = d.thermal || {};
document.getElementById('s-therm-gpu-enabled').checked = therm.gpu_enabled !== false;
document.getElementById('s-therm-cpu-enabled').checked = therm.cpu_enabled !== false;
document.getElementById('s-therm-gpu-high').value = therm.gpu_high ?? 90;
document.getElementById('s-therm-gpu-resume').value = therm.gpu_resume ?? 87;
document.getElementById('s-therm-cpu-high').value = therm.cpu_high ?? 90;
document.getElementById('s-therm-cpu-resume').value = therm.cpu_resume ?? 87;
document.getElementById('s-therm-poll').value = therm.poll_seconds ?? 5;
toggleThermalFields();
}catch(e){ showAlert('error','Failed to load settings: '+e.message); } }catch(e){ showAlert('error','Failed to load settings: '+e.message); }
} }
...@@ -286,6 +354,15 @@ async function saveSettings(){ ...@@ -286,6 +354,15 @@ async function saveSettings(){
directory: document.getElementById('s-arc-dir').value.trim(), directory: document.getElementById('s-arc-dir').value.trim(),
retention: document.getElementById('s-arc-retention').value, retention: document.getElementById('s-arc-retention').value,
}, },
thermal:{
gpu_enabled: document.getElementById('s-therm-gpu-enabled').checked,
cpu_enabled: document.getElementById('s-therm-cpu-enabled').checked,
gpu_high: parseFloat(document.getElementById('s-therm-gpu-high').value) || 90,
gpu_resume: parseFloat(document.getElementById('s-therm-gpu-resume').value) || 87,
cpu_high: parseFloat(document.getElementById('s-therm-cpu-high').value) || 90,
cpu_resume: parseFloat(document.getElementById('s-therm-cpu-resume').value) || 87,
poll_seconds: parseFloat(document.getElementById('s-therm-poll').value) || 5,
},
broker:{ broker:{
enabled: document.getElementById('s-broker-enabled').checked, enabled: document.getElementById('s-broker-enabled').checked,
base_url: document.getElementById('s-broker-base-url').value.trim(), base_url: document.getElementById('s-broker-base-url').value.trim(),
...@@ -310,7 +387,7 @@ async function saveSettings(){ ...@@ -310,7 +387,7 @@ async function saveSettings(){
method:'POST', headers:{'Content-Type':'application/json'}, method:'POST', headers:{'Content-Type':'application/json'},
body: JSON.stringify(data) body: JSON.stringify(data)
}); });
if(r.ok) showAlert('info','Settings saved. Archive changes take effect immediately; restart CoderAI for other changes.'); if(r.ok) showAlert('info','Settings saved. Archive and thermal-protection changes take effect immediately; restart CoderAI for other changes.');
else{ const e=await r.json(); showAlert('error', e.detail||'Save failed'); } else{ const e=await r.json(); showAlert('error', e.detail||'Save failed'); }
}catch(e){ showAlert('error','Error: '+e.message); } }catch(e){ showAlert('error','Error: '+e.message); }
} }
......
...@@ -139,6 +139,7 @@ from codai.api.voice_clone import router as voice_clone_router ...@@ -139,6 +139,7 @@ from codai.api.voice_clone import router as voice_clone_router
from codai.api.voice_convert import router as voice_convert_router from codai.api.voice_convert import router as voice_convert_router
from codai.api.faceswap import router as faceswap_router from codai.api.faceswap import router as faceswap_router
from codai.api.characters import router as characters_router from codai.api.characters import router as characters_router
from codai.api.loras import router as loras_router
from codai.api.spatial import router as spatial_router from codai.api.spatial import router as spatial_router
from codai.api.environments import router as environments_router from codai.api.environments import router as environments_router
from codai.admin.routes import router as admin_router from codai.admin.routes import router as admin_router
...@@ -203,6 +204,7 @@ app.include_router(voice_clone_router) ...@@ -203,6 +204,7 @@ app.include_router(voice_clone_router)
app.include_router(voice_convert_router) app.include_router(voice_convert_router)
app.include_router(faceswap_router) app.include_router(faceswap_router)
app.include_router(characters_router) app.include_router(characters_router)
app.include_router(loras_router)
app.include_router(environments_router) app.include_router(environments_router)
app.include_router(spatial_router) app.include_router(spatial_router)
app.include_router(admin_router) app.include_router(admin_router)
......
...@@ -119,11 +119,35 @@ def _load_musicgen(model_name: str, device: str): ...@@ -119,11 +119,35 @@ def _load_musicgen(model_name: str, device: str):
return model return model
def _load_audioldm(model_name: str, device: str): def _load_audioldm(model_name: str, device: str, model_config: dict = None):
import torch import torch
from diffusers import AudioLDM2Pipeline from diffusers import AudioLDM2Pipeline
pipe = AudioLDM2Pipeline.from_pretrained(model_name, torch_dtype=torch.float16) from codai.models.hf_loading import resolve_dtype
pipe = pipe.to(device) dtype = resolve_dtype(model_config, default='f16')
_xtra = {}
# Apply 4-bit/8-bit quantization to the diffusion backbone when configured.
_mc = model_config or {}
if _mc.get('load_in_4bit') or _mc.get('load_in_8bit'):
_bits = 4 if _mc.get('load_in_4bit') else 8
try:
from diffusers.quantizers import PipelineQuantizationConfig
_qk = ({'load_in_4bit': True, 'bnb_4bit_compute_dtype': dtype}
if _mc.get('load_in_4bit') else {'load_in_8bit': True})
_xtra['quantization_config'] = PipelineQuantizationConfig(
quant_backend=f"bitsandbytes_{_bits}bit",
quant_kwargs=_qk,
components_to_quantize=["transformer", "unet"],
)
print(f"AudioLDM quantization: {_bits}-bit (bitsandbytes)")
except Exception as e:
print(f"AudioLDM quantization unavailable: {e}")
pipe = AudioLDM2Pipeline.from_pretrained(model_name, torch_dtype=dtype, **_xtra)
# CPU offload when configured; otherwise place on device (skip for quantized).
_off = _mc.get('offload_strategy')
if _off in ('cpu', 'sequential', 'model', 'disk') and hasattr(pipe, 'enable_model_cpu_offload'):
pipe.enable_model_cpu_offload()
elif 'quantization_config' not in _xtra:
pipe = pipe.to(device)
return pipe return pipe
...@@ -224,7 +248,8 @@ async def audio_generate(request: AudioGenerationRequest, http_request: Request ...@@ -224,7 +248,8 @@ async def audio_generate(request: AudioGenerationRequest, http_request: Request
Compatible models: MusicGen, AudioGen, AudioLDM2, StableAudio. Compatible models: MusicGen, AudioGen, AudioLDM2, StableAudio.
""" """
_aud_progress_loading(request.model or "audio") _aud_progress_loading(request.model or "audio")
model_info = multi_model_manager.request_model(request.model, model_type="audio_gen") model_info = await asyncio.to_thread(
multi_model_manager.request_model, request.model, model_type="audio_gen")
model_name = model_info.get('model_name') model_name = model_info.get('model_name')
if not model_name: if not model_name:
err = model_info.get('error', f"Model '{request.model}' not found") err = model_info.get('error', f"Model '{request.model}' not found")
...@@ -236,13 +261,14 @@ async def audio_generate(request: AudioGenerationRequest, http_request: Request ...@@ -236,13 +261,14 @@ async def audio_generate(request: AudioGenerationRequest, http_request: Request
if pipe is None: if pipe is None:
device = _derive_device() device = _derive_device()
model_type = _detect_audio_gen_type(model_name) model_type = _detect_audio_gen_type(model_name)
_ag_cfg = model_info.get('config') or {}
try: try:
if model_type in ('musicgen', 'audiogen'): if model_type in ('musicgen', 'audiogen'):
pipe = await asyncio.get_event_loop().run_in_executor( pipe = await asyncio.get_event_loop().run_in_executor(
None, _load_musicgen, model_name, device) None, _load_musicgen, model_name, device)
else: else:
pipe = await asyncio.get_event_loop().run_in_executor( pipe = await asyncio.get_event_loop().run_in_executor(
None, _load_audioldm, model_name, device) None, _load_audioldm, model_name, device, _ag_cfg)
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to load audio gen model: {e}") raise HTTPException(status_code=500, detail=f"Failed to load audio gen model: {e}")
multi_model_manager.models[model_key] = pipe multi_model_manager.models[model_key] = pipe
......
...@@ -37,9 +37,38 @@ import tempfile ...@@ -37,9 +37,38 @@ import tempfile
import time import time
from typing import List, Optional from typing import List, Optional
from fastapi import APIRouter, HTTPException, Request from fastapi import APIRouter, Depends, HTTPException, Request
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
def _require_api_auth(request: Request) -> None:
"""Raise 401 if auth is enabled and the request carries no valid credential."""
try:
from codai.admin import routes as _admin_routes
sm = _admin_routes.session_manager
except Exception:
return # auth subsystem unavailable — allow through
if sm is None:
return # auth not configured on this instance
auth = request.headers.get("authorization", "")
if auth.lower().startswith("bearer "):
token = auth[7:].strip()
if sm.verify_token(token):
return
cookie = request.cookies.get("session", "")
if cookie.endswith(".MUST_CHANGE"):
cookie = cookie[:-12]
if cookie and sm.validate_session(cookie):
return
raise HTTPException(
status_code=401,
detail={"message": "Invalid API key. Provide a valid Bearer token.",
"type": "invalid_request_error", "code": "invalid_api_key"},
)
from codai.platform_paths import default_characters_dir, legacy_style_config_dir from codai.platform_paths import default_characters_dir, legacy_style_config_dir
router = APIRouter() router = APIRouter()
...@@ -211,7 +240,12 @@ def _decode_source(data: str) -> bytes: ...@@ -211,7 +240,12 @@ def _decode_source(data: str) -> bytes:
def _detect_faces_cv2(img_bytes: bytes): def _detect_faces_cv2(img_bytes: bytes):
"""Return list of (x,y,w,h) face rects using Haar cascade, or [] if cv2 unavailable.""" """
Return list of (x,y,w,h) face rects, largest first.
Tries MediaPipe (most accurate), then OpenCV DNN, then Haar cascade as fallback.
Detections smaller than 2% of image area are discarded as false positives.
Returns [] if no library is available or no plausible face is found.
"""
try: try:
import cv2 import cv2
import numpy as np import numpy as np
...@@ -219,19 +253,56 @@ def _detect_faces_cv2(img_bytes: bytes): ...@@ -219,19 +253,56 @@ def _detect_faces_cv2(img_bytes: bytes):
img = cv2.imdecode(arr, cv2.IMREAD_COLOR) img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
if img is None: if img is None:
return [] return []
ih, iw = img.shape[:2]
img_area = ih * iw
min_face_area = img_area * 0.02 # reject anything < 2% of image
# ── Try MediaPipe first (most accurate, no model download needed) ──
try:
import mediapipe as mp
mp_face = mp.solutions.face_detection
with mp_face.FaceDetection(model_selection=1, min_detection_confidence=0.5) as det:
rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
results = det.process(rgb)
if results.detections:
rects = []
for d in results.detections:
bb = d.location_data.relative_bounding_box
x = int(bb.xmin * iw)
y = int(bb.ymin * ih)
w = int(bb.width * iw)
h = int(bb.height * ih)
if w * h >= min_face_area:
rects.append((x, y, w, h))
if rects:
rects.sort(key=lambda r: r[2]*r[3], reverse=True)
return rects
except ImportError:
pass
# ── Haar cascade fallback (stricter parameters to reduce false positives) ──
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
gray = cv2.equalizeHist(gray)
cascade_path = cv2.data.haarcascades + 'haarcascade_frontalface_default.xml' cascade_path = cv2.data.haarcascades + 'haarcascade_frontalface_default.xml'
cascade = cv2.CascadeClassifier(cascade_path) cascade = cv2.CascadeClassifier(cascade_path)
faces = cascade.detectMultiScale(gray, scaleFactor=1.1, minNeighbors=5, minSize=(40, 40)) # minSize scaled to image: at least 8% of the shorter dimension
min_dim = int(min(iw, ih) * 0.08)
faces = cascade.detectMultiScale(
gray, scaleFactor=1.05, minNeighbors=8,
minSize=(max(40, min_dim), max(40, min_dim)),
)
if len(faces) == 0: if len(faces) == 0:
return [] return []
return [(int(x), int(y), int(w), int(h)) for x, y, w, h in faces] rects = [(int(x), int(y), int(w), int(h)) for x, y, w, h in faces
if int(w) * int(h) >= min_face_area]
rects.sort(key=lambda r: r[2]*r[3], reverse=True)
return rects
except Exception: except Exception:
return [] return []
def _crop_face(img_bytes: bytes, rect) -> Optional[bytes]: def _crop_face(img_bytes: bytes, rect) -> Optional[bytes]:
"""Crop a face rect (with padding) from an image, return PNG bytes.""" """Crop a face rect with generous padding (head-and-shoulders), return PNG bytes."""
try: try:
import cv2 import cv2
import numpy as np import numpy as np
...@@ -241,11 +312,15 @@ def _crop_face(img_bytes: bytes, rect) -> Optional[bytes]: ...@@ -241,11 +312,15 @@ def _crop_face(img_bytes: bytes, rect) -> Optional[bytes]:
if img is None: if img is None:
return None return None
ih, iw = img.shape[:2] ih, iw = img.shape[:2]
pad = int(max(w, h) * 0.4) side = max(w, h)
x1 = max(0, x - pad) # More padding on top to include hair/forehead, less at bottom
y1 = max(0, y - pad) pad_sides = int(side * 0.5)
x2 = min(iw, x + w + pad) pad_top = int(side * 0.7)
y2 = min(ih, y + h + pad) pad_bot = int(side * 0.4)
x1 = max(0, x - pad_sides)
y1 = max(0, y - pad_top)
x2 = min(iw, x + w + pad_sides)
y2 = min(ih, y + h + pad_bot)
crop = img[y1:y2, x1:x2] crop = img[y1:y2, x1:x2]
ok, buf = cv2.imencode('.png', crop) ok, buf = cv2.imencode('.png', crop)
return bytes(buf) if ok else None return bytes(buf) if ok else None
...@@ -274,7 +349,7 @@ def _extract_from_image(img_bytes: bytes) -> List[bytes]: ...@@ -274,7 +349,7 @@ def _extract_from_image(img_bytes: bytes) -> List[bytes]:
crops = [c for f in faces for c in [_crop_face(img_bytes, f)] if c] crops = [c for f in faces for c in [_crop_face(img_bytes, f)] if c]
if crops: if crops:
return crops return crops
# No face detected — use whole image as reference # No face detected (or all detections filtered as false positives) — use whole image
try: try:
from PIL import Image as PILImage from PIL import Image as PILImage
img = PILImage.open(io.BytesIO(img_bytes)).convert('RGB') img = PILImage.open(io.BytesIO(img_bytes)).convert('RGB')
...@@ -345,7 +420,7 @@ def resolve_character_profiles(profile_names: List[str]) -> List[str]: ...@@ -345,7 +420,7 @@ def resolve_character_profiles(profile_names: List[str]) -> List[str]:
# ── Endpoints ───────────────────────────────────────────────────────────────── # ── Endpoints ─────────────────────────────────────────────────────────────────
@router.post("/v1/characters") @router.post("/v1/characters")
async def save_character(req: CharacterSaveRequest): async def save_character(req: CharacterSaveRequest, _auth=Depends(_require_api_auth)):
"""Save or update a named character profile.""" """Save or update a named character profile."""
if not req.name or '/' in req.name or '..' in req.name: if not req.name or '/' in req.name or '..' in req.name:
raise HTTPException(status_code=400, detail="Invalid character name") raise HTTPException(status_code=400, detail="Invalid character name")
...@@ -356,13 +431,13 @@ async def save_character(req: CharacterSaveRequest): ...@@ -356,13 +431,13 @@ async def save_character(req: CharacterSaveRequest):
@router.get("/v1/characters") @router.get("/v1/characters")
async def list_characters(): async def list_characters(_auth=Depends(_require_api_auth)):
"""List all saved character profiles (metadata only, no images).""" """List all saved character profiles (metadata only, no images)."""
return {"characters": _list_characters()} return {"characters": _list_characters()}
@router.get("/v1/characters/{name}") @router.get("/v1/characters/{name}")
async def get_character(name: str): async def get_character(name: str, _auth=Depends(_require_api_auth)):
"""Get a character profile including its reference images as base64.""" """Get a character profile including its reference images as base64."""
meta = _load_character_meta(name) meta = _load_character_meta(name)
if not meta: if not meta:
...@@ -378,7 +453,7 @@ async def get_character(name: str): ...@@ -378,7 +453,7 @@ async def get_character(name: str):
@router.delete("/v1/characters/{name}") @router.delete("/v1/characters/{name}")
async def delete_character(name: str): async def delete_character(name: str, _auth=Depends(_require_api_auth)):
"""Delete a character profile.""" """Delete a character profile."""
cdir = _char_dir(name) cdir = _char_dir(name)
if not os.path.isdir(cdir): if not os.path.isdir(cdir):
...@@ -389,7 +464,7 @@ async def delete_character(name: str): ...@@ -389,7 +464,7 @@ async def delete_character(name: str):
@router.patch("/v1/characters/{name}") @router.patch("/v1/characters/{name}")
async def patch_character(name: str, req: CharacterPatchRequest): async def patch_character(name: str, req: CharacterPatchRequest, _auth=Depends(_require_api_auth)):
"""Update a character profile: description, add images, or remove images by index.""" """Update a character profile: description, add images, or remove images by index."""
meta = _load_character_meta(name) meta = _load_character_meta(name)
if not meta: if not meta:
...@@ -462,29 +537,24 @@ async def generate_character(req: CharacterGenerateRequest, request: Request): ...@@ -462,29 +537,24 @@ async def generate_character(req: CharacterGenerateRequest, request: Request):
if req.steps: if req.steps:
payload["steps"] = req.steps payload["steps"] = req.steps
# Forward the caller's auth token so rate-limit / auth middleware passes
auth_header = request.headers.get("authorization", "")
headers = {"Content-Type": "application/json"}
if auth_header:
headers["Authorization"] = auth_header
try: try:
from httpx import AsyncClient, ASGITransport import json as _json
async with AsyncClient( from codai.broker.asgi_bridge import execute_internal_request
transport=ASGITransport(app=request.app), resp = await execute_internal_request(
base_url="http://internal", request.app,
timeout=300, method="POST",
) as client: path="/v1/images/generations",
r = await client.post("/v1/images/generations", json=payload, headers=headers) headers={"Content-Type": "application/json"},
body=_json.dumps(payload).encode(),
if not r.is_success: )
if resp["status_code"] >= 400:
try: try:
detail = r.json().get("detail", r.text) detail = _json.loads(resp["body"]).get("detail", resp["body"].decode())
except Exception: except Exception:
detail = r.text detail = resp["body"].decode()
raise HTTPException(status_code=r.status_code, detail=f"Image generation failed: {detail}") raise HTTPException(status_code=resp["status_code"], detail=f"Image generation failed: {detail}")
images_data = r.json().get("data", []) images_data = _json.loads(resp["body"]).get("data", [])
except HTTPException: except HTTPException:
raise raise
except Exception as e: except Exception as e:
......
...@@ -48,10 +48,16 @@ def _derive_device() -> str: ...@@ -48,10 +48,16 @@ def _derive_device() -> str:
return "cuda:0" return "cuda:0"
def _load_embedding_model(model_name: str, device: str): def _load_embedding_model(model_name: str, device: str, model_config: dict = None):
from codai.models.hf_loading import build_from_pretrained_kwargs
try: try:
from sentence_transformers import SentenceTransformer from sentence_transformers import SentenceTransformer
model = SentenceTransformer(model_name, device=device) # sentence-transformers honours quantization via model_kwargs.
fp = build_from_pretrained_kwargs(model_config)
st_kwargs = {}
if 'quantization_config' in fp:
st_kwargs['model_kwargs'] = {'quantization_config': fp['quantization_config']}
model = SentenceTransformer(model_name, device=device, **st_kwargs)
return ('sentence_transformers', model) return ('sentence_transformers', model)
except ImportError: except ImportError:
pass pass
...@@ -59,8 +65,11 @@ def _load_embedding_model(model_name: str, device: str): ...@@ -59,8 +65,11 @@ def _load_embedding_model(model_name: str, device: str):
try: try:
from transformers import AutoTokenizer, AutoModel from transformers import AutoTokenizer, AutoModel
import torch import torch
fp = build_from_pretrained_kwargs(model_config)
tokenizer = AutoTokenizer.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name).to(device) model = AutoModel.from_pretrained(model_name, **fp)
if 'quantization_config' not in fp and 'device_map' not in fp:
model = model.to(device)
return ('transformers', (tokenizer, model, device)) return ('transformers', (tokenizer, model, device))
except Exception as e: except Exception as e:
raise RuntimeError(f"Cannot load embedding model '{model_name}': {e}") raise RuntimeError(f"Cannot load embedding model '{model_name}': {e}")
...@@ -97,7 +106,8 @@ async def create_embeddings(request: EmbeddingsRequest, http_request: Request = ...@@ -97,7 +106,8 @@ async def create_embeddings(request: EmbeddingsRequest, http_request: Request =
""" """
OpenAI-compatible embeddings endpoint. OpenAI-compatible embeddings endpoint.
""" """
model_info = multi_model_manager.request_model(request.model, model_type="embedding") model_info = await asyncio.to_thread(
multi_model_manager.request_model, request.model, model_type="embedding")
model_name = model_info.get('model_name') model_name = model_info.get('model_name')
if not model_name: if not model_name:
err = model_info.get('error', f"Model '{request.model}' not found") err = model_info.get('error', f"Model '{request.model}' not found")
...@@ -108,9 +118,11 @@ async def create_embeddings(request: EmbeddingsRequest, http_request: Request = ...@@ -108,9 +118,11 @@ async def create_embeddings(request: EmbeddingsRequest, http_request: Request =
if model_obj is None: if model_obj is None:
device = _derive_device() device = _derive_device()
_emb_cfg = (multi_model_manager.config.get(f"embedding:{model_name}")
or multi_model_manager.config.get(model_name) or {})
try: try:
model_obj = await asyncio.get_event_loop().run_in_executor( model_obj = await asyncio.get_event_loop().run_in_executor(
None, _load_embedding_model, model_name, device) None, _load_embedding_model, model_name, device, _emb_cfg)
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to load embedding model: {e}") raise HTTPException(status_code=500, detail=f"Failed to load embedding model: {e}")
multi_model_manager.models[model_key] = model_obj multi_model_manager.models[model_key] = model_obj
......
...@@ -39,7 +39,32 @@ import tempfile ...@@ -39,7 +39,32 @@ import tempfile
import time import time
from typing import List, Optional from typing import List, Optional
from fastapi import APIRouter, HTTPException, Request from fastapi import APIRouter, Depends, HTTPException, Request
def _require_api_auth(request: Request) -> None:
"""Raise 401 if auth is enabled and the request carries no valid credential."""
try:
from codai.admin import routes as _admin_routes
sm = _admin_routes.session_manager
except Exception:
return
if sm is None:
return
auth = request.headers.get("authorization", "")
if auth.lower().startswith("bearer "):
if sm.verify_token(auth[7:].strip()):
return
cookie = request.cookies.get("session", "")
if cookie.endswith(".MUST_CHANGE"):
cookie = cookie[:-12]
if cookie and sm.validate_session(cookie):
return
raise HTTPException(
status_code=401,
detail={"message": "Invalid API key. Provide a valid Bearer token.",
"type": "invalid_request_error", "code": "invalid_api_key"},
)
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from codai.platform_paths import default_environments_dir, legacy_style_config_dir from codai.platform_paths import default_environments_dir, legacy_style_config_dir
...@@ -283,7 +308,7 @@ def resolve_environment_profiles(profile_names: List[str]) -> List[str]: ...@@ -283,7 +308,7 @@ def resolve_environment_profiles(profile_names: List[str]) -> List[str]:
# ── Endpoints ───────────────────────────────────────────────────────────────── # ── Endpoints ─────────────────────────────────────────────────────────────────
@router.post("/v1/environments") @router.post("/v1/environments")
async def save_environment(req: EnvironmentSaveRequest): async def save_environment(req: EnvironmentSaveRequest, _auth=Depends(_require_api_auth)):
"""Save or update a named environment profile.""" """Save or update a named environment profile."""
if not req.name or '/' in req.name or '..' in req.name: if not req.name or '/' in req.name or '..' in req.name:
raise HTTPException(status_code=400, detail="Invalid environment name") raise HTTPException(status_code=400, detail="Invalid environment name")
...@@ -294,13 +319,13 @@ async def save_environment(req: EnvironmentSaveRequest): ...@@ -294,13 +319,13 @@ async def save_environment(req: EnvironmentSaveRequest):
@router.get("/v1/environments") @router.get("/v1/environments")
async def list_environments(): async def list_environments(_auth=Depends(_require_api_auth)):
"""List all saved environment profiles (metadata only).""" """List all saved environment profiles (metadata only)."""
return {"environments": _list_environments()} return {"environments": _list_environments()}
@router.get("/v1/environments/{name}") @router.get("/v1/environments/{name}")
async def get_environment(name: str): async def get_environment(name: str, _auth=Depends(_require_api_auth)):
"""Get an environment profile including its reference images as base64.""" """Get an environment profile including its reference images as base64."""
meta = _load_environment_meta(name) meta = _load_environment_meta(name)
if not meta: if not meta:
...@@ -316,7 +341,7 @@ async def get_environment(name: str): ...@@ -316,7 +341,7 @@ async def get_environment(name: str):
@router.delete("/v1/environments/{name}") @router.delete("/v1/environments/{name}")
async def delete_environment(name: str): async def delete_environment(name: str, _auth=Depends(_require_api_auth)):
"""Delete an environment profile.""" """Delete an environment profile."""
edir = _env_dir(name) edir = _env_dir(name)
if not os.path.isdir(edir): if not os.path.isdir(edir):
...@@ -327,7 +352,7 @@ async def delete_environment(name: str): ...@@ -327,7 +352,7 @@ async def delete_environment(name: str):
@router.patch("/v1/environments/{name}") @router.patch("/v1/environments/{name}")
async def patch_environment(name: str, req: EnvironmentPatchRequest): async def patch_environment(name: str, req: EnvironmentPatchRequest, _auth=Depends(_require_api_auth)):
"""Update an environment profile: description, add images, or remove images by index.""" """Update an environment profile: description, add images, or remove images by index."""
meta = _load_environment_meta(name) meta = _load_environment_meta(name)
if not meta: if not meta:
...@@ -398,28 +423,24 @@ async def generate_environment(req: EnvironmentGenerateRequest, request: Request ...@@ -398,28 +423,24 @@ async def generate_environment(req: EnvironmentGenerateRequest, request: Request
if req.steps: if req.steps:
payload["steps"] = req.steps payload["steps"] = req.steps
auth_header = request.headers.get("authorization", "")
headers = {"Content-Type": "application/json"}
if auth_header:
headers["Authorization"] = auth_header
try: try:
from httpx import AsyncClient, ASGITransport import json as _json
async with AsyncClient( from codai.broker.asgi_bridge import execute_internal_request
transport=ASGITransport(app=request.app), resp = await execute_internal_request(
base_url="http://internal", request.app,
timeout=300, method="POST",
) as client: path="/v1/images/generations",
r = await client.post("/v1/images/generations", json=payload, headers=headers) headers={"Content-Type": "application/json"},
body=_json.dumps(payload).encode(),
if not r.is_success: )
if resp["status_code"] >= 400:
try: try:
detail = r.json().get("detail", r.text) detail = _json.loads(resp["body"]).get("detail", resp["body"].decode())
except Exception: except Exception:
detail = r.text detail = resp["body"].decode()
raise HTTPException(status_code=r.status_code, detail=f"Image generation failed: {detail}") raise HTTPException(status_code=resp["status_code"], detail=f"Image generation failed: {detail}")
images_data = r.json().get("data", []) images_data = _json.loads(resp["body"]).get("data", [])
except HTTPException: except HTTPException:
raise raise
except Exception as e: except Exception as e:
......
...@@ -283,25 +283,69 @@ async def chat_completions(request: ChatCompletionRequest, http_request: Request ...@@ -283,25 +283,69 @@ async def chat_completions(request: ChatCompletionRequest, http_request: Request
# Continue with original implementation for 'auto' parser # Continue with original implementation for 'auto' parser
# Get the model for this request # Get the model for this request
requested_model = request.model requested_model = request.model
# Use the manager to resolve the model and manage VRAM (handles ondemand unloading) # Resolve and load the model, waiting if another model is currently loading.
model_info = multi_model_manager.request_model( # Retries up to ~5 minutes (60 × 5s) so requests queue behind long video loads
requested_model=requested_model, # rather than failing immediately with "No model loaded".
model_type="text" _MAX_WAIT_TRIES = 60
) _model_key = None
# Check if the model was rejected as not allowed
if model_info.get('error'):
raise HTTPException(status_code=404, detail=model_info['error'])
# Acquire the least-busy instance (increments ref-count; released on response completion)
_model_key = model_info.get('model_key')
_instance_idx = None _instance_idx = None
_acq = multi_model_manager.acquire_model_instance(_model_key) if _model_key else None mm = None
if _acq: model_info = {}
_instance_idx, mm = _acq
else: for _attempt in range(_MAX_WAIT_TRIES):
mm = multi_model_manager.get_model_for_request(requested_model) # Fail fast on a corrupted CUDA context — retrying 60× is pointless.
if getattr(multi_model_manager, 'cuda_context_poisoned', False):
raise HTTPException(status_code=503, detail=(
"CUDA context corrupted by an earlier device-side assert "
f"({multi_model_manager.cuda_poison_reason}). Restart coderai to recover."))
# If another model is loading, yield the event loop and wait for it to finish.
if not multi_model_manager._model_ready_event.is_set():
print(f"Text model '{requested_model}': waiting for model load to complete "
f"(attempt {_attempt + 1}/{_MAX_WAIT_TRIES})…")
await asyncio.to_thread(
multi_model_manager._model_ready_event.wait, 30.0
)
await asyncio.sleep(0)
# In a thread: request_model may block waiting for a busy model to go
# idle before evicting it; blocking the event loop here would deadlock.
model_info = await asyncio.to_thread(
multi_model_manager.request_model,
requested_model,
"text",
)
if model_info.get('error'):
# CUDA-poison errors are unrecoverable → 503; others (unknown model) → 404.
_status = 503 if 'CUDA context corrupted' in str(model_info['error']) else 404
raise HTTPException(status_code=_status, detail=model_info['error'])
_model_key = model_info.get('model_key')
_candidate = None
_acq = multi_model_manager.acquire_model_instance(_model_key) if _model_key else None
if _acq:
_instance_idx, _candidate = _acq
# Guard against stale pool entries (model evicted but pool not cleared)
if hasattr(_candidate, 'backend') and _candidate.backend is None:
multi_model_manager.release_model_instance(_model_key, _instance_idx)
_instance_idx = None
_candidate = None
if _candidate is None:
_candidate = multi_model_manager.get_model_for_request(requested_model)
if _candidate is None and model_manager.backend is not None:
_candidate = model_manager
# Validate the candidate has a working backend before accepting it
if _candidate is not None:
if hasattr(_candidate, 'backend') and _candidate.backend is None:
_candidate = None
if _candidate is not None:
mm = _candidate
break
print(f"Text model '{requested_model}' not ready, retrying in 5s "
f"(attempt {_attempt + 1}/{_MAX_WAIT_TRIES})…")
await asyncio.sleep(5)
def _release_instance(): def _release_instance():
if _instance_idx is not None and _model_key: if _instance_idx is not None and _model_key:
...@@ -309,12 +353,10 @@ async def chat_completions(request: ChatCompletionRequest, http_request: Request ...@@ -309,12 +353,10 @@ async def chat_completions(request: ChatCompletionRequest, http_request: Request
if mm is None: if mm is None:
_release_instance() _release_instance()
if model_manager.backend is not None: raise HTTPException(status_code=503,
current_manager = model_manager detail=f"Model '{requested_model}' could not be loaded after waiting. "
else: "Another model may be using all available VRAM.")
raise HTTPException(status_code=503, detail="Model not loaded") current_manager = mm
else:
current_manager = mm
# Inject system prompt if --system-prompt flag was provided # Inject system prompt if --system-prompt flag was provided
messages = request.messages messages = request.messages
...@@ -1161,6 +1203,7 @@ async def chat_completions(request: ChatCompletionRequest, http_request: Request ...@@ -1161,6 +1203,7 @@ async def chat_completions(request: ChatCompletionRequest, http_request: Request
tool_parser, tool_parser,
request.response_format, request.response_format,
_prefix_key, _prefix_key,
enable_thinking=reasoning_enabled,
): ):
yield chunk yield chunk
finally: finally:
...@@ -1182,6 +1225,7 @@ async def chat_completions(request: ChatCompletionRequest, http_request: Request ...@@ -1182,6 +1225,7 @@ async def chat_completions(request: ChatCompletionRequest, http_request: Request
tool_parser, tool_parser,
request.response_format, request.response_format,
force_reasoning_args, force_reasoning_args,
enable_thinking=reasoning_enabled,
) )
finally: finally:
_release_instance() _release_instance()
...@@ -1198,6 +1242,7 @@ async def stream_chat_response( ...@@ -1198,6 +1242,7 @@ async def stream_chat_response(
tool_parser: ToolCallParser, tool_parser: ToolCallParser,
response_format: Optional[Dict] = None, response_format: Optional[Dict] = None,
prefix_key: str = "", prefix_key: str = "",
enable_thinking: bool = False,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
"""Stream chat completion response with queue notifications.""" """Stream chat completion response with queue notifications."""
completion_id = f"chatcmpl-{uuid.uuid4().hex}" completion_id = f"chatcmpl-{uuid.uuid4().hex}"
...@@ -1327,6 +1372,7 @@ async def stream_chat_response( ...@@ -1327,6 +1372,7 @@ async def stream_chat_response(
stop=stop, stop=stop,
tools=tools, tools=tools,
response_format=response_format, response_format=response_format,
enable_thinking=enable_thinking,
): ):
chunk_count += 1 chunk_count += 1
# Always filter malformed content (regex-based, works per-chunk) # Always filter malformed content (regex-based, works per-chunk)
...@@ -1547,6 +1593,7 @@ async def generate_chat_response( ...@@ -1547,6 +1593,7 @@ async def generate_chat_response(
tool_parser: ToolCallParser, tool_parser: ToolCallParser,
response_format: Optional[Dict] = None, response_format: Optional[Dict] = None,
force_reasoning_args: Optional[List[str]] = None, force_reasoning_args: Optional[List[str]] = None,
enable_thinking: bool = False,
) -> Dict: ) -> Dict:
"""Generate non-streaming chat completion response.""" """Generate non-streaming chat completion response."""
completion_id = f"chatcmpl-{uuid.uuid4().hex}" completion_id = f"chatcmpl-{uuid.uuid4().hex}"
...@@ -1583,6 +1630,7 @@ async def generate_chat_response( ...@@ -1583,6 +1630,7 @@ async def generate_chat_response(
stop=stop, stop=stop,
tools=tools, tools=tools,
response_format=response_format, response_format=response_format,
enable_thinking=enable_thinking,
) )
# Always filter out malformed content # Always filter out malformed content
...@@ -1748,9 +1796,12 @@ async def completions(request: CompletionRequest): ...@@ -1748,9 +1796,12 @@ async def completions(request: CompletionRequest):
requested_model = request.model requested_model = request.model
# Use the manager to resolve the model and manage VRAM (handles ondemand unloading) # Use the manager to resolve the model and manage VRAM (handles ondemand unloading)
model_info = multi_model_manager.request_model( # In a thread: request_model may block (thermal cooldown / waiting for a busy
# model) and we must not stall the event loop.
model_info = await asyncio.to_thread(
multi_model_manager.request_model,
requested_model=requested_model, requested_model=requested_model,
model_type="text" model_type="text",
) )
# Check if the model was rejected as not allowed # Check if the model was rejected as not allowed
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
Audio transcription endpoint for the codai API. Audio transcription endpoint for the codai API.
""" """
import asyncio
import io import io
import os import os
import tempfile import tempfile
...@@ -143,7 +144,8 @@ async def create_transcription( ...@@ -143,7 +144,8 @@ async def create_transcription(
else multi_model_manager.whisper_servers.get(model) else multi_model_manager.whisper_servers.get(model)
) )
if whisper_server is not None: if whisper_server is not None:
multi_model_manager.request_model(requested_model=model, model_type="audio") await asyncio.to_thread(
multi_model_manager.request_model, requested_model=model, model_type="audio")
if not whisper_server.is_running(): if not whisper_server.is_running():
whisper_server.start( whisper_server.start(
getattr(whisper_server, "_model_path", None), getattr(whisper_server, "_model_path", None),
...@@ -166,7 +168,8 @@ async def create_transcription( ...@@ -166,7 +168,8 @@ async def create_transcription(
return _format_response(response_format, result.get("text", ""), []) return _format_response(response_format, result.get("text", ""), [])
# Use the manager to resolve the model and manage VRAM # Use the manager to resolve the model and manage VRAM
model_info = multi_model_manager.request_model( model_info = await asyncio.to_thread(
multi_model_manager.request_model,
requested_model=model, requested_model=model,
model_type="audio" model_type="audio"
) )
......
...@@ -99,9 +99,10 @@ async def create_speech(request: TTSRequest, http_request: Request = None): ...@@ -99,9 +99,10 @@ async def create_speech(request: TTSRequest, http_request: Request = None):
return {"audio": audio_base64} return {"audio": audio_base64}
# Use the manager to resolve the model and manage VRAM # Use the manager to resolve the model and manage VRAM
model_info = multi_model_manager.request_model( model_info = await asyncio.to_thread(
multi_model_manager.request_model,
requested_model=request.model, requested_model=request.model,
model_type="tts" model_type="tts",
) )
# Check if the model was rejected as not allowed # Check if the model was rejected as not allowed
......
This diff is collapsed.
...@@ -38,6 +38,34 @@ try: ...@@ -38,6 +38,34 @@ try:
except (ImportError, AttributeError): except (ImportError, AttributeError):
_grammar_guided_gen = False _grammar_guided_gen = False
def _make_llama_thermal_criteria():
"""A llama.cpp StoppingCriteriaList that pauses generation while too hot.
llama-cpp-python evaluates stopping criteria synchronously per token inside
create_(chat_)completion, so blocking here pauses the GPU forward pass —
mid-generation thermal protection for the GGUF/Vulkan/llama.cpp backend.
The criterion never stops generation (returns False) and is throttled so it
doesn't read sensors on every token. Returns None if unavailable.
"""
try:
from llama_cpp import StoppingCriteriaList
except Exception:
return None
def _pause(input_ids, logits):
try:
from codai.models.thermal import checkpoint
checkpoint(context="text-gen", throttle_seconds=2.0)
except Exception:
pass
return False
try:
return StoppingCriteriaList([_pause])
except Exception:
return None
try: try:
from llama_cpp import Llama from llama_cpp import Llama
from llama_cpp.llama_chat_format import ChatFormatterResponse from llama_cpp.llama_chat_format import ChatFormatterResponse
...@@ -699,6 +727,7 @@ class VulkanBackend(ModelBackend): ...@@ -699,6 +727,7 @@ class VulkanBackend(ModelBackend):
try: try:
result = self.model.create_completion( result = self.model.create_completion(
stopping_criteria=_make_llama_thermal_criteria(),
prompt=prompt, prompt=prompt,
max_tokens=max_tokens, max_tokens=max_tokens,
temperature=temperature, temperature=temperature,
...@@ -717,6 +746,7 @@ class VulkanBackend(ModelBackend): ...@@ -717,6 +746,7 @@ class VulkanBackend(ModelBackend):
print(f"Warning: Grammar-guided generation failed: {e}, falling back to normal generation") print(f"Warning: Grammar-guided generation failed: {e}, falling back to normal generation")
try: try:
result = self.model.create_completion( result = self.model.create_completion(
stopping_criteria=_make_llama_thermal_criteria(),
prompt=prompt, prompt=prompt,
max_tokens=max_tokens, max_tokens=max_tokens,
temperature=temperature, temperature=temperature,
...@@ -803,6 +833,7 @@ class VulkanBackend(ModelBackend): ...@@ -803,6 +833,7 @@ class VulkanBackend(ModelBackend):
prompt_len = len(prompt) if isinstance(prompt, str) else 0 prompt_len = len(prompt) if isinstance(prompt, str) else 0
for chunk in self.model.create_completion( for chunk in self.model.create_completion(
stopping_criteria=_make_llama_thermal_criteria(),
prompt=prompt, prompt=prompt,
max_tokens=max_tokens, max_tokens=max_tokens,
temperature=temperature, temperature=temperature,
...@@ -842,6 +873,7 @@ class VulkanBackend(ModelBackend): ...@@ -842,6 +873,7 @@ class VulkanBackend(ModelBackend):
prompt_len = len(prompt) if isinstance(prompt, str) else 0 prompt_len = len(prompt) if isinstance(prompt, str) else 0
for chunk in self.model.create_completion( for chunk in self.model.create_completion(
stopping_criteria=_make_llama_thermal_criteria(),
prompt=prompt, prompt=prompt,
max_tokens=max_tokens, max_tokens=max_tokens,
temperature=temperature, temperature=temperature,
...@@ -911,6 +943,7 @@ class VulkanBackend(ModelBackend): ...@@ -911,6 +943,7 @@ class VulkanBackend(ModelBackend):
prompt_len = len(prompt) prompt_len = len(prompt)
for chunk in self.model.create_completion( for chunk in self.model.create_completion(
stopping_criteria=_make_llama_thermal_criteria(),
prompt=prompt, prompt=prompt,
max_tokens=max_tokens, max_tokens=max_tokens,
temperature=temperature, temperature=temperature,
...@@ -937,6 +970,7 @@ class VulkanBackend(ModelBackend): ...@@ -937,6 +970,7 @@ class VulkanBackend(ModelBackend):
return {"stream": generate_stream(), "content": ""} return {"stream": generate_stream(), "content": ""}
else: else:
result = self.model.create_completion( result = self.model.create_completion(
stopping_criteria=_make_llama_thermal_criteria(),
prompt=prompt, prompt=prompt,
max_tokens=max_tokens, max_tokens=max_tokens,
temperature=temperature, temperature=temperature,
...@@ -1052,6 +1086,9 @@ class VulkanBackend(ModelBackend): ...@@ -1052,6 +1086,9 @@ class VulkanBackend(ModelBackend):
kwargs['stop'] = stop kwargs['stop'] = stop
if response_format and response_format.get('type') == 'json_object': if response_format and response_format.get('type') == 'json_object':
kwargs['response_format'] = {'type': 'json_object'} kwargs['response_format'] = {'type': 'json_object'}
_tc = _make_llama_thermal_criteria()
if _tc is not None:
kwargs['stopping_criteria'] = _tc
result = self.model.create_chat_completion(**kwargs) result = self.model.create_chat_completion(**kwargs)
usage = result.get('usage', {}) usage = result.get('usage', {})
...@@ -1077,6 +1114,9 @@ class VulkanBackend(ModelBackend): ...@@ -1077,6 +1114,9 @@ class VulkanBackend(ModelBackend):
) )
if stop: if stop:
kwargs['stop'] = stop kwargs['stop'] = stop
_tc = _make_llama_thermal_criteria()
if _tc is not None:
kwargs['stopping_criteria'] = _tc
prompt_tokens = 0 prompt_tokens = 0
completion_tokens = 0 completion_tokens = 0
......
...@@ -108,6 +108,24 @@ class ArchiveConfig: ...@@ -108,6 +108,24 @@ class ArchiveConfig:
retention: str = "never" # one of: 1h 1d 2d 1w 1m 3m 6m 1y never retention: str = "never" # one of: 1h 1d 2d 1w 1m 3m 6m 1y never
@dataclass
class ThermalConfig:
"""Thermal-protection configuration.
Before running a request against a loaded model, wait until CPU/GPU
temperatures are within safe limits so a long sequence of heavy
generations can't overheat the machine and trip its power-off protection.
Thresholds are in degrees Celsius. CPU and GPU can be toggled separately.
"""
cpu_enabled: bool = True
gpu_enabled: bool = True
cpu_high: float = 90.0 # pause when CPU reaches this temperature
cpu_resume: float = 87.0 # resume once CPU drops back to/below this
gpu_high: float = 90.0 # pause when GPU reaches this temperature
gpu_resume: float = 87.0 # resume once GPU drops back to/below this
poll_seconds: float = 5.0 # how often to re-check while cooling down
@dataclass @dataclass
class Config: class Config:
"""Main configuration class.""" """Main configuration class."""
...@@ -120,6 +138,7 @@ class Config: ...@@ -120,6 +138,7 @@ class Config:
image: ImageConfig = field(default_factory=ImageConfig) image: ImageConfig = field(default_factory=ImageConfig)
whisper: WhisperConfig = field(default_factory=WhisperConfig) whisper: WhisperConfig = field(default_factory=WhisperConfig)
archive: ArchiveConfig = field(default_factory=ArchiveConfig) archive: ArchiveConfig = field(default_factory=ArchiveConfig)
thermal: ThermalConfig = field(default_factory=ThermalConfig)
broker: BrokerConfig = field(default_factory=BrokerConfig) broker: BrokerConfig = field(default_factory=BrokerConfig)
system_prompt: Optional[str] = None system_prompt: Optional[str] = None
tools_closer_prompt: bool = False tools_closer_prompt: bool = False
...@@ -273,6 +292,7 @@ class ConfigManager: ...@@ -273,6 +292,7 @@ class ConfigManager:
image=ImageConfig(**config_data.get("image", {})), image=ImageConfig(**config_data.get("image", {})),
whisper=WhisperConfig(**config_data.get("whisper", {})), whisper=WhisperConfig(**config_data.get("whisper", {})),
archive=ArchiveConfig(**config_data.get("archive", {})), archive=ArchiveConfig(**config_data.get("archive", {})),
thermal=ThermalConfig(**config_data.get("thermal", {})),
broker=BrokerConfig(**config_data.get("broker", {})), broker=BrokerConfig(**config_data.get("broker", {})),
system_prompt=config_data.get("system_prompt"), system_prompt=config_data.get("system_prompt"),
tools_closer_prompt=config_data.get("tools_closer_prompt", False), tools_closer_prompt=config_data.get("tools_closer_prompt", False),
...@@ -382,6 +402,15 @@ class ConfigManager: ...@@ -382,6 +402,15 @@ class ConfigManager:
"directory": self.config.archive.directory, "directory": self.config.archive.directory,
"retention": self.config.archive.retention, "retention": self.config.archive.retention,
}, },
"thermal": {
"cpu_enabled": self.config.thermal.cpu_enabled,
"gpu_enabled": self.config.thermal.gpu_enabled,
"cpu_high": self.config.thermal.cpu_high,
"cpu_resume": self.config.thermal.cpu_resume,
"gpu_high": self.config.thermal.gpu_high,
"gpu_resume": self.config.thermal.gpu_resume,
"poll_seconds": self.config.thermal.poll_seconds,
},
"broker": { "broker": {
"enabled": self.config.broker.enabled, "enabled": self.config.broker.enabled,
"base_url": self.config.broker.base_url, "base_url": self.config.broker.base_url,
......
This diff is collapsed.
...@@ -112,6 +112,10 @@ def default_environments_dir() -> Path: ...@@ -112,6 +112,10 @@ def default_environments_dir() -> Path:
return ensure_dir(legacy_style_config_dir() / "environments") return ensure_dir(legacy_style_config_dir() / "environments")
def default_loras_dir() -> Path:
return ensure_dir(legacy_style_config_dir() / "loras")
def default_whisper_server_path() -> str: def default_whisper_server_path() -> str:
if os.name == "nt": if os.name == "nt":
local = _windows_dir("LOCALAPPDATA", _home_dir() / "AppData" / "Local") local = _windows_dir("LOCALAPPDATA", _home_dir() / "AppData" / "Local")
......
...@@ -20,6 +20,14 @@ from typing import Dict, List, Optional ...@@ -20,6 +20,14 @@ from typing import Dict, List, Optional
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
class VideoLoraConfig(BaseModel):
"""A LoRA adapter to apply to the video diffusion pipeline for one request."""
model: str # path or HF id of the LoRA weights
weight: float = 1.0
name: Optional[str] = None
model_config = ConfigDict(extra="allow")
class CharacterDialogLine(BaseModel): class CharacterDialogLine(BaseModel):
"""One spoken line in a multi-character dialog sequence.""" """One spoken line in a multi-character dialog sequence."""
character: Optional[str] = None # character profile name (used for lip-sync face) character: Optional[str] = None # character profile name (used for lip-sync face)
...@@ -78,6 +86,10 @@ class VideoGenerationRequest(BaseModel): ...@@ -78,6 +86,10 @@ class VideoGenerationRequest(BaseModel):
# Named saved profiles to load (resolved server-side) # Named saved profiles to load (resolved server-side)
character_profiles: Optional[List[str]] = None character_profiles: Optional[List[str]] = None
# Per-request LoRA adapters (e.g. trained per-character identity LoRAs).
# Applied to diffusers video pipelines that support load_lora_weights.
loras: Optional[List[VideoLoraConfig]] = None
# ── Audio generation / manipulation ────────────────────────────────── # ── Audio generation / manipulation ──────────────────────────────────
add_audio: Optional[bool] = False add_audio: Optional[bool] = False
audio_type: Optional[str] = None # music | speech | sfx | ambient audio_type: Optional[str] = None # music | speech | sfx | ambient
......
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment