LoRA training: async kickoff, restart recovery, keyframe regen UI

Server (codai/api/loras.py):
- /v1/loras/train gains wait (default True) + session; wait=false detaches
  the job and returns a job_id, avoiding HTTP read-timeouts on multi-hour
  video trainings.
- Disk-persisted job registry keyed by job_id (carries session). Progress
  endpoint serves ?job=<id> / ?session=<tok> so a client only ever sees its
  own job — no cross-user spillover. Jobs left mid-flight at startup are
  marked interrupted.
- Mid-training PEFT checkpoints (SD1.5/SDXL/Wan) + train_state.json; a
  resubmit resumes from the last step when base/target/rank (and session)
  match, so a reboot no longer throws away hours of Wan training.

Township (tools/gen_township_fighters.py):
- Async training: per-run session token + persisted per-LoRA job_id; polls
  by job_id, re-attaches to a running server job after a restart, resubmits
  an interrupted one (server resumes from checkpoint).
- Dedicated train timeouts (24h video / 4h image).
- Match page: regenerate/clear keyframes (match-level + per-clip/outcome)
  via new keyframes/keyframe render + delete scopes.

tools/videogen.py: mirror the session-token + job-id recovery helpers.
Co-Authored-By: 's avatarClaude Opus 4.8 <noreply@anthropic.com>
parent f21c6185
...@@ -64,6 +64,94 @@ _progress = { ...@@ -64,6 +64,94 @@ _progress = {
} }
_train_lock = threading.Lock() _train_lock = threading.Lock()
# ── Multi-client job registry ─────────────────────────────────────────────────
# Training is serialized server-wide (one GPU job at a time), but several clients
# may submit concurrently. Each submission gets a unique job_id and is recorded
# here so a client only ever polls *its own* job (no cross-user spillover). The
# registry is mirrored to disk so a recovering client (after a server OR client
# restart) can re-attach by job_id / session token. `_active_job_id` names the
# job currently executing, so _set_progress can mirror live progress into it.
_jobs_lock = threading.Lock()
_jobs: dict = {} # job_id -> record
_active_job_id: Optional[str] = None
_bg_tasks: set = set() # strong refs to detached train tasks
_JOB_ACTIVE_STATES = ("queued", "preparing", "training", "saving")
_MIRROR_FIELDS = ("active", "name", "step", "total", "status", "message",
"started_at", "path")
_last_job_persist = 0.0
def _jobs_file() -> str:
return os.path.join(_loras_dir(), "_train_jobs.json")
def _persist_jobs_locked(force: bool = False) -> None:
"""Write the registry to disk. Throttled to ~5s unless forced (status change)
so per-step progress updates don't hammer the disk."""
global _last_job_persist
now = time.time()
if not force and (now - _last_job_persist) < 5.0:
return
_last_job_persist = now
try:
p = _jobs_file()
tmp = p + ".tmp"
with open(tmp, "w") as f:
json.dump(_jobs, f)
os.replace(tmp, p)
except Exception:
pass
def _new_job(job_id: str, **fields) -> None:
with _jobs_lock:
rec = {"job_id": job_id, "updated_at": time.time()}
rec.update(fields)
_jobs[job_id] = rec
_persist_jobs_locked(force=True)
def _update_job(job_id: Optional[str], force: bool = False, **fields) -> None:
if not job_id:
return
with _jobs_lock:
rec = _jobs.get(job_id)
if rec is None:
rec = {"job_id": job_id}
_jobs[job_id] = rec
# A status change is significant — persist immediately.
if "status" in fields and fields["status"] != rec.get("status"):
force = True
rec.update(fields)
rec["updated_at"] = time.time()
_persist_jobs_locked(force=force)
def _load_jobs_on_start() -> None:
"""Load the persisted registry and mark any job that was mid-flight when the
process died as 'interrupted' (its in-memory GPU training is gone). Clients
polling such a job learn it must be resubmitted (server resumes from the last
on-disk checkpoint)."""
global _jobs
try:
with open(_jobs_file()) as f:
data = json.load(f)
except Exception:
return
if not isinstance(data, dict):
return
changed = False
for rec in data.values():
if rec.get("status") in _JOB_ACTIVE_STATES:
rec["status"] = "interrupted"
rec["active"] = False
rec["message"] = "interrupted by server restart — resubmit to resume"
changed = True
with _jobs_lock:
_jobs = data
if changed:
_persist_jobs_locked(force=True)
def set_global_args(args): def set_global_args(args):
global _LORAS_DIR global _LORAS_DIR
...@@ -76,6 +164,7 @@ def set_global_args(args): ...@@ -76,6 +164,7 @@ def set_global_args(args):
root = None root = None
_LORAS_DIR = os.path.join(root, 'loras') if root else str(default_loras_dir()) _LORAS_DIR = os.path.join(root, 'loras') if root else str(default_loras_dir())
os.makedirs(_LORAS_DIR, exist_ok=True) os.makedirs(_LORAS_DIR, exist_ok=True)
_load_jobs_on_start()
def _loras_dir() -> str: def _loras_dir() -> str:
...@@ -153,6 +242,15 @@ class LoraTrainRequest(BaseModel): ...@@ -153,6 +242,15 @@ class LoraTrainRequest(BaseModel):
learning_rate: Optional[float] = 1e-4 learning_rate: Optional[float] = 1e-4
resolution: Optional[int] = 512 resolution: Optional[int] = 512
seed: Optional[int] = 42 seed: Optional[int] = 42
# Run the job asynchronously: the POST returns immediately with a job_id and
# the client polls /v1/loras/progress?job=<id>. Avoids long read-timeouts on
# multi-hour video trainings. Default True keeps blocking behaviour for old
# clients that expect the result inline.
wait: Optional[bool] = True
# Caller-supplied session token. Recorded with the job so the owning client —
# and only it — can recover its job(s) after a restart. Auto-generated if
# omitted.
session: Optional[str] = None
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow")
...@@ -246,6 +344,13 @@ def _gather_images(req: LoraTrainRequest): ...@@ -246,6 +344,13 @@ def _gather_images(req: LoraTrainRequest):
def _set_progress(**kw): def _set_progress(**kw):
with _progress_lock: with _progress_lock:
_progress.update(kw) _progress.update(kw)
job_id = _active_job_id
# Mirror live progress into the executing job's record so its owner (and only
# its owner) can poll it by job_id.
if job_id:
mirror = {k: kw[k] for k in kw if k in _MIRROR_FIELDS}
if mirror:
_update_job(job_id, **mirror)
def _lora_debug_enabled() -> bool: def _lora_debug_enabled() -> bool:
...@@ -295,6 +400,83 @@ def _free_train_vram() -> None: ...@@ -295,6 +400,83 @@ def _free_train_vram() -> None:
pass pass
# ── Mid-training checkpoints (resume across restarts) ─────────────────────────
# Every _CKPT_EVERY steps we snapshot the raw PEFT adapter state(s) + a
# train_state.json into the LoRA's own folder (name-scoped, so two clients
# training different LoRAs never collide). If the process dies, the next
# submission for the same name/base/target/rank reloads the snapshot and
# continues from the saved step instead of restarting from scratch. Optimizer
# momentum is not restored (re-warms in a few steps); the trained weights — the
# expensive part — are preserved. The snapshot records the owning session so a
# resume can verify ownership.
_CKPT_EVERY = 100
def _train_state_path(name: str) -> str:
return os.path.join(_lora_dir(name), "train_state.json")
def _save_train_checkpoint(name: str, state: dict, peft_states: dict) -> None:
from safetensors.torch import save_file
d = _lora_dir(name)
os.makedirs(d, exist_ok=True)
try:
for key, sd in peft_states.items():
cpu_sd = {k: v.detach().to("cpu") for k, v in sd.items()}
save_file(cpu_sd, os.path.join(d, f"_ckpt_{key}.safetensors"))
tmp = _train_state_path(name) + ".tmp"
with open(tmp, "w") as f:
json.dump(state, f)
os.replace(tmp, _train_state_path(name))
except Exception as e:
_dbg_lora(f"checkpoint save failed for '{name}': {e}")
def _load_train_state(name: str, *, base_path: str, target: str,
rank: int, session: Optional[str] = None) -> Optional[dict]:
"""Return a valid resumable checkpoint state for `name`, or None. Validates
that base/target/rank match (and session, when provided) so we never resume
one config's weights into another's run."""
p = _train_state_path(name)
if not os.path.isfile(p):
return None
try:
with open(p) as f:
st = json.load(f)
except Exception:
return None
if (st.get("base_path") != base_path or st.get("target") != target
or int(st.get("rank", -1)) != int(rank)):
return None
if session is not None and st.get("session") not in (None, session):
return None
step = int(st.get("step", 0))
if step <= 0 or step >= int(st.get("total", 0)):
return None
return st
def _apply_peft_checkpoint(name: str, key: str, model) -> None:
from safetensors.torch import load_file
from peft import set_peft_model_state_dict
f = os.path.join(_lora_dir(name), f"_ckpt_{key}.safetensors")
if not os.path.isfile(f):
raise FileNotFoundError(f)
set_peft_model_state_dict(model, load_file(f), adapter_name="default")
def _clear_train_checkpoint(name: str) -> None:
d = _lora_dir(name)
if not os.path.isdir(d):
return
for fn in os.listdir(d):
if fn.startswith("_ckpt_") or fn == "train_state.json":
try:
os.remove(os.path.join(d, fn))
except Exception:
pass
# ── Training base-model cache ──────────────────────────────────────────────── # ── Training base-model cache ────────────────────────────────────────────────
# The SD1.x/SDXL base used for LoRA training is expensive to load from disk. # The SD1.x/SDXL base used for LoRA training is expensive to load from disk.
# We keep its components cached *on CPU* between consecutive training jobs so a # We keep its components cached *on CPU* between consecutive training jobs so a
...@@ -633,6 +815,21 @@ def _train_sd15(req, base_path, images, instance_prompt, ...@@ -633,6 +815,21 @@ def _train_sd15(req, base_path, images, instance_prompt,
lora_params = [p for p in unet.parameters() if p.requires_grad] lora_params = [p for p in unet.parameters() if p.requires_grad]
optimizer = torch.optim.AdamW(lora_params, lr=lr) optimizer = torch.optim.AdamW(lora_params, lr=lr)
# Resume from a mid-training checkpoint if one survives a prior restart.
start_step = 0
_ck = _load_train_state(name, base_path=base_path, target="image", rank=rank,
session=getattr(req, "session", None))
if _ck:
try:
_apply_peft_checkpoint(name, "default", unet)
start_step = int(_ck["step"])
_set_progress(step=start_step,
message=f"resuming from step {start_step}/{steps}")
_dbg_lora(f"resumed SD1.5 '{name}' from checkpoint step {start_step}")
except Exception as e:
print(f" [lora] could not resume '{name}': {e}; starting fresh")
start_step = 0
# Pre-encode latents and the (single) instance-prompt embedding. # Pre-encode latents and the (single) instance-prompt embedding.
latents_list = _make_dataset(images, [tokenizer], [text_encoder], instance_prompt, latents_list = _make_dataset(images, [tokenizer], [text_encoder], instance_prompt,
resolution, vae, device, torch.float32, is_sdxl=False) resolution, vae, device, torch.float32, is_sdxl=False)
...@@ -654,7 +851,7 @@ def _train_sd15(req, base_path, images, instance_prompt, ...@@ -654,7 +851,7 @@ def _train_sd15(req, base_path, images, instance_prompt,
_set_progress(status="training", message="training (SD1.5)") _set_progress(status="training", message="training (SD1.5)")
unet.train() unet.train()
n = len(latents_list) n = len(latents_list)
for step in range(steps): for step in range(start_step, steps):
latents = latents_list[step % n].to(device) latents = latents_list[step % n].to(device)
noise = torch.randn_like(latents) noise = torch.randn_like(latents)
bsz = latents.shape[0] bsz = latents.shape[0]
...@@ -676,6 +873,12 @@ def _train_sd15(req, base_path, images, instance_prompt, ...@@ -676,6 +873,12 @@ def _train_sd15(req, base_path, images, instance_prompt,
if step % 10 == 0 or step == steps - 1: if step % 10 == 0 or step == steps - 1:
_set_progress(step=step + 1, message=f"step {step+1}/{steps} loss={loss.item():.4f}") _set_progress(step=step + 1, message=f"step {step+1}/{steps} loss={loss.item():.4f}")
_dbg_lora(f"SD1.5 step {step+1}/{steps} loss={loss.item():.4f}") _dbg_lora(f"SD1.5 step {step+1}/{steps} loss={loss.item():.4f}")
if (step + 1) % _CKPT_EVERY == 0 and step + 1 < steps:
_save_train_checkpoint(name,
{"name": name, "base_path": base_path, "target": "image",
"rank": rank, "step": step + 1, "total": steps, "seed": seed,
"session": getattr(req, "session", None)},
{"default": get_peft_model_state_dict(unet)})
# Mid-training thermal checkpoint (pauses if CPU/GPU too hot). # Mid-training thermal checkpoint (pauses if CPU/GPU too hot).
try: try:
from codai.models.thermal import checkpoint as _thermal_checkpoint from codai.models.thermal import checkpoint as _thermal_checkpoint
...@@ -691,6 +894,7 @@ def _train_sd15(req, base_path, images, instance_prompt, ...@@ -691,6 +894,7 @@ def _train_sd15(req, base_path, images, instance_prompt,
unet_lora_layers=unet_lora, unet_lora_layers=unet_lora,
safe_serialization=True) safe_serialization=True)
_write_meta(name, req, base_path, len(images), "sd15", instance_prompt) _write_meta(name, req, base_path, len(images), "sd15", instance_prompt)
_clear_train_checkpoint(name)
# Job done: drop this job's adapter + transients and move the UNet back to # Job done: drop this job's adapter + transients and move the UNet back to
# CPU. The base stays cached on CPU (reused by the next job); no VRAM is held # CPU. The base stays cached on CPU (reused by the next job); no VRAM is held
...@@ -754,6 +958,21 @@ def _train_sdxl(req, base_path, images, instance_prompt, ...@@ -754,6 +958,21 @@ def _train_sdxl(req, base_path, images, instance_prompt,
lora_params = [p for p in unet.parameters() if p.requires_grad] lora_params = [p for p in unet.parameters() if p.requires_grad]
optimizer = torch.optim.AdamW(lora_params, lr=lr) optimizer = torch.optim.AdamW(lora_params, lr=lr)
# Resume from a mid-training checkpoint if one survives a prior restart.
start_step = 0
_ck = _load_train_state(name, base_path=base_path, target="image", rank=rank,
session=getattr(req, "session", None))
if _ck:
try:
_apply_peft_checkpoint(name, "default", unet)
start_step = int(_ck["step"])
_set_progress(step=start_step,
message=f"resuming from step {start_step}/{steps}")
_dbg_lora(f"resumed SDXL '{name}' from checkpoint step {start_step}")
except Exception as e:
print(f" [lora] could not resume '{name}': {e}; starting fresh")
start_step = 0
tfm = transforms.Compose([ tfm = transforms.Compose([
transforms.Resize(resolution, interpolation=transforms.InterpolationMode.BILINEAR), transforms.Resize(resolution, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(resolution), transforms.CenterCrop(resolution),
...@@ -802,7 +1021,7 @@ def _train_sdxl(req, base_path, images, instance_prompt, ...@@ -802,7 +1021,7 @@ def _train_sdxl(req, base_path, images, instance_prompt,
_set_progress(status="training", message="training (SDXL)") _set_progress(status="training", message="training (SDXL)")
unet.train() unet.train()
n = len(latents_list) n = len(latents_list)
for step in range(steps): for step in range(start_step, steps):
latents = latents_list[step % n].to(device) latents = latents_list[step % n].to(device)
noise = torch.randn_like(latents) noise = torch.randn_like(latents)
bsz = latents.shape[0] bsz = latents.shape[0]
...@@ -825,6 +1044,12 @@ def _train_sdxl(req, base_path, images, instance_prompt, ...@@ -825,6 +1044,12 @@ def _train_sdxl(req, base_path, images, instance_prompt,
if step % 10 == 0 or step == steps - 1: if step % 10 == 0 or step == steps - 1:
_set_progress(step=step + 1, message=f"step {step+1}/{steps} loss={loss.item():.4f}") _set_progress(step=step + 1, message=f"step {step+1}/{steps} loss={loss.item():.4f}")
_dbg_lora(f"SDXL step {step+1}/{steps} loss={loss.item():.4f}") _dbg_lora(f"SDXL step {step+1}/{steps} loss={loss.item():.4f}")
if (step + 1) % _CKPT_EVERY == 0 and step + 1 < steps:
_save_train_checkpoint(name,
{"name": name, "base_path": base_path, "target": "image",
"rank": rank, "step": step + 1, "total": steps, "seed": seed,
"session": getattr(req, "session", None)},
{"default": get_peft_model_state_dict(unet)})
try: try:
from codai.models.thermal import checkpoint as _thermal_checkpoint from codai.models.thermal import checkpoint as _thermal_checkpoint
_thermal_checkpoint(context="lora-train", throttle_seconds=2.0) _thermal_checkpoint(context="lora-train", throttle_seconds=2.0)
...@@ -839,6 +1064,7 @@ def _train_sdxl(req, base_path, images, instance_prompt, ...@@ -839,6 +1064,7 @@ def _train_sdxl(req, base_path, images, instance_prompt,
unet_lora_layers=unet_lora, unet_lora_layers=unet_lora,
safe_serialization=True) safe_serialization=True)
_write_meta(name, req, base_path, len(images), "sdxl", instance_prompt) _write_meta(name, req, base_path, len(images), "sdxl", instance_prompt)
_clear_train_checkpoint(name)
# Job done: drop this job's adapter + transients and move the UNet back to # Job done: drop this job's adapter + transients and move the UNet back to
# CPU. The base stays cached on CPU for the next job; no VRAM held afterwards # CPU. The base stays cached on CPU for the next job; no VRAM held afterwards
...@@ -1009,6 +1235,25 @@ def _train_wan(req, base_path, images, instance_prompt, ...@@ -1009,6 +1235,25 @@ def _train_wan(req, base_path, images, instance_prompt,
detail="Wan LoRA: no trainable adapter params were created") detail="Wan LoRA: no trainable adapter params were created")
optimizer = torch.optim.AdamW(lora_params, lr=lr) optimizer = torch.optim.AdamW(lora_params, lr=lr)
# Resume from a mid-training checkpoint if one survives a prior restart. This
# is the big win for Wan: the A14B base reload alone is ~27 min, and training
# runs for hours — a reboot otherwise throws all of it away.
start_step = 0
_ck = _load_train_state(name, base_path=base_path, target="video", rank=rank,
session=getattr(req, "session", None))
if _ck:
try:
_apply_peft_checkpoint(name, "t", experts[0][1])
if len(experts) > 1:
_apply_peft_checkpoint(name, "t2", experts[1][1])
start_step = int(_ck["step"])
_set_progress(step=start_step,
message=f"resuming from step {start_step}/{steps}")
_dbg_lora(f"resumed Wan '{name}' from checkpoint step {start_step}")
except Exception as e:
print(f" [lora] could not resume Wan '{name}': {e}; starting fresh")
start_step = 0
try: try:
sched = FlowMatchEulerDiscreteScheduler.from_pretrained(base_path, subfolder="scheduler") sched = FlowMatchEulerDiscreteScheduler.from_pretrained(base_path, subfolder="scheduler")
shift = float(getattr(sched.config, "shift", 3.0) or 3.0) shift = float(getattr(sched.config, "shift", 3.0) or 3.0)
...@@ -1024,7 +1269,7 @@ def _train_wan(req, base_path, images, instance_prompt, ...@@ -1024,7 +1269,7 @@ def _train_wan(req, base_path, images, instance_prompt,
_set_progress(status="training", message="training (Wan video LoRA)") _set_progress(status="training", message="training (Wan video LoRA)")
n = len(latents_list) n = len(latents_list)
for step in range(steps): for step in range(start_step, steps):
x0 = latents_list[step % n].to(device, dtype=compute_dtype) x0 = latents_list[step % n].to(device, dtype=compute_dtype)
noise = torch.randn_like(x0) noise = torch.randn_like(x0)
# Rectified-flow timestep with Wan resolution shift applied to sigma. # Rectified-flow timestep with Wan resolution shift applied to sigma.
...@@ -1054,6 +1299,15 @@ def _train_wan(req, base_path, images, instance_prompt, ...@@ -1054,6 +1299,15 @@ def _train_wan(req, base_path, images, instance_prompt,
message=f"step {step+1}/{steps} loss={loss.item():.4f}") message=f"step {step+1}/{steps} loss={loss.item():.4f}")
if step % 10 == 0 or step == steps - 1: if step % 10 == 0 or step == steps - 1:
_dbg_lora(f"Wan step {step+1}/{steps} loss={loss.item():.4f}") _dbg_lora(f"Wan step {step+1}/{steps} loss={loss.item():.4f}")
if (step + 1) % _CKPT_EVERY == 0 and step + 1 < steps:
_ckpt_states = {"t": get_peft_model_state_dict(experts[0][1])}
if len(experts) > 1:
_ckpt_states["t2"] = get_peft_model_state_dict(experts[1][1])
_save_train_checkpoint(name,
{"name": name, "base_path": base_path, "target": "video",
"rank": rank, "step": step + 1, "total": steps, "seed": seed,
"session": getattr(req, "session", None)},
_ckpt_states)
try: try:
from codai.models.thermal import checkpoint as _thermal_checkpoint from codai.models.thermal import checkpoint as _thermal_checkpoint
_thermal_checkpoint(context="lora-train", throttle_seconds=2.0) _thermal_checkpoint(context="lora-train", throttle_seconds=2.0)
...@@ -1081,6 +1335,7 @@ def _train_wan(req, base_path, images, instance_prompt, ...@@ -1081,6 +1335,7 @@ def _train_wan(req, base_path, images, instance_prompt,
WanPipeline.save_lora_weights(save_directory=save_dir, WanPipeline.save_lora_weights(save_directory=save_dir,
safe_serialization=True, **save_kwargs) safe_serialization=True, **save_kwargs)
_write_meta(name, req, base_path, len(images), "wan", instance_prompt) _write_meta(name, req, base_path, len(images), "wan", instance_prompt)
_clear_train_checkpoint(name)
# ── 5. Tear down THIS job's adapter, but KEEP the transformer(s) cached so a # ── 5. Tear down THIS job's adapter, but KEEP the transformer(s) cached so a
# next training against the same base skips the (very slow) reload. Remove the # next training against the same base skips the (very slow) reload. Remove the
...@@ -1142,17 +1397,25 @@ def _write_meta(name, req, base_path, n_images, arch, instance_prompt): ...@@ -1142,17 +1397,25 @@ def _write_meta(name, req, base_path, n_images, arch, instance_prompt):
_TRAIN_MODEL_KEY = "lora-train" _TRAIN_MODEL_KEY = "lora-train"
def _train_lora_blocking(req: LoraTrainRequest) -> dict: def _train_lora_blocking(req: LoraTrainRequest, job_id: Optional[str] = None) -> dict:
"""Run one training job to completion (called inside a worker thread). """Run one training job to completion (called inside a worker thread).
Holds _train_lock for the job's duration so a second training never overlaps Holds _train_lock for the job's duration so a second training never overlaps
(also the signal _release_base_cache uses to know a job is in flight). The (also the signal _release_base_cache uses to know a job is in flight). The
central queue already serializes us, so this acquire returns immediately. central queue already serializes us, so this acquire returns immediately.
`_active_job_id` is set so live progress mirrors into this job's record (and
only this job's) for its owner to poll.
""" """
global _active_job_id
_train_lock.acquire() _train_lock.acquire()
_active_job_id = job_id
try: try:
return _train_lora_sync(req) result = _train_lora_sync(req)
except Exception: if job_id:
_update_job(job_id, status="done", active=False, force=True,
message="done", path=result.get("path"))
return result
except Exception as e:
import traceback import traceback
traceback.print_exc() traceback.print_exc()
# On error the base may be in a half-moved / inconsistent state — drop the # On error the base may be in a half-moved / inconsistent state — drop the
...@@ -1161,18 +1424,40 @@ def _train_lora_blocking(req: LoraTrainRequest) -> dict: ...@@ -1161,18 +1424,40 @@ def _train_lora_blocking(req: LoraTrainRequest) -> dict:
_set_progress(active=False, status="error", message="training failed") _set_progress(active=False, status="error", message="training failed")
except Exception: except Exception:
pass pass
if job_id:
_update_job(job_id, status="error", active=False, force=True,
message=f"training failed: {e}"[:300])
_drop_base_cache() _drop_base_cache()
_drop_wan_cache() _drop_wan_cache()
raise raise
finally: finally:
_active_job_id = None
_train_lock.release() _train_lock.release()
async def _run_train_job(req: LoraTrainRequest, job_id: str) -> dict:
"""Acquire a scheduler slot, then run the blocking job in a worker thread.
Used for both the inline (wait=True) and detached (wait=False) paths so the
queue serialization and job bookkeeping are identical."""
import asyncio
request_id = f"lora-train-{job_id}"
lease = await queue_manager.acquire(request_id, _TRAIN_MODEL_KEY)
try:
return await asyncio.to_thread(_train_lora_blocking, req, job_id)
finally:
await queue_manager.release(lease)
@router.post("/v1/loras/train") @router.post("/v1/loras/train")
async def train_lora(req: LoraTrainRequest, _auth=Depends(_require_api_auth)): async def train_lora(req: LoraTrainRequest, _auth=Depends(_require_api_auth)):
"""Train a LoRA (blocking). Admitted through the central request scheduler, """Train a LoRA. Admitted through the central request scheduler so concurrent
so concurrent training requests queue and run one after another (instead of trainings queue and run one-at-a-time alongside all other model requests.
being rejected) alongside all other model requests."""
With `wait=false` the call returns immediately with a `job_id`; the client
polls /v1/loras/progress?job=<job_id>. The job keeps running on the server
regardless of the client — so a client that dies/restarts can re-attach by
job_id (or list its session's jobs with ?session=<token>) rather than
restarting the training."""
import asyncio import asyncio
import uuid import uuid
if not req.name or '/' in req.name or '..' in req.name: if not req.name or '/' in req.name or '..' in req.name:
...@@ -1180,23 +1465,72 @@ async def train_lora(req: LoraTrainRequest, _auth=Depends(_require_api_auth)): ...@@ -1180,23 +1465,72 @@ async def train_lora(req: LoraTrainRequest, _auth=Depends(_require_api_auth)):
if not req.base_model: if not req.base_model:
raise HTTPException(status_code=400, detail="base_model is required") raise HTTPException(status_code=400, detail="base_model is required")
request_id = f"lora-train-{uuid.uuid4().hex[:8]}" session = req.session or f"sess-{uuid.uuid4().hex[:12]}"
# Wait for a scheduler slot (queues behind other in-flight work; the constant req.session = session
# model key keeps trainings strictly one-at-a-time). job_id = f"job-{uuid.uuid4().hex[:12]}"
lease = await queue_manager.acquire(request_id, _TRAIN_MODEL_KEY) _new_job(job_id, session=session, name=req.name, base_model=req.base_model,
target=(req.target or "image"), status="queued", active=True,
step=0, total=req.steps or 0, message="queued",
started_at=time.time(), path=None)
if not req.wait:
# Detach: run independently of this request's lifetime. Keep a strong
# reference so the loop doesn't garbage-collect (cancel) the task.
task = asyncio.create_task(_detached_train(req, job_id))
_bg_tasks.add(task)
task.add_done_callback(_bg_tasks.discard)
return {"ok": True, "async": True, "job_id": job_id, "session": session,
"status": "queued"}
try: try:
result = await asyncio.to_thread(_train_lora_blocking, req) result = await _run_train_job(req, job_id)
except HTTPException: except HTTPException:
raise raise
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail=f"LoRA training failed: {e}") raise HTTPException(status_code=500, detail=f"LoRA training failed: {e}")
finally: return {"ok": True, "job_id": job_id, "session": session, **result}
await queue_manager.release(lease)
return {"ok": True, **result}
async def _detached_train(req: LoraTrainRequest, job_id: str) -> None:
"""Background runner for wait=false jobs. Errors are recorded on the job
record (the client polls for them); nothing is raised to a waiting request."""
try:
await _run_train_job(req, job_id)
except Exception as e:
_update_job(job_id, status="error", active=False, force=True,
message=f"training failed: {e}"[:300])
@router.get("/v1/loras/progress") @router.get("/v1/loras/progress")
async def lora_progress(): async def lora_progress(job: Optional[str] = None, session: Optional[str] = None):
"""Training progress.
- `?job=<job_id>` → that job's record (a client polls only its own job).
- `?session=<tok>` → the most recent job for that session (recovery after a
client restart that lost the job_id).
- neither → the global active-job snapshot (back-compat).
"""
if job:
with _jobs_lock:
rec = _jobs.get(job)
rec = dict(rec) if rec else None
if rec is None:
return {"active": False, "status": "unknown", "job_id": job,
"name": None, "step": 0, "total": 0,
"message": "no such job"}
return rec
if session:
with _jobs_lock:
mine = [dict(r) for r in _jobs.values() if r.get("session") == session]
if not mine:
return {"active": False, "status": "unknown", "session": session,
"name": None, "step": 0, "total": 0,
"message": "no jobs for session"}
# Prefer an active job; otherwise the most recently updated.
active = [r for r in mine if r.get("status") in _JOB_ACTIVE_STATES]
pick = (active or mine)
pick.sort(key=lambda r: r.get("updated_at", 0))
return pick[-1]
with _progress_lock: with _progress_lock:
return dict(_progress) return dict(_progress)
......
...@@ -16,6 +16,7 @@ import sys ...@@ -16,6 +16,7 @@ import sys
import tempfile import tempfile
import threading import threading
import time import time
import urllib.parse
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
...@@ -560,8 +561,9 @@ class CoderAIClient: ...@@ -560,8 +561,9 @@ class CoderAIClient:
if api_key: if api_key:
self.session.headers["Authorization"] = f"Bearer {api_key}" self.session.headers["Authorization"] = f"Bearer {api_key}"
def _post(self, path: str, body: dict) -> dict: def _post(self, path: str, body: dict, timeout: int = None) -> dict:
r = self.session.post(f"{self.base}{path}", json=body, timeout=self.timeout) r = self.session.post(f"{self.base}{path}", json=body,
timeout=timeout if timeout is not None else self.timeout)
if not r.ok: if not r.ok:
raise RuntimeError(f"POST {path} → {r.status_code}: {r.text[:400]}") raise RuntimeError(f"POST {path} → {r.status_code}: {r.text[:400]}")
return r.json() return r.json()
...@@ -696,7 +698,8 @@ class CoderAIClient: ...@@ -696,7 +698,8 @@ class CoderAIClient:
environment: str = None, images: list = None, environment: str = None, images: list = None,
steps: int = 800, rank: int = 16, steps: int = 800, rank: int = 16,
resolution: int = 512, train_base_model: str = None, resolution: int = 512, train_base_model: str = None,
target: str = "image", quantize_4bit: bool = True) -> dict: target: str = "image", quantize_4bit: bool = True,
wait: bool = True, session: str = None) -> dict:
"""Train a per-character or per-environment LoRA on the server. """Train a per-character or per-environment LoRA on the server.
Blocks until complete. Blocks until complete.
...@@ -705,7 +708,9 @@ class CoderAIClient: ...@@ -705,7 +708,9 @@ class CoderAIClient:
so it loads directly on the video pipeline.""" so it loads directly on the video pipeline."""
body = {"name": name, "base_model": base_model, body = {"name": name, "base_model": base_model,
"steps": int(steps), "rank": int(rank), "resolution": int(resolution), "steps": int(steps), "rank": int(rank), "resolution": int(resolution),
"target": target} "target": target, "wait": bool(wait)}
if session:
body["session"] = session
if target == "video": if target == "video":
body["quantize_4bit"] = bool(quantize_4bit) body["quantize_4bit"] = bool(quantize_4bit)
if train_base_model: if train_base_model:
...@@ -716,7 +721,11 @@ class CoderAIClient: ...@@ -716,7 +721,11 @@ class CoderAIClient:
body["environment"] = environment body["environment"] = environment
if images: if images:
body["images"] = images body["images"] = images
return self._post("/v1/loras/train", body) # Video DiT training (e.g. Wan A14B) can take many hours including the
# one-off model load; image LoRA is quicker. Allow a long ceiling so the
# blocking POST doesn't read-timeout while the server is still training.
train_timeout = 24 * 3600 if target == "video" else 4 * 3600
return self._post("/v1/loras/train", body, timeout=train_timeout)
def list_loras(self) -> list: def list_loras(self) -> list:
try: try:
...@@ -724,9 +733,14 @@ class CoderAIClient: ...@@ -724,9 +733,14 @@ class CoderAIClient:
except Exception: except Exception:
return [] return []
def lora_progress(self) -> dict: def lora_progress(self, job: str = None, session: str = None) -> dict:
try: try:
return self._get("/v1/loras/progress") q = ""
if job:
q = f"?job={urllib.parse.quote(job)}"
elif session:
q = f"?session={urllib.parse.quote(session)}"
return self._get(f"/v1/loras/progress{q}")
except Exception: except Exception:
return {} return {}
...@@ -1381,6 +1395,51 @@ def _load_json_map(path: Path) -> dict: ...@@ -1381,6 +1395,51 @@ def _load_json_map(path: Path) -> dict:
return {} return {}
# ── Training session token + per-LoRA server job tracking ─────────────────────
# A stable per-output session token lets the server tag this client's training
# jobs so we (and only we) can recover them after a restart — no spillover into
# another user's concurrent run. Per-LoRA server job_ids are persisted so that
# after a township restart we re-attach to a still-running server job instead of
# launching a duplicate (the server keeps training regardless of this client).
def _session_token(out_dir: Path) -> str:
p = Path(out_dir) / ".train_session"
try:
if p.exists():
tok = p.read_text().strip()
if tok:
return tok
except Exception:
pass
import uuid as _uuid
tok = f"township-{_uuid.uuid4().hex[:12]}"
try:
p.write_text(tok)
except Exception:
pass
return tok
def _train_jobs_file(out_dir: Path) -> Path:
return Path(out_dir) / "lora_train_jobs.json"
def _get_lora_job_id(out_dir: Path, lora_name: str):
return _load_json_map(_train_jobs_file(out_dir)).get(lora_name)
def _set_lora_job_id(out_dir: Path, lora_name: str, job_id):
f = _train_jobs_file(out_dir)
m = _load_json_map(f)
if job_id:
m[lora_name] = job_id
else:
m.pop(lora_name, None)
try:
f.write_text(json.dumps(m, indent=2))
except Exception:
pass
def _lora_specs_for(fighters: list, lora_map: dict, weight: float) -> list: def _lora_specs_for(fighters: list, lora_map: dict, weight: float) -> list:
"""Build the `loras` request list for the fighters appearing in a clip.""" """Build the `loras` request list for the fighters appearing in a clip."""
specs = [] specs = []
...@@ -2525,12 +2584,19 @@ def launch_web_ui(default_args): ...@@ -2525,12 +2584,19 @@ def launch_web_ui(default_args):
f"'{lora_name}' ({steps} steps, rank {rank})" f"'{lora_name}' ({steps} steps, rank {rank})"
+ (f" against {model}" if is_video else "") + "…") + (f" against {model}" if is_video else "") + "…")
# Run the blocking train call in an inner thread; poll progress here. # Kick the training off ASYNCHRONOUSLY (wait=False): the server runs
result, err = {}, {} # it independently and we poll its job record. This (a) avoids HTTP
def _do(): # read-timeouts on multi-hour video trainings, and (b) survives a
try: # township restart — the server keeps training and we re-attach by
# job_id rather than launching a duplicate. The session token tags the
# job as ours so recovery never picks up another client's job.
session = _session_token(out_dir)
_ACTIVE = ("queued", "preparing", "training", "saving")
def _kickoff():
kwargs = dict(name=lora_name, base_model=model, kwargs = dict(name=lora_name, base_model=model,
steps=int(steps), rank=int(rank)) steps=int(steps), rank=int(rank),
wait=False, session=session)
if is_video: if is_video:
kwargs["target"] = "video" kwargs["target"] = "video"
else: else:
...@@ -2538,54 +2604,94 @@ def launch_web_ui(default_args): ...@@ -2538,54 +2604,94 @@ def launch_web_ui(default_args):
if _tbm: if _tbm:
kwargs["train_base_model"] = _tbm kwargs["train_base_model"] = _tbm
kwargs[kind] = name kwargs[kind] = name
result["res"] = client.train_lora(**kwargs) resp = client.train_lora(**kwargs)
except Exception as e: jid = resp.get("job_id")
err["e"] = str(e) _set_lora_job_id(out_dir, lora_name, jid)
t = threading.Thread(target=_do, daemon=True) return jid
t.start()
# Re-attach to an existing server job if we have one recorded and the
# server still knows it (running OR finished while we were away).
job = _get_lora_job_id(out_dir, lora_name)
if job:
pj = client.lora_progress(job=job)
jstatus = (pj.get("status") or "").strip()
if jstatus == "done" and pj.get("path"):
_web_log(f" ↻ Re-attached: '{lora_name}' already trained.")
elif jstatus in _ACTIVE:
_web_log(f" ↻ Re-attached to running job for '{lora_name}'.")
else:
# interrupted / error / unknown → resubmit (server resumes
# from its last on-disk checkpoint if one exists).
_web_log(f" ↻ Previous job for '{lora_name}' was "
f"'{jstatus or 'lost'}' — resubmitting (resumes from "
f"checkpoint)…")
job = _kickoff()
else:
job = _kickoff()
_start_ts = time.time() _start_ts = time.time()
_prog(6, "preparing…") _prog(6, "preparing…")
while t.is_alive(): path = None
err_msg = None
_resubmits = 0
while True:
time.sleep(1.5) time.sleep(1.5)
_elapsed = int(time.time() - _start_ts) _elapsed = int(time.time() - _start_ts)
_mm, _ss = divmod(_elapsed, 60) _mm, _ss = divmod(_elapsed, 60)
_et = f"{_mm}m{_ss:02d}s" if _mm else f"{_ss}s" _et = f"{_mm}m{_ss:02d}s" if _mm else f"{_ss}s"
try: try:
p = client.lora_progress() p = client.lora_progress(job=job)
except Exception: except Exception:
# Server busy (large video models can load for many minutes, # Server busy (large video models can load for many minutes,
# starving the progress endpoint) — keep the UI visibly alive. # starving the progress endpoint) — keep the UI visibly alive.
_prog(6, f"preparing — loading model… ({_et})") _prog(6, f"preparing — loading model… ({_et})")
continue continue
# The server has ONE global training progress (jobs run one at a
# time via the queue). Only mirror it onto THIS card when it's
# reporting OUR LoRA; otherwise another job is training and we're
# still queued — show that instead of its progress.
pname = p.get("name") or ""
if pname and pname != lora_name:
_prog(4, f"queued — '{pname}' training first… ({_et})")
continue
status = (p.get("status") or "").strip() status = (p.get("status") or "").strip()
if status == "done":
path = p.get("path")
break
if status == "error":
err_msg = p.get("message") or "training failed"
break
if status in ("interrupted", "unknown"):
# Server restarted (or forgot the job) mid-train. Resubmit
# once to resume from checkpoint; give up if it keeps dying.
if _resubmits < 2:
_resubmits += 1
_web_log(f" ↻ Server job '{status}' — resubmitting "
f"'{lora_name}' (resume #{_resubmits})…")
try:
job = _kickoff()
except Exception as e:
err_msg = f"resubmit failed: {e}"
break
_prog(6, f"resuming after interruption… ({_et})")
continue
err_msg = f"training {status} and could not be resumed"
break
total = p.get("total") or steps total = p.get("total") or steps
step = p.get("step") or 0 step = p.get("step") or 0
if status in ("preparing", "saving") or not step: if status == "queued":
# No training steps yet (model loading / encoding / saving). # Our job is admitted but another training holds the GPU.
_prog(4, f"queued — waiting for GPU… ({_et})")
elif status in ("preparing", "saving") or not step:
_prog(6, (p.get("message") or status or "preparing") _prog(6, (p.get("message") or status or "preparing")
+ f" ({_et})") + f" ({_et})")
else: else:
pct = 6 + int(90 * step / max(1, total)) pct = 6 + int(90 * step / max(1, total))
_prog(pct, p.get("message") or status or "training") _prog(pct, p.get("message") or status or "training")
t.join()
if err: if err_msg:
_web_log(f" ✗ LoRA training failed for {name}: {err['e']}") _web_log(f" ✗ LoRA training failed for {name}: {err_msg}")
_fail(err["e"]) _set_lora_job_id(out_dir, lora_name, None)
_fail(err_msg)
return return
res = result.get("res") or {}
path = res.get("path")
if not path: if not path:
_fail(f"training returned no path: {res}") _set_lora_job_id(out_dir, lora_name, None)
_fail("training returned no path")
return return
# Done — clear the recorded job so a future retrain starts clean.
_set_lora_job_id(out_dir, lora_name, None)
# Record the trained LoRA in the on-disk map so video/keyframe runs # Record the trained LoRA in the on-disk map so video/keyframe runs
# reuse it. Image LoRAs → loras.json/env_loras.json (flat name→path). # reuse it. Image LoRAs → loras.json/env_loras.json (flat name→path).
...@@ -2721,6 +2827,78 @@ def launch_web_ui(default_args): ...@@ -2721,6 +2827,78 @@ def launch_web_ui(default_args):
lw = float(getattr(default_args, "lora_weight", 0.85)) lw = float(getattr(default_args, "lora_weight", 0.85))
elw = float(getattr(default_args, "env_lora_weight", 0.8)) elw = float(getattr(default_args, "env_lora_weight", 0.8))
# ── Regenerate keyframes (image model) ─────────────────────────────
# Deletes the targeted keyframe PNG(s) then regenerates them so a
# subsequent clip re-render uses fresh keyframes (e.g. after a LoRA
# retrain or profile edit). Does NOT re-render the video itself —
# click "Re-render" afterwards to rebuild the clip from the new
# keyframe. scope "keyframes" = whole match; "keyframe" = one clip
# (idx) or one outcome (fighter+outcome).
if scope in ("keyframes", "keyframe"):
image_model = getattr(default_args, "image_model", None)
if not image_model:
try:
image_model = pick_model(client, "image", None)
except Exception as e:
_fail(f"no image model available: {e}")
return
kdir = vdir / "keyframes"
m = next((x for x in fight_plan
if x.get("match_name") == match_name), None)
# Build the (filtered) fight/outcome plans + the list of stems to
# drop so _generate_keyframes (which keeps existing PNGs) remakes
# exactly those.
fp, op, stems = [], [], []
if scope == "keyframe" and params.get("fighter") and params.get("outcome"):
fr, oc = params.get("fighter"), params.get("outcome")
o = next((x for x in outcome_plan
if x.get("fighter") == fr and x.get("outcome") == oc
and (x.get("match_name") in (None, match_name))), None)
if not o:
_fail("outcome not found in prompts.json")
return
op = [o]
stems = [_clip_stem_outcome(fr, oc, o.get("match_name"))]
else:
if not m:
_fail("match not found in prompts.json — render it first")
return
mm = dict(m)
if scope == "keyframe":
idx = int(params.get("idx"))
mm["clips"] = [c for c in m["clips"] if int(c["idx"]) == idx]
if not mm["clips"]:
_fail("clip not found in prompts.json")
return
fp = [mm]
stems = [_clip_stem_fight(match_name, c["idx"]) for c in mm["clips"]]
_set_items([f"keyframe {s}" for s in stems])
for i, s in enumerate(stems):
_item(i, "start")
try:
(kdir / f"{s}.png").unlink()
except Exception:
pass
_prog(10, f"regenerating {len(stems)} keyframe(s)…")
try:
_generate_keyframes(
client, image_model, kdir, fp, op,
consistency | {"keyframe"}, lora_map,
float(getattr(default_args, "character_strength", 0.7)),
int(getattr(default_args, "keyframe_steps", 28)),
getattr(default_args, "keyframe_size", "512x512"), lw,
env_lora_map=env_lora_map, env_lora_weight=elw)
except Exception as e:
_fail(f"keyframe regeneration failed: {e}")
return
# Mark each item done/failed by whether its PNG now exists.
for i, s in enumerate(stems):
_item(i, "end", (kdir / f"{s}.png").exists())
made = sum(1 for s in stems if (kdir / f"{s}.png").exists())
_done(f"regenerated {made}/{len(stems)} keyframe(s) — "
f"now click Re-render to rebuild the video(s)")
return
# ── Enhance: upscale / raise-FPS existing finals + outcome videos ── # ── Enhance: upscale / raise-FPS existing finals + outcome videos ──
if scope == "enhance": if scope == "enhance":
try: try:
...@@ -3958,10 +4136,14 @@ async function reMatch(ev, scope, params){ ...@@ -3958,10 +4136,14 @@ async function reMatch(ev, scope, params){
'clip':'Re-render this single clip?', 'clip':'Re-render this single clip?',
'reassemble':'Reassemble the final short/long videos from the existing clips? (fast, no model)', 'reassemble':'Reassemble the final short/long videos from the existing clips? (fast, no model)',
'outcomes':'Re-render all output clips for this fighter (uses the video model)?', 'outcomes':'Re-render all output clips for this fighter (uses the video model)?',
'outcome':'Re-render this output clip?'}; 'outcome':'Re-render this output clip?',
'keyframes':'Regenerate ALL keyframe images for this match (uses the image model)? Existing keyframes are replaced; the clip videos are NOT re-rendered — click Re-render afterwards.',
'keyframe':'Regenerate this keyframe image (uses the image model)? The clip video is NOT re-rendered — click Re-render afterwards.'};
const kf=(scope==='keyframes'||scope==='keyframe');
if(!(await uiConfirm(labels[scope]||'Proceed?', if(!(await uiConfirm(labels[scope]||'Proceed?',
{title:'Regenerate', okText:(scope==='reassemble'?'Reassemble':'Re-render'), {title:(kf?'Regenerate keyframes':'Regenerate'),
danger:(scope!=='reassemble')})))return; okText:(scope==='reassemble'?'Reassemble':(kf?'Regenerate':'Re-render')),
danger:(scope!=='reassemble'&&!kf)})))return;
const stEl=_findStatus(ev); const stEl=_findStatus(ev);
const setSt=(c,t)=>{ if(stEl){ stEl.style.color=c; stEl.textContent=t; } }; const setSt=(c,t)=>{ if(stEl){ stEl.style.color=c; stEl.textContent=t; } };
const fd=new FormData(); fd.append('scope',scope); const fd=new FormData(); fd.append('scope',scope);
...@@ -4003,7 +4185,9 @@ async function delVid(ev, scope, params){ ...@@ -4003,7 +4185,9 @@ async function delVid(ev, scope, params){
'final':'Delete this assembled video file?', 'final':'Delete this assembled video file?',
'match':'Delete ALL video files for this match (clips + finals)? The plan/prompts are kept so you can re-render.', 'match':'Delete ALL video files for this match (clips + finals)? The plan/prompts are kept so you can re-render.',
'output':'Delete this output video file?', 'output':'Delete this output video file?',
'outputs':'Delete ALL output video files for this fighter?'}; 'outputs':'Delete ALL output video files for this fighter?',
'keyframes':'Clear ALL keyframe images for this match? The next re-render will run keyframe-free until you regenerate them.',
'keyframe':'Clear this keyframe image? The next re-render of it will run keyframe-free until you regenerate it.'};
if(!(await uiConfirm(labels[scope]||'Delete?',{title:'Remove videos', okText:'Delete', danger:true})))return; if(!(await uiConfirm(labels[scope]||'Delete?',{title:'Remove videos', okText:'Delete', danger:true})))return;
const stEl=_findStatus(ev); const stEl=_findStatus(ev);
const setSt=(c,t)=>{ if(stEl){ stEl.style.color=c; stEl.textContent=t; } }; const setSt=(c,t)=>{ if(stEl){ stEl.style.color=c; stEl.textContent=t; } };
...@@ -4226,7 +4410,10 @@ document.addEventListener('DOMContentLoaded', resumeMatchJobs); ...@@ -4226,7 +4410,10 @@ document.addEventListener('DOMContentLoaded', resumeMatchJobs);
f'<div class=card style="width:230px">' f'<div class=card style="width:230px">'
f' <div class=hint style="display:flex;justify-content:space-between;align-items:center">' f' <div class=hint style="display:flex;justify-content:space-between;align-items:center">'
f'<span>clip {idx:02d}</span>' f'<span>clip {idx:02d}</span>'
f'<span><a href="#" style="color:#7eb8f7" ' f'<span>'
f'<a href="#" style="color:#c79bf0" title="Regenerate this keyframe (image model)" '
f'onclick="reMatch(event,\'keyframe\',{{match:\'{_esc(name)}\',idx:\'{idx}\'}})">kf↻</a> '
f'<a href="#" style="color:#7eb8f7" '
f'onclick="reMatch(event,\'clip\',{{match:\'{_esc(name)}\',idx:\'{idx}\'}})">re-render</a> {rm_html}</span></div>' f'onclick="reMatch(event,\'clip\',{{match:\'{_esc(name)}\',idx:\'{idx}\'}})">re-render</a> {rm_html}</span></div>'
f' {vid_html}' f' {vid_html}'
f' <textarea data-clip="{idx}" rows=2 style="margin-top:.3rem">{_esc(c.get("prompt",""))}</textarea>' f' <textarea data-clip="{idx}" rows=2 style="margin-top:.3rem">{_esc(c.get("prompt",""))}</textarea>'
...@@ -4260,7 +4447,10 @@ document.addEventListener('DOMContentLoaded', resumeMatchJobs); ...@@ -4260,7 +4447,10 @@ document.addEventListener('DOMContentLoaded', resumeMatchJobs);
f'<div class=card style="width:215px">' f'<div class=card style="width:215px">'
f' <div class=hint style="display:flex;justify-content:space-between;align-items:center">' f' <div class=hint style="display:flex;justify-content:space-between;align-items:center">'
f'<span>{_esc(oc)}</span>' f'<span>{_esc(oc)}</span>'
f'<span><a href="#" style="color:#7eb8f7" ' f'<span>'
f'<a href="#" style="color:#c79bf0" title="Regenerate this keyframe (image model)" '
f'onclick="reMatch(event,\'keyframe\',{{match:\'{_esc(name)}\',fighter:\'{_esc(fr)}\',outcome:\'{_esc(oc)}\'}})">kf↻</a> '
f'<a href="#" style="color:#7eb8f7" '
f'onclick="reMatch(event,\'outcome\',{{match:\'{_esc(name)}\',fighter:\'{_esc(fr)}\',outcome:\'{_esc(oc)}\'}})">{act}</a> {rm}</span></div>' f'onclick="reMatch(event,\'outcome\',{{match:\'{_esc(name)}\',fighter:\'{_esc(fr)}\',outcome:\'{_esc(oc)}\'}})">{act}</a> {rm}</span></div>'
f' {vid}' f' {vid}'
f' <textarea data-outc="{_esc(fr)}|{_esc(oc)}" rows=2 style="margin-top:.3rem">{_esc(o.get("prompt",""))}</textarea>' f' <textarea data-outc="{_esc(fr)}|{_esc(oc)}" rows=2 style="margin-top:.3rem">{_esc(o.get("prompt",""))}</textarea>'
...@@ -4303,6 +4493,10 @@ document.addEventListener('DOMContentLoaded', resumeMatchJobs); ...@@ -4303,6 +4493,10 @@ document.addEventListener('DOMContentLoaded', resumeMatchJobs);
f'onclick="reMatch(event,\'reassemble\',{{match:\'{_esc(name)}\'}})">🎞 Reassemble finals</button>' f'onclick="reMatch(event,\'reassemble\',{{match:\'{_esc(name)}\'}})">🎞 Reassemble finals</button>'
f' <button class="btn btn-secondary" style="font-size:.82rem;padding:.35rem .9rem" ' f' <button class="btn btn-secondary" style="font-size:.82rem;padding:.35rem .9rem" '
f'onclick="reMatch(event,\'outcomes\',{{match:\'{_esc(name)}\'}})">♻ Re-render all outcomes</button>' f'onclick="reMatch(event,\'outcomes\',{{match:\'{_esc(name)}\'}})">♻ Re-render all outcomes</button>'
f' <button class="btn btn-secondary" style="font-size:.82rem;padding:.35rem .9rem" '
f'onclick="reMatch(event,\'keyframes\',{{match:\'{_esc(name)}\'}})">🖼 Regenerate keyframes</button>'
f' <button class="btn btn-danger" style="font-size:.82rem;padding:.35rem .9rem" '
f'onclick="delVid(event,\'keyframes\',{{match:\'{_esc(name)}\'}})">🧹 Clear keyframes</button>'
f' <button class="btn btn-danger" style="font-size:.82rem;padding:.35rem .9rem" ' f' <button class="btn btn-danger" style="font-size:.82rem;padding:.35rem .9rem" '
f'onclick="delVid(event,\'match\',{{match:\'{_esc(name)}\'}})">🗑 Remove all videos</button>' f'onclick="delVid(event,\'match\',{{match:\'{_esc(name)}\'}})">🗑 Remove all videos</button>'
f' </div>' f' </div>'
...@@ -5110,6 +5304,35 @@ async function pollJob(){ ...@@ -5110,6 +5304,35 @@ async function pollJob(){
if not _safe(fn) or not fn.endswith(".mp4"): if not _safe(fn) or not fn.endswith(".mp4"):
self._send(400, "application/json", _j.dumps({"error": "invalid file"})); return self._send(400, "application/json", _j.dumps({"error": "invalid file"})); return
_rm(vdir / fn) _rm(vdir / fn)
elif scope == "keyframes":
# Clear ALL keyframe PNGs for a match (its clips + its outcomes).
mn = _fv("match")
if not _safe(mn):
self._send(400, "application/json", _j.dumps({"error": "invalid match"})); return
kdir = vdir / "keyframes"
for p in kdir.glob(f"{mn}_clip*.png"):
_rm(p)
_, _, _, _matches_map, _ = _scan_matches()
for (_f, _o, _p) in _matches_map.get(mn, {}).get("outcomes", []):
_rm(kdir / f"{Path(_p).stem}.png")
elif scope == "keyframe":
# Clear one clip's (match+idx) or one outcome's (fighter+outcome) keyframe.
mn = _fv("match")
kdir = vdir / "keyframes"
fr, oc = _fv("fighter"), _fv("outcome")
if fr and oc:
if not _safe(fr) or not _safe(oc):
self._send(400, "application/json", _j.dumps({"error": "invalid args"})); return
stem = f"{mn}_{fr}_{oc}" if (mn and _safe(mn)) else f"{fr}_{oc}"
_rm(kdir / f"{stem}.png")
else:
idx = _fv("idx")
if not _safe(mn):
self._send(400, "application/json", _j.dumps({"error": "invalid match"})); return
try:
_rm(kdir / f"{mn}_clip{int(idx):02d}.png")
except ValueError:
self._send(400, "application/json", _j.dumps({"error": "invalid idx"})); return
else: else:
self._send(400, "application/json", _j.dumps({"error": "invalid scope"})); return self._send(400, "application/json", _j.dumps({"error": "invalid scope"})); return
self._send(200, "application/json", _j.dumps({"ok": True, "removed": removed})) self._send(200, "application/json", _j.dumps({"ok": True, "removed": removed}))
......
...@@ -72,6 +72,54 @@ def safe_slug(value: str) -> str: ...@@ -72,6 +72,54 @@ def safe_slug(value: str) -> str:
return value.strip("_-. ") or f"item_{uuid.uuid4().hex[:8]}" return value.strip("_-. ") or f"item_{uuid.uuid4().hex[:8]}"
def load_json_map(path: Path) -> dict[str, Any]:
try:
if path.exists():
data = json.loads(path.read_text(encoding="utf-8"))
return data if isinstance(data, dict) else {}
except Exception:
pass
return {}
def session_token(out_dir: Path) -> str:
path = out_dir / ".train_session"
try:
if path.exists():
token = path.read_text(encoding="utf-8").strip()
if token:
return token
except Exception:
pass
token = f"videogen-{uuid.uuid4().hex[:12]}"
try:
path.write_text(token, encoding="utf-8")
except Exception:
pass
return token
def lora_jobs_file(out_dir: Path) -> Path:
return out_dir / "lora_train_jobs.json"
def get_lora_job_id(out_dir: Path, lora_name: str) -> str | None:
return load_json_map(lora_jobs_file(out_dir)).get(lora_name)
def set_lora_job_id(out_dir: Path, lora_name: str, job_id: str | None) -> None:
path = lora_jobs_file(out_dir)
data = load_json_map(path)
if job_id:
data[lora_name] = job_id
else:
data.pop(lora_name, None)
try:
path.write_text(json.dumps(data, indent=2), encoding="utf-8")
except Exception:
pass
def data_uri_for_file(path: Path, mime: str | None = None) -> str: def data_uri_for_file(path: Path, mime: str | None = None) -> str:
if mime is None: if mime is None:
mime = mimetypes.guess_type(str(path))[0] or "application/octet-stream" mime = mimetypes.guess_type(str(path))[0] or "application/octet-stream"
...@@ -150,6 +198,20 @@ def mux_background_audio(video_path: Path, audio_path: Path, out_path: Path, mus ...@@ -150,6 +198,20 @@ def mux_background_audio(video_path: Path, audio_path: Path, out_path: Path, mus
]) ])
def extract_video_frames(video_path: Path, out_dir: Path, max_frames: int = 6) -> list[Path]:
out_dir.mkdir(parents=True, exist_ok=True)
duration = max(0.1, video_duration(video_path))
count = max(1, int(max_frames))
frames = []
for idx in range(count):
ts = duration * (idx + 0.5) / count
out = out_dir / f"{video_path.stem}_frame_{idx:02d}.png"
run_cmd([ffmpeg(), "-y", "-ss", f"{ts:.3f}", "-i", str(video_path), "-frames:v", "1", str(out)])
if out.exists() and out.stat().st_size:
frames.append(out)
return frames
class CoderAIClient: class CoderAIClient:
def __init__(self, base_url: str, api_key: str | None = None, timeout: int = 7200): def __init__(self, base_url: str, api_key: str | None = None, timeout: int = 7200):
self.base = base_url.rstrip("/") self.base = base_url.rstrip("/")
...@@ -158,8 +220,8 @@ class CoderAIClient: ...@@ -158,8 +220,8 @@ class CoderAIClient:
if api_key: if api_key:
self.session.headers["Authorization"] = f"Bearer {api_key}" self.session.headers["Authorization"] = f"Bearer {api_key}"
def _get(self, path: str) -> dict[str, Any]: def _get(self, path: str, timeout: int = 60) -> dict[str, Any]:
resp = self.session.get(f"{self.base}{path}", timeout=60) resp = self.session.get(f"{self.base}{path}", timeout=timeout)
if not resp.ok: if not resp.ok:
raise RuntimeError(f"GET {path} -> {resp.status_code}: {resp.text[:800]}") raise RuntimeError(f"GET {path} -> {resp.status_code}: {resp.text[:800]}")
return resp.json() return resp.json()
...@@ -170,6 +232,12 @@ class CoderAIClient: ...@@ -170,6 +232,12 @@ class CoderAIClient:
raise RuntimeError(f"POST {path} -> {resp.status_code}: {resp.text[:1200]}") raise RuntimeError(f"POST {path} -> {resp.status_code}: {resp.text[:1200]}")
return resp.json() return resp.json()
def _post_multipart(self, path: str, data: dict[str, Any], files: dict[str, Any]) -> dict[str, Any]:
resp = self.session.post(f"{self.base}{path}", data=data, files=files, timeout=self.timeout)
if not resp.ok:
raise RuntimeError(f"POST {path} -> {resp.status_code}: {resp.text[:1200]}")
return resp.json()
def _patch(self, path: str, body: dict[str, Any]) -> dict[str, Any]: def _patch(self, path: str, body: dict[str, Any]) -> dict[str, Any]:
resp = self.session.patch(f"{self.base}{path}", json=body, timeout=self.timeout) resp = self.session.patch(f"{self.base}{path}", json=body, timeout=self.timeout)
if not resp.ok: if not resp.ok:
...@@ -206,6 +274,51 @@ class CoderAIClient: ...@@ -206,6 +274,51 @@ class CoderAIClient:
except Exception: except Exception:
return [] return []
def list_voices(self) -> list[dict[str, Any]]:
try:
return self._get("/v1/audio/voices").get("voices", [])
except Exception:
return []
def list_loras(self) -> list[dict[str, Any]]:
try:
return self._get("/v1/loras").get("loras", [])
except Exception:
return []
def create_voice(self, name: str, description: str, transcript: str, audio_path: Path) -> dict[str, Any]:
mime = mimetypes.guess_type(audio_path.name)[0] or "audio/wav"
with audio_path.open("rb") as handle:
return self._post_multipart(
"/v1/audio/voices",
{"name": name, "description": description, "transcript": transcript},
{"audio": (audio_path.name, handle, mime)},
)
def extract_voice(self, name: str, description: str, transcript: str, media_path: Path) -> dict[str, Any]:
mime = mimetypes.guess_type(media_path.name)[0] or "application/octet-stream"
key = "video" if mime.startswith("video/") or media_path.suffix.lower() in {".mp4", ".mov", ".webm", ".mkv"} else "audio"
return self._post("/v1/audio/voices/extract", {
"name": name,
"description": description,
"transcript": transcript,
key: data_uri_for_file(media_path, mime),
})
def train_lora(self, body: dict[str, Any]) -> dict[str, Any]:
return self._post("/v1/loras/train", body)
def lora_progress(self, job: str | None = None, session: str | None = None) -> dict[str, Any]:
query = ""
if job:
query = f"?job={urllib.parse.quote(job)}"
elif session:
query = f"?session={urllib.parse.quote(session)}"
try:
return self._get(f"/v1/loras/progress{query}", timeout=30)
except Exception:
return {}
def get_profile_images(self, kind: str, name: str) -> list[str]: def get_profile_images(self, kind: str, name: str) -> list[str]:
plural = "characters" if kind == "character" else "environments" plural = "characters" if kind == "character" else "environments"
try: try:
...@@ -424,7 +537,167 @@ class VideoGenApp: ...@@ -424,7 +537,167 @@ class VideoGenApp:
return { return {
"characters": [v for k, v in sorted(chars.items()) if k], "characters": [v for k, v in sorted(chars.items()) if k],
"environments": [v for k, v in sorted(envs.items()) if k], "environments": [v for k, v in sorted(envs.items()) if k],
"voices": self.client.list_voices(),
"loras": self.client.list_loras(),
}
def start_lora_job(self, payload: dict[str, Any]) -> str:
job_id = f"lora-{uuid.uuid4().hex[:10]}"
with self.lock:
self.jobs[job_id] = {"status": "queued", "progress": 0, "kind": "lora"}
thread = threading.Thread(target=self._lora_job, args=(job_id, payload), daemon=True)
thread.start()
return job_id
def _lora_job(self, job_id: str, payload: dict[str, Any]) -> None:
name = safe_slug(payload.get("name") or "movie_style")
base_model = payload.get("base_model") or self.args.video_model or pick_model(self.client.list_models(), "video_generation")
target = payload.get("target") or "video"
media_paths = [Path(p.strip()) for p in (payload.get("media_paths") or "").splitlines() if p.strip()]
style_prompt = payload.get("style_prompt") or f"{name} cinematic style"
instance_prompt = payload.get("instance_prompt") or f"in {name} style, {style_prompt}"
frames_per_video = int(payload.get("frames_per_video") or 6)
try:
if not base_model:
raise RuntimeError("No base model configured for LoRA training")
self._job_update(job_id, status="running", progress=5, message="collecting training media")
train_dir = self.out_dir / "lora_training" / name
frame_dir = train_dir / "frames"
train_dir.mkdir(parents=True, exist_ok=True)
images: list[str] = []
scene_manifest = []
for idx, path in enumerate(media_paths):
if not path.exists() or not path.is_file():
self.emit(f"Skipping missing LoRA media: {path}")
continue
ext = path.suffix.lower()
scene_manifest.append({"path": str(path), "description": payload.get("scene_description") or ""})
if ext in {".png", ".jpg", ".jpeg", ".webp"}:
images.append(data_uri_for_file(path, mimetypes.guess_type(path.name)[0] or "image/png"))
elif ext in {".mp4", ".mov", ".webm", ".mkv"}:
self.emit(f"Extracting frames from {path.name}")
for frame in extract_video_frames(path, frame_dir / safe_slug(path.stem), frames_per_video):
images.append(data_uri_for_file(frame, "image/png"))
if not images:
raise RuntimeError("No image frames found. Add image files or videos to train from.")
(train_dir / "manifest.json").write_text(json.dumps({"name": name, "style_prompt": style_prompt, "scenes": scene_manifest}, indent=2), encoding="utf-8")
self.emit(f"Training {target} LoRA '{name}' from {len(images)} image/frame sample(s)")
self._job_update(job_id, progress=25, message="training LoRA")
train_body = {
"name": name,
"base_model": base_model,
"train_base_model": payload.get("train_base_model") or None,
"target": target,
"images": images,
"instance_prompt": instance_prompt,
"steps": int(payload.get("steps") or 800),
"rank": int(payload.get("rank") or 16),
"learning_rate": float(payload.get("learning_rate") or 1e-4),
"resolution": int(payload.get("resolution") or 512),
"seed": int(payload.get("seed") or 42),
"num_frames": int(payload.get("num_frames") or 1),
"quantize_4bit": bool(payload.get("quantize_4bit", True)),
"wait": False,
"session": session_token(self.out_dir),
} }
path = self._attach_or_start_lora(job_id, name, train_body)
self.emit(f"LoRA ready: {name}" + (f" -> {path}" if path else ""))
set_lora_job_id(self.out_dir, name, None)
self._job_update(job_id, status="done", progress=100, message="done", name=name, path=path)
except Exception as exc:
self.emit(f"LoRA training failed: {exc}")
self._job_update(job_id, status="error", error=str(exc), message=str(exc))
def _attach_or_start_lora(self, ui_job_id: str, lora_name: str, train_body: dict[str, Any]) -> str | None:
active_states = {"queued", "preparing", "training", "saving"}
def kickoff() -> str:
resp = self.client.train_lora({k: v for k, v in train_body.items() if v is not None})
server_job = resp.get("job_id")
if not server_job:
raise RuntimeError(f"LoRA training did not return a job_id: {resp}")
set_lora_job_id(self.out_dir, lora_name, server_job)
return server_job
server_job = get_lora_job_id(self.out_dir, lora_name)
if server_job:
progress = self.client.lora_progress(job=server_job)
status = (progress.get("status") or "").strip()
if status == "done" and progress.get("path"):
self.emit(f"Re-attached: LoRA '{lora_name}' was already trained")
return progress.get("path")
if status in active_states:
self.emit(f"Re-attached to running LoRA job for '{lora_name}'")
else:
self.emit(f"Previous LoRA job for '{lora_name}' was '{status or 'lost'}'; resubmitting to resume from checkpoint")
server_job = kickoff()
else:
server_job = kickoff()
started = time.time()
resubmits = 0
while True:
if self._is_cancelled(ui_job_id):
raise RuntimeError("cancelled")
time.sleep(2.0)
elapsed = int(time.time() - started)
mm, ss = divmod(elapsed, 60)
et = f"{mm}m{ss:02d}s" if mm else f"{ss}s"
progress = self.client.lora_progress(job=server_job)
status = (progress.get("status") or "").strip()
if status == "done":
return progress.get("path")
if status == "error":
raise RuntimeError(progress.get("message") or "LoRA training failed")
if status in {"interrupted", "unknown", ""}:
if resubmits < 2:
resubmits += 1
self.emit(f"LoRA job '{status or 'unknown'}' for '{lora_name}'; resubmitting resume #{resubmits}")
server_job = kickoff()
self._job_update(ui_job_id, progress=30, message=f"resuming after interruption ({et})")
continue
raise RuntimeError(f"training {status or 'unknown'} and could not be resumed")
total = progress.get("total") or train_body.get("steps") or 1
step = progress.get("step") or 0
if status == "queued":
pct = 25
msg = f"queued - waiting for GPU ({et})"
elif status in {"preparing", "saving"} or not step:
pct = 30 if status != "saving" else 94
msg = f"{progress.get('message') or status or 'preparing'} ({et})"
else:
pct = 30 + int(64 * step / max(1, total))
msg = progress.get("message") or f"training {step}/{total} ({et})"
self._job_update(ui_job_id, progress=min(98, pct), message=msg, server_job=server_job)
def start_voice_job(self, payload: dict[str, Any]) -> str:
job_id = f"voice-{uuid.uuid4().hex[:10]}"
with self.lock:
self.jobs[job_id] = {"status": "queued", "progress": 0, "kind": "voice"}
thread = threading.Thread(target=self._voice_job, args=(job_id, payload), daemon=True)
thread.start()
return job_id
def _voice_job(self, job_id: str, payload: dict[str, Any]) -> None:
name = safe_slug(payload.get("name") or "voice")
description = payload.get("description") or ""
transcript = payload.get("transcript") or ""
media_path = Path(payload.get("media_path") or "")
try:
self._job_update(job_id, status="running", progress=5, message="validating reference media")
if not media_path.exists() or not media_path.is_file():
raise RuntimeError(f"Reference audio/video file not found: {media_path}")
self.emit(f"Creating cloned voice profile '{name}' from {media_path.name}")
self._job_update(job_id, progress=25, message="uploading reference media")
if payload.get("extract", True):
self.client.extract_voice(name, description, transcript, media_path)
else:
self.client.create_voice(name, description, transcript, media_path)
self.emit(f"Voice profile ready: {name}")
self._job_update(job_id, status="done", progress=100, message="done", name=name)
except Exception as exc:
self.emit(f"Voice job failed: {exc}")
self._job_update(job_id, status="error", error=str(exc), message=str(exc))
def start_profile_job(self, payload: dict[str, Any]) -> str: def start_profile_job(self, payload: dict[str, Any]) -> str:
job_id = f"profile-{uuid.uuid4().hex[:10]}" job_id = f"profile-{uuid.uuid4().hex[:10]}"
...@@ -480,10 +753,32 @@ class VideoGenApp: ...@@ -480,10 +753,32 @@ class VideoGenApp:
job_id = f"movie-{uuid.uuid4().hex[:10]}" job_id = f"movie-{uuid.uuid4().hex[:10]}"
with self.lock: with self.lock:
self.jobs[job_id] = {"status": "queued", "progress": 0, "movie": payload.get("title") or "movie"} self.jobs[job_id] = {"status": "queued", "progress": 0, "movie": payload.get("title") or "movie"}
thread = threading.Thread(target=self._movie_job, args=(job_id, payload), daemon=True) target = self._movie_batch_job if int(payload.get("movie_count") or 1) > 1 else self._movie_job
thread = threading.Thread(target=target, args=(job_id, payload), daemon=True)
thread.start() thread.start()
return job_id return job_id
def _movie_batch_job(self, job_id: str, payload: dict[str, Any]) -> None:
count = max(1, int(payload.get("movie_count") or 1))
outputs = []
base_title = payload.get("title") or "untitled_movie"
for idx in range(count):
if self._is_cancelled(job_id):
self._job_update(job_id, status="error", error="cancelled", message="cancelled")
return
variant = dict(payload)
variant["movie_count"] = 1
variant["title"] = f"{base_title}_variant_{idx + 1:02d}"
self.emit(f"Rendering movie variant {idx + 1}/{count}")
self._movie_job(job_id, variant)
with self.lock:
job = dict(self.jobs.get(job_id, {}))
if job.get("status") == "error":
return
if job.get("output_url"):
outputs.append(job["output_url"])
self._job_update(job_id, status="done", progress=100, message=f"rendered {len(outputs)} movie variant(s)", outputs=outputs, output_url=outputs[-1] if outputs else None)
def _movie_job(self, job_id: str, payload: dict[str, Any]) -> None: def _movie_job(self, job_id: str, payload: dict[str, Any]) -> None:
title = payload.get("title") or "untitled_movie" title = payload.get("title") or "untitled_movie"
slug = safe_slug(title) slug = safe_slug(title)
...@@ -508,6 +803,7 @@ class VideoGenApp: ...@@ -508,6 +803,7 @@ class VideoGenApp:
height = int(payload.get("height") or 432) height = int(payload.get("height") or 432)
default_frames = int(payload.get("num_frames") or 32) default_frames = int(payload.get("num_frames") or 32)
use_keyframes = bool(payload.get("use_keyframes")) use_keyframes = bool(payload.get("use_keyframes"))
selected_loras = self._selected_loras(payload)
clip_paths: list[Path] = [] clip_paths: list[Path] = []
total = len(clips) total = len(clips)
self.emit(f"Starting movie '{title}' with {total} clip(s)") self.emit(f"Starting movie '{title}' with {total} clip(s)")
...@@ -540,15 +836,20 @@ class VideoGenApp: ...@@ -540,15 +836,20 @@ class VideoGenApp:
body["environment_profiles"] = environments body["environment_profiles"] = environments
if clip.get("camera_motion"): if clip.get("camera_motion"):
body["camera_motion"] = clip.get("camera_motion") body["camera_motion"] = clip.get("camera_motion")
if selected_loras:
body["loras"] = selected_loras
if clip.get("dialogues"): if clip.get("dialogues"):
body["dialogs"] = self._normalize_dialogues(clip.get("dialogues")) body["dialogs"] = self._normalize_dialogues(clip.get("dialogues"))
body["lip_sync"] = bool(clip.get("lip_sync", True)) body["lip_sync"] = bool(clip.get("lip_sync", True))
body["lip_sync_method"] = clip.get("lip_sync_method") or payload.get("lip_sync_method") or "wav2lip"
body["generate_subtitles"] = bool(clip.get("subtitles", True)) body["generate_subtitles"] = bool(clip.get("subtitles", True))
body["burn_subtitles"] = bool(clip.get("burn_subtitles", False)) body["burn_subtitles"] = bool(clip.get("burn_subtitles", False))
if clip.get("speech_text"): if clip.get("speech_text"):
body["tts_text"] = clip.get("speech_text") body["tts_text"] = clip.get("speech_text")
body["tts_voice"] = clip.get("speech_voice") or payload.get("default_voice") or "af_sarah" body["tts_voice"] = clip.get("speech_voice") or payload.get("default_voice") or "af_sarah"
body["tts_speed"] = float(clip.get("speech_speed") or 1.0)
body["lip_sync"] = bool(clip.get("lip_sync", True)) body["lip_sync"] = bool(clip.get("lip_sync", True))
body["lip_sync_method"] = clip.get("lip_sync_method") or payload.get("lip_sync_method") or "wav2lip"
body["add_audio"] = True body["add_audio"] = True
body["audio_type"] = "speech" body["audio_type"] = "speech"
if clip.get("music_prompt") or clip.get("sfx_prompt"): if clip.get("music_prompt") or clip.get("sfx_prompt"):
...@@ -600,6 +901,23 @@ class VideoGenApp: ...@@ -600,6 +901,23 @@ class VideoGenApp:
self.emit(f"Movie failed: {exc}") self.emit(f"Movie failed: {exc}")
self._job_update(job_id, status="error", error=str(exc), message=str(exc)) self._job_update(job_id, status="error", error=str(exc), message=str(exc))
def _selected_loras(self, payload: dict[str, Any]) -> list[dict[str, Any]]:
requested = payload.get("loras") or []
if not requested:
return []
by_name = {item.get("name"): item for item in self.client.list_loras() if item.get("name")}
out = []
for item in requested:
if isinstance(item, str):
name, weight = item, float(payload.get("lora_weight") or 0.8)
else:
name, weight = item.get("name") or item.get("model"), float(item.get("weight") or payload.get("lora_weight") or 0.8)
info = by_name.get(name, {})
model = info.get("path") or name
if model:
out.append({"model": model, "weight": weight, "name": name})
return out
def _build_clip_prompt(self, movie: dict[str, Any], clip: dict[str, Any], characters: list[str], environments: list[str]) -> str: def _build_clip_prompt(self, movie: dict[str, Any], clip: dict[str, Any], characters: list[str], environments: list[str]) -> str:
parts = [] parts = []
if movie.get("style"): if movie.get("style"):
...@@ -642,6 +960,7 @@ class VideoGenApp: ...@@ -642,6 +960,7 @@ class VideoGenApp:
"text": row.get("text"), "text": row.get("text"),
"start_time": row.get("start_time") if row.get("start_time") not in ("", None) else None, "start_time": row.get("start_time") if row.get("start_time") not in ("", None) else None,
"lip_sync": bool(row.get("lip_sync", True)), "lip_sync": bool(row.get("lip_sync", True)),
"lang": row.get("lang") or None,
"speed": float(row.get("speed") or 1.0), "speed": float(row.get("speed") or 1.0),
}) })
return out return out
...@@ -673,7 +992,7 @@ HTML_PAGE = r""" ...@@ -673,7 +992,7 @@ HTML_PAGE = r"""
</style> </style>
</head> </head>
<body> <body>
<header><h1>CoderAI VideoGen Studio</h1><div class="sub">Manage reusable characters and environments, write a multi-clip movie prompt, then render clips with speech, lip-sync, music, and sound effects.</div></header> <header><h1>CoderAI VideoGen Studio</h1><div class="sub">Manage reusable characters, environments, cloned voices, and multi-clip movies with realistic dialogue, lip-sync, music, and sound effects.</div></header>
<div class="wrap"> <div class="wrap">
<aside class="card"> <aside class="card">
<h2>Connection</h2> <h2>Connection</h2>
...@@ -681,7 +1000,7 @@ HTML_PAGE = r""" ...@@ -681,7 +1000,7 @@ HTML_PAGE = r"""
<label>Image model</label><select id="image_model"></select> <label>Image model</label><select id="image_model"></select>
<label>Video model</label><select id="video_model"></select> <label>Video model</label><select id="video_model"></select>
<label>Audio/Music model</label><select id="audio_model"></select> <label>Audio/Music model</label><select id="audio_model"></select>
<label>TTS voice id</label><input id="default_voice" value="af_sarah"> <label>Default voice / cloned profile</label><input id="default_voice" list="voice_list" value="af_sarah"><datalist id="voice_list"></datalist>
<hr style="border-color:var(--line);border-style:solid none none;margin:16px 0"> <hr style="border-color:var(--line);border-style:solid none none;margin:16px 0">
<h2>Live Log</h2> <h2>Live Log</h2>
<div class="log" id="log"></div> <div class="log" id="log"></div>
...@@ -691,11 +1010,12 @@ HTML_PAGE = r""" ...@@ -691,11 +1010,12 @@ HTML_PAGE = r"""
<div class="tabs"> <div class="tabs">
<button class="tab active" data-tab="profiles">Profiles</button> <button class="tab active" data-tab="profiles">Profiles</button>
<button class="tab" data-tab="movie">Movie Builder</button> <button class="tab" data-tab="movie">Movie Builder</button>
<button class="tab" data-tab="loras">Style LoRAs</button>
<button class="tab" data-tab="gallery">Gallery</button> <button class="tab" data-tab="gallery">Gallery</button>
</div> </div>
<section id="profiles" class="section active"> <section id="profiles" class="section active">
<h2>Characters and Environments</h2> <h2>Characters, Environments, and Voices</h2>
<div class="row"> <div class="row">
<div class="card"> <div class="card">
<h3>Create Character</h3> <h3>Create Character</h3>
...@@ -714,8 +1034,17 @@ HTML_PAGE = r""" ...@@ -714,8 +1034,17 @@ HTML_PAGE = r"""
<button class="btn" onclick="createProfile('environment')">Generate Environment</button> <button class="btn" onclick="createProfile('environment')">Generate Environment</button>
</div> </div>
</div> </div>
<div class="card" style="margin-top:12px">
<h3>Create Cloned Voice</h3>
<div class="muted">Use a clean 5-30 second reference clip plus an exact transcript for realistic cloned dialogue. Audio or video files are accepted.</div>
<div class="row"><div><label>Voice name</label><input id="voice_name" placeholder="alice_voice"></div><div><label>Reference audio/video path</label><input id="voice_media" placeholder="/path/to/reference.wav"></div></div>
<label>Description</label><input id="voice_desc" placeholder="Warm, calm adult voice; close mic">
<label>Exact reference transcript</label><textarea id="voice_transcript" placeholder="The exact words spoken in the reference clip. Required for best F5-TTS cloning."></textarea>
<button class="btn" onclick="createVoice()">Create Cloned Voice</button>
</div>
<h3>Saved Characters</h3><div class="grid" id="chars"></div> <h3>Saved Characters</h3><div class="grid" id="chars"></div>
<h3>Saved Environments</h3><div class="grid" id="envs"></div> <h3>Saved Environments</h3><div class="grid" id="envs"></div>
<h3>Saved Voices</h3><div class="grid" id="voices"></div>
</section> </section>
<section id="movie" class="section"> <section id="movie" class="section">
...@@ -724,6 +1053,8 @@ HTML_PAGE = r""" ...@@ -724,6 +1053,8 @@ HTML_PAGE = r"""
<div><label>Title</label><input id="title" value="my_little_movie"></div> <div><label>Title</label><input id="title" value="my_little_movie"></div>
<div><label>Visual style</label><input id="style" value="cinematic, coherent character identity, natural motion, detailed lighting"></div> <div><label>Visual style</label><input id="style" value="cinematic, coherent character identity, natural motion, detailed lighting"></div>
</div> </div>
<label>Style LoRAs to apply</label><select id="movie_loras" multiple size="5"></select>
<div class="row"><div><label>LoRA weight</label><input id="movie_lora_weight" type="number" step="0.05" value="0.8"></div><div><label>Movie variants to produce</label><input id="movie_count" type="number" min="1" value="1"></div></div>
<div class="row"> <div class="row">
<div><label>Width</label><input id="width" type="number" value="768"></div> <div><label>Width</label><input id="width" type="number" value="768"></div>
<div><label>Height</label><input id="height" type="number" value="432"></div> <div><label>Height</label><input id="height" type="number" value="432"></div>
...@@ -738,12 +1069,28 @@ HTML_PAGE = r""" ...@@ -738,12 +1069,28 @@ HTML_PAGE = r"""
</div> </div>
<label>Global negative prompt</label><input id="negative_prompt" value="flicker, morphing faces, extra limbs, low quality, unreadable text"> <label>Global negative prompt</label><input id="negative_prompt" value="flicker, morphing faces, extra limbs, low quality, unreadable text">
<label><input id="use_keyframes" type="checkbox" style="width:auto"> Generate keyframe image before each video clip for stronger character/environment consistency</label> <label><input id="use_keyframes" type="checkbox" style="width:auto"> Generate keyframe image before each video clip for stronger character/environment consistency</label>
<label>Lip-sync method</label><select id="lip_sync_method"><option value="wav2lip">wav2lip</option><option value="sadtalker">sadtalker</option></select>
<label>Final soundtrack prompt (optional; mixed under assembled movie)</label><textarea id="soundtrack_prompt" placeholder="tense orchestral pulse with soft percussion, no vocals"></textarea> <label>Final soundtrack prompt (optional; mixed under assembled movie)</label><textarea id="soundtrack_prompt" placeholder="tense orchestral pulse with soft percussion, no vocals"></textarea>
<div id="clips"></div> <div id="clips"></div>
<button class="btn secondary" onclick="addClip()">Add Clip</button> <button class="btn secondary" onclick="addClip()">Add Clip</button>
<button class="btn" onclick="startMovie()">Render Movie</button> <button class="btn" onclick="startMovie()">Render Movie</button>
</section> </section>
<section id="loras" class="section">
<h2>Style LoRA Training</h2>
<div class="muted">Train a reusable style LoRA from image/video scene references. Add stills or videos, describe the visual style, then choose the trained LoRA in Movie Builder to produce one or more movies in that style.</div>
<div class="card" style="margin-top:12px">
<div class="row"><div><label>LoRA name</label><input id="lora_name" placeholder="noir_rain_style"></div><div><label>Target</label><select id="lora_target"><option value="video">video</option><option value="image">image</option></select></div></div>
<label>Training media paths, one per line</label><textarea id="lora_media" placeholder="/path/to/style_reference_01.png&#10;/path/to/scene_clip.mp4"></textarea>
<label>Scene/style description</label><textarea id="lora_style" placeholder="Describe the scene language: lighting, lens, color palette, texture, movement, costumes, production design..."></textarea>
<label>Instance prompt/token</label><input id="lora_instance" placeholder="in noir_rain_style cinematic style">
<div class="row"><div><label>Steps</label><input id="lora_steps" type="number" value="800"></div><div><label>Rank</label><input id="lora_rank" type="number" value="16"></div></div>
<div class="row"><div><label>Resolution</label><input id="lora_resolution" type="number" value="512"></div><div><label>Frames per video sample</label><input id="lora_frames" type="number" value="6"></div></div>
<button class="btn" onclick="trainLora()">Train Style LoRA</button>
</div>
<h3>Available LoRAs</h3><div class="grid" id="loras_grid"></div>
</section>
<section id="gallery" class="section"> <section id="gallery" class="section">
<h2>Gallery</h2> <h2>Gallery</h2>
<button class="btn secondary" onclick="loadGallery()">Refresh Gallery</button> <button class="btn secondary" onclick="loadGallery()">Refresh Gallery</button>
...@@ -752,7 +1099,7 @@ HTML_PAGE = r""" ...@@ -752,7 +1099,7 @@ HTML_PAGE = r"""
</main> </main>
</div> </div>
<script> <script>
let models=[], profiles={characters:[], environments:[]}; let models=[], profiles={characters:[], environments:[], voices:[], loras:[]};
function $(id){return document.getElementById(id)} function $(id){return document.getElementById(id)}
function esc(s){return String(s||'').replace(/[&<>"']/g,m=>({'&':'&amp;','<':'&lt;','>':'&gt;','"':'&quot;',"'":'&#39;'}[m]))} function esc(s){return String(s||'').replace(/[&<>"']/g,m=>({'&':'&amp;','<':'&lt;','>':'&gt;','"':'&quot;',"'":'&#39;'}[m]))}
async function api(path, opts={}){let r=await fetch(path,{headers:{'Content-Type':'application/json'},...opts}); if(!r.ok) throw new Error(await r.text()); return await r.json()} async function api(path, opts={}){let r=await fetch(path,{headers:{'Content-Type':'application/json'},...opts}); if(!r.ok) throw new Error(await r.text()); return await r.json()}
...@@ -760,8 +1107,14 @@ function fillSelect(sel, cap, def){let s=$(sel); s.innerHTML=''; let filtered=mo ...@@ -760,8 +1107,14 @@ function fillSelect(sel, cap, def){let s=$(sel); s.innerHTML=''; let filtered=mo
async function loadModels(){let d=await api('/api/models'); models=d.models||[]; fillSelect('image_model','image_generation',d.defaults.image_model); fillSelect('video_model','video_generation',d.defaults.video_model); fillSelect('audio_model','audio_generation',d.defaults.audio_model); $('conn').textContent=`Connected: ${models.length} model(s)`} async function loadModels(){let d=await api('/api/models'); models=d.models||[]; fillSelect('image_model','image_generation',d.defaults.image_model); fillSelect('video_model','video_generation',d.defaults.video_model); fillSelect('audio_model','audio_generation',d.defaults.audio_model); $('conn').textContent=`Connected: ${models.length} model(s)`}
async function loadProfiles(){profiles=await api('/api/profiles'); renderProfiles()} async function loadProfiles(){profiles=await api('/api/profiles'); renderProfiles()}
function profileCard(p,kind){return `<div class="profile"><img src="${p.thumbnail||''}" onerror="this.style.display='none'"><div class="p"><b>${esc(p.name)}</b><div class="muted">${esc(p.description||'')}</div><span class="pill">${kind}</span><span class="pill">${p.image_count||0} refs</span><span class="pill">${p.local?'local':'server'}</span></div></div>`} function profileCard(p,kind){return `<div class="profile"><img src="${p.thumbnail||''}" onerror="this.style.display='none'"><div class="p"><b>${esc(p.name)}</b><div class="muted">${esc(p.description||'')}</div><span class="pill">${kind}</span><span class="pill">${p.image_count||0} refs</span><span class="pill">${p.local?'local':'server'}</span></div></div>`}
function renderProfiles(){$('chars').innerHTML=profiles.characters.map(p=>profileCard(p,'character')).join('')||'<div class="muted">No characters yet.</div>'; $('envs').innerHTML=profiles.environments.map(p=>profileCard(p,'environment')).join('')||'<div class="muted">No environments yet.</div>'; renderClipSelectors()} function voiceCard(v){return `<div class="profile"><div class="p"><b>${esc(v.name)}</b><div class="muted">${esc(v.description||'')}</div><span class="pill">cloned voice</span><span class="pill">${esc(v.audio_ext||'audio')}</span></div></div>`}
function loraCard(l){return `<div class="profile"><div class="p"><b>${esc(l.name)}</b><div class="muted">${esc(l.instance_prompt||l.path||'')}</div><span class="pill">${esc(l.target||'lora')}</span><span class="pill">rank ${esc(l.rank||'')}</span></div></div>`}
function renderVoiceList(){let voices=['af_sarah','af_sky','af_bella','am_adam','am_michael','en-US-JennyNeural',...(profiles.voices||[]).map(v=>v.name)].filter(Boolean); $('voice_list').innerHTML=[...new Set(voices)].map(v=>`<option value="${esc(v)}"></option>`).join('')}
function renderLoraList(){let opts=(profiles.loras||[]).map(l=>`<option value="${esc(l.name)}">${esc(l.name)}${l.target?' ('+esc(l.target)+')':''}</option>`).join(''); $('movie_loras').innerHTML=opts; $('loras_grid').innerHTML=(profiles.loras||[]).map(loraCard).join('')||'<div class="muted">No style LoRAs yet. Train one from media references.</div>'}
function renderProfiles(){$('chars').innerHTML=profiles.characters.map(p=>profileCard(p,'character')).join('')||'<div class="muted">No characters yet.</div>'; $('envs').innerHTML=profiles.environments.map(p=>profileCard(p,'environment')).join('')||'<div class="muted">No environments yet.</div>'; $('voices').innerHTML=(profiles.voices||[]).map(voiceCard).join('')||'<div class="muted">No cloned voices yet. Create one from a reference clip.</div>'; renderVoiceList(); renderLoraList(); renderClipSelectors()}
async function createProfile(kind){let isChar=kind==='character'; let name=$(isChar?'char_name':'env_name').value; let desc=$(isChar?'char_desc':'env_desc').value; let prompt=$(isChar?'char_prompt':'env_prompt').value||desc; let n=$(isChar?'char_n':'env_n').value; let [w,h]=($(isChar?'char_size':'env_size').value||'512x512').split('x').map(x=>parseInt(x,10)); let model=$('image_model').value; let d=await api('/api/profile/start',{method:'POST',body:JSON.stringify({kind,name,description:desc,prompt,model,n,width:w,height:h})}); watchJob(d.job_id);} async function createProfile(kind){let isChar=kind==='character'; let name=$(isChar?'char_name':'env_name').value; let desc=$(isChar?'char_desc':'env_desc').value; let prompt=$(isChar?'char_prompt':'env_prompt').value||desc; let n=$(isChar?'char_n':'env_n').value; let [w,h]=($(isChar?'char_size':'env_size').value||'512x512').split('x').map(x=>parseInt(x,10)); let model=$('image_model').value; let d=await api('/api/profile/start',{method:'POST',body:JSON.stringify({kind,name,description:desc,prompt,model,n,width:w,height:h})}); watchJob(d.job_id);}
async function createVoice(){let d=await api('/api/voice/start',{method:'POST',body:JSON.stringify({name:$('voice_name').value,description:$('voice_desc').value,transcript:$('voice_transcript').value,media_path:$('voice_media').value,extract:true})}); watchJob(d.job_id)}
async function trainLora(){let d=await api('/api/lora/start',{method:'POST',body:JSON.stringify({name:$('lora_name').value,base_model:$('video_model').value,target:$('lora_target').value,media_paths:$('lora_media').value,style_prompt:$('lora_style').value,scene_description:$('lora_style').value,instance_prompt:$('lora_instance').value,steps:$('lora_steps').value,rank:$('lora_rank').value,resolution:$('lora_resolution').value,frames_per_video:$('lora_frames').value,quantize_4bit:true})}); watchJob(d.job_id)}
function options(items){return items.map(p=>`<option value="${esc(p.name)}">${esc(p.name)}</option>`).join('')} function options(items){return items.map(p=>`<option value="${esc(p.name)}">${esc(p.name)}</option>`).join('')}
function addClip(data={}){let idx=document.querySelectorAll('.clip').length+1; let div=document.createElement('div'); div.className='clip'; div.innerHTML=`<div class="clip-head"><h3>Clip ${idx}</h3><button class="btn bad" onclick="this.closest('.clip').remove()">Remove</button></div> function addClip(data={}){let idx=document.querySelectorAll('.clip').length+1; let div=document.createElement('div'); div.className='clip'; div.innerHTML=`<div class="clip-head"><h3>Clip ${idx}</h3><button class="btn bad" onclick="this.closest('.clip').remove()">Remove</button></div>
<label>Clip title</label><input class="c_title" value="${esc(data.title||'Shot '+idx)}"> <label>Clip title</label><input class="c_title" value="${esc(data.title||'Shot '+idx)}">
...@@ -769,15 +1122,16 @@ function addClip(data={}){let idx=document.querySelectorAll('.clip').length+1; l ...@@ -769,15 +1122,16 @@ function addClip(data={}){let idx=document.querySelectorAll('.clip').length+1; l
<div class="row"><div><label>Characters</label><select class="c_chars" multiple size="5">${options(profiles.characters)}</select></div><div><label>Environments</label><select class="c_envs" multiple size="5">${options(profiles.environments)}</select></div></div> <div class="row"><div><label>Characters</label><select class="c_chars" multiple size="5">${options(profiles.characters)}</select></div><div><label>Environments</label><select class="c_envs" multiple size="5">${options(profiles.environments)}</select></div></div>
<div class="row"><div><label>Camera motion</label><input class="c_camera" placeholder="zoom-in, pan-left, handheld..."></div><div><label>Mood/action</label><input class="c_action" placeholder="what happens in this clip"></div></div> <div class="row"><div><label>Camera motion</label><input class="c_camera" placeholder="zoom-in, pan-left, handheld..."></div><div><label>Mood/action</label><input class="c_action" placeholder="what happens in this clip"></div></div>
<label>Speech text (simple one-speaker; optional)</label><input class="c_speech" placeholder="Line spoken in this shot"> <label>Speech text (simple one-speaker; optional)</label><input class="c_speech" placeholder="Line spoken in this shot">
<div class="row"><div><label>Speech voice</label><input class="c_voice" value="${esc($('default_voice').value||'af_sarah')}"></div><div><label><input class="c_lipsync" type="checkbox" checked style="width:auto"> Lip-sync speech/dialogue</label></div></div> <div class="row"><div><label>Speech voice / cloned profile</label><input class="c_voice" list="voice_list" value="${esc($('default_voice').value||'af_sarah')}"></div><div><label>Speech speed</label><input class="c_speed" type="number" step="0.05" value="1.0"></div></div>
<label><input class="c_lipsync" type="checkbox" checked style="width:auto"> Lip-sync speech/dialogue</label>
<h3>Multi-character dialogue</h3><div class="dialogues"></div><button class="btn secondary" onclick="addDialogue(this)">Add Dialogue Line</button> <h3>Multi-character dialogue</h3><div class="dialogues"></div><button class="btn secondary" onclick="addDialogue(this)">Add Dialogue Line</button>
<label>Music prompt for this clip</label><input class="c_music" placeholder="short local music bed for this clip"> <label>Music prompt for this clip</label><input class="c_music" placeholder="short local music bed for this clip">
<label>Sound effects prompt for this clip</label><input class="c_sfx" placeholder="rain, footsteps, door creak, city ambience"> <label>Sound effects prompt for this clip</label><input class="c_sfx" placeholder="rain, footsteps, door creak, city ambience">
</div>`; $('clips').appendChild(div); renderClipSelectors(div)} </div>`; $('clips').appendChild(div); renderClipSelectors(div)}
function renderClipSelectors(root=document){for(let s of root.querySelectorAll('.c_chars')){let vals=[...s.selectedOptions].map(o=>o.value); s.innerHTML=options(profiles.characters); for(let o of s.options) if(vals.includes(o.value)) o.selected=true} for(let s of root.querySelectorAll('.c_envs')){let vals=[...s.selectedOptions].map(o=>o.value); s.innerHTML=options(profiles.environments); for(let o of s.options) if(vals.includes(o.value)) o.selected=true} for(let s of root.querySelectorAll('.d_char')){let val=s.value; s.innerHTML='<option value="">(none)</option>'+options(profiles.characters); s.value=val}} function renderClipSelectors(root=document){for(let s of root.querySelectorAll('.c_chars')){let vals=[...s.selectedOptions].map(o=>o.value); s.innerHTML=options(profiles.characters); for(let o of s.options) if(vals.includes(o.value)) o.selected=true} for(let s of root.querySelectorAll('.c_envs')){let vals=[...s.selectedOptions].map(o=>o.value); s.innerHTML=options(profiles.environments); for(let o of s.options) if(vals.includes(o.value)) o.selected=true} for(let s of root.querySelectorAll('.d_char')){let val=s.value; s.innerHTML='<option value="">(none)</option>'+options(profiles.characters); s.value=val}}
function addDialogue(btn){let box=btn.closest('.clip').querySelector('.dialogues'); let row=document.createElement('div'); row.className='dialogue'; row.innerHTML=`<select class="d_char"><option value="">(none)</option>${options(profiles.characters)}</select><input class="d_voice" placeholder="voice/profile" value="${esc($('default_voice').value||'af_sarah')}"><input class="d_text" placeholder="dialogue text"><input class="d_start" placeholder="start s"><button class="btn bad" onclick="this.parentElement.remove()">x</button>`; box.appendChild(row)} function addDialogue(btn){let box=btn.closest('.clip').querySelector('.dialogues'); let row=document.createElement('div'); row.className='dialogue'; row.innerHTML=`<select class="d_char"><option value="">(none)</option>${options(profiles.characters)}</select><input class="d_voice" list="voice_list" placeholder="voice/profile" value="${esc($('default_voice').value||'af_sarah')}"><input class="d_text" placeholder="dialogue text"><input class="d_start" placeholder="start s"><input class="d_speed" type="number" step="0.05" value="1.0"><button class="btn bad" onclick="this.parentElement.remove()">x</button>`; box.appendChild(row)}
function selected(sel){return [...sel.selectedOptions].map(o=>o.value)} function selected(sel){return [...sel.selectedOptions].map(o=>o.value)}
function collectMovie(){let clips=[...document.querySelectorAll('.clip')].map(c=>({title:c.querySelector('.c_title').value,prompt:c.querySelector('.c_prompt').value,characters:selected(c.querySelector('.c_chars')),environments:selected(c.querySelector('.c_envs')),camera_motion:c.querySelector('.c_camera').value,action:c.querySelector('.c_action').value,speech_text:c.querySelector('.c_speech').value,speech_voice:c.querySelector('.c_voice').value,lip_sync:c.querySelector('.c_lipsync').checked,music_prompt:c.querySelector('.c_music').value,sfx_prompt:c.querySelector('.c_sfx').value,dialogues:[...c.querySelectorAll('.dialogue')].map(d=>({character:d.querySelector('.d_char').value,voice:d.querySelector('.d_voice').value,text:d.querySelector('.d_text').value,start_time:d.querySelector('.d_start').value}))})); return {title:$('title').value,style:$('style').value,image_model:$('image_model').value,video_model:$('video_model').value,audio_model:$('audio_model').value,default_voice:$('default_voice').value,width:+$('width').value,height:+$('height').value,fps:+$('fps').value,num_frames:+$('num_frames').value,steps:+$('steps').value,guidance_scale:+$('guidance_scale').value,negative_prompt:$('negative_prompt').value,use_keyframes:$('use_keyframes').checked,soundtrack_prompt:$('soundtrack_prompt').value,clips}} function collectMovie(){let clips=[...document.querySelectorAll('.clip')].map(c=>({title:c.querySelector('.c_title').value,prompt:c.querySelector('.c_prompt').value,characters:selected(c.querySelector('.c_chars')),environments:selected(c.querySelector('.c_envs')),camera_motion:c.querySelector('.c_camera').value,action:c.querySelector('.c_action').value,speech_text:c.querySelector('.c_speech').value,speech_voice:c.querySelector('.c_voice').value,speech_speed:c.querySelector('.c_speed').value,lip_sync:c.querySelector('.c_lipsync').checked,lip_sync_method:$('lip_sync_method').value,music_prompt:c.querySelector('.c_music').value,sfx_prompt:c.querySelector('.c_sfx').value,dialogues:[...c.querySelectorAll('.dialogue')].map(d=>({character:d.querySelector('.d_char').value,voice:d.querySelector('.d_voice').value,text:d.querySelector('.d_text').value,start_time:d.querySelector('.d_start').value,speed:d.querySelector('.d_speed').value,lip_sync:c.querySelector('.c_lipsync').checked}))})); return {title:$('title').value,style:$('style').value,image_model:$('image_model').value,video_model:$('video_model').value,audio_model:$('audio_model').value,default_voice:$('default_voice').value,lip_sync_method:$('lip_sync_method').value,width:+$('width').value,height:+$('height').value,fps:+$('fps').value,num_frames:+$('num_frames').value,steps:+$('steps').value,guidance_scale:+$('guidance_scale').value,negative_prompt:$('negative_prompt').value,use_keyframes:$('use_keyframes').checked,soundtrack_prompt:$('soundtrack_prompt').value,loras:selected($('movie_loras')).map(n=>({name:n,weight:+$('movie_lora_weight').value})),lora_weight:+$('movie_lora_weight').value,movie_count:+$('movie_count').value,clips}}
async function startMovie(){let d=await api('/api/movie/start',{method:'POST',body:JSON.stringify(collectMovie())}); watchJob(d.job_id)} async function startMovie(){let d=await api('/api/movie/start',{method:'POST',body:JSON.stringify(collectMovie())}); watchJob(d.job_id)}
async function watchJob(id){$('jobout').innerHTML=`<p>Job <span class="pill">${id}</span></p>`; let timer=setInterval(async()=>{let j=await api('/api/job/'+id); $('jobout').innerHTML=`<p><span class="pill">${esc(j.status)}</span> ${j.progress||0}% ${esc(j.message||'')}</p>`+(j.output_url?`<p><a href="${j.output_url}" target="_blank">Open output</a></p>`:'')+(j.error?`<p style="color:var(--bad)">${esc(j.error)}</p>`:''); if(j.status==='done'||j.status==='error'){clearInterval(timer); loadProfiles(); loadGallery()}},1500)} async function watchJob(id){$('jobout').innerHTML=`<p>Job <span class="pill">${id}</span></p>`; let timer=setInterval(async()=>{let j=await api('/api/job/'+id); $('jobout').innerHTML=`<p><span class="pill">${esc(j.status)}</span> ${j.progress||0}% ${esc(j.message||'')}</p>`+(j.output_url?`<p><a href="${j.output_url}" target="_blank">Open output</a></p>`:'')+(j.error?`<p style="color:var(--bad)">${esc(j.error)}</p>`:''); if(j.status==='done'||j.status==='error'){clearInterval(timer); loadProfiles(); loadGallery()}},1500)}
async function loadGallery(){let d=await api('/api/gallery'); $('gallery_grid').innerHTML=(d.items||[]).map(it=>`<div class="profile">${it.type==='video'?`<video src="${it.url}" controls style="width:100%;height:130px;background:#000"></video>`:`<img src="${it.url}">`}<div class="p"><b>${esc(it.name)}</b><br><a href="${it.url}" target="_blank">open</a></div></div>`).join('')||'<div class="muted">No media yet.</div>'} async function loadGallery(){let d=await api('/api/gallery'); $('gallery_grid').innerHTML=(d.items||[]).map(it=>`<div class="profile">${it.type==='video'?`<video src="${it.url}" controls style="width:100%;height:130px;background:#000"></video>`:`<img src="${it.url}">`}<div class="p"><b>${esc(it.name)}</b><br><a href="${it.url}" target="_blank">open</a></div></div>`).join('')||'<div class="muted">No media yet.</div>'}
...@@ -871,6 +1225,10 @@ def make_handler(app: VideoGenApp): ...@@ -871,6 +1225,10 @@ def make_handler(app: VideoGenApp):
payload = self._read_json() payload = self._read_json()
if path == "/api/profile/start": if path == "/api/profile/start":
self._json({"job_id": app.start_profile_job(payload)}) self._json({"job_id": app.start_profile_job(payload)})
elif path == "/api/voice/start":
self._json({"job_id": app.start_voice_job(payload)})
elif path == "/api/lora/start":
self._json({"job_id": app.start_lora_job(payload)})
elif path == "/api/movie/start": elif path == "/api/movie/start":
self._json({"job_id": app.start_movie_job(payload)}) self._json({"job_id": app.start_movie_job(payload)})
elif path.startswith("/api/job/") and path.endswith("/cancel"): elif path.startswith("/api/job/") and path.endswith("/cancel"):
...@@ -909,6 +1267,7 @@ def build_parser() -> argparse.ArgumentParser: ...@@ -909,6 +1267,7 @@ def build_parser() -> argparse.ArgumentParser:
parser.add_argument("--base-url", default=DEFAULT_BASE_URL, help="CoderAI base URL") parser.add_argument("--base-url", default=DEFAULT_BASE_URL, help="CoderAI base URL")
parser.add_argument("--api-key", default=DEFAULT_API_KEY, help="Bearer token for CoderAI") parser.add_argument("--api-key", default=DEFAULT_API_KEY, help="Bearer token for CoderAI")
parser.add_argument("--out-dir", default=DEFAULT_OUT_DIR, help="Local output directory") parser.add_argument("--out-dir", default=DEFAULT_OUT_DIR, help="Local output directory")
parser.add_argument("--host", default="0.0.0.0", help="Web UI listen host (default: 0.0.0.0)")
parser.add_argument("--web-port", type=int, default=7790, help="Local web UI port") parser.add_argument("--web-port", type=int, default=7790, help="Local web UI port")
parser.add_argument("--image-model", default="", help="Default image model") parser.add_argument("--image-model", default="", help="Default image model")
parser.add_argument("--video-model", default="", help="Default video model") parser.add_argument("--video-model", default="", help="Default video model")
...@@ -921,8 +1280,8 @@ def build_parser() -> argparse.ArgumentParser: ...@@ -921,8 +1280,8 @@ def build_parser() -> argparse.ArgumentParser:
def main(argv: list[str] | None = None) -> None: def main(argv: list[str] | None = None) -> None:
args = build_parser().parse_args(argv) args = build_parser().parse_args(argv)
app = VideoGenApp(args) app = VideoGenApp(args)
server = ThreadedHTTPServer(("127.0.0.1", args.web_port), make_handler(app)) server = ThreadedHTTPServer((args.host, args.web_port), make_handler(app))
url = f"http://127.0.0.1:{args.web_port}" url = f"http://{args.host}:{args.web_port}"
log(f"VideoGen Studio running at {url}") log(f"VideoGen Studio running at {url}")
log(f"CoderAI: {args.base_url}") log(f"CoderAI: {args.base_url}")
log(f"Output: {Path(args.out_dir).resolve()}") log(f"Output: {Path(args.out_dir).resolve()}")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment