feat: smart context caching, VRAM offload fix, GPTQ/AWQ quant backend

Smart context caching (both text backends):
- Per-instance generation lock so pooled concurrent requests can't corrupt a
  shared KV cache (GGUF + HF, incl. streaming worker thread).
- GGUF: enable multi-slot LlamaRAMCache, budget via kv_cache_budget_mb (512MB).
- HF: replace single exact-text KV slot with an LRU of token-prefix slots +
  token-level longest-common-prefix + DynamicCache clone/crop (handles
  mid-history edits); kv_cache_slots (default 3).
- Session-affinity routing in ModelInstancePool.acquire(session_key); key from
  user/X-Session-Id else a stable prefix hash.
- RAM-pressure ladder drops reclaimable prefix caches before evicting models.

VRAM fix:
- Auto-fit check no longer double-counts the KV/activation reserve when
  expected_vram_gb is already a peak estimate — borderline models (e.g.
  gemma-4-26B-A4B) stay GPU-resident instead of forced into MoE-thrashing
  device_map offload.

GPTQ/AWQ fast-kernel quant backend (HF path):
- New codai/models/quant.py: GPTQModel capability detection, quantized-checkpoint
  cache, on-demand background quantize job (falls back to bnb if unsupported).
- quant_backend config (auto|bnb|gptq|awq); loader auto-uses a quantized
  checkpoint with Marlin/ExLlama when present, else bitsandbytes.
- Admin endpoints + "Quantize to 4-bit" button with live status on the model page.
- requirements-nvidia.txt documents the from-source install + numpy caveat.
Co-Authored-By: 's avatarClaude Opus 4.8 <noreply@anthropic.com>
parent e4c040e2
...@@ -1669,6 +1669,52 @@ async def api_model_disable(request: Request, username: str = Depends(require_ad ...@@ -1669,6 +1669,52 @@ async def api_model_disable(request: Request, username: str = Depends(require_ad
return {"success": True} return {"success": True}
@router.get("/admin/api/quantize-capabilities", summary="GPTQ/AWQ quantization availability")
async def api_quantize_capabilities(username: str = Depends(require_admin)):
"""Report whether fast-kernel (GPTQ/AWQ) quantization is available + any jobs."""
from codai.models import quant
return {
"capabilities": quant.capabilities(),
"available": quant.is_available(),
"jobs": quant.all_jobs(),
}
@router.post("/admin/api/model-quantize", summary="Quantize a model to fast-kernel 4-bit")
async def api_model_quantize(request: Request, username: str = Depends(require_admin)):
"""Start (or report) an on-demand background GPTQ/AWQ quantization.
Body: {path|model_id, method?(gptq|awq), bits?(4), group_size?(128)}.
Quantization is heavy and slow; it runs in the background and the produced
checkpoint is picked up automatically on the model's next load. Falls back to
bitsandbytes if the fast kernels are unavailable or the arch is unsupported.
"""
from codai.models import quant
data = await request.json()
model_id = (data.get("path") or data.get("model_id") or "").strip()
if not model_id:
raise HTTPException(status_code=400, detail="path/model_id is required")
method = (data.get("method") or "gptq").lower()
if method not in ("gptq", "awq"):
raise HTTPException(status_code=400, detail="method must be 'gptq' or 'awq'")
try:
bits = int(data.get("bits", 4))
group_size = int(data.get("group_size", 128))
except (TypeError, ValueError):
raise HTTPException(status_code=400, detail="bits/group_size must be integers")
job = quant.start_quantization(model_id, method=method, bits=bits, group_size=group_size)
return {"success": job.get("status") != "unavailable", "job": job}
@router.get("/admin/api/quantize-status", summary="Quantization job status")
async def api_quantize_status(model_id: str = "", username: str = Depends(require_admin)):
"""Status for one model's quant job (?model_id=...), or all jobs."""
from codai.models import quant
if model_id:
return {"job": quant.get_job(model_id.strip())}
return {"jobs": quant.all_jobs()}
@router.get("/admin/api/model-loaded-status", summary="Model load status") @router.get("/admin/api/model-loaded-status", summary="Model load status")
async def api_model_loaded_status(username: str = Depends(require_admin)): async def api_model_loaded_status(username: str = Depends(require_admin)):
"""Return loaded model keys with per-model instance pool info.""" """Return loaded model keys with per-model instance pool info."""
...@@ -2117,7 +2163,8 @@ async def api_model_configure(request: Request, username: str = Depends(require_ ...@@ -2117,7 +2163,8 @@ async def api_model_configure(request: Request, username: str = Depends(require_
"max_vram", "sdcpp_flash_attn", "sdcpp_diffusion_flash_attn", "vae_tiling", "max_vram", "sdcpp_flash_attn", "sdcpp_diffusion_flash_attn", "vae_tiling",
"component_quantization", "output_crf", "force_vram_update", "component_quantization", "output_crf", "force_vram_update",
"balanced_gpu_percent", "acceleration", "balanced_gpu_percent", "acceleration",
"cache_type_k", "cache_type_v", "turboquant", "engine"): "cache_type_k", "cache_type_v", "turboquant", "engine",
"quant_backend", "kv_cache_budget_mb", "kv_cache_slots"):
if key in data: if key in data:
entry[key] = data[key] entry[key] = data[key]
......
...@@ -645,6 +645,15 @@ window.__DEFAULT_WHISPER_SERVER_PATH__ = {{ default_whisper_server_path|tojson } ...@@ -645,6 +645,15 @@ window.__DEFAULT_WHISPER_SERVER_PATH__ = {{ default_whisper_server_path|tojson }
<option value="q4_0">q4_0 (smallest)</option> <option value="q4_0">q4_0 (smallest)</option>
</select> </select>
</div> </div>
<div class="form-row" style="margin:0">
<p class="muted" style="margin:0;font-size:0.85em;line-height:1.4">
<strong>Note:</strong> KV-cache quantization only applies to GGUF models on
the llama.cpp backend. HF-transformers models (incl. <strong>gemma</strong> and
other sliding-window / hybrid linear-attention architectures) ignore this and
keep an fp16 KV cache — their quantized-cache path crashes during generation.
For those, lower <strong>n_ctx</strong> to shrink the KV VRAM reserve instead.
</p>
</div>
<div class="form-row" style="margin:0"> <div class="form-row" style="margin:0">
<label class="form-label">Max GPU % <span class="muted">(optional)</span></label> <label class="form-label">Max GPU % <span class="muted">(optional)</span></label>
<input type="number" id="cfg-max-gpu" class="form-input" min="1" max="100" placeholder="e.g. 90"> <input type="number" id="cfg-max-gpu" class="form-input" min="1" max="100" placeholder="e.g. 90">
...@@ -661,6 +670,30 @@ window.__DEFAULT_WHISPER_SERVER_PATH__ = {{ default_whisper_server_path|tojson } ...@@ -661,6 +670,30 @@ window.__DEFAULT_WHISPER_SERVER_PATH__ = {{ default_whisper_server_path|tojson }
<label style="display:flex;align-items:center;gap:.5rem;cursor:pointer;font-size:13px"><input type="checkbox" id="cfg-noram"> No RAM fallback</label> <label style="display:flex;align-items:center;gap:.5rem;cursor:pointer;font-size:13px"><input type="checkbox" id="cfg-noram"> No RAM fallback</label>
</div> </div>
<!-- Fast-kernel 4-bit (GPTQ/AWQ via Marlin) — HF transformers models only -->
<div id="cfg-fastquant-wrap" style="margin-top:1rem;padding:.75rem;border:1px solid var(--border);border-radius:6px;background:var(--raised)">
<div style="display:flex;align-items:center;gap:1rem;flex-wrap:wrap">
<div style="flex:1;min-width:220px">
<label class="form-label">Quant backend <span class="muted">(HF transformers)</span></label>
<select id="cfg-quant-backend" class="form-input">
<option value="auto">Auto — fast checkpoint if present, else bitsandbytes</option>
<option value="bnb">bitsandbytes only (NF4)</option>
<option value="gptq">GPTQ (Marlin/ExLlama) — needs quantized checkpoint</option>
<option value="awq">AWQ (Marlin) — needs quantized checkpoint</option>
</select>
</div>
<div style="display:flex;align-items:flex-end;gap:.5rem">
<button type="button" class="btn btn-secondary btn-sm" id="cfg-quantize-btn" onclick="startQuantize()">Quantize to 4-bit</button>
</div>
</div>
<div class="form-hint" id="cfg-quant-status" style="margin-top:.5rem">
bitsandbytes NF4 is the slowest 4-bit option. Quantizing to GPTQ/AWQ uses
Marlin kernels (2–4× faster) but is a heavy one-time background job; the
produced checkpoint is used automatically on the next load. Falls back to
bitsandbytes if unavailable or the architecture is unsupported.
</div>
</div>
<!-- Per-component quantization (diffusers image/video pipelines) --> <!-- Per-component quantization (diffusers image/video pipelines) -->
<div id="cfg-compquant-wrap"> <div id="cfg-compquant-wrap">
<div class="card-title" style="margin-top:1.25rem">Component Quantization <span class="muted" style="font-weight:normal">(image / video pipelines)</span></div> <div class="card-title" style="margin-top:1.25rem">Component Quantization <span class="muted" style="font-weight:normal">(image / video pipelines)</span></div>
...@@ -2084,6 +2117,88 @@ function _collectComponentQuant(){ ...@@ -2084,6 +2117,88 @@ function _collectComponentQuant(){
return Object.keys(out).length ? out : null; return Object.keys(out).length ? out : null;
} }
// ---- Fast-kernel (GPTQ/AWQ) quantization ----
let _quantPollTimer = null;
let _quantCaps = null;
async function _quantGetCaps(){
if(_quantCaps) return _quantCaps;
try{
const r = await fetch('/admin/api/quantize-capabilities');
const j = await r.json();
_quantCaps = j;
}catch(e){ _quantCaps = {available:false, capabilities:{error:String(e)}}; }
return _quantCaps;
}
function _renderQuantStatus(job, caps){
const el = document.getElementById('cfg-quant-status');
const btn = document.getElementById('cfg-quantize-btn');
if(!el) return;
if(caps && !caps.available){
btn.disabled = true;
const err = (caps.capabilities && caps.capabilities.error) || 'GPTQModel not installed';
el.innerHTML = `<span style="color:var(--warn,#d97706)">Fast-kernel quantization unavailable — using bitsandbytes. (${esc(err)})</span>`;
return;
}
const kernels = (caps && caps.capabilities && caps.capabilities.backends || []).join(', ');
if(job && job.status === 'running'){
btn.disabled = true;
const pct = Math.round((job.progress||0)*100);
el.innerHTML = `<span style="color:var(--accent,#6366f1)">Quantizing… ${pct}% — ${esc(job.message||'')}</span>`;
} else if(job && job.status === 'done'){
btn.disabled = false;
el.innerHTML = `<span style="color:var(--ok,#16a34a)">✓ Quantized checkpoint ready — used automatically on next load.</span> <span class="muted">Kernels: ${esc(kernels)}</span>`;
} else if(job && job.status === 'failed'){
btn.disabled = false;
el.innerHTML = `<span style="color:var(--danger,#dc2626)">Quantization failed: ${esc(job.error||'')}</span> <span class="muted">Falls back to bitsandbytes.</span>`;
} else {
btn.disabled = false;
el.innerHTML = `Marlin/ExLlama 4-bit (2–4× faster than bitsandbytes). One-time background job; checkpoint auto-used on next load. <span class="muted">Kernels: ${esc(kernels||'detecting…')}</span>`;
}
}
async function _refreshQuantStatus(modelPath){
if(_quantPollTimer){ clearInterval(_quantPollTimer); _quantPollTimer = null; }
const caps = await _quantGetCaps();
const path = modelPath || (document.getElementById('cfg-path')||{}).value || '';
let job = (caps.jobs && caps.jobs[path]) || null;
_renderQuantStatus(job, caps);
// Poll while a job is running.
if(caps.available){
_quantPollTimer = setInterval(async ()=>{
try{
const r = await fetch('/admin/api/quantize-status?model_id='+encodeURIComponent(path));
const j = await r.json();
_renderQuantStatus(j.job, caps);
if(!j.job || (j.job.status !== 'running')){ clearInterval(_quantPollTimer); _quantPollTimer = null; }
}catch(e){ clearInterval(_quantPollTimer); _quantPollTimer = null; }
}, 2000);
}
}
async function startQuantize(){
const path = (document.getElementById('cfg-path')||{}).value || '';
if(!path){ alert('No model selected.'); return; }
const method = (document.getElementById('cfg-quant-backend').value === 'awq') ? 'awq' : 'gptq';
const el = document.getElementById('cfg-quant-status');
if(el) el.innerHTML = '<span class="muted">Starting quantization…</span>';
try{
const r = await fetch('/admin/api/model-quantize', {
method:'POST', headers:{'Content-Type':'application/json'},
body: JSON.stringify({path, method})
});
const j = await r.json();
if(!j.success && j.job && j.job.status === 'unavailable'){
_renderQuantStatus(null, {available:false, capabilities:j.job.caps||{}});
return;
}
_refreshQuantStatus(path);
}catch(e){
if(el) el.innerHTML = '<span style="color:var(--danger,#dc2626)">Failed to start: '+esc(String(e))+'</span>';
}
}
function _renderWhisperServerRows(models){ function _renderWhisperServerRows(models){
if(!models.length) return ''; if(!models.length) return '';
const rows = models.map(m=>{ const rows = models.map(m=>{
...@@ -2788,6 +2903,8 @@ function openCfgModal(idx, cfgIdx){ ...@@ -2788,6 +2903,8 @@ function openCfgModal(idx, cfgIdx){
document.getElementById('cfg-ram-gb').value = s.manual_ram_gb != null ? s.manual_ram_gb : ''; document.getElementById('cfg-ram-gb').value = s.manual_ram_gb != null ? s.manual_ram_gb : '';
document.getElementById('cfg-4bit').checked = !!s.load_in_4bit; document.getElementById('cfg-4bit').checked = !!s.load_in_4bit;
document.getElementById('cfg-8bit').checked = !!s.load_in_8bit; document.getElementById('cfg-8bit').checked = !!s.load_in_8bit;
document.getElementById('cfg-quant-backend').value = s.quant_backend || 'auto';
_refreshQuantStatus(m.path);
_renderComponentQuant(s.component_quantization || {}); _renderComponentQuant(s.component_quantization || {});
document.getElementById('cfg-flash').checked = !!s.flash_attention; document.getElementById('cfg-flash').checked = !!s.flash_attention;
document.getElementById('cfg-noram').checked = !!s.no_ram; document.getElementById('cfg-noram').checked = !!s.no_ram;
...@@ -3136,6 +3253,7 @@ async function saveModelConfig(){ ...@@ -3136,6 +3253,7 @@ async function saveModelConfig(){
manual_ram_gb: isNaN(ramGb) ? null : ramGb, manual_ram_gb: isNaN(ramGb) ? null : ramGb,
load_in_4bit: document.getElementById('cfg-4bit').checked, load_in_4bit: document.getElementById('cfg-4bit').checked,
load_in_8bit: document.getElementById('cfg-8bit').checked, load_in_8bit: document.getElementById('cfg-8bit').checked,
quant_backend: document.getElementById('cfg-quant-backend').value || 'auto',
component_quantization: _collectComponentQuant(), component_quantization: _collectComponentQuant(),
flash_attention: document.getElementById('cfg-flash').checked, flash_attention: document.getElementById('cfg-flash').checked,
no_ram: document.getElementById('cfg-noram').checked, no_ram: document.getElementById('cfg-noram').checked,
......
...@@ -81,6 +81,57 @@ def set_global_tools_closer_prompt(tools_closer: bool): ...@@ -81,6 +81,57 @@ def set_global_tools_closer_prompt(tools_closer: bool):
_set_global_tools_closer_prompt(tools_closer) _set_global_tools_closer_prompt(tools_closer)
def _conversation_session_key(request, http_request=None) -> Optional[str]:
"""Derive a stable per-conversation key for instance/KV-cache affinity.
Prefers an explicit identifier the client supplies (the OpenAI ``user`` field
or an ``X-Session-Id`` header). Otherwise falls back to a hash of the stable
opening of the conversation (system prompt + first user turn for chat, or the
prompt head for completions) — stable across the turns of one conversation,
distinct between conversations. Returns None if nothing usable is available
(callers then fall back to least-busy routing). Never raises.
"""
try:
# 1) Explicit id wins.
if http_request is not None:
sid = http_request.headers.get('x-session-id')
if sid:
return f"sid:{sid}"
uid = getattr(request, 'user', None)
if uid:
return f"user:{uid}"
# 2) Hash the stable opening of the conversation.
import hashlib
parts = []
msgs = getattr(request, 'messages', None)
if msgs:
first_user_seen = False
for m in msgs:
role = getattr(m, 'role', None) or (m.get('role') if isinstance(m, dict) else None)
content = getattr(m, 'content', None) or (m.get('content') if isinstance(m, dict) else None)
if not isinstance(content, str):
content = str(content)
if role == 'system':
parts.append(f"system:{content}")
elif role == 'user' and not first_user_seen:
parts.append(f"user:{content}")
first_user_seen = True
break
else:
prompt = getattr(request, 'prompt', None)
if isinstance(prompt, list):
prompt = prompt[0] if prompt else ''
if prompt:
parts.append(str(prompt)[:1024])
if not parts:
return None
digest = hashlib.sha256("\n".join(parts).encode('utf-8', 'ignore')).hexdigest()[:16]
return f"hash:{digest}"
except Exception:
return None
def set_grammar_guided_gen(enabled: bool): def set_grammar_guided_gen(enabled: bool):
"""Set the grammar-guided generation flag (via state module).""" """Set the grammar-guided generation flag (via state module)."""
_set_grammar_guided_gen(enabled) _set_grammar_guided_gen(enabled)
...@@ -424,7 +475,9 @@ async def chat_completions(request: ChatCompletionRequest, http_request: Request ...@@ -424,7 +475,9 @@ async def chat_completions(request: ChatCompletionRequest, http_request: Request
_model_key = model_info.get('model_key') _model_key = model_info.get('model_key')
_candidate = None _candidate = None
_acq = multi_model_manager.acquire_model_instance(_model_key) if _model_key else None _session_key = _conversation_session_key(request, http_request)
_acq = multi_model_manager.acquire_model_instance(
_model_key, session_key=_session_key) if _model_key else None
if _acq: if _acq:
_instance_idx, _candidate = _acq _instance_idx, _candidate = _acq
# Guard against stale pool entries (model evicted but pool not cleared) # Guard against stale pool entries (model evicted but pool not cleared)
...@@ -2072,10 +2125,13 @@ async def completions(request: CompletionRequest): ...@@ -2072,10 +2125,13 @@ async def completions(request: CompletionRequest):
if model_info.get('error'): if model_info.get('error'):
raise HTTPException(status_code=404, detail=model_info['error']) raise HTTPException(status_code=404, detail=model_info['error'])
# Acquire the least-busy instance (increments ref-count; released on response completion) # Acquire an instance (session-affinity when derivable, else least-busy;
# increments ref-count; released on response completion).
_model_key = model_info.get('model_key') _model_key = model_info.get('model_key')
_instance_idx = None _instance_idx = None
_acq = multi_model_manager.acquire_model_instance(_model_key) if _model_key else None _session_key = _conversation_session_key(request)
_acq = multi_model_manager.acquire_model_instance(
_model_key, session_key=_session_key) if _model_key else None
if _acq: if _acq:
_instance_idx, mm = _acq _instance_idx, mm = _acq
else: else:
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
import os import os
import time as _time import time as _time
import threading
from typing import Optional, List, Dict from typing import Optional, List, Dict
from threading import Thread from threading import Thread
from abc import ABC from abc import ABC
...@@ -79,13 +80,24 @@ class NvidiaBackend(ModelBackend): ...@@ -79,13 +80,24 @@ class NvidiaBackend(ModelBackend):
self.device = None self.device = None
self.use_flash_attn = False self.use_flash_attn = False
self.flash_attn_available = False self.flash_attn_available = False
# KV prefix cache (single-entry, keyed by formatted prefix text) # Multi-slot KV prefix cache. Each slot keeps a pristine prefill cache for
self._kv_prefix_text: Optional[str] = None # a token prefix; on a new request we reuse the slot sharing the longest
self._kv_past_key_values = None # past_key_values tensor tuple # token prefix (so two interleaved conversations stay warm, and an edit in
self._kv_prefix_len: int = 0 # token count of the cached prefix # the middle of the history only re-encodes from the edit point). Slots are
self._kv_timestamp: float = 0.0 # an LRU list (most-recently-used last), each a dict:
# {"ids": list[int], "cache": past_key_values, "length": int, "ts": float}
self._kv_slots: list = []
self._kv_max_slots: int = 3 # overridden from config (kv_cache_slots)
self._kv_min_reuse: int = 16 # don't bother reusing a tiny prefix
self._kv_ttl: float = 300.0 # 5 min TTL self._kv_ttl: float = 300.0 # 5 min TTL
self._last_usage: Dict = {} self._last_usage: Dict = {}
# Serializes generation (prefix-cache build + model.generate) within this
# one instance. The instance pool can hand the same backend to two
# concurrent requests when all instances are busy; overlapping generate
# calls share one KV cache and would corrupt each other. A plain Lock (not
# RLock) is used so the streaming path can acquire it in a worker thread
# and release it from the event-loop thread.
self._gen_lock = threading.Lock()
def check_flash_attn_support(self) -> None: def check_flash_attn_support(self) -> None:
"""Check and print Flash Attention availability status.""" """Check and print Flash Attention availability status."""
...@@ -282,6 +294,34 @@ class NvidiaBackend(ModelBackend): ...@@ -282,6 +294,34 @@ class NvidiaBackend(ModelBackend):
return 2 return 2
return 4 return 4
def _warn_kv_quant_ignored(self):
"""Print an explicit note when an explicit KV-quant request is dropped.
gemma/sliding-window and hybrid linear-attention models route through HF
transformers' quanto/HQQ ``QuantizedCache``, which raises during
generation on those architectures — so we force fp16 KV regardless of any
``cache_type_k``/``cache_type_v`` (e.g. q4_0) the user set. Without this
line the log just shows ``quant=None`` with no hint that the request was
deliberately ignored. (q4_0 KV only applies to GGUF/llama.cpp models.)
"""
spec = str(
getattr(self, '_pending_cache_type_k', None)
or getattr(self, '_pending_cache_type_v', None)
or ''
).lower()
if spec in ('', 'f16', 'fp16', 'bf16', 'f32', 'none', 'auto'):
return
if not self._kv_quant_compatible():
why = ('sliding-window/gemma' if self._is_sliding_window_model()
else 'hybrid linear-attention')
print(
f" Note: KV-cache quant '{spec}' IGNORED — {why} models use HF "
f"transformers' quanto/HQQ QuantizedCache, which crashes on this "
f"architecture during generation, so KV stays fp16. q4_0-style KV "
f"quantization only applies to GGUF models on the llama.cpp backend. "
f"Lower n_ctx to shrink the KV reserve instead."
)
def _kv_quant_compatible(self) -> bool: def _kv_quant_compatible(self) -> bool:
"""Whether the model supports transformers' quantized KV cache. """Whether the model supports transformers' quantized KV cache.
...@@ -583,8 +623,16 @@ class NvidiaBackend(ModelBackend): ...@@ -583,8 +623,16 @@ class NvidiaBackend(ModelBackend):
bnb_4bit_quant_type='nf4', bnb_4bit_quant_type='nf4',
bnb_4bit_compute_dtype=torch.float16, bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True, bnb_4bit_use_double_quant=True,
# Required when device_map spills modules to CPU/disk: without it
# bitsandbytes refuses any offloaded quantized model and aborts
# the load ("set llm_int8_enable_fp32_cpu_offload=True"). Keeps
# quantized modules on GPU while non-quantized ones go to CPU/fp32.
llm_int8_enable_fp32_cpu_offload=True,
)
return BitsAndBytesConfig(
load_in_8bit=True,
llm_int8_enable_fp32_cpu_offload=True,
) )
return BitsAndBytesConfig(load_in_8bit=True)
def _is_moe_model(self, model_name: str) -> bool: def _is_moe_model(self, model_name: str) -> bool:
"""Check if model is a MoE model.""" """Check if model is a MoE model."""
...@@ -706,6 +754,11 @@ class NvidiaBackend(ModelBackend): ...@@ -706,6 +754,11 @@ class NvidiaBackend(ModelBackend):
offload_strategy = kwargs.get('offload_strategy', 'auto') offload_strategy = kwargs.get('offload_strategy', 'auto')
max_gpu_percent = kwargs.get('max_gpu_percent', None) max_gpu_percent = kwargs.get('max_gpu_percent', None)
expected_vram_gb = kwargs.get('expected_vram_gb') or 0 expected_vram_gb = kwargs.get('expected_vram_gb') or 0
# _get_model_used_vram_gb() always returns a PEAK estimate (a measured
# resident total, or weights + _runtime_reserve_gb), so the auto-fit check
# must not re-add the KV/activation reserve. Default True; a caller that
# ever passes a weights-only number can override with False.
expected_is_total = bool(kwargs.get('expected_vram_is_total', True))
# Check for --no-ram mode # Check for --no-ram mode
no_ram = kwargs.get('no_ram', False) no_ram = kwargs.get('no_ram', False)
...@@ -724,6 +777,45 @@ class NvidiaBackend(ModelBackend): ...@@ -724,6 +777,45 @@ class NvidiaBackend(ModelBackend):
self._pending_cache_type_k = kwargs.get('cache_type_k') self._pending_cache_type_k = kwargs.get('cache_type_k')
self._pending_cache_type_v = kwargs.get('cache_type_v') self._pending_cache_type_v = kwargs.get('cache_type_v')
# Per-model multi-slot prefix-cache size (number of warm conversations to
# keep). <=0 disables prefix caching for this model; unset keeps the
# default. Honours kwargs first, then the raw models.json entry.
_raw_cfg = kwargs.get('_raw_cfg') or {}
_slots = kwargs.get('kv_cache_slots', _raw_cfg.get('kv_cache_slots'))
if _slots is not None:
try:
self._kv_max_slots = max(0, int(_slots))
except (TypeError, ValueError):
pass
# GPTQ/AWQ fast-kernel path. quant_backend: auto|bnb|gptq|awq (default auto).
# When enabled AND a locally-quantized checkpoint already exists for this
# model, load THAT (it carries its own quant config and loads via Marlin/
# ExLlama through transformers' native path) instead of the fp16 source +
# bitsandbytes. No checkpoint → fall through to the bnb path unchanged
# (quantization itself is an explicit, on-demand admin action, never auto).
self._loaded_quant_backend = "bnb"
quant_backend = str(kwargs.get('quant_backend',
_raw_cfg.get('quant_backend') or 'auto')).lower()
if quant_backend in ('auto', 'gptq', 'awq'):
try:
from codai.models import quant as _quant
if _quant.is_available():
_methods = ('gptq', 'awq') if quant_backend == 'auto' else (quant_backend,)
for _m in _methods:
_ckpt = _quant.find_quantized_checkpoint(model_name, _m)
if _ckpt:
print(f"Using fast-kernel {_m.upper()} checkpoint: {_ckpt}")
model_name = _ckpt
load_in_4bit = load_in_8bit = False # checkpoint self-quantized
self._loaded_quant_backend = _m
break
elif quant_backend in ('gptq', 'awq'):
print(f" quant_backend={quant_backend} requested but GPTQModel/"
f"fast kernels unavailable — falling back to bitsandbytes")
except Exception as _qe:
print(f" GPTQ/AWQ checkpoint lookup failed ({_qe}); using bitsandbytes")
print(f"Loading HuggingFace model: {model_name}") print(f"Loading HuggingFace model: {model_name}")
# Flash-Attention-2 requires the ENTIRE model resident on a single CUDA # Flash-Attention-2 requires the ENTIRE model resident on a single CUDA
...@@ -906,16 +998,27 @@ class NvidiaBackend(ModelBackend): ...@@ -906,16 +998,27 @@ class NvidiaBackend(ModelBackend):
if torch.cuda.is_available() and expected_vram_gb > 0: if torch.cuda.is_available() and expected_vram_gb > 0:
_free, _ = torch.cuda.mem_get_info(0) _free, _ = torch.cuda.mem_get_info(0)
_free_gb = _free / 1e9 _free_gb = _free / 1e9
# expected_vram_gb is already a PEAK estimate: the measured
# path returns a real resident total (KV + activations
# included), and the estimate path adds _runtime_reserve_gb
# on top of the weights. Re-adding the KV reserve + a fixed
# activation pad here double-counted that headroom (e.g. a
# 24.3 GB measured total became a 28.6 GB "need"), which
# pushed models that actually fit straight into device_map
# CPU offload — catastrophic for MoE models (expert thrash).
# Compare the peak directly, only when a measurement isn't
# available fall back to padding the weights estimate.
_kv_gb = self._kv_cache_reserve_bytes() / 1e9 _kv_gb = self._kv_cache_reserve_bytes() / 1e9
_act_gb = 1.5 if _kv_gb > 0 else 0.0 if expected_is_total:
_need_gb = expected_vram_gb + _kv_gb + _act_gb _need_gb = expected_vram_gb
else:
_need_gb = expected_vram_gb + _kv_gb + (1.5 if _kv_gb > 0 else 0.0)
_borderline = 3.0 if offload_strategy == 'auto-borderline' else 0.0 _borderline = 3.0 if offload_strategy == 'auto-borderline' else 0.0
_fits = _need_gb <= (_free_gb - 0.5 + _borderline) _fits = _need_gb <= (_free_gb - 0.5 + _borderline)
if _fits: if _fits:
print(f"\n Auto: peak VRAM need {_need_gb:.1f} GB " print(f"\n Auto: peak VRAM need {_need_gb:.1f} GB "
f"(weights {expected_vram_gb:.1f} + KV {_kv_gb:.1f} " f"({'measured total' if expected_is_total else f'weights {expected_vram_gb:.1f} + KV {_kv_gb:.1f}'}) "
f"+ act {_act_gb:.1f}) fits in {_free_gb:.1f} GB free " f"fits in {_free_gb:.1f} GB free — loading full-GPU (no offload)")
f"— loading full-GPU (no offload)")
else: else:
print(f"\n Auto: peak VRAM need {_need_gb:.1f} GB > " print(f"\n Auto: peak VRAM need {_need_gb:.1f} GB > "
f"{_free_gb:.1f} GB free — going straight to " f"{_free_gb:.1f} GB free — going straight to "
...@@ -1078,6 +1181,7 @@ class NvidiaBackend(ModelBackend): ...@@ -1078,6 +1181,7 @@ class NvidiaBackend(ModelBackend):
f"weight budget {weight_budget/1e9:.1f}→{new_budget/1e9:.1f}GB " f"weight budget {weight_budget/1e9:.1f}→{new_budget/1e9:.1f}GB "
f"(rest spills to CPU)" f"(rest spills to CPU)"
) )
self._warn_kv_quant_ignored()
weight_budget = new_budget weight_budget = new_budget
max_memory[i] = weight_budget max_memory[i] = weight_budget
...@@ -1257,7 +1361,7 @@ class NvidiaBackend(ModelBackend): ...@@ -1257,7 +1361,7 @@ class NvidiaBackend(ModelBackend):
gen_kwargs["repetition_penalty"] = max(presence_penalty, frequency_penalty) if max(presence_penalty, frequency_penalty) > 1.0 else 1.0 gen_kwargs["repetition_penalty"] = max(presence_penalty, frequency_penalty) if max(presence_penalty, frequency_penalty) > 1.0 else 1.0
try: try:
with torch.no_grad(): with self._gen_lock, torch.no_grad():
outputs = self.model.generate(**gen_kwargs) outputs = self.model.generate(**gen_kwargs)
generated_tokens = outputs[0][inputs["input_ids"].shape[1]:] generated_tokens = outputs[0][inputs["input_ids"].shape[1]:]
...@@ -1268,7 +1372,7 @@ class NvidiaBackend(ModelBackend): ...@@ -1268,7 +1372,7 @@ class NvidiaBackend(ModelBackend):
print(f"Warning: CUDA OOM during generation. Clearing cache and retrying...") print(f"Warning: CUDA OOM during generation. Clearing cache and retrying...")
torch.cuda.empty_cache() torch.cuda.empty_cache()
try: try:
with torch.no_grad(): with self._gen_lock, torch.no_grad():
outputs = self.model.generate(**gen_kwargs) outputs = self.model.generate(**gen_kwargs)
generated_tokens = outputs[0][inputs["input_ids"].shape[1]:] generated_tokens = outputs[0][inputs["input_ids"].shape[1]:]
...@@ -1433,6 +1537,7 @@ class NvidiaBackend(ModelBackend): ...@@ -1433,6 +1537,7 @@ class NvidiaBackend(ModelBackend):
def generate_with_error_handling(): def generate_with_error_handling():
nonlocal generation_error nonlocal generation_error
try: try:
with self._gen_lock:
self.model.generate(**generation_kwargs) self.model.generate(**generation_kwargs)
except (RuntimeError, torch.cuda.OutOfMemoryError) as e: except (RuntimeError, torch.cuda.OutOfMemoryError) as e:
error_msg = str(e).lower() error_msg = str(e).lower()
...@@ -1487,12 +1592,6 @@ class NvidiaBackend(ModelBackend): ...@@ -1487,12 +1592,6 @@ class NvidiaBackend(ModelBackend):
# KV prefix cache helpers # KV prefix cache helpers
# ------------------------------------------------------------------ # ------------------------------------------------------------------
def _kv_cache_valid(self) -> bool:
return (
self._kv_past_key_values is not None and
_time.time() - self._kv_timestamp < self._kv_ttl
)
def _model_on_cuda(self) -> bool: def _model_on_cuda(self) -> bool:
"""Return True only when the model's first parameter is actually on a CUDA device.""" """Return True only when the model's first parameter is actually on a CUDA device."""
try: try:
...@@ -1500,33 +1599,135 @@ class NvidiaBackend(ModelBackend): ...@@ -1500,33 +1599,135 @@ class NvidiaBackend(ModelBackend):
except StopIteration: except StopIteration:
return False return False
def _build_kv_prefix(self, prefix_text: str): @staticmethod
"""Forward-pass on prefix_text to populate the KV state.""" def _lcp_len(a: list, b: list) -> int:
"""Length of the longest common prefix of two token-id lists."""
n = 0
for x, y in zip(a, b):
if x != y:
break
n += 1
return n
def _clone_crop(self, cache, length: int):
"""Return a deep copy of ``cache`` truncated to ``length`` tokens.
Generation extends the cache it is handed in place, so we must never pass
a stored slot directly — clone it first and let the clone be mutated while
the stored prefix stays pristine for the next reuse.
"""
import copy
clone = copy.deepcopy(cache)
try:
if length < clone.get_seq_length():
clone.crop(length)
except Exception:
pass
return clone
def _build_kv_prefix(self, input_ids):
"""Prefill ``input_ids`` (a [1, P] tensor) and return (past_key_values, P).
Building from a slice of the request's own ``total_input_ids`` keeps the
stored prefix token-aligned with future prompts, so token-level prefix
matching is exact (no add_special_tokens / re-render drift).
"""
import torch import torch
# KV prefix caching requires CUDA tensors; skip on CPU-mode models.
if not self._model_on_cuda(): if not self._model_on_cuda():
raise RuntimeError("KV prefix cache requires CUDA; model is on CPU") raise RuntimeError("KV prefix cache requires CUDA; model is on CPU")
inputs = self.tokenizer( input_ids = input_ids.to(self.model.device)
attn = torch.ones_like(input_ids)
with self._gen_lock, torch.no_grad():
out = self.model(input_ids=input_ids, attention_mask=attn,
use_cache=True, return_dict=True)
return out.past_key_values, int(input_ids.shape[1])
def _kv_prune(self) -> None:
"""Drop expired slots (TTL) and enforce the slot-count cap (evict LRU)."""
now = _time.time()
self._kv_slots = [s for s in self._kv_slots if now - s["ts"] < self._kv_ttl]
while len(self._kv_slots) > max(0, self._kv_max_slots):
old = self._kv_slots.pop(0)
old.pop("cache", None)
def _store_kv(self, ids: list, past_kv, length: int) -> None:
"""Insert/refresh a slot for token prefix ``ids`` (most-recently-used)."""
self._kv_slots = [s for s in self._kv_slots if s["ids"] != ids[:s["length"]] or s["length"] != length]
self._kv_slots.append({
"ids": list(ids[:length]),
"cache": past_kv,
"length": int(length),
"ts": _time.time(),
})
self._kv_prune()
def _lookup_kv(self, total_ids: list):
"""Find the slot sharing the longest token prefix with ``total_ids``.
Returns (cloned_cache, matched_len) ready to hand to generate() with
input_ids = total_input_ids[:, matched_len:], or (None, 0) on a miss.
"""
self._kv_prune()
best = None
best_m = 0
for s in self._kv_slots:
m = self._lcp_len(total_ids, s["ids"])
if m > best_m:
best_m, best = m, s
# Need a non-trivial match that still leaves at least one token to decode.
if best is None or best_m < self._kv_min_reuse or best_m >= len(total_ids):
return None, 0
clone = self._clone_crop(best["cache"], best_m)
best["ts"] = _time.time() # mark reused (LRU freshness)
return clone, best_m
def _reuse_or_seed_prefix(self, total_input_ids, total_prompt_len, prefix_msgs,
enable_thinking=False, tools=None):
"""Return (past_kv, cached_len) for the request's KV prefix.
First tries a token-level match against the multi-slot cache (reuses an
existing conversation's KV, including after a mid-history edit). On a miss
it seeds a new slot by prefilling the message-boundary prefix, built from a
slice of ``total_input_ids`` so the stored tokens stay aligned with future
prompts. Returns (None, 0) when nothing usable could be prepared.
"""
if self._kv_max_slots <= 0:
return None, 0
total_ids = total_input_ids[0].tolist()
# 1) Reuse the warmest matching slot.
past_kv, cached_len = self._lookup_kv(total_ids)
if past_kv is not None:
return past_kv, cached_len
# 2) Seed: prefill up to the aligned message-boundary, then keep a pristine
# copy in a slot and hand a clone to the caller for generation.
try:
prefix_text = self._build_chat_prompt(
prefix_msgs, enable_thinking=enable_thinking,
add_generation_prompt=False, tools=tools)
rendered = self.tokenizer(
prefix_text, return_tensors="pt", add_special_tokens=False prefix_text, return_tensors="pt", add_special_tokens=False
) )['input_ids'][0].tolist()
inputs = {k: v.to(self.model.device) for k, v in inputs.items()} # Aligned boundary = how far the rendered prefix matches the full prompt.
with torch.no_grad(): boundary = min(self._lcp_len(total_ids, rendered), total_prompt_len - 1)
out = self.model(**inputs, use_cache=True, return_dict=True) if boundary < self._kv_min_reuse:
return out.past_key_values, int(inputs['input_ids'].shape[1]) return None, 0
built_kv, boundary = self._build_kv_prefix(total_input_ids[:, :boundary])
def _store_kv(self, prefix_text: str, past_kv, prefix_len: int) -> None: self._store_kv(total_ids, built_kv, boundary)
self._kv_prefix_text = prefix_text return self._clone_crop(built_kv, boundary), boundary
self._kv_past_key_values = past_kv except Exception as e:
self._kv_prefix_len = prefix_len print(f"Warning: KV prefix cache build failed: {e}")
self._kv_timestamp = _time.time() return None, 0
def invalidate_kv_cache(self) -> None: def invalidate_kv_cache(self) -> None:
"""Discard the cached KV state (call on model unload/swap).""" """Discard all cached KV state (call on model unload/swap)."""
self._kv_prefix_text = None for s in self._kv_slots:
if self._kv_past_key_values is not None: s.pop("cache", None)
del self._kv_past_key_values self._kv_slots = []
self._kv_past_key_values = None
self._kv_prefix_len = 0 def clear_prefix_cache(self) -> None:
"""Release cached KV slots (called by the RAM/VRAM-pressure ladder)."""
with self._gen_lock:
self.invalidate_kv_cache()
def _kv_prefix_supported(self) -> bool: def _kv_prefix_supported(self) -> bool:
"""Whether this model can safely reuse a manually-prefilled KV cache. """Whether this model can safely reuse a manually-prefilled KV cache.
...@@ -1600,7 +1801,6 @@ class NvidiaBackend(ModelBackend): ...@@ -1600,7 +1801,6 @@ class NvidiaBackend(ModelBackend):
return free / 1e9 >= min_free_gb return free / 1e9 >= min_free_gb
except Exception: except Exception:
return False return False
self._kv_timestamp = 0.0
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# Usage tracking # Usage tracking
...@@ -1887,18 +2087,9 @@ class NvidiaBackend(ModelBackend): ...@@ -1887,18 +2087,9 @@ class NvidiaBackend(ModelBackend):
cached_len = 0 cached_len = 0
if prefix_msgs and self._model_on_cuda() and self._kv_prefix_supported() and self._kv_prefix_headroom_ok(): if prefix_msgs and self._model_on_cuda() and self._kv_prefix_supported() and self._kv_prefix_headroom_ok():
prefix_text = self._build_chat_prompt( past_kv, cached_len = self._reuse_or_seed_prefix(
prefix_msgs, enable_thinking=enable_thinking, add_generation_prompt=False, tools=tools) total_input_ids, total_prompt_len, prefix_msgs,
if self._kv_cache_valid() and self._kv_prefix_text == prefix_text: enable_thinking=enable_thinking, tools=tools)
past_kv = self._kv_past_key_values
cached_len = self._kv_prefix_len
else:
try:
past_kv, cached_len = self._build_kv_prefix(prefix_text)
self._store_kv(prefix_text, past_kv, cached_len)
except Exception as e:
print(f"Warning: KV prefix cache build failed: {e}")
past_kv, cached_len = None, 0
temperature, top_p, do_sample = self._validate_params(temperature, top_p) temperature, top_p, do_sample = self._validate_params(temperature, top_p)
gen_kwargs = dict( gen_kwargs = dict(
...@@ -1925,7 +2116,7 @@ class NvidiaBackend(ModelBackend): ...@@ -1925,7 +2116,7 @@ class NvidiaBackend(ModelBackend):
full_attn = torch.ones( full_attn = torch.ones(
1, total_prompt_len, dtype=torch.long, device=self.model.device 1, total_prompt_len, dtype=torch.long, device=self.model.device
) )
with torch.no_grad(): with self._gen_lock, torch.no_grad():
outputs = self.model.generate( outputs = self.model.generate(
input_ids=suffix_ids, input_ids=suffix_ids,
past_key_values=past_kv, past_key_values=past_kv,
...@@ -1936,7 +2127,7 @@ class NvidiaBackend(ModelBackend): ...@@ -1936,7 +2127,7 @@ class NvidiaBackend(ModelBackend):
else: else:
cached_len = 0 cached_len = 0
attn_mask = torch.ones_like(total_input_ids) attn_mask = torch.ones_like(total_input_ids)
with torch.no_grad(): with self._gen_lock, torch.no_grad():
outputs = self.model.generate( outputs = self.model.generate(
input_ids=total_input_ids, input_ids=total_input_ids,
attention_mask=attn_mask, attention_mask=attn_mask,
...@@ -1958,7 +2149,7 @@ class NvidiaBackend(ModelBackend): ...@@ -1958,7 +2149,7 @@ class NvidiaBackend(ModelBackend):
full_prompt, return_tensors="pt" full_prompt, return_tensors="pt"
)['input_ids'].to(self.model.device) )['input_ids'].to(self.model.device)
attn_mask = torch.ones_like(total_input_ids) attn_mask = torch.ones_like(total_input_ids)
with torch.no_grad(): with self._gen_lock, torch.no_grad():
outputs = self.model.generate( outputs = self.model.generate(
input_ids=total_input_ids, input_ids=total_input_ids,
attention_mask=attn_mask, attention_mask=attn_mask,
...@@ -1976,7 +2167,7 @@ class NvidiaBackend(ModelBackend): ...@@ -1976,7 +2167,7 @@ class NvidiaBackend(ModelBackend):
full_prompt, return_tensors="pt" full_prompt, return_tensors="pt"
)['input_ids'].to(self.model.device) )['input_ids'].to(self.model.device)
attn_mask = torch.ones_like(total_input_ids) attn_mask = torch.ones_like(total_input_ids)
with torch.no_grad(): with self._gen_lock, torch.no_grad():
outputs = self.model.generate( outputs = self.model.generate(
input_ids=total_input_ids, input_ids=total_input_ids,
attention_mask=attn_mask, attention_mask=attn_mask,
...@@ -2032,18 +2223,9 @@ class NvidiaBackend(ModelBackend): ...@@ -2032,18 +2223,9 @@ class NvidiaBackend(ModelBackend):
cached_len = 0 cached_len = 0
if prefix_msgs and self._model_on_cuda() and self._kv_prefix_supported() and self._kv_prefix_headroom_ok(): if prefix_msgs and self._model_on_cuda() and self._kv_prefix_supported() and self._kv_prefix_headroom_ok():
prefix_text = self._build_chat_prompt( past_kv, cached_len = self._reuse_or_seed_prefix(
prefix_msgs, enable_thinking=enable_thinking, add_generation_prompt=False, tools=tools) total_input_ids, total_prompt_len, prefix_msgs,
if self._kv_cache_valid() and self._kv_prefix_text == prefix_text: enable_thinking=enable_thinking, tools=tools)
past_kv = self._kv_past_key_values
cached_len = self._kv_prefix_len
else:
try:
past_kv, cached_len = self._build_kv_prefix(prefix_text)
self._store_kv(prefix_text, past_kv, cached_len)
except Exception as e:
print(f"Warning: KV prefix cache build failed (stream): {e}")
past_kv, cached_len = None, 0
temperature, top_p, do_sample = self._validate_params(temperature, top_p) temperature, top_p, do_sample = self._validate_params(temperature, top_p)
...@@ -2122,7 +2304,7 @@ class NvidiaBackend(ModelBackend): ...@@ -2122,7 +2304,7 @@ class NvidiaBackend(ModelBackend):
def _run(): def _run():
try: try:
with torch.no_grad(): with self._gen_lock, torch.no_grad():
self.model.generate(**gen_kwargs) self.model.generate(**gen_kwargs)
except Exception as e: except Exception as e:
gen_error[0] = str(e) gen_error[0] = str(e)
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
import os import os
import json import json
import threading
from typing import AsyncIterator, Optional, Union, List, Dict, Any from typing import AsyncIterator, Optional, Union, List, Dict, Any
from pathlib import Path from pathlib import Path
...@@ -143,7 +144,7 @@ def _install_layer_log_callback(): ...@@ -143,7 +144,7 @@ def _install_layer_log_callback():
return _cb # caller must hold this reference return _cb # caller must hold this reference
async def _aiter_blocking(sync_iter): async def _aiter_blocking(sync_iter, lock=None):
"""Bridge a blocking (sync) generator onto the asyncio event loop. """Bridge a blocking (sync) generator onto the asyncio event loop.
llama.cpp's create_(chat_)completion returns a *synchronous* generator whose llama.cpp's create_(chat_)completion returns a *synchronous* generator whose
...@@ -167,6 +168,13 @@ async def _aiter_blocking(sync_iter): ...@@ -167,6 +168,13 @@ async def _aiter_blocking(sync_iter):
except StopIteration: except StopIteration:
return _SENT return _SENT
# When a per-instance generation lock is supplied, hold it for the whole
# iteration so a concurrent request on the same backend can't interleave its
# forward passes into this one's KV cache. Acquired in a worker thread so a
# contended lock doesn't block the event loop; released from here (a plain
# threading.Lock permits cross-thread release).
if lock is not None:
await asyncio.to_thread(lock.acquire)
try: try:
while True: while True:
item = await asyncio.to_thread(_next) item = await asyncio.to_thread(_next)
...@@ -180,6 +188,11 @@ async def _aiter_blocking(sync_iter): ...@@ -180,6 +188,11 @@ async def _aiter_blocking(sync_iter):
close() close()
except Exception: except Exception:
pass pass
if lock is not None:
try:
lock.release()
except RuntimeError:
pass
class VulkanBackend(ModelBackend): class VulkanBackend(ModelBackend):
...@@ -198,6 +211,14 @@ class VulkanBackend(ModelBackend): ...@@ -198,6 +211,14 @@ class VulkanBackend(ModelBackend):
if self.force_cuda: if self.force_cuda:
print("DEBUG: GGUF model will use CUDA backend (forced by --backend nvidia)") print("DEBUG: GGUF model will use CUDA backend (forced by --backend nvidia)")
self._last_usage: dict = {} # usage from the most recent completion call self._last_usage: dict = {} # usage from the most recent completion call
# Serializes the synchronous forward pass within this one instance. The
# instance pool may hand the same backend to two concurrent requests when
# all instances are busy; overlapping create_completion calls share one KV
# cache and would corrupt each other. Distinct pool instances each have
# their own lock, so they still run in parallel. A plain Lock (not RLock)
# is used so the streaming path can acquire it in a worker thread and
# release it from the event-loop thread.
self._gen_lock = threading.Lock()
self._detect_chat_template() self._detect_chat_template()
def _detect_chat_template(self): def _detect_chat_template(self):
...@@ -772,10 +793,57 @@ class VulkanBackend(ModelBackend): ...@@ -772,10 +793,57 @@ class VulkanBackend(ModelBackend):
except Exception: except Exception:
pass pass
# Multi-slot prefix cache. llama-cpp-python's LlamaRAMCache keeps several
# past sequences' KV states in host RAM and, on each completion, reloads
# the one sharing the longest token prefix with the new prompt (see
# Llama._create_completion). This lets two interleaved conversations both
# stay "warm" instead of evicting each other — only the changed suffix is
# re-evaluated. Bounded by a per-model byte budget so it can't grow without
# limit (the states live in CPU RAM, copied back on a slot switch).
# kv_cache_budget_mb honours kwargs first, then the raw models.json entry,
# mirroring how cache_type_k / flash_attn are resolved above. <=0 disables
# it; unset falls back to a sensible default.
_DEFAULT_CACHE_MB = 512
_budget_mb = kwargs.get('kv_cache_budget_mb', _raw_cfg.get('kv_cache_budget_mb'))
self._setup_prefix_cache(_budget_mb if _budget_mb is not None else _DEFAULT_CACHE_MB)
# Try to detect and set up chat template # Try to detect and set up chat template
self._finalize_chat_template_detection() self._finalize_chat_template_detection()
print(f" chat_template: {self.chat_template}") print(f" chat_template: {self.chat_template}")
def _setup_prefix_cache(self, budget_mb) -> None:
"""Attach a multi-slot LlamaRAMCache sized to ``budget_mb`` megabytes."""
try:
budget = max(0, int(budget_mb)) * 1024 * 1024
except (TypeError, ValueError):
budget = 512 * 1024 * 1024
self._prefix_cache_budget = budget
if budget <= 0:
print(" Prefix cache : disabled (kv_cache_budget_mb=0)")
return
try:
from llama_cpp import LlamaRAMCache
self.model.set_cache(LlamaRAMCache(capacity_bytes=budget))
print(f" Prefix cache : multi-slot RAM cache, budget {budget // (1024*1024)} MB")
except Exception as e:
print(f" Prefix cache : could not enable ({e}); relying on single-slot prefix match")
def clear_prefix_cache(self) -> None:
"""Release the host-RAM prefix cache (called by the RAM-pressure ladder).
Re-attaches a fresh, empty LlamaRAMCache at the same budget so caching
keeps working after the reclaim — only the stored sequences are dropped.
"""
budget = getattr(self, '_prefix_cache_budget', 0)
if not budget or self.model is None:
return
try:
from llama_cpp import LlamaRAMCache
with self._gen_lock:
self.model.set_cache(LlamaRAMCache(capacity_bytes=budget))
except Exception:
pass
def generate( def generate(
self, self,
prompt: str, prompt: str,
...@@ -846,6 +914,7 @@ class VulkanBackend(ModelBackend): ...@@ -846,6 +914,7 @@ class VulkanBackend(ModelBackend):
pass pass
try: try:
with self._gen_lock:
result = self.model.create_completion( result = self.model.create_completion(
stopping_criteria=_make_llama_thermal_criteria(), stopping_criteria=_make_llama_thermal_criteria(),
prompt=prompt, prompt=prompt,
...@@ -865,6 +934,7 @@ class VulkanBackend(ModelBackend): ...@@ -865,6 +934,7 @@ class VulkanBackend(ModelBackend):
if use_grammar: if use_grammar:
print(f"Warning: Grammar-guided generation failed: {e}, falling back to normal generation") print(f"Warning: Grammar-guided generation failed: {e}, falling back to normal generation")
try: try:
with self._gen_lock:
result = self.model.create_completion( result = self.model.create_completion(
stopping_criteria=_make_llama_thermal_criteria(), stopping_criteria=_make_llama_thermal_criteria(),
prompt=prompt, prompt=prompt,
...@@ -963,7 +1033,7 @@ class VulkanBackend(ModelBackend): ...@@ -963,7 +1033,7 @@ class VulkanBackend(ModelBackend):
stop=stop, stop=stop,
stream=True, stream=True,
grammar=use_grammar, grammar=use_grammar,
)): ), lock=self._gen_lock):
text = chunk['choices'][0].get('text', '') text = chunk['choices'][0].get('text', '')
if first_chunk: if first_chunk:
...@@ -1002,7 +1072,7 @@ class VulkanBackend(ModelBackend): ...@@ -1002,7 +1072,7 @@ class VulkanBackend(ModelBackend):
repeat_penalty=repeat_penalty, repeat_penalty=repeat_penalty,
stop=stop, stop=stop,
stream=True, stream=True,
)): ), lock=self._gen_lock):
text = chunk['choices'][0].get('text', '') text = chunk['choices'][0].get('text', '')
if first_chunk: if first_chunk:
...@@ -1071,7 +1141,7 @@ class VulkanBackend(ModelBackend): ...@@ -1071,7 +1141,7 @@ class VulkanBackend(ModelBackend):
repeat_penalty=repeat_penalty, repeat_penalty=repeat_penalty,
stop=stop, stop=stop,
stream=True, stream=True,
)): ), lock=self._gen_lock):
text = chunk['choices'][0].get('text', '') text = chunk['choices'][0].get('text', '')
if first_chunk: if first_chunk:
...@@ -1089,6 +1159,7 @@ class VulkanBackend(ModelBackend): ...@@ -1089,6 +1159,7 @@ class VulkanBackend(ModelBackend):
return {"stream": generate_stream(), "content": ""} return {"stream": generate_stream(), "content": ""}
else: else:
with self._gen_lock:
result = self.model.create_completion( result = self.model.create_completion(
stopping_criteria=_make_llama_thermal_criteria(), stopping_criteria=_make_llama_thermal_criteria(),
prompt=prompt, prompt=prompt,
...@@ -1210,6 +1281,7 @@ class VulkanBackend(ModelBackend): ...@@ -1210,6 +1281,7 @@ class VulkanBackend(ModelBackend):
if _tc is not None: if _tc is not None:
kwargs['stopping_criteria'] = _tc kwargs['stopping_criteria'] = _tc
with self._gen_lock:
result = self.model.create_chat_completion(**kwargs) result = self.model.create_chat_completion(**kwargs)
usage = result.get('usage', {}) usage = result.get('usage', {})
self._store_usage( self._store_usage(
...@@ -1241,7 +1313,7 @@ class VulkanBackend(ModelBackend): ...@@ -1241,7 +1313,7 @@ class VulkanBackend(ModelBackend):
prompt_tokens = 0 prompt_tokens = 0
completion_tokens = 0 completion_tokens = 0
try: try:
async for chunk in _aiter_blocking(self.model.create_chat_completion(**kwargs)): async for chunk in _aiter_blocking(self.model.create_chat_completion(**kwargs), lock=self._gen_lock):
delta = chunk['choices'][0].get('delta', {}) delta = chunk['choices'][0].get('delta', {})
text = delta.get('content') or '' text = delta.get('content') or ''
if text: if text:
......
...@@ -84,6 +84,42 @@ def _trim_cpu_ram() -> None: ...@@ -84,6 +84,42 @@ def _trim_cpu_ram() -> None:
pass pass
def _drop_prefix_caches() -> int:
"""Release reclaimable KV prefix caches across all loaded text backends.
The GGUF LlamaRAMCache lives in host RAM and the HF KV slots in VRAM; both are
pure caches that can be rebuilt on demand. Dropping them is a cheap rung on the
RAM/VRAM-pressure ladder, run before evicting whole models. Returns the number
of backends whose cache was cleared.
"""
n = 0
try:
from codai.models.manager import multi_model_manager as _mm
except Exception:
return 0
def _clear(obj):
nonlocal n
backend = getattr(obj, 'backend', None)
fn = getattr(backend, 'clear_prefix_cache', None)
if callable(fn):
try:
fn()
n += 1
except Exception:
pass
try:
for pool in list(getattr(_mm, 'model_pools', {}).values()):
for inst in list(getattr(pool, 'instances', [])):
_clear(inst)
for obj in list(getattr(_mm, 'models', {}).values()):
_clear(obj)
except Exception:
pass
return n
class ModelManager: class ModelManager:
"""Manages the loaded model and tokenizer.""" """Manages the loaded model and tokenizer."""
...@@ -502,6 +538,10 @@ class ModelInstancePool: ...@@ -502,6 +538,10 @@ class ModelInstancePool:
self.ref_counts: list = [] self.ref_counts: list = []
self.max_instances: int = max_instances self.max_instances: int = max_instances
self._lock = threading.Lock() self._lock = threading.Lock()
# Session-affinity map: session_key -> instance index. Routes successive
# requests of one conversation back to the instance that already holds its
# warm KV prefix, instead of the purely load-based least-busy pick.
self._affinity: dict = {}
@property @property
def count(self) -> int: def count(self) -> int:
...@@ -519,15 +559,41 @@ class ModelInstancePool: ...@@ -519,15 +559,41 @@ class ModelInstancePool:
self.ref_counts.append(0) self.ref_counts.append(0)
return idx return idx
def acquire(self): def acquire(self, session_key=None):
"""Return (idx, instance) of the least-busy instance, incrementing its ref-count.""" """Return (idx, instance) of the chosen instance, incrementing its ref-count.
When ``session_key`` is given, prefer the instance this session was last
routed to (cache affinity) as long as it isn't markedly busier than the
least-busy one; otherwise fall back to the least-busy instance and record
the new mapping. With ``session_key=None`` the behaviour is unchanged
(pure least-busy).
"""
with self._lock: with self._lock:
if not self.instances: if not self.instances:
return None return None
idx = min(range(len(self.instances)), key=lambda i: self.ref_counts[i]) least = min(range(len(self.instances)), key=lambda i: self.ref_counts[i])
idx = least
if session_key is not None:
pinned = self._affinity.get(session_key)
if pinned is not None and 0 <= pinned < len(self.instances):
# Honour affinity unless the pinned instance is busier than the
# least-busy one by more than one in-flight request (keeps a hot
# cache without letting one instance pile up under load).
if self.ref_counts[pinned] <= self.ref_counts[least] + 1:
idx = pinned
self._affinity[session_key] = idx
self._prune_affinity()
self.ref_counts[idx] += 1 self.ref_counts[idx] += 1
return idx, self.instances[idx] return idx, self.instances[idx]
def _prune_affinity(self) -> None:
"""Bound the affinity map so it can't grow without limit (caller holds lock)."""
_CAP = 512
if len(self._affinity) > _CAP:
# Drop arbitrary oldest-inserted entries; dicts preserve insertion order.
for k in list(self._affinity.keys())[: len(self._affinity) - _CAP]:
self._affinity.pop(k, None)
def release(self, idx: int) -> None: def release(self, idx: int) -> None:
with self._lock: with self._lock:
if 0 <= idx < len(self.ref_counts): if 0 <= idx < len(self.ref_counts):
...@@ -3604,15 +3670,17 @@ class MultiModelManager: ...@@ -3604,15 +3670,17 @@ class MultiModelManager:
self.active_in_vram = key self.active_in_vram = key
self.models_in_vram.add(key) self.models_in_vram.add(key)
def acquire_model_instance(self, model_key: str): def acquire_model_instance(self, model_key: str, session_key=None):
"""Acquire the least-busy instance, incrementing its ref-count. """Acquire an instance, incrementing its ref-count.
Returns (instance_idx, model_obj) or None if no instance is loaded. ``session_key`` (optional) routes successive requests of one conversation
Callers MUST call release_model_instance() when done. back to the instance holding its warm KV prefix; without it the least-busy
instance is chosen. Returns (instance_idx, model_obj) or None if no
instance is loaded. Callers MUST call release_model_instance() when done.
""" """
pool = self.model_pools.get(model_key) pool = self.model_pools.get(model_key)
if pool and pool.count > 0: if pool and pool.count > 0:
return pool.acquire() return pool.acquire(session_key=session_key)
obj = self.models.get(model_key) obj = self.models.get(model_key)
if obj is not None: if obj is not None:
return 0, obj return 0, obj
......
"""GPTQ/AWQ fast-kernel quantization support for the HuggingFace (NvidiaBackend) path.
bitsandbytes NF4 is the slowest 4-bit option and especially hurts MoE models. This
module lets coderai (a) detect whether GPTQModel + fast kernels (Marlin/ExLlama) are
available, (b) resolve where a locally-quantized checkpoint for a model lives, and
(c) run an on-demand background job that quantizes a model to 4-bit GPTQ and caches
the result. Loading a produced (or otherwise pre-quantized) checkpoint then goes
through transformers' native quantization path, which picks the fast kernel.
The whole module degrades gracefully: if GPTQModel can't import, capability checks
return False and callers fall back to bitsandbytes.
"""
from __future__ import annotations
import os
import re
import threading
import time
from pathlib import Path
from typing import Dict, List, Optional, Any
from codai.models.cache import get_model_cache_dir
# --------------------------------------------------------------------------------
# Capability detection (cached — import probing is cheap but not free)
# --------------------------------------------------------------------------------
_caps_cache: Optional[Dict[str, Any]] = None
_caps_lock = threading.Lock()
def capabilities(refresh: bool = False) -> Dict[str, Any]:
"""Return a dict describing GPTQ/AWQ availability and usable kernels.
Keys: ``available`` (bool), ``version`` (str|None), ``backends`` (list[str] of
fast-kernel names that imported), ``error`` (str|None). Result is memoised.
"""
global _caps_cache
if _caps_cache is not None and not refresh:
return _caps_cache
with _caps_lock:
if _caps_cache is not None and not refresh:
return _caps_cache
caps: Dict[str, Any] = {"available": False, "version": None,
"backends": [], "error": None}
try:
import gptqmodel # noqa: F401
caps["version"] = getattr(gptqmodel, "__version__", "?")
# Which accelerated inference kernels are importable on this box.
try:
from gptqmodel.utils.backend import BACKEND
wanted = ["MARLIN", "EXLLAMA_V2", "EXLLAMA_V1", "TRITON",
"AWQ_MARLIN", "AWQ_GEMM"]
caps["backends"] = [b for b in wanted if hasattr(BACKEND, b)]
except Exception:
caps["backends"] = []
caps["available"] = True
except Exception as e: # ImportError or a broken transitive dep
caps["error"] = str(e)
_caps_cache = caps
return caps
def is_available() -> bool:
"""True when GPTQModel imports and at least one fast kernel is present."""
c = capabilities()
return bool(c["available"] and c["backends"])
# --------------------------------------------------------------------------------
# Quantized-checkpoint locations
# --------------------------------------------------------------------------------
def _safe_slug(model_name: str) -> str:
"""Filesystem-safe slug for a model id/path."""
return re.sub(r"[^A-Za-z0-9._-]+", "_", str(model_name)).strip("_")
def quantized_checkpoint_dir(model_name: str, method: str = "gptq") -> Path:
"""Cache directory where coderai stores a self-quantized checkpoint."""
root = Path(get_model_cache_dir()) / "quantized" / method.lower()
return root / _safe_slug(model_name)
def find_quantized_checkpoint(model_name: str, method: str = "gptq") -> Optional[str]:
"""Return the path to a usable locally-quantized checkpoint, or None.
A checkpoint counts as ready when its directory holds a config.json plus at
least one weights shard (so a half-finished/aborted quant isn't picked up).
"""
d = quantized_checkpoint_dir(model_name, method)
if not d.is_dir():
return None
if not (d / "config.json").is_file():
return None
has_weights = any(d.glob("*.safetensors")) or any(d.glob("*.bin"))
return str(d) if has_weights else None
# --------------------------------------------------------------------------------
# Background quantization job
# --------------------------------------------------------------------------------
_jobs: Dict[str, Dict[str, Any]] = {} # model_name -> job status dict
_jobs_lock = threading.Lock()
def get_job(model_name: str) -> Optional[Dict[str, Any]]:
with _jobs_lock:
j = _jobs.get(model_name)
return dict(j) if j else None
def all_jobs() -> Dict[str, Dict[str, Any]]:
with _jobs_lock:
return {k: dict(v) for k, v in _jobs.items()}
def _set_job(model_name: str, **fields) -> None:
with _jobs_lock:
j = _jobs.setdefault(model_name, {"model": model_name})
j.update(fields)
def start_quantization(model_name: str, method: str = "gptq", bits: int = 4,
group_size: int = 128) -> Dict[str, Any]:
"""Kick off (or report) a background quantization for ``model_name``.
Returns the job status dict. Idempotent: if a job is already running or a
checkpoint already exists, it returns that state instead of starting again.
"""
method = (method or "gptq").lower()
if not is_available():
return {"model": model_name, "status": "unavailable",
"error": "GPTQModel / fast kernels not installed",
"caps": capabilities()}
existing = find_quantized_checkpoint(model_name, method)
if existing:
_set_job(model_name, status="done", method=method, output=existing,
progress=1.0, message="already quantized")
return get_job(model_name)
with _jobs_lock:
cur = _jobs.get(model_name)
if cur and cur.get("status") == "running":
return dict(cur)
_jobs[model_name] = {"model": model_name, "method": method, "bits": bits,
"status": "running", "progress": 0.0,
"message": "starting", "started": time.time(),
"error": None, "output": None}
t = threading.Thread(
target=_quantize_worker,
args=(model_name, method, bits, group_size),
name=f"quantize-{_safe_slug(model_name)[:24]}",
daemon=True,
)
t.start()
return get_job(model_name)
def _quantize_worker(model_name: str, method: str, bits: int, group_size: int) -> None:
"""Run GPTQModel quantization and write the checkpoint to the cache dir.
Heavy and slow (loads the source model, runs calibration). Runs in its own
thread; never raises — failures are recorded on the job and the loader falls
back to bitsandbytes.
"""
out_dir = quantized_checkpoint_dir(model_name, method)
try:
_set_job(model_name, message="loading quantizer", progress=0.02)
from gptqmodel import GPTQModel, QuantizeConfig
# Calibration data: a small generic text sample. Enough to populate the
# GPTQ Hessian statistics without a domain-specific corpus.
calib = _calibration_samples()
_set_job(model_name, message="loading source model (this is slow)", progress=0.05)
qcfg = QuantizeConfig(bits=bits, group_size=group_size)
model = GPTQModel.load(model_name, qcfg)
_set_job(model_name, message="quantizing (calibration passes)", progress=0.15)
model.quantize(calib)
out_dir.mkdir(parents=True, exist_ok=True)
_set_job(model_name, message="saving checkpoint", progress=0.9)
model.save(str(out_dir))
# Persist the tokenizer alongside so the checkpoint loads standalone.
try:
from transformers import AutoTokenizer
AutoTokenizer.from_pretrained(model_name, trust_remote_code=True).save_pretrained(str(out_dir))
except Exception:
pass
_set_job(model_name, status="done", progress=1.0,
message="quantization complete", output=str(out_dir),
finished=time.time())
except Exception as e:
# Likely causes: arch not supported by the quantizer, or OOM. Leave any
# partial output in place but mark the job failed so the loader uses bnb.
_set_job(model_name, status="failed", error=str(e),
message=f"quantization failed: {e}", finished=time.time())
def _calibration_samples() -> List[str]:
"""A small, generic calibration set for GPTQ Hessian estimation."""
base = [
"The quick brown fox jumps over the lazy dog.",
"In computer science, a hash table is a data structure that maps keys to values.",
"def fibonacci(n):\n a, b = 0, 1\n for _ in range(n):\n a, b = b, a + b\n return a",
"The mitochondria is the powerhouse of the cell, generating most of the cell's ATP.",
"Once upon a time, in a land far away, there lived a wise old programmer.",
"To be, or not to be, that is the question that has echoed through the centuries.",
"Machine learning models learn patterns from data to make predictions on unseen inputs.",
"The capital of France is Paris, a city known for its art, culture, and history.",
]
# GPTQModel wants a few hundred short rows; repeat the seed set.
return [base[i % len(base)] for i in range(256)]
...@@ -171,6 +171,14 @@ def _mitigate(rss_gb: float, cap_gb: float, leak: bool, loading: bool = False) - ...@@ -171,6 +171,14 @@ def _mitigate(rss_gb: float, cap_gb: float, leak: bool, loading: bool = False) -
actions.append("drop_upscalers") actions.append("drop_upscalers")
except Exception: except Exception:
pass pass
# Drop reclaimable KV prefix caches (host-RAM LlamaRAMCache + HF KV slots)
# before resorting to evicting whole models — they rebuild on demand.
try:
from codai.models.manager import _drop_prefix_caches
if _drop_prefix_caches():
actions.append("drop_prefix_caches")
except Exception:
pass
# Still over and eviction is enabled → unload idle LRU models. # Still over and eviction is enabled → unload idle LRU models.
try: try:
from codai.models.manager import multi_model_manager as _mm from codai.models.manager import multi_model_manager as _mm
......
...@@ -58,6 +58,23 @@ tiktoken>=0.5.0 ...@@ -58,6 +58,23 @@ tiktoken>=0.5.0
tokenizers>=0.15.0 tokenizers>=0.15.0
protobuf>=3.20.0 protobuf>=3.20.0
# Optional: fast-kernel 4-bit quantization (GPTQ/AWQ via Marlin/ExLlama).
# Much faster than bitsandbytes NF4, especially for MoE models. coderai uses this
# for the per-model `quant_backend` (auto|bnb|gptq|awq) and the on-demand
# "Quantize to 4-bit" button on the model admin page. The feature auto-detects
# GPTQModel and falls back to bitsandbytes when it (or its kernels) are absent.
#
# On bleeding-edge stacks (e.g. CUDA 13 + torch 2.11) install the LATEST from
# source for new-arch support (e.g. gemma-4 MoE). GPTQModel 7.x is a pure-Python
# wheel (no CUDA compile) but pins numpy==2.2.6 in its metadata; do NOT let that
# downgrade a newer numpy your stack needs — install with the deps below and a
# numpy constraint, e.g.:
# pip install --no-deps "git+https://github.com/ModelCloud/GPTQModel.git"
# pip install -c <(echo "numpy==<your-version>") \
# torchao device-smi tokenicer logbar pypcre defuser maturin
# gptqmodel>=7.1.0
# torchao>=0.16.0
# llama-cpp-python with CUDA support (for GGUF files on CUDA backend) # llama-cpp-python with CUDA support (for GGUF files on CUDA backend)
llama-cpp-python>=0.2.0 llama-cpp-python>=0.2.0
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment