Wan LoRA: cache the whole stack (VAE + UMT5 + transformer) across jobs

Extend the cross-job cache from just the transformer expert(s) to the full Wan
stack: VAE, tokenizer and text encoder are kept on CPU between jobs (moved to GPU
only while encoding), experts stay on GPU. A back-to-back training against the
same base now reloads nothing from disk — previously the small VAE/text-encoder
still reloaded each job. The releaser and error path clear all cached components.
Co-Authored-By: 's avatarClaude Opus 4.8 <noreply@anthropic.com>
parent fc78dc98
......@@ -395,7 +395,11 @@ def _release_base_cache(needed_gb: float = 0.0) -> float:
# between jobs — the external releaser above drops them the moment a generation
# needs the GPU.
_wan_lock = threading.RLock()
_wan_cache = {"path": None, "quantize": None, "experts": None, "boundary": None}
# Caches the whole Wan stack between same-base jobs: transformer expert(s) on the
# GPU, plus the VAE / tokenizer / text-encoder held on CPU (moved to the GPU only
# while encoding). So a back-to-back training reloads nothing from disk.
_wan_cache = {"path": None, "quantize": None, "experts": None, "boundary": None,
"vae": None, "tokenizer": None, "text_encoder": None}
def _drop_wan_cache_locked() -> None:
......@@ -406,7 +410,8 @@ def _drop_wan_cache_locked() -> None:
tr.delete_adapters("default")
except Exception:
pass
_wan_cache.update(path=None, quantize=None, experts=None, boundary=None)
_wan_cache.update(path=None, quantize=None, experts=None, boundary=None,
vae=None, tokenizer=None, text_encoder=None)
def _drop_wan_cache() -> None:
......@@ -417,21 +422,22 @@ def _drop_wan_cache() -> None:
_free_train_vram()
def _acquire_wan_experts(base_path, quantize, load_fn):
"""Return cached Wan (experts, boundary) for base_path, loading via load_fn()
on a miss. A change of base_path/quantize drops the previous cache first."""
def _acquire_wan_components(base_path, quantize, load_fn):
"""Return the cached Wan stack {vae, tokenizer, text_encoder, experts,
boundary} for base_path, building it via load_fn() on a miss. A change of
base_path/quantize drops the previous cache first."""
with _wan_lock:
c = _wan_cache
if (c["experts"] is not None and c["path"] == base_path
and c["quantize"] == quantize):
_dbg_lora(f"reusing cached Wan transformer(s): {base_path}")
return c["experts"], c["boundary"]
if c["experts"] is not None:
_dbg_lora(f"reusing cached Wan stack (no reload): {base_path}")
return c
if c["experts"] is not None or c["vae"] is not None:
_dbg_lora(f"Wan base changed ({c['path']} → {base_path}); dropping cache")
_drop_wan_cache_locked()
experts, boundary = load_fn()
c.update(path=base_path, quantize=quantize, experts=experts, boundary=boundary)
return experts, boundary
comps = load_fn() # {vae, tokenizer, text_encoder, experts, boundary}
c.update(path=base_path, quantize=quantize, **comps)
return c
# Let the model manager reclaim these caches when a generation needs VRAM.
......@@ -887,12 +893,59 @@ def _train_wan(req, base_path, images, instance_prompt,
quantize = bool(getattr(req, "quantize_4bit", True))
num_frames = max(1, int(getattr(req, "num_frames", 1) or 1))
# ── 1. VAE (3D): encode each still as a 1-frame video latent, then offload ──
_set_progress(status="preparing", message=f"loading Wan VAE: {base_path}")
vae = AutoencoderKLWan.from_pretrained(base_path, subfolder="vae",
torch_dtype=torch.float32).to(device)
vae.requires_grad_(False)
vae.eval()
q_cfg = None
if quantize:
try:
from diffusers import BitsAndBytesConfig as _DiffBnb
q_cfg = _DiffBnb(load_in_4bit=True, bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=compute_dtype)
except Exception as e:
print(f" [lora][wan] 4-bit unavailable ({e}); loading in bf16")
q_cfg = None
def _load_transformer(subfolder):
kw = dict(subfolder=subfolder, torch_dtype=compute_dtype)
if q_cfg is not None:
kw["quantization_config"] = q_cfg
tr = WanTransformer3DModel.from_pretrained(base_path, **kw)
if q_cfg is None:
tr = tr.to(device)
return tr
def _build_components():
# Built once per base and cached: VAE/tokenizer/text-encoder live on CPU
# between jobs (moved to GPU only while encoding); experts stay on GPU.
_set_progress(status="preparing", message=f"loading Wan VAE: {base_path}")
_vae = AutoencoderKLWan.from_pretrained(base_path, subfolder="vae",
torch_dtype=torch.float32)
_vae.requires_grad_(False)
_vae.eval()
_set_progress(status="preparing", message="loading UMT5 text encoder")
_tok = AutoTokenizer.from_pretrained(base_path, subfolder="tokenizer")
_te = UMT5EncoderModel.from_pretrained(
base_path, subfolder="text_encoder", torch_dtype=compute_dtype)
_te.requires_grad_(False)
_te.eval()
_set_progress(
status="preparing",
message=f"loading Wan transformer{' (4-bit)' if quantize else ''}")
exp = [("transformer", _load_transformer("transformer"))]
if _os.path.isdir(_os.path.join(base_path, "transformer_2")):
exp.append(("transformer_2", _load_transformer("transformer_2")))
b = getattr(exp[0][1].config, "boundary_ratio", None)
return {"vae": _vae, "tokenizer": _tok, "text_encoder": _te,
"experts": exp, "boundary": b}
# Reuse the whole Wan stack across consecutive trainings against the same base.
comps = _acquire_wan_components(base_path, quantize, _build_components)
vae = comps["vae"]
tokenizer = comps["tokenizer"]
text_encoder = comps["text_encoder"]
experts, boundary = comps["experts"], comps["boundary"]
# ── 1. VAE (3D): encode each still as a 1-frame video latent ───────────────
_set_progress(status="preparing", message="encoding reference images (VAE)")
vae.to(device)
z_dim = int(vae.config.z_dim)
lat_mean = torch.tensor(vae.config.latents_mean).view(1, z_dim, 1, 1, 1).to(device)
lat_std = torch.tensor(vae.config.latents_std).view(1, z_dim, 1, 1, 1).to(device)
......@@ -914,17 +967,12 @@ def _train_wan(req, base_path, images, instance_prompt,
lat = vae.encode(vid).latent_dist.sample() # [1,z,t,h,w]
lat = (lat - lat_mean) / lat_std
latents_list.append(lat.to(compute_dtype).cpu())
vae.to("cpu")
del vae
vae.to("cpu") # keep cached, free its VRAM
_free_train_vram()
# ── 2. Text encoder (UMT5): encode the instance prompt once, then offload ──
# ── 2. Text encoder (UMT5): encode the instance prompt once ────────────────
_set_progress(status="preparing", message="encoding prompt (UMT5)")
tokenizer = AutoTokenizer.from_pretrained(base_path, subfolder="tokenizer")
text_encoder = UMT5EncoderModel.from_pretrained(
base_path, subfolder="text_encoder", torch_dtype=compute_dtype).to(device)
text_encoder.requires_grad_(False)
text_encoder.eval()
text_encoder.to(device)
with torch.no_grad():
tok = tokenizer(instance_prompt, padding="max_length", max_length=512,
truncation=True, return_tensors="pt")
......@@ -932,42 +980,10 @@ def _train_wan(req, base_path, images, instance_prompt,
mask = tok.attention_mask.to(device)
enc = text_encoder(ids, attention_mask=mask).last_hidden_state
encoder_hidden_states = (enc * mask.unsqueeze(-1)).to(compute_dtype).cpu()
text_encoder.to("cpu")
del text_encoder
text_encoder.to("cpu") # keep cached, free its VRAM
_free_train_vram()
# ── 3. Transformer expert(s) + LoRA ───────────────────────────────────────
_set_progress(status="preparing",
message=f"loading Wan transformer{' (4-bit)' if quantize else ''}")
q_cfg = None
if quantize:
try:
from diffusers import BitsAndBytesConfig as _DiffBnb
q_cfg = _DiffBnb(load_in_4bit=True, bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=compute_dtype)
except Exception as e:
print(f" [lora][wan] 4-bit unavailable ({e}); loading in bf16")
q_cfg = None
def _load_transformer(subfolder):
kw = dict(subfolder=subfolder, torch_dtype=compute_dtype)
if q_cfg is not None:
kw["quantization_config"] = q_cfg
tr = WanTransformer3DModel.from_pretrained(base_path, **kw)
if q_cfg is None:
tr = tr.to(device)
return tr
def _load_experts():
exp = [("transformer", _load_transformer("transformer"))]
if _os.path.isdir(_os.path.join(base_path, "transformer_2")):
exp.append(("transformer_2", _load_transformer("transformer_2")))
b = getattr(exp[0][1].config, "boundary_ratio", None)
return exp, b
# Reuse the transformer(s) across consecutive trainings against the same base.
experts, boundary = _acquire_wan_experts(base_path, quantize, _load_experts)
lora_cfg = PeftLoraConfig(
r=rank, lora_alpha=rank, init_lora_weights="gaussian",
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment