Refactor: Always detect pipeline at runtime based on model ID + task

- Add get_pipeline_for_task() function that determines pipeline
  based on model ID AND task type (t2v, i2v, t2i, i2i, v2v)
- Pipeline class is now ALWAYS detected at runtime, not from config
- Remove old dynamic switching code that's now redundant
- Update check_model.py to show runtime detection instead of fixing config
- Update check_pipelines.py to show V2V pipelines
parent 2e248868
#!/usr/bin/env python3 #!/usr/bin/env python3
"""Check model information - pipeline class is now detected at runtime"""
import sys import sys
import os import os
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from videogen import MODELS, detect_model_type, PIPELINE_CLASS_MAP, save_models_config from videogen import MODELS, detect_model_type, get_pipeline_for_task
def check_and_fix_pipeline_class(model_name, info):
"""Check if the pipeline class is correct and fix if needed"""
model_id = info.get("id", "").lower()
current_class = info.get("class", "")
# Determine expected pipeline class based on model ID
expected_class = None
# Wan models should use WanPipeline
if "wan" in model_id:
expected_class = "WanPipeline"
# LTX models
elif "ltx-video" in model_id or "ltxvideo" in model_id:
expected_class = "LTXPipeline"
# Stable Video Diffusion
elif "stable-video-diffusion" in model_id or "svd" in model_id:
expected_class = "StableVideoDiffusionPipeline"
# CogVideoX
elif "cogvideox" in model_id or "cogvideo" in model_id:
expected_class = "CogVideoXPipeline"
# I2VGenXL
elif "i2vgen" in model_id:
expected_class = "I2VGenXLPipeline"
# Mochi
elif "mochi" in model_id:
expected_class = "MochiPipeline"
# AnimateDiff
elif "animatediff" in model_id:
expected_class = "AnimateDiffPipeline"
# Flux
elif "flux" in model_id:
if "img2img" in model_id or "i2i" in model_id:
expected_class = "FluxImg2ImgPipeline"
else:
expected_class = "FluxPipeline"
# Stable Diffusion XL
elif "stable-diffusion-xl" in model_id or "sdxl" in model_id:
if "img2img" in model_id or "i2i" in model_id:
expected_class = "StableDiffusionXLImg2ImgPipeline"
else:
expected_class = "StableDiffusionXLPipeline"
# Stable Diffusion 3
elif "stable-diffusion-3" in model_id or "sd3" in model_id:
if "img2img" in model_id or "i2i" in model_id:
expected_class = "StableDiffusion3Img2ImgPipeline"
else:
expected_class = "StableDiffusion3Pipeline"
# Check if fix needed
if expected_class and current_class != expected_class:
print(f"\n🔧 FIXING: {model_name}")
print(f" Current: {current_class}")
print(f" Expected: {expected_class}")
# Update the model
MODELS[model_name]["class"] = expected_class
return True
return False
# Check and fix all models
print("=" * 60) print("=" * 60)
print("CHECKING AND FIXING PIPELINE CLASSES") print("MODEL CHECK - Pipeline is detected at runtime")
print("=" * 60)
fixed_count = 0
for name, info in MODELS.items():
if check_and_fix_pipeline_class(name, info):
fixed_count += 1
if fixed_count > 0:
print(f"\n✅ Fixed {fixed_count} models")
save_models_config(MODELS)
else:
print("\n✅ All models have correct pipeline classes")
print("\n" + "=" * 60)
print("MODEL CHECK RESULTS")
print("=" * 60) print("=" * 60)
print("\nPipeline class is now determined at RUNTIME based on:")
print(" 1. Model ID (e.g., Wan-AI/Wan2.2-I2V-A14B)")
print(" 2. Task type (t2v, i2v, t2i, i2i, v2v)")
print("\nThis means the stored 'class' field in config is ignored!")
print()
# Show Wan models # Show Wan models
for name, info in MODELS.items(): for name, info in MODELS.items():
if "wan" in name.lower() or "wan" in info.get("id", "").lower(): model_id = info.get("id", "")
if "wan" in name.lower() or "wan" in model_id.lower():
print(f"\nName: {name}") print(f"\nName: {name}")
print(f"ID: {info.get('id')}") print(f"ID: {info.get('id')}")
print(f"Class: {info.get('class')}") print(f"Stored Class: {info.get('class')} (IGNORED)")
print(f"Supports I2V: {info.get('supports_i2v')}") print(f"Supports I2V: {info.get('supports_i2v')}")
# Show what pipeline would be used for each task
print(f"\nRuntime pipeline selection:")
for task in ["t2v", "i2v", "v2v"]:
pipeline = get_pipeline_for_task(model_id, task)
print(f" {task.upper()}: {pipeline}")
print(f"Tags: {info.get('tags')}") print(f"Tags: {info.get('tags')}")
print(f"Capabilities: {detect_model_type(info)}")
...@@ -21,168 +21,74 @@ def check_pipelines(): ...@@ -21,168 +21,74 @@ def check_pipelines():
print(f"\nFound {len(pipeline_names)} Pipeline classes\n") print(f"\nFound {len(pipeline_names)} Pipeline classes\n")
# Group by base model # Print all Video-related pipelines
pipelines_by_base = {}
for name in sorted(pipeline_names):
# Categorize
base = name.replace("Pipeline", "").replace("Img2Img", "_I2I").replace("ImageToVideo", "_I2V").replace("VideoToVideo", "_V2V").replace("TextToImage", "_T2I").replace("TextToVideo", "_T2V").replace("ImageToImage", "_I2I").replace("VideoImageToVideo", "_V2V")
# Group by base
for prefix in ["Wan", "Flux", "StableDiffusion", "SD3", "LTX", "Mochi", "CogVideo", "I2VGen", "AnimateDiff"]:
if name.startswith(prefix):
base = prefix
break
if base not in pipelines_by_base:
pipelines_by_base[base] = []
pipelines_by_base[base].append(name)
# Print grouped results
print("\n" + "=" * 70) print("\n" + "=" * 70)
print("PIPELINES GROUPED BY BASE MODEL") print("VIDEO PIPELINES (T2V, I2V, V2V)")
print("=" * 70) print("=" * 70)
for base in sorted(pipelines_by_base.keys()): video_pipelines = {
names = pipelines_by_base[base] "T2V": [],
if len(names) > 1: "I2V": [],
print(f"\n{base}:") "V2V": []
for n in sorted(names): }
print(f" - {n}")
# Check specific models for name in sorted(pipeline_names):
name_lower = name.lower()
# T2V pipelines
if any(x in name_lower for x in ["texttovideo", "text-to-video", "mochi", "ltx", "cogvideo", "animatediff"]) and "image" not in name_lower and "video" not in name_lower:
if "TextToVideo" in name or "MoChi" in name or name in ["LTXPipeline", "CogVideoXPipeline", "AnimateDiffPipeline"]:
video_pipelines["T2V"].append(name)
# I2V pipelines
elif "imagetovideo" in name_lower or "image-to-video" in name_lower or "i2v" in name_lower:
video_pipelines["I2V"].append(name)
# V2V pipelines
elif "videotovideo" in name_lower or "video-to-video" in name_lower or "v2v" in name_lower:
video_pipelines["V2V"].append(name)
print("\nT2V (Text-to-Video):")
for p in sorted(video_pipelines["T2V"]):
print(f" - {p}")
print("\nI2V (Image-to-Video):")
for p in sorted(video_pipelines["I2V"]):
print(f" - {p}")
print("\nV2V (Video-to-Video):")
for p in sorted(video_pipelines["V2V"]):
print(f" - {p}")
# Print all Image-related pipelines
print("\n" + "=" * 70) print("\n" + "=" * 70)
print("CHECKING SPECIFIC PIPELINE VARIANTS") print("IMAGE PIPELINES (T2I, I2I)")
print("=" * 70) print("=" * 70)
# Wan pipelines image_pipelines = {
print("\n--- Wan ---") "T2I": [],
try: "I2I": []
from diffusers import WanPipeline }
print(f"WanPipeline (T2V): {WanPipeline}")
except Exception as e:
print(f"WanPipeline (T2V): ERROR - {e}")
try:
from diffusers import WanImageToVideoPipeline
print(f"WanImageToVideoPipeline (I2V): {WanImageToVideoPipeline}")
except Exception as e:
print(f"WanImageToVideoPipeline (I2V): ERROR - {e}")
try:
from diffusers import WanVideoToVideoPipeline
print(f"WanVideoToVideoPipeline (V2V): {WanVideoToVideoPipeline}")
except Exception as e:
print(f"WanVideoToVideoPipeline (V2V): ERROR - {e}")
# LTX pipelines
print("\n--- LTX ---")
try:
from diffusers import LTXPipeline
print(f"LTXPipeline (T2V): {LTXPipeline}")
except Exception as e:
print(f"LTXPipeline (T2V): ERROR - {e}")
try:
from diffusers import LTXImageToVideoPipeline
print(f"LTXImageToVideoPipeline (I2V): {LTXImageToVideoPipeline}")
except Exception as e:
print(f"LTXImageToVideoPipeline (I2V): ERROR - {e}")
# Flux pipelines
print("\n--- Flux ---")
try:
from diffusers import FluxPipeline
print(f"FluxPipeline (T2I): {FluxPipeline}")
except Exception as e:
print(f"FluxPipeline (T2I): ERROR - {e}")
try:
from diffusers import FluxImg2ImgPipeline
print(f"FluxImg2ImgPipeline (I2I): {FluxImg2ImgPipeline}")
except Exception as e:
print(f"FluxImg2ImgPipeline (I2I): ERROR - {e}")
# Stable Diffusion XL pipelines
print("\n--- Stable Diffusion XL ---")
try:
from diffusers import StableDiffusionXLPipeline
print(f"StableDiffusionXLPipeline (T2I): {StableDiffusionXLPipeline}")
except Exception as e:
print(f"StableDiffusionXLPipeline (T2I): ERROR - {e}")
try: for name in sorted(pipeline_names):
from diffusers import StableDiffusionXLImg2ImgPipeline name_lower = name.lower()
print(f"StableDiffusionXLImg2ImgPipeline (I2I): {StableDiffusionXLImg2ImgPipeline}") # T2I pipelines
except Exception as e: if "texttoimage" in name_lower or "text-to-image" in name_lower:
print(f"StableDiffusionXLImg2ImgPipeline (I2I): ERROR - {e}") if "img2img" not in name_lower and "inpaint" not in name_lower:
image_pipelines["T2I"].append(name)
# Stable Diffusion 3 pipelines # I2I pipelines
print("\n--- Stable Diffusion 3 ---") elif "img2img" in name_lower or "image-to-image" in name_lower:
try: image_pipelines["I2I"].append(name)
from diffusers import StableDiffusion3Pipeline
print(f"StableDiffusion3Pipeline (T2I): {StableDiffusion3Pipeline}") print("\nT2I (Text-to-Image):")
except Exception as e: for p in sorted(set(image_pipelines["T2I"])):
print(f"StableDiffusion3Pipeline (T2I): ERROR - {e}") print(f" - {p}")
try: print("\nI2I (Image-to-Image):")
from diffusers import StableDiffusion3Img2ImgPipeline for p in sorted(set(image_pipelines["I2I"])):
print(f"StableDiffusion3Img2ImgPipeline (I2I): {StableDiffusion3Img2ImgPipeline}") print(f" - {p}")
except Exception as e:
print(f"StableDiffusion3Img2ImgPipeline (I2I): ERROR - {e}")
# Stable Video Diffusion
print("\n--- Stable Video Diffusion ---")
try:
from diffusers import StableVideoDiffusionPipeline
print(f"StableVideoDiffusionPipeline (I2V): {StableVideoDiffusionPipeline}")
except Exception as e:
print(f"StableVideoDiffusionPipeline (I2V): ERROR - {e}")
# CogVideoX
print("\n--- CogVideoX ---")
try:
from diffusers import CogVideoXPipeline
print(f"CogVideoXPipeline (T2V): {CogVideoXPipeline}")
except Exception as e:
print(f"CogVideoXPipeline (T2V): ERROR - {e}")
try:
from diffusers import CogVideoXImageToVideoPipeline
print(f"CogVideoXImageToVideoPipeline (I2V): {CogVideoXImageToVideoPipeline}")
except Exception as e:
print(f"CogVideoXImageToVideoPipeline (I2V): ERROR - {e}")
# I2VGenXL
print("\n--- I2VGenXL ---")
try:
from diffusers import I2VGenXLPipeline
print(f"I2VGenXLPipeline (I2V): {I2VGenXLPipeline}")
except Exception as e:
print(f"I2VGenXLPipeline (I2V): ERROR - {e}")
# Mochi
print("\n--- Mochi ---")
try:
from diffusers import MochiPipeline
print(f"MochiPipeline (T2V): {MochiPipeline}")
except Exception as e:
print(f"MochiPipeline (T2V): ERROR - {e}")
# AnimateDiff
print("\n--- AnimateDiff ---")
try:
from diffusers import AnimateDiffPipeline
print(f"AnimateDiffPipeline (T2V): {AnimateDiffPipeline}")
except Exception as e:
print(f"AnimateDiffPipeline (T2V): ERROR - {e}")
print("\n" + "=" * 70) print("\n" + "=" * 70)
print("SUMMARY OF PIPELINE VARIANTS") print("SUMMARY - PIPELINE SELECTION BY MODEL + TASK")
print("=" * 70) print("=" * 70)
print(""" print("""
Based on the check, here's the mapping of pipelines:
T2V (Text-to-Video): T2V (Text-to-Video):
- WanPipeline - WanPipeline
- LTXPipeline - LTXPipeline
...@@ -197,6 +103,11 @@ I2V (Image-to-Video): ...@@ -197,6 +103,11 @@ I2V (Image-to-Video):
- CogVideoXImageToVideoPipeline - CogVideoXImageToVideoPipeline
- I2VGenXLPipeline - I2VGenXLPipeline
V2V (Video-to-Video):
- WanVideoToVideoPipeline
- CogVideoXVideoToVideoPipeline
- AnimateDiffVideoToVideoPipeline
T2I (Text-to-Image): T2I (Text-to-Image):
- FluxPipeline - FluxPipeline
- StableDiffusionXLPipeline - StableDiffusionXLPipeline
......
...@@ -1227,6 +1227,138 @@ def detect_pipeline_class(model_info): ...@@ -1227,6 +1227,138 @@ def detect_pipeline_class(model_info):
return None return None
def get_pipeline_for_task(model_id, task_type):
"""Get the correct pipeline class based on model ID and task type.
This function ALWAYS determines the pipeline at runtime based on:
- The model ID (to determine the base model family)
- The task type (t2v, i2v, t2i, i2i, v2v)
It does NOT use stored config values.
Returns the pipeline class name (string).
"""
model_id_lower = model_id.lower()
# Video pipelines
if task_type in ["t2v", "i2v", "v2v"]:
# Wan models
if "wan" in model_id_lower:
if task_type == "i2v":
return "WanImageToVideoPipeline"
elif task_type == "v2v":
return "WanVideoToVideoPipeline"
else: # t2v
return "WanPipeline"
# LTX models
if "ltx" in model_id_lower:
if task_type == "i2v":
return "LTXImageToVideoPipeline"
else:
return "LTXPipeline"
# Stable Video Diffusion
if "stable-video-diffusion" in model_id_lower or "svd" in model_id_lower:
return "StableVideoDiffusionPipeline"
# CogVideoX
if "cogvideo" in model_id_lower:
if task_type == "i2v":
return "CogVideoXImageToVideoPipeline"
elif task_type == "v2v":
return "CogVideoXVideoToVideoPipeline"
else:
return "CogVideoXPipeline"
# I2VGenXL
if "i2vgen" in model_id_lower:
return "I2VGenXLPipeline"
# Mochi
if "mochi" in model_id_lower:
return "MochiPipeline"
# AnimateDiff
if "animatediff" in model_id_lower:
if task_type == "v2v":
return "AnimateDiffVideoToVideoPipeline"
return "AnimateDiffPipeline"
# Allegro
if "allegro" in model_id_lower:
return "AllegroPipeline"
# Hunyuan
if "hunyuan" in model_id_lower:
return "HunyuanDiTPipeline"
# OpenSora
if "open-sora" in model_id_lower or "opensora" in model_id_lower:
return "OpenSoraPipeline"
# StepVideo
if "stepvideo" in model_id_lower or "step-video" in model_id_lower:
return "StepVideoPipeline"
# Zeroscope
if "zeroscope" in model_id_lower:
return "TextToVideoZeroPipeline"
# Modelscope
if "modelscope" in model_id_lower or "text-to-video-ms" in model_id_lower:
return "TextToVideoSDPipeline"
# Latte
if "latte" in model_id_lower:
return "LattePipeline"
# Hotshot
if "hotshot" in model_id_lower:
return "HotshotXLPipeline"
# Default fallback for video
return "WanPipeline"
# Image pipelines
elif task_type in ["t2i", "i2i"]:
# Flux
if "flux" in model_id_lower:
if task_type == "i2i":
return "FluxImg2ImgPipeline"
return "FluxPipeline"
# Stable Diffusion 3
if "stable-diffusion-3" in model_id_lower or "sd3" in model_id_lower:
if task_type == "i2i":
return "StableDiffusion3Img2ImgPipeline"
return "StableDiffusion3Pipeline"
# Stable Diffusion XL
if "stable-diffusion-xl" in model_id_lower or "sdxl" in model_id_lower:
if task_type == "i2i":
return "StableDiffusionXLImg2ImgPipeline"
return "StableDiffusionXLPipeline"
# Stable Diffusion 1.5
if "stable-diffusion" in model_id_lower and "xl" not in model_id_lower and "3" not in model_id_lower:
if task_type == "i2i":
return "StableDiffusionImg2ImgPipeline"
return "StableDiffusionPipeline"
# Lumina
if "lumina2" in model_id_lower or "lumina-2" in model_id_lower:
return "Lumina2Text2ImgPipeline"
if "lumina" in model_id_lower:
return "LuminaText2ImgPipeline"
# Default fallback for image
return "FluxPipeline"
# Unknown task - return default
return "DiffusionPipeline"
def parse_hf_url_or_id(input_str): def parse_hf_url_or_id(input_str):
"""Parse either a HuggingFace URL or model ID and return the model ID """Parse either a HuggingFace URL or model ID and return the model ID
...@@ -7811,7 +7943,33 @@ def main(args): ...@@ -7811,7 +7943,33 @@ def main(args):
sys.exit(1) sys.exit(1)
m_info = MODELS[args.model] m_info = MODELS[args.model]
PipelineClass = get_pipeline_class(m_info["class"])
# Determine task type based on arguments
model_id = m_info["id"]
is_i2v_mode = args.image_to_video or args.image
is_i2i_mode = args.image_to_image
is_v2v_mode = args.input_video is not None
if is_i2v_mode:
task_type = "i2v"
elif is_v2v_mode:
task_type = "v2v"
elif is_i2i_mode:
task_type = "i2i"
elif args.prompt or args.auto:
# Check if it's an image model or video model
if m_info.get("supports_i2v") or m_info.get("is_video"):
task_type = "t2v"
else:
task_type = "t2i"
else:
task_type = "t2v" # Default
# ALWAYS detect pipeline class at runtime based on model + task
# Do NOT use stored config value
pipeline_class = get_pipeline_for_task(model_id, task_type)
PipelineClass = get_pipeline_class(pipeline_class)
print(f" 📦 Using {pipeline_class} for {task_type.upper()} task")
if not PipelineClass: if not PipelineClass:
pipeline_class = m_info['class'] pipeline_class = m_info['class']
...@@ -7879,94 +8037,6 @@ def main(args): ...@@ -7879,94 +8037,6 @@ def main(args):
if variant := extra.get("variant"): if variant := extra.get("variant"):
pipe_kwargs["variant"] = variant pipe_kwargs["variant"] = variant
# ─── DYNAMIC PIPELINE CLASS SELECTION ───────────────────────────────────────
# Switch pipeline class based on task mode (I2V vs T2V, I2I vs T2I, V2V)
# This is similar to the LTX I2V handling already in place
is_i2v_mode = args.image_to_video or args.image
is_i2i_mode = args.image_to_image
is_v2v_mode = args.input_video is not None # Video-to-video if input video provided
# Handle WanPipeline - switch between T2V, I2V, V2V variants
if m_info["class"] == "WanPipeline":
if is_i2v_mode:
try:
from diffusers import WanImageToVideoPipeline
PipelineClass = WanImageToVideoPipeline
print(f" 🔄 Switched to WanImageToVideoPipeline for I2V mode")
except ImportError:
print(f" ⚠️ WanImageToVideoPipeline not available, using WanPipeline")
elif is_v2v_mode:
try:
from diffusers import WanVideoToVideoPipeline
PipelineClass = WanVideoToVideoPipeline
print(f" 🔄 Switched to WanVideoToVideoPipeline for V2V mode")
except ImportError:
print(f" ⚠️ WanVideoToVideoPipeline not available, using WanPipeline")
# Handle LTXPipeline - switch between T2V and I2V variants
if m_info["class"] == "LTXPipeline":
if is_i2v_mode:
try:
from diffusers import LTXImageToVideoPipeline
PipelineClass = LTXImageToVideoPipeline
print(f" 🔄 Switched to LTXImageToVideoPipeline for I2V mode")
except ImportError:
print(f" ⚠️ LTXImageToVideoPipeline not available, using LTXPipeline")
# Handle CogVideoXPipeline - switch between T2V and I2V variants
if m_info["class"] == "CogVideoXPipeline":
if is_i2v_mode:
try:
from diffusers import CogVideoXImageToVideoPipeline
PipelineClass = CogVideoXImageToVideoPipeline
print(f" 🔄 Switched to CogVideoXImageToVideoPipeline for I2V mode")
except ImportError:
print(f" ⚠️ CogVideoXImageToVideoPipeline not available, using CogVideoXPipeline")
elif is_v2v_mode:
try:
from diffusers import CogVideoXVideoToVideoPipeline
PipelineClass = CogVideoXVideoToVideoPipeline
print(f" 🔄 Switched to CogVideoXVideoToVideoPipeline for V2V mode")
except ImportError:
print(f" ⚠️ CogVideoXVideoToVideoPipeline not available, using CogVideoXPipeline")
# Handle AnimateDiffPipeline - switch between T2V and V2V variants
if m_info["class"] == "AnimateDiffPipeline":
if is_v2v_mode:
try:
from diffusers import AnimateDiffVideoToVideoPipeline
PipelineClass = AnimateDiffVideoToVideoPipeline
print(f" 🔄 Switched to AnimateDiffVideoToVideoPipeline for V2V mode")
except ImportError:
print(f" ⚠️ AnimateDiffVideoToVideoPipeline not available, using AnimateDiffPipeline")
# Handle Flux pipelines - can do T2I and I2I, need to switch based on mode
if "FluxPipeline" in m_info["class"] and is_i2i_mode:
try:
from diffusers import FluxImg2ImgPipeline
# Check if there's a specific FluxImg2Img available
PipelineClass = FluxImg2ImgPipeline
print(f" 🔄 Switched to FluxImg2ImgPipeline for I2I mode")
except ImportError:
print(f" ⚠️ FluxImg2ImgPipeline not available, using FluxPipeline")
# Handle StableDiffusionXLPipeline - can do T2I and I2I
if "StableDiffusionXLPipeline" in m_info["class"] and is_i2i_mode:
try:
from diffusers import StableDiffusionXLImg2ImgPipeline
PipelineClass = StableDiffusionXLImg2ImgPipeline
print(f" 🔄 Switched to StableDiffusionXLImg2ImgPipeline for I2I mode")
except ImportError:
print(f" ⚠️ StableDiffusionXLImg2ImgPipeline not available, using StableDiffusionXLPipeline")
# Handle StableDiffusion3Pipeline - can do T2I and I2I
if "StableDiffusion3Pipeline" in m_info["class"] and is_i2i_mode:
try:
from diffusers import StableDiffusion3Img2ImgPipeline
PipelineClass = StableDiffusion3Img2ImgPipeline
print(f" 🔄 Switched to StableDiffusion3Img2ImgPipeline for I2I mode")
except ImportError:
print(f" ⚠️ StableDiffusion3Img2ImgPipeline not available, using StableDiffusion3Pipeline")
# Handle LoRA models - need to load base model first # Handle LoRA models - need to load base model first
is_lora = m_info.get("is_lora", False) is_lora = m_info.get("is_lora", False)
lora_id = None lora_id = None
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment