Fix OOM in I2V mode: sequential model loading

- Defer I2V model loading when in I2V mode without provided image
- Generate image first with T2I model
- Unload T2I model completely (del, empty_cache, gc.collect)
- Then load I2V model and generate video
- This ensures only one model is in memory at a time
- Fixes Linux OOM killer issue when loading multiple models
parent 1c242c7e
Pipeline #236 canceled with stages
......@@ -3684,10 +3684,25 @@ def main(args):
print(f"Custom Wan VAE load failed: {e}")
timing.start()
timing.begin_step("model_loading")
debug = getattr(args, 'debug', False)
# ─── DEFER MODEL LOADING FOR I2V MODE ─────────────────────────────────────────
# For I2V mode without --image, we need to generate the image first.
# To avoid OOM, we should NOT load the I2V model until after the image is generated.
# We'll set pipe = None and load it later after image generation.
defer_i2v_loading = False
if (args.image_to_video or args.image) and not args.image and m_info.get("supports_i2v"):
# I2V mode without provided image - need to generate image first
defer_i2v_loading = True
print(f"\n⏳ Deferring I2V model loading until after image generation")
print(f" (To avoid OOM, image model will be loaded and unloaded first)")
pipe = None
pipeline_loaded_successfully = True # Skip the loading block below
else:
timing.begin_step("model_loading")
# Initialize flag for pipeline mismatch fallback
pipeline_loaded_successfully = False
......@@ -4146,6 +4161,10 @@ def main(args):
print(f"✨ Done! Seed: {seed}")
return
# ─── I2V Mode: Generate image FIRST, then load video model ─────────────────────
# IMPORTANT: To avoid OOM, we generate the image first, then unload the image model
# before loading the video model. This ensures only one model is in memory at a time.
if args.image_to_video or args.image:
if not m_info.get("supports_i2v"):
print(f"Error: {args.model} does not support image-to-video.")
......@@ -4169,7 +4188,8 @@ def main(args):
print(f"❌ Failed to load image: {e}")
sys.exit(1)
else:
# Generate image using image_model
# Generate image using image_model FIRST (before loading I2V model)
# This is critical to avoid OOM - we load T2I, generate, unload, then load I2V
timing.begin_step("image_generation")
img_info = MODELS[args.image_model]
......@@ -4222,6 +4242,7 @@ def main(args):
img_kwargs["low_cpu_mem_usage"] = True
try:
print(f"\n🖼️ Loading image model for I2V: {args.image_model}")
img_pipe = ImgCls.from_pretrained(img_model_id_to_load, **img_kwargs)
# Apply LoRA if image model is a LoRA adapter
......@@ -4237,6 +4258,7 @@ def main(args):
img_pipe.enable_model_cpu_offload()
img_prompt = ", ".join(args.prompt_image) if args.prompt_image else main_prompt
print(f" Generating initial image...")
with torch.no_grad():
init_image = img_pipe(
img_prompt,
......@@ -4247,13 +4269,92 @@ def main(args):
if is_main:
init_image.save(f"{args.output}_init.png")
print(f" Saved initial image: {args.output}_init.png")
print(f" Saved initial image: {args.output}_init.png")
timing.end_step() # image_generation
# ─── CRITICAL: Unload image model to free memory ───────────────────
print(f"\n🗑️ Unloading image model to free memory...")
del img_pipe
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
import gc
gc.collect()
print(f" ✅ Image model unloaded, memory freed")
log_memory()
except Exception as e:
print(f"Image generation failed: {e}")
sys.exit(1)
# ─── Now load the I2V model (after image model is unloaded) ─────────────
timing.begin_step("i2v_model_loading")
print(f"\n📹 Loading I2V model: {args.model}")
# Reload the I2V pipeline
try:
pipe = PipelineClass.from_pretrained(model_id_to_load, **pipe_kwargs)
# Apply LoRA if this is a LoRA model
if is_lora and lora_id:
print(f" Loading LoRA adapter: {lora_id}")
try:
pipe.load_lora_weights(lora_id)
print(f" ✅ LoRA applied successfully")
except Exception as lora_e:
print(f" ⚠️ LoRA loading failed: {lora_e}")
print(f" Continuing with base model...")
if args.no_filter and hasattr(pipe, "safety_checker"):
pipe.safety_checker = None
# Re-apply offloading strategy
if off == "auto_map":
pipe.enable_model_cpu_offload()
elif off == "sequential":
pipe.enable_sequential_cpu_offload()
elif off == "group":
try:
pipe.enable_group_offload(group_size=args.offload_group_size)
except:
print("Group offload unavailable → model offload fallback")
pipe.enable_model_cpu_offload()
elif off == "model":
pipe.enable_model_cpu_offload()
else:
pipe.to("cuda" if torch.cuda.is_available() else "cpu")
pipe.enable_attention_slicing("max")
try:
pipe.enable_vae_slicing()
pipe.enable_vae_tiling()
except:
pass
if torch.cuda.is_available():
try:
pipe.enable_xformers_memory_efficient_attention()
except:
pass
if "wan" in args.model and hasattr(pipe, "scheduler"):
try:
pipe.scheduler = UniPCMultistepScheduler.from_config(
pipe.scheduler.config,
prediction_type="flow_prediction",
flow_shift=extra.get("flow_shift", 3.0)
)
except:
pass
print(f" ✅ I2V model loaded successfully")
timing.end_step() # i2v_model_loading
except Exception as e:
print(f"❌ Failed to load I2V model: {e}")
sys.exit(1)
# ─── Audio Generation (Pre-video) ──────────────────────────────────────────
audio_path = None
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment