wip: snapshot in-progress platform updates

parent 8fd1c5c2
This diff is collapsed.
...@@ -522,7 +522,14 @@ elif [ "$BACKEND" = "all" ]; then ...@@ -522,7 +522,14 @@ elif [ "$BACKEND" = "all" ]; then
pip install setproctitle || echo -e "${YELLOW}Warning: setproctitle failed (optional)${NC}" pip install setproctitle || echo -e "${YELLOW}Warning: setproctitle failed (optional)${NC}"
# Try stable-diffusion-cpp-python (disable WebM to avoid missing libwebm cmake submodule) # Try stable-diffusion-cpp-python (disable WebM to avoid missing libwebm cmake submodule)
# Use CUDA if available (detected later in this block, check nvcc now)
if command -v nvcc &> /dev/null || [ -d "/usr/local/cuda" ]; then
CMAKE_ARGS="$SD_CMAKE_ARGS -DSD_CUDA=ON" pip install stable-diffusion-cpp-python --no-cache-dir || \
CMAKE_ARGS="$SD_CMAKE_ARGS" pip install stable-diffusion-cpp-python || \
echo -e "${YELLOW}Warning: stable-diffusion-cpp-python failed (optional)${NC}"
else
CMAKE_ARGS="$SD_CMAKE_ARGS" pip install stable-diffusion-cpp-python || echo -e "${YELLOW}Warning: stable-diffusion-cpp-python failed (optional)${NC}" CMAKE_ARGS="$SD_CMAKE_ARGS" pip install stable-diffusion-cpp-python || echo -e "${YELLOW}Warning: stable-diffusion-cpp-python failed (optional)${NC}"
fi
} }
# Install PyTorch with CUDA support (for nvidia backend) # Install PyTorch with CUDA support (for nvidia backend)
...@@ -622,14 +629,28 @@ elif [ "$BACKEND" = "all" ]; then ...@@ -622,14 +629,28 @@ elif [ "$BACKEND" = "all" ]; then
echo -e "${YELLOW}Warning: Some Vulkan packages failed to install${NC}" echo -e "${YELLOW}Warning: Some Vulkan packages failed to install${NC}"
} }
# Try to install stable-diffusion-cpp-python with OpenCL # Try to install stable-diffusion-cpp-python with CUDA+Vulkan (preferred) or fallbacks
if [ "$OPENCL_AVAILABLE" = true ]; then if [ "$CUDA_AVAILABLE" = true ] && [ "$VULKAN_AVAILABLE" = true ]; then
echo -e "${YELLOW}Installing stable-diffusion-cpp-python with OpenCL support...${NC}" echo -e "${YELLOW}Installing stable-diffusion-cpp-python with CUDA+Vulkan support...${NC}"
CMAKE_ARGS="$SD_CMAKE_ARGS" pip install stable-diffusion-cpp-python || { CMAKE_ARGS="$SD_CMAKE_ARGS -DSD_CUDA=ON -DSD_VULKAN=ON" pip install stable-diffusion-cpp-python --no-cache-dir || {
echo -e "${YELLOW}Warning: stable-diffusion-cpp-python not available (requires CMake and build tools)${NC}" echo -e "${YELLOW}CUDA+Vulkan build failed, trying CUDA only...${NC}"
CMAKE_ARGS="$SD_CMAKE_ARGS -DSD_CUDA=ON" pip install stable-diffusion-cpp-python --no-cache-dir || \
echo -e "${YELLOW}Warning: stable-diffusion-cpp-python not available${NC}"
} }
elif [ "$CUDA_AVAILABLE" = true ]; then
echo -e "${YELLOW}Installing stable-diffusion-cpp-python with CUDA support...${NC}"
CMAKE_ARGS="$SD_CMAKE_ARGS -DSD_CUDA=ON" pip install stable-diffusion-cpp-python --no-cache-dir || \
echo -e "${YELLOW}Warning: stable-diffusion-cpp-python not available${NC}"
elif [ "$VULKAN_AVAILABLE" = true ]; then
echo -e "${YELLOW}Installing stable-diffusion-cpp-python with Vulkan support...${NC}"
CMAKE_ARGS="$SD_CMAKE_ARGS -DSD_VULKAN=ON" pip install stable-diffusion-cpp-python --no-cache-dir || \
echo -e "${YELLOW}Warning: stable-diffusion-cpp-python not available${NC}"
elif [ "$OPENCL_AVAILABLE" = true ]; then
echo -e "${YELLOW}Installing stable-diffusion-cpp-python with OpenCL support...${NC}"
CMAKE_ARGS="$SD_CMAKE_ARGS -DSD_OPENCL=ON" pip install stable-diffusion-cpp-python --no-cache-dir || \
echo -e "${YELLOW}Warning: stable-diffusion-cpp-python not available${NC}"
else else
echo -e "${YELLOW}Skipping OpenCL (stable-diffusion-cpp-python) - OpenCL not available${NC}" echo -e "${YELLOW}Skipping GPU-accelerated stable-diffusion-cpp-python - no GPU backend available${NC}"
fi fi
# Install additional requirements # Install additional requirements
...@@ -667,8 +688,11 @@ elif [ "$BACKEND" = "all" ]; then ...@@ -667,8 +688,11 @@ elif [ "$BACKEND" = "all" ]; then
echo "Available backends:" echo "Available backends:"
[ "$CUDA_AVAILABLE" = true ] && echo " ✓ NVIDIA/CUDA (PyTorch)" [ "$CUDA_AVAILABLE" = true ] && echo " ✓ NVIDIA/CUDA (PyTorch)"
[ "$CUDA_AVAILABLE" = true ] && echo " ✓ CUDA (llama-cpp-python)" [ "$CUDA_AVAILABLE" = true ] && echo " ✓ CUDA (llama-cpp-python)"
[ "$CUDA_AVAILABLE" = true ] && [ "$VULKAN_AVAILABLE" = true ] && echo " ✓ CUDA+Vulkan (stable-diffusion-cpp-python)"
[ "$CUDA_AVAILABLE" = true ] && [ "$VULKAN_AVAILABLE" != true ] && echo " ✓ CUDA (stable-diffusion-cpp-python)"
[ "$CUDA_AVAILABLE" != true ] && [ "$VULKAN_AVAILABLE" = true ] && echo " ✓ Vulkan (stable-diffusion-cpp-python)"
[ "$VULKAN_AVAILABLE" = true ] && echo " ✓ Vulkan (llama-cpp-python)" [ "$VULKAN_AVAILABLE" = true ] && echo " ✓ Vulkan (llama-cpp-python)"
[ "$OPENCL_AVAILABLE" = true ] && echo " ✓ OpenCL (stable-diffusion-cpp-python)" [ "$OPENCL_AVAILABLE" = true ] && [ "$CUDA_AVAILABLE" != true ] && [ "$VULKAN_AVAILABLE" != true ] && echo " ✓ OpenCL (stable-diffusion-cpp-python)"
echo " ✓ CPU (fallback for all)" echo " ✓ CPU (fallback for all)"
if [ "$FLASH" = true ] && [ "$CUDA_AVAILABLE" = true ]; then if [ "$FLASH" = true ] && [ "$CUDA_AVAILABLE" = true ]; then
echo "" echo ""
......
...@@ -15,10 +15,13 @@ ...@@ -15,10 +15,13 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
"""Authentication and session management for admin dashboard.""" """Authentication and session management for admin dashboard."""
import base64
import hashlib import hashlib
import hmac import hmac
import json import json
import os
import secrets import secrets
import threading
import time import time
from pathlib import Path from pathlib import Path
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
...@@ -43,35 +46,62 @@ def get_or_create_secret(config_dir: Path) -> bytes: ...@@ -43,35 +46,62 @@ def get_or_create_secret(config_dir: Path) -> bytes:
def hash_password(password: str) -> str: def hash_password(password: str) -> str:
"""Hash a password using SHA-256 with salt. """Hash a password using argon2 (preferred) or scrypt as fallback.
In production, use argon2 or bcrypt. This is a minimal implementation New hashes are always produced with a proper key-derivation function and
for environments where those libraries aren't available. a per-password random salt. The legacy SHA-256/static-salt format is
only retained for *verification* of pre-existing hashes.
""" """
# Use SHA-256 with a pepper-like secret for basic hashing try:
# Real implementation should use argon2 from main.py from argon2 import PasswordHasher
salt = b'static_salt_' # In production, use per-user random salt ph = PasswordHasher()
return hashlib.sha256(salt + password.encode()).hexdigest() return ph.hash(password)
except ImportError:
pass
# scrypt fallback: encode as "scrypt:<b64salt>:<b64key>"
salt = os.urandom(16)
key = hashlib.scrypt(password.encode(), salt=salt, n=2**14, r=8, p=1)
return "scrypt:" + base64.b64encode(salt).decode() + ":" + base64.b64encode(key).decode()
def verify_password(password: str, password_hash: str) -> bool: def verify_password(password: str, password_hash: str) -> bool:
"""Verify a password against its hash.""" """Verify a password against its hash.
# Try argon2 first
Supports argon2, scrypt (new format), and the legacy SHA-256/static-salt
format so that old stored hashes continue to work.
"""
# --- argon2 ---
try: try:
from argon2 import PasswordHasher from argon2 import PasswordHasher
from argon2.exceptions import VerifyMismatchError from argon2.exceptions import VerifyMismatchError, InvalidHashError
ph = PasswordHasher() ph = PasswordHasher()
try: try:
return ph.verify(password_hash, password) return ph.verify(password_hash, password)
except VerifyMismatchError: except VerifyMismatchError:
return False return False
except InvalidHashError:
pass # not an argon2 hash; fall through
except Exception: except Exception:
pass pass
except ImportError: except ImportError:
pass pass
# Fallback to simple hash # --- scrypt ---
return hash_password(password) == password_hash if password_hash.startswith("scrypt:"):
try:
parts = password_hash.split(":")
if len(parts) == 3:
salt = base64.b64decode(parts[1])
stored_key = base64.b64decode(parts[2])
new_key = hashlib.scrypt(password.encode(), salt=salt, n=2**14, r=8, p=1)
return hmac.compare_digest(new_key, stored_key)
except Exception:
pass
return False
# --- legacy SHA-256 with static salt (read-only; never written for new passwords) ---
legacy = hashlib.sha256(b'static_salt_' + password.encode()).hexdigest()
return hmac.compare_digest(legacy, password_hash)
class SessionManager: class SessionManager:
...@@ -81,7 +111,7 @@ class SessionManager: ...@@ -81,7 +111,7 @@ class SessionManager:
self.config_dir = config_dir self.config_dir = config_dir
self.secret = get_or_create_secret(config_dir) self.secret = get_or_create_secret(config_dir)
self.session_timeout = timedelta(minutes=session_timeout_minutes) self.session_timeout = timedelta(minutes=session_timeout_minutes)
self._lock = __import__('threading').Lock() self._lock = threading.Lock()
def _load_auth_data(self) -> Dict[str, Any]: def _load_auth_data(self) -> Dict[str, Any]:
"""Load auth.json data.""" """Load auth.json data."""
......
This diff is collapsed.
...@@ -8,8 +8,8 @@ ...@@ -8,8 +8,8 @@
--border: #1A1D28; --border: #1A1D28;
--border-2: #252836; --border-2: #252836;
--text: #DDE1F0; --text: #DDE1F0;
--text-2: #636880; --text-2: #8B90A8;
--text-3: #2E3145; --text-3: #555A72;
--accent: #6366F1; --accent: #6366F1;
--accent-s: rgba(99,102,241,.12); --accent-s: rgba(99,102,241,.12);
--green: #34D399; --green: #34D399;
......
This source diff could not be displayed because it is too large. You can view the blob instead.
...@@ -28,15 +28,17 @@ ...@@ -28,15 +28,17 @@
<div class="stat-value" id="req-total">0</div> <div class="stat-value" id="req-total">0</div>
<div class="stat-sub"><span id="req-active">0</span> active</div> <div class="stat-sub"><span id="req-active">0</span> active</div>
</div> </div>
<div class="stat"> <div class="stat" id="vram-card" style="display:none">
<div class="stat-label">VRAM</div> <div class="stat-label">VRAM</div>
<div class="stat-value" id="vram-pct"></div> <div class="stat-value" id="vram-pct" style="font-size:2rem"></div>
<div class="progress" style="margin-top:.625rem"> <div class="progress" style="margin-top:.625rem">
<div class="progress-fill" id="vram-bar" style="width:0%"></div> <div class="progress-fill" id="vram-bar" style="width:0%"></div>
</div> </div>
<div class="progress-labels"> <div class="progress-labels" style="color:var(--text-1);font-size:12px;margin-top:.4rem">
<span id="vram-used"></span><span id="vram-total"></span> <span id="vram-used"></span><span id="vram-free"></span>
</div> </div>
<div style="font-size:11.5px;color:var(--text-2);margin-top:.2rem;font-family:var(--mono)" id="vram-total-line"></div>
<div class="stat-sub" id="vram-gpu" style="margin-top:.25rem"></div>
</div> </div>
</div> </div>
...@@ -85,13 +87,25 @@ async function poll() { ...@@ -85,13 +87,25 @@ async function poll() {
document.getElementById('active-models').innerHTML = html || '<span class="muted small">No models loaded</span>'; document.getElementById('active-models').innerHTML = html || '<span class="muted small">No models loaded</span>';
if (d.vram) { if (d.vram) {
const pct = Math.round(d.vram.used / d.vram.total * 100); document.getElementById('vram-card').style.display = '';
document.getElementById('vram-pct').textContent = pct + '%'; if (d.vram.free != null && d.vram.total) {
document.getElementById('vram-bar').style.width = pct + '%'; const usedPct = Math.round(d.vram.used / d.vram.total * 100);
document.getElementById('vram-used').textContent = d.vram.used.toFixed(1) + ' GB'; document.getElementById('vram-pct').textContent = usedPct + '%';
document.getElementById('vram-total').textContent = d.vram.total.toFixed(1) + ' GB'; document.getElementById('vram-bar').style.width = usedPct + '%';
document.getElementById('vram-used').textContent = d.vram.used.toFixed(1) + ' GB used';
document.getElementById('vram-free').textContent = d.vram.free.toFixed(1) + ' GB free';
document.getElementById('vram-total-line').textContent = d.vram.total.toFixed(1) + ' GB total';
} else {
document.getElementById('vram-pct').textContent = d.vram.total ? d.vram.total.toFixed(1) + ' GB' : '—';
document.getElementById('vram-bar').style.width = '0%';
document.getElementById('vram-used').textContent = '';
document.getElementById('vram-free').textContent = '';
document.getElementById('vram-total-line').textContent = '';
}
const gpuName = d.vram.gpu || '';
document.getElementById('vram-gpu').textContent = gpuName.length > 32 ? gpuName.slice(0, 32) + '…' : gpuName;
} else { } else {
document.getElementById('vram-pct').textContent = 'N/A'; document.getElementById('vram-card').style.display = 'none';
} }
if (d.requests) { if (d.requests) {
......
This diff is collapsed.
...@@ -45,6 +45,11 @@ ...@@ -45,6 +45,11 @@
<input type="text" id="s-cert" class="form-input" placeholder="/path/to/cert.pem"> <input type="text" id="s-cert" class="form-input" placeholder="/path/to/cert.pem">
</div> </div>
</div> </div>
<div class="form-row" style="margin-top:1rem;margin-bottom:0">
<label class="form-label">Request queue max size</label>
<input type="number" id="s-queue-max" class="form-input" placeholder="6" min="1" max="1000" style="max-width:160px">
<span class="form-hint">Maximum number of concurrent queued requests. Authenticated requests arriving when the queue is full receive a 429 response.</span>
</div>
</div> </div>
<!-- Storage --> <!-- Storage -->
...@@ -64,6 +69,48 @@ ...@@ -64,6 +69,48 @@
<span class="form-hint">Models will inherit this as default when configured</span> <span class="form-hint">Models will inherit this as default when configured</span>
</div> </div>
</div> </div>
<!-- Whisper Server -->
<div class="card mb-0" style="margin-top:1rem">
<div style="display:flex;align-items:center;justify-content:space-between;flex-wrap:wrap;gap:.5rem;margin-bottom:1rem">
<div class="card-title" style="margin:0">Whisper Server <span class="muted" style="font-size:11px;font-weight:400">(whisper.cpp native binary — recommended for AMD/Vulkan)</span></div>
<div style="display:flex;align-items:center;gap:.5rem">
<span id="ws-badge" class="muted small"></span>
<button class="btn btn-sm btn-secondary" onclick="wsStart()">Start</button>
<button class="btn btn-sm btn-danger" onclick="wsStop()">Stop</button>
</div>
</div>
<div style="display:grid;grid-template-columns:1fr 160px;gap:1rem;align-items:start">
<div class="form-row" style="margin:0">
<label class="form-label">Model ID <span class="muted">(used in API calls, e.g. whisper-base)</span></label>
<input type="text" id="ws-id" class="form-input" placeholder="whisper-server">
<span class="form-hint">The name clients use in the <code>model</code> field of transcription requests</span>
</div>
<div class="form-row" style="margin:0">
<label class="form-label">Port</label>
<input type="number" id="ws-port" class="form-input" placeholder="8744" min="1024" max="65535">
</div>
</div>
<div style="display:grid;grid-template-columns:1fr 160px;gap:1rem;align-items:start;margin-top:1rem">
<div class="form-row" style="margin:0">
<label class="form-label">whisper-server binary path</label>
<input type="text" id="ws-path" class="form-input" placeholder="/usr/local/bin/whisper-server">
</div>
<div class="form-row" style="margin:0">
<label class="form-label">GPU device index</label>
<input type="number" id="ws-gpu" class="form-input" placeholder="0" min="0">
</div>
</div>
<div class="form-row" style="margin-top:1rem;margin-bottom:0">
<label class="form-label">Model path <span class="muted">(GGUF whisper model, e.g. ggml-base.bin)</span></label>
<input type="text" id="ws-model" class="form-input" placeholder="/path/to/ggml-base.bin">
<span class="form-hint">Configure multiple instances by adding entries to <code>models.json</code> with <code>"backend": "whisper-server"</code></span>
</div>
<p class="form-hint" style="margin-top:.75rem;margin-bottom:0">
When configured, the transcription endpoint uses this subprocess instead of the Python faster-whisper module.
Saves settings to <code>config.json</code> and takes effect immediately (no restart needed).
</p>
</div>
{% endblock %} {% endblock %}
{% block scripts %} {% block scripts %}
...@@ -89,13 +136,69 @@ async function loadSettings(){ ...@@ -89,13 +136,69 @@ async function loadSettings(){
document.getElementById('s-https').checked = !!d.server?.https; document.getElementById('s-https').checked = !!d.server?.https;
document.getElementById('s-key').value = d.server?.https_key_path ?? ''; document.getElementById('s-key').value = d.server?.https_key_path ?? '';
document.getElementById('s-cert').value = d.server?.https_cert_path ?? ''; document.getElementById('s-cert').value = d.server?.https_cert_path ?? '';
document.getElementById('s-queue-max').value = d.server?.queue_max_size ?? 6;
document.getElementById('s-hf-cache').value = d.models?.hf_cache_dir ?? ''; document.getElementById('s-hf-cache').value = d.models?.hf_cache_dir ?? '';
document.getElementById('s-gguf-cache').value = d.models?.gguf_cache_dir ?? ''; document.getElementById('s-gguf-cache').value = d.models?.gguf_cache_dir ?? '';
document.getElementById('s-offload-dir').value = d.offload?.directory ?? './offload'; document.getElementById('s-offload-dir').value = d.offload?.directory ?? './offload';
document.getElementById('ws-path').value = d.whisper?.server_path ?? '';
document.getElementById('ws-port').value = d.whisper?.server_port ?? 8744;
toggleHttps(); toggleHttps();
}catch(e){ showAlert('error','Failed to load settings: '+e.message); } }catch(e){ showAlert('error','Failed to load settings: '+e.message); }
} }
async function loadWsStatus(){
try{
const s = await fetch('/admin/api/whisper-server/status').then(r=>r.json());
const badge = document.getElementById('ws-badge');
// s is now a dict of {model_id: {running, model, url}}
const entries = Object.entries(s);
if(!entries.length){
badge.textContent = '○ not configured';
badge.style.color = 'var(--text-2)';
return;
}
const running = entries.filter(([,v])=>v.running);
if(running.length){
badge.textContent = `● ${running.length} running`;
badge.style.color = 'var(--green, #4ade80)';
} else {
badge.textContent = '○ stopped';
badge.style.color = 'var(--text-2)';
}
}catch(e){}
}
async function wsStart(){
const path = document.getElementById('ws-path').value.trim();
if(!path){ showAlert('error','Binary path required'); return; }
try{
const r = await fetch('/admin/api/whisper-server/start',{
method:'POST', headers:{'Content-Type':'application/json'},
body: JSON.stringify({
model_id: document.getElementById('ws-id').value.trim() || 'whisper-server',
server_path: path,
model_path: document.getElementById('ws-model').value.trim() || null,
port: parseInt(document.getElementById('ws-port').value) || 8744,
gpu_device: parseInt(document.getElementById('ws-gpu').value) || 0,
})
});
const d = await r.json();
if(d.success) showAlert('info','whisper-server started');
else showAlert('error','Failed to start whisper-server');
loadWsStatus();
}catch(e){ showAlert('error','Error: '+e.message); }
}
async function wsStop(){
const modelId = document.getElementById('ws-id').value.trim() || 'whisper-server';
await fetch('/admin/api/whisper-server/stop',{
method:'POST', headers:{'Content-Type':'application/json'},
body: JSON.stringify({model_id: modelId})
});
showAlert('info','whisper-server stopped');
loadWsStatus();
}
async function saveSettings(){ async function saveSettings(){
const strOrNull = id => document.getElementById(id).value.trim() || null; const strOrNull = id => document.getElementById(id).value.trim() || null;
const data = { const data = {
...@@ -105,6 +208,7 @@ async function saveSettings(){ ...@@ -105,6 +208,7 @@ async function saveSettings(){
https: document.getElementById('s-https').checked, https: document.getElementById('s-https').checked,
https_key_path: strOrNull('s-key'), https_key_path: strOrNull('s-key'),
https_cert_path: strOrNull('s-cert'), https_cert_path: strOrNull('s-cert'),
queue_max_size: parseInt(document.getElementById('s-queue-max').value) || 6,
}, },
models:{ models:{
hf_cache_dir: strOrNull('s-hf-cache'), hf_cache_dir: strOrNull('s-hf-cache'),
...@@ -112,7 +216,11 @@ async function saveSettings(){ ...@@ -112,7 +216,11 @@ async function saveSettings(){
}, },
offload:{ offload:{
directory: document.getElementById('s-offload-dir').value.trim() || './offload', directory: document.getElementById('s-offload-dir').value.trim() || './offload',
} },
whisper:{
server_path: document.getElementById('ws-path').value.trim() || null,
server_port: parseInt(document.getElementById('ws-port').value) || 8744,
},
}; };
try{ try{
const r = await fetch('/admin/api/settings',{ const r = await fetch('/admin/api/settings',{
...@@ -125,5 +233,7 @@ async function saveSettings(){ ...@@ -125,5 +233,7 @@ async function saveSettings(){
} }
loadSettings(); loadSettings();
loadWsStatus();
setInterval(loadWsStatus, 5000);
</script> </script>
{% endblock %} {% endblock %}
...@@ -19,12 +19,16 @@ FastAPI application module for codai API. ...@@ -19,12 +19,16 @@ FastAPI application module for codai API.
Contains the FastAPI app initialization, lifespan, and core endpoints. Contains the FastAPI app initialization, lifespan, and core endpoints.
""" """
import logging
import os
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import List from typing import List
from fastapi import FastAPI, HTTPException, Request from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import FileResponse, JSONResponse from fastapi.responses import FileResponse, JSONResponse
logger = logging.getLogger(__name__)
# Import from codai modules # Import from codai modules
from codai.pydantic.textrequest import ModelList from codai.pydantic.textrequest import ModelList
from codai.models.manager import model_manager, multi_model_manager from codai.models.manager import model_manager, multi_model_manager
...@@ -89,11 +93,19 @@ from codai.api.text import router as text_router ...@@ -89,11 +93,19 @@ from codai.api.text import router as text_router
from codai.api.video import router as video_router from codai.api.video import router as video_router
from codai.api.audio_gen import router as audio_gen_router from codai.api.audio_gen import router as audio_gen_router
from codai.api.embeddings import router as embeddings_router from codai.api.embeddings import router as embeddings_router
from codai.api.pipelines import router as pipelines_router
from codai.api.custom_pipelines import router as custom_pipelines_router
from codai.api.voice_clone import router as voice_clone_router
from codai.api.voice_convert import router as voice_convert_router
from codai.api.faceswap import router as faceswap_router
from codai.api.characters import router as characters_router
from codai.admin.routes import router as admin_router from codai.admin.routes import router as admin_router
# Import and add middleware # Import and add middleware
from codai.api.log import log_requests from codai.api.log import log_requests
from codai.api.ratelimit import RateLimitMiddleware
app.middleware("http")(log_requests) app.middleware("http")(log_requests)
app.add_middleware(RateLimitMiddleware)
# Mount static files for admin dashboard # Mount static files for admin dashboard
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
...@@ -110,6 +122,12 @@ app.include_router(text_router) ...@@ -110,6 +122,12 @@ app.include_router(text_router)
app.include_router(video_router) app.include_router(video_router)
app.include_router(audio_gen_router) app.include_router(audio_gen_router)
app.include_router(embeddings_router) app.include_router(embeddings_router)
app.include_router(pipelines_router)
app.include_router(custom_pipelines_router)
app.include_router(voice_clone_router)
app.include_router(voice_convert_router)
app.include_router(faceswap_router)
app.include_router(characters_router)
app.include_router(admin_router) app.include_router(admin_router)
...@@ -133,11 +151,14 @@ async def list_models(): ...@@ -133,11 +151,14 @@ async def list_models():
@app.get("/v1/files/{filename}") @app.get("/v1/files/{filename}")
async def get_file(filename: str): async def get_file(filename: str):
"""Serve uploaded/generated files.""" """Serve uploaded/generated files."""
print(f"DEBUG get_file: filename={filename}, global_file_path={global_file_path}") if not global_file_path:
if global_file_path: raise HTTPException(status_code=404, detail="File not found")
import os # Prevent path traversal: resolve to real paths and confirm the result
file_path = os.path.join(global_file_path, filename) # stays inside the configured directory.
print(f"DEBUG get_file: full path={file_path}, exists={os.path.exists(file_path)}") safe_base = os.path.realpath(global_file_path)
if os.path.exists(file_path): candidate = os.path.realpath(os.path.join(global_file_path, filename))
return FileResponse(file_path) if not (candidate == safe_base or candidate.startswith(safe_base + os.sep)):
raise HTTPException(status_code=403, detail="Access denied")
if not os.path.isfile(candidate):
raise HTTPException(status_code=404, detail="File not found") raise HTTPException(status_code=404, detail="File not found")
return FileResponse(candidate)
\ No newline at end of file
# CoderAI - OpenAI-compatible API server
# Copyright (C) 2026 Stefy Lanza <stefy@nexlab.net>
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
Character profile endpoints.
Saved character profiles are named collections of reference images used to
maintain visual consistency of a character across multiple video generations.
POST /v1/characters – save / update a character profile
GET /v1/characters – list all saved profiles (no images)
GET /v1/characters/{name} – get a profile including base64 images
DELETE /v1/characters/{name} – delete a profile
"""
import base64
import json
import os
import time
from typing import List, Optional
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel, ConfigDict
router = APIRouter()
_CHARS_DIR: Optional[str] = None
def set_global_args(args):
global _CHARS_DIR
base = getattr(args, 'file_path', None) or os.path.expanduser('~/.coderai')
root = base if os.path.isdir(base) else (os.path.dirname(base) if base else os.path.expanduser('~/.coderai'))
_CHARS_DIR = os.path.join(root, 'characters')
os.makedirs(_CHARS_DIR, exist_ok=True)
def set_global_file_path(path: str):
pass # not needed for characters
def _chars_dir() -> str:
if _CHARS_DIR:
return _CHARS_DIR
d = os.path.expanduser('~/.coderai/characters')
os.makedirs(d, exist_ok=True)
return d
def _char_dir(name: str) -> str:
return os.path.join(_chars_dir(), name)
# ── Pydantic models ───────────────────────────────────────────────────────────
class CharacterImage(BaseModel):
label: Optional[str] = None # e.g. "front", "side", "close-up"
data: str # base64 image (with or without data: prefix)
model_config = ConfigDict(extra="allow")
class CharacterSaveRequest(BaseModel):
name: str
description: Optional[str] = ""
images: List[CharacterImage] # one or more reference images
model_config = ConfigDict(extra="allow")
class CharacterProfile(BaseModel):
name: str
description: Optional[str] = ""
image_count: int
created_at: int
images: Optional[List[CharacterImage]] = None # only populated on GET /{name}
model_config = ConfigDict(extra="allow")
# ── Helpers ───────────────────────────────────────────────────────────────────
def _save_character(name: str, description: str, images: List[CharacterImage]) -> dict:
cdir = _char_dir(name)
os.makedirs(cdir, exist_ok=True)
img_files = []
for i, img in enumerate(images):
raw = img.data
if raw.startswith('data:'):
_, b64 = raw.split(',', 1)
else:
b64 = raw
img_bytes = base64.b64decode(b64)
# Detect PNG vs JPEG from magic bytes
ext = '.png' if img_bytes[:4] == b'\x89PNG' else '.jpg'
fname = f"ref{i:02d}{ext}"
fpath = os.path.join(cdir, fname)
with open(fpath, 'wb') as f:
f.write(img_bytes)
img_files.append({'file': fname, 'label': img.label or f'ref{i}'})
meta = {
'name': name,
'description': description,
'images': img_files,
'image_count': len(img_files),
'created_at': int(time.time()),
}
with open(os.path.join(cdir, 'meta.json'), 'w') as f:
json.dump(meta, f)
return meta
def _load_character_meta(name: str) -> Optional[dict]:
meta_path = os.path.join(_char_dir(name), 'meta.json')
if not os.path.exists(meta_path):
return None
with open(meta_path) as f:
return json.load(f)
def _load_character_images(name: str) -> List[CharacterImage]:
meta = _load_character_meta(name)
if not meta:
return []
cdir = _char_dir(name)
result = []
for img_info in meta.get('images', []):
fpath = os.path.join(cdir, img_info['file'])
if not os.path.exists(fpath):
continue
with open(fpath, 'rb') as f:
raw = f.read()
ext = img_info['file'].rsplit('.', 1)[-1]
mime = 'image/png' if ext == 'png' else 'image/jpeg'
b64 = base64.b64encode(raw).decode()
result.append(CharacterImage(
label=img_info.get('label'),
data=f"data:{mime};base64,{b64}",
))
return result
def _list_characters() -> list:
d = _chars_dir()
profiles = []
for entry in os.scandir(d):
if entry.is_dir():
meta = _load_character_meta(entry.name)
if meta:
profiles.append({k: v for k, v in meta.items() if k != 'images'})
return sorted(profiles, key=lambda p: p.get('created_at', 0))
def resolve_character_profiles(profile_names: List[str]) -> List[str]:
"""Resolve saved profile names → flat list of base64 image strings."""
out = []
for name in profile_names:
for img in _load_character_images(name):
out.append(img.data)
return out
# ── Endpoints ─────────────────────────────────────────────────────────────────
@router.post("/v1/characters")
async def save_character(req: CharacterSaveRequest):
"""Save or update a named character profile."""
if not req.name or '/' in req.name or '..' in req.name:
raise HTTPException(status_code=400, detail="Invalid character name")
if not req.images:
raise HTTPException(status_code=400, detail="At least one reference image required")
meta = _save_character(req.name, req.description or '', req.images)
return {"ok": True, "name": meta['name'], "image_count": meta['image_count']}
@router.get("/v1/characters")
async def list_characters():
"""List all saved character profiles (metadata only, no images)."""
return {"characters": _list_characters()}
@router.get("/v1/characters/{name}")
async def get_character(name: str):
"""Get a character profile including its reference images as base64."""
meta = _load_character_meta(name)
if not meta:
raise HTTPException(status_code=404, detail=f"Character '{name}' not found")
images = _load_character_images(name)
return {
"name": meta['name'],
"description": meta.get('description', ''),
"image_count": meta['image_count'],
"created_at": meta['created_at'],
"images": [img.model_dump() for img in images],
}
@router.delete("/v1/characters/{name}")
async def delete_character(name: str):
"""Delete a character profile."""
cdir = _char_dir(name)
if not os.path.isdir(cdir):
raise HTTPException(status_code=404, detail=f"Character '{name}' not found")
import shutil
shutil.rmtree(cdir)
return {"ok": True, "name": name}
This diff is collapsed.
"""
Face swap endpoint.
POST /v1/images/faceswap — swap face in image or video frames
"""
import asyncio
import base64
import io
import os
import subprocess
import tempfile
import time
from typing import Optional
import cv2
import numpy as np
from fastapi import APIRouter, HTTPException, Request
from PIL import Image
from pydantic import BaseModel, ConfigDict
from codai.api.images import save_image_response
router = APIRouter()
global_args = None
global_file_path = None
_INSWAPPER_MODEL_PATH = os.path.expanduser('~/.insightface/models/inswapper_128.onnx')
_INSWAPPER_HF_REPO = 'deepinsight/inswapper'
_INSWAPPER_HF_FILE = 'inswapper_128.onnx'
_face_app = None # FaceAnalysis singleton
_swapper = None # INSwapper singleton
def set_global_args(args):
global global_args
global_args = args
def set_global_file_path(path):
global global_file_path
global_file_path = path
def _ensure_model():
"""Download inswapper_128.onnx if not present."""
if os.path.exists(_INSWAPPER_MODEL_PATH):
return
os.makedirs(os.path.dirname(_INSWAPPER_MODEL_PATH), exist_ok=True)
print(f'Downloading inswapper_128.onnx from HuggingFace…')
try:
from huggingface_hub import hf_hub_download
path = hf_hub_download(
repo_id=_INSWAPPER_HF_REPO,
filename=_INSWAPPER_HF_FILE,
local_dir=os.path.dirname(_INSWAPPER_MODEL_PATH),
)
if path != _INSWAPPER_MODEL_PATH:
import shutil
shutil.move(path, _INSWAPPER_MODEL_PATH)
except Exception as e:
raise RuntimeError(f'Failed to download inswapper model: {e}')
def _get_face_app():
global _face_app
if _face_app is None:
from insightface.app import FaceAnalysis
_face_app = FaceAnalysis(name='buffalo_l', providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
_face_app.prepare(ctx_id=0, det_size=(640, 640))
return _face_app
def _get_swapper():
global _swapper
if _swapper is None:
_ensure_model()
from insightface.model_zoo import get_model
_swapper = get_model(_INSWAPPER_MODEL_PATH, download=False)
_swapper.prepare(ctx_id=0)
return _swapper
def _decode_image(data: str) -> np.ndarray:
"""Decode base64 or data-URI image to BGR numpy array."""
if data.startswith('data:'):
_, b64 = data.split(',', 1)
data = b64
raw = base64.b64decode(data)
arr = np.frombuffer(raw, np.uint8)
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
if img is None:
raise ValueError('Could not decode image')
return img
def _swap_faces(source_img: np.ndarray, target_img: np.ndarray) -> np.ndarray:
"""Swap all faces in target_img with the face from source_img."""
app = _get_face_app()
swapper = _get_swapper()
src_faces = app.get(source_img)
if not src_faces:
raise ValueError('No face detected in source image')
src_face = src_faces[0]
tgt_faces = app.get(target_img)
if not tgt_faces:
return target_img # no face to swap in target, return as-is
result = target_img.copy()
for tgt_face in tgt_faces:
result = swapper.get(result, tgt_face, src_face, paste_back=True)
return result
def _decode_b64_or_url(data: str) -> bytes:
if data.startswith('data:'):
_, b64 = data.split(',', 1)
return base64.b64decode(b64)
if data.startswith('http'):
import urllib.request
with urllib.request.urlopen(data, timeout=30) as r:
return r.read()
return base64.b64decode(data)
# ---------------------------------------------------------------------------
# Request model
# ---------------------------------------------------------------------------
class FaceSwapRequest(BaseModel):
source_face: str # base64/data-URI image containing the source face
target: str # base64/data-URI image OR video to swap into
target_type: Optional[str] = 'image' # 'image' or 'video'
response_format: Optional[str] = 'url'
model_config = ConfigDict(extra='allow')
# ---------------------------------------------------------------------------
# Endpoint
# ---------------------------------------------------------------------------
@router.post('/v1/images/faceswap')
async def faceswap(request: FaceSwapRequest, http_request: Request = None):
"""
Swap the face from source_face into every face found in target.
target_type: 'image' (default) or 'video'.
"""
try:
_ensure_model()
except RuntimeError as e:
raise HTTPException(status_code=503, detail=str(e))
try:
src_img = _decode_image(request.source_face)
except Exception as e:
raise HTTPException(status_code=400, detail=f'Invalid source_face: {e}')
if request.target_type == 'video':
return await _faceswap_video(src_img, request, http_request)
else:
return await _faceswap_image(src_img, request, http_request)
async def _faceswap_image(src_img, request, http_request):
try:
tgt_img = _decode_image(request.target)
except Exception as e:
raise HTTPException(status_code=400, detail=f'Invalid target: {e}')
try:
result = await asyncio.get_event_loop().run_in_executor(
None, _swap_faces, src_img, tgt_img)
except ValueError as e:
raise HTTPException(status_code=422, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=f'Face swap failed: {e}')
pil_img = Image.fromarray(cv2.cvtColor(result, cv2.COLOR_BGR2RGB))
img_data = save_image_response(pil_img, request.response_format, http_request)
return {'created': int(time.time()), 'data': [img_data]}
async def _faceswap_video(src_img, request, http_request):
raw = _decode_b64_or_url(request.target)
temps = []
try:
# Write input video
in_tmp = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False)
in_tmp.write(raw); in_tmp.close()
in_path = in_tmp.name
temps.append(in_path)
# Extract frames
frames_dir = tempfile.mkdtemp()
temps.append(frames_dir)
subprocess.run(
['ffmpeg', '-y', '-i', in_path, f'{frames_dir}/%08d.png'],
capture_output=True, check=True)
# Get FPS for reassembly
probe = subprocess.run(
['ffprobe', '-v', 'error', '-select_streams', 'v:0',
'-show_entries', 'stream=r_frame_rate', '-of', 'default=nw=1:nk=1', in_path],
capture_output=True, text=True)
fps_str = probe.stdout.strip() or '25/1'
num, den = fps_str.split('/')
fps = float(num) / float(den)
# Swap faces in each frame
frame_files = sorted(os.listdir(frames_dir))
def _process_frames():
app = _get_face_app()
swapper = _get_swapper()
src_faces = app.get(src_img)
if not src_faces:
raise ValueError('No face detected in source image')
src_face = src_faces[0]
for fname in frame_files:
fpath = os.path.join(frames_dir, fname)
frame = cv2.imread(fpath)
if frame is None:
continue
tgt_faces = app.get(frame)
for tgt_face in tgt_faces:
frame = swapper.get(frame, tgt_face, src_face, paste_back=True)
cv2.imwrite(fpath, frame)
await asyncio.get_event_loop().run_in_executor(None, _process_frames)
# Reassemble video (copy original audio)
out_path = tempfile.mktemp(suffix='_swapped.mp4')
temps.append(out_path)
subprocess.run(
['ffmpeg', '-y', '-framerate', str(fps), '-i', f'{frames_dir}/%08d.png',
'-i', in_path, '-map', '0:v', '-map', '1:a?',
'-c:v', 'libx264', '-c:a', 'copy', '-shortest', out_path],
capture_output=True, check=True)
with open(out_path, 'rb') as f:
out_bytes = f.read()
if global_file_path:
import uuid
fname = f'{uuid.uuid4().hex}_swapped.mp4'
fpath = os.path.join(global_file_path, fname)
os.makedirs(global_file_path, exist_ok=True)
with open(fpath, 'wb') as f:
f.write(out_bytes)
host = http_request.headers.get('host', '127.0.0.1') if http_request else '127.0.0.1'
if ':' in host:
parts = host.split(':')
if len(parts) == 2 and parts[1].isdigit():
host = parts[0]
proto = 'https' if getattr(global_args, 'https', False) else 'http'
port = getattr(global_args, 'port', 8000) if global_args else 8000
data = [{'url': f'{proto}://{host}:{port}/v1/files/{fname}'}]
else:
data = [{'b64_mp4': base64.b64encode(out_bytes).decode()}]
return {'created': int(time.time()), 'data': data}
except subprocess.CalledProcessError as e:
raise HTTPException(status_code=500, detail=f'ffmpeg error: {e.stderr.decode()[:200]}')
except ValueError as e:
raise HTTPException(status_code=422, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=f'Video face swap failed: {e}')
finally:
import shutil
for t in temps:
try:
if os.path.isdir(t):
shutil.rmtree(t)
else:
os.unlink(t)
except Exception:
pass
This diff is collapsed.
This diff is collapsed.
# CoderAI - OpenAI-compatible API server
# Copyright (C) 2026 Stefy Lanza <stefy@nexlab.net>
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
"""Simple in-process token-bucket rate limiter middleware.
Each distinct (client-IP, route-prefix) pair gets its own bucket.
Limits are configured via RateLimitConfig. The defaults below are
intentionally generous; tighten them through the config file or CLI.
Endpoints covered:
/v1/chat/completions — expensive LLM inference
/v1/images/ — image generation
/v1/audio/ — TTS / STT / audio generation
/v1/video/ — video generation
/v1/embeddings — embedding
/v1/completions — legacy completions
"""
import time
import threading
from collections import defaultdict
from typing import Dict, Tuple
from fastapi import Request, Response
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
# Per-route-prefix defaults: (max_requests, window_seconds)
_DEFAULT_LIMITS: Dict[str, Tuple[int, int]] = {
"/v1/chat/completions": (60, 60),
"/v1/completions": (60, 60),
"/v1/images/": (30, 60),
"/v1/audio/": (60, 60),
"/v1/video/": (10, 60),
"/v1/embeddings": (120, 60),
}
# API prefixes that count against the request queue
_QUEUED_PREFIXES = ("/v1/",)
# Global toggle — set to False to disable rate limiting entirely.
RATE_LIMITING_ENABLED = True
class _Bucket:
"""Fixed-window counter."""
__slots__ = ("count", "window_start")
def __init__(self, now: float):
self.count = 0
self.window_start = now
class RateLimitMiddleware(BaseHTTPMiddleware):
"""Apply per-IP, per-route-prefix rate limiting to API endpoints."""
def __init__(self, app, limits: Dict[str, Tuple[int, int]] = None):
super().__init__(app)
self._limits = limits or _DEFAULT_LIMITS
# (client_ip, prefix) → _Bucket
self._buckets: Dict[Tuple[str, str], _Bucket] = defaultdict(lambda: _Bucket(time.monotonic()))
self._lock = threading.Lock()
def _get_prefix(self, path: str) -> str:
for prefix in self._limits:
if path.startswith(prefix):
return prefix
return ""
async def dispatch(self, request: Request, call_next):
if not RATE_LIMITING_ENABLED:
return await call_next(request)
path = request.url.path
# Queue-size enforcement for authenticated API requests
if any(path.startswith(p) for p in _QUEUED_PREFIXES):
from codai.queue.manager import queue_manager
if await queue_manager.is_full():
return JSONResponse(
status_code=429,
content={
"error": {
"message": "Server queue is full. Please retry later.",
"type": "rate_limit_error",
"code": 429,
}
},
headers={"Retry-After": "5"},
)
prefix = self._get_prefix(path)
if not prefix:
return await call_next(request)
max_req, window = self._limits[prefix]
client_ip = (
request.headers.get("x-forwarded-for", "").split(",")[0].strip()
or (request.client.host if request.client else "unknown")
)
key = (client_ip, prefix)
now = time.monotonic()
with self._lock:
bucket = self._buckets[key]
if now - bucket.window_start >= window:
bucket.count = 0
bucket.window_start = now
bucket.count += 1
count = bucket.count
remaining = max(0, max_req - count)
reset_at = int(time.time() + (window - (now - self._buckets[key].window_start)))
if count > max_req:
return JSONResponse(
status_code=429,
content={
"error": {
"message": "Rate limit exceeded. Please slow down.",
"type": "rate_limit_error",
"code": 429,
}
},
headers={
"X-RateLimit-Limit": str(max_req),
"X-RateLimit-Remaining": "0",
"X-RateLimit-Reset": str(reset_at),
"Retry-After": str(window),
},
)
response = await call_next(request)
response.headers["X-RateLimit-Limit"] = str(max_req)
response.headers["X-RateLimit-Remaining"] = str(remaining)
response.headers["X-RateLimit-Reset"] = str(reset_at)
return response
This diff is collapsed.
...@@ -23,8 +23,15 @@ import os ...@@ -23,8 +23,15 @@ import os
import tempfile import tempfile
from fastapi import APIRouter, HTTPException, UploadFile, File, Form from fastapi import APIRouter, HTTPException, UploadFile, File, Form
from fastapi.responses import PlainTextResponse
from typing import Optional from typing import Optional
# Maximum upload size: 100 MB
_MAX_AUDIO_BYTES = 100 * 1024 * 1024
# Safe audio extensions (user-supplied extension is NOT trusted for the suffix)
_SAFE_EXTENSIONS = {'.wav', '.mp3', '.ogg', '.flac', '.m4a', '.webm', '.mp4'}
# Import from codai modules # Import from codai modules
from codai.models.manager import multi_model_manager from codai.models.manager import multi_model_manager
...@@ -39,6 +46,71 @@ def set_global_args(args): ...@@ -39,6 +46,71 @@ def set_global_args(args):
global_args = args global_args = args
# =============================================================================
# Response formatting helpers
# =============================================================================
def _seconds_to_srt_time(s: float) -> str:
h = int(s // 3600)
m = int((s % 3600) // 60)
sec = s % 60
return f"{h:02d}:{m:02d}:{sec:06.3f}".replace('.', ',')
def _seconds_to_vtt_time(s: float) -> str:
h = int(s // 3600)
m = int((s % 3600) // 60)
sec = s % 60
return f"{h:02d}:{m:02d}:{sec:06.3f}"
def _format_response(fmt: str, text: str, segments: list):
"""Format a transcription result according to the requested response_format."""
fmt = (fmt or "json").lower()
if fmt == "text":
return PlainTextResponse(text)
if fmt == "srt":
lines = []
for i, seg in enumerate(segments, 1):
start = _seconds_to_srt_time(seg.get("start", 0))
end = _seconds_to_srt_time(seg.get("end", 0))
lines.append(f"{i}\n{start} --> {end}\n{seg['text'].strip()}\n")
srt_body = "\n".join(lines) if lines else f"1\n00:00:00,000 --> 00:00:00,000\n{text}\n"
return PlainTextResponse(srt_body, media_type="text/plain")
if fmt == "vtt":
lines = ["WEBVTT\n"]
for seg in segments:
start = _seconds_to_vtt_time(seg.get("start", 0))
end = _seconds_to_vtt_time(seg.get("end", 0))
lines.append(f"{start} --> {end}\n{seg['text'].strip()}\n")
if not segments:
lines.append(f"00:00:00.000 --> 00:00:00.000\n{text}\n")
return PlainTextResponse("\n".join(lines), media_type="text/vtt")
if fmt == "verbose_json":
return {
"task": "transcribe",
"language": "unknown",
"duration": segments[-1].get("end", 0) if segments else 0,
"text": text,
"segments": [
{
"id": i,
"start": s.get("start", 0),
"end": s.get("end", 0),
"text": s.get("text", "").strip(),
}
for i, s in enumerate(segments)
],
}
# Default: json
return {"text": text}
# ============================================================================= # =============================================================================
# Router and Endpoints # Router and Endpoints
# ============================================================================= # =============================================================================
...@@ -58,17 +130,37 @@ async def create_transcription( ...@@ -58,17 +130,37 @@ async def create_transcription(
""" """
Audio transcription endpoint (OpenAI-compatible). Audio transcription endpoint (OpenAI-compatible).
""" """
# Check if whisper-server is available FIRST
if multi_model_manager.whisper_server and multi_model_manager.whisper_server.is_running():
file_content = await file.read() file_content = await file.read()
result = multi_model_manager.whisper_server.transcribe( if len(file_content) > _MAX_AUDIO_BYTES:
file_content, raise HTTPException(status_code=413, detail="Audio file too large (max 100 MB)")
language=language,
prompt=prompt # Check if the requested model is a whisper-server instance
) wsm = multi_model_manager.whisper_servers.get(model)
if wsm is None and multi_model_manager.whisper_server is not None:
# Legacy single-instance fallback: use it if no specific match
if not multi_model_manager.whisper_servers:
wsm = multi_model_manager.whisper_server
if wsm is not None:
ws_key = f"audio:{model}" if model in multi_model_manager.whisper_servers else "audio:whisper-server"
# Let the VRAM manager evict other models if needed
multi_model_manager.request_model(requested_model=model, model_type="audio")
# Start the subprocess if it isn't running (on-demand)
if not wsm.is_running():
wsm.start(getattr(wsm, '_model_path', None), gpu_device=getattr(wsm, '_gpu_device', 0))
if wsm.is_running():
multi_model_manager.models[ws_key] = wsm
multi_model_manager.active_in_vram = ws_key
multi_model_manager.models_in_vram.add(ws_key)
if wsm.is_running():
result = wsm.transcribe(file_content, language=language, prompt=prompt)
if "error" in result: if "error" in result:
raise HTTPException(status_code=500, detail=result["error"]) raise HTTPException(status_code=500, detail=result["error"])
return {"text": result.get("text", "")} return _format_response(response_format, result.get("text", ""), [])
# Fall through to Python backends if subprocess failed to start
# Use the manager to resolve the model and manage VRAM # Use the manager to resolve the model and manage VRAM
model_info = multi_model_manager.request_model( model_info = multi_model_manager.request_model(
...@@ -90,11 +182,13 @@ async def create_transcription( ...@@ -90,11 +182,13 @@ async def create_transcription(
detail="Audio transcription not configured. Use --audio-model or --whisper-server." detail="Audio transcription not configured. Use --audio-model or --whisper-server."
) )
# Read the uploaded file # Determine a safe file extension from the upload's content-type or filename,
file_content = await file.read() # never trusting the raw user-supplied value for arbitrary suffixes.
raw_ext = os.path.splitext(file.filename or '')[1].lower()
safe_ext = raw_ext if raw_ext in _SAFE_EXTENSIONS else '.wav'
# Save to temp file (needed for some backends) # Save to temp file (needed for some backends)
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as tmp: with tempfile.NamedTemporaryFile(delete=False, suffix=safe_ext) as tmp:
tmp.write(file_content) tmp.write(file_content)
tmp_path = tmp.name tmp_path = tmp.name
...@@ -104,41 +198,27 @@ async def create_transcription( ...@@ -104,41 +198,27 @@ async def create_transcription(
from faster_whisper import WhisperModel from faster_whisper import WhisperModel
if whisper_model is None: if whisper_model is None:
print(f"Loading faster-whisper model: {model_name}")
# Determine compute type - always use int8 for CPU
compute_type = "int8"
# Load the model
whisper_model = WhisperModel( whisper_model = WhisperModel(
model_name, model_name,
device="cpu", # Always use CPU - faster-whisper CUDA doesn't work with AMD device="cpu",
compute_type=compute_type, compute_type="int8",
) )
# Cache the model
multi_model_manager.add_model(model_key, whisper_model) multi_model_manager.add_model(model_key, whisper_model)
multi_model_manager.current_model_key = model_key multi_model_manager.current_model_key = model_key
print(f"Loaded faster-whisper model: {model_name}")
# Run transcription raw_segments, _ = whisper_model.transcribe(
segments, info = whisper_model.transcribe(
tmp_path, tmp_path,
language=language, language=language,
initial_prompt=prompt, initial_prompt=prompt,
temperature=temperature, temperature=temperature,
) )
# Materialise the generator so we have all segment data
# Collect all segments segments = [
text_parts = [] {"start": s.start, "end": s.end, "text": s.text}
for segment in segments: for s in raw_segments
text_parts.append(segment.text) ]
full_text = "".join(s["text"] for s in segments)
full_text = "".join(text_parts) return _format_response(response_format, full_text.strip(), segments)
return {
"text": full_text.strip()
}
except ImportError: except ImportError:
pass pass
...@@ -148,41 +228,26 @@ async def create_transcription( ...@@ -148,41 +228,26 @@ async def create_transcription(
import whispercpp import whispercpp
if whisper_model is None: if whisper_model is None:
print(f"Loading whispercpp model: {model_name}")
# Check if it's a built-in model name
if model_name in ['tiny.en', 'tiny', 'base.en', 'base', 'small.en', 'small', 'medium.en', 'medium', 'large-v1', 'large']:
# It's a built-in model name
whisper_model = whispercpp.Whisper.from_pretrained(model_name)
else:
# It's a path to a GGUF file
whisper_model = whispercpp.Whisper.from_pretrained(model_name) whisper_model = whispercpp.Whisper.from_pretrained(model_name)
# Cache the model
multi_model_manager.add_model(model_key, whisper_model) multi_model_manager.add_model(model_key, whisper_model)
multi_model_manager.current_model_key = model_key multi_model_manager.current_model_key = model_key
print(f"Loaded whispercpp model: {model_name}")
# Run transcription
result = whisper_model.transcribe(tmp_path) result = whisper_model.transcribe(tmp_path)
# Extract text from result
text = "" text = ""
if hasattr(result, 'text'): if hasattr(result, 'text'):
text = result.text text = result.text
elif isinstance(result, dict): elif isinstance(result, dict):
text = result.get('text', '') text = result.get('text', '')
elif isinstance(result, list): elif isinstance(result, list):
# Some versions return a list of segments
for segment in result: for segment in result:
if hasattr(segment, 'text'): if hasattr(segment, 'text'):
text += segment.text text += segment.text
elif isinstance(segment, dict): elif isinstance(segment, dict):
text += segment.get('text', '') text += segment.get('text', '')
return { # whispercpp does not expose per-segment timestamps easily
"text": text.strip() return _format_response(response_format, text.strip(), [])
}
except ImportError as e: except ImportError as e:
raise HTTPException( raise HTTPException(
......
...@@ -263,7 +263,39 @@ def _apply_camera_motion(kw: dict, camera_motion: str): ...@@ -263,7 +263,39 @@ def _apply_camera_motion(kw: dict, camera_motion: str):
kw['camera_motion'] = camera_motion kw['camera_motion'] = camera_motion
def _apply_character_refs(kw: dict, character_references: List[str], strength: float): def _resolve_character_inputs(request) -> tuple[List[str], List[str]]:
"""Return (flat_image_list, name_list) from any combination of request fields."""
images: List[str] = []
names: List[str] = []
# 1. Expand named saved profiles
if request.character_profiles:
try:
from codai.api.characters import resolve_character_profiles
images += resolve_character_profiles(request.character_profiles)
names += list(request.character_profiles)
except Exception:
pass
# 2. Named character slots [{name, images:[...]}, ...]
if request.characters:
for slot in request.characters:
slot_imgs = slot.get('images') or []
images += slot_imgs
if slot.get('name'):
names.append(slot['name'])
# 3. Legacy flat list
if request.character_references:
images += list(request.character_references)
if request.character_names:
names += list(request.character_names)
return images, names
def _apply_character_refs(kw: dict, character_references: List[str], strength: float,
names: Optional[List[str]] = None):
"""Apply character reference images to pipeline kwargs.""" """Apply character reference images to pipeline kwargs."""
if not character_references: if not character_references:
return return
...@@ -291,8 +323,13 @@ def _generate_video(pipe, request: VideoGenerationRequest): ...@@ -291,8 +323,13 @@ def _generate_video(pipe, request: VideoGenerationRequest):
_apply_camera_motion(kw, request.camera_motion) _apply_camera_motion(kw, request.camera_motion)
if request.character_references: char_images, char_names = _resolve_character_inputs(request)
_apply_character_refs(kw, request.character_references, request.character_strength or 0.8) if char_images:
_apply_character_refs(kw, char_images, request.character_strength or 0.8, char_names)
# Prepend character names to prompt for better conditioning
if char_names and kw.get('prompt'):
names_hint = ', '.join(char_names)
kw['prompt'] = f"{names_hint}. {kw['prompt']}"
init_src = request.init_image or request.image init_src = request.init_image or request.image
...@@ -359,35 +396,49 @@ def _ffmpeg_upscale(path: str, factor: int, temps: list) -> str: ...@@ -359,35 +396,49 @@ def _ffmpeg_upscale(path: str, factor: int, temps: list) -> str:
scale = f"scale=iw*{factor}:ih*{factor}:flags=lanczos" scale = f"scale=iw*{factor}:ih*{factor}:flags=lanczos"
cmd = ['ffmpeg', '-y', '-i', path, '-vf', scale, '-c:a', 'copy', out] cmd = ['ffmpeg', '-y', '-i', path, '-vf', scale, '-c:a', 'copy', out]
r = subprocess.run(cmd, capture_output=True) r = subprocess.run(cmd, capture_output=True)
if r.returncode == 0: if r.returncode != 0:
return out import logging
logging.getLogger(__name__).warning(
"ffmpeg upscale failed (rc=%d): %s", r.returncode, r.stderr.decode(errors='replace')
)
return path # fallback to original if ffmpeg fails return path # fallback to original if ffmpeg fails
return out
def _rife_interpolate(path: str, multiplier: int, temps: list) -> str: def _rife_interpolate(path: str, multiplier: int, temps: list) -> str:
out = tempfile.mktemp(suffix='_rife.mp4') out = tempfile.mktemp(suffix='_rife.mp4')
temps.append(out) temps.append(out)
# Try rife-ncnn-vulkan binary if available import logging, shutil
import shutil _log = logging.getLogger(__name__)
if shutil.which('rife-ncnn-vulkan'): if shutil.which('rife-ncnn-vulkan'):
frames_dir = tempfile.mkdtemp() frames_dir = tempfile.mkdtemp()
out_dir = tempfile.mkdtemp() out_dir = tempfile.mkdtemp()
temps += [frames_dir, out_dir] temps += [frames_dir, out_dir]
subprocess.run(['ffmpeg', '-y', '-i', path, f'{frames_dir}/%08d.png'], r = subprocess.run(['ffmpeg', '-y', '-i', path, f'{frames_dir}/%08d.png'],
capture_output=True) capture_output=True)
subprocess.run(['rife-ncnn-vulkan', '-i', frames_dir, '-o', out_dir, if r.returncode != 0:
'-m', f'rife-v4'], capture_output=True) _log.warning("ffmpeg frame extraction failed: %s", r.stderr.decode(errors='replace'))
subprocess.run(['ffmpeg', '-y', '-r', str(multiplier * 8), '-i', else:
r = subprocess.run(['rife-ncnn-vulkan', '-i', frames_dir, '-o', out_dir,
'-m', 'rife-v4'], capture_output=True)
if r.returncode != 0:
_log.warning("rife-ncnn-vulkan failed: %s", r.stderr.decode(errors='replace'))
else:
r = subprocess.run(['ffmpeg', '-y', '-r', str(multiplier * 8), '-i',
f'{out_dir}/%08d.png', '-c:v', 'libx264', out], f'{out_dir}/%08d.png', '-c:v', 'libx264', out],
capture_output=True) capture_output=True)
if os.path.exists(out): if r.returncode != 0:
_log.warning("ffmpeg reassembly failed: %s", r.stderr.decode(errors='replace'))
elif os.path.exists(out):
return out return out
# Simple ffmpeg minterpolate fallback # Simple ffmpeg minterpolate fallback
fps_expr = f"fps=fps={multiplier}*source_fps"
cmd = ['ffmpeg', '-y', '-i', path, '-filter:v', cmd = ['ffmpeg', '-y', '-i', path, '-filter:v',
f'minterpolate=fps={multiplier * 8}', '-c:a', 'copy', out] f'minterpolate=fps={multiplier * 8}', '-c:a', 'copy', out]
r = subprocess.run(cmd, capture_output=True) r = subprocess.run(cmd, capture_output=True)
return out if r.returncode == 0 else path if r.returncode != 0:
_log.warning("ffmpeg minterpolate failed: %s", r.stderr.decode(errors='replace'))
return path
return out
def _add_audio_to_video(path: str, request: VideoGenerationRequest, def _add_audio_to_video(path: str, request: VideoGenerationRequest,
......
"""
Voice cloning endpoints.
POST /v1/audio/clone — synthesize speech in a cloned voice
GET /v1/audio/voices — list saved voice profiles
POST /v1/audio/voices — save a named voice profile (ref audio + transcript)
DELETE /v1/audio/voices/{name} — delete a voice profile
"""
import asyncio
import base64
import io
import json
import os
import tempfile
import time
from typing import Optional
from fastapi import APIRouter, HTTPException, Request, UploadFile, File, Form
from pydantic import BaseModel, ConfigDict
router = APIRouter()
global_args = None
global_file_path = None
# Directory where voice profiles are stored
_VOICES_DIR: Optional[str] = None
def set_global_args(args):
global global_args, _VOICES_DIR
global_args = args
# Store voice profiles alongside output files, or in a default location
base = getattr(args, 'file_path', None) or os.path.expanduser('~/.coderai/voices')
_VOICES_DIR = os.path.join(base if os.path.isdir(base) else os.path.dirname(base) if base else os.path.expanduser('~/.coderai'), 'voices')
os.makedirs(_VOICES_DIR, exist_ok=True)
def set_global_file_path(path):
global global_file_path
global_file_path = path
def _voices_dir() -> str:
if _VOICES_DIR:
return _VOICES_DIR
d = os.path.expanduser('~/.coderai/voices')
os.makedirs(d, exist_ok=True)
return d
def _voice_path(name: str) -> str:
return os.path.join(_voices_dir(), name)
def _list_voices() -> list:
d = _voices_dir()
voices = []
for entry in os.scandir(d):
if entry.is_dir():
meta_path = os.path.join(entry.path, 'meta.json')
if os.path.exists(meta_path):
with open(meta_path) as f:
meta = json.load(f)
voices.append(meta)
return sorted(voices, key=lambda v: v.get('created_at', 0))
def _save_voice(name: str, audio_bytes: bytes, audio_ext: str, transcript: str, description: str = '') -> dict:
vdir = _voice_path(name)
os.makedirs(vdir, exist_ok=True)
audio_file = os.path.join(vdir, f'ref{audio_ext}')
with open(audio_file, 'wb') as f:
f.write(audio_bytes)
meta = {
'name': name,
'description': description,
'transcript': transcript,
'audio_file': audio_file,
'audio_ext': audio_ext,
'created_at': int(time.time()),
}
with open(os.path.join(vdir, 'meta.json'), 'w') as f:
json.dump(meta, f)
return meta
def _load_voice(name: str) -> Optional[dict]:
meta_path = os.path.join(_voice_path(name), 'meta.json')
if not os.path.exists(meta_path):
return None
with open(meta_path) as f:
return json.load(f)
def _decode_audio(data: str) -> tuple[bytes, str]:
"""Decode base64 audio data, return (bytes, ext)."""
if data.startswith('data:'):
mime, b64 = data.split(',', 1)
ext = '.' + mime.split('/')[1].split(';')[0]
return base64.b64decode(b64), ext
return base64.b64decode(data), '.wav'
def _f5tts_clone(ref_audio_path: str, ref_text: str, gen_text: str,
speed: float = 1.0, seed: Optional[int] = None) -> bytes:
"""Run F5-TTS voice cloning, return WAV bytes."""
from f5_tts.api import F5TTS
import soundfile as sf
import numpy as np
device = None
if global_args:
import torch
if torch.cuda.is_available():
device = 'cuda'
tts = F5TTS(device=device)
wav, sr, _ = tts.infer(
ref_file=ref_audio_path,
ref_text=ref_text,
gen_text=gen_text,
speed=speed,
seed=seed,
show_info=lambda x: None,
progress=lambda x, **kw: x,
)
buf = io.BytesIO()
sf.write(buf, wav, sr, format='WAV')
return buf.getvalue()
def _save_audio_response(audio_bytes: bytes, http_request: Request) -> dict:
import uuid
filename = f"{uuid.uuid4().hex}.wav"
if global_file_path:
os.makedirs(global_file_path, exist_ok=True)
fpath = os.path.join(global_file_path, filename)
with open(fpath, 'wb') as f:
f.write(audio_bytes)
host = http_request.headers.get('host', '127.0.0.1') if http_request else '127.0.0.1'
if ':' in host:
parts = host.split(':')
if len(parts) == 2 and parts[1].isdigit():
host = parts[0]
use_https = getattr(global_args, 'https', False) if global_args else False
proto = 'https' if use_https else 'http'
port = getattr(global_args, 'port', 8000) if global_args else 8000
return {"url": f"{proto}://{host}:{port}/v1/files/{filename}"}
return {"b64_wav": base64.b64encode(audio_bytes).decode()}
# ---------------------------------------------------------------------------
# Voice profile management
# ---------------------------------------------------------------------------
@router.get("/v1/audio/voices")
async def list_voices():
"""List all saved voice profiles."""
return {"voices": _list_voices()}
@router.post("/v1/audio/voices")
async def create_voice(
name: str = Form(...),
transcript: str = Form(...),
description: str = Form(''),
audio: UploadFile = File(...),
):
"""Save a named voice profile from a reference audio file + transcript."""
if not name.replace('-', '').replace('_', '').isalnum():
raise HTTPException(status_code=400, detail="Voice name must be alphanumeric (hyphens/underscores allowed)")
audio_bytes = await audio.read()
ext = os.path.splitext(audio.filename)[1] or '.wav'
# Validate audio is readable
try:
import soundfile as sf, io as _io
sf.info(_io.BytesIO(audio_bytes))
except Exception as e:
raise HTTPException(status_code=400, detail=f"Invalid audio file: {e}")
meta = _save_voice(name, audio_bytes, ext, transcript, description)
return {"created": True, "voice": meta}
@router.delete("/v1/audio/voices/{name}")
async def delete_voice(name: str):
"""Delete a saved voice profile."""
import shutil
vdir = _voice_path(name)
if not os.path.exists(vdir):
raise HTTPException(status_code=404, detail=f"Voice '{name}' not found")
shutil.rmtree(vdir)
return {"deleted": True, "name": name}
# ---------------------------------------------------------------------------
# Voice cloning TTS
# ---------------------------------------------------------------------------
class VoiceCloneRequest(BaseModel):
text: str # text to synthesize
voice_name: Optional[str] = None # use a saved voice profile
ref_audio: Optional[str] = None # base64 reference audio (if not using saved voice)
ref_text: Optional[str] = None # transcript of ref_audio
speed: Optional[float] = 1.0
seed: Optional[int] = None
response_format: Optional[str] = "url"
model_config = ConfigDict(extra="allow")
@router.post("/v1/audio/clone")
async def clone_voice(request: VoiceCloneRequest, http_request: Request = None):
"""
Synthesize speech in a cloned voice using F5-TTS.
Provide either:
- voice_name: name of a saved voice profile
- ref_audio (base64) + ref_text: inline reference audio
"""
# Resolve reference audio
ref_audio_path = None
ref_text = request.ref_text or ''
temps = []
try:
if request.voice_name:
meta = _load_voice(request.voice_name)
if not meta:
raise HTTPException(status_code=404, detail=f"Voice '{request.voice_name}' not found")
ref_audio_path = meta['audio_file']
ref_text = ref_text or meta.get('transcript', '')
elif request.ref_audio:
audio_bytes, ext = _decode_audio(request.ref_audio)
tmp = tempfile.NamedTemporaryFile(suffix=ext, delete=False)
tmp.write(audio_bytes)
tmp.close()
ref_audio_path = tmp.name
temps.append(ref_audio_path)
else:
raise HTTPException(status_code=400, detail="Provide voice_name or ref_audio")
if not ref_text:
raise HTTPException(status_code=400, detail="ref_text (transcript of reference audio) is required for voice cloning")
try:
audio_bytes = await asyncio.get_event_loop().run_in_executor(
None, _f5tts_clone,
ref_audio_path, ref_text, request.text,
request.speed or 1.0, request.seed,
)
except ImportError:
raise HTTPException(status_code=501, detail="f5-tts not installed. Run: pip install f5-tts")
except Exception as e:
raise HTTPException(status_code=500, detail=f"Voice cloning failed: {e}")
result = _save_audio_response(audio_bytes, http_request)
return {"created": int(time.time()), "data": [result]}
finally:
for t in temps:
try:
os.unlink(t)
except Exception:
pass
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
# codai.openai — optional LiteLLM integration layer
This diff is collapsed.
...@@ -57,9 +57,14 @@ class VideoGenerationRequest(BaseModel): ...@@ -57,9 +57,14 @@ class VideoGenerationRequest(BaseModel):
camera_motion: Optional[str] = None # zoom-in | zoom-out | pan-left | pan-right | tilt-up | tilt-down | rotate camera_motion: Optional[str] = None # zoom-in | zoom-out | pan-left | pan-right | tilt-up | tilt-down | rotate
# ── Character consistency ───────────────────────────────────────────── # ── Character consistency ─────────────────────────────────────────────
character_references: Optional[List[str]] = None # list of base64/URL reference images # Each entry: {"name": "Alice", "images": ["b64...", ...]}
characters: Optional[List[dict]] = None
# Legacy flat list of base64/URL reference images (still accepted)
character_references: Optional[List[str]] = None
character_strength: Optional[float] = 0.8 character_strength: Optional[float] = 0.8
character_names: Optional[List[str]] = None # optional names per reference character_names: Optional[List[str]] = None # optional names per reference
# Named saved profiles to load (resolved server-side)
character_profiles: Optional[List[str]] = None
# ── Audio generation / manipulation ────────────────────────────────── # ── Audio generation / manipulation ──────────────────────────────────
add_audio: Optional[bool] = False add_audio: Optional[bool] = False
......
...@@ -33,6 +33,12 @@ class QueueManager: ...@@ -33,6 +33,12 @@ class QueueManager:
self.model_loading: bool = False self.model_loading: bool = False
self.model_name: Optional[str] = None self.model_name: Optional[str] = None
self.lock = asyncio.Lock() self.lock = asyncio.Lock()
self.max_size: int = 6
async def is_full(self) -> bool:
"""Return True if the queue has reached max_size."""
async with self.lock:
return len(self.waiting_requests) >= self.max_size
async def add_waiting(self, request_id: str) -> None: async def add_waiting(self, request_id: str) -> None:
"""Add a request to the waiting queue.""" """Add a request to the waiting queue."""
......
videogen @ 04778e17
Subproject commit 04778e172a9a83d0778f566045f995828c6c3556
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment