Add video and pipelines

parent 74e32470
...@@ -139,7 +139,20 @@ if [ "$BACKEND" = "nvidia" ]; then ...@@ -139,7 +139,20 @@ if [ "$BACKEND" = "nvidia" ]; then
pip install -r requirements-nvidia.txt || { pip install -r requirements-nvidia.txt || {
echo -e "${YELLOW}Warning: Some NVIDIA packages failed to install${NC}" echo -e "${YELLOW}Warning: Some NVIDIA packages failed to install${NC}"
} }
# Extended modality dependencies (video/audio/embeddings/upscaling)
echo -e "${YELLOW}Installing extended modality dependencies...${NC}"
pip install "imageio[ffmpeg]" scipy soundfile sentence-transformers \
openai-whisper argostranslate edge-tts kokoro-tts timm || {
echo -e "${YELLOW}Warning: Some optional modality packages failed${NC}"
}
pip install realesrgan basicsr || {
echo -e "${YELLOW}Warning: realesrgan/basicsr failed (image upscaling optional)${NC}"
}
pip install audiocraft 2>/dev/null || {
echo -e "${YELLOW}Note: audiocraft not installed (audio generation with MusicGen optional)${NC}"
}
# Install Flash Attention 2 if requested # Install Flash Attention 2 if requested
if [ "$FLASH" = true ]; then if [ "$FLASH" = true ]; then
echo "" echo ""
...@@ -510,7 +523,22 @@ elif [ "$BACKEND" = "all" ]; then ...@@ -510,7 +523,22 @@ elif [ "$BACKEND" = "all" ]; then
pip install -r requirements-nvidia.txt || { pip install -r requirements-nvidia.txt || {
echo -e "${YELLOW}Warning: Some NVIDIA packages failed to install${NC}" echo -e "${YELLOW}Warning: Some NVIDIA packages failed to install${NC}"
} }
# Extended modality dependencies (video/audio/embeddings/upscaling)
echo -e "${YELLOW}Installing extended modality dependencies...${NC}"
pip install "imageio[ffmpeg]" scipy soundfile sentence-transformers \
openai-whisper argostranslate edge-tts kokoro-tts timm || {
echo -e "${YELLOW}Warning: Some optional modality packages failed${NC}"
}
pip install realesrgan basicsr || {
echo -e "${YELLOW}Warning: realesrgan/basicsr failed (image upscaling optional)${NC}"
}
# audiocraft (MusicGen/AudioGen) — Meta package, may fail on some Python versions
pip install audiocraft 2>/dev/null || {
echo -e "${YELLOW}Note: audiocraft not installed (audio generation with MusicGen optional)${NC}"
echo -e "${YELLOW} Install manually: pip install audiocraft${NC}"
}
# Check for Vulkan development libraries # Check for Vulkan development libraries
VULKAN_AVAILABLE=false VULKAN_AVAILABLE=false
if pkg-config --exists vulkan 2>/dev/null; then if pkg-config --exists vulkan 2>/dev/null; then
......
...@@ -252,7 +252,8 @@ async def api_status(username: str = Depends(require_auth)): ...@@ -252,7 +252,8 @@ async def api_status(username: str = Depends(require_auth)):
try: try:
if config_manager: if config_manager:
md = config_manager.models_data md = config_manager.models_data
for cat in ("text_models", "image_models", "audio_models", "vision_models", "tts_models"): for cat in ("text_models", "image_models", "audio_models", "vision_models", "tts_models",
"video_models", "audio_gen_models", "embedding_models"):
for m in md.get(cat, []): for m in md.get(cat, []):
mid = (m.get("path") or m.get("id") or m) if isinstance(m, dict) else m mid = (m.get("path") or m.get("id") or m) if isinstance(m, dict) else m
if mid and mid not in enabled_models: if mid and mid not in enabled_models:
...@@ -712,7 +713,8 @@ def _scan_caches() -> dict: ...@@ -712,7 +713,8 @@ def _scan_caches() -> dict:
if config_manager: if config_manager:
md = config_manager.models_data md = config_manager.models_data
for cat in ("text_models", "image_models", "audio_models", for cat in ("text_models", "image_models", "audio_models",
"gguf_models", "tts_models", "vision_models"): "gguf_models", "tts_models", "vision_models", "video_models",
"audio_gen_models", "embedding_models"):
for m in md.get(cat, []): for m in md.get(cat, []):
if isinstance(m, str): if isinstance(m, str):
p = m p = m
...@@ -1011,7 +1013,8 @@ async def api_model_enable(request: Request, username: str = Depends(require_adm ...@@ -1011,7 +1013,8 @@ async def api_model_enable(request: Request, username: str = Depends(require_adm
data = await request.json() data = await request.json()
path = data.get("path") or data.get("model_id", "") path = data.get("path") or data.get("model_id", "")
model_type = data.get("model_type", "text_models") model_type = data.get("model_type", "text_models")
valid = {"text_models", "image_models", "audio_models", "gguf_models", "tts_models", "vision_models"} valid = {"text_models", "image_models", "audio_models", "gguf_models", "tts_models", "vision_models",
"video_models", "audio_gen_models", "embedding_models"}
if model_type not in valid: if model_type not in valid:
raise HTTPException(status_code=400, detail=f"model_type must be one of {valid}") raise HTTPException(status_code=400, detail=f"model_type must be one of {valid}")
lst = config_manager.models_data.setdefault(model_type, []) lst = config_manager.models_data.setdefault(model_type, [])
...@@ -1030,7 +1033,8 @@ async def api_model_disable(request: Request, username: str = Depends(require_ad ...@@ -1030,7 +1033,8 @@ async def api_model_disable(request: Request, username: str = Depends(require_ad
path = data.get("path") or data.get("model_id", "") path = data.get("path") or data.get("model_id", "")
changed = False changed = False
for cat in ("text_models", "image_models", "audio_models", for cat in ("text_models", "image_models", "audio_models",
"gguf_models", "tts_models", "vision_models"): "gguf_models", "tts_models", "vision_models", "video_models",
"audio_gen_models", "embedding_models"):
lst = config_manager.models_data.get(cat, []) lst = config_manager.models_data.get(cat, [])
new_lst = [m for m in lst new_lst = [m for m in lst
if (m if isinstance(m, str) else m.get("path", m.get("id", ""))) != path] if (m if isinstance(m, str) else m.get("path", m.get("id", ""))) != path]
...@@ -1063,7 +1067,10 @@ async def api_model_load(request: Request, username: str = Depends(require_admin ...@@ -1063,7 +1067,10 @@ async def api_model_load(request: Request, username: str = Depends(require_admin
if config_manager: if config_manager:
md = config_manager.models_data md = config_manager.models_data
for cat, mtype in (("image_models", "image"), ("audio_models", "audio"), for cat, mtype in (("image_models", "image"), ("audio_models", "audio"),
("vision_models", "vision"), ("tts_models", "tts")): ("vision_models", "vision"), ("tts_models", "tts"),
("video_models", "video"),
("audio_gen_models", "audio_gen"),
("embedding_models", "embedding")):
for m in md.get(cat, []): for m in md.get(cat, []):
mid = m if isinstance(m, str) else m.get("path") or m.get("id") or "" mid = m if isinstance(m, str) else m.get("path") or m.get("id") or ""
if mid == path: if mid == path:
...@@ -1158,7 +1165,8 @@ async def api_model_configure(request: Request, username: str = Depends(require_ ...@@ -1158,7 +1165,8 @@ async def api_model_configure(request: Request, username: str = Depends(require_
# Treat legacy gguf_models as text_models (GGUF is a format, not a type) # Treat legacy gguf_models as text_models (GGUF is a format, not a type)
if model_type == "gguf_models": if model_type == "gguf_models":
model_type = "text_models" model_type = "text_models"
valid = {"text_models", "image_models", "audio_models", "tts_models", "vision_models"} valid = {"text_models", "image_models", "audio_models", "tts_models", "vision_models", "video_models",
"audio_gen_models", "embedding_models"}
if not path: if not path:
raise HTTPException(status_code=400, detail="path is required") raise HTTPException(status_code=400, detail="path is required")
if model_type not in valid: if model_type not in valid:
......
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
</a> </a>
<div class="nav-links"> <div class="nav-links">
<a href="/admin" class="nav-link {% if request.url.path == '/admin' %}active{% endif %}">Overview</a> <a href="/admin" class="nav-link {% if request.url.path == '/admin' %}active{% endif %}">Overview</a>
<a href="/chat" class="nav-link {% if request.url.path == '/chat' %}active{% endif %}">Chat</a> <a href="/chat" class="nav-link {% if request.url.path == '/chat' %}active{% endif %}">Studio</a>
{% if is_admin|default(false) %} {% if is_admin|default(false) %}
<a href="/admin/models" class="nav-link {% if '/models' in request.url.path %}active{% endif %}">Models</a> <a href="/admin/models" class="nav-link {% if '/models' in request.url.path %}active{% endif %}">Models</a>
<a href="/admin/tokens" class="nav-link {% if '/tokens' in request.url.path %}active{% endif %}">Tokens</a> <a href="/admin/tokens" class="nav-link {% if '/tokens' in request.url.path %}active{% endif %}">Tokens</a>
......
{% extends "base.html" %} {% extends "base.html" %}
{% block title %}Chat — CoderAI{% endblock %} {% block title %}Studio — CoderAI{% endblock %}
{% block wrapper_class %}{% endblock %} {% block wrapper_class %}{% endblock %}
{% block head %}
<style>
/* ── Layout ─────────────────────────────────────────────────────── */
.studio { display:flex; height:calc(100vh - 56px); overflow:hidden; }
/* Sidebar */
.sidebar {
width:220px; min-width:180px; background:var(--surface-1);
border-right:1px solid var(--border); display:flex; flex-direction:column;
overflow:hidden; flex-shrink:0;
}
.sidebar-hd { padding:.6rem 1rem .15rem; font-size:10px; font-weight:700;
color:var(--text-3); letter-spacing:.07em; text-transform:uppercase; }
.model-list { flex:1; overflow-y:auto; padding:.2rem .4rem .5rem; }
.model-item {
display:flex; align-items:center; gap:.4rem; padding:.4rem .55rem;
border-radius:6px; cursor:pointer; font-size:12px; color:var(--text-2);
transition:background .1s;
}
.model-item:hover { background:var(--surface-2); }
.model-item.active { background:var(--accent-dim,#1d3354); color:var(--accent,#4e9cf5); font-weight:500; }
.mbadge {
font-size:9px; font-weight:700; padding:1px 5px; border-radius:20px;
letter-spacing:.03em; text-transform:uppercase; flex-shrink:0;
}
.mb-text { background:#1d3250; color:#7aaef7; }
.mb-vision { background:#1a2e3a; color:#5ed3f5; }
.mb-image { background:#1d3520; color:#6ecf7e; }
.mb-video { background:#301a40; color:#c07af5; }
.mb-audio { background:#3a2510; color:#f0a844; }
.mb-tts { background:#2d2010; color:#f0c060; }
.mb-audiogen{ background:#1a2535; color:#70b8f5; }
.mb-embed { background:#1e2e1e; color:#88c888; }
/* Main */
.studio-main { flex:1; display:flex; flex-direction:column; overflow:hidden; }
/* Two-level tab bar */
.tabbar1 {
display:flex; gap:.2rem; padding:.45rem .6rem .3rem;
border-bottom:1px solid var(--border); background:var(--surface-0); flex-shrink:0;
overflow-x:auto;
}
.tabbar2 {
display:none; gap:.15rem; padding:.3rem .6rem;
border-bottom:1px solid var(--border); background:var(--surface-1); flex-shrink:0;
overflow-x:auto;
}
.tabbar2.visible { display:flex; }
.t1btn, .t2btn {
padding:.28rem .65rem; border-radius:5px; font-size:12px; font-weight:500;
cursor:pointer; border:1px solid transparent; color:var(--text-3);
background:transparent; transition:all .1s; white-space:nowrap; flex-shrink:0;
}
.t1btn:hover, .t2btn:hover { background:var(--surface-2); color:var(--text-1); }
.t1btn.active { background:var(--accent,#4e9cf5); color:#fff; }
.t2btn.active { background:var(--surface-3,#333); color:var(--text-1); border-color:var(--border); }
.t1btn.hidden, .t2btn.hidden { display:none; }
/* Panels */
.panel { flex:1; display:none; flex-direction:column; overflow:hidden; }
.panel.active { display:flex; }
/* ── Chat ─────────────────────────────────────────────────────── */
.chat-msgs { flex:1; overflow-y:auto; padding:1rem 1.25rem; display:flex; flex-direction:column; gap:.75rem; }
.chat-empty { margin:auto; text-align:center; color:var(--text-3); }
.chat-empty h3 { font-size:1rem; margin-bottom:.25rem; }
.msg { display:flex; gap:.75rem; }
.msg.user { flex-direction:row-reverse; }
.av { width:28px; height:28px; border-radius:50%; display:flex; align-items:center;
justify-content:center; font-size:9px; font-weight:700; flex-shrink:0; }
.av.user { background:var(--accent,#4e9cf5); color:#fff; }
.av.ai { background:var(--surface-3,#2a2a2a); color:var(--text-2); }
.msg-body { max-width:70%; }
.msg.user .msg-body { text-align:right; }
.msg-meta { font-size:10px; color:var(--text-3); margin-bottom:.2rem; }
.msg-text { background:var(--surface-2); padding:.5rem .75rem; border-radius:8px;
font-size:13px; line-height:1.55; white-space:pre-wrap; word-break:break-word; }
.msg.user .msg-text { background:var(--accent-dim,#1d3354); }
.msg-img { max-width:280px; border-radius:8px; margin-top:.3rem; cursor:pointer; }
.chat-foot { flex-shrink:0; padding:.5rem .75rem .75rem; border-top:1px solid var(--border); }
.attach-bar { display:flex; align-items:center; gap:.5rem; margin-bottom:.3rem; }
.attach-thumb { width:36px; height:36px; border-radius:4px; object-fit:cover; }
.chat-row { display:flex; gap:.5rem; align-items:flex-end; }
.chat-ta { flex:1; resize:none; border-radius:6px; border:1px solid var(--border);
background:var(--surface-2); color:var(--text-1); padding:.45rem .7rem;
font-size:13px; font-family:inherit; outline:none; min-height:36px; max-height:140px; }
.chat-hint { font-size:10px; color:var(--text-3); margin-top:.2rem; text-align:right; }
/* ── Shared panel helpers ─────────────────────────────────────── */
.gen-wrap { flex:1; display:flex; overflow:hidden; }
.gen-ctrl { width:290px; padding:.9rem 1rem; overflow-y:auto; border-right:1px solid var(--border);
background:var(--surface-1); display:flex; flex-direction:column; gap:.65rem; flex-shrink:0; }
.gen-out { flex:1; display:flex; align-items:center; justify-content:center;
overflow:auto; padding:1.2rem; background:var(--surface-0); }
.gen-out-inner { display:flex; flex-direction:column; align-items:center; gap:.6rem; width:100%; max-width:960px; }
.gen-empty { color:var(--text-3); text-align:center; font-size:13px; }
.out-img { max-width:100%; max-height:calc(100vh - 200px); border-radius:8px; cursor:pointer; }
.out-video { max-width:100%; max-height:calc(100vh - 200px); border-radius:8px; }
.out-audio { width:100%; }
.fl { font-size:11px; font-weight:600; color:var(--text-2); margin-bottom:.15rem; display:block; }
.progress { font-size:11px; color:var(--text-3); min-height:14px; }
.fi, .fs, .fselect { background:var(--surface-2); border:1px solid var(--border); border-radius:5px;
color:var(--text-1); padding:.38rem .6rem; font-size:13px; font-family:inherit;
outline:none; width:100%; box-sizing:border-box; }
.fs { resize:vertical; min-height:70px; }
.fselect { appearance:auto; }
.g2 { display:grid; grid-template-columns:1fr 1fr; gap:.5rem; }
.g3 { display:grid; grid-template-columns:1fr 1fr 1fr; gap:.5rem; }
.frow { display:flex; flex-direction:column; gap:.15rem; }
.char-refs { display:flex; flex-wrap:wrap; gap:.3rem; margin-top:.25rem; }
.char-thumb { width:44px; height:44px; object-fit:cover; border-radius:4px; }
/* download link styled as button */
a.dl { display:inline-block; margin-top:.4rem; }
/* ── Pipeline ─────────────────────────────────────────────────── */
.pipe-panel { flex:1; overflow-y:auto; padding:1.2rem; display:flex; flex-direction:column; gap:1.2rem; }
.pipe-card { background:var(--surface-1); border:1px solid var(--border); border-radius:8px; padding:1rem; }
.pipe-title { font-weight:600; font-size:13px; margin-bottom:.75rem; }
.pipe-steps { display:flex; align-items:center; gap:.5rem; flex-wrap:wrap; font-size:12px; color:var(--text-3); }
.pipe-step { background:var(--surface-2); border-radius:5px; padding:.25rem .6rem; color:var(--text-2); }
.pipe-arrow { color:var(--text-3); }
</style>
{% endblock %}
{% block content %} {% block content %}
<div class="chat-wrap" style="margin:0 1.5rem 1rem;border-radius:8px"> <div class="studio">
<div class="chat-bar">
<h2>Chat</h2> <!-- Sidebar -->
<div class="chat-controls"> <aside class="sidebar">
<select id="model-sel" class="form-input" style="font-size:13px;padding:.3rem .625rem;min-width:200px"> <div class="sidebar-hd">Models</div>
<option value="">Select model…</option> <div class="model-list" id="model-list"><div class="muted small" style="padding:.5rem .6rem">Loading…</div></div>
</select> </aside>
<button class="btn btn-ghost btn-sm" onclick="newChat()">Clear</button>
<div class="studio-main">
<!-- Level-1 tabs (category) -->
<div class="tabbar1" id="tabbar1">
<button class="t1btn active" data-cat="chat" onclick="selectCat('chat')">Chat</button>
<button class="t1btn hidden" data-cat="image" onclick="selectCat('image')">Image</button>
<button class="t1btn hidden" data-cat="video" onclick="selectCat('video')">Video</button>
<button class="t1btn hidden" data-cat="audio" onclick="selectCat('audio')">Audio</button>
<button class="t1btn hidden" data-cat="pipe" onclick="selectCat('pipe')">Pipelines</button>
<button class="t1btn hidden" data-cat="embed" onclick="selectCat('embed')">Embed</button>
</div> </div>
</div>
<div class="chat-messages" id="chat-msgs"> <!-- Level-2 tabs (sub-mode) -->
<div class="chat-empty"> <div class="tabbar2" id="tabbar2">
<h3>CoderAI Chat</h3> <!-- Image sub-tabs -->
<p>Select a model and start typing</p> <button class="t2btn hidden" data-sub="img-gen" onclick="selectSub('img-gen')">Generate</button>
<button class="t2btn hidden" data-sub="img-edit" onclick="selectSub('img-edit')">Edit (i2i)</button>
<button class="t2btn hidden" data-sub="img-inpaint" onclick="selectSub('img-inpaint')">Inpaint</button>
<button class="t2btn hidden" data-sub="img-upscale" onclick="selectSub('img-upscale')">Upscale</button>
<button class="t2btn hidden" data-sub="img-depth" onclick="selectSub('img-depth')">Depth</button>
<button class="t2btn hidden" data-sub="img-seg" onclick="selectSub('img-seg')">Segment</button>
<!-- Video sub-tabs -->
<button class="t2btn hidden" data-sub="vid-t2v" onclick="selectSub('vid-t2v')">Text→Video</button>
<button class="t2btn hidden" data-sub="vid-i2v" onclick="selectSub('vid-i2v')">Image→Video</button>
<button class="t2btn hidden" data-sub="vid-v2v" onclick="selectSub('vid-v2v')">Vid→Vid</button>
<button class="t2btn hidden" data-sub="vid-ti2v" onclick="selectSub('vid-ti2v')">Ti2V</button>
<button class="t2btn hidden" data-sub="vid-interp" onclick="selectSub('vid-interp')">Interpolate</button>
<button class="t2btn hidden" data-sub="vid-sub" onclick="selectSub('vid-sub')">Subtitles</button>
<button class="t2btn hidden" data-sub="vid-dub" onclick="selectSub('vid-dub')">Dub</button>
<button class="t2btn hidden" data-sub="vid-up" onclick="selectSub('vid-up')">Upscale</button>
<!-- Audio sub-tabs -->
<button class="t2btn hidden" data-sub="aud-gen" onclick="selectSub('aud-gen')">Generate Music/SFX</button>
<button class="t2btn hidden" data-sub="aud-tts" onclick="selectSub('aud-tts')">TTS</button>
<button class="t2btn hidden" data-sub="aud-stt" onclick="selectSub('aud-stt')">Transcribe</button>
</div> </div>
</div>
<div class="chat-foot"> <!-- ═══════════════ CHAT ═══════════════ -->
<div id="typing" style="font-size:11px;color:var(--text-3);height:14px;margin-bottom:.3rem;font-family:var(--mono)"></div> <div class="panel active" id="panel-chat">
<div class="chat-input-row"> <div class="chat-msgs" id="chat-msgs">
<textarea id="chat-in" class="chat-textarea" placeholder="Send a message…" rows="1"></textarea> <div class="chat-empty"><h3>CoderAI Studio</h3><p>Select a model from the sidebar</p></div>
<button class="btn btn-primary" id="send-btn" onclick="send()" style="padding:.5rem .75rem;align-self:flex-end"> </div>
<svg viewBox="0 0 16 16" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" style="width:13px;height:13px"><line x1="14" y1="2" x2="7" y2="9"/><polygon points="14 2 10 14 7 9 2 6 14 2"/></svg> <div class="chat-foot">
</button> <div class="attach-bar" id="attach-bar" style="display:none">
<img id="attach-thumb" class="attach-thumb" src="" alt="">
<span style="font-size:11px;color:var(--text-3)" id="attach-name"></span>
<button class="btn btn-ghost btn-sm" onclick="clearAttach()" style="margin-left:auto"></button>
</div>
<div id="typing" style="font-size:11px;color:var(--text-3);height:14px;margin-bottom:.2rem;font-family:var(--mono)"></div>
<div class="chat-row">
<label id="attach-btn" style="display:none;cursor:pointer;align-self:flex-end;padding:.35rem" title="Attach image">
<svg viewBox="0 0 16 16" fill="none" stroke="currentColor" stroke-width="1.8" style="width:15px;height:15px"><path d="M2 10l4-4 3 3 2-2 3 3"/><rect x="1" y="1" width="14" height="14" rx="2"/></svg>
<input type="file" accept="image/*" id="vision-input" style="display:none" onchange="attachImg(this)">
</label>
<textarea id="chat-in" class="chat-ta" placeholder="Send a message…" rows="1"></textarea>
<button class="btn btn-primary" id="send-btn" onclick="sendChat()" style="padding:.45rem .7rem;align-self:flex-end">
<svg viewBox="0 0 16 16" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" style="width:13px;height:13px"><line x1="14" y1="2" x2="7" y2="9"/><polygon points="14 2 10 14 7 9 2 6 14 2"/></svg>
</button>
</div>
<div class="chat-hint">Enter to send · Shift+Enter for newline</div>
</div>
</div> </div>
<div class="chat-hint">Enter to send · Shift+Enter for newline</div>
</div> <!-- ═══════════════ IMAGE ═══════════════ -->
</div>
<!-- Generate -->
<div class="panel" id="panel-img-gen">
<div class="gen-wrap">
<div class="gen-ctrl">
<div class="frow"><label class="fl">Prompt</label><textarea id="ig-prompt" class="fs" placeholder="Describe the image…" rows="4"></textarea></div>
<div class="frow"><label class="fl">Negative prompt</label><textarea id="ig-neg" class="fs" placeholder="Things to avoid…" rows="2"></textarea></div>
<div class="g2">
<div class="frow"><label class="fl">Width</label><input type="number" id="ig-w" class="fi" value="1024" step="64"></div>
<div class="frow"><label class="fl">Height</label><input type="number" id="ig-h" class="fi" value="1024" step="64"></div>
</div>
<div class="g2">
<div class="frow"><label class="fl">Steps</label><input type="number" id="ig-steps" class="fi" value="30" min="1"></div>
<div class="frow"><label class="fl">CFG</label><input type="number" id="ig-cfg" class="fi" value="7.5" step="0.5"></div>
</div>
<div class="frow"><label class="fl">Seed</label><input type="number" id="ig-seed" class="fi" placeholder="random"></div>
<div class="frow"><label class="fl">Count (n)</label><input type="number" id="ig-n" class="fi" value="1" min="1" max="8"></div>
<div class="frow"><label class="fl" title="Removes the model's built-in safety checker. Only use with uncensored/NSFW fine-tunes.">Disable safety filter</label><input type="checkbox" id="ig-nosafe"></div>
<button class="btn btn-primary" onclick="genImage()">Generate</button>
<div class="progress" id="ig-prog"></div>
</div>
<div class="gen-out" id="ig-out"><div class="gen-empty">Image will appear here</div></div>
</div>
</div>
<!-- Edit (i2i) -->
<div class="panel" id="panel-img-edit">
<div class="gen-wrap">
<div class="gen-ctrl">
<div class="frow"><label class="fl">Source image</label><input type="file" id="ie-src" accept="image/*" class="fi"></div>
<div class="frow"><label class="fl">Prompt</label><textarea id="ie-prompt" class="fs" rows="4" placeholder="Describe the desired change…"></textarea></div>
<div class="frow"><label class="fl">Strength <span id="ie-sv">0.75</span></label>
<input type="range" id="ie-str" min="0" max="1" step="0.05" value="0.75" oninput="$('ie-sv').textContent=this.value" style="width:100%"></div>
<div class="g2">
<div class="frow"><label class="fl">Steps</label><input type="number" id="ie-steps" class="fi" value="30"></div>
<div class="frow"><label class="fl">CFG</label><input type="number" id="ie-cfg" class="fi" value="7.5" step="0.5"></div>
</div>
<div class="frow"><label class="fl">Seed</label><input type="number" id="ie-seed" class="fi" placeholder="random"></div>
<button class="btn btn-primary" onclick="genEdit()">Edit Image</button>
<div class="progress" id="ie-prog"></div>
</div>
<div class="gen-out" id="ie-out"><div class="gen-empty">Edited image will appear here</div></div>
</div>
</div>
<!-- Inpaint -->
<div class="panel" id="panel-img-inpaint">
<div class="gen-wrap">
<div class="gen-ctrl">
<div class="frow"><label class="fl">Source image</label><input type="file" id="ip-src" accept="image/*" class="fi"></div>
<div class="frow"><label class="fl">Mask image <span style="font-weight:400">(white = fill area)</span></label><input type="file" id="ip-mask" accept="image/*" class="fi"></div>
<div class="frow"><label class="fl">Prompt</label><textarea id="ip-prompt" class="fs" rows="4" placeholder="What to fill the masked area with…"></textarea></div>
<div class="g2">
<div class="frow"><label class="fl">Steps</label><input type="number" id="ip-steps" class="fi" value="30"></div>
<div class="frow"><label class="fl">CFG</label><input type="number" id="ip-cfg" class="fi" value="7.5" step="0.5"></div>
</div>
<div class="frow"><label class="fl">Strength <span id="ip-sv">0.99</span></label>
<input type="range" id="ip-str" min="0" max="1" step="0.05" value="0.99" oninput="$('ip-sv').textContent=this.value" style="width:100%"></div>
<div class="frow"><label class="fl">Seed</label><input type="number" id="ip-seed" class="fi" placeholder="random"></div>
<button class="btn btn-primary" onclick="genInpaint()">Inpaint</button>
<div class="progress" id="ip-prog"></div>
</div>
<div class="gen-out" id="ip-out"><div class="gen-empty">Result will appear here</div></div>
</div>
</div>
<!-- Upscale image -->
<div class="panel" id="panel-img-upscale">
<div class="gen-wrap">
<div class="gen-ctrl">
<div class="frow"><label class="fl">Source image</label><input type="file" id="iu-src" accept="image/*" class="fi"></div>
<div class="frow"><label class="fl">Scale factor</label>
<select id="iu-scale" class="fselect"><option value="2"></option><option value="4" selected></option><option value="8"></option></select>
</div>
<button class="btn btn-primary" onclick="genImgUpscale()">Upscale</button>
<div class="progress" id="iu-prog"></div>
</div>
<div class="gen-out" id="iu-out"><div class="gen-empty">Upscaled image will appear here</div></div>
</div>
</div>
<!-- Depth -->
<div class="panel" id="panel-img-depth">
<div class="gen-wrap">
<div class="gen-ctrl">
<div class="frow"><label class="fl">Source image</label><input type="file" id="id-src" accept="image/*" class="fi"></div>
<button class="btn btn-primary" onclick="genDepth()">Estimate Depth</button>
<div class="progress" id="id-prog"></div>
</div>
<div class="gen-out" id="id-out"><div class="gen-empty">Depth map will appear here</div></div>
</div>
</div>
<!-- Segment -->
<div class="panel" id="panel-img-seg">
<div class="gen-wrap">
<div class="gen-ctrl">
<div class="frow"><label class="fl">Source image</label><input type="file" id="is-src" accept="image/*" class="fi"></div>
<div class="frow"><label class="fl">Point prompts (optional)</label>
<input type="text" id="is-pts" class="fi" placeholder='[[x,y], [x,y]]'></div>
<div class="frow"><label class="fl">Box prompts (optional)</label>
<input type="text" id="is-boxes" class="fi" placeholder='[[x1,y1,x2,y2]]'></div>
<button class="btn btn-primary" onclick="genSegment()">Segment</button>
<div class="progress" id="is-prog"></div>
</div>
<div class="gen-out" id="is-out"><div class="gen-empty">Segmented image will appear here</div></div>
</div>
</div>
<!-- ═══════════════ VIDEO ═══════════════ -->
<!-- T2V -->
<div class="panel" id="panel-vid-t2v"><div class="gen-wrap">
<div class="gen-ctrl" id="vid-shared-ctrl-t2v">
{# vid_ctrl and vid_postproc are injected by JS below #}
<button class="btn btn-primary" onclick="genVideo('t2v')">Generate Video</button>
<div class="progress" id="vt-prog"></div>
</div>
<div class="gen-out" id="vt-out"><div class="gen-empty">Video will appear here</div></div>
</div></div>
<!-- I2V -->
<div class="panel" id="panel-vid-i2v"><div class="gen-wrap">
<div class="gen-ctrl">
<div class="frow"><label class="fl">Source image</label><input type="file" id="vi-src" accept="image/*" class="fi"></div>
<div class="frow"><label class="fl">Motion prompt (optional)</label><textarea id="vi-prompt" class="fs" rows="2" placeholder="How should the image animate?"></textarea></div>
<div class="g2">
<div class="frow"><label class="fl">Frames</label><input type="number" id="vi-frames" class="fi" value="16"></div>
<div class="frow"><label class="fl">FPS</label><input type="number" id="vi-fps" class="fi" value="8"></div>
</div>
<div class="g2">
<div class="frow"><label class="fl">Steps</label><input type="number" id="vi-steps" class="fi" value="25"></div>
<div class="frow"><label class="fl">Guidance</label><input type="number" id="vi-cfg" class="fi" value="7.5" step="0.5"></div>
</div>
<div class="frow"><label class="fl">Seed</label><input type="number" id="vi-seed" class="fi" placeholder="random"></div>
<button class="btn btn-primary" onclick="genVideo('i2v')">Animate</button>
<div class="progress" id="vi-prog"></div>
</div>
<div class="gen-out" id="vi-out"><div class="gen-empty">Animated video will appear here</div></div>
</div></div>
<!-- V2V -->
<div class="panel" id="panel-vid-v2v"><div class="gen-wrap">
<div class="gen-ctrl">
<div class="frow"><label class="fl">Source video</label><input type="file" id="vv-src" accept="video/*" class="fi"></div>
<div class="frow"><label class="fl">Prompt</label><textarea id="vv-prompt" class="fs" rows="3" placeholder="Describe the transformation…"></textarea></div>
<div class="frow"><label class="fl">Strength <span id="vv-sv">0.70</span></label>
<input type="range" id="vv-str" min="0" max="1" step="0.05" value="0.70" oninput="$('vv-sv').textContent=(+this.value).toFixed(2)" style="width:100%"></div>
<div class="g2">
<div class="frow"><label class="fl">Steps</label><input type="number" id="vv-steps" class="fi" value="25"></div>
<div class="frow"><label class="fl">Guidance</label><input type="number" id="vv-cfg" class="fi" value="7.5" step="0.5"></div>
</div>
<div class="frow"><label class="fl">Seed</label><input type="number" id="vv-seed" class="fi" placeholder="random"></div>
<button class="btn btn-primary" onclick="genVideo('v2v')">Transform</button>
<div class="progress" id="vv-prog"></div>
</div>
<div class="gen-out" id="vv-out"><div class="gen-empty">Transformed video will appear here</div></div>
</div></div>
<!-- Ti2V: Text + Image → Video -->
<div class="panel" id="panel-vid-ti2v"><div class="gen-wrap">
<div class="gen-ctrl">
<div class="frow"><label class="fl">Initial image</label><input type="file" id="ti-init" accept="image/*" class="fi"></div>
<div class="frow"><label class="fl">End image (optional — for frame interpolation)</label><input type="file" id="ti-end" accept="image/*" class="fi"></div>
<div class="frow"><label class="fl">Prompt</label><textarea id="ti-prompt" class="fs" rows="3" placeholder="Describe the video motion / story…"></textarea></div>
<div class="frow"><label class="fl">Negative prompt</label><textarea id="ti-neg" class="fs" rows="2" placeholder="Things to avoid…"></textarea></div>
<div class="g2">
<div class="frow"><label class="fl">Width</label><input type="number" id="ti-w" class="fi" value="512" step="64"></div>
<div class="frow"><label class="fl">Height</label><input type="number" id="ti-h" class="fi" value="512" step="64"></div>
</div>
<div class="g2">
<div class="frow"><label class="fl">Frames</label><input type="number" id="ti-frames" class="fi" value="16"></div>
<div class="frow"><label class="fl">FPS</label><input type="number" id="ti-fps" class="fi" value="8"></div>
</div>
<div class="g2">
<div class="frow"><label class="fl">Steps</label><input type="number" id="ti-steps" class="fi" value="25"></div>
<div class="frow"><label class="fl">Guidance</label><input type="number" id="ti-cfg" class="fi" value="7.5" step="0.5"></div>
</div>
<div class="frow"><label class="fl">Camera motion</label>
<select id="ti-cam" class="fselect">
<option value="">None</option>
<option value="zoom-in">Zoom in</option><option value="zoom-out">Zoom out</option>
<option value="pan-left">Pan left</option><option value="pan-right">Pan right</option>
<option value="tilt-up">Tilt up</option><option value="tilt-down">Tilt down</option>
<option value="rotate">Rotate</option>
</select>
</div>
<div class="frow"><label class="fl">Seed</label><input type="number" id="ti-seed" class="fi" placeholder="random"></div>
<div class="frow"><label class="fl" title="Removes the model's built-in safety checker. Only use with uncensored/NSFW fine-tunes.">Disable safety filter</label><input type="checkbox" id="ti-nosafe"></div>
<!-- Character consistency -->
<details style="margin-top:.25rem">
<summary style="font-size:12px;font-weight:600;cursor:pointer;color:var(--text-2)">Character consistency</summary>
<div style="margin-top:.5rem;display:flex;flex-direction:column;gap:.4rem">
<div class="frow">
<label class="fl">Reference images</label>
<input type="file" id="ti-char" accept="image/*" multiple class="fi" onchange="addCharRefs('ti',this)">
</div>
<div class="char-refs" id="ti-char-thumbs"></div>
<div class="frow"><label class="fl">Strength <span id="ti-cs">0.80</span></label>
<input type="range" id="ti-cstr" min="0" max="1" step="0.05" value="0.8" oninput="$('ti-cs').textContent=(+this.value).toFixed(2)" style="width:100%"></div>
</div>
</details>
<!-- Audio -->
<details style="margin-top:.25rem">
<summary style="font-size:12px;font-weight:600;cursor:pointer;color:var(--text-2)">Add audio</summary>
<div style="margin-top:.5rem;display:flex;flex-direction:column;gap:.4rem">
<div class="frow"><label class="fl">Audio type</label>
<select id="ti-atype" class="fselect">
<option value="">None</option>
<option value="music">Music</option><option value="sfx">SFX</option>
<option value="speech">Speech / TTS</option><option value="ambient">Ambient</option>
</select>
</div>
<div class="frow"><label class="fl">Audio prompt</label><textarea id="ti-aprompt" class="fs" rows="2" placeholder="e.g. epic orchestral soundtrack"></textarea></div>
<div class="frow"><label class="fl">TTS text (for speech)</label><textarea id="ti-ttstext" class="fs" rows="2"></textarea></div>
<div class="g2">
<div class="frow"><label class="fl">TTS voice</label><input id="ti-voice" class="fi" placeholder="af_sky"></div>
<div class="frow"><label class="fl">Speed</label><input type="number" id="ti-speed" class="fi" value="1.0" step="0.1" min="0.5"></div>
</div>
<label style="font-size:12px;display:flex;align-items:center;gap:.4rem;cursor:pointer">
<input type="checkbox" id="ti-lipsync"> Lip-sync to speech
</label>
</div>
</details>
<!-- Subtitles -->
<details style="margin-top:.25rem">
<summary style="font-size:12px;font-weight:600;cursor:pointer;color:var(--text-2)">Subtitles</summary>
<div style="margin-top:.5rem;display:flex;flex-direction:column;gap:.4rem">
<label style="font-size:12px;display:flex;align-items:center;gap:.4rem;cursor:pointer"><input type="checkbox" id="ti-gensub"> Generate subtitles</label>
<label style="font-size:12px;display:flex;align-items:center;gap:.4rem;cursor:pointer"><input type="checkbox" id="ti-burnsub"> Burn subtitles into video</label>
<label style="font-size:12px;display:flex;align-items:center;gap:.4rem;cursor:pointer"><input type="checkbox" id="ti-transsub"> Translate subtitles</label>
<div class="frow"><label class="fl">Target language</label><input id="ti-sublang" class="fi" placeholder="es, fr, de, it…"></div>
</div>
</details>
<!-- Post-processing -->
<details style="margin-top:.25rem">
<summary style="font-size:12px;font-weight:600;cursor:pointer;color:var(--text-2)">Post-processing</summary>
<div style="margin-top:.5rem;display:flex;flex-direction:column;gap:.4rem">
<label style="font-size:12px;display:flex;align-items:center;gap:.4rem;cursor:pointer"><input type="checkbox" id="ti-upscale"> Upscale output</label>
<div class="frow"><label class="fl">Upscale factor</label>
<select id="ti-upfact" class="fselect"><option value="2"></option><option value="4"></option></select>
</div>
<label style="font-size:12px;display:flex;align-items:center;gap:.4rem;cursor:pointer"><input type="checkbox" id="ti-interp"> Frame interpolation</label>
<div class="frow"><label class="fl">FPS multiplier</label>
<select id="ti-fmult" class="fselect"><option value="2"></option><option value="4"></option></select>
</div>
</div>
</details>
<button class="btn btn-primary" onclick="genTi2V()" style="margin-top:.25rem">Generate</button>
<div class="progress" id="ti-prog"></div>
</div>
<div class="gen-out" id="ti-out"><div class="gen-empty">Video will appear here</div></div>
</div></div>
<!-- Interpolate -->
<div class="panel" id="panel-vid-interp"><div class="gen-wrap">
<div class="gen-ctrl">
<div class="frow"><label class="fl">Source video (or use frame images below)</label><input type="file" id="vc-src" accept="video/*" class="fi"></div>
<div class="frow"><label class="fl">— OR — First frame</label><input type="file" id="vc-init" accept="image/*" class="fi"></div>
<div class="frow"><label class="fl">Last frame</label><input type="file" id="vc-end" accept="image/*" class="fi"></div>
<div class="frow"><label class="fl">FPS multiplier</label>
<select id="vc-mult" class="fselect"><option value="2"></option><option value="4"></option><option value="8"></option></select>
</div>
<button class="btn btn-primary" onclick="genInterp()">Interpolate</button>
<div class="progress" id="vc-prog"></div>
</div>
<div class="gen-out" id="vc-out"><div class="gen-empty">Interpolated video will appear here</div></div>
</div></div>
<!-- Subtitles -->
<div class="panel" id="panel-vid-sub"><div class="gen-wrap">
<div class="gen-ctrl">
<div class="frow"><label class="fl">Source video</label><input type="file" id="vs-src" accept="video/*" class="fi"></div>
<div class="frow"><label class="fl">Language hint</label><input id="vs-lang" class="fi" placeholder="en, fr, de… (auto-detect if blank)"></div>
<div class="frow"><label class="fl">Output format</label>
<select id="vs-fmt" class="fselect"><option value="srt">SRT text</option><option value="burned_video">Burn into video</option></select>
</div>
<label style="font-size:12px;display:flex;align-items:center;gap:.4rem;cursor:pointer"><input type="checkbox" id="vs-translate"> Translate subtitles</label>
<div class="frow"><label class="fl">Target language</label><input id="vs-tlang" class="fi" placeholder="es, fr, de…"></div>
<button class="btn btn-primary" onclick="genSubtitles()">Generate Subtitles</button>
<div class="progress" id="vs-prog"></div>
</div>
<div class="gen-out" id="vs-out"><div class="gen-empty">Result will appear here</div></div>
</div></div>
<!-- Dub -->
<div class="panel" id="panel-vid-dub"><div class="gen-wrap">
<div class="gen-ctrl">
<div class="frow"><label class="fl">Source video</label><input type="file" id="vd-src" accept="video/*" class="fi"></div>
<div class="g2">
<div class="frow"><label class="fl">Source lang</label><input id="vd-slang" class="fi" placeholder="en"></div>
<div class="frow"><label class="fl">Target lang</label><input id="vd-tlang" class="fi" placeholder="es, fr, de…"></div>
</div>
<label style="font-size:12px;display:flex;align-items:center;gap:.4rem;cursor:pointer"><input type="checkbox" id="vd-burn"> Burn subtitles into result</label>
<button class="btn btn-primary" onclick="genDub()">Dub Video</button>
<div class="progress" id="vd-prog"></div>
</div>
<div class="gen-out" id="vd-out"><div class="gen-empty">Dubbed video will appear here</div></div>
</div></div>
<!-- Video upscale -->
<div class="panel" id="panel-vid-up"><div class="gen-wrap">
<div class="gen-ctrl">
<div class="frow"><label class="fl">Source video</label><input type="file" id="vu-src" accept="video/*" class="fi"></div>
<div class="frow"><label class="fl">Scale factor</label>
<select id="vu-scale" class="fselect"><option value="2"></option><option value="4"></option></select>
</div>
<button class="btn btn-primary" onclick="genVidUpscale()">Upscale</button>
<div class="progress" id="vu-prog"></div>
</div>
<div class="gen-out" id="vu-out"><div class="gen-empty">Upscaled video will appear here</div></div>
</div></div>
<!-- ═══════════════ AUDIO ═══════════════ -->
<!-- Music/SFX generation -->
<div class="panel" id="panel-aud-gen">
<div class="gen-wrap">
<div class="gen-ctrl">
<div class="frow"><label class="fl">Prompt</label><textarea id="ag-prompt" class="fs" rows="4" placeholder="e.g. upbeat electronic dance music with synth bass…"></textarea></div>
<div class="g2">
<div class="frow"><label class="fl">Duration (sec)</label><input type="number" id="ag-dur" class="fi" value="10" min="1" max="300"></div>
<div class="frow"><label class="fl">Temperature</label><input type="number" id="ag-temp" class="fi" value="1.0" step="0.1" min="0.1"></div>
</div>
<div class="frow"><label class="fl">Melody reference (optional)</label><input type="file" id="ag-melody" accept="audio/*" class="fi"></div>
<div class="frow"><label class="fl">Seed</label><input type="number" id="ag-seed" class="fi" placeholder="random"></div>
<button class="btn btn-primary" onclick="genAudio()">Generate</button>
<div class="progress" id="ag-prog"></div>
</div>
<div class="gen-out" id="ag-out"><div class="gen-empty">Generated audio will appear here</div></div>
</div>
</div>
<!-- TTS -->
<div class="panel" id="panel-aud-tts">
<div class="gen-wrap">
<div class="gen-ctrl">
<div class="frow"><label class="fl">Text to speak</label><textarea id="at-text" class="fs" rows="6" placeholder="Enter text…"></textarea></div>
<div class="g2">
<div class="frow"><label class="fl">Voice</label><input id="at-voice" class="fi" placeholder="af_sky, en-US-Jenny…"></div>
<div class="frow"><label class="fl">Speed</label><input type="number" id="at-speed" class="fi" value="1.0" step="0.1"></div>
</div>
<button class="btn btn-primary" onclick="genTTS()">Synthesize</button>
<div class="progress" id="at-prog"></div>
</div>
<div class="gen-out" id="at-out"><div class="gen-empty">Audio will appear here</div></div>
</div>
</div>
<!-- STT -->
<div class="panel" id="panel-aud-stt">
<div class="gen-wrap">
<div class="gen-ctrl">
<div class="frow"><label class="fl">Audio / video file</label><input type="file" id="as-file" accept="audio/*,video/*" class="fi"></div>
<div class="g2">
<div class="frow"><label class="fl">Language</label><input id="as-lang" class="fi" placeholder="en, fr…"></div>
<div class="frow"><label class="fl">Hint</label><input id="as-prompt" class="fi" placeholder="context hint…"></div>
</div>
<button class="btn btn-primary" onclick="genSTT()">Transcribe</button>
<div class="progress" id="as-prog"></div>
</div>
<div class="gen-out" id="as-out">
<div class="gen-out-inner"><div class="gen-empty">Transcript will appear here</div></div>
</div>
</div>
</div>
<!-- ═══════════════ PIPELINES ═══════════════ -->
<div class="panel" id="panel-pipe">
<div class="pipe-panel">
<div class="pipe-card">
<div class="pipe-title">Image → Video Pipeline</div>
<div class="pipe-steps">
<span class="pipe-step">1. Generate image</span><span class="pipe-arrow"></span>
<span class="pipe-step">2. Animate to video</span><span class="pipe-arrow"></span>
<span class="pipe-step">3. (Optional) Add audio</span>
</div>
<div style="margin-top:.75rem;display:flex;flex-direction:column;gap:.5rem">
<div class="frow"><label class="fl">Text prompt</label><textarea id="pp1-prompt" class="fs" rows="3" placeholder="Describe the scene…"></textarea></div>
<div class="g2">
<div class="frow"><label class="fl">Image model</label><input id="pp1-imodel" class="fi" placeholder="model id"></div>
<div class="frow"><label class="fl">Video model</label><input id="pp1-vmodel" class="fi" placeholder="model id"></div>
</div>
<div class="g2">
<div class="frow"><label class="fl">Frames</label><input type="number" id="pp1-frames" class="fi" value="16"></div>
<div class="frow"><label class="fl">FPS</label><input type="number" id="pp1-fps" class="fi" value="8"></div>
</div>
<button class="btn btn-primary" onclick="runPipeline1()">Run Pipeline</button>
<div class="progress" id="pp1-prog"></div>
<div id="pp1-out"></div>
</div>
</div>
<div class="pipe-card">
<div class="pipe-title">Video → Dub + Subtitle Pipeline</div>
<div class="pipe-steps">
<span class="pipe-step">1. Transcribe</span><span class="pipe-arrow"></span>
<span class="pipe-step">2. Translate</span><span class="pipe-arrow"></span>
<span class="pipe-step">3. TTS dub</span><span class="pipe-arrow"></span>
<span class="pipe-step">4. Burn subtitles</span>
</div>
<div style="margin-top:.75rem;display:flex;flex-direction:column;gap:.5rem">
<div class="frow"><label class="fl">Source video</label><input type="file" id="pp2-src" accept="video/*" class="fi"></div>
<div class="g2">
<div class="frow"><label class="fl">Source language</label><input id="pp2-slang" class="fi" placeholder="en"></div>
<div class="frow"><label class="fl">Target language</label><input id="pp2-tlang" class="fi" placeholder="es, fr, de…"></div>
</div>
<div class="frow"><label class="fl">Transcription model</label><input id="pp2-model" class="fi" placeholder="whisper model id"></div>
<button class="btn btn-primary" onclick="runPipeline2()">Run Pipeline</button>
<div class="progress" id="pp2-prog"></div>
<div id="pp2-out"></div>
</div>
</div>
<div class="pipe-card">
<div class="pipe-title">Full Story Pipeline</div>
<div class="pipe-steps">
<span class="pipe-step">1. LLM generates script</span><span class="pipe-arrow"></span>
<span class="pipe-step">2. Image gen per scene</span><span class="pipe-arrow"></span>
<span class="pipe-step">3. Animate each scene</span><span class="pipe-arrow"></span>
<span class="pipe-step">4. TTS narration</span><span class="pipe-arrow"></span>
<span class="pipe-step">5. Merge</span>
</div>
<div style="margin-top:.75rem;display:flex;flex-direction:column;gap:.5rem">
<div class="frow"><label class="fl">Story idea</label><textarea id="pp3-story" class="fs" rows="3" placeholder="A knight fights a dragon at sunset…"></textarea></div>
<div class="g3">
<div class="frow"><label class="fl">Text model</label><input id="pp3-tmodel" class="fi" placeholder="llm id"></div>
<div class="frow"><label class="fl">Image model</label><input id="pp3-imodel" class="fi" placeholder="sd id"></div>
<div class="frow"><label class="fl">Video model</label><input id="pp3-vmodel" class="fi" placeholder="video id"></div>
</div>
<div class="frow"><label class="fl">TTS model</label><input id="pp3-amodel" class="fi" placeholder="tts id"></div>
<button class="btn btn-primary" onclick="runPipeline3()">Run Full Pipeline</button>
<div class="progress" id="pp3-prog"></div>
<div id="pp3-out"></div>
</div>
</div>
</div>
</div>
<!-- ═══════════════ EMBEDDINGS ═══════════════ -->
<div class="panel" id="panel-embed">
<div class="gen-wrap">
<div class="gen-ctrl">
<div class="frow"><label class="fl">Text(s) — one per line</label><textarea id="em-text" class="fs" rows="6" placeholder="Enter text to embed…"></textarea></div>
<div class="frow"><label class="fl">Encoding</label>
<select id="em-enc" class="fselect"><option value="float">Float array</option><option value="base64">Base64</option></select>
</div>
<div class="frow"><label class="fl">Dimensions (truncate, optional)</label><input type="number" id="em-dims" class="fi" placeholder="all"></div>
<button class="btn btn-primary" onclick="genEmbeddings()">Embed</button>
<div class="progress" id="em-prog"></div>
</div>
<div class="gen-out" id="em-out"><div class="gen-empty">Embedding vectors will appear here</div></div>
</div>
</div>
</div><!-- studio-main -->
</div><!-- studio -->
{% endblock %} {% endblock %}
{% block scripts %} {% block scripts %}
<script> <script>
let history = []; // ─────────────────────────────────────────────────────────────────
let busy = false; // State
// ─────────────────────────────────────────────────────────────────
let models = [], activeModel = null, chatHistory = [], chatBusy = false, attachedImage = null;
let charRefs = {}; // prefix → list of base64 strings
const $ = id => document.getElementById(id);
const val = id => ($(id) ? $(id).value : '');
const ival = id => parseInt(val(id)) || 0;
const fval = id => parseFloat(val(id)) || 0;
const chk = id => ($(id) ? $(id).checked : false);
// Capability → which level-1 tab
const CAP_CAT = {
text_generation:'chat', image_to_text:'chat',
image_generation:'image', image_to_image:'image', inpainting:'image',
image_upscaling:'image', depth_estimation:'image', image_segmentation:'image',
video_generation:'video', image_to_video:'video', video_to_video:'video',
video_interpolation:'video', video_upscaling:'video',
text_to_speech:'audio', speech_to_text:'audio', subtitle_generation:'video',
audio_generation:'audio',
embeddings:'embed',
};
// Capability → which level-2 sub-tab
const CAP_SUB = {
image_generation:'img-gen', image_to_image:'img-edit', inpainting:'img-inpaint',
image_upscaling:'img-upscale', depth_estimation:'img-depth', image_segmentation:'img-seg',
video_generation:'vid-t2v', image_to_video:'vid-i2v', video_to_video:'vid-v2v',
video_interpolation:'vid-interp', video_upscaling:'vid-up',
subtitle_generation:'vid-sub',
text_to_speech:'aud-tts', speech_to_text:'aud-stt',
audio_generation:'aud-gen',
};
// Sub-tab → parent category
const SUB_CAT = {
'img-gen':'image','img-edit':'image','img-inpaint':'image','img-upscale':'image','img-depth':'image','img-seg':'image',
'vid-t2v':'video','vid-i2v':'video','vid-v2v':'video','vid-ti2v':'video','vid-interp':'video',
'vid-sub':'video','vid-dub':'video','vid-up':'video',
'aud-gen':'audio','aud-tts':'audio','aud-stt':'audio',
};
// Video models also enable ti2v and dub
const VIDEO_EXTRA_SUBS = ['vid-ti2v', 'vid-dub'];
// ─────────────────────────────────────────────────────────────────
// Boot
// ─────────────────────────────────────────────────────────────────
async function loadModels() { async function loadModels() {
try { try {
const d = await fetch('/v1/models').then(r => r.json()); const d = await fetch('/v1/models').then(r => r.json());
const sel = document.getElementById('model-sel'); models = d.data || [];
sel.innerHTML = '<option value="">Select model…</option>'; renderSidebar();
(d.data || []).forEach(m => { if (models.length) selectModel(models[0]);
const o = document.createElement('option'); } catch(e) {
o.value = o.textContent = m.id; $('model-list').innerHTML = '<div class="muted small" style="padding:.5rem .6rem">Failed to load models</div>';
sel.appendChild(o); }
});
} catch {}
} }
function newChat() { const BADGE = {text:'mb-text',vision:'mb-vision',image:'mb-image',video:'mb-video',
history = []; audio:'mb-audio',tts:'mb-tts',audio_gen:'mb-audiogen',embedding:'mb-embed'};
document.getElementById('chat-msgs').innerHTML = '<div class="chat-empty"><h3>New conversation</h3><p>Start typing below</p></div>'; const BLABEL = {text:'LLM',vision:'VLM',image:'IMG',video:'VID',audio:'STT',
tts:'TTS',audio_gen:'MUS',embedding:'EMB'};
function renderSidebar() {
const el = $('model-list');
if (!models.length) { el.innerHTML='<div class="muted small" style="padding:.5rem .6rem">No models</div>'; return; }
el.innerHTML = models.map(m => {
const t = m.type || 'text';
const safe = JSON.stringify(m).replace(/"/g,'&quot;');
return `<div class="model-item" data-id="${m.id}" onclick="selectModel(${safe})">
<span class="mbadge ${BADGE[t]||'mb-text'}">${BLABEL[t]||t}</span>
<span style="overflow:hidden;text-overflow:ellipsis;white-space:nowrap;font-size:12px" title="${m.id}">${m.id.split('/').pop()}</span>
</div>`;
}).join('');
} }
function addMsg(role, text) { function selectModel(m) {
const wrap = document.getElementById('chat-msgs'); activeModel = m;
document.querySelectorAll('.model-item').forEach(el =>
el.classList.toggle('active', el.dataset.id === m.id));
chatHistory = []; attachedImage = null; updateAttachBar();
$('chat-msgs').innerHTML = `<div class="chat-empty"><h3>${m.id.split('/').pop()}</h3><p>Start below</p></div>`;
updateTabs(m);
}
function updateTabs(m) {
const caps = m.capabilities || [];
const type = m.type || 'text';
const enabledCats = new Set();
const enabledSubs = new Set();
// Always available
if (caps.includes('text_generation') || caps.includes('image_to_text')) {
enabledCats.add('chat');
}
caps.forEach(c => {
const cat = CAP_CAT[c]; if (cat) enabledCats.add(cat);
const sub = CAP_SUB[c]; if (sub) enabledSubs.add(sub);
});
// Video models also get ti2v and dub
if (enabledCats.has('video') || type === 'video') {
VIDEO_EXTRA_SUBS.forEach(s => enabledSubs.add(s));
enabledCats.add('video');
}
// Pipelines available if at least 2 different types are loaded in the server
// (always show, user can fill in model IDs manually)
enabledCats.add('pipe');
// Show/hide level-1 tabs
document.querySelectorAll('.t1btn').forEach(btn => {
btn.classList.toggle('hidden', !enabledCats.has(btn.dataset.cat));
});
// Show/hide level-2 tabs
document.querySelectorAll('.t2btn').forEach(btn => {
btn.classList.toggle('hidden', !enabledSubs.has(btn.dataset.sub));
});
// Vision attach button
$('attach-btn').style.display = caps.includes('image_to_text') ? '' : 'none';
// Switch to first available cat
const firstCat = document.querySelector('.t1btn:not(.hidden)');
if (firstCat) selectCat(firstCat.dataset.cat);
}
function selectCat(cat) {
document.querySelectorAll('.t1btn').forEach(b => b.classList.toggle('active', b.dataset.cat === cat));
// Show tabbar2 only for image/video/audio
const hasL2 = ['image','video','audio'].includes(cat);
$('tabbar2').classList.toggle('visible', hasL2);
if (!hasL2) {
document.querySelectorAll('.panel').forEach(p => p.classList.remove('active'));
const panel = $('panel-' + cat);
if (panel) panel.classList.add('active');
return;
}
// Activate first visible sub-tab in this category
const firstSub = document.querySelector(`.t2btn:not(.hidden)[data-sub^="${cat.substring(0,3)}"], .t2btn:not(.hidden)[data-sub^="${cat === 'audio' ? 'aud' : cat === 'video' ? 'vid' : 'img'}"]`);
if (firstSub) selectSub(firstSub.dataset.sub);
}
function selectSub(sub) {
document.querySelectorAll('.t2btn').forEach(b => b.classList.toggle('active', b.dataset.sub === sub));
document.querySelectorAll('.panel').forEach(p => p.classList.remove('active'));
const panel = $('panel-' + sub);
if (panel) panel.classList.add('active');
}
// ─────────────────────────────────────────────────────────────────
// Chat
// ─────────────────────────────────────────────────────────────────
function addMsg(role, text, imgSrc) {
const wrap = $('chat-msgs');
wrap.querySelector('.chat-empty')?.remove(); wrap.querySelector('.chat-empty')?.remove();
const t = new Date().toLocaleTimeString([],{hour:'2-digit',minute:'2-digit'}); const t = new Date().toLocaleTimeString([],{hour:'2-digit',minute:'2-digit'});
const d = document.createElement('div'); const d = document.createElement('div');
d.className = 'msg ' + role; d.className = 'msg ' + role;
d.innerHTML = ` const name = role === 'user' ? 'You' : (activeModel?.id?.split('/').pop() || 'AI');
<div class="msg-av ${role === 'user' ? 'user' : 'ai'}">${role === 'user' ? 'YOU' : 'AI'}</div> d.innerHTML = `<div class="av ${role==='user'?'user':'ai'}">${role==='user'?'YOU':'AI'}</div>
<div class="msg-body"> <div class="msg-body">
<div class="msg-meta">${role === 'user' ? 'You' : 'Assistant'} · ${t}</div> <div class="msg-meta">${name} · ${t}</div>
${imgSrc ? `<img src="${imgSrc}" class="msg-img" onclick="window.open(this.src)">` : ''}
<div class="msg-text">${String(text).replace(/&/g,'&amp;').replace(/</g,'&lt;').replace(/>/g,'&gt;').replace(/\n/g,'<br>')}</div> <div class="msg-text">${String(text).replace(/&/g,'&amp;').replace(/</g,'&lt;').replace(/>/g,'&gt;').replace(/\n/g,'<br>')}</div>
</div>`; </div>`;
wrap.appendChild(d); wrap.appendChild(d);
wrap.scrollTop = wrap.scrollHeight; wrap.scrollTop = wrap.scrollHeight;
} }
async function send() { function attachImg(input) {
if (busy) return; const f = input.files[0]; if (!f) return;
const model = document.getElementById('model-sel').value; const r = new FileReader();
if (!model) { document.getElementById('model-sel').focus(); return; } r.onload = e => {
const input = document.getElementById('chat-in'); attachedImage = e.target.result;
const text = input.value.trim(); $('attach-thumb').src = attachedImage;
if (!text) return; $('attach-name').textContent = f.name;
updateAttachBar();
addMsg('user', text); };
history.push({role:'user', content:text}); r.readAsDataURL(f);
input.value = ''; }
input.style.height = 'auto'; function clearAttach() { attachedImage = null; updateAttachBar(); }
function updateAttachBar() { $('attach-bar').style.display = attachedImage ? 'flex' : 'none'; }
busy = true;
document.getElementById('send-btn').disabled = true;
document.getElementById('typing').textContent = 'Assistant is typing…';
async function sendChat() {
if (chatBusy || !activeModel) return;
const input = $('chat-in');
const text = input.value.trim(); if (!text) return;
addMsg('user', text, attachedImage);
let content = text;
if (attachedImage && (activeModel.capabilities||[]).includes('image_to_text')) {
content = [{type:'image_url',image_url:{url:attachedImage}},{type:'text',text}];
}
chatHistory.push({role:'user',content});
input.value = ''; input.style.height = 'auto';
attachedImage = null; updateAttachBar();
chatBusy = true; $('send-btn').disabled = true;
$('typing').textContent = 'Thinking…';
try { try {
const r = await fetch('/v1/chat/completions', { const r = await fetch('/v1/chat/completions',{
method:'POST', headers:{'Content-Type':'application/json'}, method:'POST',headers:{'Content-Type':'application/json'},
body: JSON.stringify({model, messages: history, stream:false}) body:JSON.stringify({model:activeModel.id,messages:chatHistory,stream:false})
}); });
if (!r.ok) throw new Error('HTTP ' + r.status); if (!r.ok) throw new Error('HTTP '+r.status+': '+await r.text());
const d = await r.json(); const d = await r.json();
const reply = d.choices[0].message.content; const reply = d.choices[0].message.content;
addMsg('assistant', reply); addMsg('assistant',reply);
history.push({role:'assistant', content:reply}); chatHistory.push({role:'assistant',content:reply});
} catch (e) { } catch(e) { addMsg('assistant','Error: '+e.message); }
addMsg('assistant', 'Error: ' + e.message); finally { chatBusy=false; $('send-btn').disabled=false; $('typing').textContent=''; }
} finally { }
busy = false;
document.getElementById('send-btn').disabled = false; $('chat-in').addEventListener('keydown', e => { if(e.key==='Enter'&&!e.shiftKey){e.preventDefault();sendChat();} });
document.getElementById('typing').textContent = ''; $('chat-in').addEventListener('input', function(){ this.style.height='auto'; this.style.height=Math.min(this.scrollHeight,140)+'px'; });
// ─────────────────────────────────────────────────────────────────
// Utilities
// ─────────────────────────────────────────────────────────────────
function fileToB64(file) {
return new Promise((res,rej) => {
const r = new FileReader();
r.onload = e => res(e.target.result);
r.onerror = rej;
r.readAsDataURL(file);
});
}
function fileOrNull(id) { const f=$(id); return f&&f.files[0] ? f.files[0] : null; }
async function b64OrNull(id) { const f=fileOrNull(id); return f ? await fileToB64(f) : null; }
function showImg(outId, src, prog) {
$(outId).innerHTML = `<div class="gen-out-inner">
<img class="out-img" src="${src}" onclick="window.open(this.src)">
<a href="${src}" download class="btn btn-ghost btn-sm dl">Download</a>
</div>`;
if(prog) $(prog).textContent='Done ✓';
}
function showVideo(outId, src, prog) {
$(outId).innerHTML = `<div class="gen-out-inner">
<video class="out-video" controls src="${src}"></video>
<a href="${src}" download class="btn btn-ghost btn-sm dl">Download</a>
</div>`;
if(prog) $(prog).textContent='Done ✓';
}
function showAudio(outId, src, prog, ext) {
$(outId).innerHTML = `<div class="gen-out-inner" style="width:100%">
<audio class="out-audio" controls src="${src}"></audio>
<a href="${src}" download="audio.${ext||'wav'}" class="btn btn-ghost btn-sm dl">Download</a>
</div>`;
if(prog) $(prog).textContent='Done ✓';
}
function imgSrc(d) { return d.url || (d.b64_json ? 'data:image/png;base64,'+d.b64_json : null); }
function vidSrc(d) { return d.url || (d.b64_mp4 ? 'data:video/mp4;base64,'+d.b64_mp4 : null); }
function audSrc(d) {
if (d.url) return d.url;
for (const k of Object.keys(d)) { if(k.startsWith('b64_')) return 'data:audio/'+k.slice(4)+';base64,'+d[k]; }
return null;
}
async function post(path, body) {
const r = await fetch(path, {method:'POST', headers:{'Content-Type':'application/json'}, body:JSON.stringify(body)});
if (!r.ok) throw new Error(await r.text());
return r.json();
}
async function postForm(path, fd) {
const r = await fetch(path, {method:'POST', body:fd});
if (!r.ok) throw new Error(await r.text());
return r.json();
}
// Character reference helpers
function addCharRefs(prefix, input) {
if (!charRefs[prefix]) charRefs[prefix] = [];
const thumbs = $(prefix+'-char-thumbs');
Array.from(input.files).forEach(f => {
const reader = new FileReader();
reader.onload = e => {
charRefs[prefix].push(e.target.result);
const img = document.createElement('img');
img.src = e.target.result; img.className='char-thumb';
thumbs.appendChild(img);
};
reader.readAsDataURL(f);
});
}
// ─────────────────────────────────────────────────────────────────
// Image Generation
// ─────────────────────────────────────────────────────────────────
async function genImage() {
if (!activeModel) return;
$('ig-prog').textContent='Generating…';
try {
const d = await post('/v1/images/generations', {
model:activeModel.id, prompt:val('ig-prompt'),
size:val('ig-w')+'x'+val('ig-h'),
steps:ival('ig-steps'), guidance_scale:fval('ig-cfg'),
n:ival('ig-n')||1,
...(val('ig-seed') ? {seed:ival('ig-seed')} : {}),
...(val('ig-neg') ? {negative_prompt:val('ig-neg')} : {}),
disable_safety_checker: chk('ig-nosafe'),
response_format:'url',
});
const imgs = d.data.map(imgSrc).filter(Boolean);
$('ig-out').innerHTML = `<div class="gen-out-inner">${
imgs.map(s=>`<img class="out-img" src="${s}" onclick="window.open(this.src)">`).join('')
}<a href="${imgs[0]}" download class="btn btn-ghost btn-sm dl">Download first</a></div>`;
$('ig-prog').textContent='Done ✓';
} catch(e) { $('ig-prog').textContent='Error: '+e.message; }
}
// ─────────────────────────────────────────────────────────────────
// Image Edit (i2i)
// ─────────────────────────────────────────────────────────────────
async function genEdit() {
if (!activeModel) return;
const img = await b64OrNull('ie-src');
if (!img) { $('ie-prog').textContent='Select a source image.'; return; }
$('ie-prog').textContent='Editing…';
try {
const d = await post('/v1/images/edits', {
model:activeModel.id, prompt:val('ie-prompt'), image:img,
strength:fval('ie-str'), steps:ival('ie-steps'), guidance_scale:fval('ie-cfg'),
...(val('ie-seed') ? {seed:ival('ie-seed')} : {}), response_format:'url',
});
showImg('ie-out', imgSrc(d.data[0]), 'ie-prog');
} catch(e) { $('ie-prog').textContent='Error: '+e.message; }
}
// ─────────────────────────────────────────────────────────────────
// Inpaint
// ─────────────────────────────────────────────────────────────────
async function genInpaint() {
if (!activeModel) return;
const [img, mask] = await Promise.all([b64OrNull('ip-src'), b64OrNull('ip-mask')]);
if (!img||!mask) { $('ip-prog').textContent='Select source image and mask.'; return; }
$('ip-prog').textContent='Inpainting…';
try {
const d = await post('/v1/images/inpaint', {
model:activeModel.id, prompt:val('ip-prompt'), image:img, mask,
strength:fval('ip-str'), steps:ival('ip-steps'), guidance_scale:fval('ip-cfg'),
...(val('ip-seed') ? {seed:ival('ip-seed')} : {}), response_format:'url',
});
showImg('ip-out', imgSrc(d.data[0]), 'ip-prog');
} catch(e) { $('ip-prog').textContent='Error: '+e.message; }
}
// ─────────────────────────────────────────────────────────────────
// Image Upscale
// ─────────────────────────────────────────────────────────────────
async function genImgUpscale() {
if (!activeModel) return;
const img = await b64OrNull('iu-src');
if (!img) { $('iu-prog').textContent='Select an image.'; return; }
$('iu-prog').textContent='Upscaling…';
try {
const d = await post('/v1/images/upscale', {model:activeModel.id, image:img, scale:ival('iu-scale'), response_format:'url'});
showImg('iu-out', imgSrc(d.data[0]), 'iu-prog');
} catch(e) { $('iu-prog').textContent='Error: '+e.message; }
}
// ─────────────────────────────────────────────────────────────────
// Depth
// ─────────────────────────────────────────────────────────────────
async function genDepth() {
if (!activeModel) return;
const img = await b64OrNull('id-src');
if (!img) { $('id-prog').textContent='Select an image.'; return; }
$('id-prog').textContent='Estimating depth…';
try {
const d = await post('/v1/images/depth', {model:activeModel.id, image:img, response_format:'url'});
showImg('id-out', imgSrc(d.data[0]), 'id-prog');
} catch(e) { $('id-prog').textContent='Error: '+e.message; }
}
// ─────────────────────────────────────────────────────────────────
// Segment
// ─────────────────────────────────────────────────────────────────
async function genSegment() {
if (!activeModel) return;
const img = await b64OrNull('is-src');
if (!img) { $('is-prog').textContent='Select an image.'; return; }
$('is-prog').textContent='Segmenting…';
const pts = val('is-pts') ? JSON.parse(val('is-pts')) : null;
const boxes = val('is-boxes') ? JSON.parse(val('is-boxes')) : null;
try {
const d = await post('/v1/images/segment', {model:activeModel.id, image:img, points:pts, boxes, response_format:'url'});
showImg('is-out', imgSrc(d.data[0]), 'is-prog');
} catch(e) { $('is-prog').textContent='Error: '+e.message; }
}
// ─────────────────────────────────────────────────────────────────
// Video Generation (t2v / i2v / v2v)
// ─────────────────────────────────────────────────────────────────
async function genVideo(mode) {
if (!activeModel) return;
const progMap = {t2v:'vt-prog', i2v:'vi-prog', v2v:'vv-prog'};
const outMap = {t2v:'vt-out', i2v:'vi-out', v2v:'vv-out'};
const prog = progMap[mode], outId = outMap[mode];
$(prog).textContent='Generating… (this may take several minutes)';
const body = {model:activeModel.id, mode};
if (mode==='t2v') {
body.prompt=val('vt-prompt'); body.negative_prompt=val('vt-neg');
body.width=ival('vt-w')||512; body.height=ival('vt-h')||512;
body.num_frames=ival('vt-frames')||16; body.fps=ival('vt-fps')||8;
body.num_inference_steps=ival('vt-steps')||25; body.guidance_scale=fval('vt-cfg')||7.5;
if (val('vt-seed')) body.seed=ival('vt-seed');
body.disable_safety_checker=chk('vt-nosafe');
} else if (mode==='i2v') {
const img = await b64OrNull('vi-src');
if (!img) { $(prog).textContent='Select a source image.'; return; }
body.init_image=img; body.prompt=val('vi-prompt')||'animate this image';
body.num_frames=ival('vi-frames')||16; body.fps=ival('vi-fps')||8;
body.num_inference_steps=ival('vi-steps')||25; body.guidance_scale=fval('vi-cfg')||7.5;
if (val('vi-seed')) body.seed=ival('vi-seed');
body.disable_safety_checker=chk('vt-nosafe');
} else {
const vid = await b64OrNull('vv-src');
if (!vid) { $(prog).textContent='Select a source video.'; return; }
body.video=vid; body.prompt=val('vv-prompt'); body.strength=fval('vv-str');
body.num_inference_steps=ival('vv-steps')||25; body.guidance_scale=fval('vv-cfg')||7.5;
if (val('vv-seed')) body.seed=ival('vv-seed');
body.disable_safety_checker=chk('vt-nosafe');
} }
try {
const d = await post('/v1/video/generations', body);
showVideo(outId, vidSrc(d.data[0]), prog);
} catch(e) { $(prog).textContent='Error: '+e.message; }
}
// ─────────────────────────────────────────────────────────────────
// Ti2V (text + init image + end image)
// ─────────────────────────────────────────────────────────────────
async function genTi2V() {
if (!activeModel) return;
$('ti-prog').textContent='Generating… (may take several minutes)';
const [initImg, endImg] = await Promise.all([b64OrNull('ti-init'), b64OrNull('ti-end')]);
if (!initImg) { $('ti-prog').textContent='Select an initial image.'; return; }
const body = {
model:activeModel.id,
mode: endImg ? 'interp' : (val('ti-prompt') ? 'ti2v' : 'i2v'),
prompt:val('ti-prompt'), negative_prompt:val('ti-neg'),
init_image:initImg, end_image:endImg||undefined,
width:ival('ti-w')||512, height:ival('ti-h')||512,
num_frames:ival('ti-frames')||16, fps:ival('ti-fps')||8,
num_inference_steps:ival('ti-steps')||25, guidance_scale:fval('ti-cfg')||7.5,
camera_motion:val('ti-cam')||undefined,
...(val('ti-seed') ? {seed:ival('ti-seed')} : {}),
// Character
character_references:charRefs['ti']?.length ? charRefs['ti'] : undefined,
character_strength:fval('ti-cstr')||0.8,
// Audio
add_audio:!!val('ti-atype'),
audio_type:val('ti-atype')||undefined,
audio_prompt:val('ti-aprompt')||undefined,
tts_text:val('ti-ttstext')||undefined,
tts_voice:val('ti-voice')||undefined,
tts_speed:fval('ti-speed')||1.0,
lip_sync:chk('ti-lipsync'),
// Subtitles
generate_subtitles:chk('ti-gensub'),
burn_subtitles:chk('ti-burnsub'),
translate_subtitles:chk('ti-transsub'),
subtitle_target_lang:val('ti-sublang')||undefined,
// Post
upscale_output:chk('ti-upscale'),
upscale_factor:ival('ti-upfact')||2,
interpolate_output:chk('ti-interp'),
fps_multiplier:ival('ti-fmult')||2,
// Safety
disable_safety_checker:chk('ti-nosafe'),
};
try {
const d = await post('/v1/video/generations', body);
showVideo('ti-out', vidSrc(d.data[0]), 'ti-prog');
} catch(e) { $('ti-prog').textContent='Error: '+e.message; }
}
// ─────────────────────────────────────────────────────────────────
// Frame Interpolation
// ─────────────────────────────────────────────────────────────────
async function genInterp() {
if (!activeModel) return;
$('vc-prog').textContent='Interpolating…';
const [vid, init, end] = await Promise.all([b64OrNull('vc-src'), b64OrNull('vc-init'), b64OrNull('vc-end')]);
const body = {model:activeModel.id, fps_multiplier:ival('vc-mult')||2};
if (vid) body.video=vid;
else if (init && end) { body.init_image=init; body.end_image=end; }
else { $('vc-prog').textContent='Provide a video or both frame images.'; return; }
try {
const d = await post('/v1/video/interpolate', body);
showVideo('vc-out', vidSrc(d.data[0]), 'vc-prog');
} catch(e) { $('vc-prog').textContent='Error: '+e.message; }
}
// ─────────────────────────────────────────────────────────────────
// Subtitles
// ─────────────────────────────────────────────────────────────────
async function genSubtitles() {
if (!activeModel) return;
const vid = await b64OrNull('vs-src');
if (!vid) { $('vs-prog').textContent='Select a video.'; return; }
$('vs-prog').textContent='Generating subtitles…';
const burn = val('vs-fmt') === 'burned_video';
const body = {
model:activeModel.id, video:vid,
language:val('vs-lang')||undefined,
burn, format:val('vs-fmt'),
translate:chk('vs-translate'),
target_lang:val('vs-tlang')||undefined,
};
try {
const d = await post('/v1/video/subtitle', body);
const item = d.data[0];
if (item.url || item.b64_mp4) showVideo('vs-out', vidSrc(item), 'vs-prog');
else if (item.text) {
$('vs-out').innerHTML = `<div class="gen-out-inner" style="width:100%;text-align:left">
<pre style="white-space:pre-wrap;font-size:12px;background:var(--surface-2);padding:.75rem;border-radius:6px;width:100%;box-sizing:border-box">${item.text}</pre>
<button class="btn btn-ghost btn-sm" onclick="navigator.clipboard.writeText(${JSON.stringify(item.text)})">Copy</button>
</div>`;
$('vs-prog').textContent='Done ✓';
}
} catch(e) { $('vs-prog').textContent='Error: '+e.message; }
} }
document.getElementById('chat-in').addEventListener('keydown', e => { // ─────────────────────────────────────────────────────────────────
if (e.key === 'Enter' && !e.shiftKey) { e.preventDefault(); send(); } // Dub
}); // ─────────────────────────────────────────────────────────────────
document.getElementById('chat-in').addEventListener('input', function() { async function genDub() {
this.style.height = 'auto'; if (!activeModel) return;
this.style.height = Math.min(this.scrollHeight, 140) + 'px'; const vid = await b64OrNull('vd-src');
}); if (!vid) { $('vd-prog').textContent='Select a video.'; return; }
$('vd-prog').textContent='Dubbing…';
try {
const d = await post('/v1/video/dub', {
model:activeModel.id, video:vid,
source_lang:val('vd-slang')||undefined, target_lang:val('vd-tlang'),
burn_subtitles:chk('vd-burn'),
});
showVideo('vd-out', vidSrc(d.data[0]), 'vd-prog');
} catch(e) { $('vd-prog').textContent='Error: '+e.message; }
}
// ─────────────────────────────────────────────────────────────────
// Video upscale
// ─────────────────────────────────────────────────────────────────
async function genVidUpscale() {
if (!activeModel) return;
const vid = await b64OrNull('vu-src');
if (!vid) { $('vu-prog').textContent='Select a video.'; return; }
$('vu-prog').textContent='Upscaling…';
try {
const d = await post('/v1/video/upscale', {model:activeModel.id, video:vid, upscale_factor:ival('vu-scale')||2});
showVideo('vu-out', vidSrc(d.data[0]), 'vu-prog');
} catch(e) { $('vu-prog').textContent='Error: '+e.message; }
}
// ─────────────────────────────────────────────────────────────────
// T2V controls helper (shared form – inline because Jinja include
// isn't available here)
// ─────────────────────────────────────────────────────────────────
// The T2V panel controls are already inlined in the HTML above (vid-t2v).
// For the 3 shared controls we use ids like vt-prompt, vt-neg, vt-w, etc.
// ─────────────────────────────────────────────────────────────────
// Audio Gen
// ─────────────────────────────────────────────────────────────────
async function genAudio() {
if (!activeModel) return;
$('ag-prog').textContent='Generating audio…';
const melody = await b64OrNull('ag-melody');
const body = {
model:activeModel.id, prompt:val('ag-prompt'),
duration:fval('ag-dur')||10, temperature:fval('ag-temp')||1.0,
...(val('ag-seed') ? {seed:ival('ag-seed')} : {}),
...(melody ? {melody} : {}),
response_format:'url',
};
try {
const d = await post('/v1/audio/generate', body);
const src = audSrc(d.data[0]);
showAudio('ag-out', src, 'ag-prog', 'wav');
} catch(e) { $('ag-prog').textContent='Error: '+e.message; }
}
// ─────────────────────────────────────────────────────────────────
// TTS
// ─────────────────────────────────────────────────────────────────
async function genTTS() {
if (!activeModel) return;
$('at-prog').textContent='Synthesizing…';
try {
const r = await fetch('/v1/audio/speech', {
method:'POST', headers:{'Content-Type':'application/json'},
body:JSON.stringify({model:activeModel.id, input:val('at-text'),
speed:fval('at-speed')||1.0, voice:val('at-voice')||undefined, response_format:'mp3'})
});
if (!r.ok) throw new Error(await r.text());
const blob = await r.blob();
const src = URL.createObjectURL(blob);
showAudio('at-out', src, 'at-prog', 'mp3');
} catch(e) { $('at-prog').textContent='Error: '+e.message; }
}
// ─────────────────────────────────────────────────────────────────
// STT
// ─────────────────────────────────────────────────────────────────
async function genSTT() {
if (!activeModel) return;
const f = fileOrNull('as-file');
if (!f) { $('as-prog').textContent='Select an audio file.'; return; }
$('as-prog').textContent='Transcribing…';
const fd = new FormData();
fd.append('file', f); fd.append('model', activeModel.id);
if (val('as-lang')) fd.append('language', val('as-lang'));
if (val('as-prompt')) fd.append('prompt', val('as-prompt'));
try {
const d = await postForm('/v1/audio/transcriptions', fd);
$('as-out').innerHTML = `<div class="gen-out-inner" style="width:100%;text-align:left">
<pre style="white-space:pre-wrap;font-size:13px;line-height:1.6;background:var(--surface-2);padding:.75rem;border-radius:6px;width:100%;box-sizing:border-box">${d.text || JSON.stringify(d,null,2)}</pre>
<button class="btn btn-ghost btn-sm" onclick="navigator.clipboard.writeText(${JSON.stringify(d.text||'')})">Copy</button>
</div>`;
$('as-prog').textContent='Done ✓';
} catch(e) { $('as-prog').textContent='Error: '+e.message; }
}
// ─────────────────────────────────────────────────────────────────
// Embeddings
// ─────────────────────────────────────────────────────────────────
async function genEmbeddings() {
if (!activeModel) return;
$('em-prog').textContent='Embedding…';
const lines = val('em-text').split('\n').filter(l => l.trim());
if (!lines.length) { $('em-prog').textContent='Enter some text.'; return; }
const input = lines.length === 1 ? lines[0] : lines;
try {
const d = await post('/v1/embeddings', {
model:activeModel.id, input,
encoding_format:val('em-enc'),
...(val('em-dims') ? {dimensions:ival('em-dims')} : {}),
});
const preview = d.data.map((e,i) => {
const vec = Array.isArray(e.embedding)
? '[' + e.embedding.slice(0,8).map(v=>v.toFixed(4)).join(', ') + ', …] dim=' + e.embedding.length
: e.embedding.substring(0,60)+'…';
return `<div style="margin-bottom:.4rem"><strong>${lines[i]?.substring(0,40)||i}</strong><br><code style="font-size:11px">${vec}</code></div>`;
}).join('');
$('em-out').innerHTML = `<div class="gen-out-inner" style="width:100%;text-align:left">${preview}</div>`;
$('em-prog').textContent='Done ✓';
} catch(e) { $('em-prog').textContent='Error: '+e.message; }
}
// ─────────────────────────────────────────────────────────────────
// Pipeline 1: Image → Video
// ─────────────────────────────────────────────────────────────────
async function runPipeline1() {
const prompt=val('pp1-prompt'), imodel=val('pp1-imodel'), vmodel=val('pp1-vmodel');
if (!prompt||!imodel||!vmodel) { $('pp1-prog').textContent='Fill in prompt + both model IDs.'; return; }
$('pp1-prog').textContent='Step 1/2: Generating image…';
try {
const id = await post('/v1/images/generations', {model:imodel, prompt, response_format:'url'});
const imgSrcUrl = imgSrc(id.data[0]);
$('pp1-prog').textContent='Step 2/2: Animating image…';
const vd = await post('/v1/video/generations', {
model:vmodel, mode:'i2v', init_image:imgSrcUrl,
prompt, num_frames:ival('pp1-frames')||16, fps:ival('pp1-fps')||8,
});
$('pp1-out').innerHTML=`<img class="out-img" src="${imgSrcUrl}" style="max-width:200px;border-radius:6px">`;
showVideo('pp1-out', vidSrc(vd.data[0]), 'pp1-prog');
} catch(e) { $('pp1-prog').textContent='Error: '+e.message; }
}
// ─────────────────────────────────────────────────────────────────
// Pipeline 2: Video → Dub + Subtitle
// ─────────────────────────────────────────────────────────────────
async function runPipeline2() {
const f = fileOrNull('pp2-src');
if (!f) { $('pp2-prog').textContent='Select a video.'; return; }
const model = val('pp2-model') || (activeModel ? activeModel.id : '');
$('pp2-prog').textContent='Dubbing + subtitling…';
try {
const vid = await fileToB64(f);
const d = await post('/v1/video/dub', {
model, video:vid,
source_lang:val('pp2-slang')||undefined,
target_lang:val('pp2-tlang'), burn_subtitles:true,
});
showVideo('pp2-out', vidSrc(d.data[0]), 'pp2-prog');
} catch(e) { $('pp2-prog').textContent='Error: '+e.message; }
}
// ─────────────────────────────────────────────────────────────────
// Pipeline 3: Full Story
// ─────────────────────────────────────────────────────────────────
async function runPipeline3() {
const story=val('pp3-story'), tmodel=val('pp3-tmodel'), imodel=val('pp3-imodel'), vmodel=val('pp3-vmodel');
if (!story||!tmodel||!imodel||!vmodel) { $('pp3-prog').textContent='Fill in story and all model IDs.'; return; }
$('pp3-prog').textContent='Step 1: Generating script with LLM…';
try {
// 1. LLM generates scene descriptions
const script = await post('/v1/chat/completions', {
model:tmodel,
messages:[{role:'user',content:`Write a 3-scene visual script for this story. For each scene write: SCENE X: [brief visual description]. Story: ${story}`}],
stream:false,
});
const scriptText = script.choices[0].message.content;
const scenes = scriptText.match(/SCENE \d+: [^\n]+/g) || [story];
$('pp3-prog').textContent=`Script: ${scenes.length} scenes. Generating images…`;
// 2. Generate image per scene
const imgUrls = [];
for (const scene of scenes.slice(0,3)) {
const desc = scene.replace(/SCENE \d+: /, '');
const id = await post('/v1/images/generations', {model:imodel, prompt:desc, response_format:'url'});
imgUrls.push(imgSrc(id.data[0]));
}
// 3. Animate first scene image to video
$('pp3-prog').textContent='Animating first scene…';
const vd = await post('/v1/video/generations', {
model:vmodel, mode:'i2v', init_image:imgUrls[0],
prompt:scenes[0]?.replace(/SCENE \d+: /,''),
num_frames:16, fps:8,
});
const imgs = imgUrls.map(s=>`<img src="${s}" style="max-width:120px;border-radius:4px">`).join('');
$('pp3-out').innerHTML=`<div>${imgs}</div>`;
showVideo('pp3-out', vidSrc(vd.data[0]), 'pp3-prog');
$('pp3-prog').textContent='Done ✓ (tip: animate remaining scenes manually)';
} catch(e) { $('pp3-prog').textContent='Error: '+e.message; }
}
// The T2V panel (vid-t2v) needs its own controls inlined since Jinja
// fragment includes aren't available. Add them now:
document.getElementById('panel-vid-t2v').innerHTML = `<div class="gen-wrap">
<div class="gen-ctrl">
<div class="frow"><label class="fl">Prompt</label><textarea id="vt-prompt" class="fs" rows="4" placeholder="Describe the video…"></textarea></div>
<div class="frow"><label class="fl">Negative prompt</label><textarea id="vt-neg" class="fs" rows="2" placeholder="Things to avoid…"></textarea></div>
<div class="g2">
<div class="frow"><label class="fl">Width</label><input type="number" id="vt-w" class="fi" value="512" step="64"></div>
<div class="frow"><label class="fl">Height</label><input type="number" id="vt-h" class="fi" value="512" step="64"></div>
</div>
<div class="g2">
<div class="frow"><label class="fl">Frames</label><input type="number" id="vt-frames" class="fi" value="16"></div>
<div class="frow"><label class="fl">FPS</label><input type="number" id="vt-fps" class="fi" value="8"></div>
</div>
<div class="g2">
<div class="frow"><label class="fl">Steps</label><input type="number" id="vt-steps" class="fi" value="25"></div>
<div class="frow"><label class="fl">Guidance</label><input type="number" id="vt-cfg" class="fi" value="7.5" step="0.5"></div>
</div>
<div class="frow"><label class="fl">Seed</label><input type="number" id="vt-seed" class="fi" placeholder="random"></div>
<div class="frow"><label class="fl" title="Removes the model's built-in safety checker. Only use with uncensored/NSFW fine-tunes.">Disable safety filter</label><input type="checkbox" id="vt-nosafe"></div>
<button class="btn btn-primary" onclick="genVideo('t2v')">Generate Video</button>
<div class="progress" id="vt-prog"></div>
</div>
<div class="gen-out" id="vt-out"><div class="gen-empty">Video will appear here</div></div>
</div>`;
loadModels(); loadModels();
</script> </script>
......
...@@ -127,6 +127,8 @@ ...@@ -127,6 +127,8 @@
<option value="image-to-text">Image-to-text</option> <option value="image-to-text">Image-to-text</option>
<option value="automatic-speech-recognition">Speech recog.</option> <option value="automatic-speech-recognition">Speech recog.</option>
<option value="text-to-speech">TTS</option> <option value="text-to-speech">TTS</option>
<option value="text-to-video">Text-to-video</option>
<option value="image-to-video">Image-to-video</option>
<option value="feature-extraction">Embeddings</option> <option value="feature-extraction">Embeddings</option>
</select> </select>
</div> </div>
...@@ -296,9 +298,12 @@ ...@@ -296,9 +298,12 @@
<select id="cfg-type" class="form-input"> <select id="cfg-type" class="form-input">
<option value="text_models">Text (LLM)</option> <option value="text_models">Text (LLM)</option>
<option value="image_models">Image generation</option> <option value="image_models">Image generation</option>
<option value="audio_models">Audio</option> <option value="video_models">Video generation</option>
<option value="tts_models">TTS</option> <option value="audio_models">Audio transcription (STT)</option>
<option value="vision_models">Vision</option> <option value="tts_models">Text-to-speech (TTS)</option>
<option value="vision_models">Vision / VLM</option>
<option value="audio_gen_models">Audio generation (Music/SFX)</option>
<option value="embedding_models">Embeddings</option>
</select> </select>
</div> </div>
<div class="form-row" style="margin:0"> <div class="form-row" style="margin:0">
......
...@@ -70,6 +70,9 @@ from codai.api.transcriptions import router as transcriptions_router ...@@ -70,6 +70,9 @@ from codai.api.transcriptions import router as transcriptions_router
from codai.api.images import router as images_router from codai.api.images import router as images_router
from codai.api.tts import router as tts_router from codai.api.tts import router as tts_router
from codai.api.text import router as text_router from codai.api.text import router as text_router
from codai.api.video import router as video_router
from codai.api.audio_gen import router as audio_gen_router
from codai.api.embeddings import router as embeddings_router
from codai.admin.routes import router as admin_router from codai.admin.routes import router as admin_router
# Import and add middleware # Import and add middleware
...@@ -88,6 +91,9 @@ app.include_router(transcriptions_router) ...@@ -88,6 +91,9 @@ app.include_router(transcriptions_router)
app.include_router(images_router) app.include_router(images_router)
app.include_router(tts_router) app.include_router(tts_router)
app.include_router(text_router) app.include_router(text_router)
app.include_router(video_router)
app.include_router(audio_gen_router)
app.include_router(embeddings_router)
app.include_router(admin_router) app.include_router(admin_router)
......
"""
Audio generation endpoints for the codai API.
Supports music, sound effects, and ambient audio via MusicGen, AudioLDM2, StableAudio, etc.
POST /v1/audio/generate
"""
import asyncio
import base64
import io
import os
import time
import uuid
from fastapi import APIRouter, HTTPException, Request
from codai.models.manager import multi_model_manager
from codai.pydantic.audiogenrequest import AudioGenerationRequest, AudioGenerationResponse
router = APIRouter()
global_args = None
global_file_path = None
def set_global_args(args):
global global_args
global_args = args
def set_global_file_path(path):
global global_file_path
global_file_path = path
def _derive_device() -> str:
if global_args:
d = getattr(global_args, 'vulkan_device', None)
if d is not None:
return f"cuda:{d}"
return "cuda:0"
def _save_audio_response(audio_data: bytes, ext: str, http_request: Request) -> dict:
filename = f"{uuid.uuid4().hex}.{ext}"
if global_file_path:
os.makedirs(global_file_path, exist_ok=True)
fpath = os.path.join(global_file_path, filename)
with open(fpath, 'wb') as f:
f.write(audio_data)
url_setting = getattr(global_args, 'url', 'auto') if global_args else 'auto'
if url_setting == 'auto':
host = http_request.headers.get('host', '127.0.0.1') if http_request else '127.0.0.1'
if ':' in host:
parts = host.split(':')
if len(parts) == 2 and parts[1].isdigit():
host = parts[0]
use_https = getattr(global_args, 'https', False) if global_args else False
proto = 'https' if use_https else 'http'
port = getattr(global_args, 'port', 8000) if global_args else 8000
base_url = f"{proto}://{host}:{port}"
else:
base_url = url_setting.rstrip('/')
return {"url": f"{base_url}/v1/files/{filename}"}
else:
b64 = base64.b64encode(audio_data).decode()
return {f"b64_{ext}": b64}
def _load_musicgen(model_name: str, device: str):
from audiocraft.models import MusicGen, AudioGen
name_lower = model_name.lower()
if 'audiogen' in name_lower:
model = AudioGen.get_pretrained(model_name)
else:
model = MusicGen.get_pretrained(model_name)
model.set_generation_params(duration=30)
return model
def _load_audioldm(model_name: str, device: str):
import torch
from diffusers import AudioLDM2Pipeline
pipe = AudioLDM2Pipeline.from_pretrained(model_name, torch_dtype=torch.float16)
pipe = pipe.to(device)
return pipe
def _detect_audio_gen_type(model_name: str) -> str:
n = model_name.lower()
if 'audioldm' in n or 'stable-audio' in n:
return 'audioldm'
if 'audiogen' in n:
return 'audiogen'
return 'musicgen'
def _generate_audio(pipe, model_name: str, request: AudioGenerationRequest):
"""Run generation and return (audio_bytes, ext)."""
import numpy as np, io as _io
model_type = _detect_audio_gen_type(model_name)
if model_type in ('musicgen', 'audiogen'):
pipe.set_generation_params(
duration=request.duration,
top_k=request.top_k,
top_p=request.top_p,
temperature=request.temperature,
cfg_coef=request.cfg_coef,
)
if request.melody and model_type == 'musicgen':
import torchaudio, torch
raw = _decode_b64_or_url(request.melody)
melody_wav, sr = torchaudio.load(_io.BytesIO(raw))
wav = pipe.generate_with_chroma([request.prompt], melody_wav.unsqueeze(0), sr)
else:
wav = pipe.generate([request.prompt])
audio_np = wav[0, 0].cpu().numpy()
sr = pipe.sample_rate
elif model_type == 'audioldm':
result = pipe(
request.prompt,
num_inference_steps=50,
audio_length_in_s=request.duration,
)
audio_np = result.audios[0]
sr = 16000
# Write to wav
import scipy.io.wavfile as wavfile
buf = _io.BytesIO()
audio_int16 = (audio_np * 32767).astype(np.int16)
wavfile.write(buf, sr, audio_int16)
return buf.getvalue(), 'wav'
def _decode_b64_or_url(data: str) -> bytes:
if data.startswith("data:"):
_, enc = data.split(",", 1)
return base64.b64decode(enc)
if data.startswith("http"):
import urllib.request
with urllib.request.urlopen(data, timeout=30) as r:
return r.read()
return base64.b64decode(data)
@router.post("/v1/audio/generate", response_model=AudioGenerationResponse)
async def audio_generate(request: AudioGenerationRequest, http_request: Request = None):
"""
Generate music, sound effects, or ambient audio.
Compatible models: MusicGen, AudioGen, AudioLDM2, StableAudio.
"""
model_info = multi_model_manager.request_model(request.model, model_type="audio_gen")
model_name = model_info.get('model_name')
if not model_name:
err = model_info.get('error', f"Model '{request.model}' not found")
raise HTTPException(status_code=404, detail=err)
model_key = model_info['model_key']
pipe = model_info.get('model_object')
if pipe is None:
device = _derive_device()
model_type = _detect_audio_gen_type(model_name)
try:
if model_type in ('musicgen', 'audiogen'):
pipe = await asyncio.get_event_loop().run_in_executor(
None, _load_musicgen, model_name, device)
else:
pipe = await asyncio.get_event_loop().run_in_executor(
None, _load_audioldm, model_name, device)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to load audio gen model: {e}")
multi_model_manager.models[model_key] = pipe
multi_model_manager.current_model_key = model_key
try:
audio_bytes, ext = await asyncio.get_event_loop().run_in_executor(
None, _generate_audio, pipe, model_name, request)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Audio generation failed: {e}")
result = _save_audio_response(audio_bytes, ext, http_request)
return AudioGenerationResponse(created=int(time.time()), data=[result])
"""
Embeddings endpoint — OpenAI-compatible.
POST /v1/embeddings
Supports sentence-transformers, BGE, E5, nomic-embed, etc.
"""
import asyncio
import base64
import time
from typing import List
from fastapi import APIRouter, HTTPException, Request
from codai.models.manager import multi_model_manager
from codai.pydantic.embedrequest import EmbeddingsRequest, EmbeddingsResponse, EmbeddingObject
router = APIRouter()
global_args = None
def set_global_args(args):
global global_args
global_args = args
def _derive_device() -> str:
if global_args:
d = getattr(global_args, 'vulkan_device', None)
if d is not None:
return f"cuda:{d}"
return "cuda:0"
def _load_embedding_model(model_name: str, device: str):
try:
from sentence_transformers import SentenceTransformer
model = SentenceTransformer(model_name, device=device)
return ('sentence_transformers', model)
except ImportError:
pass
try:
from transformers import AutoTokenizer, AutoModel
import torch
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name).to(device)
return ('transformers', (tokenizer, model, device))
except Exception as e:
raise RuntimeError(f"Cannot load embedding model '{model_name}': {e}")
def _embed_texts(model_obj, texts: List[str], dimensions=None) -> List[List[float]]:
backend, model = model_obj
if backend == 'sentence_transformers':
vecs = model.encode(texts, convert_to_numpy=True, normalize_embeddings=True)
results = [v.tolist() for v in vecs]
else:
import torch
tokenizer, hf_model, device = model
encoded = tokenizer(texts, padding=True, truncation=True,
return_tensors='pt', max_length=512)
encoded = {k: v.to(device) for k, v in encoded.items()}
with torch.no_grad():
out = hf_model(**encoded)
# mean-pool last hidden state
token_embs = out.last_hidden_state
attention = encoded['attention_mask'].unsqueeze(-1).float()
mean_emb = (token_embs * attention).sum(1) / attention.sum(1)
import torch.nn.functional as F
mean_emb = F.normalize(mean_emb, dim=-1)
results = [row.cpu().tolist() for row in mean_emb]
if dimensions:
results = [v[:dimensions] for v in results]
return results
@router.post("/v1/embeddings", response_model=EmbeddingsResponse)
async def create_embeddings(request: EmbeddingsRequest, http_request: Request = None):
"""
OpenAI-compatible embeddings endpoint.
"""
model_info = multi_model_manager.request_model(request.model, model_type="embedding")
model_name = model_info.get('model_name')
if not model_name:
err = model_info.get('error', f"Model '{request.model}' not found")
raise HTTPException(status_code=404, detail=err)
model_key = model_info['model_key']
model_obj = model_info.get('model_object')
if model_obj is None:
device = _derive_device()
try:
model_obj = await asyncio.get_event_loop().run_in_executor(
None, _load_embedding_model, model_name, device)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to load embedding model: {e}")
multi_model_manager.models[model_key] = model_obj
multi_model_manager.current_model_key = model_key
texts = [request.input] if isinstance(request.input, str) else request.input
try:
vectors = await asyncio.get_event_loop().run_in_executor(
None, _embed_texts, model_obj, texts, request.dimensions)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Embedding failed: {e}")
if request.encoding_format == 'base64':
import struct
data = [EmbeddingObject(
index=i,
embedding=base64.b64encode(struct.pack(f'{len(v)}f', *v)).decode()
) for i, v in enumerate(vectors)]
else:
data = [EmbeddingObject(index=i, embedding=v) for i, v in enumerate(vectors)]
total_tokens = sum(len(t.split()) for t in texts)
return EmbeddingsResponse(
data=data,
model=request.model,
usage={"prompt_tokens": total_tokens, "total_tokens": total_tokens},
)
...@@ -6,10 +6,13 @@ import asyncio ...@@ -6,10 +6,13 @@ import asyncio
import base64 import base64
import io import io
import os import os
import time
import uuid import uuid
from typing import Optional
from fastapi import APIRouter, HTTPException, Request from fastapi import APIRouter, HTTPException, Request
from PIL import Image from PIL import Image
from pydantic import BaseModel
# Import from codai modules # Import from codai modules
from codai.models.manager import multi_model_manager from codai.models.manager import multi_model_manager
...@@ -205,6 +208,33 @@ def _derive_diffusers_device(global_args) -> str: ...@@ -205,6 +208,33 @@ def _derive_diffusers_device(global_args) -> str:
return "cuda:0" return "cuda:0"
def _disable_safety_checker(pipe):
"""Null out every safety gate a diffusers pipeline may have.
Works on SD 1.x/2.x (safety_checker + feature_extractor),
SDXL/Flux/video pipelines (no safety_checker but may have safety_concept),
and any future pipeline that gains one. Safe to call on pipelines that
never had any of these attributes.
"""
if hasattr(pipe, 'safety_checker') and pipe.safety_checker is not None:
pipe.safety_checker = None
if hasattr(pipe, 'feature_extractor') and pipe.feature_extractor is not None:
# Keep the extractor object but disconnect it from the checker so
# it cannot produce a blocking signal.
try:
pipe.feature_extractor = None
except Exception:
pass
if hasattr(pipe, 'safety_concept'):
pipe.safety_concept = None
if hasattr(pipe, 'requires_safety_checker'):
try:
pipe.requires_safety_checker = False
except Exception:
pass
return pipe
def _load_diffusers_pipeline(model_name: str, global_args): def _load_diffusers_pipeline(model_name: str, global_args):
""" """
Try to load a model using the diffusers library. Try to load a model using the diffusers library.
...@@ -343,7 +373,10 @@ def _generate_with_diffusers(pipeline, request, global_args, http_request=None): ...@@ -343,7 +373,10 @@ def _generate_with_diffusers(pipeline, request, global_args, http_request=None):
import torch import torch
import numpy as np import numpy as np
import time as time_module import time as time_module
if getattr(request, 'disable_safety_checker', False):
_disable_safety_checker(pipeline)
# Determine size # Determine size
width, height = 1024, 1024 width, height = 1024, 1024
if request.size: if request.size:
...@@ -714,3 +747,518 @@ async def create_image_generation(request: ImageGenerationRequest, http_request: ...@@ -714,3 +747,518 @@ async def create_image_generation(request: ImageGenerationRequest, http_request:
status_code=400, status_code=400,
detail=f"Failed to load image model '{model_name}'. Errors: {'; '.join(error_details) if error_details else 'No compatible backend found'}" detail=f"Failed to load image model '{model_name}'. Errors: {'; '.join(error_details) if error_details else 'No compatible backend found'}"
) )
# =============================================================================
# Image-to-Image Endpoint (POST /v1/images/edits)
# OpenAI-compatible: accepts image + prompt, returns edited image
# =============================================================================
class ImageEditRequest(BaseModel):
model: str
prompt: str
image: str # base64-encoded PNG or "data:image/...;base64,..."
mask: Optional[str] = None # optional inpaint mask (base64 PNG)
n: int = 1
size: Optional[str] = "1024x1024"
response_format: Optional[str] = "url"
strength: Optional[float] = 0.75 # denoising strength (0=keep original, 1=ignore)
steps: Optional[int] = None
guidance_scale: Optional[float] = None
seed: Optional[int] = None
quality: Optional[str] = "standard"
user: Optional[str] = None
class Config:
extra = "allow"
def _decode_b64_image(data: str):
"""Decode base64 image string to PIL Image."""
from PIL import Image as PILImage
if data.startswith("data:"):
_, encoded = data.split(",", 1)
raw = base64.b64decode(encoded)
else:
raw = base64.b64decode(data)
return PILImage.open(io.BytesIO(raw)).convert("RGB")
def _load_img2img_pipeline(model_name: str, global_args):
"""Load a diffusers img2img pipeline."""
import torch
from diffusers import (
StableDiffusionImg2ImgPipeline,
StableDiffusionXLImg2ImgPipeline,
DiffusionPipeline,
)
device = _derive_diffusers_device(global_args)
precision = getattr(global_args, 'image_precision', 'bf16') if global_args else 'bf16'
dtype_map = {'bf16': torch.bfloat16, 'f16': torch.float16, 'f32': torch.float32}
torch_dtype = dtype_map.get(precision, torch.bfloat16)
name_lower = model_name.lower()
if 'xl' in name_lower or 'sdxl' in name_lower or 'flux' in name_lower:
PipeClass = StableDiffusionXLImg2ImgPipeline
elif 'stable-diffusion' in name_lower or 'sd' in name_lower:
PipeClass = StableDiffusionImg2ImgPipeline
else:
PipeClass = DiffusionPipeline # generic fallback
for attempt in range(3):
try:
pipe = PipeClass.from_pretrained(model_name, torch_dtype=torch_dtype)
pipe = pipe.to(device)
if attempt >= 1:
pipe.enable_attention_slicing()
if attempt >= 2:
pipe.enable_sequential_cpu_offload()
return pipe
except RuntimeError as e:
if 'out of memory' in str(e).lower() and attempt < 2:
import gc, torch
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
continue
raise
@router.post("/v1/images/edits")
async def create_image_edit(request: ImageEditRequest, http_request: Request = None):
"""
Image-to-image editing endpoint (OpenAI-compatible).
Accepts a base64-encoded source image and returns an edited image.
"""
global global_args
if not request.image:
raise HTTPException(status_code=400, detail="image is required")
model_info = multi_model_manager.request_model(request.model, model_type="image")
model_name = model_info.get('model_name')
if not model_name:
err = model_info.get('error', f"Model '{request.model}' not found or not registered")
raise HTTPException(status_code=404, detail=err)
model_key = f"img2img:{model_name}"
pipe = multi_model_manager.models.get(model_key)
if pipe is None:
try:
pipe = await asyncio.get_event_loop().run_in_executor(
None, _load_img2img_pipeline, model_name, global_args
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to load img2img model: {e}")
multi_model_manager.models[model_key] = pipe
try:
import torch
source_img = _decode_b64_image(request.image)
width, height = source_img.size
if request.size:
parts = request.size.split("x")
if len(parts) == 2:
try:
width, height = int(parts[0]), int(parts[1])
source_img = source_img.resize((width, height))
except ValueError:
pass
seed = request.seed or getattr(global_args, 'image_seed', None)
generator = torch.Generator(device=pipe.device).manual_seed(seed) if seed else None
quality = request.quality or "standard"
num_steps = request.steps or (30 if quality == "standard" else 50)
cfg_scale = request.guidance_scale or (getattr(global_args, 'image_cfg_scale', 7.5) if quality == "standard" else 9.0)
result = await asyncio.get_event_loop().run_in_executor(
None,
lambda: pipe(
prompt=request.prompt,
image=source_img,
strength=request.strength,
num_inference_steps=num_steps,
guidance_scale=cfg_scale,
num_images_per_prompt=request.n,
generator=generator,
)
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Image editing failed: {e}")
images = []
for img in result.images:
img_data = save_image_response(img, request.response_format, http_request)
images.append(img_data)
return {"created": int(time.time()), "data": images}
# =============================================================================
# Inpainting Endpoint (POST /v1/images/inpaint)
# =============================================================================
class ImageInpaintRequest(BaseModel):
model: str
prompt: str
image: str # base64 source image
mask: str # base64 mask (white = inpaint region)
n: int = 1
size: Optional[str] = "1024x1024"
response_format: Optional[str] = "url"
steps: Optional[int] = None
guidance_scale: Optional[float] = None
strength: Optional[float] = 0.99
seed: Optional[int] = None
quality: Optional[str] = "standard"
class Config:
extra = "allow"
def _load_inpaint_pipeline(model_name: str, global_args):
import torch
from diffusers import (
StableDiffusionInpaintPipeline,
StableDiffusionXLInpaintPipeline,
DiffusionPipeline,
)
device = _derive_diffusers_device(global_args)
precision = getattr(global_args, 'image_precision', 'bf16') if global_args else 'bf16'
dtype_map = {'bf16': torch.bfloat16, 'f16': torch.float16, 'f32': torch.float32}
torch_dtype = dtype_map.get(precision, torch.bfloat16)
n = model_name.lower()
if 'xl' in n or 'sdxl' in n:
PClass = StableDiffusionXLInpaintPipeline
elif 'stable-diffusion' in n or 'inpaint' in n:
PClass = StableDiffusionInpaintPipeline
else:
PClass = DiffusionPipeline
for attempt in range(3):
try:
pipe = PClass.from_pretrained(model_name, torch_dtype=torch_dtype)
pipe = pipe.to(device)
if attempt >= 1:
pipe.enable_attention_slicing()
if attempt >= 2:
pipe.enable_sequential_cpu_offload()
return pipe
except RuntimeError as e:
if 'out of memory' in str(e).lower() and attempt < 2:
import gc, torch
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
continue
raise
@router.post("/v1/images/inpaint")
async def create_image_inpaint(request: ImageInpaintRequest, http_request: Request = None):
"""Inpaint a masked region of an image (OpenAI-compatible extension)."""
global global_args
if not request.image or not request.mask:
raise HTTPException(status_code=400, detail="image and mask are required")
model_info = multi_model_manager.request_model(request.model, model_type="image")
model_name = model_info.get('model_name')
if not model_name:
raise HTTPException(status_code=404, detail=model_info.get('error', 'Model not found'))
model_key = f"inpaint:{model_name}"
pipe = multi_model_manager.models.get(model_key)
if pipe is None:
try:
pipe = await asyncio.get_event_loop().run_in_executor(
None, _load_inpaint_pipeline, model_name, global_args)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to load inpaint model: {e}")
multi_model_manager.models[model_key] = pipe
try:
import torch
source_img = _decode_b64_image(request.image)
mask_img = _decode_b64_image(request.mask).convert("L") # greyscale mask
if request.size:
parts = request.size.split("x")
if len(parts) == 2:
try:
w, h = int(parts[0]), int(parts[1])
source_img = source_img.resize((w, h))
mask_img = mask_img.resize((w, h))
except ValueError:
pass
seed = request.seed or getattr(global_args, 'image_seed', None)
generator = torch.Generator(device=pipe.device).manual_seed(seed) if seed else None
quality = request.quality or "standard"
num_steps = request.steps or (30 if quality == "standard" else 50)
cfg_scale = request.guidance_scale or (getattr(global_args, 'image_cfg_scale', 7.5) if quality == "standard" else 9.0)
result = await asyncio.get_event_loop().run_in_executor(
None,
lambda: pipe(
prompt=request.prompt,
image=source_img,
mask_image=mask_img,
strength=request.strength,
num_inference_steps=num_steps,
guidance_scale=cfg_scale,
num_images_per_prompt=request.n,
generator=generator,
)
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Inpainting failed: {e}")
images = [save_image_response(img, request.response_format, http_request) for img in result.images]
return {"created": int(time.time()), "data": images}
# =============================================================================
# Image Upscale Endpoint (POST /v1/images/upscale)
# =============================================================================
class ImageUpscaleRequest(BaseModel):
model: str
image: str
scale: Optional[int] = 4
response_format: Optional[str] = "url"
class Config:
extra = "allow"
def _load_upscaler(model_name: str, global_args):
device = _derive_diffusers_device(global_args)
n = model_name.lower()
try:
from diffusers import StableDiffusionUpscalePipeline
if 'upscal' in n or 'esrgan' in n:
pipe = StableDiffusionUpscalePipeline.from_pretrained(model_name)
return ('diffusers', pipe.to(device))
except Exception:
pass
# Try basicsr / Real-ESRGAN
try:
from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan import RealESRGANer
model_obj = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64,
num_block=23, num_grow_ch=32, scale=4)
upsampler = RealESRGANer(scale=4, model_path=model_name,
model=model_obj, device=device)
return ('realesrgan', upsampler)
except Exception:
pass
# Fallback: PIL LANCZOS
return ('pil', None)
def _run_upscale(upscaler, image_bytes: bytes, scale: int):
from PIL import Image as PILImage
import numpy as np, io as _io
img = PILImage.open(_io.BytesIO(image_bytes)).convert("RGB")
backend, model = upscaler
if backend == 'realesrgan':
out_arr, _ = model.enhance(np.array(img), outscale=scale)
return PILImage.fromarray(out_arr)
if backend == 'diffusers':
result = model(prompt="", image=img, num_inference_steps=20)
return result.images[0]
# PIL fallback
w, h = img.size
return img.resize((w * scale, h * scale), PILImage.LANCZOS)
@router.post("/v1/images/upscale")
async def create_image_upscale(request: ImageUpscaleRequest, http_request: Request = None):
"""Upscale an image using Real-ESRGAN or PIL LANCZOS fallback."""
global global_args
model_info = multi_model_manager.request_model(request.model, model_type="image")
model_name = model_info.get('model_name') or request.model
model_key = f"upscale:{model_name}"
upscaler = multi_model_manager.models.get(model_key)
if upscaler is None:
try:
upscaler = await asyncio.get_event_loop().run_in_executor(
None, _load_upscaler, model_name, global_args)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to load upscaler: {e}")
multi_model_manager.models[model_key] = upscaler
raw = base64.b64decode(request.image.split(',', 1)[-1] if ',' in request.image else request.image)
try:
out_img = await asyncio.get_event_loop().run_in_executor(
None, _run_upscale, upscaler, raw, request.scale or 4)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Upscaling failed: {e}")
result = save_image_response(out_img, request.response_format, http_request)
return {"created": int(time.time()), "data": [result]}
# =============================================================================
# Depth Estimation Endpoint (POST /v1/images/depth)
# =============================================================================
class ImageDepthRequest(BaseModel):
model: str
image: str
response_format: Optional[str] = "url"
class Config:
extra = "allow"
def _load_depth_model(model_name: str, global_args):
device = _derive_diffusers_device(global_args)
try:
from transformers import pipeline as hf_pipeline
pipe = hf_pipeline("depth-estimation", model=model_name, device=device)
return ('transformers', pipe)
except Exception:
pass
try:
import torch, timm
model = torch.hub.load("intel-isl/MiDaS", "MiDaS_small")
model.eval().to(device)
return ('midas', (model, device))
except Exception as e:
raise RuntimeError(f"Cannot load depth model: {e}")
def _run_depth(depth_model, image_bytes: bytes):
from PIL import Image as PILImage
import numpy as np, io as _io
img = PILImage.open(_io.BytesIO(image_bytes)).convert("RGB")
backend, model = depth_model
if backend == 'transformers':
result = model(img)
depth_arr = np.array(result['depth'])
else:
import torch
model_obj, device = model
transforms = torch.hub.load("intel-isl/MiDaS", "transforms").small_transform
inp = transforms(np.array(img)).to(device)
with torch.no_grad():
depth_arr = model_obj(inp).squeeze().cpu().numpy()
# Normalise to 0-255
d_min, d_max = depth_arr.min(), depth_arr.max()
if d_max > d_min:
depth_arr = ((depth_arr - d_min) / (d_max - d_min) * 255).astype(np.uint8)
else:
depth_arr = depth_arr.astype(np.uint8)
return PILImage.fromarray(depth_arr)
@router.post("/v1/images/depth")
async def create_image_depth(request: ImageDepthRequest, http_request: Request = None):
"""Estimate depth map from an image."""
global global_args
model_name = request.model
model_key = f"depth:{model_name}"
depth_model = multi_model_manager.models.get(model_key)
if depth_model is None:
try:
depth_model = await asyncio.get_event_loop().run_in_executor(
None, _load_depth_model, model_name, global_args)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to load depth model: {e}")
multi_model_manager.models[model_key] = depth_model
raw = base64.b64decode(request.image.split(',', 1)[-1] if ',' in request.image else request.image)
try:
depth_img = await asyncio.get_event_loop().run_in_executor(
None, _run_depth, depth_model, raw)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Depth estimation failed: {e}")
result = save_image_response(depth_img, request.response_format, http_request)
return {"created": int(time.time()), "data": [result]}
# =============================================================================
# Segmentation Endpoint (POST /v1/images/segment)
# =============================================================================
class ImageSegmentRequest(BaseModel):
model: str
image: str
points: Optional[list] = None # [[x,y], ...] positive prompt points for SAM
boxes: Optional[list] = None # [[x1,y1,x2,y2], ...] box prompts
response_format: Optional[str] = "url"
class Config:
extra = "allow"
def _load_segmentation_model(model_name: str, global_args):
device = _derive_diffusers_device(global_args)
try:
from transformers import SamModel, SamProcessor
import torch
model = SamModel.from_pretrained(model_name).to(device)
processor = SamProcessor.from_pretrained(model_name)
return ('sam', (model, processor, device))
except Exception:
pass
try:
from transformers import pipeline as hf_pipeline
pipe = hf_pipeline("image-segmentation", model=model_name, device=device)
return ('transformers', pipe)
except Exception as e:
raise RuntimeError(f"Cannot load segmentation model: {e}")
def _run_segmentation(seg_model, image_bytes: bytes, points, boxes):
from PIL import Image as PILImage
import numpy as np, io as _io
img = PILImage.open(_io.BytesIO(image_bytes)).convert("RGB")
backend, model_data = seg_model
if backend == 'sam':
import torch
sam_model, processor, device = model_data
input_points = [points] if points else None
input_boxes = [boxes] if boxes else None
inputs = processor(img, input_points=input_points,
input_boxes=input_boxes, return_tensors='pt')
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
outputs = sam_model(**inputs)
masks = processor.image_processor.post_process_masks(
outputs.pred_masks.cpu(),
inputs['original_sizes'].cpu(),
inputs['reshaped_input_sizes'].cpu()
)[0]
# Take best mask, overlay on image
mask_np = masks[0, 0].numpy().astype(np.uint8) * 255
overlay = np.array(img.copy())
overlay[mask_np == 0] = overlay[mask_np == 0] // 2
return PILImage.fromarray(overlay)
else: # transformers generic
results = model_data(img)
# Draw first segment mask
out = np.array(img)
if results:
mask = np.array(results[0]['mask'])
out[mask == 0] = out[mask == 0] // 2
return PILImage.fromarray(out)
@router.post("/v1/images/segment")
async def create_image_segment(request: ImageSegmentRequest, http_request: Request = None):
"""Segment objects in an image using SAM or similar models."""
global global_args
model_name = request.model
model_key = f"segment:{model_name}"
seg_model = multi_model_manager.models.get(model_key)
if seg_model is None:
try:
seg_model = await asyncio.get_event_loop().run_in_executor(
None, _load_segmentation_model, model_name, global_args)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to load segmentation model: {e}")
multi_model_manager.models[model_key] = seg_model
raw = base64.b64decode(request.image.split(',', 1)[-1] if ',' in request.image else request.image)
try:
seg_img = await asyncio.get_event_loop().run_in_executor(
None, _run_segmentation, seg_model, raw,
request.points, request.boxes)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Segmentation failed: {e}")
result = save_image_response(seg_img, request.response_format, http_request)
return {"created": int(time.time()), "data": [result]}
"""
Video generation and manipulation endpoints for the codai API.
Endpoints:
POST /v1/video/generations – t2v | i2v | v2v | ti2v | interp
POST /v1/video/upscale – video super-resolution
POST /v1/video/subtitle – subtitle generation / burn-in
POST /v1/video/interpolate – frame interpolation (increase FPS)
POST /v1/video/dub – translation + TTS dubbing
"""
import asyncio
import base64
import io
import os
import subprocess
import tempfile
import time
import uuid
from typing import List, Optional
from fastapi import APIRouter, HTTPException, Request
from codai.models.manager import multi_model_manager
from codai.pydantic.videorequest import (
VideoGenerationRequest, VideoGenerationResponse,
VideoUpscaleRequest, VideoSubtitleRequest,
VideoInterpolateRequest, VideoDubRequest,
)
from codai.api.images import _disable_safety_checker
router = APIRouter()
global_args = None
global_file_path = None
def set_global_args(args):
global global_args
global_args = args
def set_global_file_path(path):
global global_file_path
global_file_path = path
# =============================================================================
# Shared helpers
# =============================================================================
def _derive_device() -> str:
if global_args:
for attr in ('image_vulkan_device', 'vulkan_device'):
d = getattr(global_args, attr, None)
if d is not None:
return f"cuda:{d}"
return "cuda:0"
def _decode_b64_or_url(data: str) -> bytes:
if not data:
return b''
if data.startswith("data:"):
_, enc = data.split(",", 1)
return base64.b64decode(enc)
if data.startswith("http://") or data.startswith("https://"):
import urllib.request
with urllib.request.urlopen(data, timeout=60) as r:
return r.read()
return base64.b64decode(data)
def _pil_from_b64(data: str):
from PIL import Image as PILImage
return PILImage.open(io.BytesIO(_decode_b64_or_url(data))).convert("RGB")
def _build_url(filename: str, http_request) -> str:
url_setting = getattr(global_args, 'url', 'auto') if global_args else 'auto'
if url_setting == 'auto':
host = (http_request.headers.get('host', '127.0.0.1')
if http_request else '127.0.0.1')
if ':' in host:
parts = host.split(':')
if len(parts) == 2 and parts[1].isdigit():
host = parts[0]
use_https = getattr(global_args, 'https', False) or getattr(global_args, 'pubkey', None)
proto = 'https' if use_https else 'http'
port = getattr(global_args, 'port', 8000) if global_args else 8000
base_url = f"{proto}://{host}:{port}"
else:
base_url = url_setting.rstrip('/')
return f"{base_url}/v1/files/{filename}"
def _save_file(data: bytes, ext: str, http_request) -> dict:
filename = f"{uuid.uuid4().hex}.{ext}"
if global_file_path:
os.makedirs(global_file_path, exist_ok=True)
out_path = os.path.join(global_file_path, filename)
with open(out_path, 'wb') as f:
f.write(data)
return {"url": _build_url(filename, http_request)}
else:
return {f"b64_{ext}": base64.b64encode(data).decode()}
def _frames_to_mp4(frames, fps: int) -> bytes:
import imageio, numpy as np
frames = [np.array(f) for f in frames]
with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as tmp:
tmp_path = tmp.name
imageio.mimsave(tmp_path, frames, fps=fps, codec='libx264', quality=8)
with open(tmp_path, 'rb') as f:
data = f.read()
os.unlink(tmp_path)
return data
def _video_bytes_to_path(video_b64: str) -> str:
"""Decode a base64/URL video to a temp file path."""
raw = _decode_b64_or_url(video_b64)
tmp = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False)
tmp.write(raw)
tmp.close()
return tmp.name
# =============================================================================
# Pipeline loading
# =============================================================================
def _detect_pipeline_class(model_name: str, mode: str):
"""Return the appropriate diffusers pipeline class."""
n = model_name.lower()
try:
from diffusers import (
CogVideoXPipeline, CogVideoXImageToVideoPipeline,
LTXPipeline, LTXImageToVideoPipeline,
StableVideoDiffusionPipeline,
I2VGenXLPipeline,
AnimateDiffPipeline,
)
if 'cogvideox' in n or 'cogvideo' in n:
return CogVideoXImageToVideoPipeline if (mode in ('i2v', 'ti2v')) else CogVideoXPipeline
if 'ltx' in n:
return LTXImageToVideoPipeline if (mode in ('i2v', 'ti2v')) else LTXPipeline
if 'svd' in n or 'stable-video-diffusion' in n:
return StableVideoDiffusionPipeline
if 'i2vgen' in n:
return I2VGenXLPipeline
if 'animatediff' in n or 'animateddiff' in n:
return AnimateDiffPipeline
except ImportError:
pass
try:
from diffusers import DiffusionPipeline
return DiffusionPipeline
except ImportError:
return None
def _load_video_pipeline(model_name: str, device: str, mode: str):
import torch, gc
PClass = _detect_pipeline_class(model_name, mode)
if PClass is None:
raise RuntimeError("diffusers not installed: pip install diffusers")
precision = getattr(global_args, 'image_precision', 'bf16') if global_args else 'bf16'
dtype_map = {'bf16': torch.bfloat16, 'f16': torch.float16, 'f32': torch.float32}
torch_dtype = dtype_map.get(precision, torch.bfloat16)
offload = getattr(global_args, 'offload_strategy', None) if global_args else None
for attempt in range(3):
try:
pipe = PClass.from_pretrained(model_name, torch_dtype=torch_dtype)
if offload == 'sequential' or attempt >= 2:
pipe.enable_sequential_cpu_offload()
elif offload == 'model' or attempt >= 1:
pipe.enable_model_cpu_offload()
else:
pipe = pipe.to(device)
return pipe
except RuntimeError as e:
if 'out of memory' in str(e).lower() and attempt < 2:
gc.collect()
try:
import torch as _torch
if _torch.cuda.is_available():
_torch.cuda.empty_cache()
except Exception:
pass
continue
raise
# =============================================================================
# Frame interpolation model loading
# =============================================================================
def _load_rife(model_name: str, device: str):
"""Load RIFE frame interpolation model."""
try:
# Try rife-ncnn-vulkan first (subprocess)
import shutil
if shutil.which('rife-ncnn-vulkan'):
return ('rife_ncnn', None)
except Exception:
pass
# Fallback: use IFNet from a HF repo
try:
from diffusers import IFPipeline # noqa – just checking if diffusers has it
except ImportError:
pass
return ('rife_hf', model_name)
# =============================================================================
# Generation logic
# =============================================================================
def _build_call_kwargs(request: VideoGenerationRequest) -> dict:
kw = {}
if request.prompt:
kw['prompt'] = request.prompt
if request.negative_prompt:
kw['negative_prompt'] = request.negative_prompt
if request.num_inference_steps:
kw['num_inference_steps'] = request.num_inference_steps
if request.guidance_scale:
kw['guidance_scale'] = request.guidance_scale
if request.num_frames:
kw['num_frames'] = request.num_frames
if request.width and request.height:
kw['width'] = request.width
kw['height'] = request.height
if request.seed is not None:
import torch
kw['generator'] = torch.Generator().manual_seed(request.seed)
return kw
def _apply_camera_motion(kw: dict, camera_motion: str):
"""Inject camera motion hint into pipeline kwargs (model-dependent)."""
# CogVideoX supports camera_motion natively
if camera_motion:
kw['camera_motion'] = camera_motion
def _apply_character_refs(kw: dict, character_references: List[str], strength: float):
"""Apply character reference images to pipeline kwargs."""
if not character_references:
return
imgs = [_pil_from_b64(r) for r in character_references]
kw['ip_adapter_image'] = imgs[0] if len(imgs) == 1 else imgs
kw['ip_adapter_scale'] = strength
def _run_pipeline(pipe, kw: dict):
result = pipe(**kw)
frames_raw = getattr(result, 'frames', None) or result[0]
if isinstance(frames_raw, list) and isinstance(frames_raw[0], list):
return frames_raw[0]
return list(frames_raw)
def _generate_video(pipe, request: VideoGenerationRequest):
mode = request.mode or ('i2v' if (request.image or request.init_image)
else 'v2v' if request.video else 't2v')
fps = request.fps or 8
kw = _build_call_kwargs(request)
kw.setdefault('num_inference_steps', 25)
kw.setdefault('guidance_scale', 7.5)
kw.setdefault('num_frames', 16)
_apply_camera_motion(kw, request.camera_motion)
if request.character_references:
_apply_character_refs(kw, request.character_references, request.character_strength or 0.8)
init_src = request.init_image or request.image
if mode == 'i2v' and init_src:
kw['image'] = _pil_from_b64(init_src)
kw.pop('prompt', None) # SVD doesn't take text
elif mode == 'ti2v' and init_src:
kw['image'] = _pil_from_b64(init_src)
# prompt stays — model uses both
elif mode == 'interp':
if not init_src or not request.end_image:
raise ValueError("interp mode requires both init_image and end_image")
kw['image'] = _pil_from_b64(init_src)
kw['image_end'] = _pil_from_b64(request.end_image)
kw.pop('prompt', None)
elif mode == 'v2v' and request.video:
kw['video'] = _decode_b64_or_url(request.video)
if request.strength is not None:
kw['strength'] = request.strength
frames = _run_pipeline(pipe, kw)
return frames, fps
# =============================================================================
# Post-processing helpers
# =============================================================================
def _postprocess_video(mp4_bytes: bytes, request: VideoGenerationRequest,
http_request, temp_paths: list) -> bytes:
"""Apply upscale / interpolation / audio steps to a raw mp4 blob."""
path = _tmp_write(mp4_bytes, '.mp4')
temp_paths.append(path)
if request.upscale_output:
path = _ffmpeg_upscale(path, request.upscale_factor or 2, temp_paths)
if request.interpolate_output and request.fps_multiplier:
path = _rife_interpolate(path, request.fps_multiplier, temp_paths)
if request.add_audio:
path = _add_audio_to_video(path, request, temp_paths)
if request.generate_subtitles or request.burn_subtitles:
path = _add_subtitles(path, request, temp_paths)
with open(path, 'rb') as f:
return f.read()
def _tmp_write(data: bytes, ext: str) -> str:
tmp = tempfile.NamedTemporaryFile(suffix=ext, delete=False)
tmp.write(data)
tmp.close()
return tmp.name
def _ffmpeg_upscale(path: str, factor: int, temps: list) -> str:
out = tempfile.mktemp(suffix='_up.mp4')
temps.append(out)
scale = f"scale=iw*{factor}:ih*{factor}:flags=lanczos"
cmd = ['ffmpeg', '-y', '-i', path, '-vf', scale, '-c:a', 'copy', out]
r = subprocess.run(cmd, capture_output=True)
if r.returncode == 0:
return out
return path # fallback to original if ffmpeg fails
def _rife_interpolate(path: str, multiplier: int, temps: list) -> str:
out = tempfile.mktemp(suffix='_rife.mp4')
temps.append(out)
# Try rife-ncnn-vulkan binary if available
import shutil
if shutil.which('rife-ncnn-vulkan'):
frames_dir = tempfile.mkdtemp()
out_dir = tempfile.mkdtemp()
temps += [frames_dir, out_dir]
subprocess.run(['ffmpeg', '-y', '-i', path, f'{frames_dir}/%08d.png'],
capture_output=True)
subprocess.run(['rife-ncnn-vulkan', '-i', frames_dir, '-o', out_dir,
'-m', f'rife-v4'], capture_output=True)
subprocess.run(['ffmpeg', '-y', '-r', str(multiplier * 8), '-i',
f'{out_dir}/%08d.png', '-c:v', 'libx264', out],
capture_output=True)
if os.path.exists(out):
return out
# Simple ffmpeg minterpolate fallback
fps_expr = f"fps=fps={multiplier}*source_fps"
cmd = ['ffmpeg', '-y', '-i', path, '-filter:v',
f'minterpolate=fps={multiplier * 8}', '-c:a', 'copy', out]
r = subprocess.run(cmd, capture_output=True)
return out if r.returncode == 0 else path
def _add_audio_to_video(path: str, request: VideoGenerationRequest,
temps: list) -> str:
out = tempfile.mktemp(suffix='_audio.mp4')
temps.append(out)
if request.audio_file:
audio_path = _tmp_write(_decode_b64_or_url(request.audio_file), '.wav')
temps.append(audio_path)
elif request.tts_text:
audio_path = _generate_tts(request.tts_text, request.tts_voice,
request.tts_speed or 1.0, temps)
else:
return path # nothing to add
if not audio_path or not os.path.exists(audio_path):
return path
cmd = ['ffmpeg', '-y', '-i', path, '-i', audio_path,
'-c:v', 'copy', '-c:a', 'aac', '-shortest', out]
r = subprocess.run(cmd, capture_output=True)
return out if r.returncode == 0 else path
def _generate_tts(text: str, voice: Optional[str], speed: float,
temps: list) -> Optional[str]:
"""Quick TTS using kokoro or edge-tts — returns wav file path."""
try:
import edge_tts, asyncio as _aio
voice_id = voice or 'en-US-JennyNeural'
out = tempfile.mktemp(suffix='.mp3')
temps.append(out)
tts = edge_tts.Communicate(text, voice_id, rate=f"+{int((speed - 1) * 100)}%")
_aio.get_event_loop().run_until_complete(tts.save(out))
return out
except ImportError:
pass
try:
from kokoro import KPipeline
import soundfile as sf, numpy as np
pipe = KPipeline(lang_code='a')
audio, sr = pipe(text, voice=voice or 'af_sky', speed=speed)
out = tempfile.mktemp(suffix='.wav')
temps.append(out)
sf.write(out, np.concatenate(audio), sr)
return out
except ImportError:
pass
return None
def _add_subtitles(path: str, request: VideoGenerationRequest, temps: list) -> str:
"""Transcribe video audio → subtitles, optionally burn them in."""
try:
import whisper
except ImportError:
return path # skip if whisper not available
srt_path = _whisper_transcribe(path, request.subtitle_language,
request.whisper_model, temps)
if not srt_path:
return path
if request.translate_subtitles and request.subtitle_target_lang:
srt_path = _translate_srt(srt_path, request.subtitle_target_lang, temps)
if request.burn_subtitles:
out = tempfile.mktemp(suffix='_sub.mp4')
temps.append(out)
# Use ASS-style subtitle filter for better styling
style = request.subtitle_style or 'default'
vf = f"subtitles={srt_path}"
if style == 'karaoke':
vf = f"ass={srt_path}"
cmd = ['ffmpeg', '-y', '-i', path, '-vf', vf, '-c:a', 'copy', out]
r = subprocess.run(cmd, capture_output=True)
if r.returncode == 0:
return out
return path
def _whisper_transcribe(video_path: str, language: Optional[str],
model_name: Optional[str], temps: list) -> Optional[str]:
try:
import whisper as _whisper
model = _whisper.load_model(model_name or 'base')
result = model.transcribe(video_path, language=language)
srt_path = tempfile.mktemp(suffix='.srt')
temps.append(srt_path)
with open(srt_path, 'w') as f:
for i, seg in enumerate(result['segments'], 1):
def _fmt(t):
h = int(t // 3600); m = int((t % 3600) // 60); s = t % 60
return f"{h:02d}:{m:02d}:{s:06.3f}".replace('.', ',')
f.write(f"{i}\n{_fmt(seg['start'])} --> {_fmt(seg['end'])}\n{seg['text'].strip()}\n\n")
return srt_path
except Exception:
return None
def _translate_srt(srt_path: str, target_lang: str, temps: list) -> str:
"""Translate SRT using argostranslate or fall back to original."""
try:
import argostranslate.package, argostranslate.translate
with open(srt_path) as f:
content = f.read()
lines = content.split('\n')
translated = []
for line in lines:
if line and not line[0].isdigit() and '-->' not in line:
line = argostranslate.translate.translate(line, 'en', target_lang)
translated.append(line)
out = tempfile.mktemp(suffix='.srt')
temps.append(out)
with open(out, 'w') as f:
f.write('\n'.join(translated))
return out
except Exception:
return srt_path
# =============================================================================
# Main generation endpoint
# =============================================================================
@router.post("/v1/video/generations", response_model=VideoGenerationResponse)
async def video_generations(request: VideoGenerationRequest,
http_request: Request = None):
"""
Generate video.
Modes (request.mode):
t2v – text-to-video
i2v – image-to-video (init_image required)
v2v – video-to-video (video required)
ti2v – text + image → video (prompt is primary driver)
interp – frame interpolation (init_image + end_image required)
"""
if not request.model:
raise HTTPException(status_code=400, detail="model is required")
# Infer mode from inputs if not set
if not request.mode or request.mode == 't2v':
if request.init_image or request.image:
request.mode = 'ti2v' if request.prompt else 'i2v'
elif request.end_image:
request.mode = 'interp'
elif request.video:
request.mode = 'v2v'
model_info = multi_model_manager.request_model(request.model, model_type="video")
model_name = model_info.get('model_name')
if not model_name:
err = model_info.get('error', f"Model '{request.model}' not found")
raise HTTPException(status_code=404, detail=err)
model_key = model_info['model_key']
pipe = model_info.get('model_object')
if pipe is None:
device = _derive_device()
try:
pipe = await asyncio.get_event_loop().run_in_executor(
None, _load_video_pipeline, model_name, device, request.mode)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to load video model: {e}")
multi_model_manager.models[model_key] = pipe
multi_model_manager.current_model_key = model_key
if getattr(request, 'disable_safety_checker', False):
_disable_safety_checker(pipe)
try:
frames, fps = await asyncio.get_event_loop().run_in_executor(
None, _generate_video, pipe, request)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Video generation failed: {e}")
# Encode raw frames to MP4
try:
import imageio, numpy as np
frame_np = [np.array(f) for f in frames]
mp4_bytes = _frames_to_mp4(frame_np, fps)
except ImportError:
raise HTTPException(status_code=500,
detail="imageio[ffmpeg] required: pip install imageio[ffmpeg]")
# Post-processing pipeline (upscale, audio, subtitles, …)
temps = []
try:
needs_post = any([
request.upscale_output,
request.interpolate_output,
request.add_audio,
request.generate_subtitles,
request.burn_subtitles,
])
if needs_post:
mp4_bytes = await asyncio.get_event_loop().run_in_executor(
None, _postprocess_video, mp4_bytes, request, http_request, temps)
finally:
for t in temps:
try:
if os.path.isfile(t):
os.unlink(t)
elif os.path.isdir(t):
import shutil
shutil.rmtree(t, ignore_errors=True)
except Exception:
pass
result = _save_file(mp4_bytes, 'mp4', http_request)
return VideoGenerationResponse(created=int(time.time()), data=[result])
# =============================================================================
# Video upscale endpoint
# =============================================================================
@router.post("/v1/video/upscale")
async def video_upscale(request: VideoUpscaleRequest, http_request: Request = None):
"""
Upscale a video using ffmpeg lanczos or Real-ESRGAN.
The model field can be 'realesrgan' or any registered video_upscaling model.
"""
raw = _decode_b64_or_url(request.video)
temps = []
try:
in_path = _tmp_write(raw, '.mp4')
temps.append(in_path)
out_path = await asyncio.get_event_loop().run_in_executor(
None, _ffmpeg_upscale, in_path, request.upscale_factor or 2, temps)
with open(out_path, 'rb') as f:
out_bytes = f.read()
finally:
for t in temps:
try:
os.unlink(t)
except Exception:
pass
result = _save_file(out_bytes, 'mp4', http_request)
return {"created": int(time.time()), "data": [result]}
# =============================================================================
# Subtitle generation endpoint
# =============================================================================
@router.post("/v1/video/subtitle")
async def video_subtitle(request: VideoSubtitleRequest, http_request: Request = None):
"""
Generate subtitles for a video.
Returns SRT/VTT text or a URL to the video with burned-in subtitles.
"""
raw = _decode_b64_or_url(request.video)
temps = []
try:
in_path = _tmp_write(raw, '.mp4')
temps.append(in_path)
srt_path = await asyncio.get_event_loop().run_in_executor(
None, _whisper_transcribe, in_path, request.language, None, temps)
if not srt_path:
raise HTTPException(status_code=500,
detail="Whisper not installed: pip install openai-whisper")
if request.translate and request.target_lang:
srt_path = await asyncio.get_event_loop().run_in_executor(
None, _translate_srt, srt_path, request.target_lang, temps)
if request.burn:
out_path = tempfile.mktemp(suffix='_sub.mp4')
temps.append(out_path)
cmd = ['ffmpeg', '-y', '-i', in_path,
'-vf', f'subtitles={srt_path}',
'-c:a', 'copy', out_path]
r = subprocess.run(cmd, capture_output=True)
if r.returncode != 0:
raise HTTPException(status_code=500,
detail=f"ffmpeg subtitle burn failed: {r.stderr.decode()}")
with open(out_path, 'rb') as f:
out_bytes = f.read()
result = _save_file(out_bytes, 'mp4', http_request)
return {"created": int(time.time()), "data": [result]}
# Return raw subtitle text
with open(srt_path) as f:
srt_text = f.read()
return {"created": int(time.time()), "data": [{"text": srt_text, "format": "srt"}]}
finally:
for t in temps:
try:
os.unlink(t)
except Exception:
pass
# =============================================================================
# Frame interpolation endpoint
# =============================================================================
@router.post("/v1/video/interpolate")
async def video_interpolate(request: VideoInterpolateRequest, http_request: Request = None):
"""
Increase video FPS via frame interpolation.
Supports rife-ncnn-vulkan (if installed) or ffmpeg minterpolate fallback.
"""
temps = []
try:
if request.video:
raw = _decode_b64_or_url(request.video)
in_path = _tmp_write(raw, '.mp4')
temps.append(in_path)
elif request.init_image and request.end_image:
# Build a 2-frame video from the two images, then interpolate
from PIL import Image as PILImage
import numpy as np, imageio
img1 = _pil_from_b64(request.init_image)
img2 = _pil_from_b64(request.end_image)
in_path = tempfile.mktemp(suffix='.mp4')
temps.append(in_path)
imageio.mimsave(in_path, [np.array(img1), np.array(img2)],
fps=2, codec='libx264')
else:
raise HTTPException(status_code=400,
detail="Provide either video or init_image + end_image")
mult = request.fps_multiplier or 2
out_path = await asyncio.get_event_loop().run_in_executor(
None, _rife_interpolate, in_path, mult, temps)
with open(out_path, 'rb') as f:
out_bytes = f.read()
finally:
for t in temps:
try:
os.unlink(t)
except Exception:
pass
result = _save_file(out_bytes, 'mp4', http_request)
return {"created": int(time.time()), "data": [result]}
# =============================================================================
# Video dubbing endpoint
# =============================================================================
@router.post("/v1/video/dub")
async def video_dub(request: VideoDubRequest, http_request: Request = None):
"""
Translate and re-dub a video.
Pipeline: Whisper → translate → TTS → merge audio → (optional) lip sync.
"""
raw = _decode_b64_or_url(request.video)
temps = []
try:
in_path = _tmp_write(raw, '.mp4')
temps.append(in_path)
# 1. Transcribe
srt_path = await asyncio.get_event_loop().run_in_executor(
None, _whisper_transcribe, in_path, request.source_lang, None, temps)
if not srt_path:
raise HTTPException(status_code=500, detail="Whisper not available")
# 2. Translate subtitles
if request.target_lang:
srt_path = await asyncio.get_event_loop().run_in_executor(
None, _translate_srt, srt_path, request.target_lang, temps)
# 3. Generate dubbed audio from translated text
with open(srt_path) as f:
srt_content = f.read()
plain_text = '\n'.join(
line for line in srt_content.split('\n')
if line and not line[0].isdigit() and '-->' not in line
)
audio_path = await asyncio.get_event_loop().run_in_executor(
None, _generate_tts, plain_text, None, 1.0, temps)
if not audio_path:
raise HTTPException(status_code=500, detail="TTS generation failed (install edge-tts or kokoro)")
# 4. Merge dubbed audio with video
out_path = tempfile.mktemp(suffix='_dubbed.mp4')
temps.append(out_path)
cmd = ['ffmpeg', '-y', '-i', in_path, '-i', audio_path,
'-map', '0:v', '-map', '1:a',
'-c:v', 'copy', '-c:a', 'aac', '-shortest', out_path]
r = subprocess.run(cmd, capture_output=True)
if r.returncode != 0:
raise HTTPException(status_code=500,
detail=f"Audio merge failed: {r.stderr.decode()}")
# 5. Burn subtitles if requested
if request.burn_subtitles:
sub_out = tempfile.mktemp(suffix='_sub.mp4')
temps.append(sub_out)
cmd2 = ['ffmpeg', '-y', '-i', out_path,
'-vf', f'subtitles={srt_path}',
'-c:a', 'copy', sub_out]
r2 = subprocess.run(cmd2, capture_output=True)
if r2.returncode == 0:
out_path = sub_out
with open(out_path, 'rb') as f:
out_bytes = f.read()
finally:
for t in temps:
try:
os.unlink(t)
except Exception:
pass
result = _save_file(out_bytes, 'mp4', http_request)
return {"created": int(time.time()), "data": [result]}
...@@ -376,6 +376,27 @@ def main(): ...@@ -376,6 +376,27 @@ def main():
if mid: if mid:
multi_model_manager.set_tts_model(mid, config=_model_cfg(m, "tts") if isinstance(m, dict) else {}) multi_model_manager.set_tts_model(mid, config=_model_cfg(m, "tts") if isinstance(m, dict) else {})
# Video generation models
video_models = models_config.get("video_models", [])
for m in video_models:
mid = _model_id(m)
if mid:
multi_model_manager.set_video_model(mid, config=_model_cfg(m, "video") if isinstance(m, dict) else {})
# Audio generation models (MusicGen, AudioLDM2, …)
audio_gen_models = models_config.get("audio_gen_models", [])
for m in audio_gen_models:
mid = _model_id(m)
if mid:
multi_model_manager.set_audio_gen_model(mid, config=_model_cfg(m, "audio_gen") if isinstance(m, dict) else {})
# Embedding models
embedding_models = models_config.get("embedding_models", [])
for m in embedding_models:
mid = _model_id(m)
if mid:
multi_model_manager.set_embedding_model(mid, config=_model_cfg(m, "embedding") if isinstance(m, dict) else {})
# Register aliases # Register aliases
aliases = models_config.get("aliases", {}) aliases = models_config.get("aliases", {})
for alias, model in aliases.items(): for alias, model in aliases.items():
...@@ -387,7 +408,10 @@ def main(): ...@@ -387,7 +408,10 @@ def main():
[("audio", m) for m in audio_models] + [("audio", m) for m in audio_models] +
[("image", m) for m in image_models] + [("image", m) for m in image_models] +
[("vision", m) for m in vision_models] + [("vision", m) for m in vision_models] +
[("tts", m) for m in tts_models] [("tts", m) for m in tts_models] +
[("video", m) for m in video_models] +
[("audio_gen", m) for m in audio_gen_models] +
[("embedding", m) for m in embedding_models]
) )
for mtype, m in all_model_entries: for mtype, m in all_model_entries:
mid = _model_id(m) mid = _model_id(m)
...@@ -498,6 +522,22 @@ def main(): ...@@ -498,6 +522,22 @@ def main():
from codai.api.images import set_global_args as set_images_global_args from codai.api.images import set_global_args as set_images_global_args
set_images_global_args(global_args) set_images_global_args(global_args)
# Set video module global args
from codai.api.video import set_global_args as set_video_global_args, set_global_file_path as set_video_file_path
set_video_global_args(global_args)
if global_file_path:
set_video_file_path(global_file_path)
# Set audio_gen module global args
from codai.api.audio_gen import set_global_args as set_audiogen_global_args, set_global_file_path as set_audiogen_file_path
set_audiogen_global_args(global_args)
if global_file_path:
set_audiogen_file_path(global_file_path)
# Set embeddings module global args
from codai.api.embeddings import set_global_args as set_embed_global_args
set_embed_global_args(global_args)
# Pre-load image models marked as load_mode == "load" # Pre-load image models marked as load_mode == "load"
for m in image_models: for m in image_models:
mid = _model_id(m) mid = _model_id(m)
......
"""Model capabilities module.""" """Model capabilities module."""
from dataclasses import dataclass from dataclasses import dataclass
from typing import List
@dataclass @dataclass
class ModelCapabilities: class ModelCapabilities:
"""Represents what a model can do.""" """Represents what a model can do."""
text_generation: bool = False # LLM/chat completion # Language / multimodal
image_to_text: bool = False # Image understanding (captioning, VQA) text_generation: bool = False # LLM chat/completion
image_generation: bool = False # Text-to-image (Stable Diffusion) image_to_text: bool = False # VQA, captioning, vision LLMs
speech_to_text: bool = False # Audio transcription embeddings: bool = False # Text/image embeddings
text_to_speech: bool = False # Speech synthesis
# Image generation & editing
image_generation: bool = False # Text-to-image (SD, Flux, …)
image_to_image: bool = False # img2img denoising
inpainting: bool = False # Inpaint with mask
controlnet: bool = False # ControlNet-guided generation
# Image analysis & processing
depth_estimation: bool = False # Monocular depth (MiDaS, DPT, ZoeDepth)
image_segmentation: bool = False # SAM, Mask R-CNN
image_upscaling: bool = False # ESRGAN, SwinIR, Real-ESRGAN
face_restoration: bool = False # CodeFormer, GFPGAN
object_detection: bool = False # YOLO, DETR
style_transfer: bool = False # Neural style transfer
# Video generation & editing
video_generation: bool = False # Text-to-video (CogVideoX, LTX, …)
image_to_video: bool = False # Image-to-video (SVD, I2VGen, …)
video_to_video: bool = False # Video style transfer / enhancement
video_interpolation: bool = False # Frame interpolation (FILM, RIFE)
video_upscaling: bool = False # Video super-resolution
# Audio: speech
speech_to_text: bool = False # Whisper transcription
text_to_speech: bool = False # Kokoro, Bark, XTTS
subtitle_generation: bool = False # WhisperX / forced alignment subtitles
# Audio: generation & manipulation
audio_generation: bool = False # MusicGen, AudioLDM2, StableAudio
audio_to_audio: bool = False # Denoising, source separation, …
# Video + audio pipelines
lip_sync: bool = False # Wav2Lip, SadTalker
video_dubbing: bool = False # Translation + TTS + lip sync
def to_list(self) -> List[str]:
out = []
for name, val in self.__dataclass_fields__.items():
if getattr(self, name):
out.append(name)
return out
def __str__(self): def __str__(self):
caps = [] return ", ".join(self.to_list()) or "none"
if self.text_generation:
caps.append("text")
if self.image_to_text:
caps.append("image-to-text")
if self.image_generation:
caps.append("image")
if self.speech_to_text:
caps.append("speech-to-text")
if self.text_to_speech:
caps.append("text-to-speech")
return ", ".join(caps) if caps else "none"
def detect_model_capabilities(model_name: str) -> ModelCapabilities: def detect_model_capabilities(model_name: str) -> ModelCapabilities:
""" """
Detect model capabilities based on model name/type. Detect model capabilities from the model name/ID.
Heuristic only — actual capabilities depend on the checkpoint.
This is a heuristic detection - actual capabilities may vary.
""" """
caps = ModelCapabilities() caps = ModelCapabilities()
if not model_name: if not model_name:
return caps return caps
name_lower = model_name.lower() n = model_name.lower()
# Check for image generation models (Stable Diffusion, SDXL, etc.) # ── Video generation ─────────────────────────────────────────────────────
if any(x in name_lower for x in ['stable-diffusion', 'sd15', 'sdxl', 'sd-xl', 'turbo', 'playground']): if any(x in n for x in ['cogvideox', 'cogvideo', 'ltx-video', 'ltxvideo',
caps.image_generation = True 'hunyuan-video', 'mochi-1', 'dynamicrafter',
return caps # Usually SD models are dedicated 'animatediff', 'text2video', 'modelscope-t2v',
'zeroscope', 'lavie']):
# Check for vision models (image-to-text) caps.video_generation = True
if any(x in name_lower for x in ['vision', 'vl-', '-vl', 'llava', 'qwen2-vl', 'qwen-vl', 'phi-4-mini', 'pixtral', 'clip']):
caps.image_to_text = True
caps.text_generation = True # Vision models are also LLMs
return caps return caps
# Check for TTS models if any(x in n for x in ['wan2.1-t2v', 'wan-t2v']):
if any(x in name_lower for x in ['kokoro', 'tts', 'speech', 'voice']): caps.video_generation = True
caps.text_to_speech = True return caps
# Image-to-video
if any(x in n for x in ['stable-video-diffusion', 'svd',
'i2vgen-xl', 'i2vgen', 'cogvideox-i2v',
'wan2.1-i2v', 'wan-i2v', 'img2vid',
'image2video', 'motionctrl']):
caps.image_to_video = True
return caps return caps
# Check for whisper models (speech-to-text) # Wan generic (detect sub-variant)
if any(x in name_lower for x in ['whisper', 'faster-whisper', 'distil-whisper']): if 'wan' in n and ('video' in n or 'diffuser' in n):
caps.image_to_video = True if 'i2v' in n else False
caps.video_generation = True if 'i2v' not in n else False
return caps
# Video interpolation
if any(x in n for x in ['film-net', 'rife', 'flavr', 'dain', 'frame-interp']):
caps.video_interpolation = True
return caps
# Video upscaling / super-resolution
if any(x in n for x in ['real-basicvsr', 'basicvsr', 'edvr',
'video-enhance', 'videoswin-sr']):
caps.video_upscaling = True
return caps
# Video-to-video
if any(x in n for x in ['tokenflow', 'text2video-zero', 'vid2vid',
'rerender-a-video', 'controlvideo']):
caps.video_to_video = True
return caps
# ── Audio ────────────────────────────────────────────────────────────────
if any(x in n for x in ['musicgen', 'audiogen', 'audioldm', 'stable-audio',
'mustango', 'noise2music', 'jukebox', 'audiocraft']):
caps.audio_generation = True
return caps
if any(x in n for x in ['demucs', 'spleeter', 'asteroid', 'open-unmix']):
caps.audio_to_audio = True
return caps
if any(x in n for x in ['whisper', 'faster-whisper', 'distil-whisper',
'wav2vec', 'hubert', 'seamless']):
caps.speech_to_text = True caps.speech_to_text = True
caps.subtitle_generation = True
return caps return caps
# Check for GGUF models (typically text models) if any(x in n for x in ['kokoro', 'xtts', 'bark', 'tortoise',
if '.gguf' in name_lower or 'gguf' in name_lower: 'speecht5', 'matcha-tts', 'voicebox']):
caps.text_to_speech = True
return caps
# Lip sync / dubbing
if any(x in n for x in ['wav2lip', 'sadtalker', 'dinet', 'videoretalking']):
caps.lip_sync = True
return caps
# ── Image: generation ────────────────────────────────────────────────────
if any(x in n for x in ['inpaint', 'instruct-pix2pix', 'paint-by-example']):
caps.inpainting = True
caps.image_generation = True
caps.image_to_image = True
return caps
if 'controlnet' in n:
caps.controlnet = True
caps.image_generation = True
return caps
if any(x in n for x in ['stable-diffusion', 'sd15', 'sdxl', 'sd-xl',
'playground', 'flux', 'kandinsky', 'deepfloyd',
'pixart', 'dalle', 'waifu', 'pony',
'realistic-vision', 'realistic_vision']):
caps.image_generation = True
caps.image_to_image = True
caps.inpainting = True # most SD/SDXL/Flux support inpainting variant
return caps
# ── Image: analysis / processing ─────────────────────────────────────────
if any(x in n for x in ['midas', 'dpt-depth', 'dpt-large', 'zoe-depth',
'depth-anything', 'marigold']):
caps.depth_estimation = True
return caps
if any(x in n for x in ['sam2', 'sam-', '-sam', 'segment-anything',
'mask-rcnn', 'fastsam']):
caps.image_segmentation = True
return caps
if any(x in n for x in ['real-esrgan', 'esrgan', 'swinir', 'edsr',
'bsrgan', 'hat-', 'dat-']):
caps.image_upscaling = True
return caps
if any(x in n for x in ['codeformer', 'gfpgan', 'restoreformer']):
caps.face_restoration = True
caps.image_upscaling = True
return caps
if any(x in n for x in ['yolo', 'detr', 'owlvit', 'rtdetr', 'dino']):
caps.object_detection = True
return caps
# ── Vision / multimodal LLMs ─────────────────────────────────────────────
if any(x in n for x in ['vision', 'vl-', '-vl', 'llava', 'qwen2-vl',
'qwen-vl', 'phi-4-mini', 'pixtral', 'clip',
'blip', 'internvl', 'moondream', 'idefics',
'cogvlm', 'minigpt', 'flamingo']):
caps.image_to_text = True
caps.text_generation = True caps.text_generation = True
return caps return caps
# Default: assume text generation (most HF models are LLMs) # ── Embeddings ───────────────────────────────────────────────────────────
if any(x in n for x in ['embed', 'bge-', 'e5-', 'minilm',
'sentence-transformer', 'nomic-embed',
'instructor-', 'gte-', 'jina-embed']):
caps.embeddings = True
return caps
# ── GGUF quantised text models ───────────────────────────────────────────
if '.gguf' in n or 'gguf' in n:
caps.text_generation = True
return caps
# Default: text generation
caps.text_generation = True caps.text_generation = True
return caps return caps
...@@ -380,6 +380,9 @@ class MultiModelManager: ...@@ -380,6 +380,9 @@ class MultiModelManager:
self.tts_model: Optional[str] = None self.tts_model: Optional[str] = None
self.image_models: List[str] = [] self.image_models: List[str] = []
self.vision_models: List[str] = [] self.vision_models: List[str] = []
self.video_models: List[str] = [] # video generation models (t2v / i2v / v2v)
self.audio_gen_models: List[str] = [] # music / sfx generation (MusicGen, AudioLDM2…)
self.embedding_models: List[str] = [] # text / multimodal embeddings
self.config: Dict[str, Dict] = {} # Store model configurations self.config: Dict[str, Dict] = {} # Store model configurations
self.tool_parser = ModelParserAdapter() self.tool_parser = ModelParserAdapter()
self.current_model_key: Optional[str] = None self.current_model_key: Optional[str] = None
...@@ -660,15 +663,34 @@ class MultiModelManager: ...@@ -660,15 +663,34 @@ class MultiModelManager:
self.vision_models.append(model_name) self.vision_models.append(model_name)
self.config[f"vision:{model_name}"] = config or {} self.config[f"vision:{model_name}"] = config or {}
# Download/cache the model at startup if it's a URL or HF ID
resolved_model = self.load_model(model_name) resolved_model = self.load_model(model_name)
if resolved_model != model_name: if resolved_model != model_name:
# Model was downloaded/cached, update the stored name
idx = self.vision_models.index(model_name) idx = self.vision_models.index(model_name)
self.vision_models[idx] = resolved_model self.vision_models[idx] = resolved_model
self.config[f"vision:{resolved_model}"] = self.config.pop(f"vision:{model_name}") self.config[f"vision:{resolved_model}"] = self.config.pop(f"vision:{model_name}")
print(f"Vision model '{model_name}' cached as: {resolved_model}") print(f"Vision model '{model_name}' cached as: {resolved_model}")
def set_video_model(self, model_name: str, config: Dict = None):
"""Add a video generation model (t2v / i2v / v2v)."""
if model_name not in self.video_models:
self.video_models.append(model_name)
self.config[f"video:{model_name}"] = config or {}
print(f"Registered video model: {model_name}")
def set_audio_gen_model(self, model_name: str, config: Dict = None):
"""Add a music/audio generation model (MusicGen, AudioLDM2, …)."""
if model_name not in self.audio_gen_models:
self.audio_gen_models.append(model_name)
self.config[f"audio_gen:{model_name}"] = config or {}
print(f"Registered audio-gen model: {model_name}")
def set_embedding_model(self, model_name: str, config: Dict = None):
"""Add a text/image embedding model."""
if model_name not in self.embedding_models:
self.embedding_models.append(model_name)
self.config[f"embedding:{model_name}"] = config or {}
print(f"Registered embedding model: {model_name}")
def set_model_alias(self, alias: str, model_name: str): def set_model_alias(self, alias: str, model_name: str):
"""Register an alias for a model.""" """Register an alias for a model."""
self.model_aliases[alias] = model_name self.model_aliases[alias] = model_name
...@@ -714,6 +736,27 @@ class MultiModelManager: ...@@ -714,6 +736,27 @@ class MultiModelManager:
allowed.add(m) allowed.add(m)
allowed.add(f"vision:{m}") allowed.add(f"vision:{m}")
# Video models
if self.video_models:
allowed.add("video")
for m in self.video_models:
allowed.add(m)
allowed.add(f"video:{m}")
# Audio generation models
if self.audio_gen_models:
allowed.add("audio_gen")
for m in self.audio_gen_models:
allowed.add(m)
allowed.add(f"audio_gen:{m}")
# Embedding models
if self.embedding_models:
allowed.add("embedding")
for m in self.embedding_models:
allowed.add(m)
allowed.add(f"embedding:{m}")
# Custom aliases # Custom aliases
for alias in self.model_aliases: for alias in self.model_aliases:
allowed.add(alias) allowed.add(alias)
...@@ -724,7 +767,8 @@ class MultiModelManager: ...@@ -724,7 +767,8 @@ class MultiModelManager:
if config_manager is not None: if config_manager is not None:
md = config_manager.models_data md = config_manager.models_data
for cat in ("text_models", "image_models", "audio_models", for cat in ("text_models", "image_models", "audio_models",
"gguf_models", "tts_models", "vision_models"): "gguf_models", "tts_models", "vision_models",
"video_models", "audio_gen_models", "embedding_models"):
for m in md.get(cat, []): for m in md.get(cat, []):
mid = (m if isinstance(m, str) else mid = (m if isinstance(m, str) else
m.get("alias") or m.get("path") or m.get("id") or "") m.get("alias") or m.get("path") or m.get("id") or "")
...@@ -924,12 +968,18 @@ class MultiModelManager: ...@@ -924,12 +968,18 @@ class MultiModelManager:
# Handle "vision" alias # Handle "vision" alias
if requested_model == "vision": if requested_model == "vision":
return f"image:{self.vision_models[0]}" if self.vision_models else None return f"image:{self.vision_models[0]}" if self.vision_models else None
# Handle "video" alias
if requested_model == "video":
return f"video:{self.video_models[0]}" if self.video_models else None
# Handle prefixed models - normalize them # Handle prefixed models - normalize them
if requested_model.startswith("audio:"): if requested_model.startswith("audio:"):
return requested_model return requested_model
if requested_model.startswith("tts:"): if requested_model.startswith("tts:"):
return requested_model return requested_model
if requested_model.startswith("video:"):
return requested_model
if requested_model.startswith("image:") or requested_model.startswith("vision:"): if requested_model.startswith("image:") or requested_model.startswith("vision:"):
# Normalize vision: to image: # Normalize vision: to image:
if requested_model.startswith("vision:"): if requested_model.startswith("vision:"):
...@@ -1231,13 +1281,19 @@ class MultiModelManager: ...@@ -1231,13 +1281,19 @@ class MultiModelManager:
resolved_name = self.tts_model resolved_name = self.tts_model
elif model_type == "vision": elif model_type == "vision":
resolved_name = self.vision_models[0] if self.vision_models else None resolved_name = self.vision_models[0] if self.vision_models else None
elif model_type == "video":
resolved_name = self.video_models[0] if self.video_models else None
elif model_type == "audio_gen":
resolved_name = self.audio_gen_models[0] if self.audio_gen_models else None
elif model_type == "embedding":
resolved_name = self.embedding_models[0] if self.embedding_models else None
else: else:
resolved_name = self.default_model resolved_name = self.default_model
else: else:
# Resolve custom aliases # Resolve custom aliases
if requested_model in self.model_aliases: if requested_model in self.model_aliases:
requested_model = self.model_aliases[requested_model] requested_model = self.model_aliases[requested_model]
# Handle "default" alias # Handle "default" alias
if requested_model == "default": if requested_model == "default":
resolved_name = self.default_model resolved_name = self.default_model
...@@ -1250,6 +1306,12 @@ class MultiModelManager: ...@@ -1250,6 +1306,12 @@ class MultiModelManager:
resolved_name = self.tts_model resolved_name = self.tts_model
elif requested_model == "vision": elif requested_model == "vision":
resolved_name = self.vision_models[0] if self.vision_models else None resolved_name = self.vision_models[0] if self.vision_models else None
elif requested_model == "video":
resolved_name = self.video_models[0] if self.video_models else None
elif requested_model == "audio_gen":
resolved_name = self.audio_gen_models[0] if self.audio_gen_models else None
elif requested_model == "embedding":
resolved_name = self.embedding_models[0] if self.embedding_models else None
# Handle prefixed models (e.g., "image:model_name") # Handle prefixed models (e.g., "image:model_name")
elif requested_model.startswith("image:"): elif requested_model.startswith("image:"):
resolved_name = requested_model[6:] resolved_name = requested_model[6:]
...@@ -1259,6 +1321,12 @@ class MultiModelManager: ...@@ -1259,6 +1321,12 @@ class MultiModelManager:
resolved_name = requested_model[4:] resolved_name = requested_model[4:]
elif requested_model.startswith("vision:"): elif requested_model.startswith("vision:"):
resolved_name = requested_model[7:] resolved_name = requested_model[7:]
elif requested_model.startswith("video:"):
resolved_name = requested_model[6:]
elif requested_model.startswith("audio_gen:"):
resolved_name = requested_model[10:]
elif requested_model.startswith("embedding:"):
resolved_name = requested_model[10:]
else: else:
resolved_name = requested_model resolved_name = requested_model
...@@ -1610,14 +1678,35 @@ class MultiModelManager: ...@@ -1610,14 +1678,35 @@ class MultiModelManager:
return None return None
def list_models(self) -> List[ModelInfo]: def list_models(self) -> List[ModelInfo]:
"""List all available models (configured + runtime aliases).""" """List all available models (configured + runtime aliases) with type/capability metadata."""
from codai.models.capabilities import detect_model_capabilities
models = [] models = []
seen_ids: set = set() seen_ids: set = set()
def _add(model_id: str): CAT_TYPE = {
if model_id not in seen_ids: "text_models": "text",
seen_ids.add(model_id) "gguf_models": "text",
models.append(ModelInfo(id=model_id)) "vision_models": "vision",
"image_models": "image",
"audio_models": "audio",
"tts_models": "tts",
"video_models": "video",
"audio_gen_models": "audio_gen",
"embedding_models": "embedding",
}
def _add(model_id: str, model_type: str = None):
if model_id in seen_ids:
return
seen_ids.add(model_id)
caps = detect_model_capabilities(model_id)
resolved_type = model_type or (caps.to_list()[0].split("_")[0] if caps.to_list() else "text")
models.append(ModelInfo(
id=model_id,
type=resolved_type,
capabilities=caps.to_list(),
))
# --- Models from config (the authoritative source) --- # --- Models from config (the authoritative source) ---
try: try:
...@@ -1625,57 +1714,72 @@ class MultiModelManager: ...@@ -1625,57 +1714,72 @@ class MultiModelManager:
if config_manager is not None: if config_manager is not None:
md = config_manager.models_data md = config_manager.models_data
for cat in ("text_models", "vision_models", "image_models", for cat in ("text_models", "vision_models", "image_models",
"audio_models", "tts_models", "gguf_models"): "audio_models", "tts_models", "gguf_models",
"video_models", "audio_gen_models", "embedding_models"):
mtype = CAT_TYPE.get(cat, "text")
for m in md.get(cat, []): for m in md.get(cat, []):
if isinstance(m, str): if isinstance(m, str):
mid = m mid = m
else: else:
mid = m.get("alias") or m.get("path") or m.get("id") or "" mid = m.get("alias") or m.get("path") or m.get("id") or ""
# Also expose the raw path/id
raw = m.get("path") or m.get("id") or "" raw = m.get("path") or m.get("id") or ""
if raw and raw != mid: if raw and raw != mid:
_add(raw) _add(raw, mtype)
# Short name
short = raw.split("/")[-1] if "/" in raw else raw short = raw.split("/")[-1] if "/" in raw else raw
if short != raw: if short != raw:
_add(short) _add(short, mtype)
if mid: if mid:
_add(mid) _add(mid, mtype)
short = mid.split("/")[-1] if "/" in mid else mid short = mid.split("/")[-1] if "/" in mid else mid
if short != mid: if short != mid:
_add(short) _add(short, mtype)
except Exception: except Exception:
pass pass
# --- Fallback: runtime default_model (if config_manager unavailable) --- # --- Fallback: runtime default_model ---
if not models and self.default_model: if not models and self.default_model:
model_id = self.default_model model_id = self.default_model
if not (model_id.startswith("http://") or model_id.startswith("https://")): if not (model_id.startswith("http://") or model_id.startswith("https://")):
short_name = model_id.split("/")[-1] if "/" in model_id else model_id short_name = model_id.split("/")[-1] if "/" in model_id else model_id
if short_name != model_id: if short_name != model_id:
_add(short_name) _add(short_name, "text")
_add(model_id) _add(model_id, "text")
_add("default") _add("default", "text")
# --- Runtime-registered non-text models (image, audio, tts, vision) --- # --- Runtime-registered non-text models ---
if self.audio_models: if self.audio_models:
_add("audio") _add("audio", "audio")
for audio_id in self.audio_models: for m in self.audio_models:
_add(f"audio:{audio_id}") _add(f"audio:{m}", "audio")
if self.tts_model: if self.tts_model:
_add("tts") _add("tts", "tts")
_add(f"tts:{self.tts_model}") _add(f"tts:{self.tts_model}", "tts")
if self.image_models: if self.image_models:
_add("image") _add("image", "image")
for image_id in self.image_models: for m in self.image_models:
_add(f"image:{image_id}") _add(f"image:{m}", "image")
if self.vision_models: if self.vision_models:
_add("vision") _add("vision", "vision")
for vision_id in self.vision_models: for m in self.vision_models:
_add(f"vision:{vision_id}") _add(f"vision:{m}", "vision")
if self.video_models:
_add("video", "video")
for m in self.video_models:
_add(f"video:{m}", "video")
if self.audio_gen_models:
_add("audio_gen", "audio_gen")
for m in self.audio_gen_models:
_add(f"audio_gen:{m}", "audio_gen")
if self.embedding_models:
_add("embedding", "embedding")
for m in self.embedding_models:
_add(f"embedding:{m}", "embedding")
# --- Custom aliases --- # --- Custom aliases ---
for alias in self.model_aliases: for alias in self.model_aliases:
......
"""Pydantic models for audio generation API."""
from typing import Dict, List, Optional
from pydantic import BaseModel, ConfigDict
class AudioGenerationRequest(BaseModel):
model: str
prompt: str
duration: Optional[float] = 10.0 # seconds
top_k: Optional[int] = 250
top_p: Optional[float] = 0.0
temperature: Optional[float] = 1.0
cfg_coef: Optional[float] = 3.0 # classifier-free guidance coefficient
seed: Optional[int] = None
# Reference audio for melody conditioning (MusicGen Melody)
melody: Optional[str] = None # base64/URL
# Output
response_format: Optional[str] = "url" # url | b64_wav | b64_mp3
user: Optional[str] = None
model_config = ConfigDict(extra="allow")
class AudioGenerationResponse(BaseModel):
created: int
data: List[Dict]
model_config = ConfigDict(extra="allow")
"""Pydantic models for embeddings API."""
from typing import Dict, List, Optional, Union
from pydantic import BaseModel, ConfigDict
class EmbeddingsRequest(BaseModel):
model: str
input: Union[str, List[str]] # text(s) to embed
image: Optional[Union[str, List[str]]] = None # base64/URL image(s) for multimodal embed
encoding_format: Optional[str] = "float" # float | base64
dimensions: Optional[int] = None # truncate to N dims if supported
user: Optional[str] = None
model_config = ConfigDict(extra="allow")
class EmbeddingObject(BaseModel):
object: str = "embedding"
index: int
embedding: Union[List[float], str] # float list or base64
class EmbeddingsResponse(BaseModel):
object: str = "list"
data: List[EmbeddingObject]
model: str
usage: Dict
model_config = ConfigDict(extra="allow")
...@@ -10,14 +10,15 @@ class ImageGenerationRequest(BaseModel): ...@@ -10,14 +10,15 @@ class ImageGenerationRequest(BaseModel):
prompt: str prompt: str
n: int = 1 n: int = 1
size: Optional[str] = "1024x1024" size: Optional[str] = "1024x1024"
steps: Optional[int] = None # Number of inference steps (overrides quality-based default) steps: Optional[int] = None
guidance_scale: Optional[float] = None # CFG scale (overrides quality-based default) guidance_scale: Optional[float] = None
quality: Optional[str] = "standard" quality: Optional[str] = "standard"
style: Optional[str] = None style: Optional[str] = None
response_format: Optional[str] = "url" response_format: Optional[str] = "url"
seed: Optional[int] = None seed: Optional[int] = None
user: Optional[str] = None user: Optional[str] = None
disable_safety_checker: Optional[bool] = False
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow")
......
...@@ -103,6 +103,8 @@ class ModelInfo(BaseModel): ...@@ -103,6 +103,8 @@ class ModelInfo(BaseModel):
object: str = "model" object: str = "model"
created: int = Field(default_factory=lambda: int(time.time())) created: int = Field(default_factory=lambda: int(time.time()))
owned_by: str = "huggingface" owned_by: str = "huggingface"
type: Optional[str] = None # e.g. "text", "image", "video", "audio", "tts", "vision", "embedding"
capabilities: Optional[List[str]] = None # list of capability strings
class ModelList(BaseModel): class ModelList(BaseModel):
......
"""Pydantic models for video generation API."""
from typing import Dict, List, Optional
from pydantic import BaseModel, ConfigDict
class VideoGenerationRequest(BaseModel):
model: str
prompt: str = ""
negative_prompt: Optional[str] = None
# Dimensions
width: Optional[int] = 512
height: Optional[int] = 512
# Temporal
num_frames: Optional[int] = None # model default if None
fps: Optional[int] = None # output FPS
# Diffusion
num_inference_steps: Optional[int] = None
guidance_scale: Optional[float] = None
seed: Optional[int] = None
# Mode
# t2v – text-to-video
# i2v – image-to-video (init_image required)
# v2v – video-to-video (video required)
# ti2v – text + init image → video (like i2v but prompt is primary driver)
# interp – frame interpolation (init_image + end_image)
mode: Optional[str] = "t2v"
# Input media (base64 or URL)
image: Optional[str] = None # alias for init_image
init_image: Optional[str] = None # first/reference frame
end_image: Optional[str] = None # last frame (for interp mode)
video: Optional[str] = None # input video (v2v / audio manipulation)
strength: Optional[float] = None # denoising strength for v2v
# Camera motion hint
camera_motion: Optional[str] = None # zoom-in | zoom-out | pan-left | pan-right | tilt-up | tilt-down | rotate
# ── Character consistency ─────────────────────────────────────────────
character_references: Optional[List[str]] = None # list of base64/URL reference images
character_strength: Optional[float] = 0.8
character_names: Optional[List[str]] = None # optional names per reference
# ── Audio generation / manipulation ──────────────────────────────────
add_audio: Optional[bool] = False
audio_type: Optional[str] = None # music | speech | sfx | ambient
audio_prompt: Optional[str] = None # prompt for music/sfx generation
audio_file: Optional[str] = None # existing audio to add (base64/URL)
tts_text: Optional[str] = None # text for speech synthesis
tts_voice: Optional[str] = None # TTS voice id
tts_speed: Optional[float] = 1.0
sync_audio: Optional[bool] = False # sync audio timing to video
lip_sync: Optional[bool] = False # warp mouth to match audio
lip_sync_method: Optional[str] = "wav2lip" # wav2lip | sadtalker
# ── Subtitles ────────────────────────────────────────────────────────
generate_subtitles: Optional[bool] = False
burn_subtitles: Optional[bool] = False
subtitle_language: Optional[str] = None # source language hint
translate_subtitles: Optional[bool] = False
subtitle_target_lang: Optional[str] = None
subtitle_style: Optional[str] = "default" # default | karaoke | minimal
whisper_model: Optional[str] = None # which whisper variant to use
# ── Video dubbing ─────────────────────────────────────────────────────
dub_video: Optional[bool] = False
dub_target_lang: Optional[str] = None
dub_source_lang: Optional[str] = None
voice_clone: Optional[bool] = False # clone original speaker voice
# ── Post-processing ───────────────────────────────────────────────────
upscale_output: Optional[bool] = False
upscale_factor: Optional[int] = 2
interpolate_output: Optional[bool] = False # increase FPS after generation
fps_multiplier: Optional[int] = 2 # e.g. 2 → 2× FPS via frame interp
convert_to_3d: Optional[bool] = False
depth_method: Optional[str] = "midas" # midas | zoe | depth-anything
# ── Memory / offload ─────────────────────────────────────────────────
offload_strategy: Optional[str] = None # sequential | model | none
# Nulls pipeline safety_checker / safety_concept so uncensored fine-tunes
# are not blocked. Has no effect on models without a safety checker.
disable_safety_checker: Optional[bool] = False
# ── Output ───────────────────────────────────────────────────────────
response_format: Optional[str] = "url" # url | b64_mp4
n: int = 1
user: Optional[str] = None
model_config = ConfigDict(extra="allow")
class VideoGenerationResponse(BaseModel):
created: int
data: List[Dict]
model_config = ConfigDict(extra="allow")
# ── Standalone operation requests ─────────────────────────────────────────────
class VideoUpscaleRequest(BaseModel):
model: str
video: str # base64/URL input video
upscale_factor: Optional[int] = 2
response_format: Optional[str] = "url"
model_config = ConfigDict(extra="allow")
class VideoSubtitleRequest(BaseModel):
model: str
video: str # base64/URL input video
language: Optional[str] = None
translate: Optional[bool] = False
target_lang: Optional[str] = None
burn: Optional[bool] = False
style: Optional[str] = "default"
response_format: Optional[str] = "srt" # srt | vtt | json | burned_video
model_config = ConfigDict(extra="allow")
class VideoInterpolateRequest(BaseModel):
model: str
video: Optional[str] = None # base64/URL input video (mutually exclusive with init/end)
init_image: Optional[str] = None # first frame
end_image: Optional[str] = None # last frame
fps_multiplier: Optional[int] = 2
response_format: Optional[str] = "url"
model_config = ConfigDict(extra="allow")
class VideoDubRequest(BaseModel):
model: str
video: str
target_lang: str
source_lang: Optional[str] = None
voice_clone: Optional[bool] = False
burn_subtitles: Optional[bool] = False
response_format: Optional[str] = "url"
model_config = ConfigDict(extra="allow")
...@@ -16,6 +16,21 @@ psutil>=5.9.0 ...@@ -16,6 +16,21 @@ psutil>=5.9.0
# Optional: Audio transcription dependencies # Optional: Audio transcription dependencies
faster-whisper>=0.10.0 # For NVIDIA/CUDA whisper transcription faster-whisper>=0.10.0 # For NVIDIA/CUDA whisper transcription
whispercpp>=0.0.17 # Alternative whisper library (works without PyTorch) whispercpp>=0.0.17 # Alternative whisper library (works without PyTorch)
openai-whisper>=20231117 # Whisper for subtitle generation
# Image/video/audio utilities
Pillow>=10.0.0
numpy>=1.24.0
imageio[ffmpeg]>=2.33.0 # frame I/O + ffmpeg bridge for video generation
scipy>=1.11.0
sentence-transformers>=2.7.0 # /v1/embeddings
argostranslate>=1.9.0 # subtitle translation
edge-tts>=6.1.9 # TTS dubbing (primary)
kokoro-tts>=0.9.0 # TTS dubbing (fallback)
soundfile>=0.12.0
realesrgan>=0.3.0
basicsr>=1.4.2
timm>=0.9.0
# Optional: for better performance with NVIDIA GPUs # Optional: for better performance with NVIDIA GPUs
bitsandbytes>=0.41.0 bitsandbytes>=0.41.0
......
...@@ -49,6 +49,32 @@ whispercpp>=0.0.17 # Alternative whisper library (works without PyTorch) ...@@ -49,6 +49,32 @@ whispercpp>=0.0.17 # Alternative whisper library (works without PyTorch)
# LiteLLM for standardized API responses # LiteLLM for standardized API responses
litellm>=1.40.0 litellm>=1.40.0
# Image/video processing utilities
Pillow>=10.0.0
numpy>=1.24.0
imageio[ffmpeg]>=2.33.0 # frame I/O + ffmpeg bridge for video generation
scipy>=1.11.0 # audio/signal processing (wav export in audio_gen)
# Embeddings
sentence-transformers>=2.7.0 # /v1/embeddings with sentence-transformer models
# Video/audio post-processing (all optional – features degrade gracefully if absent)
openai-whisper>=20231117 # subtitle generation via Whisper transcription
argostranslate>=1.9.0 # subtitle translation
edge-tts>=6.1.9 # TTS for video dubbing (primary)
kokoro-tts>=0.9.0 # TTS for video dubbing (fallback)
soundfile>=0.12.0 # audio file I/O for kokoro TTS output
# Image upscaling / restoration
realesrgan>=0.3.0 # Real-ESRGAN upscaler
basicsr>=1.4.2 # backbone required by realesrgan
timm>=0.9.0 # vision model backbones (depth/segment endpoints)
# Audio generation (optional – only needed for /v1/audio/generate)
# audiocraft is Meta's MusicGen/AudioGen library; install separately if desired:
# pip install audiocraft
# AudioLDM2 is available via diffusers (already listed above)
# Optional: for better performance # Optional: for better performance
# bitsandbytes>=0.41.0 # for 4-bit/8-bit quantization # bitsandbytes>=0.41.0 # for 4-bit/8-bit quantization
# sentencepiece>=0.1.99 # for some tokenizers # sentencepiece>=0.1.99 # for some tokenizers
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment