Wan LoRA trainer: fix fp32/bf16 dtype mismatch in the train step
torch.rand defaults to fp32, so the rectified-flow interpolation promoted x_t to
fp32 while the patch-embedding Conv3d stays bf16 (bitsandbytes 4-bit quantizes
only Linear layers), raising "Input type (float) and bias type (BFloat16) should
be the same". Compute the interpolation in fp32 then cast x_t/target back to the
model compute dtype, and pass timestep as fp32 (Wan casts it internally).
Co-Authored-By:
Claude Opus 4.8 <noreply@anthropic.com>
Showing
Please
register
or
sign in
to comment