Preserve user-specified model in auto mode retry logic

- Track if user explicitly specified --model before auto mode runs
- Skip retry with alternative models when user's model fails
- Show clear error message explaining user's choice is preserved
- Only auto-selected models can be retried with alternatives
parent be1e5b9d
Pipeline #229 canceled with stages
...@@ -3318,6 +3318,11 @@ def main(args): ...@@ -3318,6 +3318,11 @@ def main(args):
if getattr(args, 'auto', False) and not hasattr(args, '_auto_mode'): if getattr(args, 'auto', False) and not hasattr(args, '_auto_mode'):
if not args.prompt: if not args.prompt:
parser.error("--auto requires --prompt to analyze") parser.error("--auto requires --prompt to analyze")
# Track if user explicitly specified the model (before auto mode modifies it)
# This is used to preserve user's model choice during retry
args._user_specified_model = args.model is not None and args.model != (list(MODELS.keys())[0] if MODELS else None)
args = run_auto_mode(args, MODELS) args = run_auto_mode(args, MODELS)
if args is None: if args is None:
sys.exit(1) sys.exit(1)
...@@ -3520,69 +3525,77 @@ def main(args): ...@@ -3520,69 +3525,77 @@ def main(args):
max_retries = getattr(args, '_max_retries', 3) max_retries = getattr(args, '_max_retries', 3)
alternative_models = getattr(args, '_auto_alternative_models', []) alternative_models = getattr(args, '_auto_alternative_models', [])
failed_base_models = getattr(args, '_failed_base_models', set()) failed_base_models = getattr(args, '_failed_base_models', set())
user_specified_model = getattr(args, '_user_specified_model', False)
# If this was a LoRA with a base model, track the failed base model # If user explicitly specified the model, don't retry with alternatives
if is_lora and base_model_id: # The user's model choice should be preserved
failed_base_models.add(base_model_id) if user_specified_model:
args._failed_base_models = failed_base_models print(f"\n⚠️ User-specified model failed: {model_id_to_load}")
print(f" ⚠️ Base model failed: {base_model_id}") print(f" The model was explicitly provided with --model, not retrying with alternatives.")
print(f" Will skip other LoRAs depending on this base model") print(f" Please verify the model exists or try a different model.")
else:
# Find next valid alternative (skip LoRAs with failed base models) # If this was a LoRA with a base model, track the failed base model
next_model = None if is_lora and base_model_id:
skipped_loras = [] failed_base_models.add(base_model_id)
args._failed_base_models = failed_base_models
while alternative_models: print(f" ⚠️ Base model failed: {base_model_id}")
candidate_name, candidate_info, candidate_reason = alternative_models.pop(0) print(f" Will skip other LoRAs depending on this base model")
# Check if this is a LoRA with a failed base model # Find next valid alternative (skip LoRAs with failed base models)
if candidate_info.get("is_lora", False): next_model = None
candidate_base = candidate_info.get("base_model") or candidate_info.get("_inferred_base_model") skipped_loras = []
if candidate_base and candidate_base in failed_base_models:
skipped_loras.append((candidate_name, candidate_base))
continue # Skip this LoRA
# Found a valid candidate while alternative_models:
next_model = (candidate_name, candidate_info, candidate_reason) candidate_name, candidate_info, candidate_reason = alternative_models.pop(0)
break
# Check if this is a LoRA with a failed base model
# Update the alternatives list if candidate_info.get("is_lora", False):
args._auto_alternative_models = alternative_models candidate_base = candidate_info.get("base_model") or candidate_info.get("_inferred_base_model")
if candidate_base and candidate_base in failed_base_models:
if skipped_loras: skipped_loras.append((candidate_name, candidate_base))
print(f" ⏭️ Skipped {len(skipped_loras)} LoRA(s) with failed base models") continue # Skip this LoRA
if retry_count < max_retries and next_model: # Found a valid candidate
# We have a valid alternative - retry with it next_model = (candidate_name, candidate_info, candidate_reason)
args._retry_count = retry_count + 1 break
next_model_name, next_model_info, next_reason = next_model
# Print appropriate error message based on error type # Update the alternatives list
if "404" in error_str or "Entry Not Found" in error_str or "Repository Not Found" in error_str: args._auto_alternative_models = alternative_models
print(f"❌ Model not found on HuggingFace: {model_id_to_load}")
elif "401" in error_str or "Unauthorized" in error_str:
print(f"❌ Model requires authentication: {model_id_to_load}")
elif "FrozenDict" in error_str or "scale_factor" in error_str or "has no attribute" in error_str:
print(f"❌ Pipeline compatibility error: {model_id_to_load}")
print(f" This model uses an incompatible pipeline architecture.")
else:
print(f"❌ Model loading failed: {model_id_to_load}")
print(f" Error: {error_str[:100]}...")
print(f"\n🔄 Retrying with alternative model ({args._retry_count}/{max_retries})...") if skipped_loras:
print(f" New model: {next_model_name}") print(f" ⏭️ Skipped {len(skipped_loras)} LoRA(s) with failed base models")
print(f" {next_reason}")
# Update args with new model and recurse if retry_count < max_retries and next_model:
args.model = next_model_name # We have a valid alternative - retry with it
# Clean up any partial model loading args._retry_count = retry_count + 1
if torch.cuda.is_available(): next_model_name, next_model_info, next_reason = next_model
torch.cuda.empty_cache()
# Retry main() with the new model # Print appropriate error message based on error type
return main(args) if "404" in error_str or "Entry Not Found" in error_str or "Repository Not Found" in error_str:
print(f"❌ Model not found on HuggingFace: {model_id_to_load}")
# No more valid alternatives or retries exhausted elif "401" in error_str or "Unauthorized" in error_str:
print(f"\n❌ All model retries exhausted ({retry_count}/{max_retries} attempts)") print(f"❌ Model requires authentication: {model_id_to_load}")
elif "FrozenDict" in error_str or "scale_factor" in error_str or "has no attribute" in error_str:
print(f"❌ Pipeline compatibility error: {model_id_to_load}")
print(f" This model uses an incompatible pipeline architecture.")
else:
print(f"❌ Model loading failed: {model_id_to_load}")
print(f" Error: {error_str[:100]}...")
print(f"\n🔄 Retrying with alternative model ({args._retry_count}/{max_retries})...")
print(f" New model: {next_model_name}")
print(f" {next_reason}")
# Update args with new model and recurse
args.model = next_model_name
# Clean up any partial model loading
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Retry main() with the new model
return main(args)
# No more valid alternatives or retries exhausted
print(f"\n❌ All model retries exhausted ({retry_count}/{max_retries} attempts)")
# Print detailed error message for the user # Print detailed error message for the user
if "404" in error_str or "Entry Not Found" in error_str: if "404" in error_str or "Entry Not Found" in error_str:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment