Fix Wan 2.2 I2V base model detection

- Fixed model ID normalization to handle hyphens (in addition to underscores)
- Fixed dictionary key ordering in base_model_fallbacks so more specific keys (wan2.2.i2v) are checked before generic keys (wan2.2)
- Fixed Wan 2.1 I2V base model mapping (was incorrectly pointing to T2V)
- Fixed base model detection in earlier code sections to check model ID directly instead of relying on m_info.get('supports_i2v')
- Fixed typo: Wan 2.2 generic fallback now correctly uses Wan2.2-T2V

Now Wan 2.2 I2V models like Wan-AI/Wan2.2-I2V-A14B will correctly use Wan-AI/Wan2.2-I2V-14B-Diffusers as the base model instead of the incorrect Wan-AI/Wan2.2-T2V-14B-Diffusers.
parent 4a5213f8
......@@ -2239,6 +2239,11 @@ def update_all_models(hf_token=None):
if is_lora:
if "wan" in model_id.lower():
# Wan 2.2 models - use the new MoE base
if "wan2.2" in model_id.lower() or "wan2_2" in model_id.lower():
base_model = "Wan-AI/Wan2.2-I2V-14B-Diffusers" if m["is_i2v"] else "Wan-AI/Wan2.2-T2V-14B-Diffusers"
else:
# Wan 2.1 and earlier
base_model = "Wan-AI/Wan2.1-I2V-14B-Diffusers" if m["is_i2v"] else "Wan-AI/Wan2.1-T2V-14B-Diffusers"
elif "svd" in model_id.lower() or "stable-video" in model_id.lower():
base_model = "stabilityai/stable-video-diffusion-img2vid-xt-1-1"
......@@ -2340,6 +2345,11 @@ def update_all_models(hf_token=None):
if is_lora:
if "wan" in model_id.lower():
# Wan 2.2 models - use the new MoE base
if "wan2.2" in model_id.lower() or "wan2_2" in model_id.lower():
base_model = "Wan-AI/Wan2.2-I2V-14B-Diffusers" if m["is_i2v"] else "Wan-AI/Wan2.2-T2V-14B-Diffusers"
else:
# Wan 2.1 and earlier
base_model = "Wan-AI/Wan2.1-I2V-14B-Diffusers" if m["is_i2v"] else "Wan-AI/Wan2.1-T2V-14B-Diffusers"
elif "svd" in model_id.lower() or "stable-video" in model_id.lower():
base_model = "stabilityai/stable-video-diffusion-img2vid-xt-1-1"
......@@ -3868,10 +3878,11 @@ def select_best_model(gen_type, models, vram_gb=24, prefer_quality=True, return_
if "wan" in lora_id:
if "wan2.2" in lora_id:
# Wan 2.2 models - use the new MoE base
base_model_id = "Wan-AI/Wan2.2-I2V-14B-Diffusers" if info.get("supports_i2v") else "Wan-AI/Wan2.2-T2V-14B-Diffusers"
# IMPORTANT: For I2V models, always use I2V base model, not T2V
base_model_id = "Wan-AI/Wan2.2-I2V-14B-Diffusers" if "i2v" in lora_id else "Wan-AI/Wan2.2-T2V-14B-Diffusers"
else:
# Wan 2.1 and earlier
base_model_id = "Wan-AI/Wan2.1-I2V-14B-Diffusers" if info.get("supports_i2v") else "Wan-AI/Wan2.1-T2V-14B-Diffusers"
base_model_id = "Wan-AI/Wan2.1-I2V-14B-Diffusers" if "i2v" in lora_id else "Wan-AI/Wan2.1-T2V-14B-Diffusers"
elif "svd" in lora_id or "stable-video" in lora_id:
base_model_id = "stabilityai/stable-video-diffusion-img2vid-xt-1-1"
elif "sdxl" in lora_id:
......@@ -8070,23 +8081,25 @@ def main(args):
# Wan models - check for version 2.2 first, then 2.1
if "wan" in lora_id_lower:
if "wan2.2" in lora_id_lower or "wan2_2" in lora_id_lower:
# Wan 2.2 models
if m_info.get("supports_i2v"):
# Wan 2.2 models - use lora_id_lower to determine I2V vs T2V
# This is more reliable than m_info.get("supports_i2v")
if "i2v" in lora_id_lower:
base_model_id = "Wan-AI/Wan2.2-I2V-14B-Diffusers"
else:
base_model_id = "Wan-AI/Wan2.2-T2V-14B-Diffusers"
elif "wan2.1" in lora_id_lower or "wan2_1" in lora_id_lower:
# Wan 2.1 models
if m_info.get("supports_i2v"):
# Wan 2.1 models - use lora_id_lower to determine I2V vs T2V
if "i2v" in lora_id_lower:
base_model_id = "Wan-AI/Wan2.1-I2V-14B-Diffusers"
else:
base_model_id = "Wan-AI/Wan2.1-T2V-14B-Diffusers"
else:
# Generic Wan - default to 2.2 I2V if supports_i2v
if m_info.get("supports_i2v"):
# Generic Wan - check the lora_id for i2v instead of m_info
# This is more reliable as m_info.supports_i2v may not be set correctly
if "i2v" in lora_id_lower:
base_model_id = "Wan-AI/Wan2.2-I2V-14B-Diffusers"
else:
base_model_id = "Wan-AI/Wan2.1-T2V-14B-Diffusers"
base_model_id = "Wan-AI/Wan2.2-T2V-14B-Diffusers"
elif "svd" in lora_id_lower or "stable-video" in lora_id_lower:
base_model_id = "stabilityai/stable-video-diffusion-img2vid-xt-1-1"
elif "sdxl" in lora_id_lower:
......@@ -9208,32 +9221,38 @@ def main(args):
# Strategy 2: Try loading from base model and then fine-tuned weights (subfolder style)
if not loaded_with_base:
# Normalize model ID for matching (replace underscores with dots)
model_id_normalized = model_id_lower.replace('_', '.')
# Normalize model ID for matching (replace underscores AND hyphens with dots)
# This ensures wan2.2-i2v-a14b matches wan2.2.i2v.a14b
model_id_normalized = model_id_lower.replace('_', '.').replace('-', '.')
base_model_fallbacks = {
"ltx": "Lightricks/LTX-Video",
"ltxvideo": "Lightricks/LTX-Video",
# Wan 2.2 models - I2V uses I2V base, T2V uses T2V base
# Wan 2.2 I2V models - more specific keys FIRST (before generic "wan2.2")
"wan2.2.i2v.a14b": "Wan-AI/Wan2.2-I2V-14B-Diffusers",
"wan2.2.i2v": "Wan-AI/Wan2.2-I2V-14B-Diffusers",
"wan2.2.t2v": "Wan-AI/Wan2.2-T2V-14B-Diffusers",
# Wan 2.2 generic - MUST come after specific I2V/T2V keys
"wan2.2": "Wan-AI/Wan2.2-T2V-14B-Diffusers",
# Wan 2.1 models
"wan2.1.i2v.a14b": "Wan-AI/Wan2.1-T2V-14B-Diffusers",
"wan2.1.i2v": "Wan-AI/Wan2.1-T2V-14B-Diffusers",
# Wan 2.1 I2V models - more specific keys FIRST
"wan2.1.i2v.a14b": "Wan-AI/Wan2.1-I2V-14B-Diffusers",
"wan2.1.i2v": "Wan-AI/Wan2.1-I2V-14B-Diffusers",
"wan2.1.t2v": "Wan-AI/Wan2.1-T2V-14B-Diffusers",
# Wan 2.1 generic - MUST come after specific I2V/T2V keys
"wan2.1": "Wan-AI/Wan2.1-T2V-14B-Diffusers",
# Generic Wan fallback (least specific - checked last)
"wan": "Wan-AI/Wan2.1-T2V-14B-Diffusers",
"wan": "Wan-AI/Wan2.2-T2V-14B-Diffusers",
"svd": "stabilityai/stable-video-diffusion-img2vid-xt-1-1",
"cogvideo": "THUDM/CogVideoX-5b",
"mochi": "genmo/mochi-1-preview",
}
for key, base_model in base_model_fallbacks.items():
# Check both original (with underscores) and normalized (with dots)
# Check both original (with underscores) and normalized (with dots/hyphens)
if key in model_id_lower or key in model_id_normalized:
print(f" Trying to load base model first: {base_model}")
print(f" Then loading fine-tuned weights from: {model_id_to_load}")
# Found the base model - break after processing
try:
# Determine the correct pipeline class for the base model
FallbackPipelineClass = None
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment