Add automatic RGB/BGR colorspace detection for video models

- Add detect_model_colorspace() function to detect colorspace by generating test red frame
- Add get_model_colorspace() helper to retrieve or detect colorspace
- Modify frame processing to auto-swap BGR based on model config
- Preserve --swap_bgr flag as manual override
- Colorspace is saved to model config (~/.config/videogen/models.json) after first detection
parent 93a6883d
......@@ -1199,6 +1199,161 @@ def update_model_pipeline_class(model_name, new_pipeline_class):
return True
def detect_model_colorspace(pipe, model_name, m_info, args):
"""Detect if a model outputs RGB or BGR colorspace.
This function generates a test frame with pure red color and checks
if the model outputs red in the R channel (RGB) or B channel (BGR).
Args:
pipe: The loaded diffusion pipeline
model_name: Name of the model in MODELS config
m_info: Model info dict from MODELS
args: Command line arguments
Returns:
str: "RGB" or "BGR" depending on model output
"""
global MODELS
# Check if colorspace is already known in config
existing_colorspace = m_info.get("colorspace")
if existing_colorspace in ["RGB", "BGR"]:
return existing_colorspace
print(f" 🔍 Detecting colorspace for {model_name}...")
print(f" (This is done once per model by generating a test frame)")
try:
# Create a simple test prompt that should produce red
test_prompt = "solid red color, uniform red background, pure red"
# Determine test dimensions (small for speed)
test_height = 256
test_width = 256
test_frames = 1 # Just one frame for detection
# Generate with minimal steps for speed
with torch.no_grad():
# Prepare kwargs based on pipeline type
video_kwargs = {
"prompt": test_prompt,
"height": test_height,
"width": test_width,
"num_frames": test_frames,
"num_inference_steps": 10, # Minimal steps for speed
"guidance_scale": 5.0,
}
# Check if pipeline supports image input (I2V) - if so, provide red image
pipeline_class_name = type(pipe).__name__
i2v_pipelines = ['StableVideoDiffusionPipeline', 'I2VGenXLPipeline',
'LTXImageToVideoPipeline', 'WanImageToVideoPipeline',
'CogVideoXImageToVideoPipeline']
if pipeline_class_name in i2v_pipelines:
# Create a pure red image for I2V models
red_image = Image.new('RGB', (test_width, test_height), color=(255, 0, 0))
video_kwargs["image"] = red_image
output = pipe(**video_kwargs)
# Extract frames from output
if hasattr(output, "frames"):
test_frames_data = output.frames[0] if isinstance(output.frames, list) else output.frames
elif hasattr(output, "videos"):
test_frames_data = output.videos[0]
else:
# Fallback: assume RGB if we can't analyze
print(f" ⚠️ Could not analyze output format, assuming RGB")
return "RGB"
# Convert to numpy if tensor
if isinstance(test_frames_data, torch.Tensor):
test_frames_data = test_frames_data.cpu().numpy()
# Ensure shape is (frames, height, width, channels)
if test_frames_data.ndim == 5:
test_frames_data = test_frames_data[0]
if test_frames_data.ndim == 4:
# Check if channels first or last
if test_frames_data.shape[0] in [1, 3, 4]:
test_frames_data = np.transpose(test_frames_data, (1, 2, 3, 0))
# Take first frame
test_frame = test_frames_data[0] if test_frames_data.ndim == 4 else test_frames_data
# Normalize to 0-255 if needed
if test_frame.dtype == np.float32 or test_frame.dtype == np.float64:
if test_frame.max() <= 1.0:
test_frame = test_frame * 255
test_frame = test_frame.astype(np.uint8)
# Ensure we have 3 channels
if test_frame.ndim == 2:
test_frame = np.stack([test_frame] * 3, axis=-1)
elif test_frame.shape[-1] > 3:
test_frame = test_frame[..., :3]
# Analyze the colors - check center region to avoid borders
h, w = test_frame.shape[:2]
center_region = test_frame[h//4:3*h//4, w//4:3*w//4]
# Calculate average of each channel
r_avg = np.mean(center_region[..., 0])
g_avg = np.mean(center_region[..., 1])
b_avg = np.mean(center_region[..., 2])
print(f" Channel averages - R: {r_avg:.1f}, G: {g_avg:.1f}, B: {b_avg:.1f}")
# Determine colorspace
# If R > B significantly, it's RGB (red appears in red channel)
# If B > R significantly, it's BGR (red appears in blue channel)
if r_avg > b_avg + 20: # Red channel significantly higher
detected_colorspace = "RGB"
elif b_avg > r_avg + 20: # Blue channel significantly higher
detected_colorspace = "BGR"
else:
# Ambiguous - default to RGB (most common)
print(f" ⚠️ Colorspace ambiguous, defaulting to RGB")
detected_colorspace = "RGB"
print(f" ✅ Detected colorspace: {detected_colorspace}")
# Save to model config
if model_name in MODELS:
MODELS[model_name]["colorspace"] = detected_colorspace
save_models_config(MODELS)
print(f" 📝 Saved colorspace to model config")
return detected_colorspace
except Exception as e:
print(f" ⚠️ Colorspace detection failed: {e}")
print(f" Defaulting to RGB")
return "RGB"
def get_model_colorspace(pipe, model_name, m_info, args):
"""Get the colorspace for a model, detecting it if necessary.
Args:
pipe: The loaded diffusion pipeline
model_name: Name of the model in MODELS config
m_info: Model info dict from MODELS
args: Command line arguments
Returns:
str: "RGB" or "BGR"
"""
# Check if already in config
if model_name in MODELS and "colorspace" in MODELS[model_name]:
return MODELS[model_name]["colorspace"]
# Detect and save
return detect_model_colorspace(pipe, model_name, m_info, args)
def validate_hf_model(model_id, hf_token=None, debug=False):
"""Validate if a HuggingFace model exists and get its info
......@@ -10120,6 +10275,12 @@ def main(args):
print(f" ✅ Model loaded successfully")
timing.end_step() # model_loading
# Detect colorspace if not already known (only for video models)
if not ("pony" in args.model or "flux" in args.model):
colorspace = get_model_colorspace(pipe, args.model, m_info, args)
if colorspace:
print(f" 📊 Model colorspace: {colorspace}")
# ─── Audio Generation (Pre-video) ──────────────────────────────────────────
audio_path = None
......@@ -10244,9 +10405,20 @@ def main(args):
frames = 1.0 - frames
# Check if BGR->RGB channel swap is needed
# First check if user explicitly requested swap
if args.swap_bgr:
print(f" 🔄 Swapping BGR to RGB channels as requested (--swap_bgr)...")
frames = frames[..., ::-1] # Reverse channel order
else:
# Auto-detect colorspace if model info available
colorspace = m_info.get("colorspace")
if colorspace == "BGR":
print(f" 🔄 Auto-swapping BGR to RGB (detected from model config)...")
frames = frames[..., ::-1] # Reverse channel order
elif colorspace == "RGB":
# RGB colorspace, no swap needed
pass
# If colorspace not set, assume RGB (most common)
# Now convert from [0, 1] to [0, 255]
frames = np.clip(frames, 0.0, 1.0) * 255
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment