Fix I2V pipeline auto-detection

- Add detect_model_family() to identify model family (wan, sdxl, sd, ltx, etc.)
- Add get_pipeline_for_model_family() for proper pipeline selection based on family + task
- Enhance detect_generation_type() to check --image FIRST for I2V detection
- Add support for --image_model, --prompt_image, --prompt_animation as I2V indicators
- Add support for audio/subtitle options as T2V+V2V chaining indicators

This fixes the issue where SDXL models were incorrectly using WanPipeline
for I2V tasks, causing type mismatch errors (expected UMT5EncoderModel,
got CLIPTextModel). Now SDXL models correctly use DiffusionPipeline
or StableDiffusionXLPipeline.
parent 803b1763
......@@ -1493,124 +1493,339 @@ def get_pipeline_for_task(model_id, task_type):
"""
model_id_lower = model_id.lower()
# Video pipelines
if task_type in ["t2v", "i2v", "v2v"]:
# Wan models
# First, detect the model family to determine which pipeline family to use
model_family = detect_model_family(model_id)
# Now select the appropriate pipeline based on model family and task type
return get_pipeline_for_model_family(model_family, task_type)
def detect_model_family(model_id):
"""Detect the model family from model ID.
Returns one of:
- "wan" : Wan models (Wan-AI)
- "flux" : Flux models (Black Forest Labs)
- "sdxl" : Stable Diffusion XL
- "sd" : Stable Diffusion 1.5
- "sd3" : Stable Diffusion 3
- "ltx" : LTX-Video models
- "svd" : Stable Video Diffusion
- "cogvideox" : CogVideoX models
- "mochi" : Mochi models
- "animatediff" : AnimateDiff models
- "other_video" : Other video models
- "other_image" : Other image models
- "unknown" : Unknown model family
"""
model_id_lower = model_id.lower()
# Wan models - check first as they're specific
if "wan" in model_id_lower:
if task_type == "i2v":
return "WanImageToVideoPipeline"
elif task_type == "v2v":
return "WanVideoToVideoPipeline"
else: # t2v
return "WanPipeline"
return "wan"
# LTX models
if "ltx" in model_id_lower:
if task_type == "i2v":
return "LTXImageToVideoPipeline"
else:
return "LTXPipeline"
# Flux models
if "flux" in model_id_lower:
return "flux"
# Stable Diffusion XL
if "sdxl" in model_id_lower or "stable-diffusion-xl" in model_id_lower:
return "sdxl"
# Stable Diffusion 3
if "sd3" in model_id_lower or "stable-diffusion-3" in model_id_lower:
return "sd3"
# Stable Diffusion 1.5 (check after XL and 3 to avoid false positives)
if "stable-diffusion" in model_id_lower and "xl" not in model_id_lower and "3" not in model_id_lower:
return "sd"
# LTX-Video
if "ltx" in model_id_lower or "ltx-video" in model_id_lower:
return "ltx"
# Stable Video Diffusion
if "stable-video-diffusion" in model_id_lower or "svd" in model_id_lower:
return "StableVideoDiffusionPipeline"
return "svd"
# CogVideoX
if "cogvideo" in model_id_lower:
if task_type == "i2v":
return "CogVideoXImageToVideoPipeline"
elif task_type == "v2v":
return "CogVideoXVideoToVideoPipeline"
else:
return "CogVideoXPipeline"
# I2VGenXL
if "i2vgen" in model_id_lower:
return "I2VGenXLPipeline"
return "cogvideox"
# Mochi
if "mochi" in model_id_lower:
return "MochiPipeline"
return "mochi"
# AnimateDiff
if "animatediff" in model_id_lower:
if task_type == "v2v":
return "AnimateDiffVideoToVideoPipeline"
return "AnimateDiffPipeline"
return "animatediff"
# Allegro
if "allegro" in model_id_lower:
return "AllegroPipeline"
return "other_video"
# Hunyuan
if "hunyuan" in model_id_lower:
return "HunyuanDiTPipeline"
return "other_video"
# OpenSora
if "open-sora" in model_id_lower or "opensora" in model_id_lower:
return "OpenSoraPipeline"
return "other_video"
# StepVideo
if "stepvideo" in model_id_lower or "step-video" in model_id_lower:
return "StepVideoPipeline"
return "other_video"
# Zeroscope
if "zeroscope" in model_id_lower:
return "TextToVideoZeroPipeline"
return "other_video"
# Modelscope
if "modelscope" in model_id_lower or "text-to-video-ms" in model_id_lower:
return "TextToVideoSDPipeline"
return "other_video"
# Latte
if "latte" in model_id_lower:
return "LattePipeline"
return "other_video"
# Hotshot
if "hotshot" in model_id_lower:
return "HotshotXLPipeline"
return "other_video"
# Default fallback for video
return "WanPipeline"
# I2VGenXL
if "i2vgen" in model_id_lower:
return "other_video"
# Lumina
if "lumina" in model_id_lower:
return "other_image"
# Pony models (usually SDXL-based)
if "pony" in model_id_lower:
return "sdxl"
# Animagine models (usually SDXL-based)
if "animagine" in model_id_lower:
return "sdxl"
# Common SD 1.5 model name patterns
sd15_patterns = ["deliberate", "juggernaut", "realistic_vision", "realisticvision",
"anything", "rev Animated", "counterfeit", "chilloutmix", "pastel mix",
"dreamlike", "douyin", "ghostmix", "toonyou", "redshift"]
if any(pat in model_id_lower for pat in sd15_patterns):
# Check if it's SDXL (some have xl in name)
if "xl" not in model_id_lower:
return "sd"
# Common SDXL model name patterns
sdxl_patterns = ["sdxl", "pony", "animagine", "juggernaut", "cyberrealistic",
"realcartoon", "majicmix", "dreamshaper", "epicrealism",
"absolutereality", "proteus"]
if any(pat in model_id_lower for pat in sdxl_patterns):
return "sdxl"
# Default: check if it looks like a video model
if any(x in model_id_lower for x in ["video", "animation", "motion"]):
return "other_video"
# Default to image model
return "unknown"
def get_pipeline_for_model_family(model_family, task_type):
"""Get the appropriate pipeline class based on model family and task type.
This is the second step of pipeline detection:
1. detect_model_family() - identifies the model family
2. get_pipeline_for_model_family() - selects the correct pipeline
# Image pipelines
Args:
model_family: The model family (from detect_model_family)
task_type: The task type (t2v, i2v, t2i, i2i, v2v)
Returns the pipeline class name (string).
"""
# Handle video generation tasks (t2v, i2v, v2v)
if task_type in ["t2v", "i2v", "v2v"]:
return get_video_pipeline_for_family(model_family, task_type)
# Handle image generation tasks (t2i, i2i)
elif task_type in ["t2i", "i2i"]:
return get_image_pipeline_for_family(model_family, task_type)
# Unknown task - use DiffusionPipeline as fallback
return "DiffusionPipeline"
def get_video_pipeline_for_family(model_family, task_type):
"""Get the appropriate video pipeline for a given model family and task type."""
# Wan models
if model_family == "wan":
if task_type == "i2v":
return "WanImageToVideoPipeline"
elif task_type == "v2v":
return "WanVideoToVideoPipeline"
else: # t2v
return "WanPipeline"
# LTX models
elif model_family == "ltx":
if task_type == "i2v":
return "LTXImageToVideoPipeline"
elif task_type == "v2v":
# LTX supports video-to-video via its base pipeline
return "LTXPipeline"
else:
return "LTXPipeline"
# Stable Video Diffusion
elif model_family == "svd":
return "StableVideoDiffusionPipeline"
# CogVideoX
elif model_family == "cogvideox":
if task_type == "i2v":
return "CogVideoXImageToVideoPipeline"
elif task_type == "v2v":
return "CogVideoXVideoToVideoPipeline"
else:
return "CogVideoXPipeline"
# I2VGenXL
elif model_family == "i2vgen":
return "I2VGenXLPipeline"
# Mochi
elif model_family == "mochi":
return "MochiPipeline"
# AnimateDiff
elif model_family == "animatediff":
if task_type == "v2v":
return "AnimateDiffVideoToVideoPipeline"
return "AnimateDiffPipeline"
# Other known video models
elif model_family == "other_video":
# Map specific patterns to known pipelines
return "DiffusionPipeline" # Let diffusers auto-detect
# Image model families - they don't support video generation
# Return DiffusionPipeline to let it try (or fail gracefully)
elif model_family in ["sdxl", "sd", "sd3", "flux"]:
# These are image models, not video models
# Return DiffusionPipeline which will auto-detect
# The actual error will come from the pipeline
return "DiffusionPipeline"
# Unknown family - use DiffusionPipeline
return "DiffusionPipeline"
def get_image_pipeline_for_family(model_family, task_type):
"""Get the appropriate image pipeline for a given model family and task type."""
# Flux
if "flux" in model_id_lower:
if model_family == "flux":
if task_type == "i2i":
return "FluxImg2ImgPipeline"
return "FluxPipeline"
# Stable Diffusion 3
if "stable-diffusion-3" in model_id_lower or "sd3" in model_id_lower:
elif model_family == "sd3":
if task_type == "i2i":
return "StableDiffusion3Img2ImgPipeline"
return "StableDiffusion3Pipeline"
# Stable Diffusion XL
if "stable-diffusion-xl" in model_id_lower or "sdxl" in model_id_lower:
elif model_family == "sdxl":
if task_type == "i2i":
return "StableDiffusionXLImg2ImgPipeline"
return "StableDiffusionXLPipeline"
# Stable Diffusion 1.5
if "stable-diffusion" in model_id_lower and "xl" not in model_id_lower and "3" not in model_id_lower:
elif model_family == "sd":
if task_type == "i2i":
return "StableDiffusionImg2ImgPipeline"
return "StableDiffusionPipeline"
# Lumina
if "lumina2" in model_id_lower or "lumina-2" in model_id_lower:
return "Lumina2Text2ImgPipeline"
if "lumina" in model_id_lower:
elif model_family == "lumina":
return "LuminaText2ImgPipeline"
# Default fallback for image
return "FluxPipeline"
# Other image models
elif model_family == "other_image":
return "DiffusionPipeline"
# Unknown task - return default
# Video model families - they don't support pure image generation
# Use DiffusionPipeline as fallback
elif model_family in ["wan", "ltx", "svd", "cogvideox", "mochi", "animatediff", "other_video"]:
return "DiffusionPipeline"
# Unknown family - use DiffusionPipeline
return "DiffusionPipeline"
def detect_generation_type_from_args(args):
"""Detect the generation type from command-line arguments.
This is the PRIMARY way to detect generation type - it looks at:
- --image: provided image file (I2V or I2I)
- --input_video: provided video file (V2V or V2I)
- --image_to_video: explicit I2V flag
- --image_to_image: explicit I2I flag
- --output: output file extension (can indicate T2I vs T2V)
Returns one of: "t2v", "i2v", "v2v", "t2i", "i2i"
"""
if args is None:
return "t2v" # Default
# Check for explicit flags first (highest priority)
# I2V mode: --image_to_video flag
if getattr(args, 'image_to_video', False):
return "i2v"
# I2V mode: --image argument provided
if getattr(args, 'image', None):
# Check if there's also an image model - that means generate image then animate
if getattr(args, 'image_model', None):
return "i2v"
# Otherwise, check if it's I2V (video output) or I2I (image output)
# If --image-to-image is set, it's I2I
if getattr(args, 'image_to_image', False):
return "i2i"
# Check output extension - if it's a video extension, it's I2V
output = getattr(args, 'output', None)
if output:
ext = os.path.splitext(output)[1].lower()
if ext in [".mp4", ".avi", ".mov", ".webm", ".mkv"]:
return "i2v"
# Default: if --image is provided without explicit output, assume I2V for video
return "i2v"
# V2V mode: --video argument provided (or --video-to-video flag)
if getattr(args, 'video_to_video', False) or getattr(args, 'video', None):
return "v2v"
# I2I mode: --image_to_image flag (without --image would be weird but handle it)
if getattr(args, 'image_to_image', False):
return "i2i"
# Check output extension for T2I vs T2V
output = getattr(args, 'output', None)
if output:
ext = os.path.splitext(output)[1].lower()
# Image output = T2I
if ext in [".png", ".jpg", ".jpeg", ".gif", ".webp"]:
return "t2i"
# Video output = T2V
if ext in [".mp4", ".avi", ".mov", ".webm", ".mkv"]:
return "t2v"
# Default: T2V (video generation)
return "t2v"
def parse_hf_url_or_id(input_str):
"""Parse either a HuggingFace URL or model ID and return the model ID
......@@ -3924,6 +4139,70 @@ def detect_generation_type(prompt, prompt_image=None, prompt_animation=None, arg
"style_keywords": [], # New: detected style keywords
}
# CRITICAL: Check if --image argument is provided (I2V mode)
# This should be checked FIRST as it overrides other detections
if args is not None and hasattr(args, 'image') and args.image:
result["type"] = "i2v"
result["needs_image"] = True
result["needs_video"] = True
return result
# Also check for --image_to_video flag (explicit I2V mode)
if args is not None and hasattr(args, 'image_to_video') and args.image_to_video:
result["type"] = "i2v"
result["needs_image"] = True
result["needs_video"] = True
return result
# Check if --prompt_image or --prompt_animation is provided (T2I + I2V chaining)
if args is not None:
has_prompt_image = hasattr(args, 'prompt_image') and args.prompt_image
has_prompt_animation = hasattr(args, 'prompt_animation') and args.prompt_animation
has_image_model = hasattr(args, 'image_model') and args.image_model
# Check for audio operations that require video generation
has_audio = hasattr(args, 'generate_audio') and args.generate_audio
has_music = hasattr(args, 'music_model') and args.music_model
has_lip_sync = hasattr(args, 'lip_sync') and args.lip_sync
has_sync_audio = hasattr(args, 'sync_audio') and args.sync_audio
# Check for subtitle operations
has_subtitles = hasattr(args, 'create_subtitles') and args.create_subtitles
has_burn_subtitles = hasattr(args, 'burn_subtitles') and args.burn_subtitles
# T2I + I2V chaining: image_model OR prompt_image/prompt_animation
if has_prompt_image or has_prompt_animation or has_image_model:
result["type"] = "i2v"
result["needs_image"] = True
result["needs_video"] = True
result["chain_t2i"] = True # Flag for T2I + I2V chaining
return result
# T2V + V2V chaining: audio operations, subtitles, or prompt_animation
if has_prompt_animation or has_audio or has_music or has_lip_sync or has_sync_audio or has_subtitles or has_burn_subtitles:
result["type"] = "t2v" # Primary is T2V
result["needs_video"] = True
result["needs_audio"] = has_audio or has_music
result["chain_v2v"] = True # Flag for T2V + V2V chaining
return result
# Check if --video or --video-to-video is provided (V2V mode)
if args is not None and (getattr(args, 'video_to_video', False) or getattr(args, 'video', None)):
result["type"] = "v2v"
result["needs_image"] = True
result["needs_video"] = True
return result
# Check output extension - if video output with no image input, it's T2V
if args is not None and hasattr(args, 'output') and args.output:
output_ext = os.path.splitext(args.output)[1].lower()
if output_ext in [".mp4", ".avi", ".mov", ".webm", ".mkv"]:
# No image input but video output = T2V
if not (getattr(args, 'image', None) or getattr(args, 'prompt_image', None)):
result["type"] = "t2v"
result["needs_video"] = True
return result
# Detect NSFW
is_nsfw, confidence, reason = detect_nsfw_text(all_text)
result["is_nsfw"] = is_nsfw
......@@ -8279,7 +8558,7 @@ def main(args):
model_id = m_info["id"]
is_i2v_mode = args.image_to_video or args.image
is_i2i_mode = args.image_to_image
is_v2v_mode = getattr(args, 'video_to_video', False) or getattr(args, 'input_video', None) is not None
is_v2v_mode = getattr(args, 'video_to_video', False) or getattr(args, 'video', None) is not None
if is_i2v_mode:
task_type = "i2v"
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment