Refactor: Always detect pipeline at runtime based on model ID + task

- Add get_pipeline_for_task() function that determines pipeline
  based on model ID AND task type (t2v, i2v, t2i, i2i, v2v)
- Pipeline class is now ALWAYS detected at runtime, not from config
- Remove old dynamic switching code that's now redundant
- Update check_model.py to show runtime detection instead of fixing config
- Update check_pipelines.py to show V2V pipelines
parent 2e248868
#!/usr/bin/env python3
"""Check model information - pipeline class is now detected at runtime"""
import sys
import os
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from videogen import MODELS, detect_model_type, PIPELINE_CLASS_MAP, save_models_config
from videogen import MODELS, detect_model_type, get_pipeline_for_task
def check_and_fix_pipeline_class(model_name, info):
"""Check if the pipeline class is correct and fix if needed"""
model_id = info.get("id", "").lower()
current_class = info.get("class", "")
# Determine expected pipeline class based on model ID
expected_class = None
# Wan models should use WanPipeline
if "wan" in model_id:
expected_class = "WanPipeline"
# LTX models
elif "ltx-video" in model_id or "ltxvideo" in model_id:
expected_class = "LTXPipeline"
# Stable Video Diffusion
elif "stable-video-diffusion" in model_id or "svd" in model_id:
expected_class = "StableVideoDiffusionPipeline"
# CogVideoX
elif "cogvideox" in model_id or "cogvideo" in model_id:
expected_class = "CogVideoXPipeline"
# I2VGenXL
elif "i2vgen" in model_id:
expected_class = "I2VGenXLPipeline"
# Mochi
elif "mochi" in model_id:
expected_class = "MochiPipeline"
# AnimateDiff
elif "animatediff" in model_id:
expected_class = "AnimateDiffPipeline"
# Flux
elif "flux" in model_id:
if "img2img" in model_id or "i2i" in model_id:
expected_class = "FluxImg2ImgPipeline"
else:
expected_class = "FluxPipeline"
# Stable Diffusion XL
elif "stable-diffusion-xl" in model_id or "sdxl" in model_id:
if "img2img" in model_id or "i2i" in model_id:
expected_class = "StableDiffusionXLImg2ImgPipeline"
else:
expected_class = "StableDiffusionXLPipeline"
# Stable Diffusion 3
elif "stable-diffusion-3" in model_id or "sd3" in model_id:
if "img2img" in model_id or "i2i" in model_id:
expected_class = "StableDiffusion3Img2ImgPipeline"
else:
expected_class = "StableDiffusion3Pipeline"
# Check if fix needed
if expected_class and current_class != expected_class:
print(f"\n🔧 FIXING: {model_name}")
print(f" Current: {current_class}")
print(f" Expected: {expected_class}")
# Update the model
MODELS[model_name]["class"] = expected_class
return True
return False
# Check and fix all models
print("=" * 60)
print("CHECKING AND FIXING PIPELINE CLASSES")
print("=" * 60)
fixed_count = 0
for name, info in MODELS.items():
if check_and_fix_pipeline_class(name, info):
fixed_count += 1
if fixed_count > 0:
print(f"\n✅ Fixed {fixed_count} models")
save_models_config(MODELS)
else:
print("\n✅ All models have correct pipeline classes")
print("\n" + "=" * 60)
print("MODEL CHECK RESULTS")
print("MODEL CHECK - Pipeline is detected at runtime")
print("=" * 60)
print("\nPipeline class is now determined at RUNTIME based on:")
print(" 1. Model ID (e.g., Wan-AI/Wan2.2-I2V-A14B)")
print(" 2. Task type (t2v, i2v, t2i, i2i, v2v)")
print("\nThis means the stored 'class' field in config is ignored!")
print()
# Show Wan models
for name, info in MODELS.items():
if "wan" in name.lower() or "wan" in info.get("id", "").lower():
model_id = info.get("id", "")
if "wan" in name.lower() or "wan" in model_id.lower():
print(f"\nName: {name}")
print(f"ID: {info.get('id')}")
print(f"Class: {info.get('class')}")
print(f"Stored Class: {info.get('class')} (IGNORED)")
print(f"Supports I2V: {info.get('supports_i2v')}")
# Show what pipeline would be used for each task
print(f"\nRuntime pipeline selection:")
for task in ["t2v", "i2v", "v2v"]:
pipeline = get_pipeline_for_task(model_id, task)
print(f" {task.upper()}: {pipeline}")
print(f"Tags: {info.get('tags')}")
print(f"Capabilities: {detect_model_type(info)}")
......@@ -21,168 +21,74 @@ def check_pipelines():
print(f"\nFound {len(pipeline_names)} Pipeline classes\n")
# Group by base model
pipelines_by_base = {}
for name in sorted(pipeline_names):
# Categorize
base = name.replace("Pipeline", "").replace("Img2Img", "_I2I").replace("ImageToVideo", "_I2V").replace("VideoToVideo", "_V2V").replace("TextToImage", "_T2I").replace("TextToVideo", "_T2V").replace("ImageToImage", "_I2I").replace("VideoImageToVideo", "_V2V")
# Group by base
for prefix in ["Wan", "Flux", "StableDiffusion", "SD3", "LTX", "Mochi", "CogVideo", "I2VGen", "AnimateDiff"]:
if name.startswith(prefix):
base = prefix
break
if base not in pipelines_by_base:
pipelines_by_base[base] = []
pipelines_by_base[base].append(name)
# Print grouped results
# Print all Video-related pipelines
print("\n" + "=" * 70)
print("PIPELINES GROUPED BY BASE MODEL")
print("VIDEO PIPELINES (T2V, I2V, V2V)")
print("=" * 70)
for base in sorted(pipelines_by_base.keys()):
names = pipelines_by_base[base]
if len(names) > 1:
print(f"\n{base}:")
for n in sorted(names):
print(f" - {n}")
video_pipelines = {
"T2V": [],
"I2V": [],
"V2V": []
}
# Check specific models
for name in sorted(pipeline_names):
name_lower = name.lower()
# T2V pipelines
if any(x in name_lower for x in ["texttovideo", "text-to-video", "mochi", "ltx", "cogvideo", "animatediff"]) and "image" not in name_lower and "video" not in name_lower:
if "TextToVideo" in name or "MoChi" in name or name in ["LTXPipeline", "CogVideoXPipeline", "AnimateDiffPipeline"]:
video_pipelines["T2V"].append(name)
# I2V pipelines
elif "imagetovideo" in name_lower or "image-to-video" in name_lower or "i2v" in name_lower:
video_pipelines["I2V"].append(name)
# V2V pipelines
elif "videotovideo" in name_lower or "video-to-video" in name_lower or "v2v" in name_lower:
video_pipelines["V2V"].append(name)
print("\nT2V (Text-to-Video):")
for p in sorted(video_pipelines["T2V"]):
print(f" - {p}")
print("\nI2V (Image-to-Video):")
for p in sorted(video_pipelines["I2V"]):
print(f" - {p}")
print("\nV2V (Video-to-Video):")
for p in sorted(video_pipelines["V2V"]):
print(f" - {p}")
# Print all Image-related pipelines
print("\n" + "=" * 70)
print("CHECKING SPECIFIC PIPELINE VARIANTS")
print("IMAGE PIPELINES (T2I, I2I)")
print("=" * 70)
# Wan pipelines
print("\n--- Wan ---")
try:
from diffusers import WanPipeline
print(f"WanPipeline (T2V): {WanPipeline}")
except Exception as e:
print(f"WanPipeline (T2V): ERROR - {e}")
try:
from diffusers import WanImageToVideoPipeline
print(f"WanImageToVideoPipeline (I2V): {WanImageToVideoPipeline}")
except Exception as e:
print(f"WanImageToVideoPipeline (I2V): ERROR - {e}")
try:
from diffusers import WanVideoToVideoPipeline
print(f"WanVideoToVideoPipeline (V2V): {WanVideoToVideoPipeline}")
except Exception as e:
print(f"WanVideoToVideoPipeline (V2V): ERROR - {e}")
# LTX pipelines
print("\n--- LTX ---")
try:
from diffusers import LTXPipeline
print(f"LTXPipeline (T2V): {LTXPipeline}")
except Exception as e:
print(f"LTXPipeline (T2V): ERROR - {e}")
try:
from diffusers import LTXImageToVideoPipeline
print(f"LTXImageToVideoPipeline (I2V): {LTXImageToVideoPipeline}")
except Exception as e:
print(f"LTXImageToVideoPipeline (I2V): ERROR - {e}")
# Flux pipelines
print("\n--- Flux ---")
try:
from diffusers import FluxPipeline
print(f"FluxPipeline (T2I): {FluxPipeline}")
except Exception as e:
print(f"FluxPipeline (T2I): ERROR - {e}")
try:
from diffusers import FluxImg2ImgPipeline
print(f"FluxImg2ImgPipeline (I2I): {FluxImg2ImgPipeline}")
except Exception as e:
print(f"FluxImg2ImgPipeline (I2I): ERROR - {e}")
# Stable Diffusion XL pipelines
print("\n--- Stable Diffusion XL ---")
try:
from diffusers import StableDiffusionXLPipeline
print(f"StableDiffusionXLPipeline (T2I): {StableDiffusionXLPipeline}")
except Exception as e:
print(f"StableDiffusionXLPipeline (T2I): ERROR - {e}")
image_pipelines = {
"T2I": [],
"I2I": []
}
try:
from diffusers import StableDiffusionXLImg2ImgPipeline
print(f"StableDiffusionXLImg2ImgPipeline (I2I): {StableDiffusionXLImg2ImgPipeline}")
except Exception as e:
print(f"StableDiffusionXLImg2ImgPipeline (I2I): ERROR - {e}")
# Stable Diffusion 3 pipelines
print("\n--- Stable Diffusion 3 ---")
try:
from diffusers import StableDiffusion3Pipeline
print(f"StableDiffusion3Pipeline (T2I): {StableDiffusion3Pipeline}")
except Exception as e:
print(f"StableDiffusion3Pipeline (T2I): ERROR - {e}")
try:
from diffusers import StableDiffusion3Img2ImgPipeline
print(f"StableDiffusion3Img2ImgPipeline (I2I): {StableDiffusion3Img2ImgPipeline}")
except Exception as e:
print(f"StableDiffusion3Img2ImgPipeline (I2I): ERROR - {e}")
# Stable Video Diffusion
print("\n--- Stable Video Diffusion ---")
try:
from diffusers import StableVideoDiffusionPipeline
print(f"StableVideoDiffusionPipeline (I2V): {StableVideoDiffusionPipeline}")
except Exception as e:
print(f"StableVideoDiffusionPipeline (I2V): ERROR - {e}")
# CogVideoX
print("\n--- CogVideoX ---")
try:
from diffusers import CogVideoXPipeline
print(f"CogVideoXPipeline (T2V): {CogVideoXPipeline}")
except Exception as e:
print(f"CogVideoXPipeline (T2V): ERROR - {e}")
try:
from diffusers import CogVideoXImageToVideoPipeline
print(f"CogVideoXImageToVideoPipeline (I2V): {CogVideoXImageToVideoPipeline}")
except Exception as e:
print(f"CogVideoXImageToVideoPipeline (I2V): ERROR - {e}")
# I2VGenXL
print("\n--- I2VGenXL ---")
try:
from diffusers import I2VGenXLPipeline
print(f"I2VGenXLPipeline (I2V): {I2VGenXLPipeline}")
except Exception as e:
print(f"I2VGenXLPipeline (I2V): ERROR - {e}")
# Mochi
print("\n--- Mochi ---")
try:
from diffusers import MochiPipeline
print(f"MochiPipeline (T2V): {MochiPipeline}")
except Exception as e:
print(f"MochiPipeline (T2V): ERROR - {e}")
# AnimateDiff
print("\n--- AnimateDiff ---")
try:
from diffusers import AnimateDiffPipeline
print(f"AnimateDiffPipeline (T2V): {AnimateDiffPipeline}")
except Exception as e:
print(f"AnimateDiffPipeline (T2V): ERROR - {e}")
for name in sorted(pipeline_names):
name_lower = name.lower()
# T2I pipelines
if "texttoimage" in name_lower or "text-to-image" in name_lower:
if "img2img" not in name_lower and "inpaint" not in name_lower:
image_pipelines["T2I"].append(name)
# I2I pipelines
elif "img2img" in name_lower or "image-to-image" in name_lower:
image_pipelines["I2I"].append(name)
print("\nT2I (Text-to-Image):")
for p in sorted(set(image_pipelines["T2I"])):
print(f" - {p}")
print("\nI2I (Image-to-Image):")
for p in sorted(set(image_pipelines["I2I"])):
print(f" - {p}")
print("\n" + "=" * 70)
print("SUMMARY OF PIPELINE VARIANTS")
print("SUMMARY - PIPELINE SELECTION BY MODEL + TASK")
print("=" * 70)
print("""
Based on the check, here's the mapping of pipelines:
T2V (Text-to-Video):
- WanPipeline
- LTXPipeline
......@@ -197,6 +103,11 @@ I2V (Image-to-Video):
- CogVideoXImageToVideoPipeline
- I2VGenXLPipeline
V2V (Video-to-Video):
- WanVideoToVideoPipeline
- CogVideoXVideoToVideoPipeline
- AnimateDiffVideoToVideoPipeline
T2I (Text-to-Image):
- FluxPipeline
- StableDiffusionXLPipeline
......
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment