Update check_model.py to fix incorrect pipeline classes in local config

parent 0a1b413b
...@@ -3,12 +3,103 @@ import sys ...@@ -3,12 +3,103 @@ import sys
import os import os
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from videogen import MODELS, detect_model_type from videogen import MODELS, detect_model_type, PIPELINE_CLASS_MAP, save_models_config
def check_and_fix_pipeline_class(model_name, info):
"""Check if the pipeline class is correct and fix if needed"""
model_id = info.get("id", "").lower()
current_class = info.get("class", "")
# Determine expected pipeline class based on model ID
expected_class = None
# Wan models should use WanPipeline
if "wan" in model_id:
expected_class = "WanPipeline"
# LTX models
elif "ltx-video" in model_id or "ltxvideo" in model_id:
expected_class = "LTXPipeline"
# Stable Video Diffusion
elif "stable-video-diffusion" in model_id or "svd" in model_id:
expected_class = "StableVideoDiffusionPipeline"
# CogVideoX
elif "cogvideox" in model_id or "cogvideo" in model_id:
expected_class = "CogVideoXPipeline"
# I2VGenXL
elif "i2vgen" in model_id:
expected_class = "I2VGenXLPipeline"
# Mochi
elif "mochi" in model_id:
expected_class = "MochiPipeline"
# AnimateDiff
elif "animatediff" in model_id:
expected_class = "AnimateDiffPipeline"
# Flux
elif "flux" in model_id:
if "img2img" in model_id or "i2i" in model_id:
expected_class = "FluxImg2ImgPipeline"
else:
expected_class = "FluxPipeline"
# Stable Diffusion XL
elif "stable-diffusion-xl" in model_id or "sdxl" in model_id:
if "img2img" in model_id or "i2i" in model_id:
expected_class = "StableDiffusionXLImg2ImgPipeline"
else:
expected_class = "StableDiffusionXLPipeline"
# Stable Diffusion 3
elif "stable-diffusion-3" in model_id or "sd3" in model_id:
if "img2img" in model_id or "i2i" in model_id:
expected_class = "StableDiffusion3Img2ImgPipeline"
else:
expected_class = "StableDiffusion3Pipeline"
# Check if fix needed
if expected_class and current_class != expected_class:
print(f"\n🔧 FIXING: {model_name}")
print(f" Current: {current_class}")
print(f" Expected: {expected_class}")
# Update the model
MODELS[model_name]["class"] = expected_class
return True
return False
# Check and fix all models
print("=" * 60)
print("CHECKING AND FIXING PIPELINE CLASSES")
print("=" * 60)
fixed_count = 0
for name, info in MODELS.items():
if check_and_fix_pipeline_class(name, info):
fixed_count += 1
if fixed_count > 0:
print(f"\n✅ Fixed {fixed_count} models")
save_models_config(MODELS)
else:
print("\n✅ All models have correct pipeline classes")
print("\n" + "=" * 60)
print("MODEL CHECK RESULTS")
print("=" * 60)
# Show Wan models
for name, info in MODELS.items(): for name, info in MODELS.items():
if name == "wan2_2_i2v_a14b": if "wan" in name.lower() or "wan" in info.get("id", "").lower():
print("Name:", name) print(f"\nName: {name}")
print("ID:", info.get("id")) print(f"ID: {info.get('id')}")
print("Class:", info.get("class")) print(f"Class: {info.get('class')}")
print("Supports I2V:", info.get("supports_i2v")) print(f"Supports I2V: {info.get('supports_i2v')}")
print("Tags:", info.get("tags")) print(f"Tags: {info.get('tags')}")
print("Capabilities:", detect_model_type(info)) print(f"Capabilities: {detect_model_type(info)}")
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment