Commit 6db57c26 authored by Stefy Lanza (nextime / spora )'s avatar Stefy Lanza (nextime / spora )

Merge branch 'experimental' into 'master'

Fix I2V pipeline auto-detection

See merge request !1
parents 803b1763 1ef3d1b8
......@@ -1493,125 +1493,340 @@ def get_pipeline_for_task(model_id, task_type):
"""
model_id_lower = model_id.lower()
# Video pipelines
# First, detect the model family to determine which pipeline family to use
model_family = detect_model_family(model_id)
# Now select the appropriate pipeline based on model family and task type
return get_pipeline_for_model_family(model_family, task_type)
def detect_model_family(model_id):
"""Detect the model family from model ID.
Returns one of:
- "wan" : Wan models (Wan-AI)
- "flux" : Flux models (Black Forest Labs)
- "sdxl" : Stable Diffusion XL
- "sd" : Stable Diffusion 1.5
- "sd3" : Stable Diffusion 3
- "ltx" : LTX-Video models
- "svd" : Stable Video Diffusion
- "cogvideox" : CogVideoX models
- "mochi" : Mochi models
- "animatediff" : AnimateDiff models
- "other_video" : Other video models
- "other_image" : Other image models
- "unknown" : Unknown model family
"""
model_id_lower = model_id.lower()
# Wan models - check first as they're specific
if "wan" in model_id_lower:
return "wan"
# Flux models
if "flux" in model_id_lower:
return "flux"
# Stable Diffusion XL
if "sdxl" in model_id_lower or "stable-diffusion-xl" in model_id_lower:
return "sdxl"
# Stable Diffusion 3
if "sd3" in model_id_lower or "stable-diffusion-3" in model_id_lower:
return "sd3"
# Stable Diffusion 1.5 (check after XL and 3 to avoid false positives)
if "stable-diffusion" in model_id_lower and "xl" not in model_id_lower and "3" not in model_id_lower:
return "sd"
# LTX-Video
if "ltx" in model_id_lower or "ltx-video" in model_id_lower:
return "ltx"
# Stable Video Diffusion
if "stable-video-diffusion" in model_id_lower or "svd" in model_id_lower:
return "svd"
# CogVideoX
if "cogvideo" in model_id_lower:
return "cogvideox"
# Mochi
if "mochi" in model_id_lower:
return "mochi"
# AnimateDiff
if "animatediff" in model_id_lower:
return "animatediff"
# Allegro
if "allegro" in model_id_lower:
return "other_video"
# Hunyuan
if "hunyuan" in model_id_lower:
return "other_video"
# OpenSora
if "open-sora" in model_id_lower or "opensora" in model_id_lower:
return "other_video"
# StepVideo
if "stepvideo" in model_id_lower or "step-video" in model_id_lower:
return "other_video"
# Zeroscope
if "zeroscope" in model_id_lower:
return "other_video"
# Modelscope
if "modelscope" in model_id_lower or "text-to-video-ms" in model_id_lower:
return "other_video"
# Latte
if "latte" in model_id_lower:
return "other_video"
# Hotshot
if "hotshot" in model_id_lower:
return "other_video"
# I2VGenXL
if "i2vgen" in model_id_lower:
return "other_video"
# Lumina
if "lumina" in model_id_lower:
return "other_image"
# Pony models (usually SDXL-based)
if "pony" in model_id_lower:
return "sdxl"
# Animagine models (usually SDXL-based)
if "animagine" in model_id_lower:
return "sdxl"
# Common SD 1.5 model name patterns
sd15_patterns = ["deliberate", "juggernaut", "realistic_vision", "realisticvision",
"anything", "rev Animated", "counterfeit", "chilloutmix", "pastel mix",
"dreamlike", "douyin", "ghostmix", "toonyou", "redshift"]
if any(pat in model_id_lower for pat in sd15_patterns):
# Check if it's SDXL (some have xl in name)
if "xl" not in model_id_lower:
return "sd"
# Common SDXL model name patterns
sdxl_patterns = ["sdxl", "pony", "animagine", "juggernaut", "cyberrealistic",
"realcartoon", "majicmix", "dreamshaper", "epicrealism",
"absolutereality", "proteus"]
if any(pat in model_id_lower for pat in sdxl_patterns):
return "sdxl"
# Default: check if it looks like a video model
if any(x in model_id_lower for x in ["video", "animation", "motion"]):
return "other_video"
# Default to image model
return "unknown"
def get_pipeline_for_model_family(model_family, task_type):
"""Get the appropriate pipeline class based on model family and task type.
This is the second step of pipeline detection:
1. detect_model_family() - identifies the model family
2. get_pipeline_for_model_family() - selects the correct pipeline
Args:
model_family: The model family (from detect_model_family)
task_type: The task type (t2v, i2v, t2i, i2i, v2v)
Returns the pipeline class name (string).
"""
# Handle video generation tasks (t2v, i2v, v2v)
if task_type in ["t2v", "i2v", "v2v"]:
# Wan models
if "wan" in model_id_lower:
if task_type == "i2v":
return "WanImageToVideoPipeline"
elif task_type == "v2v":
return "WanVideoToVideoPipeline"
else: # t2v
return "WanPipeline"
# LTX models
if "ltx" in model_id_lower:
if task_type == "i2v":
return "LTXImageToVideoPipeline"
else:
return "LTXPipeline"
# Stable Video Diffusion
if "stable-video-diffusion" in model_id_lower or "svd" in model_id_lower:
return "StableVideoDiffusionPipeline"
# CogVideoX
if "cogvideo" in model_id_lower:
if task_type == "i2v":
return "CogVideoXImageToVideoPipeline"
elif task_type == "v2v":
return "CogVideoXVideoToVideoPipeline"
else:
return "CogVideoXPipeline"
# I2VGenXL
if "i2vgen" in model_id_lower:
return "I2VGenXLPipeline"
# Mochi
if "mochi" in model_id_lower:
return "MochiPipeline"
# AnimateDiff
if "animatediff" in model_id_lower:
if task_type == "v2v":
return "AnimateDiffVideoToVideoPipeline"
return "AnimateDiffPipeline"
# Allegro
if "allegro" in model_id_lower:
return "AllegroPipeline"
# Hunyuan
if "hunyuan" in model_id_lower:
return "HunyuanDiTPipeline"
# OpenSora
if "open-sora" in model_id_lower or "opensora" in model_id_lower:
return "OpenSoraPipeline"
# StepVideo
if "stepvideo" in model_id_lower or "step-video" in model_id_lower:
return "StepVideoPipeline"
# Zeroscope
if "zeroscope" in model_id_lower:
return "TextToVideoZeroPipeline"
# Modelscope
if "modelscope" in model_id_lower or "text-to-video-ms" in model_id_lower:
return "TextToVideoSDPipeline"
# Latte
if "latte" in model_id_lower:
return "LattePipeline"
# Hotshot
if "hotshot" in model_id_lower:
return "HotshotXLPipeline"
# Default fallback for video
return "WanPipeline"
return get_video_pipeline_for_family(model_family, task_type)
# Image pipelines
# Handle image generation tasks (t2i, i2i)
elif task_type in ["t2i", "i2i"]:
# Flux
if "flux" in model_id_lower:
if task_type == "i2i":
return "FluxImg2ImgPipeline"
return "FluxPipeline"
# Stable Diffusion 3
if "stable-diffusion-3" in model_id_lower or "sd3" in model_id_lower:
if task_type == "i2i":
return "StableDiffusion3Img2ImgPipeline"
return "StableDiffusion3Pipeline"
# Stable Diffusion XL
if "stable-diffusion-xl" in model_id_lower or "sdxl" in model_id_lower:
if task_type == "i2i":
return "StableDiffusionXLImg2ImgPipeline"
return "StableDiffusionXLPipeline"
# Stable Diffusion 1.5
if "stable-diffusion" in model_id_lower and "xl" not in model_id_lower and "3" not in model_id_lower:
if task_type == "i2i":
return "StableDiffusionImg2ImgPipeline"
return "StableDiffusionPipeline"
# Lumina
if "lumina2" in model_id_lower or "lumina-2" in model_id_lower:
return "Lumina2Text2ImgPipeline"
if "lumina" in model_id_lower:
return "LuminaText2ImgPipeline"
# Default fallback for image
return get_image_pipeline_for_family(model_family, task_type)
# Unknown task - use DiffusionPipeline as fallback
return "DiffusionPipeline"
def get_video_pipeline_for_family(model_family, task_type):
"""Get the appropriate video pipeline for a given model family and task type."""
# Wan models
if model_family == "wan":
if task_type == "i2v":
return "WanImageToVideoPipeline"
elif task_type == "v2v":
return "WanVideoToVideoPipeline"
else: # t2v
return "WanPipeline"
# LTX models
elif model_family == "ltx":
if task_type == "i2v":
return "LTXImageToVideoPipeline"
elif task_type == "v2v":
# LTX supports video-to-video via its base pipeline
return "LTXPipeline"
else:
return "LTXPipeline"
# Stable Video Diffusion
elif model_family == "svd":
return "StableVideoDiffusionPipeline"
# CogVideoX
elif model_family == "cogvideox":
if task_type == "i2v":
return "CogVideoXImageToVideoPipeline"
elif task_type == "v2v":
return "CogVideoXVideoToVideoPipeline"
else:
return "CogVideoXPipeline"
# I2VGenXL
elif model_family == "i2vgen":
return "I2VGenXLPipeline"
# Mochi
elif model_family == "mochi":
return "MochiPipeline"
# AnimateDiff
elif model_family == "animatediff":
if task_type == "v2v":
return "AnimateDiffVideoToVideoPipeline"
return "AnimateDiffPipeline"
# Other known video models
elif model_family == "other_video":
# Map specific patterns to known pipelines
return "DiffusionPipeline" # Let diffusers auto-detect
# Image model families - they don't support video generation
# Return DiffusionPipeline to let it try (or fail gracefully)
elif model_family in ["sdxl", "sd", "sd3", "flux"]:
# These are image models, not video models
# Return DiffusionPipeline which will auto-detect
# The actual error will come from the pipeline
return "DiffusionPipeline"
# Unknown family - use DiffusionPipeline
return "DiffusionPipeline"
def get_image_pipeline_for_family(model_family, task_type):
"""Get the appropriate image pipeline for a given model family and task type."""
# Flux
if model_family == "flux":
if task_type == "i2i":
return "FluxImg2ImgPipeline"
return "FluxPipeline"
# Unknown task - return default
# Stable Diffusion 3
elif model_family == "sd3":
if task_type == "i2i":
return "StableDiffusion3Img2ImgPipeline"
return "StableDiffusion3Pipeline"
# Stable Diffusion XL
elif model_family == "sdxl":
if task_type == "i2i":
return "StableDiffusionXLImg2ImgPipeline"
return "StableDiffusionXLPipeline"
# Stable Diffusion 1.5
elif model_family == "sd":
if task_type == "i2i":
return "StableDiffusionImg2ImgPipeline"
return "StableDiffusionPipeline"
# Lumina
elif model_family == "lumina":
return "LuminaText2ImgPipeline"
# Other image models
elif model_family == "other_image":
return "DiffusionPipeline"
# Video model families - they don't support pure image generation
# Use DiffusionPipeline as fallback
elif model_family in ["wan", "ltx", "svd", "cogvideox", "mochi", "animatediff", "other_video"]:
return "DiffusionPipeline"
# Unknown family - use DiffusionPipeline
return "DiffusionPipeline"
def detect_generation_type_from_args(args):
"""Detect the generation type from command-line arguments.
This is the PRIMARY way to detect generation type - it looks at:
- --image: provided image file (I2V or I2I)
- --input_video: provided video file (V2V or V2I)
- --image_to_video: explicit I2V flag
- --image_to_image: explicit I2I flag
- --output: output file extension (can indicate T2I vs T2V)
Returns one of: "t2v", "i2v", "v2v", "t2i", "i2i"
"""
if args is None:
return "t2v" # Default
# Check for explicit flags first (highest priority)
# I2V mode: --image_to_video flag
if getattr(args, 'image_to_video', False):
return "i2v"
# I2V mode: --image argument provided
if getattr(args, 'image', None):
# Check if there's also an image model - that means generate image then animate
if getattr(args, 'image_model', None):
return "i2v"
# Otherwise, check if it's I2V (video output) or I2I (image output)
# If --image-to-image is set, it's I2I
if getattr(args, 'image_to_image', False):
return "i2i"
# Check output extension - if it's a video extension, it's I2V
output = getattr(args, 'output', None)
if output:
ext = os.path.splitext(output)[1].lower()
if ext in [".mp4", ".avi", ".mov", ".webm", ".mkv"]:
return "i2v"
# Default: if --image is provided without explicit output, assume I2V for video
return "i2v"
# V2V mode: --video argument provided (or --video-to-video flag)
if getattr(args, 'video_to_video', False) or getattr(args, 'video', None):
return "v2v"
# I2I mode: --image_to_image flag (without --image would be weird but handle it)
if getattr(args, 'image_to_image', False):
return "i2i"
# Check output extension for T2I vs T2V
output = getattr(args, 'output', None)
if output:
ext = os.path.splitext(output)[1].lower()
# Image output = T2I
if ext in [".png", ".jpg", ".jpeg", ".gif", ".webp"]:
return "t2i"
# Video output = T2V
if ext in [".mp4", ".avi", ".mov", ".webm", ".mkv"]:
return "t2v"
# Default: T2V (video generation)
return "t2v"
def parse_hf_url_or_id(input_str):
"""Parse either a HuggingFace URL or model ID and return the model ID
......@@ -3924,6 +4139,70 @@ def detect_generation_type(prompt, prompt_image=None, prompt_animation=None, arg
"style_keywords": [], # New: detected style keywords
}
# CRITICAL: Check if --image argument is provided (I2V mode)
# This should be checked FIRST as it overrides other detections
if args is not None and hasattr(args, 'image') and args.image:
result["type"] = "i2v"
result["needs_image"] = True
result["needs_video"] = True
return result
# Also check for --image_to_video flag (explicit I2V mode)
if args is not None and hasattr(args, 'image_to_video') and args.image_to_video:
result["type"] = "i2v"
result["needs_image"] = True
result["needs_video"] = True
return result
# Check if --prompt_image or --prompt_animation is provided (T2I + I2V chaining)
if args is not None:
has_prompt_image = hasattr(args, 'prompt_image') and args.prompt_image
has_prompt_animation = hasattr(args, 'prompt_animation') and args.prompt_animation
has_image_model = hasattr(args, 'image_model') and args.image_model
# Check for audio operations that require video generation
has_audio = hasattr(args, 'generate_audio') and args.generate_audio
has_music = hasattr(args, 'music_model') and args.music_model
has_lip_sync = hasattr(args, 'lip_sync') and args.lip_sync
has_sync_audio = hasattr(args, 'sync_audio') and args.sync_audio
# Check for subtitle operations
has_subtitles = hasattr(args, 'create_subtitles') and args.create_subtitles
has_burn_subtitles = hasattr(args, 'burn_subtitles') and args.burn_subtitles
# T2I + I2V chaining: image_model OR prompt_image/prompt_animation
if has_prompt_image or has_prompt_animation or has_image_model:
result["type"] = "i2v"
result["needs_image"] = True
result["needs_video"] = True
result["chain_t2i"] = True # Flag for T2I + I2V chaining
return result
# T2V + V2V chaining: audio operations, subtitles, or prompt_animation
if has_prompt_animation or has_audio or has_music or has_lip_sync or has_sync_audio or has_subtitles or has_burn_subtitles:
result["type"] = "t2v" # Primary is T2V
result["needs_video"] = True
result["needs_audio"] = has_audio or has_music
result["chain_v2v"] = True # Flag for T2V + V2V chaining
return result
# Check if --video or --video-to-video is provided (V2V mode)
if args is not None and (getattr(args, 'video_to_video', False) or getattr(args, 'video', None)):
result["type"] = "v2v"
result["needs_image"] = True
result["needs_video"] = True
return result
# Check output extension - if video output with no image input, it's T2V
if args is not None and hasattr(args, 'output') and args.output:
output_ext = os.path.splitext(args.output)[1].lower()
if output_ext in [".mp4", ".avi", ".mov", ".webm", ".mkv"]:
# No image input but video output = T2V
if not (getattr(args, 'image', None) or getattr(args, 'prompt_image', None)):
result["type"] = "t2v"
result["needs_video"] = True
return result
# Detect NSFW
is_nsfw, confidence, reason = detect_nsfw_text(all_text)
result["is_nsfw"] = is_nsfw
......@@ -8279,7 +8558,7 @@ def main(args):
model_id = m_info["id"]
is_i2v_mode = args.image_to_video or args.image
is_i2i_mode = args.image_to_image
is_v2v_mode = getattr(args, 'video_to_video', False) or getattr(args, 'input_video', None) is not None
is_v2v_mode = getattr(args, 'video_to_video', False) or getattr(args, 'video', None) is not None
if is_i2v_mode:
task_type = "i2v"
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment