Fix component-only model loading to use correct pipeline for I2V mode

- When loading LTX-Video base model in I2V mode, use LTXImageToVideoPipeline
- When loading LTX-Video base model in T2V mode, use LTXPipeline
- Update PipelineClass after loading base pipeline to match the actual class used
- This fixes the 'LTXPipeline.__call__() got an unexpected keyword argument image' error
parent 202861ad
......@@ -8509,11 +8509,31 @@ def main(args):
# Determine the correct pipeline class for the base model
BasePipelineClass = None
if "LTX-Video" in base_model or "ltx" in base_model.lower():
# Check if we're in I2V mode - use LTXImageToVideoPipeline for I2V
is_i2v_mode = (args.image_to_video or args.image)
if is_i2v_mode:
try:
from diffusers import LTXImageToVideoPipeline
BasePipelineClass = LTXImageToVideoPipeline
if debug:
print(f" [DEBUG] Using LTXImageToVideoPipeline for I2V mode")
except ImportError as ie:
if debug:
print(f" [DEBUG] LTXImageToVideoPipeline not available: {ie}")
# Fallback to T2V pipeline
try:
from diffusers import LTXPipeline
BasePipelineClass = LTXPipeline
if debug:
print(f" [DEBUG] Falling back to LTXPipeline (T2V only)")
except ImportError:
pass
else:
try:
from diffusers import LTXPipeline
BasePipelineClass = LTXPipeline
if debug:
print(f" [DEBUG] Using LTXPipeline for base model")
print(f" [DEBUG] Using LTXPipeline for T2V mode")
except ImportError as ie:
if debug:
print(f" [DEBUG] LTXPipeline not available: {ie}")
......@@ -8522,7 +8542,7 @@ def main(args):
from diffusers import LTXImageToVideoPipeline
BasePipelineClass = LTXImageToVideoPipeline
if debug:
print(f" [DEBUG] Using LTXImageToVideoPipeline for base model")
print(f" [DEBUG] Using LTXImageToVideoPipeline as fallback")
except ImportError:
pass
elif "Wan" in base_model or "wan" in base_model.lower():
......@@ -8545,6 +8565,8 @@ def main(args):
try:
pipe = BasePipelineClass.from_pretrained(base_model, **pipe_kwargs)
print(f" ✅ Base pipeline loaded with {BasePipelineClass.__name__}")
# Update PipelineClass to match the base pipeline class
PipelineClass = BasePipelineClass
except Exception as base_load_e:
# Check if this is a tokenizer/cache error
error_str = str(base_load_e)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment