Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Contribute to GitLab
Sign in
Toggle navigation
C
coderai
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
nexlab
coderai
Commits
07fb3251
Commit
07fb3251
authored
May 07, 2026
by
Stefy Lanza (nextime / spora )
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Add prompt caching and prompt aggregation
parent
0ac26bed
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
939 additions
and
70 deletions
+939
-70
README.md
README.md
+8
-0
chat.html
codai/admin/templates/chat.html
+109
-15
images.py
codai/api/images.py
+155
-30
prompt_cache.py
codai/api/prompt_cache.py
+151
-0
text.py
codai/api/text.py
+42
-9
cuda.py
codai/backends/cuda.py
+285
-0
vulkan.py
codai/backends/vulkan.py
+111
-0
manager.py
codai/models/manager.py
+29
-0
textrequest.py
codai/pydantic/textrequest.py
+1
-0
manager.py
codai/queue/manager.py
+48
-16
No files found.
README.md
View file @
07fb3251
...
...
@@ -13,6 +13,14 @@ An OpenAI-compatible API server to run models on your local GPU with web adminis
-
**Multi-Modal**
: Text, image, video, audio, TTS, STT, embeddings
-
**Per-Model Configuration**
: Individual settings for each model (GPU layers, quantization, context size)
-
**On-Demand Loading**
: Models load automatically when requested, unload when idle
-
**Memory Management**
: Smart VRAM → RAM → Disk offloading for efficient resource usage
-
**Parallel Execution**
: Run multiple models simultaneously (VRAM permitting)
-
**Auto-Swap**
: Automatic model switching on request — load what's needed, unload what's idle
-
**Request Queue**
: Concurrent requests are queued and processed in order per model
-
**Prompt Caching**
: Reuse KV cache across requests to reduce latency and computation
-
**Prompt Aggregation**
: Batch concurrent requests into a single inference pass for higher throughput
-
**Custom Pipelines**
: Create and save multi-step workflows combining any generation tasks
-
**Pre-Built Pipelines**
: Ready-to-use pipelines for common workflows (image-to-video, dubbing, story generation)
### GPU Backend Support
-
**NVIDIA (CUDA)**
: PyTorch + Transformers for HuggingFace models
...
...
codai/admin/templates/chat.html
View file @
07fb3251
...
...
@@ -256,6 +256,27 @@ a.dl { display:inline-block; margin-top:.4rem; }
.pb-step-header
{
display
:
flex
;
align-items
:
center
;
gap
:
.4rem
;
font-size
:
12px
;
font-weight
:
600
;
}
.pb-step-params
{
display
:
flex
;
flex-direction
:
column
;
gap
:
.25rem
;
padding-top
:
.25rem
;
}
.pb-step-param
{
display
:
flex
;
align-items
:
center
;
gap
:
.4rem
;
font-size
:
12px
;
}
/* ── Pipeline capability chips ────────────────────────────────── */
.pipe-caps
{
display
:
flex
;
flex-wrap
:
wrap
;
align-items
:
center
;
gap
:
.3rem
;
padding-bottom
:
.6rem
;
margin-bottom
:
.25rem
;
border-bottom
:
1px
solid
var
(
--border
);
}
.pipe-caps-label
{
font-size
:
10px
;
color
:
var
(
--text-3
);
text-transform
:
uppercase
;
letter-spacing
:
.05em
;
margin-right
:
.1rem
;
}
.pipe-cap-chip
{
font-size
:
10px
;
padding
:
.1rem
.4rem
;
border-radius
:
4px
;
border
:
1px
solid
transparent
;
}
.pipe-cap-chip.ok
{
background
:
#0d2e18
;
color
:
#4ade80
;
border-color
:
#1a4a1a
;
}
.pipe-cap-chip.missing
{
background
:
#2e0d0d
;
color
:
#f07070
;
border-color
:
#5a1a1a
;
}
.pipe-cap-chip.optional
{
background
:
var
(
--surface-2
);
color
:
var
(
--text-3
);
}
.pipe-cap-chip.optional.ok
{
background
:
#1a1f0a
;
color
:
#c0d060
;
border-color
:
#3a4a10
;
}
/* ── Sub-tab model picker ─────────────────────────────────────── */
.cap-model-picker
{
display
:
flex
;
align-items
:
center
;
gap
:
.5rem
;
flex-wrap
:
wrap
;
padding-top
:
.4rem
;
border-top
:
1px
solid
var
(
--border
);
margin-top
:
.1rem
;
}
.cap-model-picker-label
{
font-size
:
10px
;
color
:
var
(
--text-3
);
text-transform
:
uppercase
;
letter-spacing
:
.05em
;
white-space
:
nowrap
;
}
.cap-model-chips
{
display
:
flex
;
flex-wrap
:
wrap
;
gap
:
.3rem
;
}
.cap-model-chip
{
font-size
:
11px
;
padding
:
.18rem
.55rem
;
border-radius
:
999px
;
border
:
1px
solid
;
cursor
:
pointer
;
background
:
transparent
;
font-family
:
inherit
;
transition
:
background
.15s
;
}
.cap-model-chip.ok
{
color
:
#4ade80
;
border-color
:
#2a5a2a
;
background
:
#0f2a0f
;
}
.cap-model-chip.ok
:hover
{
background
:
#1a3f1a
;
}
.cap-model-chip.warn
{
color
:
#f0c060
;
border-color
:
#5a3a10
;
background
:
#2a1f0a
;
}
.cap-model-chip.warn
:hover
{
background
:
#3a2f0a
;
}
.cap-model-chip.active
{
font-weight
:
700
;
outline
:
1px
solid
currentColor
;
outline-offset
:
1px
;
}
/* ── Sidebar capability highlights ───────────────────────────── */
.model-item.cap-ok
{
border-left
:
3px
solid
#4caf50
;
padding-left
:
calc
(
.6rem
-
3px
);
}
.model-item.cap-partial
{
border-left
:
3px
solid
#f0a020
;
padding-left
:
calc
(
.6rem
-
3px
);
}
.pb-step-param
label
{
min-width
:
110px
;
color
:
var
(
--text-2
);
flex-shrink
:
0
;
}
.pb-step-param
input
,
.pb-step-param
select
,
.pb-step-param
textarea
{
flex
:
1
;
font-size
:
12px
;
}
.pb-step-param
textarea
{
rows
:
2
;
resize
:
vertical
;
min-height
:
40px
;
}
...
...
@@ -1688,19 +1709,42 @@ function updatePipelineBadges() {
}
badge
.
className
=
`pipe-badge
${
state
}
`
;
badge
.
textContent
=
label
;
// Capability chips in the card body
const
body
=
card
.
querySelector
(
'.pipe-card-body'
);
if
(
body
)
{
let
capsDiv
=
body
.
querySelector
(
'.pipe-caps'
);
if
(
!
capsDiv
)
{
capsDiv
=
document
.
createElement
(
'div'
);
capsDiv
.
className
=
'pipe-caps'
;
body
.
insertBefore
(
capsDiv
,
body
.
firstChild
);
}
const
fmt
=
c
=>
c
.
replace
(
/_/g
,
' '
);
const
reqChips
=
reqs
.
map
(
c
=>
`<span class="pipe-cap-chip
${
allCaps
.
has
(
c
)
?
' ok'
:
' missing'
}
" title="
${
c
}
">
${
fmt
(
c
)}
</span>`
).
join
(
''
);
const
optChips
=
opts
.
map
(
c
=>
`<span class="pipe-cap-chip optional
${
allCaps
.
has
(
c
)
?
' ok'
:
''
}
" title="
${
c
}
">
${
fmt
(
c
)}
</span>`
).
join
(
''
);
const
parts
=
[];
if
(
reqChips
)
parts
.
push
(
`<span class="pipe-caps-label">Requires</span>
${
reqChips
}
`
);
if
(
optChips
)
parts
.
push
(
`<span class="pipe-caps-label">Optional</span>
${
optChips
}
`
);
capsDiv
.
innerHTML
=
parts
.
join
(
''
);
}
});
}
function
hasTypeFallback
(
rule
,
type
)
{
return
Array
.
isArray
(
rule
.
fallbackTypes
)
&&
rule
.
fallbackTypes
.
includes
(
type
);
function
hasTypeFallback
(
rule
,
typeOrTypes
)
{
if
(
!
Array
.
isArray
(
rule
.
fallbackTypes
))
return
false
;
if
(
typeOrTypes
instanceof
Set
)
return
rule
.
fallbackTypes
.
some
(
t
=>
typeOrTypes
.
has
(
t
));
return
rule
.
fallbackTypes
.
includes
(
typeOrTypes
);
}
function
evaluateSubCapability
(
rule
,
caps
,
type
)
{
function
evaluateSubCapability
(
rule
,
caps
,
type
OrTypes
)
{
const
required
=
rule
.
requiresAny
||
[];
const
optional
=
rule
.
optional
||
[];
const
hasRequired
=
required
.
some
(
cap
=>
caps
.
has
(
cap
));
const
hasOptional
=
optional
.
some
(
cap
=>
caps
.
has
(
cap
));
const
fallback
=
hasTypeFallback
(
rule
,
type
);
const
fallback
=
hasTypeFallback
(
rule
,
type
OrTypes
);
if
(
hasRequired
)
return
'available'
;
if
(
!
required
.
length
&&
(
hasOptional
||
fallback
))
return
'partial'
;
...
...
@@ -1754,7 +1798,7 @@ function getSubtabState(sub) {
function
getCapabilityDetails
(
sub
)
{
const
def
=
STUDIO_CAPABILITIES
[
sub
];
if
(
!
def
)
return
null
;
const
caps
=
c
apabilitySetForModel
(
activeModel
);
const
caps
=
c
rossModelCaps
(
);
const
required
=
def
.
requires
||
[];
const
optional
=
def
.
optional
||
[];
const
missingRequired
=
required
.
filter
(
cap
=>
!
caps
.
has
(
cap
));
...
...
@@ -1802,6 +1846,7 @@ function renderCapabilityCard(sub) {
</div>
${
missingBits
.
join
(
''
)}
${
notes
}
${
renderSubModelPicker
(
sub
)}
`
;
}
...
...
@@ -1813,14 +1858,25 @@ function renderCapabilityCards() {
const
shell
=
$
(
`cap-
${
sub
}
`
);
if
(
!
shell
)
return
;
const
state
=
currentTabState
.
subs
[
sub
]
||
'unavailable'
;
if
(
state
===
'available'
)
{
shell
.
style
.
display
=
'none'
;
shell
.
innerHTML
=
''
;
return
;
}
const
picker
=
renderSubModelPicker
(
sub
);
if
(
state
===
'available'
)
{
if
(
picker
)
{
shell
.
style
.
display
=
''
;
shell
.
classList
.
remove
(
'state-partial'
,
'state-unavailable'
);
shell
.
innerHTML
=
picker
;
}
else
{
shell
.
style
.
display
=
'none'
;
shell
.
innerHTML
=
''
;
}
return
;
}
shell
.
style
.
display
=
''
;
shell
.
classList
.
remove
(
'state-partial'
,
'state-unavailable'
);
shell
.
classList
.
add
(
state
===
'partial'
?
'state-partial'
:
'state-unavailable'
);
const
rule
=
SUB_CAPABILITY_RULES
[
sub
];
const
caps
=
capabilitySetForModel
(
activeModel
);
const
missingRequired
=
(
rule
.
requiresAny
||
[]).
filter
(
c
=>
!
c
aps
.
has
(
c
));
const
missingOptional
=
(
rule
.
optional
||
[]).
filter
(
c
=>
!
c
aps
.
has
(
c
));
const
allCaps
=
crossModelCaps
(
);
const
missingRequired
=
(
rule
.
requiresAny
||
[]).
filter
(
c
=>
!
allC
aps
.
has
(
c
));
const
missingOptional
=
(
rule
.
optional
||
[]).
filter
(
c
=>
!
allC
aps
.
has
(
c
));
const
label
=
document
.
querySelector
(
`.t2btn[data-sub="
${
sub
}
"]`
)?.
childNodes
[
0
]?.
textContent
?.
trim
()
||
sub
;
const
missingBits
=
[];
if
(
missingRequired
.
length
)
missingBits
.
push
(
`<div class="cap-missing"><strong>Missing required:</strong>
${
missingRequired
.
join
(
', '
)}
</div>`
);
...
...
@@ -1833,6 +1889,7 @@ function renderCapabilityCards() {
<span class="cap-chip
${
availabilityClass
}
">
${
availabilityLabel
}
</span>
</div>
${
missingBits
.
join
(
''
)}
${
picker
}
`
;
});
renderAudioBackendHealth
();
...
...
@@ -2471,22 +2528,21 @@ function selectModel(m) {
}
function
updateTabs
(
m
)
{
const
caps
=
capabilitySetForModel
(
m
);
const
allCaps
=
crossModelCaps
();
const
allTypes
=
new
Set
(
models
.
map
(
mdl
=>
mdl
.
type
||
'text'
));
const
type
=
m
.
type
||
'text'
;
refreshAudioBackendHealth
();
const
subStates
=
{};
Object
.
entries
(
SUB_CAPABILITY_RULES
).
forEach
(([
sub
,
rule
])
=>
{
if
(
VIDEO_EXTRA_SUBS
.
includes
(
sub
)
&&
type
===
'video'
&&
!
rule
.
fallbackTypes
)
{
if
(
VIDEO_EXTRA_SUBS
.
includes
(
sub
)
&&
allTypes
.
has
(
'video'
)
&&
!
rule
.
fallbackTypes
)
{
rule
=
Object
.
assign
({},
rule
,
{
fallbackTypes
:[
'video'
]
});
}
const
evalCaps
=
CROSS_MODEL_SUBS
.
has
(
sub
)
?
allCaps
:
caps
;
subStates
[
sub
]
=
evaluateSubCapability
(
rule
,
evalCaps
,
type
);
subStates
[
sub
]
=
evaluateSubCapability
(
rule
,
allCaps
,
allTypes
);
});
updatePipelineBadges
();
const
categoryStates
=
{};
CATEGORY_TABS
.
forEach
(
cat
=>
{
categoryStates
[
cat
]
=
evaluateCategoryState
(
cat
,
subStates
,
c
aps
,
type
);
categoryStates
[
cat
]
=
evaluateCategoryState
(
cat
,
subStates
,
allC
aps
,
type
);
});
currentTabState
=
{
categories
:
categoryStates
,
subs
:
subStates
};
...
...
@@ -2496,7 +2552,7 @@ function updateTabs(m) {
document
.
querySelectorAll
(
'.t2btn'
).
forEach
(
btn
=>
{
setTabVisualState
(
btn
,
subStates
[
btn
.
dataset
.
sub
]
||
'unavailable'
);
});
$
(
'attach-btn'
).
style
.
display
=
cap
s
.
has
(
'image_to_text'
)
?
''
:
'none'
;
$
(
'attach-btn'
).
style
.
display
=
cap
abilitySetForModel
(
m
)
.
has
(
'image_to_text'
)
?
''
:
'none'
;
renderCapabilityCards
();
renderDiagnostics
();
renderOutputCapabilityNotes
();
...
...
@@ -2510,6 +2566,7 @@ function selectCat(cat) {
const
hasL2
=
[
'image'
,
'video'
,
'audio'
].
includes
(
cat
);
$
(
'tabbar2'
).
classList
.
toggle
(
'visible'
,
hasL2
);
if
(
!
hasL2
)
{
clearSidebarHighlights
();
document
.
querySelectorAll
(
'.panel'
).
forEach
(
p
=>
p
.
classList
.
remove
(
'active'
));
const
panel
=
$
(
'panel-'
+
cat
);
if
(
panel
)
panel
.
classList
.
add
(
'active'
);
...
...
@@ -2526,6 +2583,42 @@ function selectCat(cat) {
if
(
nextSub
)
selectSub
(
nextSub
);
}
function
modelsForSub
(
sub
)
{
const
rule
=
SUB_CAPABILITY_RULES
[
sub
];
if
(
!
rule
)
return
[];
return
models
.
map
(
m
=>
{
const
state
=
evaluateSubCapability
(
rule
,
capabilitySetForModel
(
m
),
m
.
type
||
'text'
);
return
{
model
:
m
,
state
};
}).
filter
(
item
=>
item
.
state
!==
'unavailable'
);
}
function
renderSubModelPicker
(
sub
)
{
const
compatible
=
modelsForSub
(
sub
);
if
(
!
compatible
.
length
)
return
''
;
const
chips
=
compatible
.
map
(({
model
,
state
})
=>
{
const
cls
=
state
===
'available'
?
'ok'
:
'warn'
;
const
isActive
=
activeModel
&&
model
.
id
===
activeModel
.
id
;
const
label
=
escapeHtml
(
model
.
id
.
split
(
'/'
).
pop
());
const
safe
=
JSON
.
stringify
(
model
).
replace
(
/"/g
,
'"'
);
return
`<button class="cap-model-chip
${
cls
}${
isActive
?
' active'
:
''
}
" onclick="selectModel(
${
safe
}
)" title="
${
model
.
id
}
">
${
label
}
</button>`
;
}).
join
(
''
);
return
`<div class="cap-model-picker"><span class="cap-model-picker-label">Models</span><div class="cap-model-chips">
${
chips
}
</div></div>`
;
}
function
highlightSidebarForSub
(
sub
)
{
const
compatible
=
new
Map
(
modelsForSub
(
sub
).
map
(({
model
,
state
})
=>
[
model
.
id
,
state
]));
document
.
querySelectorAll
(
'.model-item'
).
forEach
(
el
=>
{
el
.
classList
.
remove
(
'cap-ok'
,
'cap-partial'
);
const
state
=
compatible
.
get
(
el
.
dataset
.
id
);
if
(
state
===
'available'
)
el
.
classList
.
add
(
'cap-ok'
);
else
if
(
state
===
'partial'
)
el
.
classList
.
add
(
'cap-partial'
);
});
}
function
clearSidebarHighlights
()
{
document
.
querySelectorAll
(
'.model-item'
).
forEach
(
el
=>
el
.
classList
.
remove
(
'cap-ok'
,
'cap-partial'
));
}
function
selectSub
(
sub
)
{
if
(
SUB_CAT
[
sub
]
&&
currentTabState
.
subs
[
sub
]
===
undefined
)
return
;
if
(
SUB_CAT
[
sub
])
{
...
...
@@ -2543,6 +2636,7 @@ function selectSub(sub) {
// When switching to vid-faceswap, pre-select video mode
if
(
sub
===
'vid-faceswap'
)
{
const
t
=
$
(
'fs-type'
);
if
(
t
)
{
t
.
value
=
'video'
;
fsFaceSwapTypeChange
();
}
}
if
(
sub
===
'vid-outfit'
)
{
const
t
=
$
(
'ot-type'
);
if
(
t
)
{
t
.
value
=
'video'
;
otOutfitTypeChange
();
}
}
if
(
SUB_CAT
[
sub
])
highlightSidebarForSub
(
sub
);
}
// ─────────────────────────────────────────────────────────────────
...
...
codai/api/images.py
View file @
07fb3251
...
...
@@ -39,6 +39,74 @@ from codai.pydantic.imagerequest import ImageGenerationRequest
from
codai.api.state
import
get_load_mode
# =============================================================================
# Prompt embedding cache (diffusers)
#
# Caches text-encoder outputs keyed by (prompt, negative_prompt, model_name).
# When the same prompt is requested again the encode step is skipped and the
# cached tensors are passed directly to the pipeline, saving CLIP/T5 compute.
# sd.cpp handles encoding internally — no equivalent caching is possible there.
# =============================================================================
import
hashlib
as
_hashlib
import
threading
as
_threading
class
_PromptEmbedCache
:
"""Single-entry LRU cache for diffusers prompt embeddings."""
_MAX_ENTRIES
=
32
_TTL
=
600.0
# 10 minutes
def
__init__
(
self
):
self
.
_store
:
dict
=
{}
# key -> (embeds_dict, timestamp)
self
.
_lock
=
_threading
.
Lock
()
@
staticmethod
def
_key
(
prompt
:
str
,
negative_prompt
:
str
,
model_name
:
str
)
->
str
:
raw
=
f
"{model_name}
\x00
{prompt}
\x00
{negative_prompt or ''}"
return
_hashlib
.
sha256
(
raw
.
encode
())
.
hexdigest
()[:
24
]
def
get
(
self
,
prompt
:
str
,
negative_prompt
:
str
,
model_name
:
str
)
->
Optional
[
dict
]:
k
=
self
.
_key
(
prompt
,
negative_prompt
,
model_name
)
with
self
.
_lock
:
entry
=
self
.
_store
.
get
(
k
)
if
entry
is
None
:
return
None
embeds
,
ts
=
entry
if
time
.
time
()
-
ts
>
self
.
_TTL
:
del
self
.
_store
[
k
]
return
None
return
embeds
def
put
(
self
,
prompt
:
str
,
negative_prompt
:
str
,
model_name
:
str
,
embeds
:
dict
)
->
None
:
k
=
self
.
_key
(
prompt
,
negative_prompt
,
model_name
)
with
self
.
_lock
:
self
.
_store
[
k
]
=
(
embeds
,
time
.
time
())
# Evict oldest if over limit
if
len
(
self
.
_store
)
>
self
.
_MAX_ENTRIES
:
oldest
=
min
(
self
.
_store
,
key
=
lambda
x
:
self
.
_store
[
x
][
1
])
del
self
.
_store
[
oldest
]
def
invalidate_model
(
self
,
model_name
:
str
)
->
None
:
"""Drop all entries for a model (e.g. on pipeline unload)."""
suffix
=
_hashlib
.
sha256
(
model_name
.
encode
())
.
hexdigest
()[:
8
]
with
self
.
_lock
:
drop
=
[
k
for
k
in
self
.
_store
if
self
.
_key
(
""
,
""
,
model_name
)[:
8
]
==
k
[:
8
]
or
True
# safest: just rebuild key and compare
]
# Rebuild properly: iterate and check by re-computing key prefix
# (can't reconstruct original prompts, so use model name hash marker)
self
.
_store
=
{
k
:
v
for
k
,
v
in
self
.
_store
.
items
()
if
not
k
.
startswith
(
_hashlib
.
sha256
(
model_name
.
encode
())
.
hexdigest
()[:
4
])
}
_embed_cache
=
_PromptEmbedCache
()
# Global reference to be set by coderai
global_args
=
None
global_file_path
=
None
...
...
@@ -384,7 +452,7 @@ def _load_diffusers_pipeline(model_name: str, global_args):
def
_generate_with_diffusers
(
pipeline
,
request
,
global_args
,
http_request
=
None
):
"""Generate images using a diffusers pipeline."""
"""Generate images using a diffusers pipeline
(with prompt-embedding cache)
."""
import
torch
import
numpy
as
np
import
time
as
time_module
...
...
@@ -402,13 +470,12 @@ def _generate_with_diffusers(pipeline, request, global_args, http_request=None):
height
=
int
(
parts
[
1
])
except
ValueError
:
pass
# Check for nan/inf in dimensions
if
width
!=
width
or
width
==
float
(
'inf'
):
width
=
512
if
height
!=
height
or
height
==
float
(
'inf'
):
height
=
512
# Enable memory optimizations
try
:
if
hasattr
(
pipeline
,
'enable_attention_slicing'
):
...
...
@@ -417,58 +484,116 @@ def _generate_with_diffusers(pipeline, request, global_args, http_request=None):
pipeline
.
enable_vae_slicing
()
except
Exception
as
e
:
print
(
f
"Warning: Could not enable memory optimizations: {e}"
)
# Get timestamp BEFORE calling diffusers
timestamp
=
int
(
time_module
.
time
())
# Generate images
seed
=
request
.
seed
if
request
.
seed
is
not
None
else
getattr
(
global_args
,
'image_seed'
,
None
)
generator
=
None
if
seed
is
not
None
:
generator
=
torch
.
Generator
(
device
=
pipeline
.
device
)
.
manual_seed
(
seed
)
# Quality: "standard" or "hd"
quality
=
request
.
quality
or
"standard"
# Use request parameters if provided, otherwise fall back to quality-based defaults
num_steps
=
request
.
steps
if
request
.
steps
else
(
30
if
quality
==
"standard"
else
50
)
cfg_scale
=
request
.
guidance_scale
if
request
.
guidance_scale
else
(
getattr
(
global_args
,
'image_cfg_scale'
,
7.5
)
if
quality
==
"standard"
else
9.0
)
# Generate
result
=
pipeline
(
prompt
=
request
.
prompt
,
negative_prompt
=
None
,
num_images_per_prompt
=
request
.
n
,
height
=
height
,
width
=
width
,
generator
=
generator
,
guidance_scale
=
cfg_scale
,
num_inference_steps
=
num_steps
,
)
# ------------------------------------------------------------------
# Prompt embedding cache
# Try to encode the prompt once and reuse the embeddings.
# Falls back to passing the plain text prompt if encoding fails.
# ------------------------------------------------------------------
model_id
=
getattr
(
pipeline
,
'model_name_or_path'
,
None
)
or
str
(
type
(
pipeline
)
.
__name__
)
neg_prompt
=
getattr
(
request
,
'negative_prompt'
,
None
)
or
""
do_cfg
=
cfg_scale
>
1.0
cached_embeds
=
_embed_cache
.
get
(
request
.
prompt
,
neg_prompt
,
model_id
)
embed_kwargs
=
{}
cache_hit
=
False
if
cached_embeds
is
not
None
:
embed_kwargs
=
cached_embeds
cache_hit
=
True
print
(
f
"Prompt embed cache HIT for model '{model_id}'"
)
else
:
# Try to encode and cache
try
:
if
hasattr
(
pipeline
,
'encode_prompt'
):
enc
=
pipeline
.
encode_prompt
(
prompt
=
request
.
prompt
,
device
=
pipeline
.
device
,
num_images_per_prompt
=
1
,
do_classifier_free_guidance
=
do_cfg
,
negative_prompt
=
neg_prompt
or
None
,
)
# enc is a tuple; length varies by pipeline type
if
len
(
enc
)
==
2
:
# SD 1.x: (prompt_embeds, negative_prompt_embeds)
embed_kwargs
=
{
'prompt_embeds'
:
enc
[
0
],
'negative_prompt_embeds'
:
enc
[
1
],
}
elif
len
(
enc
)
==
4
:
# SDXL: (prompt_embeds, negative_prompt_embeds,
# pooled_prompt_embeds, negative_pooled_prompt_embeds)
embed_kwargs
=
{
'prompt_embeds'
:
enc
[
0
],
'negative_prompt_embeds'
:
enc
[
1
],
'pooled_prompt_embeds'
:
enc
[
2
],
'negative_pooled_prompt_embeds'
:
enc
[
3
],
}
if
embed_kwargs
:
_embed_cache
.
put
(
request
.
prompt
,
neg_prompt
,
model_id
,
embed_kwargs
)
print
(
f
"Prompt embed cache STORE for model '{model_id}'"
)
except
Exception
as
e
:
print
(
f
"Warning: prompt encode/cache failed ({e}), using plain text prompt"
)
embed_kwargs
=
{}
# Build call kwargs
if
embed_kwargs
:
call_kwargs
=
dict
(
num_images_per_prompt
=
request
.
n
,
height
=
height
,
width
=
width
,
generator
=
generator
,
guidance_scale
=
cfg_scale
,
num_inference_steps
=
num_steps
,
**
embed_kwargs
,
)
else
:
call_kwargs
=
dict
(
prompt
=
request
.
prompt
,
negative_prompt
=
neg_prompt
or
None
,
num_images_per_prompt
=
request
.
n
,
height
=
height
,
width
=
width
,
generator
=
generator
,
guidance_scale
=
cfg_scale
,
num_inference_steps
=
num_steps
,
)
result
=
pipeline
(
**
call_kwargs
)
# Extract images
images
=
[]
try
:
result_images
=
result
.
images
except
Exception
as
img_err
:
print
(
f
"Warning: Could not access result.images: {img_err}"
)
result_images
=
getattr
(
result
,
'image'
,
None
)
or
getattr
(
result
,
'output'
,
None
)
if
result_images
is
None
:
raise
Exception
(
f
"Could not extract images from diffusers result: {img_err}"
)
for
img
in
result_images
:
if
isinstance
(
img
,
np
.
ndarray
):
img
=
np
.
nan_to_num
(
img
,
nan
=
0.0
,
posinf
=
1.0
,
neginf
=
0.0
)
img
=
np
.
clip
(
img
,
0.0
,
1.0
)
img_data
=
save_image_response
(
img
,
request
.
response_format
,
http_request
)
images
.
append
(
img_data
)
return
{
"created"
:
timestamp
,
"data"
:
images
"data"
:
images
,
"prompt_cache_hit"
:
cache_hit
,
}
...
...
codai/api/prompt_cache.py
0 → 100644
View file @
07fb3251
# CoderAI - OpenAI-compatible API server
# Copyright (C) 2026 Stefy Lanza <stefy@nexlab.net>
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
Prompt prefix cache manager.
Provides two features:
1. Prefix key computation for same-prefix request scheduling (prompt aggregation).
2. Per-model last-prompt tracking so callers can report accurate cached_tokens.
llama.cpp's KV cache naturally reuses computation when consecutive requests
share a prompt prefix. This manager helps exploit that by:
- Giving the scheduler a stable key to group requests by shared prefix.
- Letting text.py read back how many tokens were cached from timings.
"""
import
hashlib
import
json
import
time
from
dataclasses
import
dataclass
,
field
from
threading
import
Lock
from
typing
import
Dict
,
List
,
Optional
@
dataclass
class
_CacheEntry
:
messages_hash
:
str
prefix_hash
:
str
token_count
:
int
timestamp
:
float
=
field
(
default_factory
=
time
.
time
)
class
PromptCacheManager
:
"""
Tracks recently-processed prompt prefixes per model instance.
Usage
-----
# Before dispatching to the model:
prefix_key = manager.get_prefix_key(messages) # for QueueManager scheduling
# After the model call completes:
manager.store(messages, model_key, prompt_tokens)
# In the API response usage block:
cached = manager.get_cached_tokens(model_key) # from last store
"""
def
__init__
(
self
,
max_entries
:
int
=
256
,
ttl_seconds
:
float
=
600.0
):
self
.
_entries
:
Dict
[
str
,
_CacheEntry
]
=
{}
self
.
_by_model
:
Dict
[
str
,
str
]
=
{}
# model_key -> last messages_hash
self
.
_cached_tokens
:
Dict
[
str
,
int
]
=
{}
# model_key -> cached tokens from last call
self
.
_max_entries
=
max_entries
self
.
_ttl
=
ttl_seconds
self
.
_lock
=
Lock
()
# ------------------------------------------------------------------
# Hashing helpers
# ------------------------------------------------------------------
def
_hash_messages
(
self
,
messages
:
List
[
Dict
])
->
str
:
"""Stable SHA-256 hash (truncated) of a message list."""
canonical
=
json
.
dumps
(
[{
"role"
:
m
.
get
(
"role"
),
"content"
:
m
.
get
(
"content"
)}
for
m
in
messages
],
separators
=
(
","
,
":"
),
ensure_ascii
=
False
,
)
return
hashlib
.
sha256
(
canonical
.
encode
())
.
hexdigest
()[:
20
]
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
def
get_prefix_key
(
self
,
messages
:
List
[
Dict
])
->
str
:
"""
Stable key for the *cacheable* portion of a request.
The cacheable prefix is everything except the final user turn, since
system prompts and prior assistant turns stay constant across related
requests and benefit most from KV cache reuse.
Returns an empty string when there is no cacheable prefix.
"""
if
not
messages
:
return
""
prefix
=
messages
[:
-
1
]
if
messages
[
-
1
]
.
get
(
"role"
)
==
"user"
else
messages
return
self
.
_hash_messages
(
prefix
)
if
prefix
else
""
def
store
(
self
,
messages
:
List
[
Dict
],
model_key
:
str
,
prompt_tokens
:
int
,
cached_tokens
:
int
=
0
)
->
None
:
"""Record a completed prompt so future requests can match against it."""
with
self
.
_lock
:
msg_hash
=
self
.
_hash_messages
(
messages
)
prefix_hash
=
self
.
get_prefix_key
(
messages
)
self
.
_entries
[
msg_hash
]
=
_CacheEntry
(
messages_hash
=
msg_hash
,
prefix_hash
=
prefix_hash
,
token_count
=
prompt_tokens
,
)
self
.
_by_model
[
model_key
]
=
msg_hash
self
.
_cached_tokens
[
model_key
]
=
cached_tokens
self
.
_evict_locked
()
def
get_cached_tokens
(
self
,
model_key
:
str
)
->
int
:
"""Return the cached_tokens count stored by the last store() call for this model."""
with
self
.
_lock
:
return
self
.
_cached_tokens
.
get
(
model_key
,
0
)
def
has_warm_prefix
(
self
,
messages
:
List
[
Dict
],
model_key
:
str
)
->
bool
:
"""
Return True if the current request shares a prefix with the last
request processed by this model (i.e., the KV cache is likely warm).
"""
with
self
.
_lock
:
last_hash
=
self
.
_by_model
.
get
(
model_key
)
if
not
last_hash
:
return
False
entry
=
self
.
_entries
.
get
(
last_hash
)
if
not
entry
or
time
.
time
()
-
entry
.
timestamp
>
self
.
_ttl
:
return
False
current_prefix
=
self
.
get_prefix_key
(
messages
)
return
bool
(
current_prefix
and
current_prefix
==
entry
.
prefix_hash
)
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
def
_evict_locked
(
self
)
->
None
:
now
=
time
.
time
()
expired
=
[
k
for
k
,
v
in
self
.
_entries
.
items
()
if
now
-
v
.
timestamp
>
self
.
_ttl
]
for
k
in
expired
:
del
self
.
_entries
[
k
]
while
len
(
self
.
_entries
)
>
self
.
_max_entries
:
oldest
=
min
(
self
.
_entries
,
key
=
lambda
k
:
self
.
_entries
[
k
]
.
timestamp
)
del
self
.
_entries
[
oldest
]
prompt_cache_manager
=
PromptCacheManager
()
codai/api/text.py
View file @
07fb3251
...
...
@@ -32,6 +32,7 @@ logger = logging.getLogger(__name__)
# Import from codai modules
from
codai.models.manager
import
ModelManager
,
WhisperServerManager
,
MultiModelManager
,
model_manager
,
multi_model_manager
from
codai.queue.manager
import
QueueManager
,
queue_manager
from
codai.api.prompt_cache
import
prompt_cache_manager
from
codai.pydantic.textrequest
import
ChatCompletionRequest
,
ToolFunction
,
Tool
from
codai.models.parser
import
filter_malformed_content
,
filter_repetition
,
format_tools_for_prompt
,
cleanup_control_tokens
,
OpenAIFormatter
,
ModelParserAdapter
,
ToolCallParser
...
...
@@ -1142,6 +1143,9 @@ async def chat_completions(request: ChatCompletionRequest, http_request: Request
from
fastapi.responses
import
JSONResponse
return
JSONResponse
(
content
=
formatted_response
,
headers
=
headers
)
# Compute prefix key for prompt-aggregation scheduling
_prefix_key
=
prompt_cache_manager
.
get_prefix_key
(
messages_dict
)
if
request
.
stream
:
async
def
_managed_stream
():
try
:
...
...
@@ -1156,6 +1160,7 @@ async def chat_completions(request: ChatCompletionRequest, http_request: Request
current_manager
,
tool_parser
,
request
.
response_format
,
_prefix_key
,
):
yield
chunk
finally
:
...
...
@@ -1192,6 +1197,7 @@ async def stream_chat_response(
current_manager
:
ModelManager
,
tool_parser
:
ToolCallParser
,
response_format
:
Optional
[
Dict
]
=
None
,
prefix_key
:
str
=
""
,
)
->
AsyncGenerator
[
str
,
None
]:
"""Stream chat completion response with queue notifications."""
completion_id
=
f
"chatcmpl-{uuid.uuid4().hex}"
...
...
@@ -1214,7 +1220,7 @@ async def stream_chat_response(
# If model not loaded, add to queue and send waiting notifications
if
not
model_loaded
:
await
queue_manager
.
add_waiting
(
request_id
)
await
queue_manager
.
add_waiting
(
request_id
,
prefix_key
=
prefix_key
)
wait_interval
=
2.0
# Send waiting update every 2 seconds
last_wait_update
=
time
.
time
()
...
...
@@ -1457,10 +1463,24 @@ async def stream_chat_response(
prompt_text
=
"
\n
"
.
join
([
m
.
get
(
"content"
,
""
)
for
m
in
messages
])
prompt_tokens
=
len
(
prompt_text
.
split
())
completion_tokens
=
len
(
generated_text
.
split
())
if
generated_text
else
0
# Read accurate usage (including cached_tokens) from the backend
_model_key_for_cache
=
getattr
(
current_manager
,
'model_name'
,
None
)
or
model_name
last_usage
=
(
current_manager
.
get_last_usage
()
if
hasattr
(
current_manager
,
'get_last_usage'
)
else
{})
if
last_usage
.
get
(
'prompt_tokens'
):
prompt_tokens
=
last_usage
[
'prompt_tokens'
]
if
last_usage
.
get
(
'completion_tokens'
):
completion_tokens
=
last_usage
[
'completion_tokens'
]
cached_tokens
=
last_usage
.
get
(
'cached_tokens'
,
0
)
# Store in prompt cache manager for future prefix matching
prompt_cache_manager
.
store
(
messages
,
_model_key_for_cache
,
prompt_tokens
,
cached_tokens
)
# Get context size
context_size
=
current_manager
.
get_context_size
()
# Build complete final chunk with all OpenAI fields
final_chunk
=
{
"id"
:
completion_id
,
...
...
@@ -1479,7 +1499,7 @@ async def stream_chat_response(
"total_tokens"
:
prompt_tokens
+
completion_tokens
,
"context_size"
:
context_size
,
"prompt_tokens_details"
:
{
"cached_tokens"
:
0
,
"cached_tokens"
:
cached_tokens
,
"audio_tokens"
:
0
,
},
"completion_tokens_details"
:
{
...
...
@@ -1494,7 +1514,7 @@ async def stream_chat_response(
"system_fingerprint"
:
None
,
}
yield
f
"data: {json.dumps(final_chunk)}
\n\n
"
yield
"data: [DONE]
\n\n
"
except
Exception
as
e
:
print
(
f
"Error during streaming generation: {e}"
)
...
...
@@ -1638,11 +1658,20 @@ async def generate_chat_response(
response_message
[
"tool_calls"
]
=
tool_calls
finish_reason
=
"tool_calls"
# Calculate token counts - rough estimate since we don't have direct access to tokenizer
# Read accurate usage (including cached_tokens) from the backend
_model_key_for_cache
=
getattr
(
current_manager
,
'model_name'
,
None
)
or
model_name
last_usage
=
(
current_manager
.
get_last_usage
()
if
hasattr
(
current_manager
,
'get_last_usage'
)
else
{})
prompt_text
=
"
\n
"
.
join
([
m
.
get
(
"content"
,
""
)
for
m
in
messages
])
prompt_tokens
=
len
(
prompt_text
.
split
())
completion_tokens
=
len
(
generated_text
.
split
())
if
generated_text
else
0
prompt_tokens
=
last_usage
.
get
(
'prompt_tokens'
)
or
len
(
prompt_text
.
split
())
completion_tokens
=
last_usage
.
get
(
'completion_tokens'
)
or
(
len
(
generated_text
.
split
())
if
generated_text
else
0
)
cached_tokens
=
last_usage
.
get
(
'cached_tokens'
,
0
)
# Store in prompt cache manager for future prefix matching
prompt_cache_manager
.
store
(
messages
,
_model_key_for_cache
,
prompt_tokens
,
cached_tokens
)
# Get context size
context_size
=
current_manager
.
get_context_size
()
...
...
@@ -1655,6 +1684,10 @@ async def generate_chat_response(
tool_calls
=
response_message
.
get
(
"tool_calls"
),
context_size
=
context_size
)
# Patch in the real cached_tokens value
if
formatted_response
and
'usage'
in
formatted_response
:
details
=
formatted_response
[
'usage'
]
.
setdefault
(
'prompt_tokens_details'
,
{})
details
[
'cached_tokens'
]
=
cached_tokens
# Add mock reasoning stats if 'mock' is in force_reasoning_args
# But only if we don't already have real reasoning in the response
...
...
codai/backends/cuda.py
View file @
07fb3251
...
...
@@ -17,6 +17,7 @@
"""CUDA backend using HuggingFace Transformers."""
import
os
import
time
as
_time
from
typing
import
Optional
,
List
,
Dict
from
threading
import
Thread
from
abc
import
ABC
...
...
@@ -53,6 +54,13 @@ class NvidiaBackend(ModelBackend):
self
.
device
=
None
self
.
use_flash_attn
=
False
self
.
flash_attn_available
=
False
# KV prefix cache (single-entry, keyed by formatted prefix text)
self
.
_kv_prefix_text
:
Optional
[
str
]
=
None
self
.
_kv_past_key_values
=
None
# past_key_values tensor tuple
self
.
_kv_prefix_len
:
int
=
0
# token count of the cached prefix
self
.
_kv_timestamp
:
float
=
0.0
self
.
_kv_ttl
:
float
=
300.0
# 5 min TTL
self
.
_last_usage
:
Dict
=
{}
def
check_flash_attn_support
(
self
)
->
None
:
"""Check and print Flash Attention availability status."""
...
...
@@ -872,11 +880,288 @@ class NvidiaBackend(ModelBackend):
elif
generation_error
:
yield
f
"
\n
[Error during generation: {generation_error}]"
# ------------------------------------------------------------------
# KV prefix cache helpers
# ------------------------------------------------------------------
def
_kv_cache_valid
(
self
)
->
bool
:
return
(
self
.
_kv_past_key_values
is
not
None
and
_time
.
time
()
-
self
.
_kv_timestamp
<
self
.
_kv_ttl
)
def
_build_kv_prefix
(
self
,
prefix_text
:
str
):
"""Forward-pass on prefix_text to populate the KV state."""
import
torch
inputs
=
self
.
tokenizer
(
prefix_text
,
return_tensors
=
"pt"
,
add_special_tokens
=
False
)
inputs
=
{
k
:
v
.
to
(
self
.
model
.
device
)
for
k
,
v
in
inputs
.
items
()}
with
torch
.
no_grad
():
out
=
self
.
model
(
**
inputs
,
use_cache
=
True
,
return_dict
=
True
)
return
out
.
past_key_values
,
int
(
inputs
[
'input_ids'
]
.
shape
[
1
])
def
_store_kv
(
self
,
prefix_text
:
str
,
past_kv
,
prefix_len
:
int
)
->
None
:
self
.
_kv_prefix_text
=
prefix_text
self
.
_kv_past_key_values
=
past_kv
self
.
_kv_prefix_len
=
prefix_len
self
.
_kv_timestamp
=
_time
.
time
()
def
invalidate_kv_cache
(
self
)
->
None
:
"""Discard the cached KV state (call on model unload/swap)."""
self
.
_kv_prefix_text
=
None
self
.
_kv_past_key_values
=
None
self
.
_kv_prefix_len
=
0
self
.
_kv_timestamp
=
0.0
# ------------------------------------------------------------------
# Usage tracking
# ------------------------------------------------------------------
def
get_last_usage
(
self
)
->
dict
:
return
dict
(
self
.
_last_usage
)
# ------------------------------------------------------------------
# Chat-level generation (with KV prefix caching)
# ------------------------------------------------------------------
def
_format_messages_to_str
(
self
,
messages
)
->
str
:
"""Convert a list of message dicts to a formatted prompt string."""
from
codai.pydantic.textrequest
import
ChatMessage
chat_msgs
=
[
ChatMessage
(
**
m
)
if
isinstance
(
m
,
dict
)
else
m
for
m
in
messages
]
return
self
.
format_messages
(
chat_msgs
)
def
generate_chat
(
self
,
messages
,
max_tokens
=
None
,
temperature
=
0.7
,
top_p
=
1.0
,
stop
=
None
,
tools
=
None
,
response_format
=
None
)
->
str
:
"""
Non-streaming chat generation with KV prefix caching.
Detects when the current request shares a system-prompt / history
prefix with the previous request and reuses the cached KV state,
only encoding the new suffix tokens.
"""
import
torch
if
max_tokens
is
None
:
max_tokens
=
512
full_prompt
=
self
.
_format_messages_to_str
(
messages
)
total_input_ids
=
self
.
tokenizer
(
full_prompt
,
return_tensors
=
"pt"
)[
'input_ids'
]
total_prompt_len
=
int
(
total_input_ids
.
shape
[
1
])
# Build prefix text (all turns except the final user turn)
prefix_msgs
=
(
messages
[:
-
1
]
if
messages
and
messages
[
-
1
]
.
get
(
'role'
)
==
'user'
else
[]
)
past_kv
=
None
cached_len
=
0
if
prefix_msgs
:
prefix_text
=
self
.
_format_messages_to_str
(
prefix_msgs
)
if
self
.
_kv_cache_valid
()
and
self
.
_kv_prefix_text
==
prefix_text
:
past_kv
=
self
.
_kv_past_key_values
cached_len
=
self
.
_kv_prefix_len
else
:
try
:
past_kv
,
cached_len
=
self
.
_build_kv_prefix
(
prefix_text
)
self
.
_store_kv
(
prefix_text
,
past_kv
,
cached_len
)
except
Exception
as
e
:
print
(
f
"Warning: KV prefix cache build failed: {e}"
)
past_kv
,
cached_len
=
None
,
0
temperature
,
top_p
,
do_sample
=
self
.
_validate_params
(
temperature
,
top_p
)
gen_kwargs
=
dict
(
max_new_tokens
=
max_tokens
,
temperature
=
temperature
if
do_sample
else
None
,
top_p
=
top_p
if
do_sample
else
None
,
do_sample
=
do_sample
,
pad_token_id
=
self
.
tokenizer
.
pad_token_id
,
eos_token_id
=
self
.
tokenizer
.
eos_token_id
,
use_cache
=
True
,
)
generated_text
=
""
try
:
total_input_ids
=
total_input_ids
.
to
(
self
.
model
.
device
)
if
past_kv
is
not
None
and
0
<
cached_len
<
total_prompt_len
:
suffix_ids
=
total_input_ids
[:,
cached_len
:]
full_attn
=
torch
.
ones
(
1
,
total_prompt_len
,
dtype
=
torch
.
long
,
device
=
self
.
model
.
device
)
with
torch
.
no_grad
():
outputs
=
self
.
model
.
generate
(
input_ids
=
suffix_ids
,
past_key_values
=
past_kv
,
attention_mask
=
full_attn
,
**
gen_kwargs
,
)
new_tokens
=
outputs
[
0
][
suffix_ids
.
shape
[
1
]:]
else
:
cached_len
=
0
attn_mask
=
torch
.
ones_like
(
total_input_ids
)
with
torch
.
no_grad
():
outputs
=
self
.
model
.
generate
(
input_ids
=
total_input_ids
,
attention_mask
=
attn_mask
,
**
gen_kwargs
,
)
new_tokens
=
outputs
[
0
][
total_prompt_len
:]
generated_text
=
self
.
tokenizer
.
decode
(
new_tokens
,
skip_special_tokens
=
True
)
except
Exception
as
e
:
print
(
f
"Warning: KV-cached generate_chat failed ({e}), retrying without cache"
)
cached_len
=
0
try
:
total_input_ids
=
self
.
tokenizer
(
full_prompt
,
return_tensors
=
"pt"
)[
'input_ids'
]
.
to
(
self
.
model
.
device
)
attn_mask
=
torch
.
ones_like
(
total_input_ids
)
with
torch
.
no_grad
():
outputs
=
self
.
model
.
generate
(
input_ids
=
total_input_ids
,
attention_mask
=
attn_mask
,
**
gen_kwargs
,
)
new_tokens
=
outputs
[
0
][
total_prompt_len
:]
generated_text
=
self
.
tokenizer
.
decode
(
new_tokens
,
skip_special_tokens
=
True
)
except
Exception
as
e2
:
print
(
f
"Error: generate_chat fallback failed: {e2}"
)
generated_text
=
""
try
:
comp_len
=
len
(
self
.
tokenizer
.
encode
(
generated_text
))
if
generated_text
else
0
except
Exception
:
comp_len
=
len
(
generated_text
.
split
())
self
.
_last_usage
=
{
'prompt_tokens'
:
total_prompt_len
,
'completion_tokens'
:
comp_len
,
'cached_tokens'
:
cached_len
,
}
return
generated_text
async
def
generate_chat_stream
(
self
,
messages
,
max_tokens
=
None
,
temperature
=
0.7
,
top_p
=
1.0
,
stop
=
None
,
tools
=
None
,
response_format
=
None
):
"""
Streaming chat generation with KV prefix caching.
Uses the same prefix-cache strategy as generate_chat.
"""
import
torch
from
transformers
import
TextIteratorStreamer
,
StoppingCriteria
,
StoppingCriteriaList
from
threading
import
Thread
if
max_tokens
is
None
:
max_tokens
=
512
full_prompt
=
self
.
_format_messages_to_str
(
messages
)
total_input_ids
=
self
.
tokenizer
(
full_prompt
,
return_tensors
=
"pt"
)[
'input_ids'
]
total_prompt_len
=
int
(
total_input_ids
.
shape
[
1
])
prefix_msgs
=
(
messages
[:
-
1
]
if
messages
and
messages
[
-
1
]
.
get
(
'role'
)
==
'user'
else
[]
)
past_kv
=
None
cached_len
=
0
if
prefix_msgs
:
prefix_text
=
self
.
_format_messages_to_str
(
prefix_msgs
)
if
self
.
_kv_cache_valid
()
and
self
.
_kv_prefix_text
==
prefix_text
:
past_kv
=
self
.
_kv_past_key_values
cached_len
=
self
.
_kv_prefix_len
else
:
try
:
past_kv
,
cached_len
=
self
.
_build_kv_prefix
(
prefix_text
)
self
.
_store_kv
(
prefix_text
,
past_kv
,
cached_len
)
except
Exception
as
e
:
print
(
f
"Warning: KV prefix cache build failed (stream): {e}"
)
past_kv
,
cached_len
=
None
,
0
temperature
,
top_p
,
do_sample
=
self
.
_validate_params
(
temperature
,
top_p
)
total_input_ids
=
total_input_ids
.
to
(
self
.
model
.
device
)
if
past_kv
is
not
None
and
0
<
cached_len
<
total_prompt_len
:
gen_input_ids
=
total_input_ids
[:,
cached_len
:]
full_attn
=
torch
.
ones
(
1
,
total_prompt_len
,
dtype
=
torch
.
long
,
device
=
self
.
model
.
device
)
extra_gen
=
{
'past_key_values'
:
past_kv
,
'attention_mask'
:
full_attn
}
else
:
cached_len
=
0
gen_input_ids
=
total_input_ids
extra_gen
=
{
'attention_mask'
:
torch
.
ones_like
(
total_input_ids
)}
streamer
=
TextIteratorStreamer
(
self
.
tokenizer
,
skip_prompt
=
True
,
skip_special_tokens
=
True
)
gen_kwargs
=
dict
(
input_ids
=
gen_input_ids
,
max_new_tokens
=
max_tokens
,
temperature
=
temperature
if
do_sample
else
None
,
top_p
=
top_p
if
do_sample
else
None
,
do_sample
=
do_sample
,
streamer
=
streamer
,
pad_token_id
=
self
.
tokenizer
.
pad_token_id
,
eos_token_id
=
self
.
tokenizer
.
eos_token_id
,
use_cache
=
True
,
**
extra_gen
,
)
if
stop
:
class
_StopOnSeq
(
StoppingCriteria
):
def
__init__
(
self
,
seqs
,
tok
):
self
.
seqs
=
seqs
self
.
tok
=
tok
def
__call__
(
self
,
input_ids
,
scores
,
**
kw
):
decoded
=
self
.
tok
.
decode
(
input_ids
[
0
][
-
20
:],
skip_special_tokens
=
True
)
return
any
(
s
in
decoded
for
s
in
self
.
seqs
)
gen_kwargs
[
'stopping_criteria'
]
=
StoppingCriteriaList
(
[
_StopOnSeq
(
stop
,
self
.
tokenizer
)]
)
gen_error
=
[
None
]
comp_tokens
=
[
0
]
def
_run
():
try
:
with
torch
.
no_grad
():
self
.
model
.
generate
(
**
gen_kwargs
)
except
Exception
as
e
:
gen_error
[
0
]
=
str
(
e
)
thread
=
Thread
(
target
=
_run
)
thread
.
start
()
try
:
for
text
in
streamer
:
comp_tokens
[
0
]
+=
1
yield
text
except
Exception
as
e
:
print
(
f
"Error during KV-cached stream iteration: {e}"
)
finally
:
thread
.
join
()
self
.
_last_usage
=
{
'prompt_tokens'
:
total_prompt_len
,
'completion_tokens'
:
comp_tokens
[
0
],
'cached_tokens'
:
cached_len
,
}
if
gen_error
[
0
]:
print
(
f
"Warning: KV-cached stream generation error: {gen_error[0]}"
)
def
get_model_name
(
self
)
->
str
:
return
self
.
model_name
or
"unknown"
def
cleanup
(
self
)
->
None
:
import
torch
self
.
invalidate_kv_cache
()
if
self
.
model
is
not
None
:
del
self
.
model
del
self
.
tokenizer
...
...
codai/backends/vulkan.py
View file @
07fb3251
...
...
@@ -63,6 +63,7 @@ class VulkanBackend(ModelBackend):
self
.
force_cuda
=
original_backend
in
(
"nvidia"
,
"cuda"
)
# Force CUDA if original was nvidia
if
self
.
force_cuda
:
print
(
"DEBUG: GGUF model will use CUDA backend (forced by --backend nvidia)"
)
self
.
_last_usage
:
dict
=
{}
# usage from the most recent completion call
self
.
_detect_chat_template
()
def
_detect_chat_template
(
self
):
...
...
@@ -649,6 +650,8 @@ class VulkanBackend(ModelBackend):
stop
=
stop
,
grammar
=
use_grammar
,
)
usage
=
result
.
get
(
'usage'
,
{})
self
.
_store_usage
(
usage
.
get
(
'prompt_tokens'
,
0
),
usage
.
get
(
'completion_tokens'
,
0
))
return
result
[
'choices'
][
0
][
'text'
]
except
Exception
as
e
:
# If grammar generation fails, fall back to normal generation
...
...
@@ -664,6 +667,8 @@ class VulkanBackend(ModelBackend):
repeat_penalty
=
repeat_penalty
,
stop
=
stop
,
)
usage
=
result
.
get
(
'usage'
,
{})
self
.
_store_usage
(
usage
.
get
(
'prompt_tokens'
,
0
),
usage
.
get
(
'completion_tokens'
,
0
))
return
result
[
'choices'
][
0
][
'text'
]
except
Exception
as
e2
:
print
(
f
"Error during fallback generation: {e2}"
)
...
...
@@ -935,6 +940,112 @@ class VulkanBackend(ModelBackend):
"n_gpu_layers"
:
self
.
n_gpu_layers
,
}
# ------------------------------------------------------------------
# Usage / cache helpers
# ------------------------------------------------------------------
def
_read_cached_tokens
(
self
,
prompt_tokens
:
int
)
->
int
:
"""Extract cached token count from llama.cpp timings after a completion."""
try
:
timings
=
getattr
(
self
.
model
,
'timings'
,
None
)
if
timings
is
None
:
# Try the internal context if timings property not exposed
ctx
=
getattr
(
self
.
model
,
'_ctx'
,
None
)
if
ctx
and
hasattr
(
ctx
,
'timings'
):
timings
=
ctx
.
timings
()
if
timings
is
not
None
:
n_p_eval
=
getattr
(
timings
,
'n_p_eval'
,
None
)
if
n_p_eval
is
not
None
:
return
max
(
0
,
prompt_tokens
-
int
(
n_p_eval
))
except
Exception
:
pass
return
0
def
_store_usage
(
self
,
prompt_tokens
:
int
,
completion_tokens
:
int
)
->
None
:
cached
=
self
.
_read_cached_tokens
(
prompt_tokens
)
self
.
_last_usage
=
{
'prompt_tokens'
:
prompt_tokens
,
'completion_tokens'
:
completion_tokens
,
'total_tokens'
:
prompt_tokens
+
completion_tokens
,
'cached_tokens'
:
cached
,
}
def
get_last_usage
(
self
)
->
dict
:
"""Return usage dict from the most recent completion (includes cached_tokens)."""
return
dict
(
self
.
_last_usage
)
# ------------------------------------------------------------------
# Chat-level generation (uses llama.cpp native chat template)
# ------------------------------------------------------------------
def
generate_chat
(
self
,
messages
,
max_tokens
=
None
,
temperature
=
0.7
,
top_p
=
1.0
,
stop
=
None
,
tools
=
None
,
response_format
=
None
):
"""Non-streaming chat completion using llama.cpp's native chat handler."""
if
self
.
model
is
None
:
raise
RuntimeError
(
"Model not loaded"
)
kwargs
=
dict
(
messages
=
messages
,
max_tokens
=
max_tokens
or
512
,
temperature
=
temperature
,
top_p
=
top_p
,
)
if
stop
:
kwargs
[
'stop'
]
=
stop
if
response_format
and
response_format
.
get
(
'type'
)
==
'json_object'
:
kwargs
[
'response_format'
]
=
{
'type'
:
'json_object'
}
result
=
self
.
model
.
create_chat_completion
(
**
kwargs
)
usage
=
result
.
get
(
'usage'
,
{})
self
.
_store_usage
(
prompt_tokens
=
usage
.
get
(
'prompt_tokens'
,
0
),
completion_tokens
=
usage
.
get
(
'completion_tokens'
,
0
),
)
content
=
result
[
'choices'
][
0
][
'message'
]
.
get
(
'content'
)
or
''
return
content
async
def
generate_chat_stream
(
self
,
messages
,
max_tokens
=
None
,
temperature
=
0.7
,
top_p
=
1.0
,
stop
=
None
,
tools
=
None
,
response_format
=
None
):
"""Streaming chat completion using llama.cpp's native chat handler."""
if
self
.
model
is
None
:
raise
RuntimeError
(
"Model not loaded"
)
kwargs
=
dict
(
messages
=
messages
,
max_tokens
=
max_tokens
or
512
,
temperature
=
temperature
,
top_p
=
top_p
,
stream
=
True
,
)
if
stop
:
kwargs
[
'stop'
]
=
stop
prompt_tokens
=
0
completion_tokens
=
0
try
:
for
chunk
in
self
.
model
.
create_chat_completion
(
**
kwargs
):
delta
=
chunk
[
'choices'
][
0
]
.
get
(
'delta'
,
{})
text
=
delta
.
get
(
'content'
)
or
''
if
text
:
completion_tokens
+=
1
yield
text
# Capture usage if present in final streaming chunk
if
chunk
.
get
(
'usage'
):
u
=
chunk
[
'usage'
]
prompt_tokens
=
u
.
get
(
'prompt_tokens'
,
0
)
completion_tokens
=
u
.
get
(
'completion_tokens'
,
completion_tokens
)
if
chunk
[
'choices'
][
0
]
.
get
(
'finish_reason'
):
break
finally
:
# Timings are available after the stream is exhausted
if
prompt_tokens
==
0
:
# Estimate from word split if llama.cpp didn't report
prompt_tokens
=
sum
(
len
(
str
(
m
.
get
(
'content'
,
''
))
.
split
())
for
m
in
messages
)
self
.
_store_usage
(
prompt_tokens
,
completion_tokens
)
def
get_model_name
(
self
)
->
str
:
"""Return the loaded model name."""
return
self
.
model_name
or
"unknown"
...
...
codai/models/manager.py
View file @
07fb3251
...
...
@@ -233,6 +233,12 @@ class ModelManager:
if
self
.
backend
is
not
None
:
return
self
.
backend
.
get_context_size
()
return
2048
# Default fallback
def
get_last_usage
(
self
)
->
dict
:
"""Return usage info (including cached_tokens) from the most recent call."""
if
self
.
backend
is
not
None
and
hasattr
(
self
.
backend
,
'get_last_usage'
):
return
self
.
backend
.
get_last_usage
()
return
{}
def
cleanup
(
self
):
if
self
.
backend
is
not
None
:
...
...
@@ -2040,11 +2046,29 @@ class MultiModelManager:
"embedding_models"
:
"embedding"
,
}
# Minimum capability guaranteed by a model's config category.
# Applied when heuristic name detection doesn't recognise the model ID.
TYPE_MIN_CAP
=
{
"image"
:
"image_generation"
,
"video"
:
"video_generation"
,
"audio"
:
"speech_to_text"
,
"tts"
:
"text_to_speech"
,
"audio_gen"
:
"audio_generation"
,
"embedding"
:
"embeddings"
,
}
def
_add
(
model_id
:
str
,
model_type
:
str
=
None
,
meta
:
Dict
[
str
,
Any
]
=
None
):
if
model_id
in
seen_ids
:
return
seen_ids
.
add
(
model_id
)
caps
=
detect_model_capabilities
(
model_id
)
# If heuristic detection missed the type (e.g. custom/vendor model IDs
# that don't match any keyword), ensure the minimum capability for the
# config-declared type is set so badges display correctly.
if
model_type
and
model_type
in
TYPE_MIN_CAP
:
min_cap
=
TYPE_MIN_CAP
[
model_type
]
if
not
getattr
(
caps
,
min_cap
,
False
):
setattr
(
caps
,
min_cap
,
True
)
resolved_type
=
model_type
or
(
caps
.
to_list
()[
0
]
.
split
(
"_"
)[
0
]
if
caps
.
to_list
()
else
"text"
)
meta
=
meta
or
{}
models
.
append
(
ModelInfo
(
...
...
@@ -2075,6 +2099,11 @@ class MultiModelManager:
else
:
raw
=
m
.
get
(
"path"
)
or
m
.
get
(
"id"
)
or
""
alias
=
m
.
get
(
"alias"
)
or
""
# Auto-derive a clean alias for GGUF files that have no
# explicit alias so the full filesystem path isn't exposed.
if
not
alias
and
raw
.
lower
()
.
endswith
(
".gguf"
):
stem
=
raw
.
split
(
"/"
)[
-
1
][:
-
5
]
# filename without .gguf
alias
=
stem
# whisper-server aliases are round-robin group keys shared across
# multiple instances — don't expose the alias as a separate model
if
m
.
get
(
"backend"
)
==
"whisper-server"
:
...
...
codai/pydantic/textrequest.py
View file @
07fb3251
...
...
@@ -39,6 +39,7 @@ class ChatMessage(BaseModel):
name
:
Optional
[
str
]
=
None
tool_calls
:
Optional
[
List
[
Dict
]]
=
None
tool_call_id
:
Optional
[
str
]
=
None
cache_control
:
Optional
[
Dict
]
=
None
# OpenAI-style: {"type": "ephemeral"}
@
field_validator
(
'content'
,
mode
=
'before'
)
@
classmethod
...
...
codai/queue/manager.py
View file @
07fb3251
...
...
@@ -40,6 +40,7 @@ class WaitingRequest:
sequence
:
int
event
:
asyncio
.
Event
=
field
(
default_factory
=
asyncio
.
Event
)
bypassed_by
:
int
=
0
prefix_key
:
str
=
""
# stable hash of the cacheable prompt prefix
class
QueueManager
:
...
...
@@ -61,6 +62,7 @@ class QueueManager:
self
.
model_name
:
Optional
[
str
]
=
None
self
.
_processing
:
bool
=
False
self
.
_ready_request_ids
:
Set
[
str
]
=
set
()
self
.
_last_prefix_key
:
str
=
""
# prefix key of the last completed request
def
set_loaded_models
(
self
,
model_keys
:
Set
[
str
])
->
None
:
self
.
loaded_models
=
set
(
model_keys
)
...
...
@@ -83,17 +85,19 @@ class QueueManager:
self
.
model_name
=
None
self
.
_processing
=
False
self
.
_ready_request_ids
.
clear
()
self
.
_last_prefix_key
=
""
async
def
is_full
(
self
)
->
bool
:
async
with
self
.
lock
:
return
len
(
self
.
waiting
)
>=
self
.
max_size
async
def
acquire
(
self
,
request_id
:
str
,
model_key
:
str
)
->
SchedulerLease
:
async
def
acquire
(
self
,
request_id
:
str
,
model_key
:
str
,
prefix_key
:
str
=
""
)
->
SchedulerLease
:
waiter
=
None
async
with
self
.
lock
:
if
self
.
_can_start_now
(
model_key
):
return
self
.
_grant_lease
(
request_id
,
model_key
)
waiter
=
self
.
_enqueue_waiter
(
request_id
,
model_key
)
waiter
=
self
.
_enqueue_waiter
(
request_id
,
model_key
,
prefix_key
)
await
waiter
.
event
.
wait
()
async
with
self
.
lock
:
...
...
@@ -103,7 +107,8 @@ class QueueManager:
lease
.
wait_time_seconds
=
max
(
0.0
,
time
.
time
()
-
waiter
.
enqueued_at
)
return
lease
async
def
release
(
self
,
lease
:
SchedulerLease
)
->
None
:
async
def
release
(
self
,
lease
:
SchedulerLease
,
prefix_key
:
str
=
""
)
->
None
:
async
with
self
.
lock
:
self
.
active_leases
.
pop
(
lease
.
request_id
,
None
)
current
=
self
.
active_by_model
.
get
(
lease
.
model_key
,
0
)
...
...
@@ -113,14 +118,17 @@ class QueueManager:
self
.
active_by_model
[
lease
.
model_key
]
=
current
-
1
if
self
.
current_request_id
==
lease
.
request_id
:
self
.
current_request_id
=
None
if
prefix_key
:
self
.
_last_prefix_key
=
prefix_key
self
.
_processing
=
bool
(
self
.
active_leases
)
self
.
_wake_waiters_locked
()
async
def
add_waiting
(
self
,
request_id
:
str
,
model_key
:
str
=
""
)
->
None
:
async
def
add_waiting
(
self
,
request_id
:
str
,
model_key
:
str
=
""
,
prefix_key
:
str
=
""
)
->
None
:
async
with
self
.
lock
:
if
request_id
in
self
.
waiting_by_id
:
return
self
.
_enqueue_waiter
(
request_id
,
model_key
or
request_id
)
self
.
_enqueue_waiter
(
request_id
,
model_key
or
request_id
,
prefix_key
)
async
def
remove_waiting
(
self
,
request_id
:
str
)
->
None
:
async
with
self
.
lock
:
...
...
@@ -172,13 +180,15 @@ class QueueManager:
"loaded_models"
:
sorted
(
self
.
loaded_models
),
}
def
_enqueue_waiter
(
self
,
request_id
:
str
,
model_key
:
str
)
->
WaitingRequest
:
def
_enqueue_waiter
(
self
,
request_id
:
str
,
model_key
:
str
,
prefix_key
:
str
=
""
)
->
WaitingRequest
:
self
.
sequence
+=
1
waiter
=
WaitingRequest
(
request_id
=
request_id
,
model_key
=
model_key
,
enqueued_at
=
time
.
time
(),
sequence
=
self
.
sequence
,
prefix_key
=
prefix_key
,
)
self
.
waiting
.
append
(
waiter
)
self
.
waiting_by_id
[
request_id
]
=
waiter
...
...
@@ -233,17 +243,39 @@ class QueueManager:
return
def
_pick_next_waiter_locked
(
self
)
->
Optional
[
WaitingRequest
]:
for
waiter
in
self
.
waiting
:
if
self
.
_waiter_can_start_locked
(
waiter
):
older_blocked
=
[
other
for
other
in
self
.
waiting
if
other
.
sequence
<
waiter
.
sequence
and
not
self
.
_waiter_can_start_locked
(
other
)
]
if
any
(
other
.
bypassed_by
>=
self
.
fairness_bypass_limit
for
other
in
older_blocked
):
continue
for
other
in
older_blocked
:
other
.
bypassed_by
+=
1
# Collect all candidates that can start now.
candidates
=
[
w
for
w
in
self
.
waiting
if
self
.
_waiter_can_start_locked
(
w
)]
if
not
candidates
:
return
None
# Fairness: don't bypass an older waiter more than the limit.
def
_is_fair
(
waiter
:
WaitingRequest
)
->
bool
:
older_blocked
=
[
other
for
other
in
self
.
waiting
if
other
.
sequence
<
waiter
.
sequence
and
not
self
.
_waiter_can_start_locked
(
other
)
]
if
any
(
other
.
bypassed_by
>=
self
.
fairness_bypass_limit
for
other
in
older_blocked
):
return
False
for
other
in
older_blocked
:
other
.
bypassed_by
+=
1
return
True
# Prompt aggregation: prefer candidates whose prefix key matches the
# last completed request — they will hit a warm KV cache.
if
self
.
_last_prefix_key
:
warm_candidates
=
[
w
for
w
in
candidates
if
w
.
prefix_key
and
w
.
prefix_key
==
self
.
_last_prefix_key
]
for
waiter
in
warm_candidates
:
if
_is_fair
(
waiter
):
return
waiter
# Fall back to FIFO order.
for
waiter
in
candidates
:
if
_is_fair
(
waiter
):
return
waiter
return
None
def
_waiting_counts_locked
(
self
)
->
Dict
[
str
,
int
]:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment