feat: smart context caching, VRAM offload fix, GPTQ/AWQ quant backend

Smart context caching (both text backends):
- Per-instance generation lock so pooled concurrent requests can't corrupt a
  shared KV cache (GGUF + HF, incl. streaming worker thread).
- GGUF: enable multi-slot LlamaRAMCache, budget via kv_cache_budget_mb (512MB).
- HF: replace single exact-text KV slot with an LRU of token-prefix slots +
  token-level longest-common-prefix + DynamicCache clone/crop (handles
  mid-history edits); kv_cache_slots (default 3).
- Session-affinity routing in ModelInstancePool.acquire(session_key); key from
  user/X-Session-Id else a stable prefix hash.
- RAM-pressure ladder drops reclaimable prefix caches before evicting models.

VRAM fix:
- Auto-fit check no longer double-counts the KV/activation reserve when
  expected_vram_gb is already a peak estimate — borderline models (e.g.
  gemma-4-26B-A4B) stay GPU-resident instead of forced into MoE-thrashing
  device_map offload.

GPTQ/AWQ fast-kernel quant backend (HF path):
- New codai/models/quant.py: GPTQModel capability detection, quantized-checkpoint
  cache, on-demand background quantize job (falls back to bnb if unsupported).
- quant_backend config (auto|bnb|gptq|awq); loader auto-uses a quantized
  checkpoint with Marlin/ExLlama when present, else bitsandbytes.
- Admin endpoints + "Quantize to 4-bit" button with live status on the model page.
- requirements-nvidia.txt documents the from-source install + numpy caveat.
Co-Authored-By: 's avatarClaude Opus 4.8 <noreply@anthropic.com>
parent e4c040e2
...@@ -1669,6 +1669,52 @@ async def api_model_disable(request: Request, username: str = Depends(require_ad ...@@ -1669,6 +1669,52 @@ async def api_model_disable(request: Request, username: str = Depends(require_ad
return {"success": True} return {"success": True}
@router.get("/admin/api/quantize-capabilities", summary="GPTQ/AWQ quantization availability")
async def api_quantize_capabilities(username: str = Depends(require_admin)):
"""Report whether fast-kernel (GPTQ/AWQ) quantization is available + any jobs."""
from codai.models import quant
return {
"capabilities": quant.capabilities(),
"available": quant.is_available(),
"jobs": quant.all_jobs(),
}
@router.post("/admin/api/model-quantize", summary="Quantize a model to fast-kernel 4-bit")
async def api_model_quantize(request: Request, username: str = Depends(require_admin)):
"""Start (or report) an on-demand background GPTQ/AWQ quantization.
Body: {path|model_id, method?(gptq|awq), bits?(4), group_size?(128)}.
Quantization is heavy and slow; it runs in the background and the produced
checkpoint is picked up automatically on the model's next load. Falls back to
bitsandbytes if the fast kernels are unavailable or the arch is unsupported.
"""
from codai.models import quant
data = await request.json()
model_id = (data.get("path") or data.get("model_id") or "").strip()
if not model_id:
raise HTTPException(status_code=400, detail="path/model_id is required")
method = (data.get("method") or "gptq").lower()
if method not in ("gptq", "awq"):
raise HTTPException(status_code=400, detail="method must be 'gptq' or 'awq'")
try:
bits = int(data.get("bits", 4))
group_size = int(data.get("group_size", 128))
except (TypeError, ValueError):
raise HTTPException(status_code=400, detail="bits/group_size must be integers")
job = quant.start_quantization(model_id, method=method, bits=bits, group_size=group_size)
return {"success": job.get("status") != "unavailable", "job": job}
@router.get("/admin/api/quantize-status", summary="Quantization job status")
async def api_quantize_status(model_id: str = "", username: str = Depends(require_admin)):
"""Status for one model's quant job (?model_id=...), or all jobs."""
from codai.models import quant
if model_id:
return {"job": quant.get_job(model_id.strip())}
return {"jobs": quant.all_jobs()}
@router.get("/admin/api/model-loaded-status", summary="Model load status") @router.get("/admin/api/model-loaded-status", summary="Model load status")
async def api_model_loaded_status(username: str = Depends(require_admin)): async def api_model_loaded_status(username: str = Depends(require_admin)):
"""Return loaded model keys with per-model instance pool info.""" """Return loaded model keys with per-model instance pool info."""
...@@ -2117,7 +2163,8 @@ async def api_model_configure(request: Request, username: str = Depends(require_ ...@@ -2117,7 +2163,8 @@ async def api_model_configure(request: Request, username: str = Depends(require_
"max_vram", "sdcpp_flash_attn", "sdcpp_diffusion_flash_attn", "vae_tiling", "max_vram", "sdcpp_flash_attn", "sdcpp_diffusion_flash_attn", "vae_tiling",
"component_quantization", "output_crf", "force_vram_update", "component_quantization", "output_crf", "force_vram_update",
"balanced_gpu_percent", "acceleration", "balanced_gpu_percent", "acceleration",
"cache_type_k", "cache_type_v", "turboquant", "engine"): "cache_type_k", "cache_type_v", "turboquant", "engine",
"quant_backend", "kv_cache_budget_mb", "kv_cache_slots"):
if key in data: if key in data:
entry[key] = data[key] entry[key] = data[key]
......
...@@ -645,6 +645,15 @@ window.__DEFAULT_WHISPER_SERVER_PATH__ = {{ default_whisper_server_path|tojson } ...@@ -645,6 +645,15 @@ window.__DEFAULT_WHISPER_SERVER_PATH__ = {{ default_whisper_server_path|tojson }
<option value="q4_0">q4_0 (smallest)</option> <option value="q4_0">q4_0 (smallest)</option>
</select> </select>
</div> </div>
<div class="form-row" style="margin:0">
<p class="muted" style="margin:0;font-size:0.85em;line-height:1.4">
<strong>Note:</strong> KV-cache quantization only applies to GGUF models on
the llama.cpp backend. HF-transformers models (incl. <strong>gemma</strong> and
other sliding-window / hybrid linear-attention architectures) ignore this and
keep an fp16 KV cache — their quantized-cache path crashes during generation.
For those, lower <strong>n_ctx</strong> to shrink the KV VRAM reserve instead.
</p>
</div>
<div class="form-row" style="margin:0"> <div class="form-row" style="margin:0">
<label class="form-label">Max GPU % <span class="muted">(optional)</span></label> <label class="form-label">Max GPU % <span class="muted">(optional)</span></label>
<input type="number" id="cfg-max-gpu" class="form-input" min="1" max="100" placeholder="e.g. 90"> <input type="number" id="cfg-max-gpu" class="form-input" min="1" max="100" placeholder="e.g. 90">
...@@ -661,6 +670,30 @@ window.__DEFAULT_WHISPER_SERVER_PATH__ = {{ default_whisper_server_path|tojson } ...@@ -661,6 +670,30 @@ window.__DEFAULT_WHISPER_SERVER_PATH__ = {{ default_whisper_server_path|tojson }
<label style="display:flex;align-items:center;gap:.5rem;cursor:pointer;font-size:13px"><input type="checkbox" id="cfg-noram"> No RAM fallback</label> <label style="display:flex;align-items:center;gap:.5rem;cursor:pointer;font-size:13px"><input type="checkbox" id="cfg-noram"> No RAM fallback</label>
</div> </div>
<!-- Fast-kernel 4-bit (GPTQ/AWQ via Marlin) — HF transformers models only -->
<div id="cfg-fastquant-wrap" style="margin-top:1rem;padding:.75rem;border:1px solid var(--border);border-radius:6px;background:var(--raised)">
<div style="display:flex;align-items:center;gap:1rem;flex-wrap:wrap">
<div style="flex:1;min-width:220px">
<label class="form-label">Quant backend <span class="muted">(HF transformers)</span></label>
<select id="cfg-quant-backend" class="form-input">
<option value="auto">Auto — fast checkpoint if present, else bitsandbytes</option>
<option value="bnb">bitsandbytes only (NF4)</option>
<option value="gptq">GPTQ (Marlin/ExLlama) — needs quantized checkpoint</option>
<option value="awq">AWQ (Marlin) — needs quantized checkpoint</option>
</select>
</div>
<div style="display:flex;align-items:flex-end;gap:.5rem">
<button type="button" class="btn btn-secondary btn-sm" id="cfg-quantize-btn" onclick="startQuantize()">Quantize to 4-bit</button>
</div>
</div>
<div class="form-hint" id="cfg-quant-status" style="margin-top:.5rem">
bitsandbytes NF4 is the slowest 4-bit option. Quantizing to GPTQ/AWQ uses
Marlin kernels (2–4× faster) but is a heavy one-time background job; the
produced checkpoint is used automatically on the next load. Falls back to
bitsandbytes if unavailable or the architecture is unsupported.
</div>
</div>
<!-- Per-component quantization (diffusers image/video pipelines) --> <!-- Per-component quantization (diffusers image/video pipelines) -->
<div id="cfg-compquant-wrap"> <div id="cfg-compquant-wrap">
<div class="card-title" style="margin-top:1.25rem">Component Quantization <span class="muted" style="font-weight:normal">(image / video pipelines)</span></div> <div class="card-title" style="margin-top:1.25rem">Component Quantization <span class="muted" style="font-weight:normal">(image / video pipelines)</span></div>
...@@ -2084,6 +2117,88 @@ function _collectComponentQuant(){ ...@@ -2084,6 +2117,88 @@ function _collectComponentQuant(){
return Object.keys(out).length ? out : null; return Object.keys(out).length ? out : null;
} }
// ---- Fast-kernel (GPTQ/AWQ) quantization ----
let _quantPollTimer = null;
let _quantCaps = null;
async function _quantGetCaps(){
if(_quantCaps) return _quantCaps;
try{
const r = await fetch('/admin/api/quantize-capabilities');
const j = await r.json();
_quantCaps = j;
}catch(e){ _quantCaps = {available:false, capabilities:{error:String(e)}}; }
return _quantCaps;
}
function _renderQuantStatus(job, caps){
const el = document.getElementById('cfg-quant-status');
const btn = document.getElementById('cfg-quantize-btn');
if(!el) return;
if(caps && !caps.available){
btn.disabled = true;
const err = (caps.capabilities && caps.capabilities.error) || 'GPTQModel not installed';
el.innerHTML = `<span style="color:var(--warn,#d97706)">Fast-kernel quantization unavailable — using bitsandbytes. (${esc(err)})</span>`;
return;
}
const kernels = (caps && caps.capabilities && caps.capabilities.backends || []).join(', ');
if(job && job.status === 'running'){
btn.disabled = true;
const pct = Math.round((job.progress||0)*100);
el.innerHTML = `<span style="color:var(--accent,#6366f1)">Quantizing… ${pct}% — ${esc(job.message||'')}</span>`;
} else if(job && job.status === 'done'){
btn.disabled = false;
el.innerHTML = `<span style="color:var(--ok,#16a34a)">✓ Quantized checkpoint ready — used automatically on next load.</span> <span class="muted">Kernels: ${esc(kernels)}</span>`;
} else if(job && job.status === 'failed'){
btn.disabled = false;
el.innerHTML = `<span style="color:var(--danger,#dc2626)">Quantization failed: ${esc(job.error||'')}</span> <span class="muted">Falls back to bitsandbytes.</span>`;
} else {
btn.disabled = false;
el.innerHTML = `Marlin/ExLlama 4-bit (2–4× faster than bitsandbytes). One-time background job; checkpoint auto-used on next load. <span class="muted">Kernels: ${esc(kernels||'detecting…')}</span>`;
}
}
async function _refreshQuantStatus(modelPath){
if(_quantPollTimer){ clearInterval(_quantPollTimer); _quantPollTimer = null; }
const caps = await _quantGetCaps();
const path = modelPath || (document.getElementById('cfg-path')||{}).value || '';
let job = (caps.jobs && caps.jobs[path]) || null;
_renderQuantStatus(job, caps);
// Poll while a job is running.
if(caps.available){
_quantPollTimer = setInterval(async ()=>{
try{
const r = await fetch('/admin/api/quantize-status?model_id='+encodeURIComponent(path));
const j = await r.json();
_renderQuantStatus(j.job, caps);
if(!j.job || (j.job.status !== 'running')){ clearInterval(_quantPollTimer); _quantPollTimer = null; }
}catch(e){ clearInterval(_quantPollTimer); _quantPollTimer = null; }
}, 2000);
}
}
async function startQuantize(){
const path = (document.getElementById('cfg-path')||{}).value || '';
if(!path){ alert('No model selected.'); return; }
const method = (document.getElementById('cfg-quant-backend').value === 'awq') ? 'awq' : 'gptq';
const el = document.getElementById('cfg-quant-status');
if(el) el.innerHTML = '<span class="muted">Starting quantization…</span>';
try{
const r = await fetch('/admin/api/model-quantize', {
method:'POST', headers:{'Content-Type':'application/json'},
body: JSON.stringify({path, method})
});
const j = await r.json();
if(!j.success && j.job && j.job.status === 'unavailable'){
_renderQuantStatus(null, {available:false, capabilities:j.job.caps||{}});
return;
}
_refreshQuantStatus(path);
}catch(e){
if(el) el.innerHTML = '<span style="color:var(--danger,#dc2626)">Failed to start: '+esc(String(e))+'</span>';
}
}
function _renderWhisperServerRows(models){ function _renderWhisperServerRows(models){
if(!models.length) return ''; if(!models.length) return '';
const rows = models.map(m=>{ const rows = models.map(m=>{
...@@ -2788,6 +2903,8 @@ function openCfgModal(idx, cfgIdx){ ...@@ -2788,6 +2903,8 @@ function openCfgModal(idx, cfgIdx){
document.getElementById('cfg-ram-gb').value = s.manual_ram_gb != null ? s.manual_ram_gb : ''; document.getElementById('cfg-ram-gb').value = s.manual_ram_gb != null ? s.manual_ram_gb : '';
document.getElementById('cfg-4bit').checked = !!s.load_in_4bit; document.getElementById('cfg-4bit').checked = !!s.load_in_4bit;
document.getElementById('cfg-8bit').checked = !!s.load_in_8bit; document.getElementById('cfg-8bit').checked = !!s.load_in_8bit;
document.getElementById('cfg-quant-backend').value = s.quant_backend || 'auto';
_refreshQuantStatus(m.path);
_renderComponentQuant(s.component_quantization || {}); _renderComponentQuant(s.component_quantization || {});
document.getElementById('cfg-flash').checked = !!s.flash_attention; document.getElementById('cfg-flash').checked = !!s.flash_attention;
document.getElementById('cfg-noram').checked = !!s.no_ram; document.getElementById('cfg-noram').checked = !!s.no_ram;
...@@ -3136,6 +3253,7 @@ async function saveModelConfig(){ ...@@ -3136,6 +3253,7 @@ async function saveModelConfig(){
manual_ram_gb: isNaN(ramGb) ? null : ramGb, manual_ram_gb: isNaN(ramGb) ? null : ramGb,
load_in_4bit: document.getElementById('cfg-4bit').checked, load_in_4bit: document.getElementById('cfg-4bit').checked,
load_in_8bit: document.getElementById('cfg-8bit').checked, load_in_8bit: document.getElementById('cfg-8bit').checked,
quant_backend: document.getElementById('cfg-quant-backend').value || 'auto',
component_quantization: _collectComponentQuant(), component_quantization: _collectComponentQuant(),
flash_attention: document.getElementById('cfg-flash').checked, flash_attention: document.getElementById('cfg-flash').checked,
no_ram: document.getElementById('cfg-noram').checked, no_ram: document.getElementById('cfg-noram').checked,
......
...@@ -81,6 +81,57 @@ def set_global_tools_closer_prompt(tools_closer: bool): ...@@ -81,6 +81,57 @@ def set_global_tools_closer_prompt(tools_closer: bool):
_set_global_tools_closer_prompt(tools_closer) _set_global_tools_closer_prompt(tools_closer)
def _conversation_session_key(request, http_request=None) -> Optional[str]:
"""Derive a stable per-conversation key for instance/KV-cache affinity.
Prefers an explicit identifier the client supplies (the OpenAI ``user`` field
or an ``X-Session-Id`` header). Otherwise falls back to a hash of the stable
opening of the conversation (system prompt + first user turn for chat, or the
prompt head for completions) — stable across the turns of one conversation,
distinct between conversations. Returns None if nothing usable is available
(callers then fall back to least-busy routing). Never raises.
"""
try:
# 1) Explicit id wins.
if http_request is not None:
sid = http_request.headers.get('x-session-id')
if sid:
return f"sid:{sid}"
uid = getattr(request, 'user', None)
if uid:
return f"user:{uid}"
# 2) Hash the stable opening of the conversation.
import hashlib
parts = []
msgs = getattr(request, 'messages', None)
if msgs:
first_user_seen = False
for m in msgs:
role = getattr(m, 'role', None) or (m.get('role') if isinstance(m, dict) else None)
content = getattr(m, 'content', None) or (m.get('content') if isinstance(m, dict) else None)
if not isinstance(content, str):
content = str(content)
if role == 'system':
parts.append(f"system:{content}")
elif role == 'user' and not first_user_seen:
parts.append(f"user:{content}")
first_user_seen = True
break
else:
prompt = getattr(request, 'prompt', None)
if isinstance(prompt, list):
prompt = prompt[0] if prompt else ''
if prompt:
parts.append(str(prompt)[:1024])
if not parts:
return None
digest = hashlib.sha256("\n".join(parts).encode('utf-8', 'ignore')).hexdigest()[:16]
return f"hash:{digest}"
except Exception:
return None
def set_grammar_guided_gen(enabled: bool): def set_grammar_guided_gen(enabled: bool):
"""Set the grammar-guided generation flag (via state module).""" """Set the grammar-guided generation flag (via state module)."""
_set_grammar_guided_gen(enabled) _set_grammar_guided_gen(enabled)
...@@ -424,7 +475,9 @@ async def chat_completions(request: ChatCompletionRequest, http_request: Request ...@@ -424,7 +475,9 @@ async def chat_completions(request: ChatCompletionRequest, http_request: Request
_model_key = model_info.get('model_key') _model_key = model_info.get('model_key')
_candidate = None _candidate = None
_acq = multi_model_manager.acquire_model_instance(_model_key) if _model_key else None _session_key = _conversation_session_key(request, http_request)
_acq = multi_model_manager.acquire_model_instance(
_model_key, session_key=_session_key) if _model_key else None
if _acq: if _acq:
_instance_idx, _candidate = _acq _instance_idx, _candidate = _acq
# Guard against stale pool entries (model evicted but pool not cleared) # Guard against stale pool entries (model evicted but pool not cleared)
...@@ -2072,10 +2125,13 @@ async def completions(request: CompletionRequest): ...@@ -2072,10 +2125,13 @@ async def completions(request: CompletionRequest):
if model_info.get('error'): if model_info.get('error'):
raise HTTPException(status_code=404, detail=model_info['error']) raise HTTPException(status_code=404, detail=model_info['error'])
# Acquire the least-busy instance (increments ref-count; released on response completion) # Acquire an instance (session-affinity when derivable, else least-busy;
# increments ref-count; released on response completion).
_model_key = model_info.get('model_key') _model_key = model_info.get('model_key')
_instance_idx = None _instance_idx = None
_acq = multi_model_manager.acquire_model_instance(_model_key) if _model_key else None _session_key = _conversation_session_key(request)
_acq = multi_model_manager.acquire_model_instance(
_model_key, session_key=_session_key) if _model_key else None
if _acq: if _acq:
_instance_idx, mm = _acq _instance_idx, mm = _acq
else: else:
......
This diff is collapsed.
This diff is collapsed.
...@@ -84,9 +84,45 @@ def _trim_cpu_ram() -> None: ...@@ -84,9 +84,45 @@ def _trim_cpu_ram() -> None:
pass pass
def _drop_prefix_caches() -> int:
"""Release reclaimable KV prefix caches across all loaded text backends.
The GGUF LlamaRAMCache lives in host RAM and the HF KV slots in VRAM; both are
pure caches that can be rebuilt on demand. Dropping them is a cheap rung on the
RAM/VRAM-pressure ladder, run before evicting whole models. Returns the number
of backends whose cache was cleared.
"""
n = 0
try:
from codai.models.manager import multi_model_manager as _mm
except Exception:
return 0
def _clear(obj):
nonlocal n
backend = getattr(obj, 'backend', None)
fn = getattr(backend, 'clear_prefix_cache', None)
if callable(fn):
try:
fn()
n += 1
except Exception:
pass
try:
for pool in list(getattr(_mm, 'model_pools', {}).values()):
for inst in list(getattr(pool, 'instances', [])):
_clear(inst)
for obj in list(getattr(_mm, 'models', {}).values()):
_clear(obj)
except Exception:
pass
return n
class ModelManager: class ModelManager:
"""Manages the loaded model and tokenizer.""" """Manages the loaded model and tokenizer."""
def __init__(self, backend=None, backend_type=None): def __init__(self, backend=None, backend_type=None):
self.backend = backend self.backend = backend
self.backend_type = backend_type self.backend_type = backend_type
...@@ -502,6 +538,10 @@ class ModelInstancePool: ...@@ -502,6 +538,10 @@ class ModelInstancePool:
self.ref_counts: list = [] self.ref_counts: list = []
self.max_instances: int = max_instances self.max_instances: int = max_instances
self._lock = threading.Lock() self._lock = threading.Lock()
# Session-affinity map: session_key -> instance index. Routes successive
# requests of one conversation back to the instance that already holds its
# warm KV prefix, instead of the purely load-based least-busy pick.
self._affinity: dict = {}
@property @property
def count(self) -> int: def count(self) -> int:
...@@ -519,15 +559,41 @@ class ModelInstancePool: ...@@ -519,15 +559,41 @@ class ModelInstancePool:
self.ref_counts.append(0) self.ref_counts.append(0)
return idx return idx
def acquire(self): def acquire(self, session_key=None):
"""Return (idx, instance) of the least-busy instance, incrementing its ref-count.""" """Return (idx, instance) of the chosen instance, incrementing its ref-count.
When ``session_key`` is given, prefer the instance this session was last
routed to (cache affinity) as long as it isn't markedly busier than the
least-busy one; otherwise fall back to the least-busy instance and record
the new mapping. With ``session_key=None`` the behaviour is unchanged
(pure least-busy).
"""
with self._lock: with self._lock:
if not self.instances: if not self.instances:
return None return None
idx = min(range(len(self.instances)), key=lambda i: self.ref_counts[i]) least = min(range(len(self.instances)), key=lambda i: self.ref_counts[i])
idx = least
if session_key is not None:
pinned = self._affinity.get(session_key)
if pinned is not None and 0 <= pinned < len(self.instances):
# Honour affinity unless the pinned instance is busier than the
# least-busy one by more than one in-flight request (keeps a hot
# cache without letting one instance pile up under load).
if self.ref_counts[pinned] <= self.ref_counts[least] + 1:
idx = pinned
self._affinity[session_key] = idx
self._prune_affinity()
self.ref_counts[idx] += 1 self.ref_counts[idx] += 1
return idx, self.instances[idx] return idx, self.instances[idx]
def _prune_affinity(self) -> None:
"""Bound the affinity map so it can't grow without limit (caller holds lock)."""
_CAP = 512
if len(self._affinity) > _CAP:
# Drop arbitrary oldest-inserted entries; dicts preserve insertion order.
for k in list(self._affinity.keys())[: len(self._affinity) - _CAP]:
self._affinity.pop(k, None)
def release(self, idx: int) -> None: def release(self, idx: int) -> None:
with self._lock: with self._lock:
if 0 <= idx < len(self.ref_counts): if 0 <= idx < len(self.ref_counts):
...@@ -3604,15 +3670,17 @@ class MultiModelManager: ...@@ -3604,15 +3670,17 @@ class MultiModelManager:
self.active_in_vram = key self.active_in_vram = key
self.models_in_vram.add(key) self.models_in_vram.add(key)
def acquire_model_instance(self, model_key: str): def acquire_model_instance(self, model_key: str, session_key=None):
"""Acquire the least-busy instance, incrementing its ref-count. """Acquire an instance, incrementing its ref-count.
Returns (instance_idx, model_obj) or None if no instance is loaded. ``session_key`` (optional) routes successive requests of one conversation
Callers MUST call release_model_instance() when done. back to the instance holding its warm KV prefix; without it the least-busy
instance is chosen. Returns (instance_idx, model_obj) or None if no
instance is loaded. Callers MUST call release_model_instance() when done.
""" """
pool = self.model_pools.get(model_key) pool = self.model_pools.get(model_key)
if pool and pool.count > 0: if pool and pool.count > 0:
return pool.acquire() return pool.acquire(session_key=session_key)
obj = self.models.get(model_key) obj = self.models.get(model_key)
if obj is not None: if obj is not None:
return 0, obj return 0, obj
......
"""GPTQ/AWQ fast-kernel quantization support for the HuggingFace (NvidiaBackend) path.
bitsandbytes NF4 is the slowest 4-bit option and especially hurts MoE models. This
module lets coderai (a) detect whether GPTQModel + fast kernels (Marlin/ExLlama) are
available, (b) resolve where a locally-quantized checkpoint for a model lives, and
(c) run an on-demand background job that quantizes a model to 4-bit GPTQ and caches
the result. Loading a produced (or otherwise pre-quantized) checkpoint then goes
through transformers' native quantization path, which picks the fast kernel.
The whole module degrades gracefully: if GPTQModel can't import, capability checks
return False and callers fall back to bitsandbytes.
"""
from __future__ import annotations
import os
import re
import threading
import time
from pathlib import Path
from typing import Dict, List, Optional, Any
from codai.models.cache import get_model_cache_dir
# --------------------------------------------------------------------------------
# Capability detection (cached — import probing is cheap but not free)
# --------------------------------------------------------------------------------
_caps_cache: Optional[Dict[str, Any]] = None
_caps_lock = threading.Lock()
def capabilities(refresh: bool = False) -> Dict[str, Any]:
"""Return a dict describing GPTQ/AWQ availability and usable kernels.
Keys: ``available`` (bool), ``version`` (str|None), ``backends`` (list[str] of
fast-kernel names that imported), ``error`` (str|None). Result is memoised.
"""
global _caps_cache
if _caps_cache is not None and not refresh:
return _caps_cache
with _caps_lock:
if _caps_cache is not None and not refresh:
return _caps_cache
caps: Dict[str, Any] = {"available": False, "version": None,
"backends": [], "error": None}
try:
import gptqmodel # noqa: F401
caps["version"] = getattr(gptqmodel, "__version__", "?")
# Which accelerated inference kernels are importable on this box.
try:
from gptqmodel.utils.backend import BACKEND
wanted = ["MARLIN", "EXLLAMA_V2", "EXLLAMA_V1", "TRITON",
"AWQ_MARLIN", "AWQ_GEMM"]
caps["backends"] = [b for b in wanted if hasattr(BACKEND, b)]
except Exception:
caps["backends"] = []
caps["available"] = True
except Exception as e: # ImportError or a broken transitive dep
caps["error"] = str(e)
_caps_cache = caps
return caps
def is_available() -> bool:
"""True when GPTQModel imports and at least one fast kernel is present."""
c = capabilities()
return bool(c["available"] and c["backends"])
# --------------------------------------------------------------------------------
# Quantized-checkpoint locations
# --------------------------------------------------------------------------------
def _safe_slug(model_name: str) -> str:
"""Filesystem-safe slug for a model id/path."""
return re.sub(r"[^A-Za-z0-9._-]+", "_", str(model_name)).strip("_")
def quantized_checkpoint_dir(model_name: str, method: str = "gptq") -> Path:
"""Cache directory where coderai stores a self-quantized checkpoint."""
root = Path(get_model_cache_dir()) / "quantized" / method.lower()
return root / _safe_slug(model_name)
def find_quantized_checkpoint(model_name: str, method: str = "gptq") -> Optional[str]:
"""Return the path to a usable locally-quantized checkpoint, or None.
A checkpoint counts as ready when its directory holds a config.json plus at
least one weights shard (so a half-finished/aborted quant isn't picked up).
"""
d = quantized_checkpoint_dir(model_name, method)
if not d.is_dir():
return None
if not (d / "config.json").is_file():
return None
has_weights = any(d.glob("*.safetensors")) or any(d.glob("*.bin"))
return str(d) if has_weights else None
# --------------------------------------------------------------------------------
# Background quantization job
# --------------------------------------------------------------------------------
_jobs: Dict[str, Dict[str, Any]] = {} # model_name -> job status dict
_jobs_lock = threading.Lock()
def get_job(model_name: str) -> Optional[Dict[str, Any]]:
with _jobs_lock:
j = _jobs.get(model_name)
return dict(j) if j else None
def all_jobs() -> Dict[str, Dict[str, Any]]:
with _jobs_lock:
return {k: dict(v) for k, v in _jobs.items()}
def _set_job(model_name: str, **fields) -> None:
with _jobs_lock:
j = _jobs.setdefault(model_name, {"model": model_name})
j.update(fields)
def start_quantization(model_name: str, method: str = "gptq", bits: int = 4,
group_size: int = 128) -> Dict[str, Any]:
"""Kick off (or report) a background quantization for ``model_name``.
Returns the job status dict. Idempotent: if a job is already running or a
checkpoint already exists, it returns that state instead of starting again.
"""
method = (method or "gptq").lower()
if not is_available():
return {"model": model_name, "status": "unavailable",
"error": "GPTQModel / fast kernels not installed",
"caps": capabilities()}
existing = find_quantized_checkpoint(model_name, method)
if existing:
_set_job(model_name, status="done", method=method, output=existing,
progress=1.0, message="already quantized")
return get_job(model_name)
with _jobs_lock:
cur = _jobs.get(model_name)
if cur and cur.get("status") == "running":
return dict(cur)
_jobs[model_name] = {"model": model_name, "method": method, "bits": bits,
"status": "running", "progress": 0.0,
"message": "starting", "started": time.time(),
"error": None, "output": None}
t = threading.Thread(
target=_quantize_worker,
args=(model_name, method, bits, group_size),
name=f"quantize-{_safe_slug(model_name)[:24]}",
daemon=True,
)
t.start()
return get_job(model_name)
def _quantize_worker(model_name: str, method: str, bits: int, group_size: int) -> None:
"""Run GPTQModel quantization and write the checkpoint to the cache dir.
Heavy and slow (loads the source model, runs calibration). Runs in its own
thread; never raises — failures are recorded on the job and the loader falls
back to bitsandbytes.
"""
out_dir = quantized_checkpoint_dir(model_name, method)
try:
_set_job(model_name, message="loading quantizer", progress=0.02)
from gptqmodel import GPTQModel, QuantizeConfig
# Calibration data: a small generic text sample. Enough to populate the
# GPTQ Hessian statistics without a domain-specific corpus.
calib = _calibration_samples()
_set_job(model_name, message="loading source model (this is slow)", progress=0.05)
qcfg = QuantizeConfig(bits=bits, group_size=group_size)
model = GPTQModel.load(model_name, qcfg)
_set_job(model_name, message="quantizing (calibration passes)", progress=0.15)
model.quantize(calib)
out_dir.mkdir(parents=True, exist_ok=True)
_set_job(model_name, message="saving checkpoint", progress=0.9)
model.save(str(out_dir))
# Persist the tokenizer alongside so the checkpoint loads standalone.
try:
from transformers import AutoTokenizer
AutoTokenizer.from_pretrained(model_name, trust_remote_code=True).save_pretrained(str(out_dir))
except Exception:
pass
_set_job(model_name, status="done", progress=1.0,
message="quantization complete", output=str(out_dir),
finished=time.time())
except Exception as e:
# Likely causes: arch not supported by the quantizer, or OOM. Leave any
# partial output in place but mark the job failed so the loader uses bnb.
_set_job(model_name, status="failed", error=str(e),
message=f"quantization failed: {e}", finished=time.time())
def _calibration_samples() -> List[str]:
"""A small, generic calibration set for GPTQ Hessian estimation."""
base = [
"The quick brown fox jumps over the lazy dog.",
"In computer science, a hash table is a data structure that maps keys to values.",
"def fibonacci(n):\n a, b = 0, 1\n for _ in range(n):\n a, b = b, a + b\n return a",
"The mitochondria is the powerhouse of the cell, generating most of the cell's ATP.",
"Once upon a time, in a land far away, there lived a wise old programmer.",
"To be, or not to be, that is the question that has echoed through the centuries.",
"Machine learning models learn patterns from data to make predictions on unseen inputs.",
"The capital of France is Paris, a city known for its art, culture, and history.",
]
# GPTQModel wants a few hundred short rows; repeat the seed set.
return [base[i % len(base)] for i in range(256)]
...@@ -171,6 +171,14 @@ def _mitigate(rss_gb: float, cap_gb: float, leak: bool, loading: bool = False) - ...@@ -171,6 +171,14 @@ def _mitigate(rss_gb: float, cap_gb: float, leak: bool, loading: bool = False) -
actions.append("drop_upscalers") actions.append("drop_upscalers")
except Exception: except Exception:
pass pass
# Drop reclaimable KV prefix caches (host-RAM LlamaRAMCache + HF KV slots)
# before resorting to evicting whole models — they rebuild on demand.
try:
from codai.models.manager import _drop_prefix_caches
if _drop_prefix_caches():
actions.append("drop_prefix_caches")
except Exception:
pass
# Still over and eviction is enabled → unload idle LRU models. # Still over and eviction is enabled → unload idle LRU models.
try: try:
from codai.models.manager import multi_model_manager as _mm from codai.models.manager import multi_model_manager as _mm
......
...@@ -58,6 +58,23 @@ tiktoken>=0.5.0 ...@@ -58,6 +58,23 @@ tiktoken>=0.5.0
tokenizers>=0.15.0 tokenizers>=0.15.0
protobuf>=3.20.0 protobuf>=3.20.0
# Optional: fast-kernel 4-bit quantization (GPTQ/AWQ via Marlin/ExLlama).
# Much faster than bitsandbytes NF4, especially for MoE models. coderai uses this
# for the per-model `quant_backend` (auto|bnb|gptq|awq) and the on-demand
# "Quantize to 4-bit" button on the model admin page. The feature auto-detects
# GPTQModel and falls back to bitsandbytes when it (or its kernels) are absent.
#
# On bleeding-edge stacks (e.g. CUDA 13 + torch 2.11) install the LATEST from
# source for new-arch support (e.g. gemma-4 MoE). GPTQModel 7.x is a pure-Python
# wheel (no CUDA compile) but pins numpy==2.2.6 in its metadata; do NOT let that
# downgrade a newer numpy your stack needs — install with the deps below and a
# numpy constraint, e.g.:
# pip install --no-deps "git+https://github.com/ModelCloud/GPTQModel.git"
# pip install -c <(echo "numpy==<your-version>") \
# torchao device-smi tokenicer logbar pypcre defuser maturin
# gptqmodel>=7.1.0
# torchao>=0.16.0
# llama-cpp-python with CUDA support (for GGUF files on CUDA backend) # llama-cpp-python with CUDA support (for GGUF files on CUDA backend)
llama-cpp-python>=0.2.0 llama-cpp-python>=0.2.0
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment