Fix I2V model loading: use correct pipeline class for base models

- When loading fine-tuned component models (like LTXVideoTransformer3DModel),
  use the correct pipeline class for the base model instead of the configured
  PipelineClass which may be wrong
- Add proper pipeline class detection for LTX, Wan, SVD, CogVideo, Mochi
- This fixes loading models like Muinez/ltxvideo-2b-nsfw which have
  config.json only (no model_index.json)
parent 6482f2ac
......@@ -8384,9 +8384,36 @@ def main(args):
print(f" Loading base pipeline: {base_model}")
print(f" Then loading fine-tuned {class_name} from: {model_id_to_load}")
# Load base pipeline
pipe = PipelineClass.from_pretrained(base_model, **pipe_kwargs)
print(f" ✅ Base pipeline loaded")
# Determine the correct pipeline class for the base model
BasePipelineClass = None
if "LTX-Video" in base_model or "ltx" in base_model.lower():
try:
from diffusers import LTXVideoPipeline
BasePipelineClass = LTXVideoPipeline
if debug:
print(f" [DEBUG] Using LTXVideoPipeline for base model")
except ImportError:
if debug:
print(f" [DEBUG] LTXVideoPipeline not available, trying generic")
elif "Wan" in base_model or "wan" in base_model.lower():
try:
from diffusers import WanPipeline
BasePipelineClass = WanPipeline
except ImportError:
pass
elif "stable-video-diffusion" in base_model.lower() or "svd" in base_model.lower():
from diffusers import StableVideoDiffusionPipeline
BasePipelineClass = StableVideoDiffusionPipeline
# Fallback to current PipelineClass if we couldn't determine
if BasePipelineClass is None:
BasePipelineClass = PipelineClass
if debug:
print(f" [DEBUG] Using fallback PipelineClass: {PipelineClass.__name__}")
# Load base pipeline with correct class
pipe = BasePipelineClass.from_pretrained(base_model, **pipe_kwargs)
print(f" ✅ Base pipeline loaded with {BasePipelineClass.__name__}")
# Load the fine-tuned component
if class_name == "LTXVideoTransformer3DModel":
......@@ -8429,9 +8456,43 @@ def main(args):
print(f" Trying to load base model first: {base_model}")
print(f" Then loading fine-tuned weights from: {model_id_to_load}")
try:
# Load base model
pipe = PipelineClass.from_pretrained(base_model, **pipe_kwargs)
print(f" ✅ Base model loaded")
# Determine the correct pipeline class for the base model
FallbackPipelineClass = None
if "LTX-Video" in base_model or "ltx" in base_model.lower():
try:
from diffusers import LTXVideoPipeline
FallbackPipelineClass = LTXVideoPipeline
except ImportError:
pass
elif "Wan" in base_model or "wan" in base_model.lower():
try:
from diffusers import WanPipeline
FallbackPipelineClass = WanPipeline
except ImportError:
pass
elif "stable-video-diffusion" in base_model.lower() or "svd" in base_model.lower():
from diffusers import StableVideoDiffusionPipeline
FallbackPipelineClass = StableVideoDiffusionPipeline
elif "cogvideo" in base_model.lower():
try:
from diffusers import CogVideoXPipeline
FallbackPipelineClass = CogVideoXPipeline
except ImportError:
pass
elif "mochi" in base_model.lower():
try:
from diffusers import MochiPipeline
FallbackPipelineClass = MochiPipeline
except ImportError:
pass
# Fallback to current PipelineClass
if FallbackPipelineClass is None:
FallbackPipelineClass = PipelineClass
# Load base model with correct pipeline class
pipe = FallbackPipelineClass.from_pretrained(base_model, **pipe_kwargs)
print(f" ✅ Base model loaded with {FallbackPipelineClass.__name__}")
# Now try to load the fine-tuned components
# This works for models that have component folders but no model_index.json
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment