fix: Normalize model ID when matching Wan base models

The issue was that model IDs from HuggingFace use dots (wan2.2-i2v-a14b)
while user config names use underscores (wan2_2_i2v_a14b).

Now we normalize the model ID by replacing underscores with dots before
matching against the base_model_fallbacks dictionary.
parent 9a48c010
...@@ -9196,22 +9196,19 @@ def main(args): ...@@ -9196,22 +9196,19 @@ def main(args):
# Strategy 2: Try loading from base model and then fine-tuned weights (subfolder style) # Strategy 2: Try loading from base model and then fine-tuned weights (subfolder style)
if not loaded_with_base: if not loaded_with_base:
# Normalize model ID for matching (replace underscores with dots)
model_id_normalized = model_id_lower.replace('_', '.')
base_model_fallbacks = { base_model_fallbacks = {
"ltx": "Lightricks/LTX-Video", "ltx": "Lightricks/LTX-Video",
"ltxvideo": "Lightricks/LTX-Video", "ltxvideo": "Lightricks/LTX-Video",
# Wan models - more specific keys MUST come first! # Wan models - more specific keys MUST come first!
# (longer keys checked first so wan2_2_i2v matches before wan2_2) # (longer keys checked first so wan2.2.i2v matches before wan2.2)
"wan2_2_i2v_a14b": "Wan-AI/Wan2.2-I2V-14B-Diffusers", "wan2.2.i2v.a14b": "Wan-AI/Wan2.2-I2V-14B-Diffusers",
"wan2.2_i2v_a14b": "Wan-AI/Wan2.2-I2V-14B-Diffusers", "wan2.2.i2v": "Wan-AI/Wan2.2-I2V-14B-Diffusers",
"wan2_2_i2v": "Wan-AI/Wan2.2-I2V-14B-Diffusers",
"wan2.2_i2v": "Wan-AI/Wan2.2-I2V-14B-Diffusers",
"wan2_2": "Wan-AI/Wan2.2-T2V-14B-Diffusers",
"wan2.2": "Wan-AI/Wan2.2-T2V-14B-Diffusers", "wan2.2": "Wan-AI/Wan2.2-T2V-14B-Diffusers",
"wan2_1_i2v_a14b": "Wan-AI/Wan2.1-I2V-14B-Diffusers", "wan2.1.i2v.a14b": "Wan-AI/Wan2.1-I2V-14B-Diffusers",
"wan2.1_i2v_a14b": "Wan-AI/Wan2.1-I2V-14B-Diffusers", "wan2.1.i2v": "Wan-AI/Wan2.1-I2V-14B-Diffusers",
"wan2_1_i2v": "Wan-AI/Wan2.1-I2V-14B-Diffusers",
"wan2.1_i2v": "Wan-AI/Wan2.1-I2V-14B-Diffusers",
"wan2_1": "Wan-AI/Wan2.1-T2V-14B-Diffusers",
"wan2.1": "Wan-AI/Wan2.1-T2V-14B-Diffusers", "wan2.1": "Wan-AI/Wan2.1-T2V-14B-Diffusers",
# Generic Wan fallback (least specific - checked last) # Generic Wan fallback (least specific - checked last)
"wan": "Wan-AI/Wan2.1-T2V-14B-Diffusers", "wan": "Wan-AI/Wan2.1-T2V-14B-Diffusers",
...@@ -9221,7 +9218,8 @@ def main(args): ...@@ -9221,7 +9218,8 @@ def main(args):
} }
for key, base_model in base_model_fallbacks.items(): for key, base_model in base_model_fallbacks.items():
if key in model_id_lower: # Check both original (with underscores) and normalized (with dots)
if key in model_id_lower or key in model_id_normalized:
print(f" Trying to load base model first: {base_model}") print(f" Trying to load base model first: {base_model}")
print(f" Then loading fine-tuned weights from: {model_id_to_load}") print(f" Then loading fine-tuned weights from: {model_id_to_load}")
try: try:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment