Wan LoRA trainer: fix fp32/bf16 dtype mismatch in the train step

torch.rand defaults to fp32, so the rectified-flow interpolation promoted x_t to
fp32 while the patch-embedding Conv3d stays bf16 (bitsandbytes 4-bit quantizes
only Linear layers), raising "Input type (float) and bias type (BFloat16) should
be the same". Compute the interpolation in fp32 then cast x_t/target back to the
model compute dtype, and pass timestep as fp32 (Wan casts it internally).
Co-Authored-By: 's avatarClaude Opus 4.8 <noreply@anthropic.com>
parent ad891a34
......@@ -952,12 +952,16 @@ def _train_wan(req, base_path, images, instance_prompt,
x0 = latents_list[step % n].to(device, dtype=compute_dtype)
noise = torch.randn_like(x0)
# Rectified-flow timestep with Wan resolution shift applied to sigma.
u = torch.rand(1, device=device)
# Compute the interpolation in fp32 for stability, then cast x_t back to
# the model's compute dtype — the patch-embedding Conv3d is NOT quantized
# (bnb only touches Linear), so it needs inputs in compute_dtype.
u = torch.rand(1, device=device, dtype=torch.float32)
sigma = (shift * u) / (1.0 + (shift - 1.0) * u)
s = sigma.view(-1, 1, 1, 1, 1)
x_t = (1.0 - s) * x0 + s * noise
target = noise - x0 # flow-matching velocity
timestep = (sigma * num_train_t).to(compute_dtype)
x0f = x0.float()
x_t = ((1.0 - s) * x0f + s * noise.float()).to(compute_dtype)
target = (noise.float() - x0f).to(compute_dtype) # flow-matching velocity
timestep = (sigma * num_train_t).to(torch.float32) # cast internally by Wan
tr = _pick_expert(float(sigma.item()))
pred = tr(hidden_states=x_t, timestep=timestep,
encoder_hidden_states=encoder_hidden_states.to(device),
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment