quant: reject checkpoints whose weights weren't actually quantized

GPTQModel silently leaves layers it can't map (e.g. gemma-4's fused batched MoE
experts) in bf16, producing a near-full-size "checkpoint" that the loader would
redirect to and then offload. The worker now scans the saved safetensors and, if
<50% of large weight bytes are int-packed, deletes the output and marks the job
failed (so it falls back to bitsandbytes) instead of reporting "done".
Co-Authored-By: 's avatarClaude Opus 4.8 <noreply@anthropic.com>
parent 6d053dc1
...@@ -270,6 +270,22 @@ def _quantize_worker(model_name: str, method: str, bits: int, group_size: int) - ...@@ -270,6 +270,22 @@ def _quantize_worker(model_name: str, method: str, bits: int, group_size: int) -
except Exception: except Exception:
pass pass
# Verify the quantizer actually compressed the weights. Some architectures
# (e.g. gemma-4's fused-batched MoE experts) aren't covered by GPTQModel's
# module map, so quantization silently leaves the bulk of the weights in
# bf16/fp16 — producing a near-full-size "checkpoint" that wastes disk and
# would offload at load time. Reject those instead of marking them done.
frac = _quantized_fraction(out_dir)
if frac is not None and frac < 0.5:
import shutil
shutil.rmtree(out_dir, ignore_errors=True)
_set_job(model_name, status="failed", progress=1.0, finished=time.time(),
error=(f"only {frac*100:.0f}% of weights were quantized — this "
f"architecture's layers (likely fused MoE experts) aren't "
f"supported by the quantizer; checkpoint discarded"),
message="quantization left most weights unquantized — discarded")
return
_set_job(model_name, status="done", progress=1.0, _set_job(model_name, status="done", progress=1.0,
message="quantization complete", output=str(out_dir), message="quantization complete", output=str(out_dir),
finished=time.time()) finished=time.time())
...@@ -280,6 +296,45 @@ def _quantize_worker(model_name: str, method: str, bits: int, group_size: int) - ...@@ -280,6 +296,45 @@ def _quantize_worker(model_name: str, method: str, bits: int, group_size: int) -
message=f"quantization failed: {e}", finished=time.time()) message=f"quantization failed: {e}", finished=time.time())
def _quantized_fraction(ckpt_dir: Path) -> Optional[float]:
"""Fraction of large weight bytes that are actually low-bit (int-packed).
Scans the saved safetensors and compares int8/int16/int32 (GPTQ qweight/qzeros)
bytes against bf16/fp16 weight bytes. Near 1.0 = properly quantized; near 0.0 =
the quantizer skipped most layers. Returns None if it can't be determined.
"""
try:
from safetensors import safe_open
_BPE = {"I32": 4, "I16": 2, "I8": 1, "U8": 1, "BF16": 2, "F16": 2, "F32": 4}
low_bits = 0
full = 0
shards = list(ckpt_dir.glob("*.safetensors"))
if not shards:
return None
for f in shards:
with safe_open(str(f), framework="numpy") as h:
for k in h.keys():
sl = h.get_slice(k)
dt = sl.get_dtype()
n = 1
for s in sl.get_shape():
n *= s
nbytes = n * _BPE.get(dt, 2)
if dt in ("I32", "I16", "I8", "U8"):
low_bits += nbytes
elif dt in ("BF16", "F16", "F32"):
# Ignore small tensors (norms, biases, scales, router gates);
# only large 2-D+ weights signal an unquantized layer.
if len(sl.get_shape()) >= 2 and n >= 1_000_000:
full += nbytes
total = low_bits + full
if total <= 0:
return None
return low_bits / total
except Exception:
return None
def _calibration_samples() -> List[str]: def _calibration_samples() -> List[str]:
"""A small, generic calibration set for GPTQ Hessian estimation.""" """A small, generic calibration set for GPTQ Hessian estimation."""
base = [ base = [
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment