Commit 87ed1011 authored by Stefy Lanza (nextime / spora )'s avatar Stefy Lanza (nextime / spora )

Merge branch 'experimental'

parents 258c584b 40787d0a
...@@ -10144,6 +10144,54 @@ def main(args): ...@@ -10144,6 +10144,54 @@ def main(args):
print("Unknown output format.") print("Unknown output format.")
return return
# Process frames for export - ensure correct format
# WanImageToVideoPipeline returns tensor with shape (batch, channels, frames, height, width)
# or (frames, channels, height, width) - need to convert to (frames, height, width, channels)
if isinstance(frames, torch.Tensor):
# Convert tensor to numpy
frames = frames.cpu().numpy()
# Handle different frame array shapes
if isinstance(frames, np.ndarray):
# If frames is 5D (batch, channels, frames, height, width), extract first batch
if frames.ndim == 5:
frames = frames[0] # Now (channels, frames, height, width)
# If frames is 4D (channels, frames, height, width), transpose to (frames, height, width, channels)
if frames.ndim == 4:
if frames.shape[0] in [1, 3, 4]: # channels first
frames = np.transpose(frames, (1, 2, 3, 0)) # (frames, height, width, channels)
elif frames.shape[-1] in [1, 3, 4]: # channels last, already correct
pass # Already in correct format
# If frames is 3D (height, width, channels) or (frames, height, width), add dimension
if frames.ndim == 3:
if frames.shape[-1] not in [1, 3, 4] and frames.shape[0] in [1, 3, 4]:
# channels first, need to transpose
frames = np.transpose(frames, (1, 2, 0))
# If single frame (height, width, channels), add batch dimension
if frames.ndim == 3 and frames.shape[-1] in [1, 3, 4]:
frames = np.expand_dims(frames, axis=0)
# Ensure we have 3 channels (RGB) for export
if frames.ndim == 4 and frames.shape[-1] == 1:
# Grayscale, convert to RGB
frames = np.repeat(frames, 3, axis=-1)
elif frames.ndim == 4 and frames.shape[-1] > 4:
# Too many channels, take first 3
frames = frames[..., :3]
# Ensure frames are in 0-255 uint8 range for video export
if isinstance(frames, np.ndarray):
if frames.dtype == np.float32 or frames.dtype == np.float64:
# Assume 0-1 range, convert to 0-255
if frames.max() <= 1.0:
frames = (frames * 255).astype(np.uint8)
else:
frames = frames.astype(np.uint8)
elif frames.dtype != np.uint8:
frames = frames.astype(np.uint8)
export_to_video(frames, f"{args.output}.mp4", fps=args.fps) export_to_video(frames, f"{args.output}.mp4", fps=args.fps)
timing.end_step() # video_generation timing.end_step() # video_generation
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment