Fix pipeline component mismatch fallback and indentation

- Add fallback mechanism for models with incorrect model_index.json
- Detect pipeline class from model ID patterns when component mismatch occurs
- Fix indentation error in auto mode retry logic block
- Properly handle Wan2.2-I2V models with misconfigured pipeline class
parent 4668132a
Pipeline #230 canceled with stages
......@@ -3518,86 +3518,137 @@ def main(args):
print(f" [DEBUG] Response: {e.response}")
print()
# Check if we should retry with an alternative model (auto mode)
# This applies to ALL error types - we try alternatives before giving up
if getattr(args, '_auto_mode', False):
retry_count = getattr(args, '_retry_count', 0)
max_retries = getattr(args, '_max_retries', 3)
alternative_models = getattr(args, '_auto_alternative_models', [])
failed_base_models = getattr(args, '_failed_base_models', set())
user_specified_model = getattr(args, '_user_specified_model', False)
# Check if this is a pipeline component mismatch error
# This happens when the model_index.json has the wrong _class_name
is_component_mismatch = "expected" in error_str and "but only" in error_str and "were passed" in error_str
if is_component_mismatch:
# Try to re-detect the correct pipeline class from model ID pattern
detected_class = None
model_id_lower = model_id_to_load.lower()
# If user explicitly specified the model, don't retry with alternatives
# The user's model choice should be preserved
if user_specified_model:
print(f"\n⚠️ User-specified model failed: {model_id_to_load}")
print(f" The model was explicitly provided with --model, not retrying with alternatives.")
print(f" Please verify the model exists or try a different model.")
else:
# If this was a LoRA with a base model, track the failed base model
if is_lora and base_model_id:
failed_base_models.add(base_model_id)
args._failed_base_models = failed_base_models
print(f" ⚠️ Base model failed: {base_model_id}")
print(f" Will skip other LoRAs depending on this base model")
# Force detection based on model ID patterns (most reliable for misconfigured models)
if "wan2.1" in model_id_lower or "wan2.2" in model_id_lower or "wan2" in model_id_lower:
detected_class = "WanPipeline"
elif "svd" in model_id_lower or "stable-video-diffusion" in model_id_lower:
detected_class = "StableVideoDiffusionPipeline"
elif "ltx" in model_id_lower:
detected_class = "LTXVideoPipeline"
elif "mochi" in model_id_lower:
detected_class = "MochiPipeline"
elif "cogvideo" in model_id_lower:
detected_class = "CogVideoXPipeline"
elif "flux" in model_id_lower:
detected_class = "FluxPipeline"
if detected_class and detected_class != m_info["class"]:
print(f"\n⚠️ Pipeline component mismatch detected!")
print(f" Configured class: {m_info['class']}")
print(f" Detected class: {detected_class}")
print(f" The model's model_index.json may have an incorrect _class_name.")
print(f" Retrying with detected pipeline class: {detected_class}")
# Find next valid alternative (skip LoRAs with failed base models)
next_model = None
skipped_loras = []
# Get the correct pipeline class
CorrectPipelineClass = get_pipeline_class(detected_class)
if CorrectPipelineClass:
try:
pipe = CorrectPipelineClass.from_pretrained(model_id_to_load, **pipe_kwargs)
# Success! Update the model info for future runs
print(f" ✅ Successfully loaded with {detected_class}")
# Continue with the loaded pipeline
timing.end_step() # model_loading
# Update PipelineClass for the rest of the code
PipelineClass = CorrectPipelineClass
# Skip to after the error handling
goto_after_loading = True
except Exception as retry_e:
print(f" ❌ Retry with {detected_class} also failed: {retry_e}")
# Continue with normal error handling
is_component_mismatch = False # Don't retry again below
error_str = str(retry_e)
# If we successfully loaded with the corrected pipeline, skip error handling
if not locals().get('goto_after_loading', False):
# Check if we should retry with an alternative model (auto mode)
# This applies to ALL error types - we try alternatives before giving up
if getattr(args, '_auto_mode', False):
retry_count = getattr(args, '_retry_count', 0)
max_retries = getattr(args, '_max_retries', 3)
alternative_models = getattr(args, '_auto_alternative_models', [])
failed_base_models = getattr(args, '_failed_base_models', set())
user_specified_model = getattr(args, '_user_specified_model', False)
while alternative_models:
candidate_name, candidate_info, candidate_reason = alternative_models.pop(0)
# If user explicitly specified the model, don't retry with alternatives
# The user's model choice should be preserved
if user_specified_model:
print(f"\n⚠️ User-specified model failed: {model_id_to_load}")
print(f" The model was explicitly provided with --model, not retrying with alternatives.")
print(f" Please verify the model exists or try a different model.")
else:
# If this was a LoRA with a base model, track the failed base model
if is_lora and base_model_id:
failed_base_models.add(base_model_id)
args._failed_base_models = failed_base_models
print(f" ⚠️ Base model failed: {base_model_id}")
print(f" Will skip other LoRAs depending on this base model")
# Check if this is a LoRA with a failed base model
if candidate_info.get("is_lora", False):
candidate_base = candidate_info.get("base_model") or candidate_info.get("_inferred_base_model")
if candidate_base and candidate_base in failed_base_models:
skipped_loras.append((candidate_name, candidate_base))
continue # Skip this LoRA
# Find next valid alternative (skip LoRAs with failed base models)
next_model = None
skipped_loras = []
# Found a valid candidate
next_model = (candidate_name, candidate_info, candidate_reason)
break
# Update the alternatives list
args._auto_alternative_models = alternative_models
if skipped_loras:
print(f" ⏭️ Skipped {len(skipped_loras)} LoRA(s) with failed base models")
if retry_count < max_retries and next_model:
# We have a valid alternative - retry with it
args._retry_count = retry_count + 1
next_model_name, next_model_info, next_reason = next_model
while alternative_models:
candidate_name, candidate_info, candidate_reason = alternative_models.pop(0)
# Check if this is a LoRA with a failed base model
if candidate_info.get("is_lora", False):
candidate_base = candidate_info.get("base_model") or candidate_info.get("_inferred_base_model")
if candidate_base and candidate_base in failed_base_models:
skipped_loras.append((candidate_name, candidate_base))
continue # Skip this LoRA
# Found a valid candidate
next_model = (candidate_name, candidate_info, candidate_reason)
break
# Print appropriate error message based on error type
if "404" in error_str or "Entry Not Found" in error_str or "Repository Not Found" in error_str:
print(f"❌ Model not found on HuggingFace: {model_id_to_load}")
elif "401" in error_str or "Unauthorized" in error_str:
print(f"❌ Model requires authentication: {model_id_to_load}")
elif "FrozenDict" in error_str or "scale_factor" in error_str or "has no attribute" in error_str:
print(f"❌ Pipeline compatibility error: {model_id_to_load}")
print(f" This model uses an incompatible pipeline architecture.")
else:
print(f"❌ Model loading failed: {model_id_to_load}")
print(f" Error: {error_str[:100]}...")
# Update the alternatives list
args._auto_alternative_models = alternative_models
print(f"\n🔄 Retrying with alternative model ({args._retry_count}/{max_retries})...")
print(f" New model: {next_model_name}")
print(f" {next_reason}")
if skipped_loras:
print(f" ⏭️ Skipped {len(skipped_loras)} LoRA(s) with failed base models")
# Update args with new model and recurse
args.model = next_model_name
# Clean up any partial model loading
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Retry main() with the new model
return main(args)
# No more valid alternatives or retries exhausted
print(f"\n❌ All model retries exhausted ({retry_count}/{max_retries} attempts)")
# Print detailed error message for the user
if retry_count < max_retries and next_model:
# We have a valid alternative - retry with it
args._retry_count = retry_count + 1
next_model_name, next_model_info, next_reason = next_model
# Print appropriate error message based on error type
if "404" in error_str or "Entry Not Found" in error_str or "Repository Not Found" in error_str:
print(f"❌ Model not found on HuggingFace: {model_id_to_load}")
elif "401" in error_str or "Unauthorized" in error_str:
print(f"❌ Model requires authentication: {model_id_to_load}")
elif "FrozenDict" in error_str or "scale_factor" in error_str or "has no attribute" in error_str:
print(f"❌ Pipeline compatibility error: {model_id_to_load}")
print(f" This model uses an incompatible pipeline architecture.")
else:
print(f"❌ Model loading failed: {model_id_to_load}")
print(f" Error: {error_str[:100]}...")
print(f"\n🔄 Retrying with alternative model ({args._retry_count}/{max_retries})...")
print(f" New model: {next_model_name}")
print(f" {next_reason}")
# Update args with new model and recurse
args.model = next_model_name
# Clean up any partial model loading
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Retry main() with the new model
return main(args)
# No more valid alternatives or retries exhausted
print(f"\n❌ All model retries exhausted ({retry_count}/{max_retries} attempts)")
# Print detailed error message for the user
if "404" in error_str or "Entry Not Found" in error_str:
print(f"❌ Model not found on HuggingFace: {model_id_to_load}")
print(f" This model may have been removed or the ID is incorrect.")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment