LoRA training: async kickoff, restart recovery, keyframe regen UI

Server (codai/api/loras.py):
- /v1/loras/train gains wait (default True) + session; wait=false detaches
  the job and returns a job_id, avoiding HTTP read-timeouts on multi-hour
  video trainings.
- Disk-persisted job registry keyed by job_id (carries session). Progress
  endpoint serves ?job=<id> / ?session=<tok> so a client only ever sees its
  own job — no cross-user spillover. Jobs left mid-flight at startup are
  marked interrupted.
- Mid-training PEFT checkpoints (SD1.5/SDXL/Wan) + train_state.json; a
  resubmit resumes from the last step when base/target/rank (and session)
  match, so a reboot no longer throws away hours of Wan training.

Township (tools/gen_township_fighters.py):
- Async training: per-run session token + persisted per-LoRA job_id; polls
  by job_id, re-attaches to a running server job after a restart, resubmits
  an interrupted one (server resumes from checkpoint).
- Dedicated train timeouts (24h video / 4h image).
- Match page: regenerate/clear keyframes (match-level + per-clip/outcome)
  via new keyframes/keyframe render + delete scopes.

tools/videogen.py: mirror the session-token + job-id recovery helpers.
Co-Authored-By: 's avatarClaude Opus 4.8 <noreply@anthropic.com>
parent f21c6185
......@@ -64,6 +64,94 @@ _progress = {
}
_train_lock = threading.Lock()
# ── Multi-client job registry ─────────────────────────────────────────────────
# Training is serialized server-wide (one GPU job at a time), but several clients
# may submit concurrently. Each submission gets a unique job_id and is recorded
# here so a client only ever polls *its own* job (no cross-user spillover). The
# registry is mirrored to disk so a recovering client (after a server OR client
# restart) can re-attach by job_id / session token. `_active_job_id` names the
# job currently executing, so _set_progress can mirror live progress into it.
_jobs_lock = threading.Lock()
_jobs: dict = {} # job_id -> record
_active_job_id: Optional[str] = None
_bg_tasks: set = set() # strong refs to detached train tasks
_JOB_ACTIVE_STATES = ("queued", "preparing", "training", "saving")
_MIRROR_FIELDS = ("active", "name", "step", "total", "status", "message",
"started_at", "path")
_last_job_persist = 0.0
def _jobs_file() -> str:
return os.path.join(_loras_dir(), "_train_jobs.json")
def _persist_jobs_locked(force: bool = False) -> None:
"""Write the registry to disk. Throttled to ~5s unless forced (status change)
so per-step progress updates don't hammer the disk."""
global _last_job_persist
now = time.time()
if not force and (now - _last_job_persist) < 5.0:
return
_last_job_persist = now
try:
p = _jobs_file()
tmp = p + ".tmp"
with open(tmp, "w") as f:
json.dump(_jobs, f)
os.replace(tmp, p)
except Exception:
pass
def _new_job(job_id: str, **fields) -> None:
with _jobs_lock:
rec = {"job_id": job_id, "updated_at": time.time()}
rec.update(fields)
_jobs[job_id] = rec
_persist_jobs_locked(force=True)
def _update_job(job_id: Optional[str], force: bool = False, **fields) -> None:
if not job_id:
return
with _jobs_lock:
rec = _jobs.get(job_id)
if rec is None:
rec = {"job_id": job_id}
_jobs[job_id] = rec
# A status change is significant — persist immediately.
if "status" in fields and fields["status"] != rec.get("status"):
force = True
rec.update(fields)
rec["updated_at"] = time.time()
_persist_jobs_locked(force=force)
def _load_jobs_on_start() -> None:
"""Load the persisted registry and mark any job that was mid-flight when the
process died as 'interrupted' (its in-memory GPU training is gone). Clients
polling such a job learn it must be resubmitted (server resumes from the last
on-disk checkpoint)."""
global _jobs
try:
with open(_jobs_file()) as f:
data = json.load(f)
except Exception:
return
if not isinstance(data, dict):
return
changed = False
for rec in data.values():
if rec.get("status") in _JOB_ACTIVE_STATES:
rec["status"] = "interrupted"
rec["active"] = False
rec["message"] = "interrupted by server restart — resubmit to resume"
changed = True
with _jobs_lock:
_jobs = data
if changed:
_persist_jobs_locked(force=True)
def set_global_args(args):
global _LORAS_DIR
......@@ -76,6 +164,7 @@ def set_global_args(args):
root = None
_LORAS_DIR = os.path.join(root, 'loras') if root else str(default_loras_dir())
os.makedirs(_LORAS_DIR, exist_ok=True)
_load_jobs_on_start()
def _loras_dir() -> str:
......@@ -153,6 +242,15 @@ class LoraTrainRequest(BaseModel):
learning_rate: Optional[float] = 1e-4
resolution: Optional[int] = 512
seed: Optional[int] = 42
# Run the job asynchronously: the POST returns immediately with a job_id and
# the client polls /v1/loras/progress?job=<id>. Avoids long read-timeouts on
# multi-hour video trainings. Default True keeps blocking behaviour for old
# clients that expect the result inline.
wait: Optional[bool] = True
# Caller-supplied session token. Recorded with the job so the owning client —
# and only it — can recover its job(s) after a restart. Auto-generated if
# omitted.
session: Optional[str] = None
model_config = ConfigDict(extra="allow")
......@@ -246,6 +344,13 @@ def _gather_images(req: LoraTrainRequest):
def _set_progress(**kw):
with _progress_lock:
_progress.update(kw)
job_id = _active_job_id
# Mirror live progress into the executing job's record so its owner (and only
# its owner) can poll it by job_id.
if job_id:
mirror = {k: kw[k] for k in kw if k in _MIRROR_FIELDS}
if mirror:
_update_job(job_id, **mirror)
def _lora_debug_enabled() -> bool:
......@@ -295,6 +400,83 @@ def _free_train_vram() -> None:
pass
# ── Mid-training checkpoints (resume across restarts) ─────────────────────────
# Every _CKPT_EVERY steps we snapshot the raw PEFT adapter state(s) + a
# train_state.json into the LoRA's own folder (name-scoped, so two clients
# training different LoRAs never collide). If the process dies, the next
# submission for the same name/base/target/rank reloads the snapshot and
# continues from the saved step instead of restarting from scratch. Optimizer
# momentum is not restored (re-warms in a few steps); the trained weights — the
# expensive part — are preserved. The snapshot records the owning session so a
# resume can verify ownership.
_CKPT_EVERY = 100
def _train_state_path(name: str) -> str:
return os.path.join(_lora_dir(name), "train_state.json")
def _save_train_checkpoint(name: str, state: dict, peft_states: dict) -> None:
from safetensors.torch import save_file
d = _lora_dir(name)
os.makedirs(d, exist_ok=True)
try:
for key, sd in peft_states.items():
cpu_sd = {k: v.detach().to("cpu") for k, v in sd.items()}
save_file(cpu_sd, os.path.join(d, f"_ckpt_{key}.safetensors"))
tmp = _train_state_path(name) + ".tmp"
with open(tmp, "w") as f:
json.dump(state, f)
os.replace(tmp, _train_state_path(name))
except Exception as e:
_dbg_lora(f"checkpoint save failed for '{name}': {e}")
def _load_train_state(name: str, *, base_path: str, target: str,
rank: int, session: Optional[str] = None) -> Optional[dict]:
"""Return a valid resumable checkpoint state for `name`, or None. Validates
that base/target/rank match (and session, when provided) so we never resume
one config's weights into another's run."""
p = _train_state_path(name)
if not os.path.isfile(p):
return None
try:
with open(p) as f:
st = json.load(f)
except Exception:
return None
if (st.get("base_path") != base_path or st.get("target") != target
or int(st.get("rank", -1)) != int(rank)):
return None
if session is not None and st.get("session") not in (None, session):
return None
step = int(st.get("step", 0))
if step <= 0 or step >= int(st.get("total", 0)):
return None
return st
def _apply_peft_checkpoint(name: str, key: str, model) -> None:
from safetensors.torch import load_file
from peft import set_peft_model_state_dict
f = os.path.join(_lora_dir(name), f"_ckpt_{key}.safetensors")
if not os.path.isfile(f):
raise FileNotFoundError(f)
set_peft_model_state_dict(model, load_file(f), adapter_name="default")
def _clear_train_checkpoint(name: str) -> None:
d = _lora_dir(name)
if not os.path.isdir(d):
return
for fn in os.listdir(d):
if fn.startswith("_ckpt_") or fn == "train_state.json":
try:
os.remove(os.path.join(d, fn))
except Exception:
pass
# ── Training base-model cache ────────────────────────────────────────────────
# The SD1.x/SDXL base used for LoRA training is expensive to load from disk.
# We keep its components cached *on CPU* between consecutive training jobs so a
......@@ -633,6 +815,21 @@ def _train_sd15(req, base_path, images, instance_prompt,
lora_params = [p for p in unet.parameters() if p.requires_grad]
optimizer = torch.optim.AdamW(lora_params, lr=lr)
# Resume from a mid-training checkpoint if one survives a prior restart.
start_step = 0
_ck = _load_train_state(name, base_path=base_path, target="image", rank=rank,
session=getattr(req, "session", None))
if _ck:
try:
_apply_peft_checkpoint(name, "default", unet)
start_step = int(_ck["step"])
_set_progress(step=start_step,
message=f"resuming from step {start_step}/{steps}")
_dbg_lora(f"resumed SD1.5 '{name}' from checkpoint step {start_step}")
except Exception as e:
print(f" [lora] could not resume '{name}': {e}; starting fresh")
start_step = 0
# Pre-encode latents and the (single) instance-prompt embedding.
latents_list = _make_dataset(images, [tokenizer], [text_encoder], instance_prompt,
resolution, vae, device, torch.float32, is_sdxl=False)
......@@ -654,7 +851,7 @@ def _train_sd15(req, base_path, images, instance_prompt,
_set_progress(status="training", message="training (SD1.5)")
unet.train()
n = len(latents_list)
for step in range(steps):
for step in range(start_step, steps):
latents = latents_list[step % n].to(device)
noise = torch.randn_like(latents)
bsz = latents.shape[0]
......@@ -676,6 +873,12 @@ def _train_sd15(req, base_path, images, instance_prompt,
if step % 10 == 0 or step == steps - 1:
_set_progress(step=step + 1, message=f"step {step+1}/{steps} loss={loss.item():.4f}")
_dbg_lora(f"SD1.5 step {step+1}/{steps} loss={loss.item():.4f}")
if (step + 1) % _CKPT_EVERY == 0 and step + 1 < steps:
_save_train_checkpoint(name,
{"name": name, "base_path": base_path, "target": "image",
"rank": rank, "step": step + 1, "total": steps, "seed": seed,
"session": getattr(req, "session", None)},
{"default": get_peft_model_state_dict(unet)})
# Mid-training thermal checkpoint (pauses if CPU/GPU too hot).
try:
from codai.models.thermal import checkpoint as _thermal_checkpoint
......@@ -691,6 +894,7 @@ def _train_sd15(req, base_path, images, instance_prompt,
unet_lora_layers=unet_lora,
safe_serialization=True)
_write_meta(name, req, base_path, len(images), "sd15", instance_prompt)
_clear_train_checkpoint(name)
# Job done: drop this job's adapter + transients and move the UNet back to
# CPU. The base stays cached on CPU (reused by the next job); no VRAM is held
......@@ -754,6 +958,21 @@ def _train_sdxl(req, base_path, images, instance_prompt,
lora_params = [p for p in unet.parameters() if p.requires_grad]
optimizer = torch.optim.AdamW(lora_params, lr=lr)
# Resume from a mid-training checkpoint if one survives a prior restart.
start_step = 0
_ck = _load_train_state(name, base_path=base_path, target="image", rank=rank,
session=getattr(req, "session", None))
if _ck:
try:
_apply_peft_checkpoint(name, "default", unet)
start_step = int(_ck["step"])
_set_progress(step=start_step,
message=f"resuming from step {start_step}/{steps}")
_dbg_lora(f"resumed SDXL '{name}' from checkpoint step {start_step}")
except Exception as e:
print(f" [lora] could not resume '{name}': {e}; starting fresh")
start_step = 0
tfm = transforms.Compose([
transforms.Resize(resolution, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(resolution),
......@@ -802,7 +1021,7 @@ def _train_sdxl(req, base_path, images, instance_prompt,
_set_progress(status="training", message="training (SDXL)")
unet.train()
n = len(latents_list)
for step in range(steps):
for step in range(start_step, steps):
latents = latents_list[step % n].to(device)
noise = torch.randn_like(latents)
bsz = latents.shape[0]
......@@ -825,6 +1044,12 @@ def _train_sdxl(req, base_path, images, instance_prompt,
if step % 10 == 0 or step == steps - 1:
_set_progress(step=step + 1, message=f"step {step+1}/{steps} loss={loss.item():.4f}")
_dbg_lora(f"SDXL step {step+1}/{steps} loss={loss.item():.4f}")
if (step + 1) % _CKPT_EVERY == 0 and step + 1 < steps:
_save_train_checkpoint(name,
{"name": name, "base_path": base_path, "target": "image",
"rank": rank, "step": step + 1, "total": steps, "seed": seed,
"session": getattr(req, "session", None)},
{"default": get_peft_model_state_dict(unet)})
try:
from codai.models.thermal import checkpoint as _thermal_checkpoint
_thermal_checkpoint(context="lora-train", throttle_seconds=2.0)
......@@ -839,6 +1064,7 @@ def _train_sdxl(req, base_path, images, instance_prompt,
unet_lora_layers=unet_lora,
safe_serialization=True)
_write_meta(name, req, base_path, len(images), "sdxl", instance_prompt)
_clear_train_checkpoint(name)
# Job done: drop this job's adapter + transients and move the UNet back to
# CPU. The base stays cached on CPU for the next job; no VRAM held afterwards
......@@ -1009,6 +1235,25 @@ def _train_wan(req, base_path, images, instance_prompt,
detail="Wan LoRA: no trainable adapter params were created")
optimizer = torch.optim.AdamW(lora_params, lr=lr)
# Resume from a mid-training checkpoint if one survives a prior restart. This
# is the big win for Wan: the A14B base reload alone is ~27 min, and training
# runs for hours — a reboot otherwise throws all of it away.
start_step = 0
_ck = _load_train_state(name, base_path=base_path, target="video", rank=rank,
session=getattr(req, "session", None))
if _ck:
try:
_apply_peft_checkpoint(name, "t", experts[0][1])
if len(experts) > 1:
_apply_peft_checkpoint(name, "t2", experts[1][1])
start_step = int(_ck["step"])
_set_progress(step=start_step,
message=f"resuming from step {start_step}/{steps}")
_dbg_lora(f"resumed Wan '{name}' from checkpoint step {start_step}")
except Exception as e:
print(f" [lora] could not resume Wan '{name}': {e}; starting fresh")
start_step = 0
try:
sched = FlowMatchEulerDiscreteScheduler.from_pretrained(base_path, subfolder="scheduler")
shift = float(getattr(sched.config, "shift", 3.0) or 3.0)
......@@ -1024,7 +1269,7 @@ def _train_wan(req, base_path, images, instance_prompt,
_set_progress(status="training", message="training (Wan video LoRA)")
n = len(latents_list)
for step in range(steps):
for step in range(start_step, steps):
x0 = latents_list[step % n].to(device, dtype=compute_dtype)
noise = torch.randn_like(x0)
# Rectified-flow timestep with Wan resolution shift applied to sigma.
......@@ -1054,6 +1299,15 @@ def _train_wan(req, base_path, images, instance_prompt,
message=f"step {step+1}/{steps} loss={loss.item():.4f}")
if step % 10 == 0 or step == steps - 1:
_dbg_lora(f"Wan step {step+1}/{steps} loss={loss.item():.4f}")
if (step + 1) % _CKPT_EVERY == 0 and step + 1 < steps:
_ckpt_states = {"t": get_peft_model_state_dict(experts[0][1])}
if len(experts) > 1:
_ckpt_states["t2"] = get_peft_model_state_dict(experts[1][1])
_save_train_checkpoint(name,
{"name": name, "base_path": base_path, "target": "video",
"rank": rank, "step": step + 1, "total": steps, "seed": seed,
"session": getattr(req, "session", None)},
_ckpt_states)
try:
from codai.models.thermal import checkpoint as _thermal_checkpoint
_thermal_checkpoint(context="lora-train", throttle_seconds=2.0)
......@@ -1081,6 +1335,7 @@ def _train_wan(req, base_path, images, instance_prompt,
WanPipeline.save_lora_weights(save_directory=save_dir,
safe_serialization=True, **save_kwargs)
_write_meta(name, req, base_path, len(images), "wan", instance_prompt)
_clear_train_checkpoint(name)
# ── 5. Tear down THIS job's adapter, but KEEP the transformer(s) cached so a
# next training against the same base skips the (very slow) reload. Remove the
......@@ -1142,17 +1397,25 @@ def _write_meta(name, req, base_path, n_images, arch, instance_prompt):
_TRAIN_MODEL_KEY = "lora-train"
def _train_lora_blocking(req: LoraTrainRequest) -> dict:
def _train_lora_blocking(req: LoraTrainRequest, job_id: Optional[str] = None) -> dict:
"""Run one training job to completion (called inside a worker thread).
Holds _train_lock for the job's duration so a second training never overlaps
(also the signal _release_base_cache uses to know a job is in flight). The
central queue already serializes us, so this acquire returns immediately.
`_active_job_id` is set so live progress mirrors into this job's record (and
only this job's) for its owner to poll.
"""
global _active_job_id
_train_lock.acquire()
_active_job_id = job_id
try:
return _train_lora_sync(req)
except Exception:
result = _train_lora_sync(req)
if job_id:
_update_job(job_id, status="done", active=False, force=True,
message="done", path=result.get("path"))
return result
except Exception as e:
import traceback
traceback.print_exc()
# On error the base may be in a half-moved / inconsistent state — drop the
......@@ -1161,18 +1424,40 @@ def _train_lora_blocking(req: LoraTrainRequest) -> dict:
_set_progress(active=False, status="error", message="training failed")
except Exception:
pass
if job_id:
_update_job(job_id, status="error", active=False, force=True,
message=f"training failed: {e}"[:300])
_drop_base_cache()
_drop_wan_cache()
raise
finally:
_active_job_id = None
_train_lock.release()
async def _run_train_job(req: LoraTrainRequest, job_id: str) -> dict:
"""Acquire a scheduler slot, then run the blocking job in a worker thread.
Used for both the inline (wait=True) and detached (wait=False) paths so the
queue serialization and job bookkeeping are identical."""
import asyncio
request_id = f"lora-train-{job_id}"
lease = await queue_manager.acquire(request_id, _TRAIN_MODEL_KEY)
try:
return await asyncio.to_thread(_train_lora_blocking, req, job_id)
finally:
await queue_manager.release(lease)
@router.post("/v1/loras/train")
async def train_lora(req: LoraTrainRequest, _auth=Depends(_require_api_auth)):
"""Train a LoRA (blocking). Admitted through the central request scheduler,
so concurrent training requests queue and run one after another (instead of
being rejected) alongside all other model requests."""
"""Train a LoRA. Admitted through the central request scheduler so concurrent
trainings queue and run one-at-a-time alongside all other model requests.
With `wait=false` the call returns immediately with a `job_id`; the client
polls /v1/loras/progress?job=<job_id>. The job keeps running on the server
regardless of the client — so a client that dies/restarts can re-attach by
job_id (or list its session's jobs with ?session=<token>) rather than
restarting the training."""
import asyncio
import uuid
if not req.name or '/' in req.name or '..' in req.name:
......@@ -1180,23 +1465,72 @@ async def train_lora(req: LoraTrainRequest, _auth=Depends(_require_api_auth)):
if not req.base_model:
raise HTTPException(status_code=400, detail="base_model is required")
request_id = f"lora-train-{uuid.uuid4().hex[:8]}"
# Wait for a scheduler slot (queues behind other in-flight work; the constant
# model key keeps trainings strictly one-at-a-time).
lease = await queue_manager.acquire(request_id, _TRAIN_MODEL_KEY)
session = req.session or f"sess-{uuid.uuid4().hex[:12]}"
req.session = session
job_id = f"job-{uuid.uuid4().hex[:12]}"
_new_job(job_id, session=session, name=req.name, base_model=req.base_model,
target=(req.target or "image"), status="queued", active=True,
step=0, total=req.steps or 0, message="queued",
started_at=time.time(), path=None)
if not req.wait:
# Detach: run independently of this request's lifetime. Keep a strong
# reference so the loop doesn't garbage-collect (cancel) the task.
task = asyncio.create_task(_detached_train(req, job_id))
_bg_tasks.add(task)
task.add_done_callback(_bg_tasks.discard)
return {"ok": True, "async": True, "job_id": job_id, "session": session,
"status": "queued"}
try:
result = await asyncio.to_thread(_train_lora_blocking, req)
result = await _run_train_job(req, job_id)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"LoRA training failed: {e}")
finally:
await queue_manager.release(lease)
return {"ok": True, **result}
return {"ok": True, "job_id": job_id, "session": session, **result}
async def _detached_train(req: LoraTrainRequest, job_id: str) -> None:
"""Background runner for wait=false jobs. Errors are recorded on the job
record (the client polls for them); nothing is raised to a waiting request."""
try:
await _run_train_job(req, job_id)
except Exception as e:
_update_job(job_id, status="error", active=False, force=True,
message=f"training failed: {e}"[:300])
@router.get("/v1/loras/progress")
async def lora_progress():
async def lora_progress(job: Optional[str] = None, session: Optional[str] = None):
"""Training progress.
- `?job=<job_id>` → that job's record (a client polls only its own job).
- `?session=<tok>` → the most recent job for that session (recovery after a
client restart that lost the job_id).
- neither → the global active-job snapshot (back-compat).
"""
if job:
with _jobs_lock:
rec = _jobs.get(job)
rec = dict(rec) if rec else None
if rec is None:
return {"active": False, "status": "unknown", "job_id": job,
"name": None, "step": 0, "total": 0,
"message": "no such job"}
return rec
if session:
with _jobs_lock:
mine = [dict(r) for r in _jobs.values() if r.get("session") == session]
if not mine:
return {"active": False, "status": "unknown", "session": session,
"name": None, "step": 0, "total": 0,
"message": "no jobs for session"}
# Prefer an active job; otherwise the most recently updated.
active = [r for r in mine if r.get("status") in _JOB_ACTIVE_STATES]
pick = (active or mine)
pick.sort(key=lambda r: r.get("updated_at", 0))
return pick[-1]
with _progress_lock:
return dict(_progress)
......
......@@ -16,6 +16,7 @@ import sys
import tempfile
import threading
import time
import urllib.parse
from pathlib import Path
from typing import Optional
......@@ -560,8 +561,9 @@ class CoderAIClient:
if api_key:
self.session.headers["Authorization"] = f"Bearer {api_key}"
def _post(self, path: str, body: dict) -> dict:
r = self.session.post(f"{self.base}{path}", json=body, timeout=self.timeout)
def _post(self, path: str, body: dict, timeout: int = None) -> dict:
r = self.session.post(f"{self.base}{path}", json=body,
timeout=timeout if timeout is not None else self.timeout)
if not r.ok:
raise RuntimeError(f"POST {path} → {r.status_code}: {r.text[:400]}")
return r.json()
......@@ -696,7 +698,8 @@ class CoderAIClient:
environment: str = None, images: list = None,
steps: int = 800, rank: int = 16,
resolution: int = 512, train_base_model: str = None,
target: str = "image", quantize_4bit: bool = True) -> dict:
target: str = "image", quantize_4bit: bool = True,
wait: bool = True, session: str = None) -> dict:
"""Train a per-character or per-environment LoRA on the server.
Blocks until complete.
......@@ -705,7 +708,9 @@ class CoderAIClient:
so it loads directly on the video pipeline."""
body = {"name": name, "base_model": base_model,
"steps": int(steps), "rank": int(rank), "resolution": int(resolution),
"target": target}
"target": target, "wait": bool(wait)}
if session:
body["session"] = session
if target == "video":
body["quantize_4bit"] = bool(quantize_4bit)
if train_base_model:
......@@ -716,7 +721,11 @@ class CoderAIClient:
body["environment"] = environment
if images:
body["images"] = images
return self._post("/v1/loras/train", body)
# Video DiT training (e.g. Wan A14B) can take many hours including the
# one-off model load; image LoRA is quicker. Allow a long ceiling so the
# blocking POST doesn't read-timeout while the server is still training.
train_timeout = 24 * 3600 if target == "video" else 4 * 3600
return self._post("/v1/loras/train", body, timeout=train_timeout)
def list_loras(self) -> list:
try:
......@@ -724,9 +733,14 @@ class CoderAIClient:
except Exception:
return []
def lora_progress(self) -> dict:
def lora_progress(self, job: str = None, session: str = None) -> dict:
try:
return self._get("/v1/loras/progress")
q = ""
if job:
q = f"?job={urllib.parse.quote(job)}"
elif session:
q = f"?session={urllib.parse.quote(session)}"
return self._get(f"/v1/loras/progress{q}")
except Exception:
return {}
......@@ -1381,6 +1395,51 @@ def _load_json_map(path: Path) -> dict:
return {}
# ── Training session token + per-LoRA server job tracking ─────────────────────
# A stable per-output session token lets the server tag this client's training
# jobs so we (and only we) can recover them after a restart — no spillover into
# another user's concurrent run. Per-LoRA server job_ids are persisted so that
# after a township restart we re-attach to a still-running server job instead of
# launching a duplicate (the server keeps training regardless of this client).
def _session_token(out_dir: Path) -> str:
p = Path(out_dir) / ".train_session"
try:
if p.exists():
tok = p.read_text().strip()
if tok:
return tok
except Exception:
pass
import uuid as _uuid
tok = f"township-{_uuid.uuid4().hex[:12]}"
try:
p.write_text(tok)
except Exception:
pass
return tok
def _train_jobs_file(out_dir: Path) -> Path:
return Path(out_dir) / "lora_train_jobs.json"
def _get_lora_job_id(out_dir: Path, lora_name: str):
return _load_json_map(_train_jobs_file(out_dir)).get(lora_name)
def _set_lora_job_id(out_dir: Path, lora_name: str, job_id):
f = _train_jobs_file(out_dir)
m = _load_json_map(f)
if job_id:
m[lora_name] = job_id
else:
m.pop(lora_name, None)
try:
f.write_text(json.dumps(m, indent=2))
except Exception:
pass
def _lora_specs_for(fighters: list, lora_map: dict, weight: float) -> list:
"""Build the `loras` request list for the fighters appearing in a clip."""
specs = []
......@@ -2525,67 +2584,114 @@ def launch_web_ui(default_args):
f"'{lora_name}' ({steps} steps, rank {rank})"
+ (f" against {model}" if is_video else "") + "…")
# Run the blocking train call in an inner thread; poll progress here.
result, err = {}, {}
def _do():
try:
kwargs = dict(name=lora_name, base_model=model,
steps=int(steps), rank=int(rank))
if is_video:
kwargs["target"] = "video"
else:
_tbm = getattr(default_args, "lora_train_base_model", None) or None
if _tbm:
kwargs["train_base_model"] = _tbm
kwargs[kind] = name
result["res"] = client.train_lora(**kwargs)
except Exception as e:
err["e"] = str(e)
t = threading.Thread(target=_do, daemon=True)
t.start()
# Kick the training off ASYNCHRONOUSLY (wait=False): the server runs
# it independently and we poll its job record. This (a) avoids HTTP
# read-timeouts on multi-hour video trainings, and (b) survives a
# township restart — the server keeps training and we re-attach by
# job_id rather than launching a duplicate. The session token tags the
# job as ours so recovery never picks up another client's job.
session = _session_token(out_dir)
_ACTIVE = ("queued", "preparing", "training", "saving")
def _kickoff():
kwargs = dict(name=lora_name, base_model=model,
steps=int(steps), rank=int(rank),
wait=False, session=session)
if is_video:
kwargs["target"] = "video"
else:
_tbm = getattr(default_args, "lora_train_base_model", None) or None
if _tbm:
kwargs["train_base_model"] = _tbm
kwargs[kind] = name
resp = client.train_lora(**kwargs)
jid = resp.get("job_id")
_set_lora_job_id(out_dir, lora_name, jid)
return jid
# Re-attach to an existing server job if we have one recorded and the
# server still knows it (running OR finished while we were away).
job = _get_lora_job_id(out_dir, lora_name)
if job:
pj = client.lora_progress(job=job)
jstatus = (pj.get("status") or "").strip()
if jstatus == "done" and pj.get("path"):
_web_log(f" ↻ Re-attached: '{lora_name}' already trained.")
elif jstatus in _ACTIVE:
_web_log(f" ↻ Re-attached to running job for '{lora_name}'.")
else:
# interrupted / error / unknown → resubmit (server resumes
# from its last on-disk checkpoint if one exists).
_web_log(f" ↻ Previous job for '{lora_name}' was "
f"'{jstatus or 'lost'}' — resubmitting (resumes from "
f"checkpoint)…")
job = _kickoff()
else:
job = _kickoff()
_start_ts = time.time()
_prog(6, "preparing…")
while t.is_alive():
path = None
err_msg = None
_resubmits = 0
while True:
time.sleep(1.5)
_elapsed = int(time.time() - _start_ts)
_mm, _ss = divmod(_elapsed, 60)
_et = f"{_mm}m{_ss:02d}s" if _mm else f"{_ss}s"
try:
p = client.lora_progress()
p = client.lora_progress(job=job)
except Exception:
# Server busy (large video models can load for many minutes,
# starving the progress endpoint) — keep the UI visibly alive.
_prog(6, f"preparing — loading model… ({_et})")
continue
# The server has ONE global training progress (jobs run one at a
# time via the queue). Only mirror it onto THIS card when it's
# reporting OUR LoRA; otherwise another job is training and we're
# still queued — show that instead of its progress.
pname = p.get("name") or ""
if pname and pname != lora_name:
_prog(4, f"queued — '{pname}' training first… ({_et})")
continue
status = (p.get("status") or "").strip()
if status == "done":
path = p.get("path")
break
if status == "error":
err_msg = p.get("message") or "training failed"
break
if status in ("interrupted", "unknown"):
# Server restarted (or forgot the job) mid-train. Resubmit
# once to resume from checkpoint; give up if it keeps dying.
if _resubmits < 2:
_resubmits += 1
_web_log(f" ↻ Server job '{status}' — resubmitting "
f"'{lora_name}' (resume #{_resubmits})…")
try:
job = _kickoff()
except Exception as e:
err_msg = f"resubmit failed: {e}"
break
_prog(6, f"resuming after interruption… ({_et})")
continue
err_msg = f"training {status} and could not be resumed"
break
total = p.get("total") or steps
step = p.get("step") or 0
if status in ("preparing", "saving") or not step:
# No training steps yet (model loading / encoding / saving).
if status == "queued":
# Our job is admitted but another training holds the GPU.
_prog(4, f"queued — waiting for GPU… ({_et})")
elif status in ("preparing", "saving") or not step:
_prog(6, (p.get("message") or status or "preparing")
+ f" ({_et})")
else:
pct = 6 + int(90 * step / max(1, total))
_prog(pct, p.get("message") or status or "training")
t.join()
if err:
_web_log(f" ✗ LoRA training failed for {name}: {err['e']}")
_fail(err["e"])
if err_msg:
_web_log(f" ✗ LoRA training failed for {name}: {err_msg}")
_set_lora_job_id(out_dir, lora_name, None)
_fail(err_msg)
return
res = result.get("res") or {}
path = res.get("path")
if not path:
_fail(f"training returned no path: {res}")
_set_lora_job_id(out_dir, lora_name, None)
_fail("training returned no path")
return
# Done — clear the recorded job so a future retrain starts clean.
_set_lora_job_id(out_dir, lora_name, None)
# Record the trained LoRA in the on-disk map so video/keyframe runs
# reuse it. Image LoRAs → loras.json/env_loras.json (flat name→path).
......@@ -2721,6 +2827,78 @@ def launch_web_ui(default_args):
lw = float(getattr(default_args, "lora_weight", 0.85))
elw = float(getattr(default_args, "env_lora_weight", 0.8))
# ── Regenerate keyframes (image model) ─────────────────────────────
# Deletes the targeted keyframe PNG(s) then regenerates them so a
# subsequent clip re-render uses fresh keyframes (e.g. after a LoRA
# retrain or profile edit). Does NOT re-render the video itself —
# click "Re-render" afterwards to rebuild the clip from the new
# keyframe. scope "keyframes" = whole match; "keyframe" = one clip
# (idx) or one outcome (fighter+outcome).
if scope in ("keyframes", "keyframe"):
image_model = getattr(default_args, "image_model", None)
if not image_model:
try:
image_model = pick_model(client, "image", None)
except Exception as e:
_fail(f"no image model available: {e}")
return
kdir = vdir / "keyframes"
m = next((x for x in fight_plan
if x.get("match_name") == match_name), None)
# Build the (filtered) fight/outcome plans + the list of stems to
# drop so _generate_keyframes (which keeps existing PNGs) remakes
# exactly those.
fp, op, stems = [], [], []
if scope == "keyframe" and params.get("fighter") and params.get("outcome"):
fr, oc = params.get("fighter"), params.get("outcome")
o = next((x for x in outcome_plan
if x.get("fighter") == fr and x.get("outcome") == oc
and (x.get("match_name") in (None, match_name))), None)
if not o:
_fail("outcome not found in prompts.json")
return
op = [o]
stems = [_clip_stem_outcome(fr, oc, o.get("match_name"))]
else:
if not m:
_fail("match not found in prompts.json — render it first")
return
mm = dict(m)
if scope == "keyframe":
idx = int(params.get("idx"))
mm["clips"] = [c for c in m["clips"] if int(c["idx"]) == idx]
if not mm["clips"]:
_fail("clip not found in prompts.json")
return
fp = [mm]
stems = [_clip_stem_fight(match_name, c["idx"]) for c in mm["clips"]]
_set_items([f"keyframe {s}" for s in stems])
for i, s in enumerate(stems):
_item(i, "start")
try:
(kdir / f"{s}.png").unlink()
except Exception:
pass
_prog(10, f"regenerating {len(stems)} keyframe(s)…")
try:
_generate_keyframes(
client, image_model, kdir, fp, op,
consistency | {"keyframe"}, lora_map,
float(getattr(default_args, "character_strength", 0.7)),
int(getattr(default_args, "keyframe_steps", 28)),
getattr(default_args, "keyframe_size", "512x512"), lw,
env_lora_map=env_lora_map, env_lora_weight=elw)
except Exception as e:
_fail(f"keyframe regeneration failed: {e}")
return
# Mark each item done/failed by whether its PNG now exists.
for i, s in enumerate(stems):
_item(i, "end", (kdir / f"{s}.png").exists())
made = sum(1 for s in stems if (kdir / f"{s}.png").exists())
_done(f"regenerated {made}/{len(stems)} keyframe(s) — "
f"now click Re-render to rebuild the video(s)")
return
# ── Enhance: upscale / raise-FPS existing finals + outcome videos ──
if scope == "enhance":
try:
......@@ -3958,10 +4136,14 @@ async function reMatch(ev, scope, params){
'clip':'Re-render this single clip?',
'reassemble':'Reassemble the final short/long videos from the existing clips? (fast, no model)',
'outcomes':'Re-render all output clips for this fighter (uses the video model)?',
'outcome':'Re-render this output clip?'};
'outcome':'Re-render this output clip?',
'keyframes':'Regenerate ALL keyframe images for this match (uses the image model)? Existing keyframes are replaced; the clip videos are NOT re-rendered — click Re-render afterwards.',
'keyframe':'Regenerate this keyframe image (uses the image model)? The clip video is NOT re-rendered — click Re-render afterwards.'};
const kf=(scope==='keyframes'||scope==='keyframe');
if(!(await uiConfirm(labels[scope]||'Proceed?',
{title:'Regenerate', okText:(scope==='reassemble'?'Reassemble':'Re-render'),
danger:(scope!=='reassemble')})))return;
{title:(kf?'Regenerate keyframes':'Regenerate'),
okText:(scope==='reassemble'?'Reassemble':(kf?'Regenerate':'Re-render')),
danger:(scope!=='reassemble'&&!kf)})))return;
const stEl=_findStatus(ev);
const setSt=(c,t)=>{ if(stEl){ stEl.style.color=c; stEl.textContent=t; } };
const fd=new FormData(); fd.append('scope',scope);
......@@ -4003,7 +4185,9 @@ async function delVid(ev, scope, params){
'final':'Delete this assembled video file?',
'match':'Delete ALL video files for this match (clips + finals)? The plan/prompts are kept so you can re-render.',
'output':'Delete this output video file?',
'outputs':'Delete ALL output video files for this fighter?'};
'outputs':'Delete ALL output video files for this fighter?',
'keyframes':'Clear ALL keyframe images for this match? The next re-render will run keyframe-free until you regenerate them.',
'keyframe':'Clear this keyframe image? The next re-render of it will run keyframe-free until you regenerate it.'};
if(!(await uiConfirm(labels[scope]||'Delete?',{title:'Remove videos', okText:'Delete', danger:true})))return;
const stEl=_findStatus(ev);
const setSt=(c,t)=>{ if(stEl){ stEl.style.color=c; stEl.textContent=t; } };
......@@ -4226,7 +4410,10 @@ document.addEventListener('DOMContentLoaded', resumeMatchJobs);
f'<div class=card style="width:230px">'
f' <div class=hint style="display:flex;justify-content:space-between;align-items:center">'
f'<span>clip {idx:02d}</span>'
f'<span><a href="#" style="color:#7eb8f7" '
f'<span>'
f'<a href="#" style="color:#c79bf0" title="Regenerate this keyframe (image model)" '
f'onclick="reMatch(event,\'keyframe\',{{match:\'{_esc(name)}\',idx:\'{idx}\'}})">kf↻</a> '
f'<a href="#" style="color:#7eb8f7" '
f'onclick="reMatch(event,\'clip\',{{match:\'{_esc(name)}\',idx:\'{idx}\'}})">re-render</a> {rm_html}</span></div>'
f' {vid_html}'
f' <textarea data-clip="{idx}" rows=2 style="margin-top:.3rem">{_esc(c.get("prompt",""))}</textarea>'
......@@ -4260,7 +4447,10 @@ document.addEventListener('DOMContentLoaded', resumeMatchJobs);
f'<div class=card style="width:215px">'
f' <div class=hint style="display:flex;justify-content:space-between;align-items:center">'
f'<span>{_esc(oc)}</span>'
f'<span><a href="#" style="color:#7eb8f7" '
f'<span>'
f'<a href="#" style="color:#c79bf0" title="Regenerate this keyframe (image model)" '
f'onclick="reMatch(event,\'keyframe\',{{match:\'{_esc(name)}\',fighter:\'{_esc(fr)}\',outcome:\'{_esc(oc)}\'}})">kf↻</a> '
f'<a href="#" style="color:#7eb8f7" '
f'onclick="reMatch(event,\'outcome\',{{match:\'{_esc(name)}\',fighter:\'{_esc(fr)}\',outcome:\'{_esc(oc)}\'}})">{act}</a> {rm}</span></div>'
f' {vid}'
f' <textarea data-outc="{_esc(fr)}|{_esc(oc)}" rows=2 style="margin-top:.3rem">{_esc(o.get("prompt",""))}</textarea>'
......@@ -4303,6 +4493,10 @@ document.addEventListener('DOMContentLoaded', resumeMatchJobs);
f'onclick="reMatch(event,\'reassemble\',{{match:\'{_esc(name)}\'}})">🎞 Reassemble finals</button>'
f' <button class="btn btn-secondary" style="font-size:.82rem;padding:.35rem .9rem" '
f'onclick="reMatch(event,\'outcomes\',{{match:\'{_esc(name)}\'}})">♻ Re-render all outcomes</button>'
f' <button class="btn btn-secondary" style="font-size:.82rem;padding:.35rem .9rem" '
f'onclick="reMatch(event,\'keyframes\',{{match:\'{_esc(name)}\'}})">🖼 Regenerate keyframes</button>'
f' <button class="btn btn-danger" style="font-size:.82rem;padding:.35rem .9rem" '
f'onclick="delVid(event,\'keyframes\',{{match:\'{_esc(name)}\'}})">🧹 Clear keyframes</button>'
f' <button class="btn btn-danger" style="font-size:.82rem;padding:.35rem .9rem" '
f'onclick="delVid(event,\'match\',{{match:\'{_esc(name)}\'}})">🗑 Remove all videos</button>'
f' </div>'
......@@ -5110,6 +5304,35 @@ async function pollJob(){
if not _safe(fn) or not fn.endswith(".mp4"):
self._send(400, "application/json", _j.dumps({"error": "invalid file"})); return
_rm(vdir / fn)
elif scope == "keyframes":
# Clear ALL keyframe PNGs for a match (its clips + its outcomes).
mn = _fv("match")
if not _safe(mn):
self._send(400, "application/json", _j.dumps({"error": "invalid match"})); return
kdir = vdir / "keyframes"
for p in kdir.glob(f"{mn}_clip*.png"):
_rm(p)
_, _, _, _matches_map, _ = _scan_matches()
for (_f, _o, _p) in _matches_map.get(mn, {}).get("outcomes", []):
_rm(kdir / f"{Path(_p).stem}.png")
elif scope == "keyframe":
# Clear one clip's (match+idx) or one outcome's (fighter+outcome) keyframe.
mn = _fv("match")
kdir = vdir / "keyframes"
fr, oc = _fv("fighter"), _fv("outcome")
if fr and oc:
if not _safe(fr) or not _safe(oc):
self._send(400, "application/json", _j.dumps({"error": "invalid args"})); return
stem = f"{mn}_{fr}_{oc}" if (mn and _safe(mn)) else f"{fr}_{oc}"
_rm(kdir / f"{stem}.png")
else:
idx = _fv("idx")
if not _safe(mn):
self._send(400, "application/json", _j.dumps({"error": "invalid match"})); return
try:
_rm(kdir / f"{mn}_clip{int(idx):02d}.png")
except ValueError:
self._send(400, "application/json", _j.dumps({"error": "invalid idx"})); return
else:
self._send(400, "application/json", _j.dumps({"error": "invalid scope"})); return
self._send(200, "application/json", _j.dumps({"ok": True, "removed": removed}))
......
......@@ -72,6 +72,54 @@ def safe_slug(value: str) -> str:
return value.strip("_-. ") or f"item_{uuid.uuid4().hex[:8]}"
def load_json_map(path: Path) -> dict[str, Any]:
try:
if path.exists():
data = json.loads(path.read_text(encoding="utf-8"))
return data if isinstance(data, dict) else {}
except Exception:
pass
return {}
def session_token(out_dir: Path) -> str:
path = out_dir / ".train_session"
try:
if path.exists():
token = path.read_text(encoding="utf-8").strip()
if token:
return token
except Exception:
pass
token = f"videogen-{uuid.uuid4().hex[:12]}"
try:
path.write_text(token, encoding="utf-8")
except Exception:
pass
return token
def lora_jobs_file(out_dir: Path) -> Path:
return out_dir / "lora_train_jobs.json"
def get_lora_job_id(out_dir: Path, lora_name: str) -> str | None:
return load_json_map(lora_jobs_file(out_dir)).get(lora_name)
def set_lora_job_id(out_dir: Path, lora_name: str, job_id: str | None) -> None:
path = lora_jobs_file(out_dir)
data = load_json_map(path)
if job_id:
data[lora_name] = job_id
else:
data.pop(lora_name, None)
try:
path.write_text(json.dumps(data, indent=2), encoding="utf-8")
except Exception:
pass
def data_uri_for_file(path: Path, mime: str | None = None) -> str:
if mime is None:
mime = mimetypes.guess_type(str(path))[0] or "application/octet-stream"
......@@ -150,6 +198,20 @@ def mux_background_audio(video_path: Path, audio_path: Path, out_path: Path, mus
])
def extract_video_frames(video_path: Path, out_dir: Path, max_frames: int = 6) -> list[Path]:
out_dir.mkdir(parents=True, exist_ok=True)
duration = max(0.1, video_duration(video_path))
count = max(1, int(max_frames))
frames = []
for idx in range(count):
ts = duration * (idx + 0.5) / count
out = out_dir / f"{video_path.stem}_frame_{idx:02d}.png"
run_cmd([ffmpeg(), "-y", "-ss", f"{ts:.3f}", "-i", str(video_path), "-frames:v", "1", str(out)])
if out.exists() and out.stat().st_size:
frames.append(out)
return frames
class CoderAIClient:
def __init__(self, base_url: str, api_key: str | None = None, timeout: int = 7200):
self.base = base_url.rstrip("/")
......@@ -158,8 +220,8 @@ class CoderAIClient:
if api_key:
self.session.headers["Authorization"] = f"Bearer {api_key}"
def _get(self, path: str) -> dict[str, Any]:
resp = self.session.get(f"{self.base}{path}", timeout=60)
def _get(self, path: str, timeout: int = 60) -> dict[str, Any]:
resp = self.session.get(f"{self.base}{path}", timeout=timeout)
if not resp.ok:
raise RuntimeError(f"GET {path} -> {resp.status_code}: {resp.text[:800]}")
return resp.json()
......@@ -170,6 +232,12 @@ class CoderAIClient:
raise RuntimeError(f"POST {path} -> {resp.status_code}: {resp.text[:1200]}")
return resp.json()
def _post_multipart(self, path: str, data: dict[str, Any], files: dict[str, Any]) -> dict[str, Any]:
resp = self.session.post(f"{self.base}{path}", data=data, files=files, timeout=self.timeout)
if not resp.ok:
raise RuntimeError(f"POST {path} -> {resp.status_code}: {resp.text[:1200]}")
return resp.json()
def _patch(self, path: str, body: dict[str, Any]) -> dict[str, Any]:
resp = self.session.patch(f"{self.base}{path}", json=body, timeout=self.timeout)
if not resp.ok:
......@@ -206,6 +274,51 @@ class CoderAIClient:
except Exception:
return []
def list_voices(self) -> list[dict[str, Any]]:
try:
return self._get("/v1/audio/voices").get("voices", [])
except Exception:
return []
def list_loras(self) -> list[dict[str, Any]]:
try:
return self._get("/v1/loras").get("loras", [])
except Exception:
return []
def create_voice(self, name: str, description: str, transcript: str, audio_path: Path) -> dict[str, Any]:
mime = mimetypes.guess_type(audio_path.name)[0] or "audio/wav"
with audio_path.open("rb") as handle:
return self._post_multipart(
"/v1/audio/voices",
{"name": name, "description": description, "transcript": transcript},
{"audio": (audio_path.name, handle, mime)},
)
def extract_voice(self, name: str, description: str, transcript: str, media_path: Path) -> dict[str, Any]:
mime = mimetypes.guess_type(media_path.name)[0] or "application/octet-stream"
key = "video" if mime.startswith("video/") or media_path.suffix.lower() in {".mp4", ".mov", ".webm", ".mkv"} else "audio"
return self._post("/v1/audio/voices/extract", {
"name": name,
"description": description,
"transcript": transcript,
key: data_uri_for_file(media_path, mime),
})
def train_lora(self, body: dict[str, Any]) -> dict[str, Any]:
return self._post("/v1/loras/train", body)
def lora_progress(self, job: str | None = None, session: str | None = None) -> dict[str, Any]:
query = ""
if job:
query = f"?job={urllib.parse.quote(job)}"
elif session:
query = f"?session={urllib.parse.quote(session)}"
try:
return self._get(f"/v1/loras/progress{query}", timeout=30)
except Exception:
return {}
def get_profile_images(self, kind: str, name: str) -> list[str]:
plural = "characters" if kind == "character" else "environments"
try:
......@@ -424,8 +537,168 @@ class VideoGenApp:
return {
"characters": [v for k, v in sorted(chars.items()) if k],
"environments": [v for k, v in sorted(envs.items()) if k],
"voices": self.client.list_voices(),
"loras": self.client.list_loras(),
}
def start_lora_job(self, payload: dict[str, Any]) -> str:
job_id = f"lora-{uuid.uuid4().hex[:10]}"
with self.lock:
self.jobs[job_id] = {"status": "queued", "progress": 0, "kind": "lora"}
thread = threading.Thread(target=self._lora_job, args=(job_id, payload), daemon=True)
thread.start()
return job_id
def _lora_job(self, job_id: str, payload: dict[str, Any]) -> None:
name = safe_slug(payload.get("name") or "movie_style")
base_model = payload.get("base_model") or self.args.video_model or pick_model(self.client.list_models(), "video_generation")
target = payload.get("target") or "video"
media_paths = [Path(p.strip()) for p in (payload.get("media_paths") or "").splitlines() if p.strip()]
style_prompt = payload.get("style_prompt") or f"{name} cinematic style"
instance_prompt = payload.get("instance_prompt") or f"in {name} style, {style_prompt}"
frames_per_video = int(payload.get("frames_per_video") or 6)
try:
if not base_model:
raise RuntimeError("No base model configured for LoRA training")
self._job_update(job_id, status="running", progress=5, message="collecting training media")
train_dir = self.out_dir / "lora_training" / name
frame_dir = train_dir / "frames"
train_dir.mkdir(parents=True, exist_ok=True)
images: list[str] = []
scene_manifest = []
for idx, path in enumerate(media_paths):
if not path.exists() or not path.is_file():
self.emit(f"Skipping missing LoRA media: {path}")
continue
ext = path.suffix.lower()
scene_manifest.append({"path": str(path), "description": payload.get("scene_description") or ""})
if ext in {".png", ".jpg", ".jpeg", ".webp"}:
images.append(data_uri_for_file(path, mimetypes.guess_type(path.name)[0] or "image/png"))
elif ext in {".mp4", ".mov", ".webm", ".mkv"}:
self.emit(f"Extracting frames from {path.name}")
for frame in extract_video_frames(path, frame_dir / safe_slug(path.stem), frames_per_video):
images.append(data_uri_for_file(frame, "image/png"))
if not images:
raise RuntimeError("No image frames found. Add image files or videos to train from.")
(train_dir / "manifest.json").write_text(json.dumps({"name": name, "style_prompt": style_prompt, "scenes": scene_manifest}, indent=2), encoding="utf-8")
self.emit(f"Training {target} LoRA '{name}' from {len(images)} image/frame sample(s)")
self._job_update(job_id, progress=25, message="training LoRA")
train_body = {
"name": name,
"base_model": base_model,
"train_base_model": payload.get("train_base_model") or None,
"target": target,
"images": images,
"instance_prompt": instance_prompt,
"steps": int(payload.get("steps") or 800),
"rank": int(payload.get("rank") or 16),
"learning_rate": float(payload.get("learning_rate") or 1e-4),
"resolution": int(payload.get("resolution") or 512),
"seed": int(payload.get("seed") or 42),
"num_frames": int(payload.get("num_frames") or 1),
"quantize_4bit": bool(payload.get("quantize_4bit", True)),
"wait": False,
"session": session_token(self.out_dir),
}
path = self._attach_or_start_lora(job_id, name, train_body)
self.emit(f"LoRA ready: {name}" + (f" -> {path}" if path else ""))
set_lora_job_id(self.out_dir, name, None)
self._job_update(job_id, status="done", progress=100, message="done", name=name, path=path)
except Exception as exc:
self.emit(f"LoRA training failed: {exc}")
self._job_update(job_id, status="error", error=str(exc), message=str(exc))
def _attach_or_start_lora(self, ui_job_id: str, lora_name: str, train_body: dict[str, Any]) -> str | None:
active_states = {"queued", "preparing", "training", "saving"}
def kickoff() -> str:
resp = self.client.train_lora({k: v for k, v in train_body.items() if v is not None})
server_job = resp.get("job_id")
if not server_job:
raise RuntimeError(f"LoRA training did not return a job_id: {resp}")
set_lora_job_id(self.out_dir, lora_name, server_job)
return server_job
server_job = get_lora_job_id(self.out_dir, lora_name)
if server_job:
progress = self.client.lora_progress(job=server_job)
status = (progress.get("status") or "").strip()
if status == "done" and progress.get("path"):
self.emit(f"Re-attached: LoRA '{lora_name}' was already trained")
return progress.get("path")
if status in active_states:
self.emit(f"Re-attached to running LoRA job for '{lora_name}'")
else:
self.emit(f"Previous LoRA job for '{lora_name}' was '{status or 'lost'}'; resubmitting to resume from checkpoint")
server_job = kickoff()
else:
server_job = kickoff()
started = time.time()
resubmits = 0
while True:
if self._is_cancelled(ui_job_id):
raise RuntimeError("cancelled")
time.sleep(2.0)
elapsed = int(time.time() - started)
mm, ss = divmod(elapsed, 60)
et = f"{mm}m{ss:02d}s" if mm else f"{ss}s"
progress = self.client.lora_progress(job=server_job)
status = (progress.get("status") or "").strip()
if status == "done":
return progress.get("path")
if status == "error":
raise RuntimeError(progress.get("message") or "LoRA training failed")
if status in {"interrupted", "unknown", ""}:
if resubmits < 2:
resubmits += 1
self.emit(f"LoRA job '{status or 'unknown'}' for '{lora_name}'; resubmitting resume #{resubmits}")
server_job = kickoff()
self._job_update(ui_job_id, progress=30, message=f"resuming after interruption ({et})")
continue
raise RuntimeError(f"training {status or 'unknown'} and could not be resumed")
total = progress.get("total") or train_body.get("steps") or 1
step = progress.get("step") or 0
if status == "queued":
pct = 25
msg = f"queued - waiting for GPU ({et})"
elif status in {"preparing", "saving"} or not step:
pct = 30 if status != "saving" else 94
msg = f"{progress.get('message') or status or 'preparing'} ({et})"
else:
pct = 30 + int(64 * step / max(1, total))
msg = progress.get("message") or f"training {step}/{total} ({et})"
self._job_update(ui_job_id, progress=min(98, pct), message=msg, server_job=server_job)
def start_voice_job(self, payload: dict[str, Any]) -> str:
job_id = f"voice-{uuid.uuid4().hex[:10]}"
with self.lock:
self.jobs[job_id] = {"status": "queued", "progress": 0, "kind": "voice"}
thread = threading.Thread(target=self._voice_job, args=(job_id, payload), daemon=True)
thread.start()
return job_id
def _voice_job(self, job_id: str, payload: dict[str, Any]) -> None:
name = safe_slug(payload.get("name") or "voice")
description = payload.get("description") or ""
transcript = payload.get("transcript") or ""
media_path = Path(payload.get("media_path") or "")
try:
self._job_update(job_id, status="running", progress=5, message="validating reference media")
if not media_path.exists() or not media_path.is_file():
raise RuntimeError(f"Reference audio/video file not found: {media_path}")
self.emit(f"Creating cloned voice profile '{name}' from {media_path.name}")
self._job_update(job_id, progress=25, message="uploading reference media")
if payload.get("extract", True):
self.client.extract_voice(name, description, transcript, media_path)
else:
self.client.create_voice(name, description, transcript, media_path)
self.emit(f"Voice profile ready: {name}")
self._job_update(job_id, status="done", progress=100, message="done", name=name)
except Exception as exc:
self.emit(f"Voice job failed: {exc}")
self._job_update(job_id, status="error", error=str(exc), message=str(exc))
def start_profile_job(self, payload: dict[str, Any]) -> str:
job_id = f"profile-{uuid.uuid4().hex[:10]}"
with self.lock:
......@@ -480,10 +753,32 @@ class VideoGenApp:
job_id = f"movie-{uuid.uuid4().hex[:10]}"
with self.lock:
self.jobs[job_id] = {"status": "queued", "progress": 0, "movie": payload.get("title") or "movie"}
thread = threading.Thread(target=self._movie_job, args=(job_id, payload), daemon=True)
target = self._movie_batch_job if int(payload.get("movie_count") or 1) > 1 else self._movie_job
thread = threading.Thread(target=target, args=(job_id, payload), daemon=True)
thread.start()
return job_id
def _movie_batch_job(self, job_id: str, payload: dict[str, Any]) -> None:
count = max(1, int(payload.get("movie_count") or 1))
outputs = []
base_title = payload.get("title") or "untitled_movie"
for idx in range(count):
if self._is_cancelled(job_id):
self._job_update(job_id, status="error", error="cancelled", message="cancelled")
return
variant = dict(payload)
variant["movie_count"] = 1
variant["title"] = f"{base_title}_variant_{idx + 1:02d}"
self.emit(f"Rendering movie variant {idx + 1}/{count}")
self._movie_job(job_id, variant)
with self.lock:
job = dict(self.jobs.get(job_id, {}))
if job.get("status") == "error":
return
if job.get("output_url"):
outputs.append(job["output_url"])
self._job_update(job_id, status="done", progress=100, message=f"rendered {len(outputs)} movie variant(s)", outputs=outputs, output_url=outputs[-1] if outputs else None)
def _movie_job(self, job_id: str, payload: dict[str, Any]) -> None:
title = payload.get("title") or "untitled_movie"
slug = safe_slug(title)
......@@ -508,6 +803,7 @@ class VideoGenApp:
height = int(payload.get("height") or 432)
default_frames = int(payload.get("num_frames") or 32)
use_keyframes = bool(payload.get("use_keyframes"))
selected_loras = self._selected_loras(payload)
clip_paths: list[Path] = []
total = len(clips)
self.emit(f"Starting movie '{title}' with {total} clip(s)")
......@@ -540,15 +836,20 @@ class VideoGenApp:
body["environment_profiles"] = environments
if clip.get("camera_motion"):
body["camera_motion"] = clip.get("camera_motion")
if selected_loras:
body["loras"] = selected_loras
if clip.get("dialogues"):
body["dialogs"] = self._normalize_dialogues(clip.get("dialogues"))
body["lip_sync"] = bool(clip.get("lip_sync", True))
body["lip_sync_method"] = clip.get("lip_sync_method") or payload.get("lip_sync_method") or "wav2lip"
body["generate_subtitles"] = bool(clip.get("subtitles", True))
body["burn_subtitles"] = bool(clip.get("burn_subtitles", False))
if clip.get("speech_text"):
body["tts_text"] = clip.get("speech_text")
body["tts_voice"] = clip.get("speech_voice") or payload.get("default_voice") or "af_sarah"
body["tts_speed"] = float(clip.get("speech_speed") or 1.0)
body["lip_sync"] = bool(clip.get("lip_sync", True))
body["lip_sync_method"] = clip.get("lip_sync_method") or payload.get("lip_sync_method") or "wav2lip"
body["add_audio"] = True
body["audio_type"] = "speech"
if clip.get("music_prompt") or clip.get("sfx_prompt"):
......@@ -600,6 +901,23 @@ class VideoGenApp:
self.emit(f"Movie failed: {exc}")
self._job_update(job_id, status="error", error=str(exc), message=str(exc))
def _selected_loras(self, payload: dict[str, Any]) -> list[dict[str, Any]]:
requested = payload.get("loras") or []
if not requested:
return []
by_name = {item.get("name"): item for item in self.client.list_loras() if item.get("name")}
out = []
for item in requested:
if isinstance(item, str):
name, weight = item, float(payload.get("lora_weight") or 0.8)
else:
name, weight = item.get("name") or item.get("model"), float(item.get("weight") or payload.get("lora_weight") or 0.8)
info = by_name.get(name, {})
model = info.get("path") or name
if model:
out.append({"model": model, "weight": weight, "name": name})
return out
def _build_clip_prompt(self, movie: dict[str, Any], clip: dict[str, Any], characters: list[str], environments: list[str]) -> str:
parts = []
if movie.get("style"):
......@@ -642,6 +960,7 @@ class VideoGenApp:
"text": row.get("text"),
"start_time": row.get("start_time") if row.get("start_time") not in ("", None) else None,
"lip_sync": bool(row.get("lip_sync", True)),
"lang": row.get("lang") or None,
"speed": float(row.get("speed") or 1.0),
})
return out
......@@ -673,7 +992,7 @@ HTML_PAGE = r"""
</style>
</head>
<body>
<header><h1>CoderAI VideoGen Studio</h1><div class="sub">Manage reusable characters and environments, write a multi-clip movie prompt, then render clips with speech, lip-sync, music, and sound effects.</div></header>
<header><h1>CoderAI VideoGen Studio</h1><div class="sub">Manage reusable characters, environments, cloned voices, and multi-clip movies with realistic dialogue, lip-sync, music, and sound effects.</div></header>
<div class="wrap">
<aside class="card">
<h2>Connection</h2>
......@@ -681,7 +1000,7 @@ HTML_PAGE = r"""
<label>Image model</label><select id="image_model"></select>
<label>Video model</label><select id="video_model"></select>
<label>Audio/Music model</label><select id="audio_model"></select>
<label>TTS voice id</label><input id="default_voice" value="af_sarah">
<label>Default voice / cloned profile</label><input id="default_voice" list="voice_list" value="af_sarah"><datalist id="voice_list"></datalist>
<hr style="border-color:var(--line);border-style:solid none none;margin:16px 0">
<h2>Live Log</h2>
<div class="log" id="log"></div>
......@@ -691,11 +1010,12 @@ HTML_PAGE = r"""
<div class="tabs">
<button class="tab active" data-tab="profiles">Profiles</button>
<button class="tab" data-tab="movie">Movie Builder</button>
<button class="tab" data-tab="loras">Style LoRAs</button>
<button class="tab" data-tab="gallery">Gallery</button>
</div>
<section id="profiles" class="section active">
<h2>Characters and Environments</h2>
<h2>Characters, Environments, and Voices</h2>
<div class="row">
<div class="card">
<h3>Create Character</h3>
......@@ -714,8 +1034,17 @@ HTML_PAGE = r"""
<button class="btn" onclick="createProfile('environment')">Generate Environment</button>
</div>
</div>
<div class="card" style="margin-top:12px">
<h3>Create Cloned Voice</h3>
<div class="muted">Use a clean 5-30 second reference clip plus an exact transcript for realistic cloned dialogue. Audio or video files are accepted.</div>
<div class="row"><div><label>Voice name</label><input id="voice_name" placeholder="alice_voice"></div><div><label>Reference audio/video path</label><input id="voice_media" placeholder="/path/to/reference.wav"></div></div>
<label>Description</label><input id="voice_desc" placeholder="Warm, calm adult voice; close mic">
<label>Exact reference transcript</label><textarea id="voice_transcript" placeholder="The exact words spoken in the reference clip. Required for best F5-TTS cloning."></textarea>
<button class="btn" onclick="createVoice()">Create Cloned Voice</button>
</div>
<h3>Saved Characters</h3><div class="grid" id="chars"></div>
<h3>Saved Environments</h3><div class="grid" id="envs"></div>
<h3>Saved Voices</h3><div class="grid" id="voices"></div>
</section>
<section id="movie" class="section">
......@@ -724,6 +1053,8 @@ HTML_PAGE = r"""
<div><label>Title</label><input id="title" value="my_little_movie"></div>
<div><label>Visual style</label><input id="style" value="cinematic, coherent character identity, natural motion, detailed lighting"></div>
</div>
<label>Style LoRAs to apply</label><select id="movie_loras" multiple size="5"></select>
<div class="row"><div><label>LoRA weight</label><input id="movie_lora_weight" type="number" step="0.05" value="0.8"></div><div><label>Movie variants to produce</label><input id="movie_count" type="number" min="1" value="1"></div></div>
<div class="row">
<div><label>Width</label><input id="width" type="number" value="768"></div>
<div><label>Height</label><input id="height" type="number" value="432"></div>
......@@ -738,12 +1069,28 @@ HTML_PAGE = r"""
</div>
<label>Global negative prompt</label><input id="negative_prompt" value="flicker, morphing faces, extra limbs, low quality, unreadable text">
<label><input id="use_keyframes" type="checkbox" style="width:auto"> Generate keyframe image before each video clip for stronger character/environment consistency</label>
<label>Lip-sync method</label><select id="lip_sync_method"><option value="wav2lip">wav2lip</option><option value="sadtalker">sadtalker</option></select>
<label>Final soundtrack prompt (optional; mixed under assembled movie)</label><textarea id="soundtrack_prompt" placeholder="tense orchestral pulse with soft percussion, no vocals"></textarea>
<div id="clips"></div>
<button class="btn secondary" onclick="addClip()">Add Clip</button>
<button class="btn" onclick="startMovie()">Render Movie</button>
</section>
<section id="loras" class="section">
<h2>Style LoRA Training</h2>
<div class="muted">Train a reusable style LoRA from image/video scene references. Add stills or videos, describe the visual style, then choose the trained LoRA in Movie Builder to produce one or more movies in that style.</div>
<div class="card" style="margin-top:12px">
<div class="row"><div><label>LoRA name</label><input id="lora_name" placeholder="noir_rain_style"></div><div><label>Target</label><select id="lora_target"><option value="video">video</option><option value="image">image</option></select></div></div>
<label>Training media paths, one per line</label><textarea id="lora_media" placeholder="/path/to/style_reference_01.png&#10;/path/to/scene_clip.mp4"></textarea>
<label>Scene/style description</label><textarea id="lora_style" placeholder="Describe the scene language: lighting, lens, color palette, texture, movement, costumes, production design..."></textarea>
<label>Instance prompt/token</label><input id="lora_instance" placeholder="in noir_rain_style cinematic style">
<div class="row"><div><label>Steps</label><input id="lora_steps" type="number" value="800"></div><div><label>Rank</label><input id="lora_rank" type="number" value="16"></div></div>
<div class="row"><div><label>Resolution</label><input id="lora_resolution" type="number" value="512"></div><div><label>Frames per video sample</label><input id="lora_frames" type="number" value="6"></div></div>
<button class="btn" onclick="trainLora()">Train Style LoRA</button>
</div>
<h3>Available LoRAs</h3><div class="grid" id="loras_grid"></div>
</section>
<section id="gallery" class="section">
<h2>Gallery</h2>
<button class="btn secondary" onclick="loadGallery()">Refresh Gallery</button>
......@@ -752,7 +1099,7 @@ HTML_PAGE = r"""
</main>
</div>
<script>
let models=[], profiles={characters:[], environments:[]};
let models=[], profiles={characters:[], environments:[], voices:[], loras:[]};
function $(id){return document.getElementById(id)}
function esc(s){return String(s||'').replace(/[&<>"']/g,m=>({'&':'&amp;','<':'&lt;','>':'&gt;','"':'&quot;',"'":'&#39;'}[m]))}
async function api(path, opts={}){let r=await fetch(path,{headers:{'Content-Type':'application/json'},...opts}); if(!r.ok) throw new Error(await r.text()); return await r.json()}
......@@ -760,8 +1107,14 @@ function fillSelect(sel, cap, def){let s=$(sel); s.innerHTML=''; let filtered=mo
async function loadModels(){let d=await api('/api/models'); models=d.models||[]; fillSelect('image_model','image_generation',d.defaults.image_model); fillSelect('video_model','video_generation',d.defaults.video_model); fillSelect('audio_model','audio_generation',d.defaults.audio_model); $('conn').textContent=`Connected: ${models.length} model(s)`}
async function loadProfiles(){profiles=await api('/api/profiles'); renderProfiles()}
function profileCard(p,kind){return `<div class="profile"><img src="${p.thumbnail||''}" onerror="this.style.display='none'"><div class="p"><b>${esc(p.name)}</b><div class="muted">${esc(p.description||'')}</div><span class="pill">${kind}</span><span class="pill">${p.image_count||0} refs</span><span class="pill">${p.local?'local':'server'}</span></div></div>`}
function renderProfiles(){$('chars').innerHTML=profiles.characters.map(p=>profileCard(p,'character')).join('')||'<div class="muted">No characters yet.</div>'; $('envs').innerHTML=profiles.environments.map(p=>profileCard(p,'environment')).join('')||'<div class="muted">No environments yet.</div>'; renderClipSelectors()}
function voiceCard(v){return `<div class="profile"><div class="p"><b>${esc(v.name)}</b><div class="muted">${esc(v.description||'')}</div><span class="pill">cloned voice</span><span class="pill">${esc(v.audio_ext||'audio')}</span></div></div>`}
function loraCard(l){return `<div class="profile"><div class="p"><b>${esc(l.name)}</b><div class="muted">${esc(l.instance_prompt||l.path||'')}</div><span class="pill">${esc(l.target||'lora')}</span><span class="pill">rank ${esc(l.rank||'')}</span></div></div>`}
function renderVoiceList(){let voices=['af_sarah','af_sky','af_bella','am_adam','am_michael','en-US-JennyNeural',...(profiles.voices||[]).map(v=>v.name)].filter(Boolean); $('voice_list').innerHTML=[...new Set(voices)].map(v=>`<option value="${esc(v)}"></option>`).join('')}
function renderLoraList(){let opts=(profiles.loras||[]).map(l=>`<option value="${esc(l.name)}">${esc(l.name)}${l.target?' ('+esc(l.target)+')':''}</option>`).join(''); $('movie_loras').innerHTML=opts; $('loras_grid').innerHTML=(profiles.loras||[]).map(loraCard).join('')||'<div class="muted">No style LoRAs yet. Train one from media references.</div>'}
function renderProfiles(){$('chars').innerHTML=profiles.characters.map(p=>profileCard(p,'character')).join('')||'<div class="muted">No characters yet.</div>'; $('envs').innerHTML=profiles.environments.map(p=>profileCard(p,'environment')).join('')||'<div class="muted">No environments yet.</div>'; $('voices').innerHTML=(profiles.voices||[]).map(voiceCard).join('')||'<div class="muted">No cloned voices yet. Create one from a reference clip.</div>'; renderVoiceList(); renderLoraList(); renderClipSelectors()}
async function createProfile(kind){let isChar=kind==='character'; let name=$(isChar?'char_name':'env_name').value; let desc=$(isChar?'char_desc':'env_desc').value; let prompt=$(isChar?'char_prompt':'env_prompt').value||desc; let n=$(isChar?'char_n':'env_n').value; let [w,h]=($(isChar?'char_size':'env_size').value||'512x512').split('x').map(x=>parseInt(x,10)); let model=$('image_model').value; let d=await api('/api/profile/start',{method:'POST',body:JSON.stringify({kind,name,description:desc,prompt,model,n,width:w,height:h})}); watchJob(d.job_id);}
async function createVoice(){let d=await api('/api/voice/start',{method:'POST',body:JSON.stringify({name:$('voice_name').value,description:$('voice_desc').value,transcript:$('voice_transcript').value,media_path:$('voice_media').value,extract:true})}); watchJob(d.job_id)}
async function trainLora(){let d=await api('/api/lora/start',{method:'POST',body:JSON.stringify({name:$('lora_name').value,base_model:$('video_model').value,target:$('lora_target').value,media_paths:$('lora_media').value,style_prompt:$('lora_style').value,scene_description:$('lora_style').value,instance_prompt:$('lora_instance').value,steps:$('lora_steps').value,rank:$('lora_rank').value,resolution:$('lora_resolution').value,frames_per_video:$('lora_frames').value,quantize_4bit:true})}); watchJob(d.job_id)}
function options(items){return items.map(p=>`<option value="${esc(p.name)}">${esc(p.name)}</option>`).join('')}
function addClip(data={}){let idx=document.querySelectorAll('.clip').length+1; let div=document.createElement('div'); div.className='clip'; div.innerHTML=`<div class="clip-head"><h3>Clip ${idx}</h3><button class="btn bad" onclick="this.closest('.clip').remove()">Remove</button></div>
<label>Clip title</label><input class="c_title" value="${esc(data.title||'Shot '+idx)}">
......@@ -769,15 +1122,16 @@ function addClip(data={}){let idx=document.querySelectorAll('.clip').length+1; l
<div class="row"><div><label>Characters</label><select class="c_chars" multiple size="5">${options(profiles.characters)}</select></div><div><label>Environments</label><select class="c_envs" multiple size="5">${options(profiles.environments)}</select></div></div>
<div class="row"><div><label>Camera motion</label><input class="c_camera" placeholder="zoom-in, pan-left, handheld..."></div><div><label>Mood/action</label><input class="c_action" placeholder="what happens in this clip"></div></div>
<label>Speech text (simple one-speaker; optional)</label><input class="c_speech" placeholder="Line spoken in this shot">
<div class="row"><div><label>Speech voice</label><input class="c_voice" value="${esc($('default_voice').value||'af_sarah')}"></div><div><label><input class="c_lipsync" type="checkbox" checked style="width:auto"> Lip-sync speech/dialogue</label></div></div>
<div class="row"><div><label>Speech voice / cloned profile</label><input class="c_voice" list="voice_list" value="${esc($('default_voice').value||'af_sarah')}"></div><div><label>Speech speed</label><input class="c_speed" type="number" step="0.05" value="1.0"></div></div>
<label><input class="c_lipsync" type="checkbox" checked style="width:auto"> Lip-sync speech/dialogue</label>
<h3>Multi-character dialogue</h3><div class="dialogues"></div><button class="btn secondary" onclick="addDialogue(this)">Add Dialogue Line</button>
<label>Music prompt for this clip</label><input class="c_music" placeholder="short local music bed for this clip">
<label>Sound effects prompt for this clip</label><input class="c_sfx" placeholder="rain, footsteps, door creak, city ambience">
</div>`; $('clips').appendChild(div); renderClipSelectors(div)}
function renderClipSelectors(root=document){for(let s of root.querySelectorAll('.c_chars')){let vals=[...s.selectedOptions].map(o=>o.value); s.innerHTML=options(profiles.characters); for(let o of s.options) if(vals.includes(o.value)) o.selected=true} for(let s of root.querySelectorAll('.c_envs')){let vals=[...s.selectedOptions].map(o=>o.value); s.innerHTML=options(profiles.environments); for(let o of s.options) if(vals.includes(o.value)) o.selected=true} for(let s of root.querySelectorAll('.d_char')){let val=s.value; s.innerHTML='<option value="">(none)</option>'+options(profiles.characters); s.value=val}}
function addDialogue(btn){let box=btn.closest('.clip').querySelector('.dialogues'); let row=document.createElement('div'); row.className='dialogue'; row.innerHTML=`<select class="d_char"><option value="">(none)</option>${options(profiles.characters)}</select><input class="d_voice" placeholder="voice/profile" value="${esc($('default_voice').value||'af_sarah')}"><input class="d_text" placeholder="dialogue text"><input class="d_start" placeholder="start s"><button class="btn bad" onclick="this.parentElement.remove()">x</button>`; box.appendChild(row)}
function addDialogue(btn){let box=btn.closest('.clip').querySelector('.dialogues'); let row=document.createElement('div'); row.className='dialogue'; row.innerHTML=`<select class="d_char"><option value="">(none)</option>${options(profiles.characters)}</select><input class="d_voice" list="voice_list" placeholder="voice/profile" value="${esc($('default_voice').value||'af_sarah')}"><input class="d_text" placeholder="dialogue text"><input class="d_start" placeholder="start s"><input class="d_speed" type="number" step="0.05" value="1.0"><button class="btn bad" onclick="this.parentElement.remove()">x</button>`; box.appendChild(row)}
function selected(sel){return [...sel.selectedOptions].map(o=>o.value)}
function collectMovie(){let clips=[...document.querySelectorAll('.clip')].map(c=>({title:c.querySelector('.c_title').value,prompt:c.querySelector('.c_prompt').value,characters:selected(c.querySelector('.c_chars')),environments:selected(c.querySelector('.c_envs')),camera_motion:c.querySelector('.c_camera').value,action:c.querySelector('.c_action').value,speech_text:c.querySelector('.c_speech').value,speech_voice:c.querySelector('.c_voice').value,lip_sync:c.querySelector('.c_lipsync').checked,music_prompt:c.querySelector('.c_music').value,sfx_prompt:c.querySelector('.c_sfx').value,dialogues:[...c.querySelectorAll('.dialogue')].map(d=>({character:d.querySelector('.d_char').value,voice:d.querySelector('.d_voice').value,text:d.querySelector('.d_text').value,start_time:d.querySelector('.d_start').value}))})); return {title:$('title').value,style:$('style').value,image_model:$('image_model').value,video_model:$('video_model').value,audio_model:$('audio_model').value,default_voice:$('default_voice').value,width:+$('width').value,height:+$('height').value,fps:+$('fps').value,num_frames:+$('num_frames').value,steps:+$('steps').value,guidance_scale:+$('guidance_scale').value,negative_prompt:$('negative_prompt').value,use_keyframes:$('use_keyframes').checked,soundtrack_prompt:$('soundtrack_prompt').value,clips}}
function collectMovie(){let clips=[...document.querySelectorAll('.clip')].map(c=>({title:c.querySelector('.c_title').value,prompt:c.querySelector('.c_prompt').value,characters:selected(c.querySelector('.c_chars')),environments:selected(c.querySelector('.c_envs')),camera_motion:c.querySelector('.c_camera').value,action:c.querySelector('.c_action').value,speech_text:c.querySelector('.c_speech').value,speech_voice:c.querySelector('.c_voice').value,speech_speed:c.querySelector('.c_speed').value,lip_sync:c.querySelector('.c_lipsync').checked,lip_sync_method:$('lip_sync_method').value,music_prompt:c.querySelector('.c_music').value,sfx_prompt:c.querySelector('.c_sfx').value,dialogues:[...c.querySelectorAll('.dialogue')].map(d=>({character:d.querySelector('.d_char').value,voice:d.querySelector('.d_voice').value,text:d.querySelector('.d_text').value,start_time:d.querySelector('.d_start').value,speed:d.querySelector('.d_speed').value,lip_sync:c.querySelector('.c_lipsync').checked}))})); return {title:$('title').value,style:$('style').value,image_model:$('image_model').value,video_model:$('video_model').value,audio_model:$('audio_model').value,default_voice:$('default_voice').value,lip_sync_method:$('lip_sync_method').value,width:+$('width').value,height:+$('height').value,fps:+$('fps').value,num_frames:+$('num_frames').value,steps:+$('steps').value,guidance_scale:+$('guidance_scale').value,negative_prompt:$('negative_prompt').value,use_keyframes:$('use_keyframes').checked,soundtrack_prompt:$('soundtrack_prompt').value,loras:selected($('movie_loras')).map(n=>({name:n,weight:+$('movie_lora_weight').value})),lora_weight:+$('movie_lora_weight').value,movie_count:+$('movie_count').value,clips}}
async function startMovie(){let d=await api('/api/movie/start',{method:'POST',body:JSON.stringify(collectMovie())}); watchJob(d.job_id)}
async function watchJob(id){$('jobout').innerHTML=`<p>Job <span class="pill">${id}</span></p>`; let timer=setInterval(async()=>{let j=await api('/api/job/'+id); $('jobout').innerHTML=`<p><span class="pill">${esc(j.status)}</span> ${j.progress||0}% ${esc(j.message||'')}</p>`+(j.output_url?`<p><a href="${j.output_url}" target="_blank">Open output</a></p>`:'')+(j.error?`<p style="color:var(--bad)">${esc(j.error)}</p>`:''); if(j.status==='done'||j.status==='error'){clearInterval(timer); loadProfiles(); loadGallery()}},1500)}
async function loadGallery(){let d=await api('/api/gallery'); $('gallery_grid').innerHTML=(d.items||[]).map(it=>`<div class="profile">${it.type==='video'?`<video src="${it.url}" controls style="width:100%;height:130px;background:#000"></video>`:`<img src="${it.url}">`}<div class="p"><b>${esc(it.name)}</b><br><a href="${it.url}" target="_blank">open</a></div></div>`).join('')||'<div class="muted">No media yet.</div>'}
......@@ -871,6 +1225,10 @@ def make_handler(app: VideoGenApp):
payload = self._read_json()
if path == "/api/profile/start":
self._json({"job_id": app.start_profile_job(payload)})
elif path == "/api/voice/start":
self._json({"job_id": app.start_voice_job(payload)})
elif path == "/api/lora/start":
self._json({"job_id": app.start_lora_job(payload)})
elif path == "/api/movie/start":
self._json({"job_id": app.start_movie_job(payload)})
elif path.startswith("/api/job/") and path.endswith("/cancel"):
......@@ -909,6 +1267,7 @@ def build_parser() -> argparse.ArgumentParser:
parser.add_argument("--base-url", default=DEFAULT_BASE_URL, help="CoderAI base URL")
parser.add_argument("--api-key", default=DEFAULT_API_KEY, help="Bearer token for CoderAI")
parser.add_argument("--out-dir", default=DEFAULT_OUT_DIR, help="Local output directory")
parser.add_argument("--host", default="0.0.0.0", help="Web UI listen host (default: 0.0.0.0)")
parser.add_argument("--web-port", type=int, default=7790, help="Local web UI port")
parser.add_argument("--image-model", default="", help="Default image model")
parser.add_argument("--video-model", default="", help="Default video model")
......@@ -921,8 +1280,8 @@ def build_parser() -> argparse.ArgumentParser:
def main(argv: list[str] | None = None) -> None:
args = build_parser().parse_args(argv)
app = VideoGenApp(args)
server = ThreadedHTTPServer(("127.0.0.1", args.web_port), make_handler(app))
url = f"http://127.0.0.1:{args.web_port}"
server = ThreadedHTTPServer((args.host, args.web_port), make_handler(app))
url = f"http://{args.host}:{args.web_port}"
log(f"VideoGen Studio running at {url}")
log(f"CoderAI: {args.base_url}")
log(f"Output: {Path(args.out_dir).resolve()}")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment