Fix: Add pipeline component mismatch fallback for image model loading in I2V mode

parent 0a1f210c
......@@ -4245,7 +4245,54 @@ def main(args):
try:
print(f"\n🖼️ Loading image model for I2V: {args.image_model}")
img_pipe = ImgCls.from_pretrained(img_model_id_to_load, **img_kwargs)
# Initialize flag for pipeline mismatch fallback
img_pipeline_loaded_successfully = False
try:
img_pipe = ImgCls.from_pretrained(img_model_id_to_load, **img_kwargs)
img_pipeline_loaded_successfully = True
except Exception as e:
error_str = str(e)
# Check if this is a pipeline component mismatch error
is_component_mismatch = "expected" in error_str and "but only" in error_str and "were passed" in error_str
if is_component_mismatch:
# Try to re-detect the correct pipeline class from model ID pattern
detected_class = None
img_model_id_lower = img_model_id_to_load.lower()
# Force detection based on model ID patterns
if "flux" in img_model_id_lower:
detected_class = "FluxPipeline"
elif "sdxl" in img_model_id_lower or "stable-diffusion-xl" in img_model_id_lower:
detected_class = "StableDiffusionXLPipeline"
elif "sd3" in img_model_id_lower or "stable-diffusion-3" in img_model_id_lower:
detected_class = "StableDiffusion3Pipeline"
elif "sd15" in img_model_id_lower or "stable-diffusion-1.5" in img_model_id_lower:
detected_class = "StableDiffusionPipeline"
if detected_class and detected_class != img_info["class"]:
print(f"\n⚠️ Image model pipeline component mismatch detected!")
print(f" Configured class: {img_info['class']}")
print(f" Detected class: {detected_class}")
print(f" Retrying with detected pipeline class: {detected_class}")
# Get the correct pipeline class
CorrectImgCls = get_pipeline_class(detected_class)
if CorrectImgCls:
try:
img_pipe = CorrectImgCls.from_pretrained(img_model_id_to_load, **img_kwargs)
print(f" ✅ Successfully loaded image model with {detected_class}")
ImgCls = CorrectImgCls
img_pipeline_loaded_successfully = True
except Exception as retry_e:
print(f" ❌ Retry with {detected_class} also failed: {retry_e}")
raise e # Re-raise original error
if not img_pipeline_loaded_successfully:
raise e # Re-raise if we couldn't handle it
# Apply LoRA if image model is a LoRA adapter
if img_is_lora and img_lora_id:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment