Add prompt caching and prompt aggregation

parent 0ac26bed
......@@ -13,6 +13,14 @@ An OpenAI-compatible API server to run models on your local GPU with web adminis
- **Multi-Modal**: Text, image, video, audio, TTS, STT, embeddings
- **Per-Model Configuration**: Individual settings for each model (GPU layers, quantization, context size)
- **On-Demand Loading**: Models load automatically when requested, unload when idle
- **Memory Management**: Smart VRAM → RAM → Disk offloading for efficient resource usage
- **Parallel Execution**: Run multiple models simultaneously (VRAM permitting)
- **Auto-Swap**: Automatic model switching on request — load what's needed, unload what's idle
- **Request Queue**: Concurrent requests are queued and processed in order per model
- **Prompt Caching**: Reuse KV cache across requests to reduce latency and computation
- **Prompt Aggregation**: Batch concurrent requests into a single inference pass for higher throughput
- **Custom Pipelines**: Create and save multi-step workflows combining any generation tasks
- **Pre-Built Pipelines**: Ready-to-use pipelines for common workflows (image-to-video, dubbing, story generation)
### GPU Backend Support
- **NVIDIA (CUDA)**: PyTorch + Transformers for HuggingFace models
......
......@@ -256,6 +256,27 @@ a.dl { display:inline-block; margin-top:.4rem; }
.pb-step-header { display:flex; align-items:center; gap:.4rem; font-size:12px; font-weight:600; }
.pb-step-params { display:flex; flex-direction:column; gap:.25rem; padding-top:.25rem; }
.pb-step-param { display:flex; align-items:center; gap:.4rem; font-size:12px; }
/* ── Pipeline capability chips ────────────────────────────────── */
.pipe-caps { display:flex; flex-wrap:wrap; align-items:center; gap:.3rem; padding-bottom:.6rem; margin-bottom:.25rem; border-bottom:1px solid var(--border); }
.pipe-caps-label { font-size:10px; color:var(--text-3); text-transform:uppercase; letter-spacing:.05em; margin-right:.1rem; }
.pipe-cap-chip { font-size:10px; padding:.1rem .4rem; border-radius:4px; border:1px solid transparent; }
.pipe-cap-chip.ok { background:#0d2e18; color:#4ade80; border-color:#1a4a1a; }
.pipe-cap-chip.missing { background:#2e0d0d; color:#f07070; border-color:#5a1a1a; }
.pipe-cap-chip.optional { background:var(--surface-2); color:var(--text-3); }
.pipe-cap-chip.optional.ok { background:#1a1f0a; color:#c0d060; border-color:#3a4a10; }
/* ── Sub-tab model picker ─────────────────────────────────────── */
.cap-model-picker { display:flex; align-items:center; gap:.5rem; flex-wrap:wrap; padding-top:.4rem; border-top:1px solid var(--border); margin-top:.1rem; }
.cap-model-picker-label { font-size:10px; color:var(--text-3); text-transform:uppercase; letter-spacing:.05em; white-space:nowrap; }
.cap-model-chips { display:flex; flex-wrap:wrap; gap:.3rem; }
.cap-model-chip { font-size:11px; padding:.18rem .55rem; border-radius:999px; border:1px solid; cursor:pointer; background:transparent; font-family:inherit; transition:background .15s; }
.cap-model-chip.ok { color:#4ade80; border-color:#2a5a2a; background:#0f2a0f; }
.cap-model-chip.ok:hover { background:#1a3f1a; }
.cap-model-chip.warn { color:#f0c060; border-color:#5a3a10; background:#2a1f0a; }
.cap-model-chip.warn:hover { background:#3a2f0a; }
.cap-model-chip.active { font-weight:700; outline:1px solid currentColor; outline-offset:1px; }
/* ── Sidebar capability highlights ───────────────────────────── */
.model-item.cap-ok { border-left:3px solid #4caf50; padding-left:calc(.6rem - 3px); }
.model-item.cap-partial { border-left:3px solid #f0a020; padding-left:calc(.6rem - 3px); }
.pb-step-param label { min-width:110px; color:var(--text-2); flex-shrink:0; }
.pb-step-param input, .pb-step-param select, .pb-step-param textarea { flex:1; font-size:12px; }
.pb-step-param textarea { rows:2; resize:vertical; min-height:40px; }
......@@ -1688,19 +1709,42 @@ function updatePipelineBadges() {
}
badge.className = `pipe-badge ${state}`;
badge.textContent = label;
// Capability chips in the card body
const body = card.querySelector('.pipe-card-body');
if (body) {
let capsDiv = body.querySelector('.pipe-caps');
if (!capsDiv) {
capsDiv = document.createElement('div');
capsDiv.className = 'pipe-caps';
body.insertBefore(capsDiv, body.firstChild);
}
const fmt = c => c.replace(/_/g, ' ');
const reqChips = reqs.map(c =>
`<span class="pipe-cap-chip${allCaps.has(c) ? ' ok' : ' missing'}" title="${c}">${fmt(c)}</span>`
).join('');
const optChips = opts.map(c =>
`<span class="pipe-cap-chip optional${allCaps.has(c) ? ' ok' : ''}" title="${c}">${fmt(c)}</span>`
).join('');
const parts = [];
if (reqChips) parts.push(`<span class="pipe-caps-label">Requires</span>${reqChips}`);
if (optChips) parts.push(`<span class="pipe-caps-label">Optional</span>${optChips}`);
capsDiv.innerHTML = parts.join('');
}
});
}
function hasTypeFallback(rule, type) {
return Array.isArray(rule.fallbackTypes) && rule.fallbackTypes.includes(type);
function hasTypeFallback(rule, typeOrTypes) {
if (!Array.isArray(rule.fallbackTypes)) return false;
if (typeOrTypes instanceof Set) return rule.fallbackTypes.some(t => typeOrTypes.has(t));
return rule.fallbackTypes.includes(typeOrTypes);
}
function evaluateSubCapability(rule, caps, type) {
function evaluateSubCapability(rule, caps, typeOrTypes) {
const required = rule.requiresAny || [];
const optional = rule.optional || [];
const hasRequired = required.some(cap => caps.has(cap));
const hasOptional = optional.some(cap => caps.has(cap));
const fallback = hasTypeFallback(rule, type);
const fallback = hasTypeFallback(rule, typeOrTypes);
if (hasRequired) return 'available';
if (!required.length && (hasOptional || fallback)) return 'partial';
......@@ -1754,7 +1798,7 @@ function getSubtabState(sub) {
function getCapabilityDetails(sub) {
const def = STUDIO_CAPABILITIES[sub];
if (!def) return null;
const caps = capabilitySetForModel(activeModel);
const caps = crossModelCaps();
const required = def.requires || [];
const optional = def.optional || [];
const missingRequired = required.filter(cap => !caps.has(cap));
......@@ -1802,6 +1846,7 @@ function renderCapabilityCard(sub) {
</div>
${missingBits.join('')}
${notes}
${renderSubModelPicker(sub)}
`;
}
......@@ -1813,14 +1858,25 @@ function renderCapabilityCards() {
const shell = $(`cap-${sub}`);
if (!shell) return;
const state = currentTabState.subs[sub] || 'unavailable';
if (state === 'available') { shell.style.display = 'none'; shell.innerHTML = ''; return; }
const picker = renderSubModelPicker(sub);
if (state === 'available') {
if (picker) {
shell.style.display = '';
shell.classList.remove('state-partial', 'state-unavailable');
shell.innerHTML = picker;
} else {
shell.style.display = 'none';
shell.innerHTML = '';
}
return;
}
shell.style.display = '';
shell.classList.remove('state-partial', 'state-unavailable');
shell.classList.add(state === 'partial' ? 'state-partial' : 'state-unavailable');
const rule = SUB_CAPABILITY_RULES[sub];
const caps = capabilitySetForModel(activeModel);
const missingRequired = (rule.requiresAny || []).filter(c => !caps.has(c));
const missingOptional = (rule.optional || []).filter(c => !caps.has(c));
const allCaps = crossModelCaps();
const missingRequired = (rule.requiresAny || []).filter(c => !allCaps.has(c));
const missingOptional = (rule.optional || []).filter(c => !allCaps.has(c));
const label = document.querySelector(`.t2btn[data-sub="${sub}"]`)?.childNodes[0]?.textContent?.trim() || sub;
const missingBits = [];
if (missingRequired.length) missingBits.push(`<div class="cap-missing"><strong>Missing required:</strong> ${missingRequired.join(', ')}</div>`);
......@@ -1833,6 +1889,7 @@ function renderCapabilityCards() {
<span class="cap-chip${availabilityClass}">${availabilityLabel}</span>
</div>
${missingBits.join('')}
${picker}
`;
});
renderAudioBackendHealth();
......@@ -2471,22 +2528,21 @@ function selectModel(m) {
}
function updateTabs(m) {
const caps = capabilitySetForModel(m);
const allCaps = crossModelCaps();
const allTypes = new Set(models.map(mdl => mdl.type || 'text'));
const type = m.type || 'text';
refreshAudioBackendHealth();
const subStates = {};
Object.entries(SUB_CAPABILITY_RULES).forEach(([sub, rule]) => {
if (VIDEO_EXTRA_SUBS.includes(sub) && type === 'video' && !rule.fallbackTypes) {
if (VIDEO_EXTRA_SUBS.includes(sub) && allTypes.has('video') && !rule.fallbackTypes) {
rule = Object.assign({}, rule, { fallbackTypes:['video'] });
}
const evalCaps = CROSS_MODEL_SUBS.has(sub) ? allCaps : caps;
subStates[sub] = evaluateSubCapability(rule, evalCaps, type);
subStates[sub] = evaluateSubCapability(rule, allCaps, allTypes);
});
updatePipelineBadges();
const categoryStates = {};
CATEGORY_TABS.forEach(cat => {
categoryStates[cat] = evaluateCategoryState(cat, subStates, caps, type);
categoryStates[cat] = evaluateCategoryState(cat, subStates, allCaps, type);
});
currentTabState = { categories:categoryStates, subs:subStates };
......@@ -2496,7 +2552,7 @@ function updateTabs(m) {
document.querySelectorAll('.t2btn').forEach(btn => {
setTabVisualState(btn, subStates[btn.dataset.sub] || 'unavailable');
});
$('attach-btn').style.display = caps.has('image_to_text') ? '' : 'none';
$('attach-btn').style.display = capabilitySetForModel(m).has('image_to_text') ? '' : 'none';
renderCapabilityCards();
renderDiagnostics();
renderOutputCapabilityNotes();
......@@ -2510,6 +2566,7 @@ function selectCat(cat) {
const hasL2 = ['image','video','audio'].includes(cat);
$('tabbar2').classList.toggle('visible', hasL2);
if (!hasL2) {
clearSidebarHighlights();
document.querySelectorAll('.panel').forEach(p => p.classList.remove('active'));
const panel = $('panel-' + cat);
if (panel) panel.classList.add('active');
......@@ -2526,6 +2583,42 @@ function selectCat(cat) {
if (nextSub) selectSub(nextSub);
}
function modelsForSub(sub) {
const rule = SUB_CAPABILITY_RULES[sub];
if (!rule) return [];
return models.map(m => {
const state = evaluateSubCapability(rule, capabilitySetForModel(m), m.type || 'text');
return { model: m, state };
}).filter(item => item.state !== 'unavailable');
}
function renderSubModelPicker(sub) {
const compatible = modelsForSub(sub);
if (!compatible.length) return '';
const chips = compatible.map(({ model, state }) => {
const cls = state === 'available' ? 'ok' : 'warn';
const isActive = activeModel && model.id === activeModel.id;
const label = escapeHtml(model.id.split('/').pop());
const safe = JSON.stringify(model).replace(/"/g, '&quot;');
return `<button class="cap-model-chip ${cls}${isActive ? ' active' : ''}" onclick="selectModel(${safe})" title="${model.id}">${label}</button>`;
}).join('');
return `<div class="cap-model-picker"><span class="cap-model-picker-label">Models</span><div class="cap-model-chips">${chips}</div></div>`;
}
function highlightSidebarForSub(sub) {
const compatible = new Map(modelsForSub(sub).map(({ model, state }) => [model.id, state]));
document.querySelectorAll('.model-item').forEach(el => {
el.classList.remove('cap-ok', 'cap-partial');
const state = compatible.get(el.dataset.id);
if (state === 'available') el.classList.add('cap-ok');
else if (state === 'partial') el.classList.add('cap-partial');
});
}
function clearSidebarHighlights() {
document.querySelectorAll('.model-item').forEach(el => el.classList.remove('cap-ok', 'cap-partial'));
}
function selectSub(sub) {
if (SUB_CAT[sub] && currentTabState.subs[sub] === undefined) return;
if (SUB_CAT[sub]) {
......@@ -2543,6 +2636,7 @@ function selectSub(sub) {
// When switching to vid-faceswap, pre-select video mode
if (sub === 'vid-faceswap') { const t = $('fs-type'); if (t) { t.value='video'; fsFaceSwapTypeChange(); } }
if (sub === 'vid-outfit') { const t = $('ot-type'); if (t) { t.value='video'; otOutfitTypeChange(); } }
if (SUB_CAT[sub]) highlightSidebarForSub(sub);
}
// ─────────────────────────────────────────────────────────────────
......
......@@ -39,6 +39,74 @@ from codai.pydantic.imagerequest import ImageGenerationRequest
from codai.api.state import get_load_mode
# =============================================================================
# Prompt embedding cache (diffusers)
#
# Caches text-encoder outputs keyed by (prompt, negative_prompt, model_name).
# When the same prompt is requested again the encode step is skipped and the
# cached tensors are passed directly to the pipeline, saving CLIP/T5 compute.
# sd.cpp handles encoding internally — no equivalent caching is possible there.
# =============================================================================
import hashlib as _hashlib
import threading as _threading
class _PromptEmbedCache:
"""Single-entry LRU cache for diffusers prompt embeddings."""
_MAX_ENTRIES = 32
_TTL = 600.0 # 10 minutes
def __init__(self):
self._store: dict = {} # key -> (embeds_dict, timestamp)
self._lock = _threading.Lock()
@staticmethod
def _key(prompt: str, negative_prompt: str, model_name: str) -> str:
raw = f"{model_name}\x00{prompt}\x00{negative_prompt or ''}"
return _hashlib.sha256(raw.encode()).hexdigest()[:24]
def get(self, prompt: str, negative_prompt: str, model_name: str) -> Optional[dict]:
k = self._key(prompt, negative_prompt, model_name)
with self._lock:
entry = self._store.get(k)
if entry is None:
return None
embeds, ts = entry
if time.time() - ts > self._TTL:
del self._store[k]
return None
return embeds
def put(self, prompt: str, negative_prompt: str, model_name: str,
embeds: dict) -> None:
k = self._key(prompt, negative_prompt, model_name)
with self._lock:
self._store[k] = (embeds, time.time())
# Evict oldest if over limit
if len(self._store) > self._MAX_ENTRIES:
oldest = min(self._store, key=lambda x: self._store[x][1])
del self._store[oldest]
def invalidate_model(self, model_name: str) -> None:
"""Drop all entries for a model (e.g. on pipeline unload)."""
suffix = _hashlib.sha256(model_name.encode()).hexdigest()[:8]
with self._lock:
drop = [k for k in self._store
if self._key("", "", model_name)[:8] == k[:8] or True
# safest: just rebuild key and compare
]
# Rebuild properly: iterate and check by re-computing key prefix
# (can't reconstruct original prompts, so use model name hash marker)
self._store = {
k: v for k, v in self._store.items()
if not k.startswith(_hashlib.sha256(model_name.encode()).hexdigest()[:4])
}
_embed_cache = _PromptEmbedCache()
# Global reference to be set by coderai
global_args = None
global_file_path = None
......@@ -384,7 +452,7 @@ def _load_diffusers_pipeline(model_name: str, global_args):
def _generate_with_diffusers(pipeline, request, global_args, http_request=None):
"""Generate images using a diffusers pipeline."""
"""Generate images using a diffusers pipeline (with prompt-embedding cache)."""
import torch
import numpy as np
import time as time_module
......@@ -402,13 +470,12 @@ def _generate_with_diffusers(pipeline, request, global_args, http_request=None):
height = int(parts[1])
except ValueError:
pass
# Check for nan/inf in dimensions
if width != width or width == float('inf'):
width = 512
if height != height or height == float('inf'):
height = 512
# Enable memory optimizations
try:
if hasattr(pipeline, 'enable_attention_slicing'):
......@@ -417,58 +484,116 @@ def _generate_with_diffusers(pipeline, request, global_args, http_request=None):
pipeline.enable_vae_slicing()
except Exception as e:
print(f"Warning: Could not enable memory optimizations: {e}")
# Get timestamp BEFORE calling diffusers
timestamp = int(time_module.time())
# Generate images
seed = request.seed if request.seed is not None else getattr(global_args, 'image_seed', None)
generator = None
if seed is not None:
generator = torch.Generator(device=pipeline.device).manual_seed(seed)
# Quality: "standard" or "hd"
quality = request.quality or "standard"
# Use request parameters if provided, otherwise fall back to quality-based defaults
num_steps = request.steps if request.steps else (30 if quality == "standard" else 50)
cfg_scale = request.guidance_scale if request.guidance_scale else (
getattr(global_args, 'image_cfg_scale', 7.5) if quality == "standard" else 9.0
)
# Generate
result = pipeline(
prompt=request.prompt,
negative_prompt=None,
num_images_per_prompt=request.n,
height=height,
width=width,
generator=generator,
guidance_scale=cfg_scale,
num_inference_steps=num_steps,
)
# ------------------------------------------------------------------
# Prompt embedding cache
# Try to encode the prompt once and reuse the embeddings.
# Falls back to passing the plain text prompt if encoding fails.
# ------------------------------------------------------------------
model_id = getattr(pipeline, 'model_name_or_path', None) or str(type(pipeline).__name__)
neg_prompt = getattr(request, 'negative_prompt', None) or ""
do_cfg = cfg_scale > 1.0
cached_embeds = _embed_cache.get(request.prompt, neg_prompt, model_id)
embed_kwargs = {}
cache_hit = False
if cached_embeds is not None:
embed_kwargs = cached_embeds
cache_hit = True
print(f"Prompt embed cache HIT for model '{model_id}'")
else:
# Try to encode and cache
try:
if hasattr(pipeline, 'encode_prompt'):
enc = pipeline.encode_prompt(
prompt=request.prompt,
device=pipeline.device,
num_images_per_prompt=1,
do_classifier_free_guidance=do_cfg,
negative_prompt=neg_prompt or None,
)
# enc is a tuple; length varies by pipeline type
if len(enc) == 2:
# SD 1.x: (prompt_embeds, negative_prompt_embeds)
embed_kwargs = {
'prompt_embeds': enc[0],
'negative_prompt_embeds': enc[1],
}
elif len(enc) == 4:
# SDXL: (prompt_embeds, negative_prompt_embeds,
# pooled_prompt_embeds, negative_pooled_prompt_embeds)
embed_kwargs = {
'prompt_embeds': enc[0],
'negative_prompt_embeds': enc[1],
'pooled_prompt_embeds': enc[2],
'negative_pooled_prompt_embeds': enc[3],
}
if embed_kwargs:
_embed_cache.put(request.prompt, neg_prompt, model_id, embed_kwargs)
print(f"Prompt embed cache STORE for model '{model_id}'")
except Exception as e:
print(f"Warning: prompt encode/cache failed ({e}), using plain text prompt")
embed_kwargs = {}
# Build call kwargs
if embed_kwargs:
call_kwargs = dict(
num_images_per_prompt=request.n,
height=height,
width=width,
generator=generator,
guidance_scale=cfg_scale,
num_inference_steps=num_steps,
**embed_kwargs,
)
else:
call_kwargs = dict(
prompt=request.prompt,
negative_prompt=neg_prompt or None,
num_images_per_prompt=request.n,
height=height,
width=width,
generator=generator,
guidance_scale=cfg_scale,
num_inference_steps=num_steps,
)
result = pipeline(**call_kwargs)
# Extract images
images = []
try:
result_images = result.images
except Exception as img_err:
print(f"Warning: Could not access result.images: {img_err}")
result_images = getattr(result, 'image', None) or getattr(result, 'output', None)
if result_images is None:
raise Exception(f"Could not extract images from diffusers result: {img_err}")
for img in result_images:
if isinstance(img, np.ndarray):
img = np.nan_to_num(img, nan=0.0, posinf=1.0, neginf=0.0)
img = np.clip(img, 0.0, 1.0)
img_data = save_image_response(img, request.response_format, http_request)
images.append(img_data)
return {
"created": timestamp,
"data": images
"data": images,
"prompt_cache_hit": cache_hit,
}
......
# CoderAI - OpenAI-compatible API server
# Copyright (C) 2026 Stefy Lanza <stefy@nexlab.net>
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
Prompt prefix cache manager.
Provides two features:
1. Prefix key computation for same-prefix request scheduling (prompt aggregation).
2. Per-model last-prompt tracking so callers can report accurate cached_tokens.
llama.cpp's KV cache naturally reuses computation when consecutive requests
share a prompt prefix. This manager helps exploit that by:
- Giving the scheduler a stable key to group requests by shared prefix.
- Letting text.py read back how many tokens were cached from timings.
"""
import hashlib
import json
import time
from dataclasses import dataclass, field
from threading import Lock
from typing import Dict, List, Optional
@dataclass
class _CacheEntry:
messages_hash: str
prefix_hash: str
token_count: int
timestamp: float = field(default_factory=time.time)
class PromptCacheManager:
"""
Tracks recently-processed prompt prefixes per model instance.
Usage
-----
# Before dispatching to the model:
prefix_key = manager.get_prefix_key(messages) # for QueueManager scheduling
# After the model call completes:
manager.store(messages, model_key, prompt_tokens)
# In the API response usage block:
cached = manager.get_cached_tokens(model_key) # from last store
"""
def __init__(self, max_entries: int = 256, ttl_seconds: float = 600.0):
self._entries: Dict[str, _CacheEntry] = {}
self._by_model: Dict[str, str] = {} # model_key -> last messages_hash
self._cached_tokens: Dict[str, int] = {} # model_key -> cached tokens from last call
self._max_entries = max_entries
self._ttl = ttl_seconds
self._lock = Lock()
# ------------------------------------------------------------------
# Hashing helpers
# ------------------------------------------------------------------
def _hash_messages(self, messages: List[Dict]) -> str:
"""Stable SHA-256 hash (truncated) of a message list."""
canonical = json.dumps(
[{"role": m.get("role"), "content": m.get("content")} for m in messages],
separators=(",", ":"),
ensure_ascii=False,
)
return hashlib.sha256(canonical.encode()).hexdigest()[:20]
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
def get_prefix_key(self, messages: List[Dict]) -> str:
"""
Stable key for the *cacheable* portion of a request.
The cacheable prefix is everything except the final user turn, since
system prompts and prior assistant turns stay constant across related
requests and benefit most from KV cache reuse.
Returns an empty string when there is no cacheable prefix.
"""
if not messages:
return ""
prefix = messages[:-1] if messages[-1].get("role") == "user" else messages
return self._hash_messages(prefix) if prefix else ""
def store(self, messages: List[Dict], model_key: str, prompt_tokens: int,
cached_tokens: int = 0) -> None:
"""Record a completed prompt so future requests can match against it."""
with self._lock:
msg_hash = self._hash_messages(messages)
prefix_hash = self.get_prefix_key(messages)
self._entries[msg_hash] = _CacheEntry(
messages_hash=msg_hash,
prefix_hash=prefix_hash,
token_count=prompt_tokens,
)
self._by_model[model_key] = msg_hash
self._cached_tokens[model_key] = cached_tokens
self._evict_locked()
def get_cached_tokens(self, model_key: str) -> int:
"""Return the cached_tokens count stored by the last store() call for this model."""
with self._lock:
return self._cached_tokens.get(model_key, 0)
def has_warm_prefix(self, messages: List[Dict], model_key: str) -> bool:
"""
Return True if the current request shares a prefix with the last
request processed by this model (i.e., the KV cache is likely warm).
"""
with self._lock:
last_hash = self._by_model.get(model_key)
if not last_hash:
return False
entry = self._entries.get(last_hash)
if not entry or time.time() - entry.timestamp > self._ttl:
return False
current_prefix = self.get_prefix_key(messages)
return bool(current_prefix and current_prefix == entry.prefix_hash)
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
def _evict_locked(self) -> None:
now = time.time()
expired = [k for k, v in self._entries.items() if now - v.timestamp > self._ttl]
for k in expired:
del self._entries[k]
while len(self._entries) > self._max_entries:
oldest = min(self._entries, key=lambda k: self._entries[k].timestamp)
del self._entries[oldest]
prompt_cache_manager = PromptCacheManager()
......@@ -32,6 +32,7 @@ logger = logging.getLogger(__name__)
# Import from codai modules
from codai.models.manager import ModelManager, WhisperServerManager, MultiModelManager, model_manager, multi_model_manager
from codai.queue.manager import QueueManager, queue_manager
from codai.api.prompt_cache import prompt_cache_manager
from codai.pydantic.textrequest import ChatCompletionRequest, ToolFunction, Tool
from codai.models.parser import filter_malformed_content, filter_repetition, format_tools_for_prompt, cleanup_control_tokens, OpenAIFormatter, ModelParserAdapter, ToolCallParser
......@@ -1142,6 +1143,9 @@ async def chat_completions(request: ChatCompletionRequest, http_request: Request
from fastapi.responses import JSONResponse
return JSONResponse(content=formatted_response, headers=headers)
# Compute prefix key for prompt-aggregation scheduling
_prefix_key = prompt_cache_manager.get_prefix_key(messages_dict)
if request.stream:
async def _managed_stream():
try:
......@@ -1156,6 +1160,7 @@ async def chat_completions(request: ChatCompletionRequest, http_request: Request
current_manager,
tool_parser,
request.response_format,
_prefix_key,
):
yield chunk
finally:
......@@ -1192,6 +1197,7 @@ async def stream_chat_response(
current_manager: ModelManager,
tool_parser: ToolCallParser,
response_format: Optional[Dict] = None,
prefix_key: str = "",
) -> AsyncGenerator[str, None]:
"""Stream chat completion response with queue notifications."""
completion_id = f"chatcmpl-{uuid.uuid4().hex}"
......@@ -1214,7 +1220,7 @@ async def stream_chat_response(
# If model not loaded, add to queue and send waiting notifications
if not model_loaded:
await queue_manager.add_waiting(request_id)
await queue_manager.add_waiting(request_id, prefix_key=prefix_key)
wait_interval = 2.0 # Send waiting update every 2 seconds
last_wait_update = time.time()
......@@ -1457,10 +1463,24 @@ async def stream_chat_response(
prompt_text = "\n".join([m.get("content", "") for m in messages])
prompt_tokens = len(prompt_text.split())
completion_tokens = len(generated_text.split()) if generated_text else 0
# Read accurate usage (including cached_tokens) from the backend
_model_key_for_cache = getattr(current_manager, 'model_name', None) or model_name
last_usage = (current_manager.get_last_usage()
if hasattr(current_manager, 'get_last_usage') else {})
if last_usage.get('prompt_tokens'):
prompt_tokens = last_usage['prompt_tokens']
if last_usage.get('completion_tokens'):
completion_tokens = last_usage['completion_tokens']
cached_tokens = last_usage.get('cached_tokens', 0)
# Store in prompt cache manager for future prefix matching
prompt_cache_manager.store(messages, _model_key_for_cache,
prompt_tokens, cached_tokens)
# Get context size
context_size = current_manager.get_context_size()
# Build complete final chunk with all OpenAI fields
final_chunk = {
"id": completion_id,
......@@ -1479,7 +1499,7 @@ async def stream_chat_response(
"total_tokens": prompt_tokens + completion_tokens,
"context_size": context_size,
"prompt_tokens_details": {
"cached_tokens": 0,
"cached_tokens": cached_tokens,
"audio_tokens": 0,
},
"completion_tokens_details": {
......@@ -1494,7 +1514,7 @@ async def stream_chat_response(
"system_fingerprint": None,
}
yield f"data: {json.dumps(final_chunk)}\n\n"
yield "data: [DONE]\n\n"
except Exception as e:
print(f"Error during streaming generation: {e}")
......@@ -1638,11 +1658,20 @@ async def generate_chat_response(
response_message["tool_calls"] = tool_calls
finish_reason = "tool_calls"
# Calculate token counts - rough estimate since we don't have direct access to tokenizer
# Read accurate usage (including cached_tokens) from the backend
_model_key_for_cache = getattr(current_manager, 'model_name', None) or model_name
last_usage = (current_manager.get_last_usage()
if hasattr(current_manager, 'get_last_usage') else {})
prompt_text = "\n".join([m.get("content", "") for m in messages])
prompt_tokens = len(prompt_text.split())
completion_tokens = len(generated_text.split()) if generated_text else 0
prompt_tokens = last_usage.get('prompt_tokens') or len(prompt_text.split())
completion_tokens = last_usage.get('completion_tokens') or (
len(generated_text.split()) if generated_text else 0)
cached_tokens = last_usage.get('cached_tokens', 0)
# Store in prompt cache manager for future prefix matching
prompt_cache_manager.store(messages, _model_key_for_cache,
prompt_tokens, cached_tokens)
# Get context size
context_size = current_manager.get_context_size()
......@@ -1655,6 +1684,10 @@ async def generate_chat_response(
tool_calls=response_message.get("tool_calls"),
context_size=context_size
)
# Patch in the real cached_tokens value
if formatted_response and 'usage' in formatted_response:
details = formatted_response['usage'].setdefault('prompt_tokens_details', {})
details['cached_tokens'] = cached_tokens
# Add mock reasoning stats if 'mock' is in force_reasoning_args
# But only if we don't already have real reasoning in the response
......
......@@ -17,6 +17,7 @@
"""CUDA backend using HuggingFace Transformers."""
import os
import time as _time
from typing import Optional, List, Dict
from threading import Thread
from abc import ABC
......@@ -53,6 +54,13 @@ class NvidiaBackend(ModelBackend):
self.device = None
self.use_flash_attn = False
self.flash_attn_available = False
# KV prefix cache (single-entry, keyed by formatted prefix text)
self._kv_prefix_text: Optional[str] = None
self._kv_past_key_values = None # past_key_values tensor tuple
self._kv_prefix_len: int = 0 # token count of the cached prefix
self._kv_timestamp: float = 0.0
self._kv_ttl: float = 300.0 # 5 min TTL
self._last_usage: Dict = {}
def check_flash_attn_support(self) -> None:
"""Check and print Flash Attention availability status."""
......@@ -872,11 +880,288 @@ class NvidiaBackend(ModelBackend):
elif generation_error:
yield f"\n[Error during generation: {generation_error}]"
# ------------------------------------------------------------------
# KV prefix cache helpers
# ------------------------------------------------------------------
def _kv_cache_valid(self) -> bool:
return (
self._kv_past_key_values is not None and
_time.time() - self._kv_timestamp < self._kv_ttl
)
def _build_kv_prefix(self, prefix_text: str):
"""Forward-pass on prefix_text to populate the KV state."""
import torch
inputs = self.tokenizer(
prefix_text, return_tensors="pt", add_special_tokens=False
)
inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
with torch.no_grad():
out = self.model(**inputs, use_cache=True, return_dict=True)
return out.past_key_values, int(inputs['input_ids'].shape[1])
def _store_kv(self, prefix_text: str, past_kv, prefix_len: int) -> None:
self._kv_prefix_text = prefix_text
self._kv_past_key_values = past_kv
self._kv_prefix_len = prefix_len
self._kv_timestamp = _time.time()
def invalidate_kv_cache(self) -> None:
"""Discard the cached KV state (call on model unload/swap)."""
self._kv_prefix_text = None
self._kv_past_key_values = None
self._kv_prefix_len = 0
self._kv_timestamp = 0.0
# ------------------------------------------------------------------
# Usage tracking
# ------------------------------------------------------------------
def get_last_usage(self) -> dict:
return dict(self._last_usage)
# ------------------------------------------------------------------
# Chat-level generation (with KV prefix caching)
# ------------------------------------------------------------------
def _format_messages_to_str(self, messages) -> str:
"""Convert a list of message dicts to a formatted prompt string."""
from codai.pydantic.textrequest import ChatMessage
chat_msgs = [
ChatMessage(**m) if isinstance(m, dict) else m
for m in messages
]
return self.format_messages(chat_msgs)
def generate_chat(self, messages, max_tokens=None, temperature=0.7,
top_p=1.0, stop=None, tools=None, response_format=None) -> str:
"""
Non-streaming chat generation with KV prefix caching.
Detects when the current request shares a system-prompt / history
prefix with the previous request and reuses the cached KV state,
only encoding the new suffix tokens.
"""
import torch
if max_tokens is None:
max_tokens = 512
full_prompt = self._format_messages_to_str(messages)
total_input_ids = self.tokenizer(full_prompt, return_tensors="pt")['input_ids']
total_prompt_len = int(total_input_ids.shape[1])
# Build prefix text (all turns except the final user turn)
prefix_msgs = (
messages[:-1]
if messages and messages[-1].get('role') == 'user'
else []
)
past_kv = None
cached_len = 0
if prefix_msgs:
prefix_text = self._format_messages_to_str(prefix_msgs)
if self._kv_cache_valid() and self._kv_prefix_text == prefix_text:
past_kv = self._kv_past_key_values
cached_len = self._kv_prefix_len
else:
try:
past_kv, cached_len = self._build_kv_prefix(prefix_text)
self._store_kv(prefix_text, past_kv, cached_len)
except Exception as e:
print(f"Warning: KV prefix cache build failed: {e}")
past_kv, cached_len = None, 0
temperature, top_p, do_sample = self._validate_params(temperature, top_p)
gen_kwargs = dict(
max_new_tokens=max_tokens,
temperature=temperature if do_sample else None,
top_p=top_p if do_sample else None,
do_sample=do_sample,
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id,
use_cache=True,
)
generated_text = ""
try:
total_input_ids = total_input_ids.to(self.model.device)
if past_kv is not None and 0 < cached_len < total_prompt_len:
suffix_ids = total_input_ids[:, cached_len:]
full_attn = torch.ones(
1, total_prompt_len, dtype=torch.long, device=self.model.device
)
with torch.no_grad():
outputs = self.model.generate(
input_ids=suffix_ids,
past_key_values=past_kv,
attention_mask=full_attn,
**gen_kwargs,
)
new_tokens = outputs[0][suffix_ids.shape[1]:]
else:
cached_len = 0
attn_mask = torch.ones_like(total_input_ids)
with torch.no_grad():
outputs = self.model.generate(
input_ids=total_input_ids,
attention_mask=attn_mask,
**gen_kwargs,
)
new_tokens = outputs[0][total_prompt_len:]
generated_text = self.tokenizer.decode(new_tokens, skip_special_tokens=True)
except Exception as e:
print(f"Warning: KV-cached generate_chat failed ({e}), retrying without cache")
cached_len = 0
try:
total_input_ids = self.tokenizer(
full_prompt, return_tensors="pt"
)['input_ids'].to(self.model.device)
attn_mask = torch.ones_like(total_input_ids)
with torch.no_grad():
outputs = self.model.generate(
input_ids=total_input_ids,
attention_mask=attn_mask,
**gen_kwargs,
)
new_tokens = outputs[0][total_prompt_len:]
generated_text = self.tokenizer.decode(new_tokens, skip_special_tokens=True)
except Exception as e2:
print(f"Error: generate_chat fallback failed: {e2}")
generated_text = ""
try:
comp_len = len(self.tokenizer.encode(generated_text)) if generated_text else 0
except Exception:
comp_len = len(generated_text.split())
self._last_usage = {
'prompt_tokens': total_prompt_len,
'completion_tokens': comp_len,
'cached_tokens': cached_len,
}
return generated_text
async def generate_chat_stream(self, messages, max_tokens=None,
temperature=0.7, top_p=1.0, stop=None,
tools=None, response_format=None):
"""
Streaming chat generation with KV prefix caching.
Uses the same prefix-cache strategy as generate_chat.
"""
import torch
from transformers import TextIteratorStreamer, StoppingCriteria, StoppingCriteriaList
from threading import Thread
if max_tokens is None:
max_tokens = 512
full_prompt = self._format_messages_to_str(messages)
total_input_ids = self.tokenizer(full_prompt, return_tensors="pt")['input_ids']
total_prompt_len = int(total_input_ids.shape[1])
prefix_msgs = (
messages[:-1]
if messages and messages[-1].get('role') == 'user'
else []
)
past_kv = None
cached_len = 0
if prefix_msgs:
prefix_text = self._format_messages_to_str(prefix_msgs)
if self._kv_cache_valid() and self._kv_prefix_text == prefix_text:
past_kv = self._kv_past_key_values
cached_len = self._kv_prefix_len
else:
try:
past_kv, cached_len = self._build_kv_prefix(prefix_text)
self._store_kv(prefix_text, past_kv, cached_len)
except Exception as e:
print(f"Warning: KV prefix cache build failed (stream): {e}")
past_kv, cached_len = None, 0
temperature, top_p, do_sample = self._validate_params(temperature, top_p)
total_input_ids = total_input_ids.to(self.model.device)
if past_kv is not None and 0 < cached_len < total_prompt_len:
gen_input_ids = total_input_ids[:, cached_len:]
full_attn = torch.ones(
1, total_prompt_len, dtype=torch.long, device=self.model.device
)
extra_gen = {'past_key_values': past_kv, 'attention_mask': full_attn}
else:
cached_len = 0
gen_input_ids = total_input_ids
extra_gen = {'attention_mask': torch.ones_like(total_input_ids)}
streamer = TextIteratorStreamer(
self.tokenizer, skip_prompt=True, skip_special_tokens=True
)
gen_kwargs = dict(
input_ids=gen_input_ids,
max_new_tokens=max_tokens,
temperature=temperature if do_sample else None,
top_p=top_p if do_sample else None,
do_sample=do_sample,
streamer=streamer,
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id,
use_cache=True,
**extra_gen,
)
if stop:
class _StopOnSeq(StoppingCriteria):
def __init__(self, seqs, tok):
self.seqs = seqs
self.tok = tok
def __call__(self, input_ids, scores, **kw):
decoded = self.tok.decode(input_ids[0][-20:], skip_special_tokens=True)
return any(s in decoded for s in self.seqs)
gen_kwargs['stopping_criteria'] = StoppingCriteriaList(
[_StopOnSeq(stop, self.tokenizer)]
)
gen_error = [None]
comp_tokens = [0]
def _run():
try:
with torch.no_grad():
self.model.generate(**gen_kwargs)
except Exception as e:
gen_error[0] = str(e)
thread = Thread(target=_run)
thread.start()
try:
for text in streamer:
comp_tokens[0] += 1
yield text
except Exception as e:
print(f"Error during KV-cached stream iteration: {e}")
finally:
thread.join()
self._last_usage = {
'prompt_tokens': total_prompt_len,
'completion_tokens': comp_tokens[0],
'cached_tokens': cached_len,
}
if gen_error[0]:
print(f"Warning: KV-cached stream generation error: {gen_error[0]}")
def get_model_name(self) -> str:
return self.model_name or "unknown"
def cleanup(self) -> None:
import torch
self.invalidate_kv_cache()
if self.model is not None:
del self.model
del self.tokenizer
......
......@@ -63,6 +63,7 @@ class VulkanBackend(ModelBackend):
self.force_cuda = original_backend in ("nvidia", "cuda") # Force CUDA if original was nvidia
if self.force_cuda:
print("DEBUG: GGUF model will use CUDA backend (forced by --backend nvidia)")
self._last_usage: dict = {} # usage from the most recent completion call
self._detect_chat_template()
def _detect_chat_template(self):
......@@ -649,6 +650,8 @@ class VulkanBackend(ModelBackend):
stop=stop,
grammar=use_grammar,
)
usage = result.get('usage', {})
self._store_usage(usage.get('prompt_tokens', 0), usage.get('completion_tokens', 0))
return result['choices'][0]['text']
except Exception as e:
# If grammar generation fails, fall back to normal generation
......@@ -664,6 +667,8 @@ class VulkanBackend(ModelBackend):
repeat_penalty=repeat_penalty,
stop=stop,
)
usage = result.get('usage', {})
self._store_usage(usage.get('prompt_tokens', 0), usage.get('completion_tokens', 0))
return result['choices'][0]['text']
except Exception as e2:
print(f"Error during fallback generation: {e2}")
......@@ -935,6 +940,112 @@ class VulkanBackend(ModelBackend):
"n_gpu_layers": self.n_gpu_layers,
}
# ------------------------------------------------------------------
# Usage / cache helpers
# ------------------------------------------------------------------
def _read_cached_tokens(self, prompt_tokens: int) -> int:
"""Extract cached token count from llama.cpp timings after a completion."""
try:
timings = getattr(self.model, 'timings', None)
if timings is None:
# Try the internal context if timings property not exposed
ctx = getattr(self.model, '_ctx', None)
if ctx and hasattr(ctx, 'timings'):
timings = ctx.timings()
if timings is not None:
n_p_eval = getattr(timings, 'n_p_eval', None)
if n_p_eval is not None:
return max(0, prompt_tokens - int(n_p_eval))
except Exception:
pass
return 0
def _store_usage(self, prompt_tokens: int, completion_tokens: int) -> None:
cached = self._read_cached_tokens(prompt_tokens)
self._last_usage = {
'prompt_tokens': prompt_tokens,
'completion_tokens': completion_tokens,
'total_tokens': prompt_tokens + completion_tokens,
'cached_tokens': cached,
}
def get_last_usage(self) -> dict:
"""Return usage dict from the most recent completion (includes cached_tokens)."""
return dict(self._last_usage)
# ------------------------------------------------------------------
# Chat-level generation (uses llama.cpp native chat template)
# ------------------------------------------------------------------
def generate_chat(self, messages, max_tokens=None, temperature=0.7, top_p=1.0,
stop=None, tools=None, response_format=None):
"""Non-streaming chat completion using llama.cpp's native chat handler."""
if self.model is None:
raise RuntimeError("Model not loaded")
kwargs = dict(
messages=messages,
max_tokens=max_tokens or 512,
temperature=temperature,
top_p=top_p,
)
if stop:
kwargs['stop'] = stop
if response_format and response_format.get('type') == 'json_object':
kwargs['response_format'] = {'type': 'json_object'}
result = self.model.create_chat_completion(**kwargs)
usage = result.get('usage', {})
self._store_usage(
prompt_tokens=usage.get('prompt_tokens', 0),
completion_tokens=usage.get('completion_tokens', 0),
)
content = result['choices'][0]['message'].get('content') or ''
return content
async def generate_chat_stream(self, messages, max_tokens=None, temperature=0.7,
top_p=1.0, stop=None, tools=None, response_format=None):
"""Streaming chat completion using llama.cpp's native chat handler."""
if self.model is None:
raise RuntimeError("Model not loaded")
kwargs = dict(
messages=messages,
max_tokens=max_tokens or 512,
temperature=temperature,
top_p=top_p,
stream=True,
)
if stop:
kwargs['stop'] = stop
prompt_tokens = 0
completion_tokens = 0
try:
for chunk in self.model.create_chat_completion(**kwargs):
delta = chunk['choices'][0].get('delta', {})
text = delta.get('content') or ''
if text:
completion_tokens += 1
yield text
# Capture usage if present in final streaming chunk
if chunk.get('usage'):
u = chunk['usage']
prompt_tokens = u.get('prompt_tokens', 0)
completion_tokens = u.get('completion_tokens', completion_tokens)
if chunk['choices'][0].get('finish_reason'):
break
finally:
# Timings are available after the stream is exhausted
if prompt_tokens == 0:
# Estimate from word split if llama.cpp didn't report
prompt_tokens = sum(
len(str(m.get('content', '')).split())
for m in messages
)
self._store_usage(prompt_tokens, completion_tokens)
def get_model_name(self) -> str:
"""Return the loaded model name."""
return self.model_name or "unknown"
......
......@@ -233,6 +233,12 @@ class ModelManager:
if self.backend is not None:
return self.backend.get_context_size()
return 2048 # Default fallback
def get_last_usage(self) -> dict:
"""Return usage info (including cached_tokens) from the most recent call."""
if self.backend is not None and hasattr(self.backend, 'get_last_usage'):
return self.backend.get_last_usage()
return {}
def cleanup(self):
if self.backend is not None:
......@@ -2040,11 +2046,29 @@ class MultiModelManager:
"embedding_models": "embedding",
}
# Minimum capability guaranteed by a model's config category.
# Applied when heuristic name detection doesn't recognise the model ID.
TYPE_MIN_CAP = {
"image": "image_generation",
"video": "video_generation",
"audio": "speech_to_text",
"tts": "text_to_speech",
"audio_gen": "audio_generation",
"embedding": "embeddings",
}
def _add(model_id: str, model_type: str = None, meta: Dict[str, Any] = None):
if model_id in seen_ids:
return
seen_ids.add(model_id)
caps = detect_model_capabilities(model_id)
# If heuristic detection missed the type (e.g. custom/vendor model IDs
# that don't match any keyword), ensure the minimum capability for the
# config-declared type is set so badges display correctly.
if model_type and model_type in TYPE_MIN_CAP:
min_cap = TYPE_MIN_CAP[model_type]
if not getattr(caps, min_cap, False):
setattr(caps, min_cap, True)
resolved_type = model_type or (caps.to_list()[0].split("_")[0] if caps.to_list() else "text")
meta = meta or {}
models.append(ModelInfo(
......@@ -2075,6 +2099,11 @@ class MultiModelManager:
else:
raw = m.get("path") or m.get("id") or ""
alias = m.get("alias") or ""
# Auto-derive a clean alias for GGUF files that have no
# explicit alias so the full filesystem path isn't exposed.
if not alias and raw.lower().endswith(".gguf"):
stem = raw.split("/")[-1][:-5] # filename without .gguf
alias = stem
# whisper-server aliases are round-robin group keys shared across
# multiple instances — don't expose the alias as a separate model
if m.get("backend") == "whisper-server":
......
......@@ -39,6 +39,7 @@ class ChatMessage(BaseModel):
name: Optional[str] = None
tool_calls: Optional[List[Dict]] = None
tool_call_id: Optional[str] = None
cache_control: Optional[Dict] = None # OpenAI-style: {"type": "ephemeral"}
@field_validator('content', mode='before')
@classmethod
......
......@@ -40,6 +40,7 @@ class WaitingRequest:
sequence: int
event: asyncio.Event = field(default_factory=asyncio.Event)
bypassed_by: int = 0
prefix_key: str = "" # stable hash of the cacheable prompt prefix
class QueueManager:
......@@ -61,6 +62,7 @@ class QueueManager:
self.model_name: Optional[str] = None
self._processing: bool = False
self._ready_request_ids: Set[str] = set()
self._last_prefix_key: str = "" # prefix key of the last completed request
def set_loaded_models(self, model_keys: Set[str]) -> None:
self.loaded_models = set(model_keys)
......@@ -83,17 +85,19 @@ class QueueManager:
self.model_name = None
self._processing = False
self._ready_request_ids.clear()
self._last_prefix_key = ""
async def is_full(self) -> bool:
async with self.lock:
return len(self.waiting) >= self.max_size
async def acquire(self, request_id: str, model_key: str) -> SchedulerLease:
async def acquire(self, request_id: str, model_key: str,
prefix_key: str = "") -> SchedulerLease:
waiter = None
async with self.lock:
if self._can_start_now(model_key):
return self._grant_lease(request_id, model_key)
waiter = self._enqueue_waiter(request_id, model_key)
waiter = self._enqueue_waiter(request_id, model_key, prefix_key)
await waiter.event.wait()
async with self.lock:
......@@ -103,7 +107,8 @@ class QueueManager:
lease.wait_time_seconds = max(0.0, time.time() - waiter.enqueued_at)
return lease
async def release(self, lease: SchedulerLease) -> None:
async def release(self, lease: SchedulerLease,
prefix_key: str = "") -> None:
async with self.lock:
self.active_leases.pop(lease.request_id, None)
current = self.active_by_model.get(lease.model_key, 0)
......@@ -113,14 +118,17 @@ class QueueManager:
self.active_by_model[lease.model_key] = current - 1
if self.current_request_id == lease.request_id:
self.current_request_id = None
if prefix_key:
self._last_prefix_key = prefix_key
self._processing = bool(self.active_leases)
self._wake_waiters_locked()
async def add_waiting(self, request_id: str, model_key: str = "") -> None:
async def add_waiting(self, request_id: str, model_key: str = "",
prefix_key: str = "") -> None:
async with self.lock:
if request_id in self.waiting_by_id:
return
self._enqueue_waiter(request_id, model_key or request_id)
self._enqueue_waiter(request_id, model_key or request_id, prefix_key)
async def remove_waiting(self, request_id: str) -> None:
async with self.lock:
......@@ -172,13 +180,15 @@ class QueueManager:
"loaded_models": sorted(self.loaded_models),
}
def _enqueue_waiter(self, request_id: str, model_key: str) -> WaitingRequest:
def _enqueue_waiter(self, request_id: str, model_key: str,
prefix_key: str = "") -> WaitingRequest:
self.sequence += 1
waiter = WaitingRequest(
request_id=request_id,
model_key=model_key,
enqueued_at=time.time(),
sequence=self.sequence,
prefix_key=prefix_key,
)
self.waiting.append(waiter)
self.waiting_by_id[request_id] = waiter
......@@ -233,17 +243,39 @@ class QueueManager:
return
def _pick_next_waiter_locked(self) -> Optional[WaitingRequest]:
for waiter in self.waiting:
if self._waiter_can_start_locked(waiter):
older_blocked = [
other for other in self.waiting
if other.sequence < waiter.sequence and not self._waiter_can_start_locked(other)
]
if any(other.bypassed_by >= self.fairness_bypass_limit for other in older_blocked):
continue
for other in older_blocked:
other.bypassed_by += 1
# Collect all candidates that can start now.
candidates = [w for w in self.waiting if self._waiter_can_start_locked(w)]
if not candidates:
return None
# Fairness: don't bypass an older waiter more than the limit.
def _is_fair(waiter: WaitingRequest) -> bool:
older_blocked = [
other for other in self.waiting
if other.sequence < waiter.sequence and not self._waiter_can_start_locked(other)
]
if any(other.bypassed_by >= self.fairness_bypass_limit for other in older_blocked):
return False
for other in older_blocked:
other.bypassed_by += 1
return True
# Prompt aggregation: prefer candidates whose prefix key matches the
# last completed request — they will hit a warm KV cache.
if self._last_prefix_key:
warm_candidates = [
w for w in candidates
if w.prefix_key and w.prefix_key == self._last_prefix_key
]
for waiter in warm_candidates:
if _is_fair(waiter):
return waiter
# Fall back to FIFO order.
for waiter in candidates:
if _is_fair(waiter):
return waiter
return None
def _waiting_counts_locked(self) -> Dict[str, int]:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment