Fix Wan 2.2 I2V base model detection

- Fixed model ID normalization to handle hyphens (in addition to underscores)
- Fixed dictionary key ordering in base_model_fallbacks so more specific keys (wan2.2.i2v) are checked before generic keys (wan2.2)
- Fixed Wan 2.1 I2V base model mapping (was incorrectly pointing to T2V)
- Fixed base model detection in earlier code sections to check model ID directly instead of relying on m_info.get('supports_i2v')
- Fixed typo: Wan 2.2 generic fallback now correctly uses Wan2.2-T2V

Now Wan 2.2 I2V models like Wan-AI/Wan2.2-I2V-A14B will correctly use Wan-AI/Wan2.2-I2V-14B-Diffusers as the base model instead of the incorrect Wan-AI/Wan2.2-T2V-14B-Diffusers.
parent 4a5213f8
...@@ -2239,6 +2239,11 @@ def update_all_models(hf_token=None): ...@@ -2239,6 +2239,11 @@ def update_all_models(hf_token=None):
if is_lora: if is_lora:
if "wan" in model_id.lower(): if "wan" in model_id.lower():
# Wan 2.2 models - use the new MoE base
if "wan2.2" in model_id.lower() or "wan2_2" in model_id.lower():
base_model = "Wan-AI/Wan2.2-I2V-14B-Diffusers" if m["is_i2v"] else "Wan-AI/Wan2.2-T2V-14B-Diffusers"
else:
# Wan 2.1 and earlier
base_model = "Wan-AI/Wan2.1-I2V-14B-Diffusers" if m["is_i2v"] else "Wan-AI/Wan2.1-T2V-14B-Diffusers" base_model = "Wan-AI/Wan2.1-I2V-14B-Diffusers" if m["is_i2v"] else "Wan-AI/Wan2.1-T2V-14B-Diffusers"
elif "svd" in model_id.lower() or "stable-video" in model_id.lower(): elif "svd" in model_id.lower() or "stable-video" in model_id.lower():
base_model = "stabilityai/stable-video-diffusion-img2vid-xt-1-1" base_model = "stabilityai/stable-video-diffusion-img2vid-xt-1-1"
...@@ -2340,6 +2345,11 @@ def update_all_models(hf_token=None): ...@@ -2340,6 +2345,11 @@ def update_all_models(hf_token=None):
if is_lora: if is_lora:
if "wan" in model_id.lower(): if "wan" in model_id.lower():
# Wan 2.2 models - use the new MoE base
if "wan2.2" in model_id.lower() or "wan2_2" in model_id.lower():
base_model = "Wan-AI/Wan2.2-I2V-14B-Diffusers" if m["is_i2v"] else "Wan-AI/Wan2.2-T2V-14B-Diffusers"
else:
# Wan 2.1 and earlier
base_model = "Wan-AI/Wan2.1-I2V-14B-Diffusers" if m["is_i2v"] else "Wan-AI/Wan2.1-T2V-14B-Diffusers" base_model = "Wan-AI/Wan2.1-I2V-14B-Diffusers" if m["is_i2v"] else "Wan-AI/Wan2.1-T2V-14B-Diffusers"
elif "svd" in model_id.lower() or "stable-video" in model_id.lower(): elif "svd" in model_id.lower() or "stable-video" in model_id.lower():
base_model = "stabilityai/stable-video-diffusion-img2vid-xt-1-1" base_model = "stabilityai/stable-video-diffusion-img2vid-xt-1-1"
...@@ -3868,10 +3878,11 @@ def select_best_model(gen_type, models, vram_gb=24, prefer_quality=True, return_ ...@@ -3868,10 +3878,11 @@ def select_best_model(gen_type, models, vram_gb=24, prefer_quality=True, return_
if "wan" in lora_id: if "wan" in lora_id:
if "wan2.2" in lora_id: if "wan2.2" in lora_id:
# Wan 2.2 models - use the new MoE base # Wan 2.2 models - use the new MoE base
base_model_id = "Wan-AI/Wan2.2-I2V-14B-Diffusers" if info.get("supports_i2v") else "Wan-AI/Wan2.2-T2V-14B-Diffusers" # IMPORTANT: For I2V models, always use I2V base model, not T2V
base_model_id = "Wan-AI/Wan2.2-I2V-14B-Diffusers" if "i2v" in lora_id else "Wan-AI/Wan2.2-T2V-14B-Diffusers"
else: else:
# Wan 2.1 and earlier # Wan 2.1 and earlier
base_model_id = "Wan-AI/Wan2.1-I2V-14B-Diffusers" if info.get("supports_i2v") else "Wan-AI/Wan2.1-T2V-14B-Diffusers" base_model_id = "Wan-AI/Wan2.1-I2V-14B-Diffusers" if "i2v" in lora_id else "Wan-AI/Wan2.1-T2V-14B-Diffusers"
elif "svd" in lora_id or "stable-video" in lora_id: elif "svd" in lora_id or "stable-video" in lora_id:
base_model_id = "stabilityai/stable-video-diffusion-img2vid-xt-1-1" base_model_id = "stabilityai/stable-video-diffusion-img2vid-xt-1-1"
elif "sdxl" in lora_id: elif "sdxl" in lora_id:
...@@ -8070,23 +8081,25 @@ def main(args): ...@@ -8070,23 +8081,25 @@ def main(args):
# Wan models - check for version 2.2 first, then 2.1 # Wan models - check for version 2.2 first, then 2.1
if "wan" in lora_id_lower: if "wan" in lora_id_lower:
if "wan2.2" in lora_id_lower or "wan2_2" in lora_id_lower: if "wan2.2" in lora_id_lower or "wan2_2" in lora_id_lower:
# Wan 2.2 models # Wan 2.2 models - use lora_id_lower to determine I2V vs T2V
if m_info.get("supports_i2v"): # This is more reliable than m_info.get("supports_i2v")
if "i2v" in lora_id_lower:
base_model_id = "Wan-AI/Wan2.2-I2V-14B-Diffusers" base_model_id = "Wan-AI/Wan2.2-I2V-14B-Diffusers"
else: else:
base_model_id = "Wan-AI/Wan2.2-T2V-14B-Diffusers" base_model_id = "Wan-AI/Wan2.2-T2V-14B-Diffusers"
elif "wan2.1" in lora_id_lower or "wan2_1" in lora_id_lower: elif "wan2.1" in lora_id_lower or "wan2_1" in lora_id_lower:
# Wan 2.1 models # Wan 2.1 models - use lora_id_lower to determine I2V vs T2V
if m_info.get("supports_i2v"): if "i2v" in lora_id_lower:
base_model_id = "Wan-AI/Wan2.1-I2V-14B-Diffusers" base_model_id = "Wan-AI/Wan2.1-I2V-14B-Diffusers"
else: else:
base_model_id = "Wan-AI/Wan2.1-T2V-14B-Diffusers" base_model_id = "Wan-AI/Wan2.1-T2V-14B-Diffusers"
else: else:
# Generic Wan - default to 2.2 I2V if supports_i2v # Generic Wan - check the lora_id for i2v instead of m_info
if m_info.get("supports_i2v"): # This is more reliable as m_info.supports_i2v may not be set correctly
if "i2v" in lora_id_lower:
base_model_id = "Wan-AI/Wan2.2-I2V-14B-Diffusers" base_model_id = "Wan-AI/Wan2.2-I2V-14B-Diffusers"
else: else:
base_model_id = "Wan-AI/Wan2.1-T2V-14B-Diffusers" base_model_id = "Wan-AI/Wan2.2-T2V-14B-Diffusers"
elif "svd" in lora_id_lower or "stable-video" in lora_id_lower: elif "svd" in lora_id_lower or "stable-video" in lora_id_lower:
base_model_id = "stabilityai/stable-video-diffusion-img2vid-xt-1-1" base_model_id = "stabilityai/stable-video-diffusion-img2vid-xt-1-1"
elif "sdxl" in lora_id_lower: elif "sdxl" in lora_id_lower:
...@@ -9208,32 +9221,38 @@ def main(args): ...@@ -9208,32 +9221,38 @@ def main(args):
# Strategy 2: Try loading from base model and then fine-tuned weights (subfolder style) # Strategy 2: Try loading from base model and then fine-tuned weights (subfolder style)
if not loaded_with_base: if not loaded_with_base:
# Normalize model ID for matching (replace underscores with dots) # Normalize model ID for matching (replace underscores AND hyphens with dots)
model_id_normalized = model_id_lower.replace('_', '.') # This ensures wan2.2-i2v-a14b matches wan2.2.i2v.a14b
model_id_normalized = model_id_lower.replace('_', '.').replace('-', '.')
base_model_fallbacks = { base_model_fallbacks = {
"ltx": "Lightricks/LTX-Video", "ltx": "Lightricks/LTX-Video",
"ltxvideo": "Lightricks/LTX-Video", "ltxvideo": "Lightricks/LTX-Video",
# Wan 2.2 models - I2V uses I2V base, T2V uses T2V base # Wan 2.2 I2V models - more specific keys FIRST (before generic "wan2.2")
"wan2.2.i2v.a14b": "Wan-AI/Wan2.2-I2V-14B-Diffusers", "wan2.2.i2v.a14b": "Wan-AI/Wan2.2-I2V-14B-Diffusers",
"wan2.2.i2v": "Wan-AI/Wan2.2-I2V-14B-Diffusers", "wan2.2.i2v": "Wan-AI/Wan2.2-I2V-14B-Diffusers",
"wan2.2.t2v": "Wan-AI/Wan2.2-T2V-14B-Diffusers",
# Wan 2.2 generic - MUST come after specific I2V/T2V keys
"wan2.2": "Wan-AI/Wan2.2-T2V-14B-Diffusers", "wan2.2": "Wan-AI/Wan2.2-T2V-14B-Diffusers",
# Wan 2.1 models # Wan 2.1 I2V models - more specific keys FIRST
"wan2.1.i2v.a14b": "Wan-AI/Wan2.1-T2V-14B-Diffusers", "wan2.1.i2v.a14b": "Wan-AI/Wan2.1-I2V-14B-Diffusers",
"wan2.1.i2v": "Wan-AI/Wan2.1-T2V-14B-Diffusers", "wan2.1.i2v": "Wan-AI/Wan2.1-I2V-14B-Diffusers",
"wan2.1.t2v": "Wan-AI/Wan2.1-T2V-14B-Diffusers",
# Wan 2.1 generic - MUST come after specific I2V/T2V keys
"wan2.1": "Wan-AI/Wan2.1-T2V-14B-Diffusers", "wan2.1": "Wan-AI/Wan2.1-T2V-14B-Diffusers",
# Generic Wan fallback (least specific - checked last) # Generic Wan fallback (least specific - checked last)
"wan": "Wan-AI/Wan2.1-T2V-14B-Diffusers", "wan": "Wan-AI/Wan2.2-T2V-14B-Diffusers",
"svd": "stabilityai/stable-video-diffusion-img2vid-xt-1-1", "svd": "stabilityai/stable-video-diffusion-img2vid-xt-1-1",
"cogvideo": "THUDM/CogVideoX-5b", "cogvideo": "THUDM/CogVideoX-5b",
"mochi": "genmo/mochi-1-preview", "mochi": "genmo/mochi-1-preview",
} }
for key, base_model in base_model_fallbacks.items(): for key, base_model in base_model_fallbacks.items():
# Check both original (with underscores) and normalized (with dots) # Check both original (with underscores) and normalized (with dots/hyphens)
if key in model_id_lower or key in model_id_normalized: if key in model_id_lower or key in model_id_normalized:
print(f" Trying to load base model first: {base_model}") print(f" Trying to load base model first: {base_model}")
print(f" Then loading fine-tuned weights from: {model_id_to_load}") print(f" Then loading fine-tuned weights from: {model_id_to_load}")
# Found the base model - break after processing
try: try:
# Determine the correct pipeline class for the base model # Determine the correct pipeline class for the base model
FallbackPipelineClass = None FallbackPipelineClass = None
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment