wip: snapshot in-progress platform updates

parent 8fd1c5c2
This diff is collapsed.
......@@ -522,7 +522,14 @@ elif [ "$BACKEND" = "all" ]; then
pip install setproctitle || echo -e "${YELLOW}Warning: setproctitle failed (optional)${NC}"
# Try stable-diffusion-cpp-python (disable WebM to avoid missing libwebm cmake submodule)
# Use CUDA if available (detected later in this block, check nvcc now)
if command -v nvcc &> /dev/null || [ -d "/usr/local/cuda" ]; then
CMAKE_ARGS="$SD_CMAKE_ARGS -DSD_CUDA=ON" pip install stable-diffusion-cpp-python --no-cache-dir || \
CMAKE_ARGS="$SD_CMAKE_ARGS" pip install stable-diffusion-cpp-python || \
echo -e "${YELLOW}Warning: stable-diffusion-cpp-python failed (optional)${NC}"
else
CMAKE_ARGS="$SD_CMAKE_ARGS" pip install stable-diffusion-cpp-python || echo -e "${YELLOW}Warning: stable-diffusion-cpp-python failed (optional)${NC}"
fi
}
# Install PyTorch with CUDA support (for nvidia backend)
......@@ -622,14 +629,28 @@ elif [ "$BACKEND" = "all" ]; then
echo -e "${YELLOW}Warning: Some Vulkan packages failed to install${NC}"
}
# Try to install stable-diffusion-cpp-python with OpenCL
if [ "$OPENCL_AVAILABLE" = true ]; then
echo -e "${YELLOW}Installing stable-diffusion-cpp-python with OpenCL support...${NC}"
CMAKE_ARGS="$SD_CMAKE_ARGS" pip install stable-diffusion-cpp-python || {
echo -e "${YELLOW}Warning: stable-diffusion-cpp-python not available (requires CMake and build tools)${NC}"
# Try to install stable-diffusion-cpp-python with CUDA+Vulkan (preferred) or fallbacks
if [ "$CUDA_AVAILABLE" = true ] && [ "$VULKAN_AVAILABLE" = true ]; then
echo -e "${YELLOW}Installing stable-diffusion-cpp-python with CUDA+Vulkan support...${NC}"
CMAKE_ARGS="$SD_CMAKE_ARGS -DSD_CUDA=ON -DSD_VULKAN=ON" pip install stable-diffusion-cpp-python --no-cache-dir || {
echo -e "${YELLOW}CUDA+Vulkan build failed, trying CUDA only...${NC}"
CMAKE_ARGS="$SD_CMAKE_ARGS -DSD_CUDA=ON" pip install stable-diffusion-cpp-python --no-cache-dir || \
echo -e "${YELLOW}Warning: stable-diffusion-cpp-python not available${NC}"
}
elif [ "$CUDA_AVAILABLE" = true ]; then
echo -e "${YELLOW}Installing stable-diffusion-cpp-python with CUDA support...${NC}"
CMAKE_ARGS="$SD_CMAKE_ARGS -DSD_CUDA=ON" pip install stable-diffusion-cpp-python --no-cache-dir || \
echo -e "${YELLOW}Warning: stable-diffusion-cpp-python not available${NC}"
elif [ "$VULKAN_AVAILABLE" = true ]; then
echo -e "${YELLOW}Installing stable-diffusion-cpp-python with Vulkan support...${NC}"
CMAKE_ARGS="$SD_CMAKE_ARGS -DSD_VULKAN=ON" pip install stable-diffusion-cpp-python --no-cache-dir || \
echo -e "${YELLOW}Warning: stable-diffusion-cpp-python not available${NC}"
elif [ "$OPENCL_AVAILABLE" = true ]; then
echo -e "${YELLOW}Installing stable-diffusion-cpp-python with OpenCL support...${NC}"
CMAKE_ARGS="$SD_CMAKE_ARGS -DSD_OPENCL=ON" pip install stable-diffusion-cpp-python --no-cache-dir || \
echo -e "${YELLOW}Warning: stable-diffusion-cpp-python not available${NC}"
else
echo -e "${YELLOW}Skipping OpenCL (stable-diffusion-cpp-python) - OpenCL not available${NC}"
echo -e "${YELLOW}Skipping GPU-accelerated stable-diffusion-cpp-python - no GPU backend available${NC}"
fi
# Install additional requirements
......@@ -667,8 +688,11 @@ elif [ "$BACKEND" = "all" ]; then
echo "Available backends:"
[ "$CUDA_AVAILABLE" = true ] && echo " ✓ NVIDIA/CUDA (PyTorch)"
[ "$CUDA_AVAILABLE" = true ] && echo " ✓ CUDA (llama-cpp-python)"
[ "$CUDA_AVAILABLE" = true ] && [ "$VULKAN_AVAILABLE" = true ] && echo " ✓ CUDA+Vulkan (stable-diffusion-cpp-python)"
[ "$CUDA_AVAILABLE" = true ] && [ "$VULKAN_AVAILABLE" != true ] && echo " ✓ CUDA (stable-diffusion-cpp-python)"
[ "$CUDA_AVAILABLE" != true ] && [ "$VULKAN_AVAILABLE" = true ] && echo " ✓ Vulkan (stable-diffusion-cpp-python)"
[ "$VULKAN_AVAILABLE" = true ] && echo " ✓ Vulkan (llama-cpp-python)"
[ "$OPENCL_AVAILABLE" = true ] && echo " ✓ OpenCL (stable-diffusion-cpp-python)"
[ "$OPENCL_AVAILABLE" = true ] && [ "$CUDA_AVAILABLE" != true ] && [ "$VULKAN_AVAILABLE" != true ] && echo " ✓ OpenCL (stable-diffusion-cpp-python)"
echo " ✓ CPU (fallback for all)"
if [ "$FLASH" = true ] && [ "$CUDA_AVAILABLE" = true ]; then
echo ""
......
......@@ -15,10 +15,13 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>.
"""Authentication and session management for admin dashboard."""
import base64
import hashlib
import hmac
import json
import os
import secrets
import threading
import time
from pathlib import Path
from typing import Any, Dict, Optional
......@@ -43,35 +46,62 @@ def get_or_create_secret(config_dir: Path) -> bytes:
def hash_password(password: str) -> str:
"""Hash a password using SHA-256 with salt.
"""Hash a password using argon2 (preferred) or scrypt as fallback.
In production, use argon2 or bcrypt. This is a minimal implementation
for environments where those libraries aren't available.
New hashes are always produced with a proper key-derivation function and
a per-password random salt. The legacy SHA-256/static-salt format is
only retained for *verification* of pre-existing hashes.
"""
# Use SHA-256 with a pepper-like secret for basic hashing
# Real implementation should use argon2 from main.py
salt = b'static_salt_' # In production, use per-user random salt
return hashlib.sha256(salt + password.encode()).hexdigest()
try:
from argon2 import PasswordHasher
ph = PasswordHasher()
return ph.hash(password)
except ImportError:
pass
# scrypt fallback: encode as "scrypt:<b64salt>:<b64key>"
salt = os.urandom(16)
key = hashlib.scrypt(password.encode(), salt=salt, n=2**14, r=8, p=1)
return "scrypt:" + base64.b64encode(salt).decode() + ":" + base64.b64encode(key).decode()
def verify_password(password: str, password_hash: str) -> bool:
"""Verify a password against its hash."""
# Try argon2 first
"""Verify a password against its hash.
Supports argon2, scrypt (new format), and the legacy SHA-256/static-salt
format so that old stored hashes continue to work.
"""
# --- argon2 ---
try:
from argon2 import PasswordHasher
from argon2.exceptions import VerifyMismatchError
from argon2.exceptions import VerifyMismatchError, InvalidHashError
ph = PasswordHasher()
try:
return ph.verify(password_hash, password)
except VerifyMismatchError:
return False
except InvalidHashError:
pass # not an argon2 hash; fall through
except Exception:
pass
except ImportError:
pass
# Fallback to simple hash
return hash_password(password) == password_hash
# --- scrypt ---
if password_hash.startswith("scrypt:"):
try:
parts = password_hash.split(":")
if len(parts) == 3:
salt = base64.b64decode(parts[1])
stored_key = base64.b64decode(parts[2])
new_key = hashlib.scrypt(password.encode(), salt=salt, n=2**14, r=8, p=1)
return hmac.compare_digest(new_key, stored_key)
except Exception:
pass
return False
# --- legacy SHA-256 with static salt (read-only; never written for new passwords) ---
legacy = hashlib.sha256(b'static_salt_' + password.encode()).hexdigest()
return hmac.compare_digest(legacy, password_hash)
class SessionManager:
......@@ -81,7 +111,7 @@ class SessionManager:
self.config_dir = config_dir
self.secret = get_or_create_secret(config_dir)
self.session_timeout = timedelta(minutes=session_timeout_minutes)
self._lock = __import__('threading').Lock()
self._lock = threading.Lock()
def _load_auth_data(self) -> Dict[str, Any]:
"""Load auth.json data."""
......
This diff is collapsed.
......@@ -8,8 +8,8 @@
--border: #1A1D28;
--border-2: #252836;
--text: #DDE1F0;
--text-2: #636880;
--text-3: #2E3145;
--text-2: #8B90A8;
--text-3: #555A72;
--accent: #6366F1;
--accent-s: rgba(99,102,241,.12);
--green: #34D399;
......
This source diff could not be displayed because it is too large. You can view the blob instead.
......@@ -28,15 +28,17 @@
<div class="stat-value" id="req-total">0</div>
<div class="stat-sub"><span id="req-active">0</span> active</div>
</div>
<div class="stat">
<div class="stat" id="vram-card" style="display:none">
<div class="stat-label">VRAM</div>
<div class="stat-value" id="vram-pct"></div>
<div class="stat-value" id="vram-pct" style="font-size:2rem"></div>
<div class="progress" style="margin-top:.625rem">
<div class="progress-fill" id="vram-bar" style="width:0%"></div>
</div>
<div class="progress-labels">
<span id="vram-used"></span><span id="vram-total"></span>
<div class="progress-labels" style="color:var(--text-1);font-size:12px;margin-top:.4rem">
<span id="vram-used"></span><span id="vram-free"></span>
</div>
<div style="font-size:11.5px;color:var(--text-2);margin-top:.2rem;font-family:var(--mono)" id="vram-total-line"></div>
<div class="stat-sub" id="vram-gpu" style="margin-top:.25rem"></div>
</div>
</div>
......@@ -85,13 +87,25 @@ async function poll() {
document.getElementById('active-models').innerHTML = html || '<span class="muted small">No models loaded</span>';
if (d.vram) {
const pct = Math.round(d.vram.used / d.vram.total * 100);
document.getElementById('vram-pct').textContent = pct + '%';
document.getElementById('vram-bar').style.width = pct + '%';
document.getElementById('vram-used').textContent = d.vram.used.toFixed(1) + ' GB';
document.getElementById('vram-total').textContent = d.vram.total.toFixed(1) + ' GB';
document.getElementById('vram-card').style.display = '';
if (d.vram.free != null && d.vram.total) {
const usedPct = Math.round(d.vram.used / d.vram.total * 100);
document.getElementById('vram-pct').textContent = usedPct + '%';
document.getElementById('vram-bar').style.width = usedPct + '%';
document.getElementById('vram-used').textContent = d.vram.used.toFixed(1) + ' GB used';
document.getElementById('vram-free').textContent = d.vram.free.toFixed(1) + ' GB free';
document.getElementById('vram-total-line').textContent = d.vram.total.toFixed(1) + ' GB total';
} else {
document.getElementById('vram-pct').textContent = d.vram.total ? d.vram.total.toFixed(1) + ' GB' : '—';
document.getElementById('vram-bar').style.width = '0%';
document.getElementById('vram-used').textContent = '';
document.getElementById('vram-free').textContent = '';
document.getElementById('vram-total-line').textContent = '';
}
const gpuName = d.vram.gpu || '';
document.getElementById('vram-gpu').textContent = gpuName.length > 32 ? gpuName.slice(0, 32) + '…' : gpuName;
} else {
document.getElementById('vram-pct').textContent = 'N/A';
document.getElementById('vram-card').style.display = 'none';
}
if (d.requests) {
......
This diff is collapsed.
......@@ -45,6 +45,11 @@
<input type="text" id="s-cert" class="form-input" placeholder="/path/to/cert.pem">
</div>
</div>
<div class="form-row" style="margin-top:1rem;margin-bottom:0">
<label class="form-label">Request queue max size</label>
<input type="number" id="s-queue-max" class="form-input" placeholder="6" min="1" max="1000" style="max-width:160px">
<span class="form-hint">Maximum number of concurrent queued requests. Authenticated requests arriving when the queue is full receive a 429 response.</span>
</div>
</div>
<!-- Storage -->
......@@ -64,6 +69,48 @@
<span class="form-hint">Models will inherit this as default when configured</span>
</div>
</div>
<!-- Whisper Server -->
<div class="card mb-0" style="margin-top:1rem">
<div style="display:flex;align-items:center;justify-content:space-between;flex-wrap:wrap;gap:.5rem;margin-bottom:1rem">
<div class="card-title" style="margin:0">Whisper Server <span class="muted" style="font-size:11px;font-weight:400">(whisper.cpp native binary — recommended for AMD/Vulkan)</span></div>
<div style="display:flex;align-items:center;gap:.5rem">
<span id="ws-badge" class="muted small"></span>
<button class="btn btn-sm btn-secondary" onclick="wsStart()">Start</button>
<button class="btn btn-sm btn-danger" onclick="wsStop()">Stop</button>
</div>
</div>
<div style="display:grid;grid-template-columns:1fr 160px;gap:1rem;align-items:start">
<div class="form-row" style="margin:0">
<label class="form-label">Model ID <span class="muted">(used in API calls, e.g. whisper-base)</span></label>
<input type="text" id="ws-id" class="form-input" placeholder="whisper-server">
<span class="form-hint">The name clients use in the <code>model</code> field of transcription requests</span>
</div>
<div class="form-row" style="margin:0">
<label class="form-label">Port</label>
<input type="number" id="ws-port" class="form-input" placeholder="8744" min="1024" max="65535">
</div>
</div>
<div style="display:grid;grid-template-columns:1fr 160px;gap:1rem;align-items:start;margin-top:1rem">
<div class="form-row" style="margin:0">
<label class="form-label">whisper-server binary path</label>
<input type="text" id="ws-path" class="form-input" placeholder="/usr/local/bin/whisper-server">
</div>
<div class="form-row" style="margin:0">
<label class="form-label">GPU device index</label>
<input type="number" id="ws-gpu" class="form-input" placeholder="0" min="0">
</div>
</div>
<div class="form-row" style="margin-top:1rem;margin-bottom:0">
<label class="form-label">Model path <span class="muted">(GGUF whisper model, e.g. ggml-base.bin)</span></label>
<input type="text" id="ws-model" class="form-input" placeholder="/path/to/ggml-base.bin">
<span class="form-hint">Configure multiple instances by adding entries to <code>models.json</code> with <code>"backend": "whisper-server"</code></span>
</div>
<p class="form-hint" style="margin-top:.75rem;margin-bottom:0">
When configured, the transcription endpoint uses this subprocess instead of the Python faster-whisper module.
Saves settings to <code>config.json</code> and takes effect immediately (no restart needed).
</p>
</div>
{% endblock %}
{% block scripts %}
......@@ -89,13 +136,69 @@ async function loadSettings(){
document.getElementById('s-https').checked = !!d.server?.https;
document.getElementById('s-key').value = d.server?.https_key_path ?? '';
document.getElementById('s-cert').value = d.server?.https_cert_path ?? '';
document.getElementById('s-queue-max').value = d.server?.queue_max_size ?? 6;
document.getElementById('s-hf-cache').value = d.models?.hf_cache_dir ?? '';
document.getElementById('s-gguf-cache').value = d.models?.gguf_cache_dir ?? '';
document.getElementById('s-offload-dir').value = d.offload?.directory ?? './offload';
document.getElementById('ws-path').value = d.whisper?.server_path ?? '';
document.getElementById('ws-port').value = d.whisper?.server_port ?? 8744;
toggleHttps();
}catch(e){ showAlert('error','Failed to load settings: '+e.message); }
}
async function loadWsStatus(){
try{
const s = await fetch('/admin/api/whisper-server/status').then(r=>r.json());
const badge = document.getElementById('ws-badge');
// s is now a dict of {model_id: {running, model, url}}
const entries = Object.entries(s);
if(!entries.length){
badge.textContent = '○ not configured';
badge.style.color = 'var(--text-2)';
return;
}
const running = entries.filter(([,v])=>v.running);
if(running.length){
badge.textContent = `● ${running.length} running`;
badge.style.color = 'var(--green, #4ade80)';
} else {
badge.textContent = '○ stopped';
badge.style.color = 'var(--text-2)';
}
}catch(e){}
}
async function wsStart(){
const path = document.getElementById('ws-path').value.trim();
if(!path){ showAlert('error','Binary path required'); return; }
try{
const r = await fetch('/admin/api/whisper-server/start',{
method:'POST', headers:{'Content-Type':'application/json'},
body: JSON.stringify({
model_id: document.getElementById('ws-id').value.trim() || 'whisper-server',
server_path: path,
model_path: document.getElementById('ws-model').value.trim() || null,
port: parseInt(document.getElementById('ws-port').value) || 8744,
gpu_device: parseInt(document.getElementById('ws-gpu').value) || 0,
})
});
const d = await r.json();
if(d.success) showAlert('info','whisper-server started');
else showAlert('error','Failed to start whisper-server');
loadWsStatus();
}catch(e){ showAlert('error','Error: '+e.message); }
}
async function wsStop(){
const modelId = document.getElementById('ws-id').value.trim() || 'whisper-server';
await fetch('/admin/api/whisper-server/stop',{
method:'POST', headers:{'Content-Type':'application/json'},
body: JSON.stringify({model_id: modelId})
});
showAlert('info','whisper-server stopped');
loadWsStatus();
}
async function saveSettings(){
const strOrNull = id => document.getElementById(id).value.trim() || null;
const data = {
......@@ -105,6 +208,7 @@ async function saveSettings(){
https: document.getElementById('s-https').checked,
https_key_path: strOrNull('s-key'),
https_cert_path: strOrNull('s-cert'),
queue_max_size: parseInt(document.getElementById('s-queue-max').value) || 6,
},
models:{
hf_cache_dir: strOrNull('s-hf-cache'),
......@@ -112,7 +216,11 @@ async function saveSettings(){
},
offload:{
directory: document.getElementById('s-offload-dir').value.trim() || './offload',
}
},
whisper:{
server_path: document.getElementById('ws-path').value.trim() || null,
server_port: parseInt(document.getElementById('ws-port').value) || 8744,
},
};
try{
const r = await fetch('/admin/api/settings',{
......@@ -125,5 +233,7 @@ async function saveSettings(){
}
loadSettings();
loadWsStatus();
setInterval(loadWsStatus, 5000);
</script>
{% endblock %}
......@@ -19,12 +19,16 @@ FastAPI application module for codai API.
Contains the FastAPI app initialization, lifespan, and core endpoints.
"""
import logging
import os
from contextlib import asynccontextmanager
from typing import List
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import FileResponse, JSONResponse
logger = logging.getLogger(__name__)
# Import from codai modules
from codai.pydantic.textrequest import ModelList
from codai.models.manager import model_manager, multi_model_manager
......@@ -89,11 +93,19 @@ from codai.api.text import router as text_router
from codai.api.video import router as video_router
from codai.api.audio_gen import router as audio_gen_router
from codai.api.embeddings import router as embeddings_router
from codai.api.pipelines import router as pipelines_router
from codai.api.custom_pipelines import router as custom_pipelines_router
from codai.api.voice_clone import router as voice_clone_router
from codai.api.voice_convert import router as voice_convert_router
from codai.api.faceswap import router as faceswap_router
from codai.api.characters import router as characters_router
from codai.admin.routes import router as admin_router
# Import and add middleware
from codai.api.log import log_requests
from codai.api.ratelimit import RateLimitMiddleware
app.middleware("http")(log_requests)
app.add_middleware(RateLimitMiddleware)
# Mount static files for admin dashboard
from fastapi.staticfiles import StaticFiles
......@@ -110,6 +122,12 @@ app.include_router(text_router)
app.include_router(video_router)
app.include_router(audio_gen_router)
app.include_router(embeddings_router)
app.include_router(pipelines_router)
app.include_router(custom_pipelines_router)
app.include_router(voice_clone_router)
app.include_router(voice_convert_router)
app.include_router(faceswap_router)
app.include_router(characters_router)
app.include_router(admin_router)
......@@ -133,11 +151,14 @@ async def list_models():
@app.get("/v1/files/{filename}")
async def get_file(filename: str):
"""Serve uploaded/generated files."""
print(f"DEBUG get_file: filename={filename}, global_file_path={global_file_path}")
if global_file_path:
import os
file_path = os.path.join(global_file_path, filename)
print(f"DEBUG get_file: full path={file_path}, exists={os.path.exists(file_path)}")
if os.path.exists(file_path):
return FileResponse(file_path)
if not global_file_path:
raise HTTPException(status_code=404, detail="File not found")
# Prevent path traversal: resolve to real paths and confirm the result
# stays inside the configured directory.
safe_base = os.path.realpath(global_file_path)
candidate = os.path.realpath(os.path.join(global_file_path, filename))
if not (candidate == safe_base or candidate.startswith(safe_base + os.sep)):
raise HTTPException(status_code=403, detail="Access denied")
if not os.path.isfile(candidate):
raise HTTPException(status_code=404, detail="File not found")
return FileResponse(candidate)
\ No newline at end of file
# CoderAI - OpenAI-compatible API server
# Copyright (C) 2026 Stefy Lanza <stefy@nexlab.net>
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
Character profile endpoints.
Saved character profiles are named collections of reference images used to
maintain visual consistency of a character across multiple video generations.
POST /v1/characters – save / update a character profile
GET /v1/characters – list all saved profiles (no images)
GET /v1/characters/{name} – get a profile including base64 images
DELETE /v1/characters/{name} – delete a profile
"""
import base64
import json
import os
import time
from typing import List, Optional
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel, ConfigDict
router = APIRouter()
_CHARS_DIR: Optional[str] = None
def set_global_args(args):
global _CHARS_DIR
base = getattr(args, 'file_path', None) or os.path.expanduser('~/.coderai')
root = base if os.path.isdir(base) else (os.path.dirname(base) if base else os.path.expanduser('~/.coderai'))
_CHARS_DIR = os.path.join(root, 'characters')
os.makedirs(_CHARS_DIR, exist_ok=True)
def set_global_file_path(path: str):
pass # not needed for characters
def _chars_dir() -> str:
if _CHARS_DIR:
return _CHARS_DIR
d = os.path.expanduser('~/.coderai/characters')
os.makedirs(d, exist_ok=True)
return d
def _char_dir(name: str) -> str:
return os.path.join(_chars_dir(), name)
# ── Pydantic models ───────────────────────────────────────────────────────────
class CharacterImage(BaseModel):
label: Optional[str] = None # e.g. "front", "side", "close-up"
data: str # base64 image (with or without data: prefix)
model_config = ConfigDict(extra="allow")
class CharacterSaveRequest(BaseModel):
name: str
description: Optional[str] = ""
images: List[CharacterImage] # one or more reference images
model_config = ConfigDict(extra="allow")
class CharacterProfile(BaseModel):
name: str
description: Optional[str] = ""
image_count: int
created_at: int
images: Optional[List[CharacterImage]] = None # only populated on GET /{name}
model_config = ConfigDict(extra="allow")
# ── Helpers ───────────────────────────────────────────────────────────────────
def _save_character(name: str, description: str, images: List[CharacterImage]) -> dict:
cdir = _char_dir(name)
os.makedirs(cdir, exist_ok=True)
img_files = []
for i, img in enumerate(images):
raw = img.data
if raw.startswith('data:'):
_, b64 = raw.split(',', 1)
else:
b64 = raw
img_bytes = base64.b64decode(b64)
# Detect PNG vs JPEG from magic bytes
ext = '.png' if img_bytes[:4] == b'\x89PNG' else '.jpg'
fname = f"ref{i:02d}{ext}"
fpath = os.path.join(cdir, fname)
with open(fpath, 'wb') as f:
f.write(img_bytes)
img_files.append({'file': fname, 'label': img.label or f'ref{i}'})
meta = {
'name': name,
'description': description,
'images': img_files,
'image_count': len(img_files),
'created_at': int(time.time()),
}
with open(os.path.join(cdir, 'meta.json'), 'w') as f:
json.dump(meta, f)
return meta
def _load_character_meta(name: str) -> Optional[dict]:
meta_path = os.path.join(_char_dir(name), 'meta.json')
if not os.path.exists(meta_path):
return None
with open(meta_path) as f:
return json.load(f)
def _load_character_images(name: str) -> List[CharacterImage]:
meta = _load_character_meta(name)
if not meta:
return []
cdir = _char_dir(name)
result = []
for img_info in meta.get('images', []):
fpath = os.path.join(cdir, img_info['file'])
if not os.path.exists(fpath):
continue
with open(fpath, 'rb') as f:
raw = f.read()
ext = img_info['file'].rsplit('.', 1)[-1]
mime = 'image/png' if ext == 'png' else 'image/jpeg'
b64 = base64.b64encode(raw).decode()
result.append(CharacterImage(
label=img_info.get('label'),
data=f"data:{mime};base64,{b64}",
))
return result
def _list_characters() -> list:
d = _chars_dir()
profiles = []
for entry in os.scandir(d):
if entry.is_dir():
meta = _load_character_meta(entry.name)
if meta:
profiles.append({k: v for k, v in meta.items() if k != 'images'})
return sorted(profiles, key=lambda p: p.get('created_at', 0))
def resolve_character_profiles(profile_names: List[str]) -> List[str]:
"""Resolve saved profile names → flat list of base64 image strings."""
out = []
for name in profile_names:
for img in _load_character_images(name):
out.append(img.data)
return out
# ── Endpoints ─────────────────────────────────────────────────────────────────
@router.post("/v1/characters")
async def save_character(req: CharacterSaveRequest):
"""Save or update a named character profile."""
if not req.name or '/' in req.name or '..' in req.name:
raise HTTPException(status_code=400, detail="Invalid character name")
if not req.images:
raise HTTPException(status_code=400, detail="At least one reference image required")
meta = _save_character(req.name, req.description or '', req.images)
return {"ok": True, "name": meta['name'], "image_count": meta['image_count']}
@router.get("/v1/characters")
async def list_characters():
"""List all saved character profiles (metadata only, no images)."""
return {"characters": _list_characters()}
@router.get("/v1/characters/{name}")
async def get_character(name: str):
"""Get a character profile including its reference images as base64."""
meta = _load_character_meta(name)
if not meta:
raise HTTPException(status_code=404, detail=f"Character '{name}' not found")
images = _load_character_images(name)
return {
"name": meta['name'],
"description": meta.get('description', ''),
"image_count": meta['image_count'],
"created_at": meta['created_at'],
"images": [img.model_dump() for img in images],
}
@router.delete("/v1/characters/{name}")
async def delete_character(name: str):
"""Delete a character profile."""
cdir = _char_dir(name)
if not os.path.isdir(cdir):
raise HTTPException(status_code=404, detail=f"Character '{name}' not found")
import shutil
shutil.rmtree(cdir)
return {"ok": True, "name": name}
This diff is collapsed.
"""
Face swap endpoint.
POST /v1/images/faceswap — swap face in image or video frames
"""
import asyncio
import base64
import io
import os
import subprocess
import tempfile
import time
from typing import Optional
import cv2
import numpy as np
from fastapi import APIRouter, HTTPException, Request
from PIL import Image
from pydantic import BaseModel, ConfigDict
from codai.api.images import save_image_response
router = APIRouter()
global_args = None
global_file_path = None
_INSWAPPER_MODEL_PATH = os.path.expanduser('~/.insightface/models/inswapper_128.onnx')
_INSWAPPER_HF_REPO = 'deepinsight/inswapper'
_INSWAPPER_HF_FILE = 'inswapper_128.onnx'
_face_app = None # FaceAnalysis singleton
_swapper = None # INSwapper singleton
def set_global_args(args):
global global_args
global_args = args
def set_global_file_path(path):
global global_file_path
global_file_path = path
def _ensure_model():
"""Download inswapper_128.onnx if not present."""
if os.path.exists(_INSWAPPER_MODEL_PATH):
return
os.makedirs(os.path.dirname(_INSWAPPER_MODEL_PATH), exist_ok=True)
print(f'Downloading inswapper_128.onnx from HuggingFace…')
try:
from huggingface_hub import hf_hub_download
path = hf_hub_download(
repo_id=_INSWAPPER_HF_REPO,
filename=_INSWAPPER_HF_FILE,
local_dir=os.path.dirname(_INSWAPPER_MODEL_PATH),
)
if path != _INSWAPPER_MODEL_PATH:
import shutil
shutil.move(path, _INSWAPPER_MODEL_PATH)
except Exception as e:
raise RuntimeError(f'Failed to download inswapper model: {e}')
def _get_face_app():
global _face_app
if _face_app is None:
from insightface.app import FaceAnalysis
_face_app = FaceAnalysis(name='buffalo_l', providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
_face_app.prepare(ctx_id=0, det_size=(640, 640))
return _face_app
def _get_swapper():
global _swapper
if _swapper is None:
_ensure_model()
from insightface.model_zoo import get_model
_swapper = get_model(_INSWAPPER_MODEL_PATH, download=False)
_swapper.prepare(ctx_id=0)
return _swapper
def _decode_image(data: str) -> np.ndarray:
"""Decode base64 or data-URI image to BGR numpy array."""
if data.startswith('data:'):
_, b64 = data.split(',', 1)
data = b64
raw = base64.b64decode(data)
arr = np.frombuffer(raw, np.uint8)
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
if img is None:
raise ValueError('Could not decode image')
return img
def _swap_faces(source_img: np.ndarray, target_img: np.ndarray) -> np.ndarray:
"""Swap all faces in target_img with the face from source_img."""
app = _get_face_app()
swapper = _get_swapper()
src_faces = app.get(source_img)
if not src_faces:
raise ValueError('No face detected in source image')
src_face = src_faces[0]
tgt_faces = app.get(target_img)
if not tgt_faces:
return target_img # no face to swap in target, return as-is
result = target_img.copy()
for tgt_face in tgt_faces:
result = swapper.get(result, tgt_face, src_face, paste_back=True)
return result
def _decode_b64_or_url(data: str) -> bytes:
if data.startswith('data:'):
_, b64 = data.split(',', 1)
return base64.b64decode(b64)
if data.startswith('http'):
import urllib.request
with urllib.request.urlopen(data, timeout=30) as r:
return r.read()
return base64.b64decode(data)
# ---------------------------------------------------------------------------
# Request model
# ---------------------------------------------------------------------------
class FaceSwapRequest(BaseModel):
source_face: str # base64/data-URI image containing the source face
target: str # base64/data-URI image OR video to swap into
target_type: Optional[str] = 'image' # 'image' or 'video'
response_format: Optional[str] = 'url'
model_config = ConfigDict(extra='allow')
# ---------------------------------------------------------------------------
# Endpoint
# ---------------------------------------------------------------------------
@router.post('/v1/images/faceswap')
async def faceswap(request: FaceSwapRequest, http_request: Request = None):
"""
Swap the face from source_face into every face found in target.
target_type: 'image' (default) or 'video'.
"""
try:
_ensure_model()
except RuntimeError as e:
raise HTTPException(status_code=503, detail=str(e))
try:
src_img = _decode_image(request.source_face)
except Exception as e:
raise HTTPException(status_code=400, detail=f'Invalid source_face: {e}')
if request.target_type == 'video':
return await _faceswap_video(src_img, request, http_request)
else:
return await _faceswap_image(src_img, request, http_request)
async def _faceswap_image(src_img, request, http_request):
try:
tgt_img = _decode_image(request.target)
except Exception as e:
raise HTTPException(status_code=400, detail=f'Invalid target: {e}')
try:
result = await asyncio.get_event_loop().run_in_executor(
None, _swap_faces, src_img, tgt_img)
except ValueError as e:
raise HTTPException(status_code=422, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=f'Face swap failed: {e}')
pil_img = Image.fromarray(cv2.cvtColor(result, cv2.COLOR_BGR2RGB))
img_data = save_image_response(pil_img, request.response_format, http_request)
return {'created': int(time.time()), 'data': [img_data]}
async def _faceswap_video(src_img, request, http_request):
raw = _decode_b64_or_url(request.target)
temps = []
try:
# Write input video
in_tmp = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False)
in_tmp.write(raw); in_tmp.close()
in_path = in_tmp.name
temps.append(in_path)
# Extract frames
frames_dir = tempfile.mkdtemp()
temps.append(frames_dir)
subprocess.run(
['ffmpeg', '-y', '-i', in_path, f'{frames_dir}/%08d.png'],
capture_output=True, check=True)
# Get FPS for reassembly
probe = subprocess.run(
['ffprobe', '-v', 'error', '-select_streams', 'v:0',
'-show_entries', 'stream=r_frame_rate', '-of', 'default=nw=1:nk=1', in_path],
capture_output=True, text=True)
fps_str = probe.stdout.strip() or '25/1'
num, den = fps_str.split('/')
fps = float(num) / float(den)
# Swap faces in each frame
frame_files = sorted(os.listdir(frames_dir))
def _process_frames():
app = _get_face_app()
swapper = _get_swapper()
src_faces = app.get(src_img)
if not src_faces:
raise ValueError('No face detected in source image')
src_face = src_faces[0]
for fname in frame_files:
fpath = os.path.join(frames_dir, fname)
frame = cv2.imread(fpath)
if frame is None:
continue
tgt_faces = app.get(frame)
for tgt_face in tgt_faces:
frame = swapper.get(frame, tgt_face, src_face, paste_back=True)
cv2.imwrite(fpath, frame)
await asyncio.get_event_loop().run_in_executor(None, _process_frames)
# Reassemble video (copy original audio)
out_path = tempfile.mktemp(suffix='_swapped.mp4')
temps.append(out_path)
subprocess.run(
['ffmpeg', '-y', '-framerate', str(fps), '-i', f'{frames_dir}/%08d.png',
'-i', in_path, '-map', '0:v', '-map', '1:a?',
'-c:v', 'libx264', '-c:a', 'copy', '-shortest', out_path],
capture_output=True, check=True)
with open(out_path, 'rb') as f:
out_bytes = f.read()
if global_file_path:
import uuid
fname = f'{uuid.uuid4().hex}_swapped.mp4'
fpath = os.path.join(global_file_path, fname)
os.makedirs(global_file_path, exist_ok=True)
with open(fpath, 'wb') as f:
f.write(out_bytes)
host = http_request.headers.get('host', '127.0.0.1') if http_request else '127.0.0.1'
if ':' in host:
parts = host.split(':')
if len(parts) == 2 and parts[1].isdigit():
host = parts[0]
proto = 'https' if getattr(global_args, 'https', False) else 'http'
port = getattr(global_args, 'port', 8000) if global_args else 8000
data = [{'url': f'{proto}://{host}:{port}/v1/files/{fname}'}]
else:
data = [{'b64_mp4': base64.b64encode(out_bytes).decode()}]
return {'created': int(time.time()), 'data': data}
except subprocess.CalledProcessError as e:
raise HTTPException(status_code=500, detail=f'ffmpeg error: {e.stderr.decode()[:200]}')
except ValueError as e:
raise HTTPException(status_code=422, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=f'Video face swap failed: {e}')
finally:
import shutil
for t in temps:
try:
if os.path.isdir(t):
shutil.rmtree(t)
else:
os.unlink(t)
except Exception:
pass
This diff is collapsed.
This diff is collapsed.
# CoderAI - OpenAI-compatible API server
# Copyright (C) 2026 Stefy Lanza <stefy@nexlab.net>
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
"""Simple in-process token-bucket rate limiter middleware.
Each distinct (client-IP, route-prefix) pair gets its own bucket.
Limits are configured via RateLimitConfig. The defaults below are
intentionally generous; tighten them through the config file or CLI.
Endpoints covered:
/v1/chat/completions — expensive LLM inference
/v1/images/ — image generation
/v1/audio/ — TTS / STT / audio generation
/v1/video/ — video generation
/v1/embeddings — embedding
/v1/completions — legacy completions
"""
import time
import threading
from collections import defaultdict
from typing import Dict, Tuple
from fastapi import Request, Response
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
# Per-route-prefix defaults: (max_requests, window_seconds)
_DEFAULT_LIMITS: Dict[str, Tuple[int, int]] = {
"/v1/chat/completions": (60, 60),
"/v1/completions": (60, 60),
"/v1/images/": (30, 60),
"/v1/audio/": (60, 60),
"/v1/video/": (10, 60),
"/v1/embeddings": (120, 60),
}
# API prefixes that count against the request queue
_QUEUED_PREFIXES = ("/v1/",)
# Global toggle — set to False to disable rate limiting entirely.
RATE_LIMITING_ENABLED = True
class _Bucket:
"""Fixed-window counter."""
__slots__ = ("count", "window_start")
def __init__(self, now: float):
self.count = 0
self.window_start = now
class RateLimitMiddleware(BaseHTTPMiddleware):
"""Apply per-IP, per-route-prefix rate limiting to API endpoints."""
def __init__(self, app, limits: Dict[str, Tuple[int, int]] = None):
super().__init__(app)
self._limits = limits or _DEFAULT_LIMITS
# (client_ip, prefix) → _Bucket
self._buckets: Dict[Tuple[str, str], _Bucket] = defaultdict(lambda: _Bucket(time.monotonic()))
self._lock = threading.Lock()
def _get_prefix(self, path: str) -> str:
for prefix in self._limits:
if path.startswith(prefix):
return prefix
return ""
async def dispatch(self, request: Request, call_next):
if not RATE_LIMITING_ENABLED:
return await call_next(request)
path = request.url.path
# Queue-size enforcement for authenticated API requests
if any(path.startswith(p) for p in _QUEUED_PREFIXES):
from codai.queue.manager import queue_manager
if await queue_manager.is_full():
return JSONResponse(
status_code=429,
content={
"error": {
"message": "Server queue is full. Please retry later.",
"type": "rate_limit_error",
"code": 429,
}
},
headers={"Retry-After": "5"},
)
prefix = self._get_prefix(path)
if not prefix:
return await call_next(request)
max_req, window = self._limits[prefix]
client_ip = (
request.headers.get("x-forwarded-for", "").split(",")[0].strip()
or (request.client.host if request.client else "unknown")
)
key = (client_ip, prefix)
now = time.monotonic()
with self._lock:
bucket = self._buckets[key]
if now - bucket.window_start >= window:
bucket.count = 0
bucket.window_start = now
bucket.count += 1
count = bucket.count
remaining = max(0, max_req - count)
reset_at = int(time.time() + (window - (now - self._buckets[key].window_start)))
if count > max_req:
return JSONResponse(
status_code=429,
content={
"error": {
"message": "Rate limit exceeded. Please slow down.",
"type": "rate_limit_error",
"code": 429,
}
},
headers={
"X-RateLimit-Limit": str(max_req),
"X-RateLimit-Remaining": "0",
"X-RateLimit-Reset": str(reset_at),
"Retry-After": str(window),
},
)
response = await call_next(request)
response.headers["X-RateLimit-Limit"] = str(max_req)
response.headers["X-RateLimit-Remaining"] = str(remaining)
response.headers["X-RateLimit-Reset"] = str(reset_at)
return response
This diff is collapsed.
......@@ -23,8 +23,15 @@ import os
import tempfile
from fastapi import APIRouter, HTTPException, UploadFile, File, Form
from fastapi.responses import PlainTextResponse
from typing import Optional
# Maximum upload size: 100 MB
_MAX_AUDIO_BYTES = 100 * 1024 * 1024
# Safe audio extensions (user-supplied extension is NOT trusted for the suffix)
_SAFE_EXTENSIONS = {'.wav', '.mp3', '.ogg', '.flac', '.m4a', '.webm', '.mp4'}
# Import from codai modules
from codai.models.manager import multi_model_manager
......@@ -39,6 +46,71 @@ def set_global_args(args):
global_args = args
# =============================================================================
# Response formatting helpers
# =============================================================================
def _seconds_to_srt_time(s: float) -> str:
h = int(s // 3600)
m = int((s % 3600) // 60)
sec = s % 60
return f"{h:02d}:{m:02d}:{sec:06.3f}".replace('.', ',')
def _seconds_to_vtt_time(s: float) -> str:
h = int(s // 3600)
m = int((s % 3600) // 60)
sec = s % 60
return f"{h:02d}:{m:02d}:{sec:06.3f}"
def _format_response(fmt: str, text: str, segments: list):
"""Format a transcription result according to the requested response_format."""
fmt = (fmt or "json").lower()
if fmt == "text":
return PlainTextResponse(text)
if fmt == "srt":
lines = []
for i, seg in enumerate(segments, 1):
start = _seconds_to_srt_time(seg.get("start", 0))
end = _seconds_to_srt_time(seg.get("end", 0))
lines.append(f"{i}\n{start} --> {end}\n{seg['text'].strip()}\n")
srt_body = "\n".join(lines) if lines else f"1\n00:00:00,000 --> 00:00:00,000\n{text}\n"
return PlainTextResponse(srt_body, media_type="text/plain")
if fmt == "vtt":
lines = ["WEBVTT\n"]
for seg in segments:
start = _seconds_to_vtt_time(seg.get("start", 0))
end = _seconds_to_vtt_time(seg.get("end", 0))
lines.append(f"{start} --> {end}\n{seg['text'].strip()}\n")
if not segments:
lines.append(f"00:00:00.000 --> 00:00:00.000\n{text}\n")
return PlainTextResponse("\n".join(lines), media_type="text/vtt")
if fmt == "verbose_json":
return {
"task": "transcribe",
"language": "unknown",
"duration": segments[-1].get("end", 0) if segments else 0,
"text": text,
"segments": [
{
"id": i,
"start": s.get("start", 0),
"end": s.get("end", 0),
"text": s.get("text", "").strip(),
}
for i, s in enumerate(segments)
],
}
# Default: json
return {"text": text}
# =============================================================================
# Router and Endpoints
# =============================================================================
......@@ -58,17 +130,37 @@ async def create_transcription(
"""
Audio transcription endpoint (OpenAI-compatible).
"""
# Check if whisper-server is available FIRST
if multi_model_manager.whisper_server and multi_model_manager.whisper_server.is_running():
file_content = await file.read()
result = multi_model_manager.whisper_server.transcribe(
file_content,
language=language,
prompt=prompt
)
if len(file_content) > _MAX_AUDIO_BYTES:
raise HTTPException(status_code=413, detail="Audio file too large (max 100 MB)")
# Check if the requested model is a whisper-server instance
wsm = multi_model_manager.whisper_servers.get(model)
if wsm is None and multi_model_manager.whisper_server is not None:
# Legacy single-instance fallback: use it if no specific match
if not multi_model_manager.whisper_servers:
wsm = multi_model_manager.whisper_server
if wsm is not None:
ws_key = f"audio:{model}" if model in multi_model_manager.whisper_servers else "audio:whisper-server"
# Let the VRAM manager evict other models if needed
multi_model_manager.request_model(requested_model=model, model_type="audio")
# Start the subprocess if it isn't running (on-demand)
if not wsm.is_running():
wsm.start(getattr(wsm, '_model_path', None), gpu_device=getattr(wsm, '_gpu_device', 0))
if wsm.is_running():
multi_model_manager.models[ws_key] = wsm
multi_model_manager.active_in_vram = ws_key
multi_model_manager.models_in_vram.add(ws_key)
if wsm.is_running():
result = wsm.transcribe(file_content, language=language, prompt=prompt)
if "error" in result:
raise HTTPException(status_code=500, detail=result["error"])
return {"text": result.get("text", "")}
return _format_response(response_format, result.get("text", ""), [])
# Fall through to Python backends if subprocess failed to start
# Use the manager to resolve the model and manage VRAM
model_info = multi_model_manager.request_model(
......@@ -90,11 +182,13 @@ async def create_transcription(
detail="Audio transcription not configured. Use --audio-model or --whisper-server."
)
# Read the uploaded file
file_content = await file.read()
# Determine a safe file extension from the upload's content-type or filename,
# never trusting the raw user-supplied value for arbitrary suffixes.
raw_ext = os.path.splitext(file.filename or '')[1].lower()
safe_ext = raw_ext if raw_ext in _SAFE_EXTENSIONS else '.wav'
# Save to temp file (needed for some backends)
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as tmp:
with tempfile.NamedTemporaryFile(delete=False, suffix=safe_ext) as tmp:
tmp.write(file_content)
tmp_path = tmp.name
......@@ -104,41 +198,27 @@ async def create_transcription(
from faster_whisper import WhisperModel
if whisper_model is None:
print(f"Loading faster-whisper model: {model_name}")
# Determine compute type - always use int8 for CPU
compute_type = "int8"
# Load the model
whisper_model = WhisperModel(
model_name,
device="cpu", # Always use CPU - faster-whisper CUDA doesn't work with AMD
compute_type=compute_type,
device="cpu",
compute_type="int8",
)
# Cache the model
multi_model_manager.add_model(model_key, whisper_model)
multi_model_manager.current_model_key = model_key
print(f"Loaded faster-whisper model: {model_name}")
# Run transcription
segments, info = whisper_model.transcribe(
raw_segments, _ = whisper_model.transcribe(
tmp_path,
language=language,
initial_prompt=prompt,
temperature=temperature,
)
# Collect all segments
text_parts = []
for segment in segments:
text_parts.append(segment.text)
full_text = "".join(text_parts)
return {
"text": full_text.strip()
}
# Materialise the generator so we have all segment data
segments = [
{"start": s.start, "end": s.end, "text": s.text}
for s in raw_segments
]
full_text = "".join(s["text"] for s in segments)
return _format_response(response_format, full_text.strip(), segments)
except ImportError:
pass
......@@ -148,41 +228,26 @@ async def create_transcription(
import whispercpp
if whisper_model is None:
print(f"Loading whispercpp model: {model_name}")
# Check if it's a built-in model name
if model_name in ['tiny.en', 'tiny', 'base.en', 'base', 'small.en', 'small', 'medium.en', 'medium', 'large-v1', 'large']:
# It's a built-in model name
whisper_model = whispercpp.Whisper.from_pretrained(model_name)
else:
# It's a path to a GGUF file
whisper_model = whispercpp.Whisper.from_pretrained(model_name)
# Cache the model
multi_model_manager.add_model(model_key, whisper_model)
multi_model_manager.current_model_key = model_key
print(f"Loaded whispercpp model: {model_name}")
# Run transcription
result = whisper_model.transcribe(tmp_path)
# Extract text from result
text = ""
if hasattr(result, 'text'):
text = result.text
elif isinstance(result, dict):
text = result.get('text', '')
elif isinstance(result, list):
# Some versions return a list of segments
for segment in result:
if hasattr(segment, 'text'):
text += segment.text
elif isinstance(segment, dict):
text += segment.get('text', '')
return {
"text": text.strip()
}
# whispercpp does not expose per-segment timestamps easily
return _format_response(response_format, text.strip(), [])
except ImportError as e:
raise HTTPException(
......
......@@ -263,7 +263,39 @@ def _apply_camera_motion(kw: dict, camera_motion: str):
kw['camera_motion'] = camera_motion
def _apply_character_refs(kw: dict, character_references: List[str], strength: float):
def _resolve_character_inputs(request) -> tuple[List[str], List[str]]:
"""Return (flat_image_list, name_list) from any combination of request fields."""
images: List[str] = []
names: List[str] = []
# 1. Expand named saved profiles
if request.character_profiles:
try:
from codai.api.characters import resolve_character_profiles
images += resolve_character_profiles(request.character_profiles)
names += list(request.character_profiles)
except Exception:
pass
# 2. Named character slots [{name, images:[...]}, ...]
if request.characters:
for slot in request.characters:
slot_imgs = slot.get('images') or []
images += slot_imgs
if slot.get('name'):
names.append(slot['name'])
# 3. Legacy flat list
if request.character_references:
images += list(request.character_references)
if request.character_names:
names += list(request.character_names)
return images, names
def _apply_character_refs(kw: dict, character_references: List[str], strength: float,
names: Optional[List[str]] = None):
"""Apply character reference images to pipeline kwargs."""
if not character_references:
return
......@@ -291,8 +323,13 @@ def _generate_video(pipe, request: VideoGenerationRequest):
_apply_camera_motion(kw, request.camera_motion)
if request.character_references:
_apply_character_refs(kw, request.character_references, request.character_strength or 0.8)
char_images, char_names = _resolve_character_inputs(request)
if char_images:
_apply_character_refs(kw, char_images, request.character_strength or 0.8, char_names)
# Prepend character names to prompt for better conditioning
if char_names and kw.get('prompt'):
names_hint = ', '.join(char_names)
kw['prompt'] = f"{names_hint}. {kw['prompt']}"
init_src = request.init_image or request.image
......@@ -359,35 +396,49 @@ def _ffmpeg_upscale(path: str, factor: int, temps: list) -> str:
scale = f"scale=iw*{factor}:ih*{factor}:flags=lanczos"
cmd = ['ffmpeg', '-y', '-i', path, '-vf', scale, '-c:a', 'copy', out]
r = subprocess.run(cmd, capture_output=True)
if r.returncode == 0:
return out
if r.returncode != 0:
import logging
logging.getLogger(__name__).warning(
"ffmpeg upscale failed (rc=%d): %s", r.returncode, r.stderr.decode(errors='replace')
)
return path # fallback to original if ffmpeg fails
return out
def _rife_interpolate(path: str, multiplier: int, temps: list) -> str:
out = tempfile.mktemp(suffix='_rife.mp4')
temps.append(out)
# Try rife-ncnn-vulkan binary if available
import shutil
import logging, shutil
_log = logging.getLogger(__name__)
if shutil.which('rife-ncnn-vulkan'):
frames_dir = tempfile.mkdtemp()
out_dir = tempfile.mkdtemp()
temps += [frames_dir, out_dir]
subprocess.run(['ffmpeg', '-y', '-i', path, f'{frames_dir}/%08d.png'],
r = subprocess.run(['ffmpeg', '-y', '-i', path, f'{frames_dir}/%08d.png'],
capture_output=True)
subprocess.run(['rife-ncnn-vulkan', '-i', frames_dir, '-o', out_dir,
'-m', f'rife-v4'], capture_output=True)
subprocess.run(['ffmpeg', '-y', '-r', str(multiplier * 8), '-i',
if r.returncode != 0:
_log.warning("ffmpeg frame extraction failed: %s", r.stderr.decode(errors='replace'))
else:
r = subprocess.run(['rife-ncnn-vulkan', '-i', frames_dir, '-o', out_dir,
'-m', 'rife-v4'], capture_output=True)
if r.returncode != 0:
_log.warning("rife-ncnn-vulkan failed: %s", r.stderr.decode(errors='replace'))
else:
r = subprocess.run(['ffmpeg', '-y', '-r', str(multiplier * 8), '-i',
f'{out_dir}/%08d.png', '-c:v', 'libx264', out],
capture_output=True)
if os.path.exists(out):
if r.returncode != 0:
_log.warning("ffmpeg reassembly failed: %s", r.stderr.decode(errors='replace'))
elif os.path.exists(out):
return out
# Simple ffmpeg minterpolate fallback
fps_expr = f"fps=fps={multiplier}*source_fps"
cmd = ['ffmpeg', '-y', '-i', path, '-filter:v',
f'minterpolate=fps={multiplier * 8}', '-c:a', 'copy', out]
r = subprocess.run(cmd, capture_output=True)
return out if r.returncode == 0 else path
if r.returncode != 0:
_log.warning("ffmpeg minterpolate failed: %s", r.stderr.decode(errors='replace'))
return path
return out
def _add_audio_to_video(path: str, request: VideoGenerationRequest,
......
"""
Voice cloning endpoints.
POST /v1/audio/clone — synthesize speech in a cloned voice
GET /v1/audio/voices — list saved voice profiles
POST /v1/audio/voices — save a named voice profile (ref audio + transcript)
DELETE /v1/audio/voices/{name} — delete a voice profile
"""
import asyncio
import base64
import io
import json
import os
import tempfile
import time
from typing import Optional
from fastapi import APIRouter, HTTPException, Request, UploadFile, File, Form
from pydantic import BaseModel, ConfigDict
router = APIRouter()
global_args = None
global_file_path = None
# Directory where voice profiles are stored
_VOICES_DIR: Optional[str] = None
def set_global_args(args):
global global_args, _VOICES_DIR
global_args = args
# Store voice profiles alongside output files, or in a default location
base = getattr(args, 'file_path', None) or os.path.expanduser('~/.coderai/voices')
_VOICES_DIR = os.path.join(base if os.path.isdir(base) else os.path.dirname(base) if base else os.path.expanduser('~/.coderai'), 'voices')
os.makedirs(_VOICES_DIR, exist_ok=True)
def set_global_file_path(path):
global global_file_path
global_file_path = path
def _voices_dir() -> str:
if _VOICES_DIR:
return _VOICES_DIR
d = os.path.expanduser('~/.coderai/voices')
os.makedirs(d, exist_ok=True)
return d
def _voice_path(name: str) -> str:
return os.path.join(_voices_dir(), name)
def _list_voices() -> list:
d = _voices_dir()
voices = []
for entry in os.scandir(d):
if entry.is_dir():
meta_path = os.path.join(entry.path, 'meta.json')
if os.path.exists(meta_path):
with open(meta_path) as f:
meta = json.load(f)
voices.append(meta)
return sorted(voices, key=lambda v: v.get('created_at', 0))
def _save_voice(name: str, audio_bytes: bytes, audio_ext: str, transcript: str, description: str = '') -> dict:
vdir = _voice_path(name)
os.makedirs(vdir, exist_ok=True)
audio_file = os.path.join(vdir, f'ref{audio_ext}')
with open(audio_file, 'wb') as f:
f.write(audio_bytes)
meta = {
'name': name,
'description': description,
'transcript': transcript,
'audio_file': audio_file,
'audio_ext': audio_ext,
'created_at': int(time.time()),
}
with open(os.path.join(vdir, 'meta.json'), 'w') as f:
json.dump(meta, f)
return meta
def _load_voice(name: str) -> Optional[dict]:
meta_path = os.path.join(_voice_path(name), 'meta.json')
if not os.path.exists(meta_path):
return None
with open(meta_path) as f:
return json.load(f)
def _decode_audio(data: str) -> tuple[bytes, str]:
"""Decode base64 audio data, return (bytes, ext)."""
if data.startswith('data:'):
mime, b64 = data.split(',', 1)
ext = '.' + mime.split('/')[1].split(';')[0]
return base64.b64decode(b64), ext
return base64.b64decode(data), '.wav'
def _f5tts_clone(ref_audio_path: str, ref_text: str, gen_text: str,
speed: float = 1.0, seed: Optional[int] = None) -> bytes:
"""Run F5-TTS voice cloning, return WAV bytes."""
from f5_tts.api import F5TTS
import soundfile as sf
import numpy as np
device = None
if global_args:
import torch
if torch.cuda.is_available():
device = 'cuda'
tts = F5TTS(device=device)
wav, sr, _ = tts.infer(
ref_file=ref_audio_path,
ref_text=ref_text,
gen_text=gen_text,
speed=speed,
seed=seed,
show_info=lambda x: None,
progress=lambda x, **kw: x,
)
buf = io.BytesIO()
sf.write(buf, wav, sr, format='WAV')
return buf.getvalue()
def _save_audio_response(audio_bytes: bytes, http_request: Request) -> dict:
import uuid
filename = f"{uuid.uuid4().hex}.wav"
if global_file_path:
os.makedirs(global_file_path, exist_ok=True)
fpath = os.path.join(global_file_path, filename)
with open(fpath, 'wb') as f:
f.write(audio_bytes)
host = http_request.headers.get('host', '127.0.0.1') if http_request else '127.0.0.1'
if ':' in host:
parts = host.split(':')
if len(parts) == 2 and parts[1].isdigit():
host = parts[0]
use_https = getattr(global_args, 'https', False) if global_args else False
proto = 'https' if use_https else 'http'
port = getattr(global_args, 'port', 8000) if global_args else 8000
return {"url": f"{proto}://{host}:{port}/v1/files/{filename}"}
return {"b64_wav": base64.b64encode(audio_bytes).decode()}
# ---------------------------------------------------------------------------
# Voice profile management
# ---------------------------------------------------------------------------
@router.get("/v1/audio/voices")
async def list_voices():
"""List all saved voice profiles."""
return {"voices": _list_voices()}
@router.post("/v1/audio/voices")
async def create_voice(
name: str = Form(...),
transcript: str = Form(...),
description: str = Form(''),
audio: UploadFile = File(...),
):
"""Save a named voice profile from a reference audio file + transcript."""
if not name.replace('-', '').replace('_', '').isalnum():
raise HTTPException(status_code=400, detail="Voice name must be alphanumeric (hyphens/underscores allowed)")
audio_bytes = await audio.read()
ext = os.path.splitext(audio.filename)[1] or '.wav'
# Validate audio is readable
try:
import soundfile as sf, io as _io
sf.info(_io.BytesIO(audio_bytes))
except Exception as e:
raise HTTPException(status_code=400, detail=f"Invalid audio file: {e}")
meta = _save_voice(name, audio_bytes, ext, transcript, description)
return {"created": True, "voice": meta}
@router.delete("/v1/audio/voices/{name}")
async def delete_voice(name: str):
"""Delete a saved voice profile."""
import shutil
vdir = _voice_path(name)
if not os.path.exists(vdir):
raise HTTPException(status_code=404, detail=f"Voice '{name}' not found")
shutil.rmtree(vdir)
return {"deleted": True, "name": name}
# ---------------------------------------------------------------------------
# Voice cloning TTS
# ---------------------------------------------------------------------------
class VoiceCloneRequest(BaseModel):
text: str # text to synthesize
voice_name: Optional[str] = None # use a saved voice profile
ref_audio: Optional[str] = None # base64 reference audio (if not using saved voice)
ref_text: Optional[str] = None # transcript of ref_audio
speed: Optional[float] = 1.0
seed: Optional[int] = None
response_format: Optional[str] = "url"
model_config = ConfigDict(extra="allow")
@router.post("/v1/audio/clone")
async def clone_voice(request: VoiceCloneRequest, http_request: Request = None):
"""
Synthesize speech in a cloned voice using F5-TTS.
Provide either:
- voice_name: name of a saved voice profile
- ref_audio (base64) + ref_text: inline reference audio
"""
# Resolve reference audio
ref_audio_path = None
ref_text = request.ref_text or ''
temps = []
try:
if request.voice_name:
meta = _load_voice(request.voice_name)
if not meta:
raise HTTPException(status_code=404, detail=f"Voice '{request.voice_name}' not found")
ref_audio_path = meta['audio_file']
ref_text = ref_text or meta.get('transcript', '')
elif request.ref_audio:
audio_bytes, ext = _decode_audio(request.ref_audio)
tmp = tempfile.NamedTemporaryFile(suffix=ext, delete=False)
tmp.write(audio_bytes)
tmp.close()
ref_audio_path = tmp.name
temps.append(ref_audio_path)
else:
raise HTTPException(status_code=400, detail="Provide voice_name or ref_audio")
if not ref_text:
raise HTTPException(status_code=400, detail="ref_text (transcript of reference audio) is required for voice cloning")
try:
audio_bytes = await asyncio.get_event_loop().run_in_executor(
None, _f5tts_clone,
ref_audio_path, ref_text, request.text,
request.speed or 1.0, request.seed,
)
except ImportError:
raise HTTPException(status_code=501, detail="f5-tts not installed. Run: pip install f5-tts")
except Exception as e:
raise HTTPException(status_code=500, detail=f"Voice cloning failed: {e}")
result = _save_audio_response(audio_bytes, http_request)
return {"created": int(time.time()), "data": [result]}
finally:
for t in temps:
try:
os.unlink(t)
except Exception:
pass
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
# codai.openai — optional LiteLLM integration layer
This diff is collapsed.
......@@ -57,9 +57,14 @@ class VideoGenerationRequest(BaseModel):
camera_motion: Optional[str] = None # zoom-in | zoom-out | pan-left | pan-right | tilt-up | tilt-down | rotate
# ── Character consistency ─────────────────────────────────────────────
character_references: Optional[List[str]] = None # list of base64/URL reference images
# Each entry: {"name": "Alice", "images": ["b64...", ...]}
characters: Optional[List[dict]] = None
# Legacy flat list of base64/URL reference images (still accepted)
character_references: Optional[List[str]] = None
character_strength: Optional[float] = 0.8
character_names: Optional[List[str]] = None # optional names per reference
# Named saved profiles to load (resolved server-side)
character_profiles: Optional[List[str]] = None
# ── Audio generation / manipulation ──────────────────────────────────
add_audio: Optional[bool] = False
......
......@@ -33,6 +33,12 @@ class QueueManager:
self.model_loading: bool = False
self.model_name: Optional[str] = None
self.lock = asyncio.Lock()
self.max_size: int = 6
async def is_full(self) -> bool:
"""Return True if the queue has reached max_size."""
async with self.lock:
return len(self.waiting_requests) >= self.max_size
async def add_waiting(self, request_id: str) -> None:
"""Add a request to the waiting queue."""
......
videogen @ 04778e17
Subproject commit 04778e172a9a83d0778f566045f995828c6c3556
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment