#!/usr/bin/env python3
import sys
import os
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))

from videogen import MODELS, detect_model_type, PIPELINE_CLASS_MAP, save_models_config

def check_and_fix_pipeline_class(model_name, info):
    """Check if the pipeline class is correct and fix if needed"""
    model_id = info.get("id", "").lower()
    current_class = info.get("class", "")
    
    # Determine expected pipeline class based on model ID
    expected_class = None
    
    # Wan models should use WanPipeline
    if "wan" in model_id:
        expected_class = "WanPipeline"
    
    # LTX models
    elif "ltx-video" in model_id or "ltxvideo" in model_id:
        expected_class = "LTXPipeline"
    
    # Stable Video Diffusion
    elif "stable-video-diffusion" in model_id or "svd" in model_id:
        expected_class = "StableVideoDiffusionPipeline"
    
    # CogVideoX
    elif "cogvideox" in model_id or "cogvideo" in model_id:
        expected_class = "CogVideoXPipeline"
    
    # I2VGenXL
    elif "i2vgen" in model_id:
        expected_class = "I2VGenXLPipeline"
    
    # Mochi
    elif "mochi" in model_id:
        expected_class = "MochiPipeline"
    
    # AnimateDiff
    elif "animatediff" in model_id:
        expected_class = "AnimateDiffPipeline"
    
    # Flux
    elif "flux" in model_id:
        if "img2img" in model_id or "i2i" in model_id:
            expected_class = "FluxImg2ImgPipeline"
        else:
            expected_class = "FluxPipeline"
    
    # Stable Diffusion XL
    elif "stable-diffusion-xl" in model_id or "sdxl" in model_id:
        if "img2img" in model_id or "i2i" in model_id:
            expected_class = "StableDiffusionXLImg2ImgPipeline"
        else:
            expected_class = "StableDiffusionXLPipeline"
    
    # Stable Diffusion 3
    elif "stable-diffusion-3" in model_id or "sd3" in model_id:
        if "img2img" in model_id or "i2i" in model_id:
            expected_class = "StableDiffusion3Img2ImgPipeline"
        else:
            expected_class = "StableDiffusion3Pipeline"
    
    # Check if fix needed
    if expected_class and current_class != expected_class:
        print(f"\n🔧 FIXING: {model_name}")
        print(f"   Current:  {current_class}")
        print(f"   Expected: {expected_class}")
        
        # Update the model
        MODELS[model_name]["class"] = expected_class
        return True
    
    return False

# Check and fix all models
print("=" * 60)
print("CHECKING AND FIXING PIPELINE CLASSES")
print("=" * 60)

fixed_count = 0
for name, info in MODELS.items():
    if check_and_fix_pipeline_class(name, info):
        fixed_count += 1

if fixed_count > 0:
    print(f"\n✅ Fixed {fixed_count} models")
    save_models_config(MODELS)
else:
    print("\n✅ All models have correct pipeline classes")

print("\n" + "=" * 60)
print("MODEL CHECK RESULTS")
print("=" * 60)

# Show Wan models
for name, info in MODELS.items():
    if "wan" in name.lower() or "wan" in info.get("id", "").lower():
        print(f"\nName: {name}")
        print(f"ID: {info.get('id')}")
        print(f"Class: {info.get('class')}")
        print(f"Supports I2V: {info.get('supports_i2v')}")
        print(f"Tags: {info.get('tags')}")
        print(f"Capabilities: {detect_model_type(info)}")
