Commit 5341ee6a authored by Your Name's avatar Your Name

feat: Pipeline fixes, regex optimization, GBNF grammar support, and prompt distillation

- Fixed streaming mode pipeline issues:
  - Fixed n-gram counting to handle partial matches correctly
  - Added per-chunk filtering to prevent duplicate n-grams across chunks

- Optimized regex patterns (~35 patterns pre-compiled):
  - Pre-compiled all regex patterns for better performance
  - Added false positive protection with length-based filtering
  - Optimized tool call parsing in parser.py

- Added grammar-guided generation (--ggg / --grammar-guided-gen):
  - New GBNF grammar file (tool_call_grammar.gbnf) for tool call parsing
  - Grammar loading utilities in models/grammar.py
  - Vulkan backend: Added GBNF grammar support via llama_generate_grammar
  - CUDA backend: Added outlines support for structured output

- Added prompt distillation (--tools-closer-prompt):
  - New CLI option --tools-closer-prompt for prompt distillation
  - Enables generating distilled tool descriptions for better accuracy
parent d9cba7ec
...@@ -10,6 +10,22 @@ from codai.backends.base import ModelBackend ...@@ -10,6 +10,22 @@ from codai.backends.base import ModelBackend
from codai.models.capabilities import detect_model_capabilities from codai.models.capabilities import detect_model_capabilities
from codai.pydantic.textrequest import ChatMessage from codai.pydantic.textrequest import ChatMessage
# Try to import outlines for grammar-guided generation
try:
from outlines import models, generate
OUTLINES_AVAILABLE = True
except ImportError:
OUTLINES_AVAILABLE = False
models = None
generate = None
# Import global flag from coderai (will be None if not running as server)
try:
import coderai
_grammar_guided_gen = getattr(coderai, 'grammar_guided_gen', False)
except (ImportError, AttributeError):
_grammar_guided_gen = False
class NvidiaBackend(ModelBackend): class NvidiaBackend(ModelBackend):
"""Backend for NVIDIA GPUs using HuggingFace Transformers.""" """Backend for NVIDIA GPUs using HuggingFace Transformers."""
...@@ -406,8 +422,82 @@ class NvidiaBackend(ModelBackend): ...@@ -406,8 +422,82 @@ class NvidiaBackend(ModelBackend):
def generate(self, prompt: str, max_tokens: Optional[int] = None, def generate(self, prompt: str, max_tokens: Optional[int] = None,
temperature: float = 0.7, top_p: float = 1.0, temperature: float = 0.7, top_p: float = 1.0,
stop: Optional[List[str]] = None) -> str: stop: Optional[List[str]] = None,
"""Generate text non-streaming.""" grammar: Optional[str] = None) -> str:
"""Generate text non-streaming.
Args:
prompt: Input prompt
max_tokens: Maximum tokens to generate
temperature: Sampling temperature
top_p: Top-p sampling
stop: Stop sequences
grammar: Optional regex pattern for constrained generation (outlines)
"""
import torch
from transformers import LogitsProcessor, LogitsProcessorList
# Check for grammar-guided generation using outlines
use_grammar = grammar
if use_grammar is None:
# Check global flag
try:
import coderai
if getattr(coderai, 'grammar_guided_gen', False):
if not OUTLINES_AVAILABLE:
print("Warning: outlines not installed. Run: pip install outlines")
use_grammar = None
else:
# Use a regex pattern for tool calls
use_grammar = r'<tool>.*?</tool>|\{.*?"name".*?"arguments".*?\}|\[.*?"name".*?"arguments".*?\]'
except (ImportError, AttributeError):
pass
# If outlines is available and grammar is enabled, use outlines
if use_grammar and OUTLINES_AVAILABLE:
try:
return self._generate_with_outlines(prompt, max_tokens, temperature, top_p, stop, use_grammar)
except Exception as e:
print(f"Warning: Outlines generation failed: {e}, falling back to normal generation")
# Fall through to normal generation
# Normal generation without grammar
return self._generate_normal(prompt, max_tokens, temperature, top_p, stop)
def _generate_with_outlines(self, prompt: str, max_tokens: Optional[int],
temperature: float, top_p: float,
stop: Optional[List[str]],
pattern: str) -> str:
"""Generate text using outlines library for grammar-guided generation."""
if max_tokens is None:
max_tokens = 512
# Create outlines model from the loaded model
model = models.Transformers(self.model, self.tokenizer)
# Create regex generator
regex_generator = generate.regex(model, pattern=pattern)
# Generate with outlines
# Note: outlines uses its own sampling parameters
result = regex_generator(
prompt,
max_tokens=max_tokens,
temperature=temperature if temperature > 0 else 0.7,
)
# Extract the generated text (outlines returns the full output)
if isinstance(result, str):
# Remove the prompt from the result
if result.startswith(prompt):
result = result[len(prompt):]
return result
return str(result)
def _generate_normal(self, prompt: str, max_tokens: Optional[int],
temperature: float, top_p: float,
stop: Optional[List[str]]) -> str:
"""Normal generation without grammar constraints."""
import torch import torch
from transformers import LogitsProcessor, LogitsProcessorList from transformers import LogitsProcessor, LogitsProcessorList
...@@ -469,8 +559,83 @@ class NvidiaBackend(ModelBackend): ...@@ -469,8 +559,83 @@ class NvidiaBackend(ModelBackend):
async def generate_stream(self, prompt: str, max_tokens: Optional[int] = None, async def generate_stream(self, prompt: str, max_tokens: Optional[int] = None,
temperature: float = 0.7, top_p: float = 1.0, temperature: float = 0.7, top_p: float = 1.0,
stop: Optional[List[str]] = None): stop: Optional[List[str]] = None,
"""Generate text in streaming fashion.""" grammar: Optional[str] = None):
"""Generate text in streaming fashion.
Args:
prompt: Input prompt
max_tokens: Maximum tokens to generate
temperature: Sampling temperature
top_p: Top-p sampling
stop: Stop sequences
grammar: Optional regex pattern for constrained generation (outlines)
"""
# Check for grammar-guided generation using outlines
use_grammar = grammar
if use_grammar is None:
# Check global flag
try:
import coderai
if getattr(coderai, 'grammar_guided_gen', False):
if not OUTLINES_AVAILABLE:
print("Warning: outlines not installed. Run: pip install outlines")
use_grammar = None
else:
# Use a regex pattern for tool calls
use_grammar = r'<tool>.*?</tool>|\{.*?"name".*?"arguments".*?\}|\[.*?"name".*?"arguments".*?\]'
except (ImportError, AttributeError):
pass
# If outlines is available and grammar is enabled, use outlines
if use_grammar and OUTLINES_AVAILABLE:
try:
async for chunk in self._generate_stream_outlines(prompt, max_tokens, temperature, top_p, stop, use_grammar):
yield chunk
return
except Exception as e:
print(f"Warning: Outlines streaming generation failed: {e}, falling back to normal generation")
# Fall through to normal generation
# Normal streaming generation without grammar
async for chunk in self._generate_stream_normal(prompt, max_tokens, temperature, top_p, stop):
yield chunk
async def _generate_stream_outlines(self, prompt: str, max_tokens: Optional[int],
temperature: float, top_p: float,
stop: Optional[List[str]],
pattern: str):
"""Generate text using outlines library in streaming mode."""
if max_tokens is None:
max_tokens = 512
# Create outlines model from the loaded model
model = models.Transformers(self.model, self.tokenizer)
# Create regex generator
regex_generator = generate.regex(model, pattern=pattern)
# Generate with outlines (outlines doesn't support true streaming, so we yield the result)
result = regex_generator(
prompt,
max_tokens=max_tokens,
temperature=temperature if temperature > 0 else 0.7,
)
# Extract the generated text (outlines returns the full output)
if isinstance(result, str):
# Remove the prompt from the result
if result.startswith(prompt):
result = result[len(prompt):]
# Yield the entire result as a single chunk (outlines doesn't support true streaming)
yield result
else:
yield str(result)
async def _generate_stream_normal(self, prompt: str, max_tokens: Optional[int],
temperature: float, top_p: float,
stop: Optional[List[str]]):
"""Normal streaming generation without grammar constraints."""
import torch import torch
from transformers import TextIteratorStreamer, LogitsProcessor, LogitsProcessorList, StoppingCriteria, StoppingCriteriaList from transformers import TextIteratorStreamer, LogitsProcessor, LogitsProcessorList, StoppingCriteria, StoppingCriteriaList
......
...@@ -13,6 +13,14 @@ from codai.models.utils import ( ...@@ -13,6 +13,14 @@ from codai.models.utils import (
get_reasoning_stop_tokens get_reasoning_stop_tokens
) )
from codai.models.cache import get_cached_model_path from codai.models.cache import get_cached_model_path
from codai.models.grammar import get_tool_call_grammar, is_grammar_available
# Import global flag from coderai (will be None if not running as server)
try:
import coderai
_grammar_guided_gen = getattr(coderai, 'grammar_guided_gen', False)
except (ImportError, AttributeError):
_grammar_guided_gen = False
try: try:
from llama_cpp import Llama from llama_cpp import Llama
...@@ -527,7 +535,8 @@ class VulkanBackend(ModelBackend): ...@@ -527,7 +535,8 @@ class VulkanBackend(ModelBackend):
max_tokens: Optional[int] = None, max_tokens: Optional[int] = None,
temperature: float = 0.7, temperature: float = 0.7,
top_p: float = 1.0, top_p: float = 1.0,
stop: Optional[List[str]] = None stop: Optional[List[str]] = None,
grammar: Optional[str] = None
) -> str: ) -> str:
"""Generate text non-streaming. """Generate text non-streaming.
...@@ -537,6 +546,7 @@ class VulkanBackend(ModelBackend): ...@@ -537,6 +546,7 @@ class VulkanBackend(ModelBackend):
temperature: Sampling temperature temperature: Sampling temperature
top_p: Top-p sampling top_p: Top-p sampling
stop: Stop sequences stop: Stop sequences
grammar: Optional GBNF grammar string for constrained generation
Returns: Returns:
Generated text Generated text
...@@ -571,6 +581,17 @@ class VulkanBackend(ModelBackend): ...@@ -571,6 +581,17 @@ class VulkanBackend(ModelBackend):
# Get default stop tokens based on template # Get default stop tokens based on template
stop = get_reasoning_stop_tokens(self.chat_template) stop = get_reasoning_stop_tokens(self.chat_template)
# Check for grammar-guided generation
use_grammar = grammar
if use_grammar is None:
# Check global flag
try:
import coderai
if getattr(coderai, 'grammar_guided_gen', False):
use_grammar = get_tool_call_grammar()
except (ImportError, AttributeError):
pass
try: try:
result = self.model.create_completion( result = self.model.create_completion(
prompt=prompt, prompt=prompt,
...@@ -580,9 +601,27 @@ class VulkanBackend(ModelBackend): ...@@ -580,9 +601,27 @@ class VulkanBackend(ModelBackend):
top_k=top_k, top_k=top_k,
repeat_penalty=repeat_penalty, repeat_penalty=repeat_penalty,
stop=stop, stop=stop,
grammar=use_grammar,
) )
return result['choices'][0]['text'] return result['choices'][0]['text']
except Exception as e: except Exception as e:
# If grammar generation fails, fall back to normal generation
if use_grammar:
print(f"Warning: Grammar-guided generation failed: {e}, falling back to normal generation")
try:
result = self.model.create_completion(
prompt=prompt,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
repeat_penalty=repeat_penalty,
stop=stop,
)
return result['choices'][0]['text']
except Exception as e2:
print(f"Error during fallback generation: {e2}")
raise
print(f"Error during generation: {e}") print(f"Error during generation: {e}")
raise raise
...@@ -613,6 +652,7 @@ class VulkanBackend(ModelBackend): ...@@ -613,6 +652,7 @@ class VulkanBackend(ModelBackend):
top_p = kwargs.get('top_p', 0.9) top_p = kwargs.get('top_p', 0.9)
top_k = kwargs.get('top_k', 40) top_k = kwargs.get('top_k', 40)
repeat_penalty = kwargs.get('repeat_penalty', 1.1) repeat_penalty = kwargs.get('repeat_penalty', 1.1)
grammar = kwargs.get('grammar', None)
# Get stop strings # Get stop strings
stop = kwargs.get('stop', None) stop = kwargs.get('stop', None)
...@@ -620,6 +660,17 @@ class VulkanBackend(ModelBackend): ...@@ -620,6 +660,17 @@ class VulkanBackend(ModelBackend):
# Get default stop tokens based on template # Get default stop tokens based on template
stop = get_reasoning_stop_tokens(self.chat_template) stop = get_reasoning_stop_tokens(self.chat_template)
# Check for grammar-guided generation
use_grammar = grammar
if use_grammar is None:
# Check global flag
try:
import coderai
if getattr(coderai, 'grammar_guided_gen', False):
use_grammar = get_tool_call_grammar()
except (ImportError, AttributeError):
pass
try: try:
# For chat, we need to extract just the new text from each chunk # For chat, we need to extract just the new text from each chunk
# The first chunk will have the full prompt + first token # The first chunk will have the full prompt + first token
...@@ -637,6 +688,7 @@ class VulkanBackend(ModelBackend): ...@@ -637,6 +688,7 @@ class VulkanBackend(ModelBackend):
repeat_penalty=repeat_penalty, repeat_penalty=repeat_penalty,
stop=stop, stop=stop,
stream=True, stream=True,
grammar=use_grammar,
): ):
text = chunk['choices'][0].get('text', '') text = chunk['choices'][0].get('text', '')
...@@ -659,6 +711,41 @@ class VulkanBackend(ModelBackend): ...@@ -659,6 +711,41 @@ class VulkanBackend(ModelBackend):
break break
except Exception as e: except Exception as e:
# If grammar generation fails, fall back to normal generation
if use_grammar:
print(f"Warning: Grammar-guided streaming generation failed: {e}, falling back to normal generation")
try:
first_chunk = True
prompt_len = len(prompt) if isinstance(prompt, str) else 0
for chunk in self.model.create_completion(
prompt=prompt,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
repeat_penalty=repeat_penalty,
stop=stop,
stream=True,
):
text = chunk['choices'][0].get('text', '')
if first_chunk:
if text and len(text) > prompt_len:
new_text = text[prompt_len:]
if new_text:
yield new_text
first_chunk = False
else:
if text:
yield text
if chunk['choices'][0].get('finish_reason'):
break
except Exception as e2:
print(f"Error during fallback streaming generation: {e2}")
raise
else:
print(f"Error during streaming generation: {e}") print(f"Error during streaming generation: {e}")
raise raise
......
"""Grammar loading utilities for grammar-guided generation."""
import os
from typing import Optional
# Default grammar file path
DEFAULT_GRAMMAR_PATH = os.path.join(os.path.dirname(__file__), "tool_call_grammar.gbnf")
# Cache for the loaded grammar
_grammar_cache: Optional[str] = None
def load_tool_call_grammar(grammar_path: Optional[str] = None) -> str:
"""Load the GBNF grammar for tool calls.
Args:
grammar_path: Optional path to custom grammar file.
If None, uses default tool_call_grammar.gbnf.
Returns:
The grammar string.
"""
global _grammar_cache
path = grammar_path or DEFAULT_GRAMMAR_PATH
# Return cached version if available
if _grammar_cache is not None and path == DEFAULT_GRAMMAR_PATH:
return _grammar_cache
try:
with open(path, 'r') as f:
grammar = f.read()
# Cache the default grammar
if path == DEFAULT_GRAMMAR_PATH:
_grammar_cache = grammar
return grammar
except FileNotFoundError:
print(f"Warning: Grammar file not found at {path}")
return ""
except Exception as e:
print(f"Warning: Failed to load grammar from {path}: {e}")
return ""
def get_tool_call_grammar() -> str:
"""Get the default tool call grammar.
Returns:
The GBNF grammar string for tool calls.
"""
return load_tool_call_grammar()
def is_grammar_available() -> bool:
"""Check if the default grammar file is available.
Returns:
True if grammar is available, False otherwise.
"""
return os.path.exists(DEFAULT_GRAMMAR_PATH)
This diff is collapsed.
...@@ -304,7 +304,8 @@ class AgenticTemplateManager: ...@@ -304,7 +304,8 @@ class AgenticTemplateManager:
def format_for_raw_completion(self, system_prompt: str, user_message: str, def format_for_raw_completion(self, system_prompt: str, user_message: str,
inject_system: bool = True, inject_system: bool = True,
force_reasoning: bool = True, force_reasoning: bool = True,
tools: Optional[List[Dict]] = None) -> Tuple[str, List[str]]: tools: Optional[List[Dict]] = None,
tools_closer_prompt: bool = False) -> Tuple[str, List[str]]:
""" """
Format prompt for raw completion (bypassing chat API). Format prompt for raw completion (bypassing chat API).
...@@ -314,6 +315,8 @@ class AgenticTemplateManager: ...@@ -314,6 +315,8 @@ class AgenticTemplateManager:
inject_system: If True, injects agentic system instructions inject_system: If True, injects agentic system instructions
force_reasoning: If True, seeds prompt with thought tag to force reasoning force_reasoning: If True, seeds prompt with thought tag to force reasoning
tools: Optional list of tool definitions to include in the prompt tools: Optional list of tool definitions to include in the prompt
tools_closer_prompt: If True, place tools right before the user's message
instead of in the system prompt (prompt distillation)
Returns: Returns:
Tuple of (formatted_prompt, stop_tokens) Tuple of (formatted_prompt, stop_tokens)
...@@ -326,8 +329,8 @@ class AgenticTemplateManager: ...@@ -326,8 +329,8 @@ class AgenticTemplateManager:
# Get tool call tags for this model family # Get tool call tags for this model family
tool_tags = self.TOOL_CALL_TAGS.get(self.family_key, self.TOOL_CALL_TAGS["generic"]) tool_tags = self.TOOL_CALL_TAGS.get(self.family_key, self.TOOL_CALL_TAGS["generic"])
# Add tool descriptions to system prompt if tools are provided AND no custom system prompt exists # Build tools text if tools are provided
# (don't override client's custom system prompt with tool instructions) tools_text = None
if tools and not has_custom_system: if tools and not has_custom_system:
import json import json
tool_descriptions = [] tool_descriptions = []
...@@ -345,7 +348,14 @@ class AgenticTemplateManager: ...@@ -345,7 +348,14 @@ class AgenticTemplateManager:
tools_text += f"\n\nIMPORTANT: When you need to use a tool, you MUST format your response EXACTLY as:\n" tools_text += f"\n\nIMPORTANT: When you need to use a tool, you MUST format your response EXACTLY as:\n"
tools_text += tool_tags["json_format"] tools_text += tool_tags["json_format"]
# Prepend tools to system prompt # Handle tools placement based on tools_closer_prompt flag
if tools_text:
if tools_closer_prompt:
# Prompt distillation: place tools right before the user message
# Don't add tools to system prompt
pass
else:
# Traditional behavior: prepend tools to system prompt
effective_system = f"{tools_text}\n\n{effective_system}" if effective_system else tools_text effective_system = f"{tools_text}\n\n{effective_system}" if effective_system else tools_text
# Inject system prompt if requested # Inject system prompt if requested
...@@ -369,6 +379,42 @@ class AgenticTemplateManager: ...@@ -369,6 +379,42 @@ class AgenticTemplateManager:
if prompt.endswith(thought_tag + "\n"): if prompt.endswith(thought_tag + "\n"):
prompt = prompt[:-len(thought_tag + "\n")] prompt = prompt[:-len(thought_tag + "\n")]
# If tools_closer_prompt is enabled and we have tools, insert them before user message
if tools_closer_prompt and tools_text:
# Find the position right before the user's message and insert tools there
# The template format is: {sys}\nuser\n{user} or similar
# We need to find where user message starts and insert tools there
# For reasoning prefix templates, tools are inserted after system and before user content
# Format: <system>\n\nAvailable tools: <tools>\n\nUser: <message>
# Find the user message part in the prompt and insert tools before it
user_marker = None
# Try different common user message markers based on template family
if self.family_key in ("qwen", "llama3", "deepseek", "yi"):
user_marker = "<|im_start|>user"
elif self.family_key == "phi3":
user_marker = "<|user|>"
elif self.family_key == "gemma":
user_marker = "<start_of_turn>user"
elif self.family_key == "mistral":
user_marker = "[INST]"
elif self.family_key == "anthropic":
user_marker = "\n\nHuman:"
elif self.family_key == "command-r":
user_marker = "<|START_OF_TURN_TOKEN|><|USER_TOKEN|>"
else:
# Generic: look for "User:" or "\nUser\n" or similar
user_marker = "\nUser:"
# Find where user message starts
user_pos = prompt.find(user_marker)
if user_pos != -1:
# Insert tools right before user message
tools_section = f"\n\nAvailable tools: {tools_text}\n"
prompt = prompt[:user_pos] + tools_section + prompt[user_pos:]
stop_tokens = self.get_stop_tokens() stop_tokens = self.get_stop_tokens()
return prompt, stop_tokens return prompt, stop_tokens
......
# GBNF Grammar for Tool Call Generation
# This grammar constrains the model to produce valid JSON tool calls
# Supports multiple tool call formats:
# 1. XML-style: <tool>{"name": "...", "arguments": {...}}</tool>
# 2. JSON array: [{"name": "...", "arguments": {...}}]
# 3. JSON object: {"name": "...", "arguments": {...}}
root ::= tool-call
# Main tool call - accepts XML-wrapped JSON or direct JSON
tool-call ::= (xml-wrapper-json | json-array | json-object) ws*
# XML-style: <tool>{"name": "...", "arguments": {...}}</tool>
xml-wrapper-json ::= "<tool>" ws* json-object ws* "</tool>"
# JSON array format: [{"name": "...", "arguments": {...}}, ...]
json-array ::= "[" ws* json-object (ws* "," ws* json-object)* ws* "]"
# JSON object format: {"name": "...", "arguments": {...}}
json-object ::= "{" ws* string ws* ":" ws* value ws* ("," ws* string ws* ":" ws* value)* ws* "}"
# Tool call entry
tool-call-entry ::= "{" ws* "\"name\"" ws* ":" ws* string ws* "," ws* "\"arguments\"" ws* ":" ws* arguments-object ws* "}"
# Arguments object
arguments-object ::= "{" ws* (string ws* ":" ws* arg-value ws* ("," ws* string ws* ":" ws* arg-value ws*)*)? ws* "}"
# String (JSON string with escapes)
string ::= "\"" (escaped-char | [^"\\])* "\""
# Escaped character
escaped-char ::= "\\" (["\\/bfnrt] | "u" hex hex hex hex)
# Hex digit
hex ::= [0-9A-Fa-f]
# Value (any JSON value)
value ::= string | number | boolean | null | json-object | json-array
# Argument value (restricted set for tool arguments)
arg-value ::= string | number | boolean | null | arg-object | arg-array
# Argument object
arg-object ::= "{" ws* (string ws* ":" ws* arg-value ws* ("," ws* string ws* ":" ws* arg-value ws*)*)? ws* "}"
# Argument array
arg-array ::= "[" ws* (arg-value ws* ("," ws* arg-value ws*)*)? ws* "]"
# Number (integer or float)
number ::= "-"? [0-9]+ ("." [0-9]+)? (("e" | "E") ("+" | "-")? [0-9]+)?
# Boolean
boolean ::= "true" | "false"
# Null
null ::= "null"
# Whitespace (optional)
ws ::= [ \t\n\r]*
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment