Wan2.2 video fixes, pipeline cache, smarter offload, model-load tasks

Wan2.2 A14B (dual-expert) generation fixes:
- Fuse the Lightning distill LoRA into BOTH experts (transformer +
  transformer_2); diffusers' fuse_lora defaults to ["transformer"] only, which
  left the low-noise expert undistilled → 4-step clips collapsed to a solid
  colour. Also load per-request fighter/env LoRAs into both experts.
- Pre-configure the wan22_lightning_4step preset with the local high/low-noise
  LoRAs (lora_high/lora_low), used when acceleration is enabled, ignored when
  not; surfaced in the Acceleration UI.
- Safety net: only apply the preset's low step count when the distill LoRA
  actually fused, else fall back to safe steps.
- Skip bitsandbytes/quanto quant for the VAE (conv-only → "no linear modules").

VRAM / offload:
- Strategy auto-selection actually fires now ('auto' is normalised, not passed
  through as a no-op) and no longer double-counts the runtime/accel reserve.
- Graceful OOM degrade ladder: full-GPU → balanced @ configured% → 80 → 60 →
  40 → sequential → disk, respecting the model's balanced_gpu_percent as the
  starting cap. Expose 'balanced' as a selectable offload strategy.

Pipeline disk cache (--pipeline-cache / --rebuild-pipeline-cache):
- Cache the quantized base pipeline to disk and reload it on later starts,
  skipping re-download/re-quantization; accel LoRA re-fused per load. Fail-safe
  with self-healing invalidate-and-rebuild.

Tasks / misc:
- Show model loading as a (non-cancellable, non-pausable) Tasks entry.
- Filter the Tasks-page pollers from the access log unless --debug-web.
- Township gen script: per-image keyframe progress (no longer all-or-nothing).
Co-Authored-By: 's avatarClaude Opus 4.8 <noreply@anthropic.com>
parent 8ad15128
...@@ -2052,7 +2052,7 @@ async def api_tasks(username: str = Depends(require_admin)): ...@@ -2052,7 +2052,7 @@ async def api_tasks(username: str = Depends(require_admin)):
seen.add(t["id"]) seen.add(t["id"])
t = dict(t) t = dict(t)
t["cancellable"] = bool(t.get("cancellable", True) and t.get("active", False)) t["cancellable"] = bool(t.get("cancellable", True) and t.get("active", False))
t["pausable"] = (t.get("status") == "running") t["pausable"] = bool(t.get("pausable", True) and t.get("status") == "running")
t["restartable"] = False t["restartable"] = False
tasks.append(t) tasks.append(t)
......
...@@ -671,13 +671,15 @@ window.__DEFAULT_WHISPER_SERVER_PATH__ = {{ default_whisper_server_path|tojson } ...@@ -671,13 +671,15 @@ window.__DEFAULT_WHISPER_SERVER_PATH__ = {{ default_whisper_server_path|tojson }
<div class="form-row" style="margin:0"> <div class="form-row" style="margin:0">
<label class="form-label">Strategy</label> <label class="form-label">Strategy</label>
<select id="cfg-offload-strategy" class="form-input"> <select id="cfg-offload-strategy" class="form-input">
<option value="auto">Auto</option> <option value="auto">Auto (pick from free VRAM)</option>
<option value="model">CPU offload (model)</option> <option value="none">None (GPU only)</option>
<option value="sequential">CPU offload (sequential)</option> <option value="balanced">Balanced (fill GPU, spill to CPU → disk)</option>
<option value="model">CPU offload (model — module-by-module)</option>
<option value="sequential">CPU offload (sequential — most aggressive)</option>
<option value="cpu">CPU RAM (legacy)</option> <option value="cpu">CPU RAM (legacy)</option>
<option value="disk">Disk</option> <option value="disk">Disk</option>
<option value="none">None (GPU only)</option>
</select> </select>
<span class="form-hint">Auto picks full-GPU when the weights fit, else Balanced. Pick <b>Balanced</b> + lower the GPU % below for a model that's just over VRAM; <b>model</b>/<b>sequential</b> keep less on GPU (slower, but the safest fit).</span>
</div> </div>
<div class="form-row" style="margin:0"> <div class="form-row" style="margin:0">
<label class="form-label">Offload directory</label> <label class="form-label">Offload directory</label>
...@@ -729,6 +731,12 @@ window.__DEFAULT_WHISPER_SERVER_PATH__ = {{ default_whisper_server_path|tojson } ...@@ -729,6 +731,12 @@ window.__DEFAULT_WHISPER_SERVER_PATH__ = {{ default_whisper_server_path|tojson }
<label class="form-label">Distill LoRA <span class="muted">(path or HF repo, optionally repo:weight_name.safetensors; blank for turbo full-models)</span></label> <label class="form-label">Distill LoRA <span class="muted">(path or HF repo, optionally repo:weight_name.safetensors; blank for turbo full-models)</span></label>
<input type="text" id="cfg-accel-lora" class="form-input" placeholder="e.g. ByteDance/SDXL-Lightning:sdxl_lightning_4step_lora.safetensors"> <input type="text" id="cfg-accel-lora" class="form-input" placeholder="e.g. ByteDance/SDXL-Lightning:sdxl_lightning_4step_lora.safetensors">
</div> </div>
<div class="form-row" style="max-width:560px">
<label class="form-label">Distill LoRA — high/low noise <span class="muted">(Wan2.2 A14B two-expert only; overrides the single LoRA per expert)</span></label>
<input type="text" id="cfg-accel-lora-high" class="form-input" placeholder="high-noise → transformer (e.g. repo:..._high_noise_..._4step.safetensors)">
<input type="text" id="cfg-accel-lora-low" class="form-input" style="margin-top:.4rem" placeholder="low-noise → transformer_2 (e.g. repo:..._low_noise_..._4step.safetensors)">
<span class="form-hint">Wan2.2 A14B has two experts; the distill LoRA must be fused into <b>both</b> or the clip collapses to a solid colour at 4 steps. Leave blank to apply the single LoRA above to both.</span>
</div>
<div style="display:flex;gap:1rem;flex-wrap:wrap"> <div style="display:flex;gap:1rem;flex-wrap:wrap">
<div class="form-row" style="max-width:130px"> <div class="form-row" style="max-width:130px">
<label class="form-label">LoRA weight</label> <label class="form-label">LoRA weight</label>
...@@ -2801,6 +2809,8 @@ async function _populateAccel(a){ ...@@ -2801,6 +2809,8 @@ async function _populateAccel(a){
const sel = document.getElementById('cfg-accel-preset'); const sel = document.getElementById('cfg-accel-preset');
sel.value = [...sel.options].some(o=>o.value===(a.preset||'')) ? a.preset : 'custom'; sel.value = [...sel.options].some(o=>o.value===(a.preset||'')) ? a.preset : 'custom';
document.getElementById('cfg-accel-lora').value = a.lora || ''; document.getElementById('cfg-accel-lora').value = a.lora || '';
document.getElementById('cfg-accel-lora-high').value = a.lora_high || '';
document.getElementById('cfg-accel-lora-low').value = a.lora_low || '';
document.getElementById('cfg-accel-weight').value = a.lora_weight != null ? a.lora_weight : ''; document.getElementById('cfg-accel-weight').value = a.lora_weight != null ? a.lora_weight : '';
document.getElementById('cfg-accel-steps').value = a.steps != null ? a.steps : ''; document.getElementById('cfg-accel-steps').value = a.steps != null ? a.steps : '';
document.getElementById('cfg-accel-guidance').value = a.guidance_scale != null ? a.guidance_scale : ''; document.getElementById('cfg-accel-guidance').value = a.guidance_scale != null ? a.guidance_scale : '';
...@@ -2815,6 +2825,8 @@ function _collectAccel(){ ...@@ -2815,6 +2825,8 @@ function _collectAccel(){
enabled: true, enabled: true,
preset: document.getElementById('cfg-accel-preset').value || 'custom', preset: document.getElementById('cfg-accel-preset').value || 'custom',
lora: document.getElementById('cfg-accel-lora').value.trim() || null, lora: document.getElementById('cfg-accel-lora').value.trim() || null,
lora_high: document.getElementById('cfg-accel-lora-high').value.trim() || null,
lora_low: document.getElementById('cfg-accel-lora-low').value.trim() || null,
lora_weight: num('cfg-accel-weight'), lora_weight: num('cfg-accel-weight'),
steps: num('cfg-accel-steps'), steps: num('cfg-accel-steps'),
guidance_scale: num('cfg-accel-guidance'), guidance_scale: num('cfg-accel-guidance'),
......
...@@ -69,7 +69,7 @@ function fmtTime(s) { ...@@ -69,7 +69,7 @@ function fmtTime(s) {
} catch { return ''; } } catch { return ''; }
} }
const KIND_LABEL = {training:'Training', image:'Image', video:'Video', audio:'Audio', text:'Text', pipeline:'Pipeline', request:'Request'}; const KIND_LABEL = {training:'Training', image:'Image', video:'Video', audio:'Audio', text:'Text', pipeline:'Pipeline', request:'Request', loading:'Loading'};
const STATUS_BADGE = { const STATUS_BADGE = {
running:'badge-admin', queued:'badge-user', done:'badge-ok', error:'badge-err', running:'badge-admin', queued:'badge-user', done:'badge-ok', error:'badge-err',
cancelled:'badge-user', interrupted:'badge-warn' cancelled:'badge-user', interrupted:'badge-warn'
......
...@@ -266,13 +266,15 @@ async def audio_generate(request: AudioGenerationRequest, http_request: Request ...@@ -266,13 +266,15 @@ async def audio_generate(request: AudioGenerationRequest, http_request: Request
device = _derive_device() device = _derive_device()
model_type = _detect_audio_gen_type(model_name) model_type = _detect_audio_gen_type(model_name)
_ag_cfg = model_info.get('config') or {} _ag_cfg = model_info.get('config') or {}
from codai.tasks import loading_task
try: try:
if model_type in ('musicgen', 'audiogen'): with loading_task(model_name, model_type="audio"):
pipe = await asyncio.get_event_loop().run_in_executor( if model_type in ('musicgen', 'audiogen'):
None, _load_musicgen, model_name, device) pipe = await asyncio.get_event_loop().run_in_executor(
else: None, _load_musicgen, model_name, device)
pipe = await asyncio.get_event_loop().run_in_executor( else:
None, _load_audioldm, model_name, device, _ag_cfg) pipe = await asyncio.get_event_loop().run_in_executor(
None, _load_audioldm, model_name, device, _ag_cfg)
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to load audio gen model: {e}") raise HTTPException(status_code=500, detail=f"Failed to load audio gen model: {e}")
multi_model_manager.models[model_key] = pipe multi_model_manager.models[model_key] = pipe
......
...@@ -121,9 +121,11 @@ async def create_embeddings(request: EmbeddingsRequest, http_request: Request = ...@@ -121,9 +121,11 @@ async def create_embeddings(request: EmbeddingsRequest, http_request: Request =
if model_obj is None: if model_obj is None:
device = _derive_device() device = _derive_device()
from codai.tasks import loading_task
try: try:
model_obj = await asyncio.get_event_loop().run_in_executor( with loading_task(model_name, model_type="embedding"):
None, _load_embedding_model, model_name, device, _emb_cfg) model_obj = await asyncio.get_event_loop().run_in_executor(
None, _load_embedding_model, model_name, device, _emb_cfg)
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to load embedding model: {e}") raise HTTPException(status_code=500, detail=f"Failed to load embedding model: {e}")
multi_model_manager.models[model_key] = model_obj multi_model_manager.models[model_key] = model_obj
......
...@@ -1370,7 +1370,14 @@ async def create_image_generation(request: ImageGenerationRequest, http_request: ...@@ -1370,7 +1370,14 @@ async def create_image_generation(request: ImageGenerationRequest, http_request:
is_gguf = _is_gguf_model(model_name) is_gguf = _is_gguf_model(model_name)
diffusers_error = None diffusers_error = None
sdcpp_error = None sdcpp_error = None
# Show the load as a (non-cancellable) Tasks-page entry spanning both
# backend attempts; finished done on success, error only if all fail.
from codai.tasks import task_registry as _treg
_ltid = _treg.register("loading", title=f"Loading {model_name}",
model=model_key, status="running",
cancellable=False, pausable=False)
# Try diffusers first (for non-GGUF models) # Try diffusers first (for non-GGUF models)
if not is_gguf: if not is_gguf:
try: try:
...@@ -1391,6 +1398,7 @@ async def create_image_generation(request: ImageGenerationRequest, http_request: ...@@ -1391,6 +1398,7 @@ async def create_image_generation(request: ImageGenerationRequest, http_request:
pass pass
print(f"Loaded diffusers model: {model_name}") print(f"Loaded diffusers model: {model_name}")
_treg.finish(_ltid, "done")
return await _generate_with_diffusers(pipeline, request, global_args, http_request) return await _generate_with_diffusers(pipeline, request, global_args, http_request)
except ImportError as e: except ImportError as e:
...@@ -1426,7 +1434,8 @@ async def create_image_generation(request: ImageGenerationRequest, http_request: ...@@ -1426,7 +1434,8 @@ async def create_image_generation(request: ImageGenerationRequest, http_request:
except Exception: except Exception:
pass pass
print(f"Loaded sd.cpp model: {model_name}") print(f"Loaded sd.cpp model: {model_name}")
_treg.finish(_ltid, "done")
return await _generate_with_sdcpp(sd_model, request, global_args, return await _generate_with_sdcpp(sd_model, request, global_args,
http_request, model_config=cfg) http_request, model_config=cfg)
else: else:
...@@ -1449,6 +1458,7 @@ async def create_image_generation(request: ImageGenerationRequest, http_request: ...@@ -1449,6 +1458,7 @@ async def create_image_generation(request: ImageGenerationRequest, http_request:
if sdcpp_error: if sdcpp_error:
error_details.append(f"sd.cpp: {sdcpp_error}") error_details.append(f"sd.cpp: {sdcpp_error}")
_treg.finish(_ltid, "error", "; ".join(error_details) or "no compatible backend")
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
detail=f"Failed to load image model '{model_name}'. Errors: {'; '.join(error_details) if error_details else 'No compatible backend found'}" detail=f"Failed to load image model '{model_name}'. Errors: {'; '.join(error_details) if error_details else 'No compatible backend found'}"
......
This diff is collapsed.
...@@ -254,4 +254,18 @@ configuration directory (--config DIR, default: OS-specific CoderAI directory). ...@@ -254,4 +254,18 @@ configuration directory (--config DIR, default: OS-specific CoderAI directory).
"Mid-flight jobs are marked 'cancelled' (checkpoints are kept, so they " "Mid-flight jobs are marked 'cancelled' (checkpoints are kept, so they "
"can still be restarted manually from the Tasks page).", "can still be restarted manually from the Tasks page).",
) )
parser.add_argument(
"--pipeline-cache",
action="store_true",
help="Cache quantized diffusers pipelines to disk after the first build "
"and reload them from that cache on later starts — skipping the "
"expensive re-download/re-quantization (e.g. the Wan2.2 A14B). The "
"fast acceleration LoRA fuse is re-applied per load. Uses extra disk.",
)
parser.add_argument(
"--rebuild-pipeline-cache",
action="store_true",
help="Ignore any existing pipeline cache and rebuild it from scratch this "
"run (use after changing a model's quantization/precision config).",
)
return parser.parse_args() return parser.parse_args()
...@@ -878,6 +878,14 @@ def main(): ...@@ -878,6 +878,14 @@ def main():
if not _resume_jobs: if not _resume_jobs:
print("LoRA job recovery: DISABLED (interrupted training will be cancelled on restart)") print("LoRA job recovery: DISABLED (interrupted training will be cancelled on restart)")
if getattr(args, "pipeline_cache", False):
try:
from codai.models.pipeline_cache import cache_root
_pc_extra = " (rebuilding this run)" if getattr(args, "rebuild_pipeline_cache", False) else ""
print(f"Pipeline cache: ENABLED{_pc_extra} — quantized pipelines cached at {cache_root()}")
except Exception:
print("Pipeline cache: ENABLED")
# Set environment profiles module global args # Set environment profiles module global args
from codai.api.environments import set_global_args as set_envs_global_args from codai.api.environments import set_global_args as set_envs_global_args
set_envs_global_args(global_args) set_envs_global_args(global_args)
...@@ -964,13 +972,18 @@ def main(): ...@@ -964,13 +972,18 @@ def main():
if not _debug_web: if not _debug_web:
class _AccessNoiseFilter(logging.Filter): class _AccessNoiseFilter(logging.Filter):
# uvicorn.access record args: (client_addr, method, full_path, http_ver, status) # uvicorn.access record args: (client_addr, method, full_path, http_ver, status)
_NOISY = ("/v1/loras/progress",) _NOISY_PREFIX = ("/v1/loras/progress",)
# Exact-match only, so the live Tasks-page pollers are dropped but the
# user-initiated action endpoints (/admin/api/tasks/{id}/pause, …) still log.
_NOISY_EXACT = ("/admin/api/tasks", "/admin/api/system-stats")
def filter(self, record): def filter(self, record):
try: try:
args = record.args args = record.args
if isinstance(args, (tuple, list)) and len(args) >= 3: if isinstance(args, (tuple, list)) and len(args) >= 3:
path = str(args[2]).split("?", 1)[0] path = str(args[2]).split("?", 1)[0]
if any(path == p or path.startswith(p) for p in self._NOISY): if path in self._NOISY_EXACT:
return False
if any(path == p or path.startswith(p) for p in self._NOISY_PREFIX):
return False return False
except Exception: except Exception:
pass pass
......
...@@ -52,7 +52,15 @@ ACCEL_PRESETS: dict = { ...@@ -52,7 +52,15 @@ ACCEL_PRESETS: dict = {
"label": "Wan2.2 Lightning (4-step DMD)", "label": "Wan2.2 Lightning (4-step DMD)",
"family": "wan", "family": "wan",
"applies_to": ["video"], "applies_to": ["video"],
"lora": "lightx2v/Wan2.2-Lightning", # Wan2.2 A14B is a two-expert MoE: the distill LoRA must be fused into BOTH
# the high-noise (transformer) and low-noise (transformer_2) experts, or the
# clip collapses to a solid colour at 4 steps. These default to the locally
# installed lightx2v/Wan2.2-Lightning weights (resolved from cache — not a
# download); override per model in the Acceleration config for T2V or a
# different rank/version. `lora` stays None because the two experts differ.
"lora": None,
"lora_high": "lightx2v/Wan2.2-Lightning:Wan2.2-I2V-A14B-4steps-lora-rank64-Seko-V1/high_noise_model.safetensors",
"lora_low": "lightx2v/Wan2.2-Lightning:Wan2.2-I2V-A14B-4steps-lora-rank64-Seko-V1/low_noise_model.safetensors",
"lora_weight": 1.0, "lora_weight": 1.0,
"steps": 4, "steps": 4,
"guidance_scale": 1.0, "guidance_scale": 1.0,
...@@ -175,6 +183,12 @@ def resolve_acceleration(model_cfg: Optional[dict]) -> Optional[dict]: ...@@ -175,6 +183,12 @@ def resolve_acceleration(model_cfg: Optional[dict]) -> Optional[dict]:
out = { out = {
"preset": preset_key or "custom", "preset": preset_key or "custom",
"lora": _pick("lora"), "lora": _pick("lora"),
# Wan2.2 A14B is a two-expert MoE: the distill LoRA differs for the
# high-noise (transformer) and low-noise (transformer_2) experts. When
# these are set they take precedence over the single `lora` per expert;
# otherwise the single `lora` is applied to BOTH experts.
"lora_high": _pick("lora_high"),
"lora_low": _pick("lora_low"),
"lora_weight": _pick("lora_weight", 1.0), "lora_weight": _pick("lora_weight", 1.0),
"steps": _pick("steps"), "steps": _pick("steps"),
"guidance_scale": _pick("guidance_scale"), "guidance_scale": _pick("guidance_scale"),
...@@ -255,35 +269,94 @@ def apply_accel_to_pipeline(pipe, accel: Optional[dict]) -> None: ...@@ -255,35 +269,94 @@ def apply_accel_to_pipeline(pipe, accel: Optional[dict]) -> None:
log.warning("[accel] flow_shift apply failed: %s", e) log.warning("[accel] flow_shift apply failed: %s", e)
# 3. Fuse the distill LoRA (when one is configured — turbo has none). # 3. Fuse the distill LoRA (when one is configured — turbo has none).
lora_ref = accel.get("lora") # `_coderai_accel_fused` records whether a distill LoRA actually baked in,
if not lora_ref: # so the generator only drops to the preset's low step count when the model
# is genuinely distilled (running 4 steps un-distilled collapses the video
# to a solid colour — exactly the Wan2.2 dual-expert failure mode).
try:
pipe._coderai_accel_fused = False
except Exception:
pass
has_t2 = getattr(pipe, "transformer_2", None) is not None
lora_high = accel.get("lora_high") or accel.get("lora")
lora_low = accel.get("lora_low") or accel.get("lora")
if not lora_high and not lora_low:
# No LoRA (e.g. a full distilled model like SDXL-Turbo) — treat as distilled.
try:
pipe._coderai_accel_fused = True
except Exception:
pass
return return
if not hasattr(pipe, "load_lora_weights"): if not hasattr(pipe, "load_lora_weights"):
log.warning("[accel] pipeline %s has no load_lora_weights — cannot fuse " log.warning("[accel] pipeline %s has no load_lora_weights — cannot fuse "
"acceleration LoRA", type(pipe).__name__) "acceleration LoRA", type(pipe).__name__)
return return
repo, weight_name = _split_lora_ref(lora_ref)
weight = float(accel.get("lora_weight") or 1.0) weight = float(accel.get("lora_weight") or 1.0)
try:
load_kwargs = {"adapter_name": "__accel__"} def _load_one(ref, into_t2: bool, adapter: str) -> bool:
repo, weight_name = _split_lora_ref(ref)
kw = {"adapter_name": adapter}
if weight_name: if weight_name:
load_kwargs["weight_name"] = weight_name kw["weight_name"] = weight_name
pipe.load_lora_weights(repo, **load_kwargs) if into_t2:
kw["load_into_transformer_2"] = True
pipe.load_lora_weights(repo, **kw)
return True
loaded_adapters = []
try:
# High-noise expert (transformer) — always.
if lora_high and _load_one(lora_high, False, "__accel__"):
loaded_adapters.append("__accel__")
# Low-noise expert (transformer_2) — only on dual-expert Wan2.2 models.
if has_t2 and lora_low:
try:
if _load_one(lora_low, True, "__accel_2__"):
loaded_adapters.append("__accel_2__")
except Exception as e2:
log.warning("[accel] could not load distill LoRA into transformer_2 "
"(%s) — low-noise expert stays un-distilled", e2)
elif has_t2 and not lora_low:
log.warning("[accel] model has a second expert (transformer_2) but no "
"low-noise distill LoRA — set acceleration.lora_low")
if not loaded_adapters:
raise RuntimeError("no distill adapter registered on the pipeline")
try: try:
pipe.set_adapters(["__accel__"], [weight]) pipe.set_adapters(loaded_adapters, [weight] * len(loaded_adapters))
except Exception: except Exception:
pass pass
# Bake it in, then drop the adapter handle so per-request LoRAs are clean. # Bake them in, then drop the adapter handles so per-request LoRAs are clean.
pipe.fuse_lora(lora_scale=weight) # CRITICAL: diffusers' Wan fuse_lora defaults to components=["transformer"],
# so without naming transformer_2 the low-noise expert's distill adapter is
# never fused — and the subsequent unload strips it off, leaving that expert
# undistilled. At 4 steps that collapses the clip to a solid colour. Fuse
# BOTH experts explicitly.
_fuse_components = ["transformer"]
if has_t2:
_fuse_components.append("transformer_2")
try:
pipe.fuse_lora(components=_fuse_components, lora_scale=weight)
except TypeError:
# Older diffusers without the `components` kwarg — best effort.
pipe.fuse_lora(lora_scale=weight)
try: try:
pipe.unload_lora_weights() pipe.unload_lora_weights()
except Exception: except Exception:
pass pass
log.info("[accel] fused distillation LoRA %s (weight=%s) into %s", try:
repo, weight, type(pipe).__name__) pipe._coderai_accel_fused = True
except Exception:
pass
log.info("[accel] fused distillation LoRA(s) %s (weight=%s) into %s%s",
loaded_adapters, weight, type(pipe).__name__,
" (both experts)" if len(loaded_adapters) > 1 else "")
except Exception as e: except Exception as e:
log.warning("[accel] failed to fuse acceleration LoRA %s: %s — generating " log.warning("[accel] failed to fuse acceleration LoRA (high=%s low=%s): %s "
"without acceleration", lora_ref, e) "— generating WITHOUT acceleration (step count will fall back to "
"a safe default, not the preset's distilled count)",
lora_high, lora_low, e)
def accel_call_defaults(accel: Optional[dict]) -> dict: def accel_call_defaults(accel: Optional[dict]) -> dict:
......
...@@ -182,6 +182,15 @@ def build_pipeline_quant_config(model_name: str, cfg: Optional[Dict[str, Any]], ...@@ -182,6 +182,15 @@ def build_pipeline_quant_config(model_name: str, cfg: Optional[Dict[str, Any]],
bnb_4bit_compute_dtype=dtype, bnb_4bit_use_double_quant=True) bnb_4bit_compute_dtype=dtype, bnb_4bit_use_double_quant=True)
return BnB(load_in_8bit=True) return BnB(load_in_8bit=True)
def _bnb_incompatible(name: str) -> bool:
# bitsandbytes (4/8-bit) and optimum-quanto (2-bit) only quantize
# nn.Linear. A fully-convolutional component (the VAE) has no Linear
# layers, so applying them triggers a hard "no linear modules were found"
# error. Such components must stay full precision (a smaller VAE comes
# from a GGUF VAE instead, handled separately).
n = (name or '').lower()
return n == 'vae' or n.endswith('_vae') or n.startswith('vae')
quant_mapping: Dict[str, Any] = {} quant_mapping: Dict[str, Any] = {}
descs = [] descs = []
if comp_q: if comp_q:
...@@ -189,6 +198,11 @@ def build_pipeline_quant_config(model_name: str, cfg: Optional[Dict[str, Any]], ...@@ -189,6 +198,11 @@ def build_pipeline_quant_config(model_name: str, cfg: Optional[Dict[str, Any]],
mode = _normalize_quant_mode(raw_mode) # GGUF/none → None here mode = _normalize_quant_mode(raw_mode) # GGUF/none → None here
if mode is None: if mode is None:
continue continue
if _bnb_incompatible(name):
print(f" Skipping {mode} for '{name}': it has no Linear layers "
f"(conv-only) — bitsandbytes/quanto cannot quantize it; "
f"leaving full precision (use a GGUF VAE to shrink the VAE).")
continue
cfg_obj = _mk(comp_lib.get(name, 'diffusers'), mode) cfg_obj = _mk(comp_lib.get(name, 'diffusers'), mode)
if cfg_obj is not None: if cfg_obj is not None:
quant_mapping[name] = cfg_obj quant_mapping[name] = cfg_obj
...@@ -198,6 +212,8 @@ def build_pipeline_quant_config(model_name: str, cfg: Optional[Dict[str, Any]], ...@@ -198,6 +212,8 @@ def build_pipeline_quant_config(model_name: str, cfg: Optional[Dict[str, Any]],
targets = [n for n in comp_lib if _is_heavy(n)] or \ targets = [n for n in comp_lib if _is_heavy(n)] or \
['transformer', 'transformer_2', 'text_encoder', 'unet'] ['transformer', 'transformer_2', 'text_encoder', 'unet']
for name in targets: for name in targets:
if _bnb_incompatible(name):
continue
cfg_obj = _mk(comp_lib.get(name, 'diffusers'), mode) cfg_obj = _mk(comp_lib.get(name, 'diffusers'), mode)
if cfg_obj is not None: if cfg_obj is not None:
quant_mapping[name] = cfg_obj quant_mapping[name] = cfg_obj
......
...@@ -817,7 +817,9 @@ class MultiModelManager: ...@@ -817,7 +817,9 @@ class MultiModelManager:
print(f"Loading default model on demand: {self.default_model}") print(f"Loading default model on demand: {self.default_model}")
_snap = self.vram_before_load() _snap = self.vram_before_load()
kwargs['expected_vram_gb'] = self._get_model_used_vram_gb(self.default_model) kwargs['expected_vram_gb'] = self._get_model_used_vram_gb(self.default_model)
model_manager.load_model(self.default_model, backend_type=backend_type, **kwargs) from codai.tasks import loading_task
with loading_task(self.default_model, model_type="text"):
model_manager.load_model(self.default_model, backend_type=backend_type, **kwargs)
self.add_model(self.default_model, model_manager) self.add_model(self.default_model, model_manager)
self.record_vram_delta(self.default_model, _snap) self.record_vram_delta(self.default_model, _snap)
self.current_model_key = self.default_model self.current_model_key = self.default_model
...@@ -916,7 +918,9 @@ class MultiModelManager: ...@@ -916,7 +918,9 @@ class MultiModelManager:
# it can decide whether Flash-Attention-2 is safe (FA2 requires the # it can decide whether Flash-Attention-2 is safe (FA2 requires the
# whole model on GPU; it device-side-asserts when layers offload). # whole model on GPU; it device-side-asserts when layers offload).
kwargs['expected_vram_gb'] = self._get_model_used_vram_gb(model_name) kwargs['expected_vram_gb'] = self._get_model_used_vram_gb(model_name)
model_manager.load_model(model_name, backend_type=backend_type, **kwargs) from codai.tasks import loading_task
with loading_task(model_name, model_type="text"):
model_manager.load_model(model_name, backend_type=backend_type, **kwargs)
self.add_model(model_name, model_manager) self.add_model(model_name, model_manager)
self.record_vram_delta(model_name, _snap) self.record_vram_delta(model_name, _snap)
self.current_model_key = model_name self.current_model_key = model_name
......
# CoderAI - OpenAI-compatible API server
# Copyright (C) 2026 Stefy Lanza <stefy@nexlab.net>
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
"""On-disk cache of *built* diffusers pipelines.
Building a large quantized video pipeline (e.g. Wan2.2 A14B at 4-bit) is slow:
download + bitsandbytes quantization of ~28B parameters. The weights don't change
between restarts, so once built we ``save_pretrained`` the pipeline to a local
cache keyed by ``(model, quantization, precision)``. A later start with
``--pipeline-cache`` reloads from there with a plain ``from_pretrained`` of the
already-quantized weights — no re-download, no re-quantization.
Scope: only the *base* pipeline is cached. The acceleration/distillation LoRA is
NOT baked into the cache — it is re-fused on every load (a fast operation), so the
cache stays independent of the (cheap to re-apply) ``acceleration`` config and we
avoid the fragile round-trip of serialising a fused + quantized model.
Everything here is best-effort: any failure (save or load) is swallowed and the
caller falls back to a normal build, so the cache can never break generation.
"""
import hashlib
import json
import os
import shutil
import time
from typing import Optional
# Bump when the cache layout / marker format changes so stale caches are ignored.
_CACHE_VERSION = 1
def _global_args():
try:
from codai.api.state import get_global_args
return get_global_args()
except Exception:
return None
def enabled() -> bool:
"""True when --pipeline-cache was passed."""
ga = _global_args()
return bool(ga is not None and getattr(ga, "pipeline_cache", False))
def _force_rebuild() -> bool:
ga = _global_args()
return bool(ga is not None and getattr(ga, "rebuild_pipeline_cache", False))
def cache_root() -> str:
"""Root dir for cached pipelines. Sits next to the offload dir by default."""
ga = _global_args()
offload_dir = getattr(ga, "offload_dir", None) if ga else None
if offload_dir:
root = os.path.join(os.path.dirname(os.path.abspath(os.path.expanduser(offload_dir))),
"pipeline_cache")
else:
root = os.path.join(os.path.expanduser("~"), ".cache", "coderai", "pipeline_cache")
return root
def _signature(model_name: str, model_cfg: Optional[dict]) -> str:
"""Stable hash of everything that changes the *built* (quantized) weights:
the model id, the quantization choices, and the precision. NOT acceleration
(re-applied per load) and NOT offload (a runtime placement decision)."""
c = model_cfg or {}
payload = {
"v": _CACHE_VERSION,
"model": model_name,
"precision": c.get("precision") or "bf16",
"load_in_4bit": bool(c.get("load_in_4bit", False)),
"load_in_8bit": bool(c.get("load_in_8bit", False)),
"component_quantization": c.get("component_quantization") or {},
}
blob = json.dumps(payload, sort_keys=True, default=str)
return hashlib.sha256(blob.encode()).hexdigest()[:16]
def _safe_name(model_name: str) -> str:
return "".join(ch if ch.isalnum() or ch in "-._" else "_" for ch in model_name)[:80]
def path(model_name: str, model_cfg: Optional[dict]) -> str:
"""Absolute cache directory for this model + quant/precision signature."""
return os.path.join(cache_root(),
f"{_safe_name(model_name)}__{_signature(model_name, model_cfg)}")
def _marker(p: str) -> str:
return os.path.join(p, ".coderai_pipeline_cache.json")
def valid(p: str) -> bool:
"""True if a complete, current cache exists at ``p`` and rebuild wasn't forced."""
if not p or _force_rebuild():
return False
try:
if not os.path.isfile(os.path.join(p, "model_index.json")):
return False
with open(_marker(p)) as f:
meta = json.load(f)
return meta.get("version") == _CACHE_VERSION and meta.get("complete") is True
except Exception:
return False
def invalidate(model_name: str, model_cfg: Optional[dict]) -> None:
"""Delete a model's cache dir (e.g. after a failed cache load) so the next
build rewrites it. Best-effort."""
try:
p = path(model_name, model_cfg)
if p and os.path.isdir(p):
shutil.rmtree(p, ignore_errors=True)
print(f" [pipeline-cache] invalidated {p}")
except Exception:
pass
def save(pipe, p: str, *, model_name: str = "", model_cfg: Optional[dict] = None) -> bool:
"""Serialize ``pipe`` to the cache dir ``p`` (atomic via a temp dir).
Returns True on success. Any failure is logged and returns False — the caller
keeps the freshly built in-memory pipeline regardless."""
if not p:
return False
tmp = p + ".building"
try:
os.makedirs(cache_root(), exist_ok=True)
if os.path.exists(tmp):
shutil.rmtree(tmp, ignore_errors=True)
print(f" [pipeline-cache] saving quantized pipeline → {p}")
t0 = time.time()
pipe.save_pretrained(tmp)
with open(_marker(tmp), "w") as f:
json.dump({
"version": _CACHE_VERSION, "complete": True,
"model": model_name, "saved_at": time.time(),
"signature": _signature(model_name, model_cfg),
}, f)
if os.path.exists(p):
shutil.rmtree(p, ignore_errors=True)
os.replace(tmp, p)
print(f" [pipeline-cache] saved in {time.time() - t0:.0f}s")
return True
except Exception as e:
print(f" [pipeline-cache] save failed ({e}) — continuing without a cache")
try:
shutil.rmtree(tmp, ignore_errors=True)
except Exception:
pass
return False
...@@ -23,6 +23,7 @@ from codai.tasks.registry import ( ...@@ -23,6 +23,7 @@ from codai.tasks.registry import (
task_registry, task_registry,
raise_if_cancelled, raise_if_cancelled,
wait_if_paused, wait_if_paused,
loading_task,
) )
__all__ = [ __all__ = [
...@@ -32,4 +33,5 @@ __all__ = [ ...@@ -32,4 +33,5 @@ __all__ = [
"task_registry", "task_registry",
"raise_if_cancelled", "raise_if_cancelled",
"wait_if_paused", "wait_if_paused",
"loading_task",
] ]
...@@ -32,6 +32,7 @@ a task with a ``job_id`` links the two. ...@@ -32,6 +32,7 @@ a task with a ``job_id`` links the two.
import threading import threading
import time import time
import uuid import uuid
from contextlib import contextmanager
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from typing import Dict, List, Optional from typing import Dict, List, Optional
...@@ -60,6 +61,7 @@ class Task: ...@@ -60,6 +61,7 @@ class Task:
ended_at: Optional[float] = None ended_at: Optional[float] = None
cancellable: bool = True cancellable: bool = True
restartable: bool = False restartable: bool = False
pausable: bool = True
paused: bool = False paused: bool = False
def to_dict(self) -> dict: def to_dict(self) -> dict:
...@@ -83,13 +85,14 @@ class TaskRegistry: ...@@ -83,13 +85,14 @@ class TaskRegistry:
def register(self, kind: str, *, title: str = "", model: str = "", def register(self, kind: str, *, title: str = "", model: str = "",
total: int = 0, job_id: Optional[str] = None, total: int = 0, job_id: Optional[str] = None,
status: str = "queued", cancellable: bool = True, status: str = "queued", cancellable: bool = True,
restartable: bool = False, task_id: Optional[str] = None) -> str: restartable: bool = False, pausable: bool = True,
task_id: Optional[str] = None) -> str:
tid = task_id or f"task-{uuid.uuid4().hex[:12]}" tid = task_id or f"task-{uuid.uuid4().hex[:12]}"
with self._lock: with self._lock:
self._tasks[tid] = Task( self._tasks[tid] = Task(
id=tid, kind=kind, title=title, model=model, total=total, id=tid, kind=kind, title=title, model=model, total=total,
job_id=job_id, status=status, cancellable=cancellable, job_id=job_id, status=status, cancellable=cancellable,
restartable=restartable, restartable=restartable, pausable=pausable,
) )
self._events[tid] = threading.Event() self._events[tid] = threading.Event()
self._pause_events[tid] = threading.Event() self._pause_events[tid] = threading.Event()
...@@ -256,3 +259,25 @@ def wait_if_paused(task_id: Optional[str]) -> None: ...@@ -256,3 +259,25 @@ def wait_if_paused(task_id: Optional[str]) -> None:
Returns immediately when not paused; raises :class:`TaskCancelled` if the Returns immediately when not paused; raises :class:`TaskCancelled` if the
task is cancelled while paused. A falsy ``task_id`` is a no-op.""" task is cancelled while paused. A falsy ``task_id`` is a no-op."""
task_registry.wait_if_paused(task_id) task_registry.wait_if_paused(task_id)
@contextmanager
def loading_task(model: str, *, model_type: str = "model", title: Optional[str] = None):
"""Context manager that shows a model load as a Tasks-page entry.
Model loading can't be paused or cancelled (it's a single blocking
``from_pretrained`` / ``Llama(...)`` call), so the task is registered
non-cancellable and non-pausable — the Tasks UI shows it with no action
buttons. The task finishes ``done`` on success or ``error`` on exception.
Re-entrant guard: a nested load of the same model_key reuses no task; each
call is independent (loads don't nest in practice)."""
label = title or f"Loading {model}"
tid = task_registry.register(
"loading", title=label, model=model or "", status="running",
cancellable=False, restartable=False, pausable=False)
try:
yield tid
task_registry.finish(tid, "done")
except BaseException as e: # noqa: BLE001 — record then re-raise
task_registry.finish(tid, "error", str(e)[:200] or e.__class__.__name__)
raise
...@@ -1709,10 +1709,21 @@ def _generate_keyframes(client: CoderAIClient, image_model: str, keyframe_dir: P ...@@ -1709,10 +1709,21 @@ def _generate_keyframes(client: CoderAIClient, image_model: str, keyframe_dir: P
fight_plan: list, outcome_plan: list, consistency: set, fight_plan: list, outcome_plan: list, consistency: set,
lora_map: dict, char_strength: float, keyframe_steps: int, lora_map: dict, char_strength: float, keyframe_steps: int,
keyframe_size: str, lora_weight: float, keyframe_size: str, lora_weight: float,
env_lora_map: dict = None, env_lora_weight: float = 0.8): env_lora_map: dict = None, env_lora_weight: float = 0.8,
kf_cb=None):
"""Generate one keyframe still per clip (image model). Saved as PNG keyed by """Generate one keyframe still per clip (image model). Saved as PNG keyed by
the clip's output stem so the render phase can pick them up as init images. the clip's output stem so the render phase can pick them up as init images.
Resumable: existing PNGs are kept.""" Resumable: existing PNGs are kept.
kf_cb(stem, phase, ok) — optional; fired so callers (the web match-render job)
can show per-image progress. phase is "start" (this keyframe begins) or "end"
(finished, ok=True/False); a reused/existing PNG fires "end" with ok=True."""
def _kf(stem, phase, ok=None):
if kf_cb:
try:
kf_cb(stem, phase, ok)
except Exception:
pass
keyframe_dir.mkdir(parents=True, exist_ok=True) keyframe_dir.mkdir(parents=True, exist_ok=True)
use_ip = "ipadapter" in consistency or "keyframe" in consistency use_ip = "ipadapter" in consistency or "keyframe" in consistency
use_lora = "lora" in consistency use_lora = "lora" in consistency
...@@ -1751,7 +1762,9 @@ def _generate_keyframes(client: CoderAIClient, image_model: str, keyframe_dir: P ...@@ -1751,7 +1762,9 @@ def _generate_keyframes(client: CoderAIClient, image_model: str, keyframe_dir: P
out_png = keyframe_dir / f"{stem}.png" out_png = keyframe_dir / f"{stem}.png"
if out_png.exists() and out_png.stat().st_size > 0: if out_png.exists() and out_png.stat().st_size > 0:
skipped += 1 skipped += 1
_kf(stem, "end", True) # already present — show it as done
continue continue
_kf(stem, "start")
profiles = list(fighters) if use_ip else None profiles = list(fighters) if use_ip else None
loras = None loras = None
if use_lora: if use_lora:
...@@ -1771,8 +1784,10 @@ def _generate_keyframes(client: CoderAIClient, image_model: str, keyframe_dir: P ...@@ -1771,8 +1784,10 @@ def _generate_keyframes(client: CoderAIClient, image_model: str, keyframe_dir: P
) )
out_png.write_bytes(img) out_png.write_bytes(img)
made += 1 made += 1
_kf(stem, "end", True)
except Exception as e: except Exception as e:
failed += 1 failed += 1
_kf(stem, "end", False)
_log(f" ✗ keyframe {stem} failed: {e}") _log(f" ✗ keyframe {stem} failed: {e}")
_log(f" ── Keyframes: {made} new, {skipped} reused, {failed} failed ──") _log(f" ── Keyframes: {made} new, {skipped} reused, {failed} failed ──")
...@@ -2968,13 +2983,30 @@ def launch_web_ui(default_args): ...@@ -2968,13 +2983,30 @@ def launch_web_ui(default_args):
_done("no missing keyframes — all present") _done("no missing keyframes — all present")
return return
_set_items([f"keyframe {s}" for s in work]) _set_items([f"keyframe {s}" for s in work])
for i, s in enumerate(work): _kf_idx = {s: i for i, s in enumerate(work)}
_item(i, "start") # Delete the targeted PNGs first so they're actually regenerated
if not missing_only: # (missing-only keeps existing ones → reported done by the callback).
if not missing_only:
for s in work:
try: try:
(kdir / f"{s}.png").unlink() (kdir / f"{s}.png").unlink()
except Exception: except Exception:
pass pass
# Per-image progress: _generate_keyframes fires kf_cb as each
# keyframe starts/finishes, so the bars advance image-by-image
# instead of all flipping at the end.
_kf_done = [0]
def _kf_cb(stem, phase, ok=None):
i = _kf_idx.get(stem)
if i is None:
return
_item(i, phase, ok)
if phase == "end":
_kf_done[0] += 1
_prog(10 + int(88 * _kf_done[0] / max(1, len(work))),
f"keyframe {_kf_done[0]}/{len(work)} done")
_prog(10, ("filling in {n} missing keyframe(s)…" if missing_only _prog(10, ("filling in {n} missing keyframe(s)…" if missing_only
else "regenerating {n} keyframe(s)…").format(n=len(work))) else "regenerating {n} keyframe(s)…").format(n=len(work)))
try: try:
...@@ -2984,11 +3016,13 @@ def launch_web_ui(default_args): ...@@ -2984,11 +3016,13 @@ def launch_web_ui(default_args):
float(getattr(default_args, "character_strength", 0.7)), float(getattr(default_args, "character_strength", 0.7)),
int(getattr(default_args, "keyframe_steps", 28)), int(getattr(default_args, "keyframe_steps", 28)),
getattr(default_args, "keyframe_size", "512x512"), lw, getattr(default_args, "keyframe_size", "512x512"), lw,
env_lora_map=env_lora_map, env_lora_weight=elw) env_lora_map=env_lora_map, env_lora_weight=elw,
kf_cb=_kf_cb)
except Exception as e: except Exception as e:
_fail(f"keyframe regeneration failed: {e}") _fail(f"keyframe regeneration failed: {e}")
return return
# Mark each item done/failed by whether its PNG now exists. # Safety net: resolve any item the callback didn't (e.g. a stem
# _generate_keyframes never visited) by whether its PNG now exists.
for i, s in enumerate(work): for i, s in enumerate(work):
_item(i, "end", (kdir / f"{s}.png").exists()) _item(i, "end", (kdir / f"{s}.png").exists())
made = sum(1 for s in work if (kdir / f"{s}.png").exists()) made = sum(1 for s in work if (kdir / f"{s}.png").exists())
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment