Fix OOM in I2V mode: sequential model loading

- Defer I2V model loading when in I2V mode without provided image
- Generate image first with T2I model
- Unload T2I model completely (del, empty_cache, gc.collect)
- Then load I2V model and generate video
- This ensures only one model is in memory at a time
- Fixes Linux OOM killer issue when loading multiple models
parent 1c242c7e
Pipeline #236 canceled with stages
...@@ -3684,242 +3684,257 @@ def main(args): ...@@ -3684,242 +3684,257 @@ def main(args):
print(f"Custom Wan VAE load failed: {e}") print(f"Custom Wan VAE load failed: {e}")
timing.start() timing.start()
timing.begin_step("model_loading")
debug = getattr(args, 'debug', False) debug = getattr(args, 'debug', False)
# Initialize flag for pipeline mismatch fallback # ─── DEFER MODEL LOADING FOR I2V MODE ─────────────────────────────────────────
pipeline_loaded_successfully = False # For I2V mode without --image, we need to generate the image first.
# To avoid OOM, we should NOT load the I2V model until after the image is generated.
if debug: # We'll set pipe = None and load it later after image generation.
print(f"\n🔍 [DEBUG] Model Loading Details:")
print(f" [DEBUG] Model ID to load: {model_id_to_load}") defer_i2v_loading = False
print(f" [DEBUG] Pipeline class: {m_info['class']}") if (args.image_to_video or args.image) and not args.image and m_info.get("supports_i2v"):
print(f" [DEBUG] Is LoRA: {is_lora}") # I2V mode without provided image - need to generate image first
if is_lora: defer_i2v_loading = True
print(f" [DEBUG] LoRA ID: {lora_id}") print(f"\n⏳ Deferring I2V model loading until after image generation")
print(f" [DEBUG] Pipeline kwargs:") print(f" (To avoid OOM, image model will be loaded and unloaded first)")
for k, v in pipe_kwargs.items(): pipe = None
if k == "max_memory": pipeline_loaded_successfully = True # Skip the loading block below
print(f" {k}: {v}") else:
elif k == "device_map": timing.begin_step("model_loading")
print(f" {k}: {v}")
else: # Initialize flag for pipeline mismatch fallback
print(f" {k}: {v}") pipeline_loaded_successfully = False
print(f" [DEBUG] HF Token: {'***' + os.environ.get('HF_TOKEN', '')[-4:] if os.environ.get('HF_TOKEN') else 'Not set'}")
print(f" [DEBUG] Cache dir: {os.environ.get('HF_HOME', 'default')}")
print()
try:
pipe = PipelineClass.from_pretrained(model_id_to_load, **pipe_kwargs)
pipeline_loaded_successfully = True
except Exception as e:
error_str = str(e)
if debug: if debug:
print(f"\n🔍 [DEBUG] Error Details:") print(f"\n🔍 [DEBUG] Model Loading Details:")
print(f" [DEBUG] Exception type: {type(e).__name__}") print(f" [DEBUG] Model ID to load: {model_id_to_load}")
print(f" [DEBUG] Error message: {error_str}") print(f" [DEBUG] Pipeline class: {m_info['class']}")
if hasattr(e, 'response'): print(f" [DEBUG] Is LoRA: {is_lora}")
print(f" [DEBUG] Response: {e.response}") if is_lora:
print(f" [DEBUG] LoRA ID: {lora_id}")
print(f" [DEBUG] Pipeline kwargs:")
for k, v in pipe_kwargs.items():
if k == "max_memory":
print(f" {k}: {v}")
elif k == "device_map":
print(f" {k}: {v}")
else:
print(f" {k}: {v}")
print(f" [DEBUG] HF Token: {'***' + os.environ.get('HF_TOKEN', '')[-4:] if os.environ.get('HF_TOKEN') else 'Not set'}")
print(f" [DEBUG] Cache dir: {os.environ.get('HF_HOME', 'default')}")
print() print()
# Check if this is a pipeline component mismatch error try:
# This happens when the model_index.json has the wrong _class_name pipe = PipelineClass.from_pretrained(model_id_to_load, **pipe_kwargs)
is_component_mismatch = "expected" in error_str and "but only" in error_str and "were passed" in error_str pipeline_loaded_successfully = True
except Exception as e:
if is_component_mismatch: error_str = str(e)
# Try to re-detect the correct pipeline class from model ID pattern
detected_class = None
model_id_lower = model_id_to_load.lower()
# Force detection based on model ID patterns (most reliable for misconfigured models) if debug:
if "wan2.1" in model_id_lower or "wan2.2" in model_id_lower or "wan2" in model_id_lower: print(f"\n🔍 [DEBUG] Error Details:")
detected_class = "WanPipeline" print(f" [DEBUG] Exception type: {type(e).__name__}")
elif "svd" in model_id_lower or "stable-video-diffusion" in model_id_lower: print(f" [DEBUG] Error message: {error_str}")
detected_class = "StableVideoDiffusionPipeline" if hasattr(e, 'response'):
elif "ltx" in model_id_lower: print(f" [DEBUG] Response: {e.response}")
detected_class = "LTXVideoPipeline" print()
elif "mochi" in model_id_lower:
detected_class = "MochiPipeline" # Check if this is a pipeline component mismatch error
elif "cogvideo" in model_id_lower: # This happens when the model_index.json has the wrong _class_name
detected_class = "CogVideoXPipeline" is_component_mismatch = "expected" in error_str and "but only" in error_str and "were passed" in error_str
elif "flux" in model_id_lower:
detected_class = "FluxPipeline"
if detected_class and detected_class != m_info["class"]: if is_component_mismatch:
print(f"\n⚠️ Pipeline component mismatch detected!") # Try to re-detect the correct pipeline class from model ID pattern
print(f" Configured class: {m_info['class']}") detected_class = None
print(f" Detected class: {detected_class}") model_id_lower = model_id_to_load.lower()
print(f" The model's model_index.json may have an incorrect _class_name.")
print(f" Retrying with detected pipeline class: {detected_class}")
# Get the correct pipeline class # Force detection based on model ID patterns (most reliable for misconfigured models)
CorrectPipelineClass = get_pipeline_class(detected_class) if "wan2.1" in model_id_lower or "wan2.2" in model_id_lower or "wan2" in model_id_lower:
if CorrectPipelineClass: detected_class = "WanPipeline"
try: elif "svd" in model_id_lower or "stable-video-diffusion" in model_id_lower:
pipe = CorrectPipelineClass.from_pretrained(model_id_to_load, **pipe_kwargs) detected_class = "StableVideoDiffusionPipeline"
# Success! Update the model info for future runs elif "ltx" in model_id_lower:
print(f" ✅ Successfully loaded with {detected_class}") detected_class = "LTXVideoPipeline"
# Update PipelineClass for the rest of the code elif "mochi" in model_id_lower:
PipelineClass = CorrectPipelineClass detected_class = "MochiPipeline"
# Mark as successfully loaded elif "cogvideo" in model_id_lower:
pipeline_loaded_successfully = True detected_class = "CogVideoXPipeline"
except Exception as retry_e: elif "flux" in model_id_lower:
print(f" ❌ Retry with {detected_class} also failed: {retry_e}") detected_class = "FluxPipeline"
# Continue with normal error handling
is_component_mismatch = False # Don't retry again below
error_str = str(retry_e)
# If we successfully loaded with the corrected pipeline, skip error handling
if not pipeline_loaded_successfully:
# Check if we should retry with an alternative model (auto mode)
# This applies to ALL error types - we try alternatives before giving up
if getattr(args, '_auto_mode', False):
retry_count = getattr(args, '_retry_count', 0)
max_retries = getattr(args, '_max_retries', 3)
alternative_models = getattr(args, '_auto_alternative_models', [])
failed_base_models = getattr(args, '_failed_base_models', set())
user_specified_model = getattr(args, '_user_specified_model', False)
# If user explicitly specified the model, don't retry with alternatives if detected_class and detected_class != m_info["class"]:
# The user's model choice should be preserved print(f"\n⚠️ Pipeline component mismatch detected!")
if user_specified_model: print(f" Configured class: {m_info['class']}")
print(f"\n⚠️ User-specified model failed: {model_id_to_load}") print(f" Detected class: {detected_class}")
print(f" The model was explicitly provided with --model, not retrying with alternatives.") print(f" The model's model_index.json may have an incorrect _class_name.")
print(f" Please verify the model exists or try a different model.") print(f" Retrying with detected pipeline class: {detected_class}")
else:
# Record the failure for auto-disable tracking
record_model_failure(args.model, model_id_to_load)
# If this was a LoRA with a base model, track the failed base model
if is_lora and base_model_id:
failed_base_models.add(base_model_id)
args._failed_base_models = failed_base_models
print(f" ⚠️ Base model failed: {base_model_id}")
print(f" Will skip other LoRAs depending on this base model")
# Find next valid alternative (skip LoRAs with failed base models AND disabled models) # Get the correct pipeline class
next_model = None CorrectPipelineClass = get_pipeline_class(detected_class)
skipped_loras = [] if CorrectPipelineClass:
skipped_disabled = [] try:
pipe = CorrectPipelineClass.from_pretrained(model_id_to_load, **pipe_kwargs)
# Success! Update the model info for future runs
print(f" Successfully loaded with {detected_class}")
# Update PipelineClass for the rest of the code
PipelineClass = CorrectPipelineClass
# Mark as successfully loaded
pipeline_loaded_successfully = True
except Exception as retry_e:
print(f" Retry with {detected_class} also failed: {retry_e}")
# Continue with normal error handling
is_component_mismatch = False # Don't retry again below
error_str = str(retry_e)
# If we successfully loaded with the corrected pipeline, skip error handling
if not pipeline_loaded_successfully:
# Check if we should retry with an alternative model (auto mode)
# This applies to ALL error types - we try alternatives before giving up
if getattr(args, '_auto_mode', False):
retry_count = getattr(args, '_retry_count', 0)
max_retries = getattr(args, '_max_retries', 3)
alternative_models = getattr(args, '_auto_alternative_models', [])
failed_base_models = getattr(args, '_failed_base_models', set())
user_specified_model = getattr(args, '_user_specified_model', False)
while alternative_models: # If user explicitly specified the model, don't retry with alternatives
candidate_name, candidate_info, candidate_reason = alternative_models.pop(0) # The user's model choice should be preserved
if user_specified_model:
print(f"\n⚠️ User-specified model failed: {model_id_to_load}")
print(f" The model was explicitly provided with --model, not retrying with alternatives.")
print(f" Please verify the model exists or try a different model.")
else:
# Record the failure for auto-disable tracking
record_model_failure(args.model, model_id_to_load)
# Check if this model is disabled for auto mode # If this was a LoRA with a base model, track the failed base model
candidate_id = candidate_info.get("id", "") if is_lora and base_model_id:
if is_model_disabled(candidate_id, candidate_name): failed_base_models.add(base_model_id)
skipped_disabled.append((candidate_name, candidate_id)) args._failed_base_models = failed_base_models
continue # Skip disabled model print(f" ⚠️ Base model failed: {base_model_id}")
print(f" Will skip other LoRAs depending on this base model")
# Check if this is a LoRA with a failed base model # Find next valid alternative (skip LoRAs with failed base models AND disabled models)
if candidate_info.get("is_lora", False): next_model = None
candidate_base = candidate_info.get("base_model") or candidate_info.get("_inferred_base_model") skipped_loras = []
if candidate_base and candidate_base in failed_base_models: skipped_disabled = []
skipped_loras.append((candidate_name, candidate_base))
continue # Skip this LoRA
# Found a valid candidate while alternative_models:
next_model = (candidate_name, candidate_info, candidate_reason) candidate_name, candidate_info, candidate_reason = alternative_models.pop(0)
break
# Check if this model is disabled for auto mode
# Update the alternatives list candidate_id = candidate_info.get("id", "")
args._auto_alternative_models = alternative_models if is_model_disabled(candidate_id, candidate_name):
skipped_disabled.append((candidate_name, candidate_id))
if skipped_loras: continue # Skip disabled model
print(f" ⏭️ Skipped {len(skipped_loras)} LoRA(s) with failed base models")
# Check if this is a LoRA with a failed base model
if skipped_disabled: if candidate_info.get("is_lora", False):
print(f" ⏭️ Skipped {len(skipped_disabled)} auto-disabled model(s)") candidate_base = candidate_info.get("base_model") or candidate_info.get("_inferred_base_model")
if candidate_base and candidate_base in failed_base_models:
if retry_count < max_retries and next_model: skipped_loras.append((candidate_name, candidate_base))
# We have a valid alternative - retry with it continue # Skip this LoRA
args._retry_count = retry_count + 1
next_model_name, next_model_info, next_reason = next_model # Found a valid candidate
next_model = (candidate_name, candidate_info, candidate_reason)
break
# Print appropriate error message based on error type # Update the alternatives list
if "404" in error_str or "Entry Not Found" in error_str or "Repository Not Found" in error_str: args._auto_alternative_models = alternative_models
print(f"❌ Model not found on HuggingFace: {model_id_to_load}")
elif "401" in error_str or "Unauthorized" in error_str:
print(f"❌ Model requires authentication: {model_id_to_load}")
elif "FrozenDict" in error_str or "scale_factor" in error_str or "has no attribute" in error_str:
print(f"❌ Pipeline compatibility error: {model_id_to_load}")
print(f" This model uses an incompatible pipeline architecture.")
else:
print(f"❌ Model loading failed: {model_id_to_load}")
print(f" Error: {error_str[:100]}...")
print(f"\n🔄 Retrying with alternative model ({args._retry_count}/{max_retries})...") if skipped_loras:
print(f" New model: {next_model_name}") print(f" ⏭️ Skipped {len(skipped_loras)} LoRA(s) with failed base models")
print(f" {next_reason}")
# Update args with new model and recurse if skipped_disabled:
args.model = next_model_name print(f" ⏭️ Skipped {len(skipped_disabled)} auto-disabled model(s)")
# Clean up any partial model loading
if torch.cuda.is_available(): if retry_count < max_retries and next_model:
torch.cuda.empty_cache() # We have a valid alternative - retry with it
# Retry main() with the new model args._retry_count = retry_count + 1
return main(args) next_model_name, next_model_info, next_reason = next_model
# No more valid alternatives or retries exhausted # Print appropriate error message based on error type
print(f"\n❌ All model retries exhausted ({retry_count}/{max_retries} attempts)") if "404" in error_str or "Entry Not Found" in error_str or "Repository Not Found" in error_str:
print(f" Model not found on HuggingFace: {model_id_to_load}")
# Print detailed error message for the user elif "401" in error_str or "Unauthorized" in error_str:
if "404" in error_str or "Entry Not Found" in error_str: print(f" Model requires authentication: {model_id_to_load}")
print(f"❌ Model not found on HuggingFace: {model_id_to_load}") elif "FrozenDict" in error_str or "scale_factor" in error_str or "has no attribute" in error_str:
print(f" This model may have been removed or the ID is incorrect.") print(f" Pipeline compatibility error: {model_id_to_load}")
if debug: print(f" This model uses an incompatible pipeline architecture.")
print(f"\n [DEBUG] Troubleshooting:") else:
print(f" - Check if the model exists: https://huggingface.co/{model_id_to_load}") print(f" Model loading failed: {model_id_to_load}")
print(f" - Verify the model ID spelling") print(f" Error: {error_str[:100]}...")
print(f" - The model may have been renamed or moved")
print(f"\n 💡 Try searching for an alternative:") print(f"\n🔄 Retrying with alternative model ({args._retry_count}/{max_retries})...")
print(f" videogen --search-models ltxvideo") print(f" New model: {next_model_name}")
print(f"\n 💡 Or use the official LTX Video model:") print(f" {next_reason}")
print(f" videogen --model ltx_video --prompt 'your prompt' ...")
elif "401" in error_str or "Unauthorized" in error_str: # Update args with new model and recurse
print(f"❌ Model requires authentication: {model_id_to_load}") args.model = next_model_name
print(f" Set your HuggingFace token:") # Clean up any partial model loading
print(f" export HF_TOKEN=your_token_here") if torch.cuda.is_available():
print(f" huggingface-cli login") torch.cuda.empty_cache()
if debug: # Retry main() with the new model
print(f"\n [DEBUG] To get a token:") return main(args)
print(f" 1. Go to https://huggingface.co/settings/tokens")
print(f" 2. Create a new token with 'read' permissions") # No more valid alternatives or retries exhausted
print(f" 3. Export it: export HF_TOKEN=hf_xxx") print(f"\n All model retries exhausted ({retry_count}/{max_retries} attempts)")
elif "gated" in error_str.lower():
print(f"❌ This is a gated model: {model_id_to_load}") # Print detailed error message for the user
print(f" You need to accept the license on HuggingFace:") if "404" in error_str or "Entry Not Found" in error_str:
print(f" https://huggingface.co/{model_id_to_load}") print(f" Model not found on HuggingFace: {model_id_to_load}")
print(f" Then set HF_TOKEN and run again.") print(f" This model may have been removed or the ID is incorrect.")
elif "connection" in error_str.lower() or "timeout" in error_str.lower(): if debug:
print(f"❌ Network error loading model: {model_id_to_load}") print(f"\n [DEBUG] Troubleshooting:")
print(f" Check your internet connection and try again.") print(f" - Check if the model exists: https://huggingface.co/{model_id_to_load}")
if debug: print(f" - Verify the model ID spelling")
print(f"\n [DEBUG] Network troubleshooting:") print(f" - The model may have been renamed or moved")
print(f" - Check if you can access: https://huggingface.co/{model_id_to_load}") print(f"\n 💡 Try searching for an alternative:")
print(f" - Try with a VPN if HuggingFace is blocked") print(f" videogen --search-models ltxvideo")
print(f" - Check if HF_ENDPOINT is set (for China mirror): {os.environ.get('HF_ENDPOINT', 'not set')}") print(f"\n 💡 Or use the official LTX Video model:")
elif "FrozenDict" in error_str or "scale_factor" in error_str or "has no attribute" in error_str: print(f" videogen --model ltx_video --prompt 'your prompt' ...")
print(f"❌ Pipeline compatibility error: {model_id_to_load}") elif "401" in error_str or "Unauthorized" in error_str:
print(f" This model uses a pipeline architecture incompatible with your diffusers version.") print(f" Model requires authentication: {model_id_to_load}")
print(f" The model may require a specific diffusers version or different pipeline class.") print(f" Set your HuggingFace token:")
if debug: print(f" export HF_TOKEN=your_token_here")
print(f"\n [DEBUG] Compatibility troubleshooting:") print(f" huggingface-cli login")
print(f" - Try updating diffusers: pip install --upgrade git+https://github.com/huggingface/diffusers.git") if debug:
print(f" - Check the model's documentation for required versions") print(f"\n [DEBUG] To get a token:")
print(f" - The model may be incorrectly configured in models.json") print(f" 1. Go to https://huggingface.co/settings/tokens")
print(f"\n 💡 Try a different model with --model <name>") print(f" 2. Create a new token with 'read' permissions")
else: print(f" 3. Export it: export HF_TOKEN=hf_xxx")
print(f"Model loading failed: {e}") elif "gated" in error_str.lower():
if debug: print(f" This is a gated model: {model_id_to_load}")
import traceback print(f" You need to accept the license on HuggingFace:")
print(f"\n [DEBUG] Full traceback:") print(f" https://huggingface.co/{model_id_to_load}")
traceback.print_exc() print(f" Then set HF_TOKEN and run again.")
elif "connection" in error_str.lower() or "timeout" in error_str.lower():
print(f"\n 💡 Try searching for alternative models: videogen --search-models <query>") print(f" Network error loading model: {model_id_to_load}")
sys.exit(1) print(f" Check your internet connection and try again.")
if debug:
print(f"\n [DEBUG] Network troubleshooting:")
print(f" - Check if you can access: https://huggingface.co/{model_id_to_load}")
print(f" - Try with a VPN if HuggingFace is blocked")
print(f" - Check if HF_ENDPOINT is set (for China mirror): {os.environ.get('HF_ENDPOINT', 'not set')}")
elif "FrozenDict" in error_str or "scale_factor" in error_str or "has no attribute" in error_str:
print(f" Pipeline compatibility error: {model_id_to_load}")
print(f" This model uses a pipeline architecture incompatible with your diffusers version.")
print(f" The model may require a specific diffusers version or different pipeline class.")
if debug:
print(f"\n [DEBUG] Compatibility troubleshooting:")
print(f" - Try updating diffusers: pip install --upgrade git+https://github.com/huggingface/diffusers.git")
print(f" - Check the model's documentation for required versions")
print(f" - The model may be incorrectly configured in models.json")
print(f"\n 💡 Try a different model with --model <name>")
else:
print(f"Model loading failed: {e}")
if debug:
import traceback
print(f"\n [DEBUG] Full traceback:")
traceback.print_exc()
print(f"\n 💡 Try searching for alternative models: videogen --search-models <query>")
sys.exit(1)
timing.end_step() # model_loading timing.end_step() # model_loading
...@@ -4146,6 +4161,10 @@ def main(args): ...@@ -4146,6 +4161,10 @@ def main(args):
print(f"✨ Done! Seed: {seed}") print(f"✨ Done! Seed: {seed}")
return return
# ─── I2V Mode: Generate image FIRST, then load video model ─────────────────────
# IMPORTANT: To avoid OOM, we generate the image first, then unload the image model
# before loading the video model. This ensures only one model is in memory at a time.
if args.image_to_video or args.image: if args.image_to_video or args.image:
if not m_info.get("supports_i2v"): if not m_info.get("supports_i2v"):
print(f"Error: {args.model} does not support image-to-video.") print(f"Error: {args.model} does not support image-to-video.")
...@@ -4169,7 +4188,8 @@ def main(args): ...@@ -4169,7 +4188,8 @@ def main(args):
print(f"❌ Failed to load image: {e}") print(f"❌ Failed to load image: {e}")
sys.exit(1) sys.exit(1)
else: else:
# Generate image using image_model # Generate image using image_model FIRST (before loading I2V model)
# This is critical to avoid OOM - we load T2I, generate, unload, then load I2V
timing.begin_step("image_generation") timing.begin_step("image_generation")
img_info = MODELS[args.image_model] img_info = MODELS[args.image_model]
...@@ -4222,6 +4242,7 @@ def main(args): ...@@ -4222,6 +4242,7 @@ def main(args):
img_kwargs["low_cpu_mem_usage"] = True img_kwargs["low_cpu_mem_usage"] = True
try: try:
print(f"\n🖼️ Loading image model for I2V: {args.image_model}")
img_pipe = ImgCls.from_pretrained(img_model_id_to_load, **img_kwargs) img_pipe = ImgCls.from_pretrained(img_model_id_to_load, **img_kwargs)
# Apply LoRA if image model is a LoRA adapter # Apply LoRA if image model is a LoRA adapter
...@@ -4237,6 +4258,7 @@ def main(args): ...@@ -4237,6 +4258,7 @@ def main(args):
img_pipe.enable_model_cpu_offload() img_pipe.enable_model_cpu_offload()
img_prompt = ", ".join(args.prompt_image) if args.prompt_image else main_prompt img_prompt = ", ".join(args.prompt_image) if args.prompt_image else main_prompt
print(f" Generating initial image...")
with torch.no_grad(): with torch.no_grad():
init_image = img_pipe( init_image = img_pipe(
img_prompt, img_prompt,
...@@ -4247,12 +4269,91 @@ def main(args): ...@@ -4247,12 +4269,91 @@ def main(args):
if is_main: if is_main:
init_image.save(f"{args.output}_init.png") init_image.save(f"{args.output}_init.png")
print(f" Saved initial image: {args.output}_init.png") print(f" Saved initial image: {args.output}_init.png")
timing.end_step() # image_generation timing.end_step() # image_generation
# ─── CRITICAL: Unload image model to free memory ───────────────────
print(f"\n🗑️ Unloading image model to free memory...")
del img_pipe
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
import gc
gc.collect()
print(f" ✅ Image model unloaded, memory freed")
log_memory()
except Exception as e: except Exception as e:
print(f"Image generation failed: {e}") print(f"Image generation failed: {e}")
sys.exit(1) sys.exit(1)
# ─── Now load the I2V model (after image model is unloaded) ─────────────
timing.begin_step("i2v_model_loading")
print(f"\n📹 Loading I2V model: {args.model}")
# Reload the I2V pipeline
try:
pipe = PipelineClass.from_pretrained(model_id_to_load, **pipe_kwargs)
# Apply LoRA if this is a LoRA model
if is_lora and lora_id:
print(f" Loading LoRA adapter: {lora_id}")
try:
pipe.load_lora_weights(lora_id)
print(f" ✅ LoRA applied successfully")
except Exception as lora_e:
print(f" ⚠️ LoRA loading failed: {lora_e}")
print(f" Continuing with base model...")
if args.no_filter and hasattr(pipe, "safety_checker"):
pipe.safety_checker = None
# Re-apply offloading strategy
if off == "auto_map":
pipe.enable_model_cpu_offload()
elif off == "sequential":
pipe.enable_sequential_cpu_offload()
elif off == "group":
try:
pipe.enable_group_offload(group_size=args.offload_group_size)
except:
print("Group offload unavailable → model offload fallback")
pipe.enable_model_cpu_offload()
elif off == "model":
pipe.enable_model_cpu_offload()
else:
pipe.to("cuda" if torch.cuda.is_available() else "cpu")
pipe.enable_attention_slicing("max")
try:
pipe.enable_vae_slicing()
pipe.enable_vae_tiling()
except:
pass
if torch.cuda.is_available():
try:
pipe.enable_xformers_memory_efficient_attention()
except:
pass
if "wan" in args.model and hasattr(pipe, "scheduler"):
try:
pipe.scheduler = UniPCMultistepScheduler.from_config(
pipe.scheduler.config,
prediction_type="flow_prediction",
flow_shift=extra.get("flow_shift", 3.0)
)
except:
pass
print(f" ✅ I2V model loaded successfully")
timing.end_step() # i2v_model_loading
except Exception as e:
print(f"❌ Failed to load I2V model: {e}")
sys.exit(1)
# ─── Audio Generation (Pre-video) ────────────────────────────────────────── # ─── Audio Generation (Pre-video) ──────────────────────────────────────────
audio_path = None audio_path = None
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment