Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Contribute to GitLab
Sign in
Toggle navigation
C
coderai
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
nexlab
coderai
Commits
07fb3251
Commit
07fb3251
authored
May 07, 2026
by
Stefy Lanza (nextime / spora )
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Add prompt caching and prompt aggregation
parent
0ac26bed
Changes
10
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
939 additions
and
70 deletions
+939
-70
README.md
README.md
+8
-0
chat.html
codai/admin/templates/chat.html
+109
-15
images.py
codai/api/images.py
+155
-30
prompt_cache.py
codai/api/prompt_cache.py
+151
-0
text.py
codai/api/text.py
+42
-9
cuda.py
codai/backends/cuda.py
+285
-0
vulkan.py
codai/backends/vulkan.py
+111
-0
manager.py
codai/models/manager.py
+29
-0
textrequest.py
codai/pydantic/textrequest.py
+1
-0
manager.py
codai/queue/manager.py
+48
-16
No files found.
README.md
View file @
07fb3251
...
@@ -13,6 +13,14 @@ An OpenAI-compatible API server to run models on your local GPU with web adminis
...
@@ -13,6 +13,14 @@ An OpenAI-compatible API server to run models on your local GPU with web adminis
-
**Multi-Modal**
: Text, image, video, audio, TTS, STT, embeddings
-
**Multi-Modal**
: Text, image, video, audio, TTS, STT, embeddings
-
**Per-Model Configuration**
: Individual settings for each model (GPU layers, quantization, context size)
-
**Per-Model Configuration**
: Individual settings for each model (GPU layers, quantization, context size)
-
**On-Demand Loading**
: Models load automatically when requested, unload when idle
-
**On-Demand Loading**
: Models load automatically when requested, unload when idle
-
**Memory Management**
: Smart VRAM → RAM → Disk offloading for efficient resource usage
-
**Parallel Execution**
: Run multiple models simultaneously (VRAM permitting)
-
**Auto-Swap**
: Automatic model switching on request — load what's needed, unload what's idle
-
**Request Queue**
: Concurrent requests are queued and processed in order per model
-
**Prompt Caching**
: Reuse KV cache across requests to reduce latency and computation
-
**Prompt Aggregation**
: Batch concurrent requests into a single inference pass for higher throughput
-
**Custom Pipelines**
: Create and save multi-step workflows combining any generation tasks
-
**Pre-Built Pipelines**
: Ready-to-use pipelines for common workflows (image-to-video, dubbing, story generation)
### GPU Backend Support
### GPU Backend Support
-
**NVIDIA (CUDA)**
: PyTorch + Transformers for HuggingFace models
-
**NVIDIA (CUDA)**
: PyTorch + Transformers for HuggingFace models
...
...
codai/admin/templates/chat.html
View file @
07fb3251
This diff is collapsed.
Click to expand it.
codai/api/images.py
View file @
07fb3251
...
@@ -39,6 +39,74 @@ from codai.pydantic.imagerequest import ImageGenerationRequest
...
@@ -39,6 +39,74 @@ from codai.pydantic.imagerequest import ImageGenerationRequest
from
codai.api.state
import
get_load_mode
from
codai.api.state
import
get_load_mode
# =============================================================================
# Prompt embedding cache (diffusers)
#
# Caches text-encoder outputs keyed by (prompt, negative_prompt, model_name).
# When the same prompt is requested again the encode step is skipped and the
# cached tensors are passed directly to the pipeline, saving CLIP/T5 compute.
# sd.cpp handles encoding internally — no equivalent caching is possible there.
# =============================================================================
import
hashlib
as
_hashlib
import
threading
as
_threading
class
_PromptEmbedCache
:
"""Single-entry LRU cache for diffusers prompt embeddings."""
_MAX_ENTRIES
=
32
_TTL
=
600.0
# 10 minutes
def
__init__
(
self
):
self
.
_store
:
dict
=
{}
# key -> (embeds_dict, timestamp)
self
.
_lock
=
_threading
.
Lock
()
@
staticmethod
def
_key
(
prompt
:
str
,
negative_prompt
:
str
,
model_name
:
str
)
->
str
:
raw
=
f
"{model_name}
\x00
{prompt}
\x00
{negative_prompt or ''}"
return
_hashlib
.
sha256
(
raw
.
encode
())
.
hexdigest
()[:
24
]
def
get
(
self
,
prompt
:
str
,
negative_prompt
:
str
,
model_name
:
str
)
->
Optional
[
dict
]:
k
=
self
.
_key
(
prompt
,
negative_prompt
,
model_name
)
with
self
.
_lock
:
entry
=
self
.
_store
.
get
(
k
)
if
entry
is
None
:
return
None
embeds
,
ts
=
entry
if
time
.
time
()
-
ts
>
self
.
_TTL
:
del
self
.
_store
[
k
]
return
None
return
embeds
def
put
(
self
,
prompt
:
str
,
negative_prompt
:
str
,
model_name
:
str
,
embeds
:
dict
)
->
None
:
k
=
self
.
_key
(
prompt
,
negative_prompt
,
model_name
)
with
self
.
_lock
:
self
.
_store
[
k
]
=
(
embeds
,
time
.
time
())
# Evict oldest if over limit
if
len
(
self
.
_store
)
>
self
.
_MAX_ENTRIES
:
oldest
=
min
(
self
.
_store
,
key
=
lambda
x
:
self
.
_store
[
x
][
1
])
del
self
.
_store
[
oldest
]
def
invalidate_model
(
self
,
model_name
:
str
)
->
None
:
"""Drop all entries for a model (e.g. on pipeline unload)."""
suffix
=
_hashlib
.
sha256
(
model_name
.
encode
())
.
hexdigest
()[:
8
]
with
self
.
_lock
:
drop
=
[
k
for
k
in
self
.
_store
if
self
.
_key
(
""
,
""
,
model_name
)[:
8
]
==
k
[:
8
]
or
True
# safest: just rebuild key and compare
]
# Rebuild properly: iterate and check by re-computing key prefix
# (can't reconstruct original prompts, so use model name hash marker)
self
.
_store
=
{
k
:
v
for
k
,
v
in
self
.
_store
.
items
()
if
not
k
.
startswith
(
_hashlib
.
sha256
(
model_name
.
encode
())
.
hexdigest
()[:
4
])
}
_embed_cache
=
_PromptEmbedCache
()
# Global reference to be set by coderai
# Global reference to be set by coderai
global_args
=
None
global_args
=
None
global_file_path
=
None
global_file_path
=
None
...
@@ -384,7 +452,7 @@ def _load_diffusers_pipeline(model_name: str, global_args):
...
@@ -384,7 +452,7 @@ def _load_diffusers_pipeline(model_name: str, global_args):
def
_generate_with_diffusers
(
pipeline
,
request
,
global_args
,
http_request
=
None
):
def
_generate_with_diffusers
(
pipeline
,
request
,
global_args
,
http_request
=
None
):
"""Generate images using a diffusers pipeline."""
"""Generate images using a diffusers pipeline
(with prompt-embedding cache)
."""
import
torch
import
torch
import
numpy
as
np
import
numpy
as
np
import
time
as
time_module
import
time
as
time_module
...
@@ -402,13 +470,12 @@ def _generate_with_diffusers(pipeline, request, global_args, http_request=None):
...
@@ -402,13 +470,12 @@ def _generate_with_diffusers(pipeline, request, global_args, http_request=None):
height
=
int
(
parts
[
1
])
height
=
int
(
parts
[
1
])
except
ValueError
:
except
ValueError
:
pass
pass
# Check for nan/inf in dimensions
if
width
!=
width
or
width
==
float
(
'inf'
):
if
width
!=
width
or
width
==
float
(
'inf'
):
width
=
512
width
=
512
if
height
!=
height
or
height
==
float
(
'inf'
):
if
height
!=
height
or
height
==
float
(
'inf'
):
height
=
512
height
=
512
# Enable memory optimizations
# Enable memory optimizations
try
:
try
:
if
hasattr
(
pipeline
,
'enable_attention_slicing'
):
if
hasattr
(
pipeline
,
'enable_attention_slicing'
):
...
@@ -417,58 +484,116 @@ def _generate_with_diffusers(pipeline, request, global_args, http_request=None):
...
@@ -417,58 +484,116 @@ def _generate_with_diffusers(pipeline, request, global_args, http_request=None):
pipeline
.
enable_vae_slicing
()
pipeline
.
enable_vae_slicing
()
except
Exception
as
e
:
except
Exception
as
e
:
print
(
f
"Warning: Could not enable memory optimizations: {e}"
)
print
(
f
"Warning: Could not enable memory optimizations: {e}"
)
# Get timestamp BEFORE calling diffusers
timestamp
=
int
(
time_module
.
time
())
timestamp
=
int
(
time_module
.
time
())
# Generate images
seed
=
request
.
seed
if
request
.
seed
is
not
None
else
getattr
(
global_args
,
'image_seed'
,
None
)
seed
=
request
.
seed
if
request
.
seed
is
not
None
else
getattr
(
global_args
,
'image_seed'
,
None
)
generator
=
None
generator
=
None
if
seed
is
not
None
:
if
seed
is
not
None
:
generator
=
torch
.
Generator
(
device
=
pipeline
.
device
)
.
manual_seed
(
seed
)
generator
=
torch
.
Generator
(
device
=
pipeline
.
device
)
.
manual_seed
(
seed
)
# Quality: "standard" or "hd"
quality
=
request
.
quality
or
"standard"
quality
=
request
.
quality
or
"standard"
# Use request parameters if provided, otherwise fall back to quality-based defaults
num_steps
=
request
.
steps
if
request
.
steps
else
(
30
if
quality
==
"standard"
else
50
)
num_steps
=
request
.
steps
if
request
.
steps
else
(
30
if
quality
==
"standard"
else
50
)
cfg_scale
=
request
.
guidance_scale
if
request
.
guidance_scale
else
(
cfg_scale
=
request
.
guidance_scale
if
request
.
guidance_scale
else
(
getattr
(
global_args
,
'image_cfg_scale'
,
7.5
)
if
quality
==
"standard"
else
9.0
getattr
(
global_args
,
'image_cfg_scale'
,
7.5
)
if
quality
==
"standard"
else
9.0
)
)
# Generate
# ------------------------------------------------------------------
result
=
pipeline
(
# Prompt embedding cache
prompt
=
request
.
prompt
,
# Try to encode the prompt once and reuse the embeddings.
negative_prompt
=
None
,
# Falls back to passing the plain text prompt if encoding fails.
num_images_per_prompt
=
request
.
n
,
# ------------------------------------------------------------------
height
=
height
,
model_id
=
getattr
(
pipeline
,
'model_name_or_path'
,
None
)
or
str
(
type
(
pipeline
)
.
__name__
)
width
=
width
,
neg_prompt
=
getattr
(
request
,
'negative_prompt'
,
None
)
or
""
generator
=
generator
,
do_cfg
=
cfg_scale
>
1.0
guidance_scale
=
cfg_scale
,
num_inference_steps
=
num_steps
,
cached_embeds
=
_embed_cache
.
get
(
request
.
prompt
,
neg_prompt
,
model_id
)
)
embed_kwargs
=
{}
cache_hit
=
False
if
cached_embeds
is
not
None
:
embed_kwargs
=
cached_embeds
cache_hit
=
True
print
(
f
"Prompt embed cache HIT for model '{model_id}'"
)
else
:
# Try to encode and cache
try
:
if
hasattr
(
pipeline
,
'encode_prompt'
):
enc
=
pipeline
.
encode_prompt
(
prompt
=
request
.
prompt
,
device
=
pipeline
.
device
,
num_images_per_prompt
=
1
,
do_classifier_free_guidance
=
do_cfg
,
negative_prompt
=
neg_prompt
or
None
,
)
# enc is a tuple; length varies by pipeline type
if
len
(
enc
)
==
2
:
# SD 1.x: (prompt_embeds, negative_prompt_embeds)
embed_kwargs
=
{
'prompt_embeds'
:
enc
[
0
],
'negative_prompt_embeds'
:
enc
[
1
],
}
elif
len
(
enc
)
==
4
:
# SDXL: (prompt_embeds, negative_prompt_embeds,
# pooled_prompt_embeds, negative_pooled_prompt_embeds)
embed_kwargs
=
{
'prompt_embeds'
:
enc
[
0
],
'negative_prompt_embeds'
:
enc
[
1
],
'pooled_prompt_embeds'
:
enc
[
2
],
'negative_pooled_prompt_embeds'
:
enc
[
3
],
}
if
embed_kwargs
:
_embed_cache
.
put
(
request
.
prompt
,
neg_prompt
,
model_id
,
embed_kwargs
)
print
(
f
"Prompt embed cache STORE for model '{model_id}'"
)
except
Exception
as
e
:
print
(
f
"Warning: prompt encode/cache failed ({e}), using plain text prompt"
)
embed_kwargs
=
{}
# Build call kwargs
if
embed_kwargs
:
call_kwargs
=
dict
(
num_images_per_prompt
=
request
.
n
,
height
=
height
,
width
=
width
,
generator
=
generator
,
guidance_scale
=
cfg_scale
,
num_inference_steps
=
num_steps
,
**
embed_kwargs
,
)
else
:
call_kwargs
=
dict
(
prompt
=
request
.
prompt
,
negative_prompt
=
neg_prompt
or
None
,
num_images_per_prompt
=
request
.
n
,
height
=
height
,
width
=
width
,
generator
=
generator
,
guidance_scale
=
cfg_scale
,
num_inference_steps
=
num_steps
,
)
result
=
pipeline
(
**
call_kwargs
)
# Extract images
# Extract images
images
=
[]
images
=
[]
try
:
try
:
result_images
=
result
.
images
result_images
=
result
.
images
except
Exception
as
img_err
:
except
Exception
as
img_err
:
print
(
f
"Warning: Could not access result.images: {img_err}"
)
result_images
=
getattr
(
result
,
'image'
,
None
)
or
getattr
(
result
,
'output'
,
None
)
result_images
=
getattr
(
result
,
'image'
,
None
)
or
getattr
(
result
,
'output'
,
None
)
if
result_images
is
None
:
if
result_images
is
None
:
raise
Exception
(
f
"Could not extract images from diffusers result: {img_err}"
)
raise
Exception
(
f
"Could not extract images from diffusers result: {img_err}"
)
for
img
in
result_images
:
for
img
in
result_images
:
if
isinstance
(
img
,
np
.
ndarray
):
if
isinstance
(
img
,
np
.
ndarray
):
img
=
np
.
nan_to_num
(
img
,
nan
=
0.0
,
posinf
=
1.0
,
neginf
=
0.0
)
img
=
np
.
nan_to_num
(
img
,
nan
=
0.0
,
posinf
=
1.0
,
neginf
=
0.0
)
img
=
np
.
clip
(
img
,
0.0
,
1.0
)
img
=
np
.
clip
(
img
,
0.0
,
1.0
)
img_data
=
save_image_response
(
img
,
request
.
response_format
,
http_request
)
img_data
=
save_image_response
(
img
,
request
.
response_format
,
http_request
)
images
.
append
(
img_data
)
images
.
append
(
img_data
)
return
{
return
{
"created"
:
timestamp
,
"created"
:
timestamp
,
"data"
:
images
"data"
:
images
,
"prompt_cache_hit"
:
cache_hit
,
}
}
...
...
codai/api/prompt_cache.py
0 → 100644
View file @
07fb3251
# CoderAI - OpenAI-compatible API server
# Copyright (C) 2026 Stefy Lanza <stefy@nexlab.net>
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
Prompt prefix cache manager.
Provides two features:
1. Prefix key computation for same-prefix request scheduling (prompt aggregation).
2. Per-model last-prompt tracking so callers can report accurate cached_tokens.
llama.cpp's KV cache naturally reuses computation when consecutive requests
share a prompt prefix. This manager helps exploit that by:
- Giving the scheduler a stable key to group requests by shared prefix.
- Letting text.py read back how many tokens were cached from timings.
"""
import
hashlib
import
json
import
time
from
dataclasses
import
dataclass
,
field
from
threading
import
Lock
from
typing
import
Dict
,
List
,
Optional
@
dataclass
class
_CacheEntry
:
messages_hash
:
str
prefix_hash
:
str
token_count
:
int
timestamp
:
float
=
field
(
default_factory
=
time
.
time
)
class
PromptCacheManager
:
"""
Tracks recently-processed prompt prefixes per model instance.
Usage
-----
# Before dispatching to the model:
prefix_key = manager.get_prefix_key(messages) # for QueueManager scheduling
# After the model call completes:
manager.store(messages, model_key, prompt_tokens)
# In the API response usage block:
cached = manager.get_cached_tokens(model_key) # from last store
"""
def
__init__
(
self
,
max_entries
:
int
=
256
,
ttl_seconds
:
float
=
600.0
):
self
.
_entries
:
Dict
[
str
,
_CacheEntry
]
=
{}
self
.
_by_model
:
Dict
[
str
,
str
]
=
{}
# model_key -> last messages_hash
self
.
_cached_tokens
:
Dict
[
str
,
int
]
=
{}
# model_key -> cached tokens from last call
self
.
_max_entries
=
max_entries
self
.
_ttl
=
ttl_seconds
self
.
_lock
=
Lock
()
# ------------------------------------------------------------------
# Hashing helpers
# ------------------------------------------------------------------
def
_hash_messages
(
self
,
messages
:
List
[
Dict
])
->
str
:
"""Stable SHA-256 hash (truncated) of a message list."""
canonical
=
json
.
dumps
(
[{
"role"
:
m
.
get
(
"role"
),
"content"
:
m
.
get
(
"content"
)}
for
m
in
messages
],
separators
=
(
","
,
":"
),
ensure_ascii
=
False
,
)
return
hashlib
.
sha256
(
canonical
.
encode
())
.
hexdigest
()[:
20
]
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
def
get_prefix_key
(
self
,
messages
:
List
[
Dict
])
->
str
:
"""
Stable key for the *cacheable* portion of a request.
The cacheable prefix is everything except the final user turn, since
system prompts and prior assistant turns stay constant across related
requests and benefit most from KV cache reuse.
Returns an empty string when there is no cacheable prefix.
"""
if
not
messages
:
return
""
prefix
=
messages
[:
-
1
]
if
messages
[
-
1
]
.
get
(
"role"
)
==
"user"
else
messages
return
self
.
_hash_messages
(
prefix
)
if
prefix
else
""
def
store
(
self
,
messages
:
List
[
Dict
],
model_key
:
str
,
prompt_tokens
:
int
,
cached_tokens
:
int
=
0
)
->
None
:
"""Record a completed prompt so future requests can match against it."""
with
self
.
_lock
:
msg_hash
=
self
.
_hash_messages
(
messages
)
prefix_hash
=
self
.
get_prefix_key
(
messages
)
self
.
_entries
[
msg_hash
]
=
_CacheEntry
(
messages_hash
=
msg_hash
,
prefix_hash
=
prefix_hash
,
token_count
=
prompt_tokens
,
)
self
.
_by_model
[
model_key
]
=
msg_hash
self
.
_cached_tokens
[
model_key
]
=
cached_tokens
self
.
_evict_locked
()
def
get_cached_tokens
(
self
,
model_key
:
str
)
->
int
:
"""Return the cached_tokens count stored by the last store() call for this model."""
with
self
.
_lock
:
return
self
.
_cached_tokens
.
get
(
model_key
,
0
)
def
has_warm_prefix
(
self
,
messages
:
List
[
Dict
],
model_key
:
str
)
->
bool
:
"""
Return True if the current request shares a prefix with the last
request processed by this model (i.e., the KV cache is likely warm).
"""
with
self
.
_lock
:
last_hash
=
self
.
_by_model
.
get
(
model_key
)
if
not
last_hash
:
return
False
entry
=
self
.
_entries
.
get
(
last_hash
)
if
not
entry
or
time
.
time
()
-
entry
.
timestamp
>
self
.
_ttl
:
return
False
current_prefix
=
self
.
get_prefix_key
(
messages
)
return
bool
(
current_prefix
and
current_prefix
==
entry
.
prefix_hash
)
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
def
_evict_locked
(
self
)
->
None
:
now
=
time
.
time
()
expired
=
[
k
for
k
,
v
in
self
.
_entries
.
items
()
if
now
-
v
.
timestamp
>
self
.
_ttl
]
for
k
in
expired
:
del
self
.
_entries
[
k
]
while
len
(
self
.
_entries
)
>
self
.
_max_entries
:
oldest
=
min
(
self
.
_entries
,
key
=
lambda
k
:
self
.
_entries
[
k
]
.
timestamp
)
del
self
.
_entries
[
oldest
]
prompt_cache_manager
=
PromptCacheManager
()
codai/api/text.py
View file @
07fb3251
...
@@ -32,6 +32,7 @@ logger = logging.getLogger(__name__)
...
@@ -32,6 +32,7 @@ logger = logging.getLogger(__name__)
# Import from codai modules
# Import from codai modules
from
codai.models.manager
import
ModelManager
,
WhisperServerManager
,
MultiModelManager
,
model_manager
,
multi_model_manager
from
codai.models.manager
import
ModelManager
,
WhisperServerManager
,
MultiModelManager
,
model_manager
,
multi_model_manager
from
codai.queue.manager
import
QueueManager
,
queue_manager
from
codai.queue.manager
import
QueueManager
,
queue_manager
from
codai.api.prompt_cache
import
prompt_cache_manager
from
codai.pydantic.textrequest
import
ChatCompletionRequest
,
ToolFunction
,
Tool
from
codai.pydantic.textrequest
import
ChatCompletionRequest
,
ToolFunction
,
Tool
from
codai.models.parser
import
filter_malformed_content
,
filter_repetition
,
format_tools_for_prompt
,
cleanup_control_tokens
,
OpenAIFormatter
,
ModelParserAdapter
,
ToolCallParser
from
codai.models.parser
import
filter_malformed_content
,
filter_repetition
,
format_tools_for_prompt
,
cleanup_control_tokens
,
OpenAIFormatter
,
ModelParserAdapter
,
ToolCallParser
...
@@ -1142,6 +1143,9 @@ async def chat_completions(request: ChatCompletionRequest, http_request: Request
...
@@ -1142,6 +1143,9 @@ async def chat_completions(request: ChatCompletionRequest, http_request: Request
from
fastapi.responses
import
JSONResponse
from
fastapi.responses
import
JSONResponse
return
JSONResponse
(
content
=
formatted_response
,
headers
=
headers
)
return
JSONResponse
(
content
=
formatted_response
,
headers
=
headers
)
# Compute prefix key for prompt-aggregation scheduling
_prefix_key
=
prompt_cache_manager
.
get_prefix_key
(
messages_dict
)
if
request
.
stream
:
if
request
.
stream
:
async
def
_managed_stream
():
async
def
_managed_stream
():
try
:
try
:
...
@@ -1156,6 +1160,7 @@ async def chat_completions(request: ChatCompletionRequest, http_request: Request
...
@@ -1156,6 +1160,7 @@ async def chat_completions(request: ChatCompletionRequest, http_request: Request
current_manager
,
current_manager
,
tool_parser
,
tool_parser
,
request
.
response_format
,
request
.
response_format
,
_prefix_key
,
):
):
yield
chunk
yield
chunk
finally
:
finally
:
...
@@ -1192,6 +1197,7 @@ async def stream_chat_response(
...
@@ -1192,6 +1197,7 @@ async def stream_chat_response(
current_manager
:
ModelManager
,
current_manager
:
ModelManager
,
tool_parser
:
ToolCallParser
,
tool_parser
:
ToolCallParser
,
response_format
:
Optional
[
Dict
]
=
None
,
response_format
:
Optional
[
Dict
]
=
None
,
prefix_key
:
str
=
""
,
)
->
AsyncGenerator
[
str
,
None
]:
)
->
AsyncGenerator
[
str
,
None
]:
"""Stream chat completion response with queue notifications."""
"""Stream chat completion response with queue notifications."""
completion_id
=
f
"chatcmpl-{uuid.uuid4().hex}"
completion_id
=
f
"chatcmpl-{uuid.uuid4().hex}"
...
@@ -1214,7 +1220,7 @@ async def stream_chat_response(
...
@@ -1214,7 +1220,7 @@ async def stream_chat_response(
# If model not loaded, add to queue and send waiting notifications
# If model not loaded, add to queue and send waiting notifications
if
not
model_loaded
:
if
not
model_loaded
:
await
queue_manager
.
add_waiting
(
request_id
)
await
queue_manager
.
add_waiting
(
request_id
,
prefix_key
=
prefix_key
)
wait_interval
=
2.0
# Send waiting update every 2 seconds
wait_interval
=
2.0
# Send waiting update every 2 seconds
last_wait_update
=
time
.
time
()
last_wait_update
=
time
.
time
()
...
@@ -1457,10 +1463,24 @@ async def stream_chat_response(
...
@@ -1457,10 +1463,24 @@ async def stream_chat_response(
prompt_text
=
"
\n
"
.
join
([
m
.
get
(
"content"
,
""
)
for
m
in
messages
])
prompt_text
=
"
\n
"
.
join
([
m
.
get
(
"content"
,
""
)
for
m
in
messages
])
prompt_tokens
=
len
(
prompt_text
.
split
())
prompt_tokens
=
len
(
prompt_text
.
split
())
completion_tokens
=
len
(
generated_text
.
split
())
if
generated_text
else
0
completion_tokens
=
len
(
generated_text
.
split
())
if
generated_text
else
0
# Read accurate usage (including cached_tokens) from the backend
_model_key_for_cache
=
getattr
(
current_manager
,
'model_name'
,
None
)
or
model_name
last_usage
=
(
current_manager
.
get_last_usage
()
if
hasattr
(
current_manager
,
'get_last_usage'
)
else
{})
if
last_usage
.
get
(
'prompt_tokens'
):
prompt_tokens
=
last_usage
[
'prompt_tokens'
]
if
last_usage
.
get
(
'completion_tokens'
):
completion_tokens
=
last_usage
[
'completion_tokens'
]
cached_tokens
=
last_usage
.
get
(
'cached_tokens'
,
0
)
# Store in prompt cache manager for future prefix matching
prompt_cache_manager
.
store
(
messages
,
_model_key_for_cache
,
prompt_tokens
,
cached_tokens
)
# Get context size
# Get context size
context_size
=
current_manager
.
get_context_size
()
context_size
=
current_manager
.
get_context_size
()
# Build complete final chunk with all OpenAI fields
# Build complete final chunk with all OpenAI fields
final_chunk
=
{
final_chunk
=
{
"id"
:
completion_id
,
"id"
:
completion_id
,
...
@@ -1479,7 +1499,7 @@ async def stream_chat_response(
...
@@ -1479,7 +1499,7 @@ async def stream_chat_response(
"total_tokens"
:
prompt_tokens
+
completion_tokens
,
"total_tokens"
:
prompt_tokens
+
completion_tokens
,
"context_size"
:
context_size
,
"context_size"
:
context_size
,
"prompt_tokens_details"
:
{
"prompt_tokens_details"
:
{
"cached_tokens"
:
0
,
"cached_tokens"
:
cached_tokens
,
"audio_tokens"
:
0
,
"audio_tokens"
:
0
,
},
},
"completion_tokens_details"
:
{
"completion_tokens_details"
:
{
...
@@ -1494,7 +1514,7 @@ async def stream_chat_response(
...
@@ -1494,7 +1514,7 @@ async def stream_chat_response(
"system_fingerprint"
:
None
,
"system_fingerprint"
:
None
,
}
}
yield
f
"data: {json.dumps(final_chunk)}
\n\n
"
yield
f
"data: {json.dumps(final_chunk)}
\n\n
"
yield
"data: [DONE]
\n\n
"
yield
"data: [DONE]
\n\n
"
except
Exception
as
e
:
except
Exception
as
e
:
print
(
f
"Error during streaming generation: {e}"
)
print
(
f
"Error during streaming generation: {e}"
)
...
@@ -1638,11 +1658,20 @@ async def generate_chat_response(
...
@@ -1638,11 +1658,20 @@ async def generate_chat_response(
response_message
[
"tool_calls"
]
=
tool_calls
response_message
[
"tool_calls"
]
=
tool_calls
finish_reason
=
"tool_calls"
finish_reason
=
"tool_calls"
# Calculate token counts - rough estimate since we don't have direct access to tokenizer
# Read accurate usage (including cached_tokens) from the backend
_model_key_for_cache
=
getattr
(
current_manager
,
'model_name'
,
None
)
or
model_name
last_usage
=
(
current_manager
.
get_last_usage
()
if
hasattr
(
current_manager
,
'get_last_usage'
)
else
{})
prompt_text
=
"
\n
"
.
join
([
m
.
get
(
"content"
,
""
)
for
m
in
messages
])
prompt_text
=
"
\n
"
.
join
([
m
.
get
(
"content"
,
""
)
for
m
in
messages
])
prompt_tokens
=
len
(
prompt_text
.
split
())
prompt_tokens
=
last_usage
.
get
(
'prompt_tokens'
)
or
len
(
prompt_text
.
split
())
completion_tokens
=
len
(
generated_text
.
split
())
if
generated_text
else
0
completion_tokens
=
last_usage
.
get
(
'completion_tokens'
)
or
(
len
(
generated_text
.
split
())
if
generated_text
else
0
)
cached_tokens
=
last_usage
.
get
(
'cached_tokens'
,
0
)
# Store in prompt cache manager for future prefix matching
prompt_cache_manager
.
store
(
messages
,
_model_key_for_cache
,
prompt_tokens
,
cached_tokens
)
# Get context size
# Get context size
context_size
=
current_manager
.
get_context_size
()
context_size
=
current_manager
.
get_context_size
()
...
@@ -1655,6 +1684,10 @@ async def generate_chat_response(
...
@@ -1655,6 +1684,10 @@ async def generate_chat_response(
tool_calls
=
response_message
.
get
(
"tool_calls"
),
tool_calls
=
response_message
.
get
(
"tool_calls"
),
context_size
=
context_size
context_size
=
context_size
)
)
# Patch in the real cached_tokens value
if
formatted_response
and
'usage'
in
formatted_response
:
details
=
formatted_response
[
'usage'
]
.
setdefault
(
'prompt_tokens_details'
,
{})
details
[
'cached_tokens'
]
=
cached_tokens
# Add mock reasoning stats if 'mock' is in force_reasoning_args
# Add mock reasoning stats if 'mock' is in force_reasoning_args
# But only if we don't already have real reasoning in the response
# But only if we don't already have real reasoning in the response
...
...
codai/backends/cuda.py
View file @
07fb3251
This diff is collapsed.
Click to expand it.
codai/backends/vulkan.py
View file @
07fb3251
...
@@ -63,6 +63,7 @@ class VulkanBackend(ModelBackend):
...
@@ -63,6 +63,7 @@ class VulkanBackend(ModelBackend):
self
.
force_cuda
=
original_backend
in
(
"nvidia"
,
"cuda"
)
# Force CUDA if original was nvidia
self
.
force_cuda
=
original_backend
in
(
"nvidia"
,
"cuda"
)
# Force CUDA if original was nvidia
if
self
.
force_cuda
:
if
self
.
force_cuda
:
print
(
"DEBUG: GGUF model will use CUDA backend (forced by --backend nvidia)"
)
print
(
"DEBUG: GGUF model will use CUDA backend (forced by --backend nvidia)"
)
self
.
_last_usage
:
dict
=
{}
# usage from the most recent completion call
self
.
_detect_chat_template
()
self
.
_detect_chat_template
()
def
_detect_chat_template
(
self
):
def
_detect_chat_template
(
self
):
...
@@ -649,6 +650,8 @@ class VulkanBackend(ModelBackend):
...
@@ -649,6 +650,8 @@ class VulkanBackend(ModelBackend):
stop
=
stop
,
stop
=
stop
,
grammar
=
use_grammar
,
grammar
=
use_grammar
,
)
)
usage
=
result
.
get
(
'usage'
,
{})
self
.
_store_usage
(
usage
.
get
(
'prompt_tokens'
,
0
),
usage
.
get
(
'completion_tokens'
,
0
))
return
result
[
'choices'
][
0
][
'text'
]
return
result
[
'choices'
][
0
][
'text'
]
except
Exception
as
e
:
except
Exception
as
e
:
# If grammar generation fails, fall back to normal generation
# If grammar generation fails, fall back to normal generation
...
@@ -664,6 +667,8 @@ class VulkanBackend(ModelBackend):
...
@@ -664,6 +667,8 @@ class VulkanBackend(ModelBackend):
repeat_penalty
=
repeat_penalty
,
repeat_penalty
=
repeat_penalty
,
stop
=
stop
,
stop
=
stop
,
)
)
usage
=
result
.
get
(
'usage'
,
{})
self
.
_store_usage
(
usage
.
get
(
'prompt_tokens'
,
0
),
usage
.
get
(
'completion_tokens'
,
0
))
return
result
[
'choices'
][
0
][
'text'
]
return
result
[
'choices'
][
0
][
'text'
]
except
Exception
as
e2
:
except
Exception
as
e2
:
print
(
f
"Error during fallback generation: {e2}"
)
print
(
f
"Error during fallback generation: {e2}"
)
...
@@ -935,6 +940,112 @@ class VulkanBackend(ModelBackend):
...
@@ -935,6 +940,112 @@ class VulkanBackend(ModelBackend):
"n_gpu_layers"
:
self
.
n_gpu_layers
,
"n_gpu_layers"
:
self
.
n_gpu_layers
,
}
}
# ------------------------------------------------------------------
# Usage / cache helpers
# ------------------------------------------------------------------
def
_read_cached_tokens
(
self
,
prompt_tokens
:
int
)
->
int
:
"""Extract cached token count from llama.cpp timings after a completion."""
try
:
timings
=
getattr
(
self
.
model
,
'timings'
,
None
)
if
timings
is
None
:
# Try the internal context if timings property not exposed
ctx
=
getattr
(
self
.
model
,
'_ctx'
,
None
)
if
ctx
and
hasattr
(
ctx
,
'timings'
):
timings
=
ctx
.
timings
()
if
timings
is
not
None
:
n_p_eval
=
getattr
(
timings
,
'n_p_eval'
,
None
)
if
n_p_eval
is
not
None
:
return
max
(
0
,
prompt_tokens
-
int
(
n_p_eval
))
except
Exception
:
pass
return
0
def
_store_usage
(
self
,
prompt_tokens
:
int
,
completion_tokens
:
int
)
->
None
:
cached
=
self
.
_read_cached_tokens
(
prompt_tokens
)
self
.
_last_usage
=
{
'prompt_tokens'
:
prompt_tokens
,
'completion_tokens'
:
completion_tokens
,
'total_tokens'
:
prompt_tokens
+
completion_tokens
,
'cached_tokens'
:
cached
,
}
def
get_last_usage
(
self
)
->
dict
:
"""Return usage dict from the most recent completion (includes cached_tokens)."""
return
dict
(
self
.
_last_usage
)
# ------------------------------------------------------------------
# Chat-level generation (uses llama.cpp native chat template)
# ------------------------------------------------------------------
def
generate_chat
(
self
,
messages
,
max_tokens
=
None
,
temperature
=
0.7
,
top_p
=
1.0
,
stop
=
None
,
tools
=
None
,
response_format
=
None
):
"""Non-streaming chat completion using llama.cpp's native chat handler."""
if
self
.
model
is
None
:
raise
RuntimeError
(
"Model not loaded"
)
kwargs
=
dict
(
messages
=
messages
,
max_tokens
=
max_tokens
or
512
,
temperature
=
temperature
,
top_p
=
top_p
,
)
if
stop
:
kwargs
[
'stop'
]
=
stop
if
response_format
and
response_format
.
get
(
'type'
)
==
'json_object'
:
kwargs
[
'response_format'
]
=
{
'type'
:
'json_object'
}
result
=
self
.
model
.
create_chat_completion
(
**
kwargs
)
usage
=
result
.
get
(
'usage'
,
{})
self
.
_store_usage
(
prompt_tokens
=
usage
.
get
(
'prompt_tokens'
,
0
),
completion_tokens
=
usage
.
get
(
'completion_tokens'
,
0
),
)
content
=
result
[
'choices'
][
0
][
'message'
]
.
get
(
'content'
)
or
''
return
content
async
def
generate_chat_stream
(
self
,
messages
,
max_tokens
=
None
,
temperature
=
0.7
,
top_p
=
1.0
,
stop
=
None
,
tools
=
None
,
response_format
=
None
):
"""Streaming chat completion using llama.cpp's native chat handler."""
if
self
.
model
is
None
:
raise
RuntimeError
(
"Model not loaded"
)
kwargs
=
dict
(
messages
=
messages
,
max_tokens
=
max_tokens
or
512
,
temperature
=
temperature
,
top_p
=
top_p
,
stream
=
True
,
)
if
stop
:
kwargs
[
'stop'
]
=
stop
prompt_tokens
=
0
completion_tokens
=
0
try
:
for
chunk
in
self
.
model
.
create_chat_completion
(
**
kwargs
):
delta
=
chunk
[
'choices'
][
0
]
.
get
(
'delta'
,
{})
text
=
delta
.
get
(
'content'
)
or
''
if
text
:
completion_tokens
+=
1
yield
text
# Capture usage if present in final streaming chunk
if
chunk
.
get
(
'usage'
):
u
=
chunk
[
'usage'
]
prompt_tokens
=
u
.
get
(
'prompt_tokens'
,
0
)
completion_tokens
=
u
.
get
(
'completion_tokens'
,
completion_tokens
)
if
chunk
[
'choices'
][
0
]
.
get
(
'finish_reason'
):
break
finally
:
# Timings are available after the stream is exhausted
if
prompt_tokens
==
0
:
# Estimate from word split if llama.cpp didn't report
prompt_tokens
=
sum
(
len
(
str
(
m
.
get
(
'content'
,
''
))
.
split
())
for
m
in
messages
)
self
.
_store_usage
(
prompt_tokens
,
completion_tokens
)
def
get_model_name
(
self
)
->
str
:
def
get_model_name
(
self
)
->
str
:
"""Return the loaded model name."""
"""Return the loaded model name."""
return
self
.
model_name
or
"unknown"
return
self
.
model_name
or
"unknown"
...
...
codai/models/manager.py
View file @
07fb3251
...
@@ -233,6 +233,12 @@ class ModelManager:
...
@@ -233,6 +233,12 @@ class ModelManager:
if
self
.
backend
is
not
None
:
if
self
.
backend
is
not
None
:
return
self
.
backend
.
get_context_size
()
return
self
.
backend
.
get_context_size
()
return
2048
# Default fallback
return
2048
# Default fallback
def
get_last_usage
(
self
)
->
dict
:
"""Return usage info (including cached_tokens) from the most recent call."""
if
self
.
backend
is
not
None
and
hasattr
(
self
.
backend
,
'get_last_usage'
):
return
self
.
backend
.
get_last_usage
()
return
{}
def
cleanup
(
self
):
def
cleanup
(
self
):
if
self
.
backend
is
not
None
:
if
self
.
backend
is
not
None
:
...
@@ -2040,11 +2046,29 @@ class MultiModelManager:
...
@@ -2040,11 +2046,29 @@ class MultiModelManager:
"embedding_models"
:
"embedding"
,
"embedding_models"
:
"embedding"
,
}
}
# Minimum capability guaranteed by a model's config category.
# Applied when heuristic name detection doesn't recognise the model ID.
TYPE_MIN_CAP
=
{
"image"
:
"image_generation"
,
"video"
:
"video_generation"
,
"audio"
:
"speech_to_text"
,
"tts"
:
"text_to_speech"
,
"audio_gen"
:
"audio_generation"
,
"embedding"
:
"embeddings"
,
}
def
_add
(
model_id
:
str
,
model_type
:
str
=
None
,
meta
:
Dict
[
str
,
Any
]
=
None
):
def
_add
(
model_id
:
str
,
model_type
:
str
=
None
,
meta
:
Dict
[
str
,
Any
]
=
None
):
if
model_id
in
seen_ids
:
if
model_id
in
seen_ids
:
return
return
seen_ids
.
add
(
model_id
)
seen_ids
.
add
(
model_id
)
caps
=
detect_model_capabilities
(
model_id
)
caps
=
detect_model_capabilities
(
model_id
)
# If heuristic detection missed the type (e.g. custom/vendor model IDs
# that don't match any keyword), ensure the minimum capability for the
# config-declared type is set so badges display correctly.
if
model_type
and
model_type
in
TYPE_MIN_CAP
:
min_cap
=
TYPE_MIN_CAP
[
model_type
]
if
not
getattr
(
caps
,
min_cap
,
False
):
setattr
(
caps
,
min_cap
,
True
)
resolved_type
=
model_type
or
(
caps
.
to_list
()[
0
]
.
split
(
"_"
)[
0
]
if
caps
.
to_list
()
else
"text"
)
resolved_type
=
model_type
or
(
caps
.
to_list
()[
0
]
.
split
(
"_"
)[
0
]
if
caps
.
to_list
()
else
"text"
)
meta
=
meta
or
{}
meta
=
meta
or
{}
models
.
append
(
ModelInfo
(
models
.
append
(
ModelInfo
(
...
@@ -2075,6 +2099,11 @@ class MultiModelManager:
...
@@ -2075,6 +2099,11 @@ class MultiModelManager:
else
:
else
:
raw
=
m
.
get
(
"path"
)
or
m
.
get
(
"id"
)
or
""
raw
=
m
.
get
(
"path"
)
or
m
.
get
(
"id"
)
or
""
alias
=
m
.
get
(
"alias"
)
or
""
alias
=
m
.
get
(
"alias"
)
or
""
# Auto-derive a clean alias for GGUF files that have no
# explicit alias so the full filesystem path isn't exposed.
if
not
alias
and
raw
.
lower
()
.
endswith
(
".gguf"
):
stem
=
raw
.
split
(
"/"
)[
-
1
][:
-
5
]
# filename without .gguf
alias
=
stem
# whisper-server aliases are round-robin group keys shared across
# whisper-server aliases are round-robin group keys shared across
# multiple instances — don't expose the alias as a separate model
# multiple instances — don't expose the alias as a separate model
if
m
.
get
(
"backend"
)
==
"whisper-server"
:
if
m
.
get
(
"backend"
)
==
"whisper-server"
:
...
...
codai/pydantic/textrequest.py
View file @
07fb3251
...
@@ -39,6 +39,7 @@ class ChatMessage(BaseModel):
...
@@ -39,6 +39,7 @@ class ChatMessage(BaseModel):
name
:
Optional
[
str
]
=
None
name
:
Optional
[
str
]
=
None
tool_calls
:
Optional
[
List
[
Dict
]]
=
None
tool_calls
:
Optional
[
List
[
Dict
]]
=
None
tool_call_id
:
Optional
[
str
]
=
None
tool_call_id
:
Optional
[
str
]
=
None
cache_control
:
Optional
[
Dict
]
=
None
# OpenAI-style: {"type": "ephemeral"}
@
field_validator
(
'content'
,
mode
=
'before'
)
@
field_validator
(
'content'
,
mode
=
'before'
)
@
classmethod
@
classmethod
...
...
codai/queue/manager.py
View file @
07fb3251
...
@@ -40,6 +40,7 @@ class WaitingRequest:
...
@@ -40,6 +40,7 @@ class WaitingRequest:
sequence
:
int
sequence
:
int
event
:
asyncio
.
Event
=
field
(
default_factory
=
asyncio
.
Event
)
event
:
asyncio
.
Event
=
field
(
default_factory
=
asyncio
.
Event
)
bypassed_by
:
int
=
0
bypassed_by
:
int
=
0
prefix_key
:
str
=
""
# stable hash of the cacheable prompt prefix
class
QueueManager
:
class
QueueManager
:
...
@@ -61,6 +62,7 @@ class QueueManager:
...
@@ -61,6 +62,7 @@ class QueueManager:
self
.
model_name
:
Optional
[
str
]
=
None
self
.
model_name
:
Optional
[
str
]
=
None
self
.
_processing
:
bool
=
False
self
.
_processing
:
bool
=
False
self
.
_ready_request_ids
:
Set
[
str
]
=
set
()
self
.
_ready_request_ids
:
Set
[
str
]
=
set
()
self
.
_last_prefix_key
:
str
=
""
# prefix key of the last completed request
def
set_loaded_models
(
self
,
model_keys
:
Set
[
str
])
->
None
:
def
set_loaded_models
(
self
,
model_keys
:
Set
[
str
])
->
None
:
self
.
loaded_models
=
set
(
model_keys
)
self
.
loaded_models
=
set
(
model_keys
)
...
@@ -83,17 +85,19 @@ class QueueManager:
...
@@ -83,17 +85,19 @@ class QueueManager:
self
.
model_name
=
None
self
.
model_name
=
None
self
.
_processing
=
False
self
.
_processing
=
False
self
.
_ready_request_ids
.
clear
()
self
.
_ready_request_ids
.
clear
()
self
.
_last_prefix_key
=
""
async
def
is_full
(
self
)
->
bool
:
async
def
is_full
(
self
)
->
bool
:
async
with
self
.
lock
:
async
with
self
.
lock
:
return
len
(
self
.
waiting
)
>=
self
.
max_size
return
len
(
self
.
waiting
)
>=
self
.
max_size
async
def
acquire
(
self
,
request_id
:
str
,
model_key
:
str
)
->
SchedulerLease
:
async
def
acquire
(
self
,
request_id
:
str
,
model_key
:
str
,
prefix_key
:
str
=
""
)
->
SchedulerLease
:
waiter
=
None
waiter
=
None
async
with
self
.
lock
:
async
with
self
.
lock
:
if
self
.
_can_start_now
(
model_key
):
if
self
.
_can_start_now
(
model_key
):
return
self
.
_grant_lease
(
request_id
,
model_key
)
return
self
.
_grant_lease
(
request_id
,
model_key
)
waiter
=
self
.
_enqueue_waiter
(
request_id
,
model_key
)
waiter
=
self
.
_enqueue_waiter
(
request_id
,
model_key
,
prefix_key
)
await
waiter
.
event
.
wait
()
await
waiter
.
event
.
wait
()
async
with
self
.
lock
:
async
with
self
.
lock
:
...
@@ -103,7 +107,8 @@ class QueueManager:
...
@@ -103,7 +107,8 @@ class QueueManager:
lease
.
wait_time_seconds
=
max
(
0.0
,
time
.
time
()
-
waiter
.
enqueued_at
)
lease
.
wait_time_seconds
=
max
(
0.0
,
time
.
time
()
-
waiter
.
enqueued_at
)
return
lease
return
lease
async
def
release
(
self
,
lease
:
SchedulerLease
)
->
None
:
async
def
release
(
self
,
lease
:
SchedulerLease
,
prefix_key
:
str
=
""
)
->
None
:
async
with
self
.
lock
:
async
with
self
.
lock
:
self
.
active_leases
.
pop
(
lease
.
request_id
,
None
)
self
.
active_leases
.
pop
(
lease
.
request_id
,
None
)
current
=
self
.
active_by_model
.
get
(
lease
.
model_key
,
0
)
current
=
self
.
active_by_model
.
get
(
lease
.
model_key
,
0
)
...
@@ -113,14 +118,17 @@ class QueueManager:
...
@@ -113,14 +118,17 @@ class QueueManager:
self
.
active_by_model
[
lease
.
model_key
]
=
current
-
1
self
.
active_by_model
[
lease
.
model_key
]
=
current
-
1
if
self
.
current_request_id
==
lease
.
request_id
:
if
self
.
current_request_id
==
lease
.
request_id
:
self
.
current_request_id
=
None
self
.
current_request_id
=
None
if
prefix_key
:
self
.
_last_prefix_key
=
prefix_key
self
.
_processing
=
bool
(
self
.
active_leases
)
self
.
_processing
=
bool
(
self
.
active_leases
)
self
.
_wake_waiters_locked
()
self
.
_wake_waiters_locked
()
async
def
add_waiting
(
self
,
request_id
:
str
,
model_key
:
str
=
""
)
->
None
:
async
def
add_waiting
(
self
,
request_id
:
str
,
model_key
:
str
=
""
,
prefix_key
:
str
=
""
)
->
None
:
async
with
self
.
lock
:
async
with
self
.
lock
:
if
request_id
in
self
.
waiting_by_id
:
if
request_id
in
self
.
waiting_by_id
:
return
return
self
.
_enqueue_waiter
(
request_id
,
model_key
or
request_id
)
self
.
_enqueue_waiter
(
request_id
,
model_key
or
request_id
,
prefix_key
)
async
def
remove_waiting
(
self
,
request_id
:
str
)
->
None
:
async
def
remove_waiting
(
self
,
request_id
:
str
)
->
None
:
async
with
self
.
lock
:
async
with
self
.
lock
:
...
@@ -172,13 +180,15 @@ class QueueManager:
...
@@ -172,13 +180,15 @@ class QueueManager:
"loaded_models"
:
sorted
(
self
.
loaded_models
),
"loaded_models"
:
sorted
(
self
.
loaded_models
),
}
}
def
_enqueue_waiter
(
self
,
request_id
:
str
,
model_key
:
str
)
->
WaitingRequest
:
def
_enqueue_waiter
(
self
,
request_id
:
str
,
model_key
:
str
,
prefix_key
:
str
=
""
)
->
WaitingRequest
:
self
.
sequence
+=
1
self
.
sequence
+=
1
waiter
=
WaitingRequest
(
waiter
=
WaitingRequest
(
request_id
=
request_id
,
request_id
=
request_id
,
model_key
=
model_key
,
model_key
=
model_key
,
enqueued_at
=
time
.
time
(),
enqueued_at
=
time
.
time
(),
sequence
=
self
.
sequence
,
sequence
=
self
.
sequence
,
prefix_key
=
prefix_key
,
)
)
self
.
waiting
.
append
(
waiter
)
self
.
waiting
.
append
(
waiter
)
self
.
waiting_by_id
[
request_id
]
=
waiter
self
.
waiting_by_id
[
request_id
]
=
waiter
...
@@ -233,17 +243,39 @@ class QueueManager:
...
@@ -233,17 +243,39 @@ class QueueManager:
return
return
def
_pick_next_waiter_locked
(
self
)
->
Optional
[
WaitingRequest
]:
def
_pick_next_waiter_locked
(
self
)
->
Optional
[
WaitingRequest
]:
for
waiter
in
self
.
waiting
:
# Collect all candidates that can start now.
if
self
.
_waiter_can_start_locked
(
waiter
):
candidates
=
[
w
for
w
in
self
.
waiting
if
self
.
_waiter_can_start_locked
(
w
)]
older_blocked
=
[
if
not
candidates
:
other
for
other
in
self
.
waiting
return
None
if
other
.
sequence
<
waiter
.
sequence
and
not
self
.
_waiter_can_start_locked
(
other
)
]
# Fairness: don't bypass an older waiter more than the limit.
if
any
(
other
.
bypassed_by
>=
self
.
fairness_bypass_limit
for
other
in
older_blocked
):
def
_is_fair
(
waiter
:
WaitingRequest
)
->
bool
:
continue
older_blocked
=
[
for
other
in
older_blocked
:
other
for
other
in
self
.
waiting
other
.
bypassed_by
+=
1
if
other
.
sequence
<
waiter
.
sequence
and
not
self
.
_waiter_can_start_locked
(
other
)
]
if
any
(
other
.
bypassed_by
>=
self
.
fairness_bypass_limit
for
other
in
older_blocked
):
return
False
for
other
in
older_blocked
:
other
.
bypassed_by
+=
1
return
True
# Prompt aggregation: prefer candidates whose prefix key matches the
# last completed request — they will hit a warm KV cache.
if
self
.
_last_prefix_key
:
warm_candidates
=
[
w
for
w
in
candidates
if
w
.
prefix_key
and
w
.
prefix_key
==
self
.
_last_prefix_key
]
for
waiter
in
warm_candidates
:
if
_is_fair
(
waiter
):
return
waiter
# Fall back to FIFO order.
for
waiter
in
candidates
:
if
_is_fair
(
waiter
):
return
waiter
return
waiter
return
None
return
None
def
_waiting_counts_locked
(
self
)
->
Dict
[
str
,
int
]:
def
_waiting_counts_locked
(
self
)
->
Dict
[
str
,
int
]:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment