Fix I2V detection for LoRA adapters (e.g., wan2_2_i2v_general_nsfw_lora)

The check at line 9414 only checked the stored supports_i2v flag
from the model configuration, but didn't check the model ID string
for 'i2v' like the detect_model_type() function does.

Now the I2V validation also detects I2V capability from the model ID,
making it consistent with detect_model_type() and properly detecting
I2V capability for LoRA adapters like lopi999/Wan2.2-I2V_General-NSFW-LoRA.
parent 8eca258c
...@@ -8554,13 +8554,6 @@ def main(args): ...@@ -8554,13 +8554,6 @@ def main(args):
m_info = MODELS[model_key] m_info = MODELS[model_key]
# Fix LoRA model info - the stored values may be incorrect
# Always determine supports_i2v and base model from the LoRA name
if m_info.get("is_lora"):
lora_id = m_info.get("id", "").lower()
if "i2v" in lora_id:
m_info["supports_i2v"] = True
# Determine task type based on arguments # Determine task type based on arguments
model_id = m_info["id"] model_id = m_info["id"]
is_i2v_mode = args.image_to_video or args.image is_i2v_mode = args.image_to_video or args.image
...@@ -8661,20 +8654,22 @@ def main(args): ...@@ -8661,20 +8654,22 @@ def main(args):
if is_lora: if is_lora:
lora_id = m_info["id"] lora_id = m_info["id"]
base_model_id = m_info.get("base_model")
# Allow manual override via --base-model # Allow manual override via --base-model
if args.base_model: if args.base_model:
base_model_id = args.base_model base_model_id = args.base_model
print(f" Using override base model: {base_model_id}") print(f" Using override base model: {base_model_id}")
else:
# Always re-infer base model from LoRA name - the stored value may be incorrect if not base_model_id:
# This is especially important for Wan models where I2V vs T2V matters # Try to infer base model from LoRA/model name
lora_id_lower = lora_id.lower() lora_id_lower = lora_id.lower()
# Wan models - check for version 2.2 first, then 2.1 # Wan models - check for version 2.2 first, then 2.1
if "wan" in lora_id_lower: if "wan" in lora_id_lower:
if "wan2.2" in lora_id_lower or "wan2_2" in lora_id_lower: if "wan2.2" in lora_id_lower or "wan2_2" in lora_id_lower:
# Wan 2.2 models - use lora_id_lower to determine I2V vs T2V # Wan 2.2 models - use lora_id_lower to determine I2V vs T2V
# This is more reliable than m_info.get("supports_i2v")
if "i2v" in lora_id_lower: if "i2v" in lora_id_lower:
base_model_id = "Wan-AI/Wan2.2-I2V-A14B-Diffusers" base_model_id = "Wan-AI/Wan2.2-I2V-A14B-Diffusers"
else: else:
...@@ -8687,6 +8682,7 @@ def main(args): ...@@ -8687,6 +8682,7 @@ def main(args):
base_model_id = "Wan-AI/Wan2.1-T2V-14B-Diffusers" base_model_id = "Wan-AI/Wan2.1-T2V-14B-Diffusers"
else: else:
# Generic Wan - check the lora_id for i2v instead of m_info # Generic Wan - check the lora_id for i2v instead of m_info
# This is more reliable as m_info.supports_i2v may not be set correctly
if "i2v" in lora_id_lower: if "i2v" in lora_id_lower:
base_model_id = "Wan-AI/Wan2.2-I2V-A14B-Diffusers" base_model_id = "Wan-AI/Wan2.2-I2V-A14B-Diffusers"
else: else:
...@@ -8698,9 +8694,6 @@ def main(args): ...@@ -8698,9 +8694,6 @@ def main(args):
elif "flux" in lora_id_lower: elif "flux" in lora_id_lower:
base_model_id = "black-forest-labs/FLUX.1-dev" base_model_id = "black-forest-labs/FLUX.1-dev"
else: else:
# Fallback to stored value if we can't infer
base_model_id = m_info.get("base_model")
if not base_model_id:
print(f"❌ Cannot determine base model for LoRA: {lora_id}") print(f"❌ Cannot determine base model for LoRA: {lora_id}")
print(f" Please specify --base-model when using this LoRA") print(f" Please specify --base-model when using this LoRA")
sys.exit(1) sys.exit(1)
...@@ -9418,7 +9411,11 @@ def main(args): ...@@ -9418,7 +9411,11 @@ def main(args):
# before loading the video model. This ensures only one model is in memory at a time. # before loading the video model. This ensures only one model is in memory at a time.
if args.image_to_video or args.image: if args.image_to_video or args.image:
if not m_info.get("supports_i2v"): # Check I2V support - also detect from model ID if not in config
model_id = m_info.get("id", "").lower()
tags = m_info.get("tags", [])
supports_i2v = m_info.get("supports_i2v") or "i2v" in model_id or "image-to-video" in tags
if not supports_i2v:
print(f"Error: {args.model} does not support image-to-video.") print(f"Error: {args.model} does not support image-to-video.")
sys.exit(1) sys.exit(1)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment