Fix LoRA base model and I2V detection

- Always infer base model from LoRA name, not stored database value
- Fix supports_i2v detection for LoRAs based on LoRA name
- Ensures Wan 2.2 I2V LoRAs use correct I2V base model
parent 4a3889e4
...@@ -8554,6 +8554,13 @@ def main(args): ...@@ -8554,6 +8554,13 @@ def main(args):
m_info = MODELS[model_key] m_info = MODELS[model_key]
# Fix LoRA model info - the stored values may be incorrect
# Always determine supports_i2v and base model from the LoRA name
if m_info.get("is_lora"):
lora_id = m_info.get("id", "").lower()
if "i2v" in lora_id:
m_info["supports_i2v"] = True
# Determine task type based on arguments # Determine task type based on arguments
model_id = m_info["id"] model_id = m_info["id"]
is_i2v_mode = args.image_to_video or args.image is_i2v_mode = args.image_to_video or args.image
...@@ -8654,22 +8661,20 @@ def main(args): ...@@ -8654,22 +8661,20 @@ def main(args):
if is_lora: if is_lora:
lora_id = m_info["id"] lora_id = m_info["id"]
base_model_id = m_info.get("base_model")
# Allow manual override via --base-model # Allow manual override via --base-model
if args.base_model: if args.base_model:
base_model_id = args.base_model base_model_id = args.base_model
print(f" Using override base model: {base_model_id}") print(f" Using override base model: {base_model_id}")
else:
if not base_model_id: # Always re-infer base model from LoRA name - the stored value may be incorrect
# Try to infer base model from LoRA/model name # This is especially important for Wan models where I2V vs T2V matters
lora_id_lower = lora_id.lower() lora_id_lower = lora_id.lower()
# Wan models - check for version 2.2 first, then 2.1 # Wan models - check for version 2.2 first, then 2.1
if "wan" in lora_id_lower: if "wan" in lora_id_lower:
if "wan2.2" in lora_id_lower or "wan2_2" in lora_id_lower: if "wan2.2" in lora_id_lower or "wan2_2" in lora_id_lower:
# Wan 2.2 models - use lora_id_lower to determine I2V vs T2V # Wan 2.2 models - use lora_id_lower to determine I2V vs T2V
# This is more reliable than m_info.get("supports_i2v")
if "i2v" in lora_id_lower: if "i2v" in lora_id_lower:
base_model_id = "Wan-AI/Wan2.2-I2V-A14B-Diffusers" base_model_id = "Wan-AI/Wan2.2-I2V-A14B-Diffusers"
else: else:
...@@ -8682,7 +8687,6 @@ def main(args): ...@@ -8682,7 +8687,6 @@ def main(args):
base_model_id = "Wan-AI/Wan2.1-T2V-14B-Diffusers" base_model_id = "Wan-AI/Wan2.1-T2V-14B-Diffusers"
else: else:
# Generic Wan - check the lora_id for i2v instead of m_info # Generic Wan - check the lora_id for i2v instead of m_info
# This is more reliable as m_info.supports_i2v may not be set correctly
if "i2v" in lora_id_lower: if "i2v" in lora_id_lower:
base_model_id = "Wan-AI/Wan2.2-I2V-A14B-Diffusers" base_model_id = "Wan-AI/Wan2.2-I2V-A14B-Diffusers"
else: else:
...@@ -8694,9 +8698,12 @@ def main(args): ...@@ -8694,9 +8698,12 @@ def main(args):
elif "flux" in lora_id_lower: elif "flux" in lora_id_lower:
base_model_id = "black-forest-labs/FLUX.1-dev" base_model_id = "black-forest-labs/FLUX.1-dev"
else: else:
print(f"❌ Cannot determine base model for LoRA: {lora_id}") # Fallback to stored value if we can't infer
print(f" Please specify --base-model when using this LoRA") base_model_id = m_info.get("base_model")
sys.exit(1) if not base_model_id:
print(f"❌ Cannot determine base model for LoRA: {lora_id}")
print(f" Please specify --base-model when using this LoRA")
sys.exit(1)
print(f" LoRA detected: {lora_id}") print(f" LoRA detected: {lora_id}")
print(f" Base model: {base_model_id}") print(f" Base model: {base_model_id}")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment