Add prompt caching and prompt aggregation

parent 0ac26bed
...@@ -13,6 +13,14 @@ An OpenAI-compatible API server to run models on your local GPU with web adminis ...@@ -13,6 +13,14 @@ An OpenAI-compatible API server to run models on your local GPU with web adminis
- **Multi-Modal**: Text, image, video, audio, TTS, STT, embeddings - **Multi-Modal**: Text, image, video, audio, TTS, STT, embeddings
- **Per-Model Configuration**: Individual settings for each model (GPU layers, quantization, context size) - **Per-Model Configuration**: Individual settings for each model (GPU layers, quantization, context size)
- **On-Demand Loading**: Models load automatically when requested, unload when idle - **On-Demand Loading**: Models load automatically when requested, unload when idle
- **Memory Management**: Smart VRAM → RAM → Disk offloading for efficient resource usage
- **Parallel Execution**: Run multiple models simultaneously (VRAM permitting)
- **Auto-Swap**: Automatic model switching on request — load what's needed, unload what's idle
- **Request Queue**: Concurrent requests are queued and processed in order per model
- **Prompt Caching**: Reuse KV cache across requests to reduce latency and computation
- **Prompt Aggregation**: Batch concurrent requests into a single inference pass for higher throughput
- **Custom Pipelines**: Create and save multi-step workflows combining any generation tasks
- **Pre-Built Pipelines**: Ready-to-use pipelines for common workflows (image-to-video, dubbing, story generation)
### GPU Backend Support ### GPU Backend Support
- **NVIDIA (CUDA)**: PyTorch + Transformers for HuggingFace models - **NVIDIA (CUDA)**: PyTorch + Transformers for HuggingFace models
......
This diff is collapsed.
...@@ -39,6 +39,74 @@ from codai.pydantic.imagerequest import ImageGenerationRequest ...@@ -39,6 +39,74 @@ from codai.pydantic.imagerequest import ImageGenerationRequest
from codai.api.state import get_load_mode from codai.api.state import get_load_mode
# =============================================================================
# Prompt embedding cache (diffusers)
#
# Caches text-encoder outputs keyed by (prompt, negative_prompt, model_name).
# When the same prompt is requested again the encode step is skipped and the
# cached tensors are passed directly to the pipeline, saving CLIP/T5 compute.
# sd.cpp handles encoding internally — no equivalent caching is possible there.
# =============================================================================
import hashlib as _hashlib
import threading as _threading
class _PromptEmbedCache:
"""Single-entry LRU cache for diffusers prompt embeddings."""
_MAX_ENTRIES = 32
_TTL = 600.0 # 10 minutes
def __init__(self):
self._store: dict = {} # key -> (embeds_dict, timestamp)
self._lock = _threading.Lock()
@staticmethod
def _key(prompt: str, negative_prompt: str, model_name: str) -> str:
raw = f"{model_name}\x00{prompt}\x00{negative_prompt or ''}"
return _hashlib.sha256(raw.encode()).hexdigest()[:24]
def get(self, prompt: str, negative_prompt: str, model_name: str) -> Optional[dict]:
k = self._key(prompt, negative_prompt, model_name)
with self._lock:
entry = self._store.get(k)
if entry is None:
return None
embeds, ts = entry
if time.time() - ts > self._TTL:
del self._store[k]
return None
return embeds
def put(self, prompt: str, negative_prompt: str, model_name: str,
embeds: dict) -> None:
k = self._key(prompt, negative_prompt, model_name)
with self._lock:
self._store[k] = (embeds, time.time())
# Evict oldest if over limit
if len(self._store) > self._MAX_ENTRIES:
oldest = min(self._store, key=lambda x: self._store[x][1])
del self._store[oldest]
def invalidate_model(self, model_name: str) -> None:
"""Drop all entries for a model (e.g. on pipeline unload)."""
suffix = _hashlib.sha256(model_name.encode()).hexdigest()[:8]
with self._lock:
drop = [k for k in self._store
if self._key("", "", model_name)[:8] == k[:8] or True
# safest: just rebuild key and compare
]
# Rebuild properly: iterate and check by re-computing key prefix
# (can't reconstruct original prompts, so use model name hash marker)
self._store = {
k: v for k, v in self._store.items()
if not k.startswith(_hashlib.sha256(model_name.encode()).hexdigest()[:4])
}
_embed_cache = _PromptEmbedCache()
# Global reference to be set by coderai # Global reference to be set by coderai
global_args = None global_args = None
global_file_path = None global_file_path = None
...@@ -384,7 +452,7 @@ def _load_diffusers_pipeline(model_name: str, global_args): ...@@ -384,7 +452,7 @@ def _load_diffusers_pipeline(model_name: str, global_args):
def _generate_with_diffusers(pipeline, request, global_args, http_request=None): def _generate_with_diffusers(pipeline, request, global_args, http_request=None):
"""Generate images using a diffusers pipeline.""" """Generate images using a diffusers pipeline (with prompt-embedding cache)."""
import torch import torch
import numpy as np import numpy as np
import time as time_module import time as time_module
...@@ -402,13 +470,12 @@ def _generate_with_diffusers(pipeline, request, global_args, http_request=None): ...@@ -402,13 +470,12 @@ def _generate_with_diffusers(pipeline, request, global_args, http_request=None):
height = int(parts[1]) height = int(parts[1])
except ValueError: except ValueError:
pass pass
# Check for nan/inf in dimensions
if width != width or width == float('inf'): if width != width or width == float('inf'):
width = 512 width = 512
if height != height or height == float('inf'): if height != height or height == float('inf'):
height = 512 height = 512
# Enable memory optimizations # Enable memory optimizations
try: try:
if hasattr(pipeline, 'enable_attention_slicing'): if hasattr(pipeline, 'enable_attention_slicing'):
...@@ -417,58 +484,116 @@ def _generate_with_diffusers(pipeline, request, global_args, http_request=None): ...@@ -417,58 +484,116 @@ def _generate_with_diffusers(pipeline, request, global_args, http_request=None):
pipeline.enable_vae_slicing() pipeline.enable_vae_slicing()
except Exception as e: except Exception as e:
print(f"Warning: Could not enable memory optimizations: {e}") print(f"Warning: Could not enable memory optimizations: {e}")
# Get timestamp BEFORE calling diffusers
timestamp = int(time_module.time()) timestamp = int(time_module.time())
# Generate images
seed = request.seed if request.seed is not None else getattr(global_args, 'image_seed', None) seed = request.seed if request.seed is not None else getattr(global_args, 'image_seed', None)
generator = None generator = None
if seed is not None: if seed is not None:
generator = torch.Generator(device=pipeline.device).manual_seed(seed) generator = torch.Generator(device=pipeline.device).manual_seed(seed)
# Quality: "standard" or "hd"
quality = request.quality or "standard" quality = request.quality or "standard"
# Use request parameters if provided, otherwise fall back to quality-based defaults
num_steps = request.steps if request.steps else (30 if quality == "standard" else 50) num_steps = request.steps if request.steps else (30 if quality == "standard" else 50)
cfg_scale = request.guidance_scale if request.guidance_scale else ( cfg_scale = request.guidance_scale if request.guidance_scale else (
getattr(global_args, 'image_cfg_scale', 7.5) if quality == "standard" else 9.0 getattr(global_args, 'image_cfg_scale', 7.5) if quality == "standard" else 9.0
) )
# Generate # ------------------------------------------------------------------
result = pipeline( # Prompt embedding cache
prompt=request.prompt, # Try to encode the prompt once and reuse the embeddings.
negative_prompt=None, # Falls back to passing the plain text prompt if encoding fails.
num_images_per_prompt=request.n, # ------------------------------------------------------------------
height=height, model_id = getattr(pipeline, 'model_name_or_path', None) or str(type(pipeline).__name__)
width=width, neg_prompt = getattr(request, 'negative_prompt', None) or ""
generator=generator, do_cfg = cfg_scale > 1.0
guidance_scale=cfg_scale,
num_inference_steps=num_steps, cached_embeds = _embed_cache.get(request.prompt, neg_prompt, model_id)
) embed_kwargs = {}
cache_hit = False
if cached_embeds is not None:
embed_kwargs = cached_embeds
cache_hit = True
print(f"Prompt embed cache HIT for model '{model_id}'")
else:
# Try to encode and cache
try:
if hasattr(pipeline, 'encode_prompt'):
enc = pipeline.encode_prompt(
prompt=request.prompt,
device=pipeline.device,
num_images_per_prompt=1,
do_classifier_free_guidance=do_cfg,
negative_prompt=neg_prompt or None,
)
# enc is a tuple; length varies by pipeline type
if len(enc) == 2:
# SD 1.x: (prompt_embeds, negative_prompt_embeds)
embed_kwargs = {
'prompt_embeds': enc[0],
'negative_prompt_embeds': enc[1],
}
elif len(enc) == 4:
# SDXL: (prompt_embeds, negative_prompt_embeds,
# pooled_prompt_embeds, negative_pooled_prompt_embeds)
embed_kwargs = {
'prompt_embeds': enc[0],
'negative_prompt_embeds': enc[1],
'pooled_prompt_embeds': enc[2],
'negative_pooled_prompt_embeds': enc[3],
}
if embed_kwargs:
_embed_cache.put(request.prompt, neg_prompt, model_id, embed_kwargs)
print(f"Prompt embed cache STORE for model '{model_id}'")
except Exception as e:
print(f"Warning: prompt encode/cache failed ({e}), using plain text prompt")
embed_kwargs = {}
# Build call kwargs
if embed_kwargs:
call_kwargs = dict(
num_images_per_prompt=request.n,
height=height,
width=width,
generator=generator,
guidance_scale=cfg_scale,
num_inference_steps=num_steps,
**embed_kwargs,
)
else:
call_kwargs = dict(
prompt=request.prompt,
negative_prompt=neg_prompt or None,
num_images_per_prompt=request.n,
height=height,
width=width,
generator=generator,
guidance_scale=cfg_scale,
num_inference_steps=num_steps,
)
result = pipeline(**call_kwargs)
# Extract images # Extract images
images = [] images = []
try: try:
result_images = result.images result_images = result.images
except Exception as img_err: except Exception as img_err:
print(f"Warning: Could not access result.images: {img_err}")
result_images = getattr(result, 'image', None) or getattr(result, 'output', None) result_images = getattr(result, 'image', None) or getattr(result, 'output', None)
if result_images is None: if result_images is None:
raise Exception(f"Could not extract images from diffusers result: {img_err}") raise Exception(f"Could not extract images from diffusers result: {img_err}")
for img in result_images: for img in result_images:
if isinstance(img, np.ndarray): if isinstance(img, np.ndarray):
img = np.nan_to_num(img, nan=0.0, posinf=1.0, neginf=0.0) img = np.nan_to_num(img, nan=0.0, posinf=1.0, neginf=0.0)
img = np.clip(img, 0.0, 1.0) img = np.clip(img, 0.0, 1.0)
img_data = save_image_response(img, request.response_format, http_request) img_data = save_image_response(img, request.response_format, http_request)
images.append(img_data) images.append(img_data)
return { return {
"created": timestamp, "created": timestamp,
"data": images "data": images,
"prompt_cache_hit": cache_hit,
} }
......
# CoderAI - OpenAI-compatible API server
# Copyright (C) 2026 Stefy Lanza <stefy@nexlab.net>
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
Prompt prefix cache manager.
Provides two features:
1. Prefix key computation for same-prefix request scheduling (prompt aggregation).
2. Per-model last-prompt tracking so callers can report accurate cached_tokens.
llama.cpp's KV cache naturally reuses computation when consecutive requests
share a prompt prefix. This manager helps exploit that by:
- Giving the scheduler a stable key to group requests by shared prefix.
- Letting text.py read back how many tokens were cached from timings.
"""
import hashlib
import json
import time
from dataclasses import dataclass, field
from threading import Lock
from typing import Dict, List, Optional
@dataclass
class _CacheEntry:
messages_hash: str
prefix_hash: str
token_count: int
timestamp: float = field(default_factory=time.time)
class PromptCacheManager:
"""
Tracks recently-processed prompt prefixes per model instance.
Usage
-----
# Before dispatching to the model:
prefix_key = manager.get_prefix_key(messages) # for QueueManager scheduling
# After the model call completes:
manager.store(messages, model_key, prompt_tokens)
# In the API response usage block:
cached = manager.get_cached_tokens(model_key) # from last store
"""
def __init__(self, max_entries: int = 256, ttl_seconds: float = 600.0):
self._entries: Dict[str, _CacheEntry] = {}
self._by_model: Dict[str, str] = {} # model_key -> last messages_hash
self._cached_tokens: Dict[str, int] = {} # model_key -> cached tokens from last call
self._max_entries = max_entries
self._ttl = ttl_seconds
self._lock = Lock()
# ------------------------------------------------------------------
# Hashing helpers
# ------------------------------------------------------------------
def _hash_messages(self, messages: List[Dict]) -> str:
"""Stable SHA-256 hash (truncated) of a message list."""
canonical = json.dumps(
[{"role": m.get("role"), "content": m.get("content")} for m in messages],
separators=(",", ":"),
ensure_ascii=False,
)
return hashlib.sha256(canonical.encode()).hexdigest()[:20]
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
def get_prefix_key(self, messages: List[Dict]) -> str:
"""
Stable key for the *cacheable* portion of a request.
The cacheable prefix is everything except the final user turn, since
system prompts and prior assistant turns stay constant across related
requests and benefit most from KV cache reuse.
Returns an empty string when there is no cacheable prefix.
"""
if not messages:
return ""
prefix = messages[:-1] if messages[-1].get("role") == "user" else messages
return self._hash_messages(prefix) if prefix else ""
def store(self, messages: List[Dict], model_key: str, prompt_tokens: int,
cached_tokens: int = 0) -> None:
"""Record a completed prompt so future requests can match against it."""
with self._lock:
msg_hash = self._hash_messages(messages)
prefix_hash = self.get_prefix_key(messages)
self._entries[msg_hash] = _CacheEntry(
messages_hash=msg_hash,
prefix_hash=prefix_hash,
token_count=prompt_tokens,
)
self._by_model[model_key] = msg_hash
self._cached_tokens[model_key] = cached_tokens
self._evict_locked()
def get_cached_tokens(self, model_key: str) -> int:
"""Return the cached_tokens count stored by the last store() call for this model."""
with self._lock:
return self._cached_tokens.get(model_key, 0)
def has_warm_prefix(self, messages: List[Dict], model_key: str) -> bool:
"""
Return True if the current request shares a prefix with the last
request processed by this model (i.e., the KV cache is likely warm).
"""
with self._lock:
last_hash = self._by_model.get(model_key)
if not last_hash:
return False
entry = self._entries.get(last_hash)
if not entry or time.time() - entry.timestamp > self._ttl:
return False
current_prefix = self.get_prefix_key(messages)
return bool(current_prefix and current_prefix == entry.prefix_hash)
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
def _evict_locked(self) -> None:
now = time.time()
expired = [k for k, v in self._entries.items() if now - v.timestamp > self._ttl]
for k in expired:
del self._entries[k]
while len(self._entries) > self._max_entries:
oldest = min(self._entries, key=lambda k: self._entries[k].timestamp)
del self._entries[oldest]
prompt_cache_manager = PromptCacheManager()
...@@ -32,6 +32,7 @@ logger = logging.getLogger(__name__) ...@@ -32,6 +32,7 @@ logger = logging.getLogger(__name__)
# Import from codai modules # Import from codai modules
from codai.models.manager import ModelManager, WhisperServerManager, MultiModelManager, model_manager, multi_model_manager from codai.models.manager import ModelManager, WhisperServerManager, MultiModelManager, model_manager, multi_model_manager
from codai.queue.manager import QueueManager, queue_manager from codai.queue.manager import QueueManager, queue_manager
from codai.api.prompt_cache import prompt_cache_manager
from codai.pydantic.textrequest import ChatCompletionRequest, ToolFunction, Tool from codai.pydantic.textrequest import ChatCompletionRequest, ToolFunction, Tool
from codai.models.parser import filter_malformed_content, filter_repetition, format_tools_for_prompt, cleanup_control_tokens, OpenAIFormatter, ModelParserAdapter, ToolCallParser from codai.models.parser import filter_malformed_content, filter_repetition, format_tools_for_prompt, cleanup_control_tokens, OpenAIFormatter, ModelParserAdapter, ToolCallParser
...@@ -1142,6 +1143,9 @@ async def chat_completions(request: ChatCompletionRequest, http_request: Request ...@@ -1142,6 +1143,9 @@ async def chat_completions(request: ChatCompletionRequest, http_request: Request
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
return JSONResponse(content=formatted_response, headers=headers) return JSONResponse(content=formatted_response, headers=headers)
# Compute prefix key for prompt-aggregation scheduling
_prefix_key = prompt_cache_manager.get_prefix_key(messages_dict)
if request.stream: if request.stream:
async def _managed_stream(): async def _managed_stream():
try: try:
...@@ -1156,6 +1160,7 @@ async def chat_completions(request: ChatCompletionRequest, http_request: Request ...@@ -1156,6 +1160,7 @@ async def chat_completions(request: ChatCompletionRequest, http_request: Request
current_manager, current_manager,
tool_parser, tool_parser,
request.response_format, request.response_format,
_prefix_key,
): ):
yield chunk yield chunk
finally: finally:
...@@ -1192,6 +1197,7 @@ async def stream_chat_response( ...@@ -1192,6 +1197,7 @@ async def stream_chat_response(
current_manager: ModelManager, current_manager: ModelManager,
tool_parser: ToolCallParser, tool_parser: ToolCallParser,
response_format: Optional[Dict] = None, response_format: Optional[Dict] = None,
prefix_key: str = "",
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
"""Stream chat completion response with queue notifications.""" """Stream chat completion response with queue notifications."""
completion_id = f"chatcmpl-{uuid.uuid4().hex}" completion_id = f"chatcmpl-{uuid.uuid4().hex}"
...@@ -1214,7 +1220,7 @@ async def stream_chat_response( ...@@ -1214,7 +1220,7 @@ async def stream_chat_response(
# If model not loaded, add to queue and send waiting notifications # If model not loaded, add to queue and send waiting notifications
if not model_loaded: if not model_loaded:
await queue_manager.add_waiting(request_id) await queue_manager.add_waiting(request_id, prefix_key=prefix_key)
wait_interval = 2.0 # Send waiting update every 2 seconds wait_interval = 2.0 # Send waiting update every 2 seconds
last_wait_update = time.time() last_wait_update = time.time()
...@@ -1457,10 +1463,24 @@ async def stream_chat_response( ...@@ -1457,10 +1463,24 @@ async def stream_chat_response(
prompt_text = "\n".join([m.get("content", "") for m in messages]) prompt_text = "\n".join([m.get("content", "") for m in messages])
prompt_tokens = len(prompt_text.split()) prompt_tokens = len(prompt_text.split())
completion_tokens = len(generated_text.split()) if generated_text else 0 completion_tokens = len(generated_text.split()) if generated_text else 0
# Read accurate usage (including cached_tokens) from the backend
_model_key_for_cache = getattr(current_manager, 'model_name', None) or model_name
last_usage = (current_manager.get_last_usage()
if hasattr(current_manager, 'get_last_usage') else {})
if last_usage.get('prompt_tokens'):
prompt_tokens = last_usage['prompt_tokens']
if last_usage.get('completion_tokens'):
completion_tokens = last_usage['completion_tokens']
cached_tokens = last_usage.get('cached_tokens', 0)
# Store in prompt cache manager for future prefix matching
prompt_cache_manager.store(messages, _model_key_for_cache,
prompt_tokens, cached_tokens)
# Get context size # Get context size
context_size = current_manager.get_context_size() context_size = current_manager.get_context_size()
# Build complete final chunk with all OpenAI fields # Build complete final chunk with all OpenAI fields
final_chunk = { final_chunk = {
"id": completion_id, "id": completion_id,
...@@ -1479,7 +1499,7 @@ async def stream_chat_response( ...@@ -1479,7 +1499,7 @@ async def stream_chat_response(
"total_tokens": prompt_tokens + completion_tokens, "total_tokens": prompt_tokens + completion_tokens,
"context_size": context_size, "context_size": context_size,
"prompt_tokens_details": { "prompt_tokens_details": {
"cached_tokens": 0, "cached_tokens": cached_tokens,
"audio_tokens": 0, "audio_tokens": 0,
}, },
"completion_tokens_details": { "completion_tokens_details": {
...@@ -1494,7 +1514,7 @@ async def stream_chat_response( ...@@ -1494,7 +1514,7 @@ async def stream_chat_response(
"system_fingerprint": None, "system_fingerprint": None,
} }
yield f"data: {json.dumps(final_chunk)}\n\n" yield f"data: {json.dumps(final_chunk)}\n\n"
yield "data: [DONE]\n\n" yield "data: [DONE]\n\n"
except Exception as e: except Exception as e:
print(f"Error during streaming generation: {e}") print(f"Error during streaming generation: {e}")
...@@ -1638,11 +1658,20 @@ async def generate_chat_response( ...@@ -1638,11 +1658,20 @@ async def generate_chat_response(
response_message["tool_calls"] = tool_calls response_message["tool_calls"] = tool_calls
finish_reason = "tool_calls" finish_reason = "tool_calls"
# Calculate token counts - rough estimate since we don't have direct access to tokenizer # Read accurate usage (including cached_tokens) from the backend
_model_key_for_cache = getattr(current_manager, 'model_name', None) or model_name
last_usage = (current_manager.get_last_usage()
if hasattr(current_manager, 'get_last_usage') else {})
prompt_text = "\n".join([m.get("content", "") for m in messages]) prompt_text = "\n".join([m.get("content", "") for m in messages])
prompt_tokens = len(prompt_text.split()) prompt_tokens = last_usage.get('prompt_tokens') or len(prompt_text.split())
completion_tokens = len(generated_text.split()) if generated_text else 0 completion_tokens = last_usage.get('completion_tokens') or (
len(generated_text.split()) if generated_text else 0)
cached_tokens = last_usage.get('cached_tokens', 0)
# Store in prompt cache manager for future prefix matching
prompt_cache_manager.store(messages, _model_key_for_cache,
prompt_tokens, cached_tokens)
# Get context size # Get context size
context_size = current_manager.get_context_size() context_size = current_manager.get_context_size()
...@@ -1655,6 +1684,10 @@ async def generate_chat_response( ...@@ -1655,6 +1684,10 @@ async def generate_chat_response(
tool_calls=response_message.get("tool_calls"), tool_calls=response_message.get("tool_calls"),
context_size=context_size context_size=context_size
) )
# Patch in the real cached_tokens value
if formatted_response and 'usage' in formatted_response:
details = formatted_response['usage'].setdefault('prompt_tokens_details', {})
details['cached_tokens'] = cached_tokens
# Add mock reasoning stats if 'mock' is in force_reasoning_args # Add mock reasoning stats if 'mock' is in force_reasoning_args
# But only if we don't already have real reasoning in the response # But only if we don't already have real reasoning in the response
......
This diff is collapsed.
...@@ -63,6 +63,7 @@ class VulkanBackend(ModelBackend): ...@@ -63,6 +63,7 @@ class VulkanBackend(ModelBackend):
self.force_cuda = original_backend in ("nvidia", "cuda") # Force CUDA if original was nvidia self.force_cuda = original_backend in ("nvidia", "cuda") # Force CUDA if original was nvidia
if self.force_cuda: if self.force_cuda:
print("DEBUG: GGUF model will use CUDA backend (forced by --backend nvidia)") print("DEBUG: GGUF model will use CUDA backend (forced by --backend nvidia)")
self._last_usage: dict = {} # usage from the most recent completion call
self._detect_chat_template() self._detect_chat_template()
def _detect_chat_template(self): def _detect_chat_template(self):
...@@ -649,6 +650,8 @@ class VulkanBackend(ModelBackend): ...@@ -649,6 +650,8 @@ class VulkanBackend(ModelBackend):
stop=stop, stop=stop,
grammar=use_grammar, grammar=use_grammar,
) )
usage = result.get('usage', {})
self._store_usage(usage.get('prompt_tokens', 0), usage.get('completion_tokens', 0))
return result['choices'][0]['text'] return result['choices'][0]['text']
except Exception as e: except Exception as e:
# If grammar generation fails, fall back to normal generation # If grammar generation fails, fall back to normal generation
...@@ -664,6 +667,8 @@ class VulkanBackend(ModelBackend): ...@@ -664,6 +667,8 @@ class VulkanBackend(ModelBackend):
repeat_penalty=repeat_penalty, repeat_penalty=repeat_penalty,
stop=stop, stop=stop,
) )
usage = result.get('usage', {})
self._store_usage(usage.get('prompt_tokens', 0), usage.get('completion_tokens', 0))
return result['choices'][0]['text'] return result['choices'][0]['text']
except Exception as e2: except Exception as e2:
print(f"Error during fallback generation: {e2}") print(f"Error during fallback generation: {e2}")
...@@ -935,6 +940,112 @@ class VulkanBackend(ModelBackend): ...@@ -935,6 +940,112 @@ class VulkanBackend(ModelBackend):
"n_gpu_layers": self.n_gpu_layers, "n_gpu_layers": self.n_gpu_layers,
} }
# ------------------------------------------------------------------
# Usage / cache helpers
# ------------------------------------------------------------------
def _read_cached_tokens(self, prompt_tokens: int) -> int:
"""Extract cached token count from llama.cpp timings after a completion."""
try:
timings = getattr(self.model, 'timings', None)
if timings is None:
# Try the internal context if timings property not exposed
ctx = getattr(self.model, '_ctx', None)
if ctx and hasattr(ctx, 'timings'):
timings = ctx.timings()
if timings is not None:
n_p_eval = getattr(timings, 'n_p_eval', None)
if n_p_eval is not None:
return max(0, prompt_tokens - int(n_p_eval))
except Exception:
pass
return 0
def _store_usage(self, prompt_tokens: int, completion_tokens: int) -> None:
cached = self._read_cached_tokens(prompt_tokens)
self._last_usage = {
'prompt_tokens': prompt_tokens,
'completion_tokens': completion_tokens,
'total_tokens': prompt_tokens + completion_tokens,
'cached_tokens': cached,
}
def get_last_usage(self) -> dict:
"""Return usage dict from the most recent completion (includes cached_tokens)."""
return dict(self._last_usage)
# ------------------------------------------------------------------
# Chat-level generation (uses llama.cpp native chat template)
# ------------------------------------------------------------------
def generate_chat(self, messages, max_tokens=None, temperature=0.7, top_p=1.0,
stop=None, tools=None, response_format=None):
"""Non-streaming chat completion using llama.cpp's native chat handler."""
if self.model is None:
raise RuntimeError("Model not loaded")
kwargs = dict(
messages=messages,
max_tokens=max_tokens or 512,
temperature=temperature,
top_p=top_p,
)
if stop:
kwargs['stop'] = stop
if response_format and response_format.get('type') == 'json_object':
kwargs['response_format'] = {'type': 'json_object'}
result = self.model.create_chat_completion(**kwargs)
usage = result.get('usage', {})
self._store_usage(
prompt_tokens=usage.get('prompt_tokens', 0),
completion_tokens=usage.get('completion_tokens', 0),
)
content = result['choices'][0]['message'].get('content') or ''
return content
async def generate_chat_stream(self, messages, max_tokens=None, temperature=0.7,
top_p=1.0, stop=None, tools=None, response_format=None):
"""Streaming chat completion using llama.cpp's native chat handler."""
if self.model is None:
raise RuntimeError("Model not loaded")
kwargs = dict(
messages=messages,
max_tokens=max_tokens or 512,
temperature=temperature,
top_p=top_p,
stream=True,
)
if stop:
kwargs['stop'] = stop
prompt_tokens = 0
completion_tokens = 0
try:
for chunk in self.model.create_chat_completion(**kwargs):
delta = chunk['choices'][0].get('delta', {})
text = delta.get('content') or ''
if text:
completion_tokens += 1
yield text
# Capture usage if present in final streaming chunk
if chunk.get('usage'):
u = chunk['usage']
prompt_tokens = u.get('prompt_tokens', 0)
completion_tokens = u.get('completion_tokens', completion_tokens)
if chunk['choices'][0].get('finish_reason'):
break
finally:
# Timings are available after the stream is exhausted
if prompt_tokens == 0:
# Estimate from word split if llama.cpp didn't report
prompt_tokens = sum(
len(str(m.get('content', '')).split())
for m in messages
)
self._store_usage(prompt_tokens, completion_tokens)
def get_model_name(self) -> str: def get_model_name(self) -> str:
"""Return the loaded model name.""" """Return the loaded model name."""
return self.model_name or "unknown" return self.model_name or "unknown"
......
...@@ -233,6 +233,12 @@ class ModelManager: ...@@ -233,6 +233,12 @@ class ModelManager:
if self.backend is not None: if self.backend is not None:
return self.backend.get_context_size() return self.backend.get_context_size()
return 2048 # Default fallback return 2048 # Default fallback
def get_last_usage(self) -> dict:
"""Return usage info (including cached_tokens) from the most recent call."""
if self.backend is not None and hasattr(self.backend, 'get_last_usage'):
return self.backend.get_last_usage()
return {}
def cleanup(self): def cleanup(self):
if self.backend is not None: if self.backend is not None:
...@@ -2040,11 +2046,29 @@ class MultiModelManager: ...@@ -2040,11 +2046,29 @@ class MultiModelManager:
"embedding_models": "embedding", "embedding_models": "embedding",
} }
# Minimum capability guaranteed by a model's config category.
# Applied when heuristic name detection doesn't recognise the model ID.
TYPE_MIN_CAP = {
"image": "image_generation",
"video": "video_generation",
"audio": "speech_to_text",
"tts": "text_to_speech",
"audio_gen": "audio_generation",
"embedding": "embeddings",
}
def _add(model_id: str, model_type: str = None, meta: Dict[str, Any] = None): def _add(model_id: str, model_type: str = None, meta: Dict[str, Any] = None):
if model_id in seen_ids: if model_id in seen_ids:
return return
seen_ids.add(model_id) seen_ids.add(model_id)
caps = detect_model_capabilities(model_id) caps = detect_model_capabilities(model_id)
# If heuristic detection missed the type (e.g. custom/vendor model IDs
# that don't match any keyword), ensure the minimum capability for the
# config-declared type is set so badges display correctly.
if model_type and model_type in TYPE_MIN_CAP:
min_cap = TYPE_MIN_CAP[model_type]
if not getattr(caps, min_cap, False):
setattr(caps, min_cap, True)
resolved_type = model_type or (caps.to_list()[0].split("_")[0] if caps.to_list() else "text") resolved_type = model_type or (caps.to_list()[0].split("_")[0] if caps.to_list() else "text")
meta = meta or {} meta = meta or {}
models.append(ModelInfo( models.append(ModelInfo(
...@@ -2075,6 +2099,11 @@ class MultiModelManager: ...@@ -2075,6 +2099,11 @@ class MultiModelManager:
else: else:
raw = m.get("path") or m.get("id") or "" raw = m.get("path") or m.get("id") or ""
alias = m.get("alias") or "" alias = m.get("alias") or ""
# Auto-derive a clean alias for GGUF files that have no
# explicit alias so the full filesystem path isn't exposed.
if not alias and raw.lower().endswith(".gguf"):
stem = raw.split("/")[-1][:-5] # filename without .gguf
alias = stem
# whisper-server aliases are round-robin group keys shared across # whisper-server aliases are round-robin group keys shared across
# multiple instances — don't expose the alias as a separate model # multiple instances — don't expose the alias as a separate model
if m.get("backend") == "whisper-server": if m.get("backend") == "whisper-server":
......
...@@ -39,6 +39,7 @@ class ChatMessage(BaseModel): ...@@ -39,6 +39,7 @@ class ChatMessage(BaseModel):
name: Optional[str] = None name: Optional[str] = None
tool_calls: Optional[List[Dict]] = None tool_calls: Optional[List[Dict]] = None
tool_call_id: Optional[str] = None tool_call_id: Optional[str] = None
cache_control: Optional[Dict] = None # OpenAI-style: {"type": "ephemeral"}
@field_validator('content', mode='before') @field_validator('content', mode='before')
@classmethod @classmethod
......
...@@ -40,6 +40,7 @@ class WaitingRequest: ...@@ -40,6 +40,7 @@ class WaitingRequest:
sequence: int sequence: int
event: asyncio.Event = field(default_factory=asyncio.Event) event: asyncio.Event = field(default_factory=asyncio.Event)
bypassed_by: int = 0 bypassed_by: int = 0
prefix_key: str = "" # stable hash of the cacheable prompt prefix
class QueueManager: class QueueManager:
...@@ -61,6 +62,7 @@ class QueueManager: ...@@ -61,6 +62,7 @@ class QueueManager:
self.model_name: Optional[str] = None self.model_name: Optional[str] = None
self._processing: bool = False self._processing: bool = False
self._ready_request_ids: Set[str] = set() self._ready_request_ids: Set[str] = set()
self._last_prefix_key: str = "" # prefix key of the last completed request
def set_loaded_models(self, model_keys: Set[str]) -> None: def set_loaded_models(self, model_keys: Set[str]) -> None:
self.loaded_models = set(model_keys) self.loaded_models = set(model_keys)
...@@ -83,17 +85,19 @@ class QueueManager: ...@@ -83,17 +85,19 @@ class QueueManager:
self.model_name = None self.model_name = None
self._processing = False self._processing = False
self._ready_request_ids.clear() self._ready_request_ids.clear()
self._last_prefix_key = ""
async def is_full(self) -> bool: async def is_full(self) -> bool:
async with self.lock: async with self.lock:
return len(self.waiting) >= self.max_size return len(self.waiting) >= self.max_size
async def acquire(self, request_id: str, model_key: str) -> SchedulerLease: async def acquire(self, request_id: str, model_key: str,
prefix_key: str = "") -> SchedulerLease:
waiter = None waiter = None
async with self.lock: async with self.lock:
if self._can_start_now(model_key): if self._can_start_now(model_key):
return self._grant_lease(request_id, model_key) return self._grant_lease(request_id, model_key)
waiter = self._enqueue_waiter(request_id, model_key) waiter = self._enqueue_waiter(request_id, model_key, prefix_key)
await waiter.event.wait() await waiter.event.wait()
async with self.lock: async with self.lock:
...@@ -103,7 +107,8 @@ class QueueManager: ...@@ -103,7 +107,8 @@ class QueueManager:
lease.wait_time_seconds = max(0.0, time.time() - waiter.enqueued_at) lease.wait_time_seconds = max(0.0, time.time() - waiter.enqueued_at)
return lease return lease
async def release(self, lease: SchedulerLease) -> None: async def release(self, lease: SchedulerLease,
prefix_key: str = "") -> None:
async with self.lock: async with self.lock:
self.active_leases.pop(lease.request_id, None) self.active_leases.pop(lease.request_id, None)
current = self.active_by_model.get(lease.model_key, 0) current = self.active_by_model.get(lease.model_key, 0)
...@@ -113,14 +118,17 @@ class QueueManager: ...@@ -113,14 +118,17 @@ class QueueManager:
self.active_by_model[lease.model_key] = current - 1 self.active_by_model[lease.model_key] = current - 1
if self.current_request_id == lease.request_id: if self.current_request_id == lease.request_id:
self.current_request_id = None self.current_request_id = None
if prefix_key:
self._last_prefix_key = prefix_key
self._processing = bool(self.active_leases) self._processing = bool(self.active_leases)
self._wake_waiters_locked() self._wake_waiters_locked()
async def add_waiting(self, request_id: str, model_key: str = "") -> None: async def add_waiting(self, request_id: str, model_key: str = "",
prefix_key: str = "") -> None:
async with self.lock: async with self.lock:
if request_id in self.waiting_by_id: if request_id in self.waiting_by_id:
return return
self._enqueue_waiter(request_id, model_key or request_id) self._enqueue_waiter(request_id, model_key or request_id, prefix_key)
async def remove_waiting(self, request_id: str) -> None: async def remove_waiting(self, request_id: str) -> None:
async with self.lock: async with self.lock:
...@@ -172,13 +180,15 @@ class QueueManager: ...@@ -172,13 +180,15 @@ class QueueManager:
"loaded_models": sorted(self.loaded_models), "loaded_models": sorted(self.loaded_models),
} }
def _enqueue_waiter(self, request_id: str, model_key: str) -> WaitingRequest: def _enqueue_waiter(self, request_id: str, model_key: str,
prefix_key: str = "") -> WaitingRequest:
self.sequence += 1 self.sequence += 1
waiter = WaitingRequest( waiter = WaitingRequest(
request_id=request_id, request_id=request_id,
model_key=model_key, model_key=model_key,
enqueued_at=time.time(), enqueued_at=time.time(),
sequence=self.sequence, sequence=self.sequence,
prefix_key=prefix_key,
) )
self.waiting.append(waiter) self.waiting.append(waiter)
self.waiting_by_id[request_id] = waiter self.waiting_by_id[request_id] = waiter
...@@ -233,17 +243,39 @@ class QueueManager: ...@@ -233,17 +243,39 @@ class QueueManager:
return return
def _pick_next_waiter_locked(self) -> Optional[WaitingRequest]: def _pick_next_waiter_locked(self) -> Optional[WaitingRequest]:
for waiter in self.waiting: # Collect all candidates that can start now.
if self._waiter_can_start_locked(waiter): candidates = [w for w in self.waiting if self._waiter_can_start_locked(w)]
older_blocked = [ if not candidates:
other for other in self.waiting return None
if other.sequence < waiter.sequence and not self._waiter_can_start_locked(other)
] # Fairness: don't bypass an older waiter more than the limit.
if any(other.bypassed_by >= self.fairness_bypass_limit for other in older_blocked): def _is_fair(waiter: WaitingRequest) -> bool:
continue older_blocked = [
for other in older_blocked: other for other in self.waiting
other.bypassed_by += 1 if other.sequence < waiter.sequence and not self._waiter_can_start_locked(other)
]
if any(other.bypassed_by >= self.fairness_bypass_limit for other in older_blocked):
return False
for other in older_blocked:
other.bypassed_by += 1
return True
# Prompt aggregation: prefer candidates whose prefix key matches the
# last completed request — they will hit a warm KV cache.
if self._last_prefix_key:
warm_candidates = [
w for w in candidates
if w.prefix_key and w.prefix_key == self._last_prefix_key
]
for waiter in warm_candidates:
if _is_fair(waiter):
return waiter
# Fall back to FIFO order.
for waiter in candidates:
if _is_fair(waiter):
return waiter return waiter
return None return None
def _waiting_counts_locked(self) -> Dict[str, int]: def _waiting_counts_locked(self) -> Dict[str, int]:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment