Commit 076a7724 authored by Your Name's avatar Your Name

Remove --parser litellm option and add OpenAIFormatter for response sanitization

- Remove the --parser argument and litellm backend handling code
- Add OpenAIFormatter class in codai/models/parsers.py for final response sanitization
- Integrate formatter into both streaming and non-streaming response paths
- Use litellm's ModelResponse and ChatCompletionChunk for proper OpenAI format
parent b505de59
import time
import uuid
from litellm import ModelResponse, ChatCompletionChunk, Choices, StreamingChoices, Delta, Message, Usage
class OpenAIFormatter:
"""Formatter for standardizing chat completion responses in OpenAI format.
This class provides final sanitization of responses before sending them
to clients. It processes the output of the internal parser and formats
them into proper OpenAI-compatible responses.
"""
def __init__(self, model_name: str):
self.model_name = model_name
self.id = f"chatcmpl-{uuid.uuid4()}"
def format_full(self, text: str, prompt_tokens: int, completion_tokens: int, tool_calls=None) -> dict:
"""Format a standard (non-streaming) response.
Args:
text: The generated text content
prompt_tokens: Number of tokens in the prompt
completion_tokens: Number of tokens in the completion
tool_calls: Optional list of tool calls to include
Returns:
Dictionary representation of ModelResponse
"""
return ModelResponse(
id=self.id,
model=self.model_name,
object="chat.completion",
created=int(time.time()),
choices=[Choices(
finish_reason="tool_calls" if tool_calls else "stop",
index=0,
message=Message(content=text if not tool_calls else None, role="assistant", tool_calls=tool_calls)
)],
usage=Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens
)
).model_dump()
def format_chunk(self, delta_text: str, is_final: bool = False, usage: dict = None) -> dict:
"""Format a streaming chunk response.
Args:
delta_text: The incremental text content for this chunk
is_final: Whether this is the final chunk
usage: Optional usage information (typically only sent on final chunk)
Returns:
Dictionary representation of ChatCompletionChunk
"""
return ChatCompletionChunk(
id=self.id,
model=self.model_name,
object="chat.completion.chunk",
created=int(time.time()),
choices=[StreamingChoices(
finish_reason="stop" if is_final else None,
index=0,
delta=Delta(content=delta_text, role="assistant")
)],
usage=usage # Only send usage on the final chunk
).model_dump()
def format_final_chunk(self, usage: dict = None) -> dict:
"""Format the final streaming chunk with usage information.
Args:
usage: Usage statistics dictionary with prompt_tokens, completion_tokens, total_tokens
Returns:
Dictionary representation of the final ChatCompletionChunk
"""
return ChatCompletionChunk(
id=self.id,
model=self.model_name,
object="chat.completion.chunk",
created=int(time.time()),
choices=[StreamingChoices(
finish_reason="stop",
index=0,
delta=Delta(content=None, role="assistant")
)],
usage=usage
).model_dump()
"""
LiteLLM Backend - OpenAI-compatible chat completion using litellm.
This module provides a litellm-based backend for the OpenAI-compatible API,
used when --parser litellm is specified.
"""
import os
import json
import re
from typing import List, Dict, Any, Optional, AsyncGenerator, Union
try:
import litellm
# Register 'coderai' as an alias for the 'openai' provider
# This allows LiteLLM to use its internal HTTP handler for custom providers
# by mapping them to the 'openai' provider behavior
litellm.custom_provider_map = [
{"provider": "coderai", "custom_handler": litellm.openai}
]
from litellm import acompletion, completion
from litellm.exceptions import (
AuthenticationError,
BadRequestError,
RateLimitError,
ServiceUnavailableError,
ContextWindowExceededError,
)
LITELLM_AVAILABLE = True
# Map litellm exceptions to OpenAI error codes
ERROR_CODE_MAP = {
AuthenticationError: {"code": 401, "type": "invalid_api_key"},
BadRequestError: {"code": 400, "type": "invalid_request_error"},
RateLimitError: {"code": 429, "type": "rate_limit_error"},
ServiceUnavailableError: {"code": 503, "type": "service_unavailable"},
ContextWindowExceededError: {"code": 400, "type": "context_window_exceeded"},
}
except ImportError:
LITELLM_AVAILABLE = False
litellm = None
completion = None
acompletion = None
ERROR_CODE_MAP = {}
def get_error_response(status_code: int, message: str, error_type: str = "internal_error") -> Dict:
"""Create an OpenAI-compatible error response."""
return {
"error": {
"message": message,
"type": error_type,
"code": status_code,
}
}
class LiteLLMBackend:
"""
LiteLLM-based backend for OpenAI-compatible chat completions.
Used when --parser litellm is specified to leverage litellm's
standardized response format and broader model support.
"""
def __init__(
self,
model: str = "gpt-3.5-turbo",
api_key: Optional[str] = None,
base_url: Optional[str] = None,
api_base: Optional[str] = None, # Add api_base parameter
context_window: int = 4096,
model_manager: Optional[Any] = None,
**kwargs
):
"""
Initialize the LiteLLM backend.
Args:
model: Model name to use (e.g., "gpt-3.5-turbo", "ollama/llama2")
api_key: API key for the model provider
base_url: Custom base URL for OpenAI-compatible APIs
api_base: API base URL (alternative to base_url, e.g., "http://localhost:11434/v1")
context_window: Maximum context window size for rate limit headers
model_manager: Reference to MultiModelManager for resolving aliases
"""
self.model = model
# Use provided API key, or generate a fake one if not provided
# This allows litellm to proceed without requiring an API key
self.api_key = api_key if api_key else "sk-fakekey"
self.base_url = base_url or api_base # Use either base_url or api_base
self.context_window = context_window
self.model_manager = model_manager
self.tool_parser = None # Coderai's tool parser for post-processing
self.tools_schema = {} # Tools schema for coderai parser
# Configure litellm
if self.base_url:
litellm.base_url = self.base_url
if self.api_key:
litellm.api_key = self.api_key
# Turn on litellm debug mode if global debug is enabled
_setup_litellm_debug()
def normalize_model_name(self, model: str) -> str:
"""
Normalize model name for litellm.
Always formats as: openai/{provider}/{model}
- If provider is detected from known patterns, use it
- If model has / (e.g. org/model), use the org name as provider
- If provider unknown, use "coderai" as default
Args:
model: Original model name (may be an alias)
Returns:
Normalized model name: openai/{provider}/{model}
"""
print(f"DEBUG litellm: normalize_model_name input: {model}")
# First, resolve alias to actual model name if we have a model manager
resolved_model = self._resolve_model_alias(model)
print(f"DEBUG litellm: After alias resolution: {resolved_model}")
# Known litellm providers
known_providers = ['openai', 'anthropic', 'gemini', 'meta', 'mistral', 'cohere',
'ai21', 'bedrock', 'azure', 'ollama', 'huggingface', 'deepseek',
'qwen', 'sagemaker', 'vertex', 'aiplatform', 'vllm', 'tgi', 'coderai']
# Check if there's an existing provider prefix (contains /)
if '/' in resolved_model:
parts = resolved_model.split('/')
prefix = parts[0].lower()
if prefix in known_providers:
# Valid provider, reformat as openai/{provider}/{model}
model_part = '/'.join(parts[1:])
result = f"coderai/{prefix}/{model_part}"
print(f"DEBUG litellm: Known provider '{prefix}', returning: {result}")
return result
# Otherwise, treat the first part as the org/provider name (not default to huggingface)
# This allows custom model paths like TeichAI/model-name to work correctly
org_name = parts[0]
model_part = '/'.join(parts[1:])
result = f"coderai/{org_name}/{model_part}"
print(f"DEBUG litellm: Custom org/model, returning: {result}")
return result
# No provider prefix - detect provider from model name pattern
provider_map = {
# OpenAI models
'gpt-': 'openai',
'gpt3': 'openai',
'gpt4': 'openai',
# Anthropic models
'claude': 'anthropic',
# Google models
'gemini': 'gemini',
'palm': 'gemini',
# Meta/Llama models
'llama': 'meta',
'llama2': 'meta',
'llama3': 'meta',
# Mistral models
'mistral': 'mistral',
# AWS models
'amazon': 'bedrock',
# Azure models
'azure': 'azure',
# Cohere models
'cohere': 'cohere',
# AI21 models
'ai21': 'ai21',
# Local/Ollama models
'ollama': 'ollama',
# HuggingFace models
'hf': 'huggingface',
# DeepSeek models
'deepseek': 'deepseek',
# Qwen models
'qwen': 'qwen',
}
model_lower = resolved_model.lower()
# Check for known patterns
for pattern, provider in provider_map.items():
if model_lower.startswith(pattern):
result = f"coderai/{provider}/{resolved_model}"
print(f"DEBUG litellm: Detected provider '{provider}', returning: {result}")
return result
# Default: use "coderai" as provider for unknown models
result = f"coderai/coderai/{resolved_model}"
print(f"DEBUG litellm: Unknown provider, using 'coderai', returning: {result}")
return result
def _resolve_model_alias(self, model: str) -> str:
"""
Resolve model alias to actual model name.
Handles aliases like "default", "image", "audio", "tts", or custom aliases
registered via --model-alias.
Args:
model: Model name or alias
Returns:
Resolved actual model name
"""
if not self.model_manager:
print(f"DEBUG litellm: No model_manager, returning model as-is: {model}")
return model
# Check if model is "default" or empty - use default_model
if not model or model == "default":
default_model = getattr(self.model_manager, 'default_model', None)
print(f"DEBUG litellm: Resolving 'default' alias to: {default_model}")
if default_model:
return default_model
return model
# Check if model is "image" - get first image model
if model == "image":
image_models = getattr(self.model_manager, 'image_models', [])
resolved = image_models[0] if image_models else model
print(f"DEBUG litellm: Resolving 'image' alias to: {resolved}")
return resolved
# Check if model is "audio" - get first audio model
if model == "audio":
audio_models = getattr(self.model_manager, 'audio_models', [])
resolved = audio_models[0] if audio_models else model
print(f"DEBUG litellm: Resolving 'audio' alias to: {resolved}")
return resolved
# Check if model is "tts" - get tts model
if model == "tts":
tts_model = getattr(self.model_manager, 'tts_model', None)
print(f"DEBUG litellm: Resolving 'tts' alias to: {tts_model}")
if tts_model:
return tts_model
return model
# Check custom aliases registered via --model-alias
model_aliases = getattr(self.model_manager, 'model_aliases', {})
if model in model_aliases:
resolved = model_aliases[model]
print(f"DEBUG litellm: Resolving alias '{model}' to: {resolved}")
return resolved
print(f"DEBUG litellm: Model '{model}' is not an alias, returning as-is")
return model
def _convert_messages(self, messages: List[Dict]) -> List[Dict]:
"""Convert OpenAI message format to litellm format."""
converted = []
for msg in messages:
# Handle both 'content' and 'tool' role variations
role = msg.get("role", "user")
content = msg.get("content", "")
# Handle tool calls
if "tool_calls" in msg and msg["tool_calls"]:
tool_calls = []
for tc in msg["tool_calls"]:
if isinstance(tc, dict):
tool_calls.append({
"id": tc.get("id", ""),
"type": "function",
"function": {
"name": tc.get("function", {}).get("name", ""),
"arguments": tc.get("function", {}).get("arguments", "")
}
})
# Add the assistant message with tool calls
converted.append({
"role": role,
"content": content,
"tool_calls": tool_calls
})
elif msg.get("tool_call_id"):
# Tool result message
converted.append({
"role": role,
"content": content,
"tool_call_id": msg.get("tool_call_id")
})
else:
converted.append({
"role": role,
"content": content
})
return converted
def _calculate_tokens_remaining(self, prompt_tokens: int) -> int:
"""Calculate remaining context window tokens."""
return max(0, self.context_window - prompt_tokens)
def _create_response_headers(
self,
prompt_tokens: int,
completion_tokens: int,
total_tokens: int
) -> Dict[str, str]:
"""Create rate limit headers for the response."""
remaining = self._calculate_tokens_remaining(prompt_tokens)
return {
"x-ratelimit-limit-tokens": str(self.context_window),
"x-ratelimit-remaining-tokens": str(remaining),
"x-ratelimit-limit-requests": "60", # Default, can be overridden
"x-ratelimit-remaining-requests": "60",
"x-ratelimit-limit-tokens-usage": str(total_tokens),
"x-ratelimit-remaining-tokens-usage": str(completion_tokens),
"x-ratelimit-token-usage": str(total_tokens),
}
def _parse_tool_calls(self, response: Dict) -> List[Dict]:
"""Parse tool calls from litellm response."""
tool_calls = []
# Check for tool calls in the response
if "choices" in response and response["choices"]:
choice = response["choices"][0]
if "message" in choice:
msg = choice["message"]
if "tool_calls" in msg:
for tc in msg["tool_calls"]:
if isinstance(tc, dict):
tool_calls.append({
"id": tc.get("id", f"call_{id(tc)}"),
"type": "function",
"function": {
"name": tc.get("function", {}).get("name", ""),
"arguments": tc.get("function", {}).get("arguments", "{}")
}
})
return tool_calls
def _extract_content(self, response: Dict) -> str:
"""Extract content from litellm response."""
if "choices" in response and response["choices"]:
choice = response["choices"][0]
if "message" in choice:
return choice["message"].get("content", "") or ""
return ""
def _create_chunk(
self,
content: str,
role: str = "assistant",
tool_calls: Optional[List[Dict]] = None,
finish_reason: Optional[str] = None,
index: int = 0
) -> Dict:
"""Create a chat completion chunk."""
chunk = {
"id": f"chatcmpl-{id(content)}",
"object": "chat.completion.chunk",
"created": 0,
"model": self.model,
"choices": [{
"index": index,
"delta": {
"role": role,
"content": content
},
"finish_reason": finish_reason
}]
}
if tool_calls:
chunk["choices"][0]["delta"]["tool_calls"] = tool_calls
return chunk
async def chat_completion(
self,
messages: List[Dict],
model: Optional[str] = None,
temperature: float = 0.7,
top_p: float = 1.0,
max_tokens: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
tools: Optional[List[Dict]] = None,
tool_choice: Optional[Union[str, Dict]] = "auto",
stream: bool = False,
tool_parser=None, # Add coderai's tool parser for post-processing
**kwargs
) -> Union[Dict, AsyncGenerator]:
"""
Generate a chat completion using litellm.
Args:
messages: List of message dictionaries
model: Optional model override
temperature: Sampling temperature
top_p: Top-p sampling
max_tokens: Maximum tokens to generate
stop: Stop sequences
tools: Tool definitions
tool_choice: Tool choice mode
stream: Whether to stream the response
tool_parser: Optional coderai tool parser for post-processing tool calls
Returns:
Response dict or async generator for streaming
"""
if not LITELLM_AVAILABLE:
raise RuntimeError("litellm is not installed. Run: pip install litellm")
# Store tool_parser for post-processing
self.tool_parser = tool_parser
# Convert tools to coderai schema format if tools provided
if tools:
self.tools_schema = {}
for tool in tools:
if isinstance(tool, dict) and 'function' in tool:
func = tool.get('function', {})
self.tools_schema[func.get('name', '')] = {
'description': func.get('description', ''),
'parameters': func.get('parameters', {})
}
# Prepare the model - normalize name for litellm
use_model = self.normalize_model_name(model or self.model)
# For HuggingFace models, set a fake API key to skip auth
# The key must be in "sk-fakekey" format for litellm to accept it
if 'huggingface' in use_model.lower():
litellm.api_key = "sk-fakekey"
print("DEBUG litellm: HuggingFace model - using fake key")
# Debug: show api_key and api_base being used
# Always show this debug info to help diagnose litellm configuration
print(f"DEBUG litellm: api_key={litellm.api_key}, api_base={self.base_url}, model={use_model}")
# Convert messages to litellm format
litellm_messages = self._convert_messages(messages)
# Prepare completion arguments
completion_args = {
"model": use_model,
"messages": litellm_messages,
"temperature": temperature,
"top_p": top_p,
"stream": stream,
}
if max_tokens:
completion_args["max_tokens"] = max_tokens
if stop:
completion_args["stop"] = stop
if tools:
completion_args["tools"] = tools
if tool_choice:
completion_args["tool_choice"] = tool_choice
# Add any additional kwargs
completion_args.update(kwargs)
if stream:
return self._stream_response(completion_args)
else:
return await self._get_response(completion_args)
async def _get_response(self, completion_args: Dict) -> Dict:
"""Get a non-streaming response from litellm."""
try:
response = await acompletion(**completion_args)
return self._process_response(response)
except Exception as e:
return self._handle_error(e)
def _process_response(self, response: Any) -> Dict:
"""Process litellm response into OpenAI format."""
# Convert litellm response to OpenAI format
usage = {}
if hasattr(response, "usage") and response.usage:
usage = {
"prompt_tokens": response.usage.get("prompt_tokens", 0),
"completion_tokens": response.usage.get("completion_tokens", 0),
"total_tokens": response.usage.get("total_tokens", 0),
}
# Extract message content
content = ""
tool_calls = []
if hasattr(response, "choices") and response.choices:
choice = response.choices[0]
if hasattr(choice, "message"):
msg = choice.message
content = msg.content or ""
# Handle tool calls
if hasattr(msg, "tool_calls") and msg.tool_calls:
for tc in msg.tool_calls:
if hasattr(tc, "function"):
func = tc.function
tool_calls.append({
"id": tc.id or f"call_{id(tc)}",
"type": "function",
"function": {
"name": func.name,
"arguments": func.arguments
}
})
# Build OpenAI-compatible response
result = {
"id": f"chatcmpl-{id(response)}",
"object": "chat.completion",
"created": getattr(response, "created", 0),
"model": getattr(response, "model", self.model),
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": content,
},
"finish_reason": getattr(response.choices[0], "finish_reason", None) if hasattr(response, "choices") and response.choices else None,
}],
"usage": usage,
}
if tool_calls:
result["choices"][0]["message"]["tool_calls"] = tool_calls
# Use coderai's tool parser for post-processing if available
if self.tool_parser and content:
# Try to extract tool calls using coderai's parser
try:
# Convert tools to the format expected by coderai parser
tools_schema = {}
if hasattr(self, 'tools_schema') and self.tools_schema:
tools_schema = self.tools_schema
# Use coderai parser to extract tool calls from content
parsed_tool_calls = self.tool_parser.extract_tool_calls(content, tools_schema) if hasattr(self.tool_parser, 'extract_tool_calls') else None
if parsed_tool_calls:
# Replace tool calls with coderai-parsed versions
result["choices"][0]["message"]["tool_calls"] = parsed_tool_calls
# Strip tool tags from content
if hasattr(self.tool_parser, 'strip_tool_calls_from_content'):
clean_content = self.tool_parser.strip_tool_calls_from_content(content)
result["choices"][0]["message"]["content"] = clean_content
except Exception as e:
print(f"DEBUG litellm: Coderai parser post-processing error: {e}")
return result
async def _stream_response(self, completion_args: Dict) -> AsyncGenerator:
"""Stream response from litellm."""
try:
response = await acompletion(**completion_args)
async for chunk in response:
yield self._process_stream_chunk(chunk)
except Exception as e:
error_resp = self._handle_error(e)
yield error_resp
def _process_stream_chunk(self, chunk: Any) -> Dict:
"""Process a streaming chunk from litellm."""
content = ""
tool_calls = []
finish_reason = None
if hasattr(chunk, "choices") and chunk.choices:
choice = chunk.choices[0]
if hasattr(choice, "delta"):
delta = choice.delta
content = delta.content or ""
if hasattr(delta, "tool_calls") and delta.tool_calls:
for tc in delta.tool_calls:
if hasattr(tc, "function"):
func = tc.function
tool_calls.append({
"id": tc.id or f"call_{id(tc)}",
"type": "function",
"function": {
"name": func.name,
"arguments": func.arguments
}
})
finish_reason = getattr(choice, "finish_reason", None)
result = {
"id": f"chatcmpl-{id(chunk)}",
"object": "chat.completion.chunk",
"created": getattr(chunk, "created", 0),
"model": getattr(chunk, "model", self.model),
"choices": [{
"index": 0,
"delta": {},
"finish_reason": finish_reason,
}]
}
if content:
result["choices"][0]["delta"]["content"] = content
if tool_calls:
result["choices"][0]["delta"]["tool_calls"] = tool_calls
# Accumulate content for coderai parser post-processing at end of stream
if content:
if not hasattr(self, '_accumulated_content'):
self._accumulated_content = ""
self._accumulated_content += content
# Use coderai's tool parser for post-processing if available and this is final chunk
if self.tool_parser and hasattr(self, '_accumulated_content') and self._accumulated_content:
if finish_reason == 'stop':
try:
# Use coderai parser to extract tool calls from accumulated content
tools_schema = getattr(self, 'tools_schema', {})
if hasattr(self.tool_parser, 'extract_tool_calls'):
parsed_tool_calls = self.tool_parser.extract_tool_calls(self._accumulated_content, tools_schema)
if parsed_tool_calls:
# Add tool calls to final chunk
result["choices"][0]["delta"]["tool_calls"] = parsed_tool_calls
# Strip tool tags from content
if hasattr(self.tool_parser, 'strip_tool_calls_from_content'):
clean_content = self.tool_parser.strip_tool_calls_from_content(self._accumulated_content)
result["choices"][0]["delta"]["content"] = clean_content
# Clear accumulated content after processing
self._accumulated_content = ""
except Exception as e:
print(f"DEBUG litellm: Coderai parser stream post-processing error: {e}")
return result
def _handle_error(self, exception: Exception) -> Dict:
"""Handle litellm exceptions and convert to OpenAI format."""
error_info = ERROR_CODE_MAP.get(type(exception), {"code": 500, "type": "internal_error"})
return {
"error": {
"message": str(exception),
"type": error_info["type"],
"code": error_info["code"],
}
}
def parse_qwen_tool_calls(self, text: str) -> List[Dict]:
"""
Parse Qwen-style tool calls from text content.
Handles both <tool> and <tool_call> tags, with support for:
- JSON format: <tool>{"name": "func", "arguments": {...}}</tool>
- Coder style: <tool=func><parameter=key>value</parameter></tool>
Returns a list of tool call dictionaries in OpenAI format.
"""
tool_calls = []
# 1. IMMEDIATE REPETITION GUARD - handle looping
if text.count('<tool') > 1:
parts = re.split(r'<(?:tool|tool_call)', text, flags=re.IGNORECASE)
text = f"<tool{parts[1]}" if len(parts) > 1 else text
# 2. Pre-cleaning (remove thinking tags)
clean_text = re.sub(r'<\|.*?\|>|<(?:thought|think)>.*?((?:</(?:thought|think)>)|$)', '', text, flags=re.DOTALL | re.IGNORECASE)
# 3. MATCH BOTH <tool> AND <tool_call>
tag_pattern = r'<(?:tool|tool_call)>(.*?)(?:</(?:tool|tool_call)>|$)'
matches = re.findall(tag_pattern, clean_text, re.DOTALL | re.IGNORECASE)
# If no tags found but text looks like JSON, try whole text
if not matches and '{' in clean_text and '"name"' in clean_text:
matches = [clean_text]
for block in matches:
block = block.strip()
if not block:
continue
# Clean markdown and detect partial JSON
json_str = re.sub(r'```(?:json)?\s*(.*?)\s*```', r'\1', block, flags=re.DOTALL).strip()
# Recovery of unclosed JSON
if json_str.startswith('{') and not json_str.endswith('}'):
json_str += '}'
try:
data = json.loads(json_str)
if 'name' in data:
tool_calls.append({
"id": f"call_{id(data)}",
"type": "function",
"function": {
"name": data['name'],
"arguments": json.dumps(data.get('arguments', {} or data.get('parameters', {})))
}
})
break # Circuit breaker after first valid call
except json.JSONDecodeError:
# Fallback: try regex extraction
name_match = re.search(r'"name":\s*"([^"]+)"', json_str)
if name_match:
tool_calls.append({
"id": f"call_{id(name_match)}",
"type": "function",
"function": {
"name": name_match.group(1),
"arguments": "{}"
}
})
break
# 4. CODER STYLE FALLBACK
if not tool_calls:
pattern = r'<(?:function|tool|call)=([^>]+)>(.*?)(?:</(?:function|tool|call|tool_call)>|$)'
for name, body in re.findall(pattern, clean_text, re.DOTALL | re.IGNORECASE):
params = re.findall(r'<parameter=([^>]+)>(.*?)</parameter>', body, re.DOTALL)
args = {}
for k, v in params:
val = v.strip()
try:
args[k.strip()] = json.loads(val)
except:
args[k.strip()] = val
tool_calls.append({
"id": f"call_{id(args)}",
"type": "function",
"function": {
"name": name.strip(),
"arguments": json.dumps(args)
}
})
break # Circuit breaker
return tool_calls
def strip_tool_tags(self, text: str) -> str:
"""Strip tool call tags from text, leaving only the content."""
# Remove <tool>...</tool> and <tool_call>...</tool_call> blocks
clean = re.sub(r'<tool[^>]*>.*?</tool[^>]*>', '', text, flags=re.DOTALL | re.IGNORECASE)
clean = re.sub(r'<tool_call[^>]*>.*?</tool_call[^>]*>', '', clean, flags=re.DOTALL | re.IGNORECASE)
clean = re.sub(r'<function[^>]*>.*?</function[^>]*>', '', clean, flags=re.DOTALL | re.IGNORECASE)
return clean.strip()
def get_rate_limit_headers(self, prompt_tokens: int = 0, completion_tokens: int = 0) -> Dict[str, str]:
"""Get rate limit headers based on current usage."""
total = prompt_tokens + completion_tokens
return self._create_response_headers(prompt_tokens, completion_tokens, total)
# Default instance
default_litellm_backend: Optional[LiteLLMBackend] = None
# Turn on litellm debug mode if global debug is enabled
def _setup_litellm_debug():
"""Turn on litellm debug mode if global debug is enabled."""
try:
import sys
# Check if global_debug is True in coderai module at runtime
if 'coderai' in sys.modules:
from coderai import global_debug
if global_debug:
import litellm
litellm._turn_on_debug()
print("DEBUG litellm: Debug mode enabled")
except Exception as e:
print(f"DEBUG litellm: Could not enable debug mode: {e}")
def get_litellm_backend(
model: str = "gpt-3.5-turbo",
api_key: Optional[str] = None,
base_url: Optional[str] = None,
api_base: Optional[str] = None, # Add api_base parameter
context_window: int = 4096,
model_manager: Optional[Any] = None,
**kwargs
) -> LiteLLMBackend:
"""Get or create the default LiteLLM backend instance."""
global default_litellm_backend
# Always create a new instance with the provided model_manager
# This ensures aliases are resolved correctly on each call
default_litellm_backend = LiteLLMBackend(
model=model,
api_key=api_key,
base_url=base_url,
api_base=api_base,
context_window=context_window,
model_manager=model_manager,
**kwargs
)
return default_litellm_backend
def set_litellm_backend(backend: LiteLLMBackend) -> None:
"""Set the default LiteLLM backend instance."""
global default_litellm_backend
default_litellm_backend = backend
...@@ -30,6 +30,7 @@ from threading import Thread ...@@ -30,6 +30,7 @@ from threading import Thread
# Import codai module for enhanced tool call parsing # Import codai module for enhanced tool call parsing
from codai.models import ModelParserDispatcher from codai.models import ModelParserDispatcher
from codai.models.parsers import OpenAIFormatter
# Per-model semaphores for request concurrency control # Per-model semaphores for request concurrency control
model_semaphores: dict = {} model_semaphores: dict = {}
load_mode = {"mode": "ondemand"} # Track load mode globally load_mode = {"mode": "ondemand"} # Track load mode globally
...@@ -5170,223 +5171,7 @@ async def create_speech(request: TTSRequest): ...@@ -5170,223 +5171,7 @@ async def create_speech(request: TTSRequest):
async def chat_completions(request: ChatCompletionRequest, http_request: Request = None): async def chat_completions(request: ChatCompletionRequest, http_request: Request = None):
"""Chat completions endpoint with streaming and tool support.""" """Chat completions endpoint with streaming and tool support."""
# Check if we should use litellm backend # Continue with original implementation
parser_type = getattr(global_args, 'parser', 'auto') if global_args else 'auto'
if parser_type == 'litellm':
# Use LiteLLM backend
from codai.openai.litellm import get_litellm_backend, LITELLM_AVAILABLE
if not LITELLM_AVAILABLE:
raise HTTPException(
status_code=500,
detail="LiteLLM is not installed. Run: pip install litellm"
)
# Check for API key in request - litellm requires an API key
# If not provided, use a fake key to allow the request to proceed
api_key = None
# Try to get API key from request body
if hasattr(request, 'api_key') and request.api_key:
api_key = request.api_key
# If no API key in body, try to get from Authorization header
if not api_key:
auth_header = http_request.headers.get('Authorization', '') if http_request else ''
if auth_header.startswith('Bearer '):
api_key = auth_header[7:] # Extract token after 'Bearer '
# If still no API key, use a fake key to allow litellm to proceed
# litellm will then fail with the actual provider error if needed
if not api_key:
api_key = "fake-key-for-local-testing"
print("DEBUG: No API key provided, using fake key for litellm")
# Determine the base URL for litellm to connect to
# Use the server's host and port for local connections
api_base = None
# Check if model starts with 'ollama:' - use local Ollama
if request.model and request.model.startswith('ollama:'):
# Get the host from the request headers
client_host = "127.0.0.1"
if http_request:
host_header = http_request.headers.get('host', '')
if host_header:
# Strip port if present
if ':' in host_header:
client_host = host_header.split(':')[0]
if client_host.replace('.', '').isdigit():
# It's an IP, keep it
pass
else:
# It's a hostname, use localhost
client_host = "127.0.0.1"
else:
client_host = host_header
# Get port from global_args or use default
port = getattr(global_args, 'port', 11434) if global_args else 11434
api_base = f"http://{client_host}:{port}/v1"
print(f"DEBUG: Using api_base for Ollama: {api_base}")
else:
# For non-Ollama models, use the server's own URL as base
# This allows LiteLLM to make requests to the local server
if http_request:
# Get the host from the request headers
host_header = http_request.headers.get('host', '')
if host_header:
# Strip port if present to reconstruct clean URL
if ':' in host_header:
client_host = host_header.split(':')[0]
# Keep the port from the request for consistency
server_port = host_header.split(':')[1] if len(host_header.split(':')) > 1 else str(getattr(global_args, 'port', 6745))
else:
client_host = host_header
server_port = str(getattr(global_args, 'port', 6745))
else:
# Fallback to client host if no Host header
client_host = http_request.client.host if http_request.client else "127.0.0.1"
server_port = str(getattr(global_args, 'port', 6745))
else:
# Fallback if no http_request
client_host = "127.0.0.1"
server_port = str(getattr(global_args, 'port', 6745))
# Determine protocol (http or https)
use_https = getattr(global_args, 'https', False) or getattr(global_args, 'pubkey', None)
protocol = "https" if use_https else "http"
api_base = f"{protocol}://{client_host}:{server_port}/v1"
print(f"DEBUG: Using api_base for local server: {api_base}")
# Get or create litellm backend
litellm_backend = get_litellm_backend(
model=request.model,
api_key=api_key,
api_base=api_base,
context_window=8192, # Default, can be made configurable
model_manager=multi_model_manager # Pass for alias resolution
)
# Get the tool_parser from multi_model_manager for model-specific parsing
tool_parser = multi_model_manager.tool_parser if hasattr(multi_model_manager, 'tool_parser') else None
# Convert messages to dict format
messages_dict = []
for msg in request.messages:
msg_dict = {"role": msg.role, "content": msg.content or ""}
if hasattr(msg, 'tool_calls') and msg.tool_calls:
msg_dict["tool_calls"] = msg.tool_calls
if hasattr(msg, 'tool_call_id') and msg.tool_call_id:
msg_dict["tool_call_id"] = msg.tool_call_id
messages_dict.append(msg_dict)
# Prepare tools if provided
tools_dict = None
if request.tools:
tools_dict = request.tools
# Generate response
try:
if request.stream:
# Streaming response
from fastapi.responses import StreamingResponse
async def generate():
try:
async for chunk in await litellm_backend.chat_completion(
messages=messages_dict,
model=request.model,
temperature=request.temperature,
top_p=request.top_p,
max_tokens=request.max_tokens,
stop=request.stop,
tools=tools_dict,
tool_choice=request.tool_choice,
stream=True,
tool_parser=tool_parser,
):
# Add rate limit headers
headers = {}
if 'usage' in chunk:
headers = litellm_backend.get_rate_limit_headers(
prompt_tokens=chunk.get('usage', {}).get('prompt_tokens', 0),
completion_tokens=chunk.get('usage', {}).get('completion_tokens', 0)
)
# Handle Qwen tool calls if model is Qwen family
if 'qwen' in request.model.lower():
content = chunk.get('choices', [{}])[0].get('delta', {}).get('content', '')
tool_calls = chunk.get('choices', [{}])[0].get('delta', {}).get('tool_calls', [])
if not tool_calls and content:
# Try to parse tool calls from content
tool_calls = litellm_backend.parse_qwen_tool_calls(content)
if tool_calls:
# Strip tool tags from content
content = litellm_backend.strip_tool_tags(content)
chunk['choices'][0]['delta']['content'] = content
chunk['choices'][0]['delta']['tool_calls'] = tool_calls
yield f"data: {json.dumps(chunk)}\n\n"
yield "data: [DONE]\n\n"
except Exception as e:
yield f"data: {json.dumps({'error': {'message': str(e), 'type': 'internal_error'}})}\n\n"
return StreamingResponse(generate(), media_type="text/event-stream")
else:
# Non-streaming response
response = await litellm_backend.chat_completion(
messages=messages_dict,
model=request.model,
temperature=request.temperature,
top_p=request.top_p,
max_tokens=request.max_tokens,
stop=request.stop,
tools=tools_dict,
tool_choice=request.tool_choice,
stream=False,
tool_parser=tool_parser,
)
# Handle Qwen tool calls
if 'qwen' in request.model.lower() and 'choices' in response:
msg = response['choices'][0].get('message', {})
content = msg.get('content', '')
tool_calls = msg.get('tool_calls', [])
if not tool_calls and content:
tool_calls = litellm_backend.parse_qwen_tool_calls(content)
if tool_calls:
msg['content'] = litellm_backend.strip_tool_tags(content)
msg['tool_calls'] = tool_calls
response['choices'][0]['message'] = msg
# Add rate limit headers
headers = {}
if 'usage' in response:
headers = litellm_backend.get_rate_limit_headers(
prompt_tokens=response.get('usage', {}).get('prompt_tokens', 0),
completion_tokens=response.get('usage', {}).get('completion_tokens', 0)
)
from fastapi.responses import JSONResponse
return JSONResponse(content=response, headers=headers)
except Exception as e:
# Handle litellm errors
error_response = {
"error": {
"message": str(e),
"type": "internal_error",
"code": 500
}
}
return JSONResponse(content=error_response, status_code=500)
# Continue with original implementation for 'auto' parser
# Get the model for this request # Get the model for this request
requested_model = request.model requested_model = request.model
...@@ -5756,37 +5541,14 @@ async def stream_chat_response( ...@@ -5756,37 +5541,14 @@ async def stream_chat_response(
prompt_tokens = len(prompt_text.split()) prompt_tokens = len(prompt_text.split())
completion_tokens = len(generated_text.split()) if generated_text else 0 completion_tokens = len(generated_text.split()) if generated_text else 0
# Build complete final chunk with all OpenAI fields # Build complete final chunk with OpenAIFormatter for sanitization
final_chunk = { formatter = OpenAIFormatter(model_name)
"id": completion_id, usage_details = {
"object": "chat.completion.chunk",
"created": created,
"model": model_name,
"choices": [{
"index": 0,
"finish_reason": "stop",
"logprobs": None,
"native_finish_reason": "stop",
}],
"usage": {
"prompt_tokens": prompt_tokens, "prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens, "completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens, "total_tokens": prompt_tokens + completion_tokens,
"prompt_tokens_details": {
"cached_tokens": 0,
"audio_tokens": 0,
},
"completion_tokens_details": {
"reasoning_tokens": 0,
"audio_tokens": 0,
},
},
"provider": {
"provider_name": "coderai",
"provider_id": "coderai",
},
"system_fingerprint": None,
} }
final_chunk = formatter.format_final_chunk(usage=usage_details)
yield f"data: {json.dumps(final_chunk)}\n\n" yield f"data: {json.dumps(final_chunk)}\n\n"
else: else:
# Calculate token counts for usage in final chunk # Calculate token counts for usage in final chunk
...@@ -5930,56 +5692,14 @@ async def generate_chat_response( ...@@ -5930,56 +5692,14 @@ async def generate_chat_response(
prompt_tokens = len(prompt_text.split()) prompt_tokens = len(prompt_text.split())
completion_tokens = len(generated_text.split()) if generated_text else 0 completion_tokens = len(generated_text.split()) if generated_text else 0
# Build complete OpenAI-compatible response with all standard fields # Use OpenAIFormatter for final sanitization
# Provider info formatter = OpenAIFormatter(model_name)
provider = { return formatter.format_full(
"provider_name": "coderai", text=response_message.get("content", ""),
"provider_id": "coderai", prompt_tokens=prompt_tokens,
} completion_tokens=completion_tokens,
tool_calls=response_message.get("tool_calls")
# Build choices with all OpenAI fields )
choice = {
"index": 0,
"message": response_message,
"finish_reason": finish_reason,
}
# Add logprobs (null since we don't have token-level probabilities)
choice["logprobs"] = None
# Add native_finish_reason (same as finish_reason for our purposes)
choice["native_finish_reason"] = finish_reason
# Build detailed usage information
usage_details = {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens,
}
# Add prompt_tokens_details (breakdown of prompt tokens)
usage_details["prompt_tokens_details"] = {
"cached_tokens": 0,
"audio_tokens": 0,
}
# Add completion_tokens_details (breakdown of completion tokens)
usage_details["completion_tokens_details"] = {
"reasoning_tokens": 0,
"audio_tokens": 0,
}
return {
"id": completion_id,
"object": "chat.completion",
"created": created,
"model": model_name,
"choices": [choice],
"usage": usage_details,
# Additional OpenAI-compatible fields
"provider": provider,
"system_fingerprint": None,
}
except Exception as e: except Exception as e:
print(f"Error during generation: {e}") print(f"Error during generation: {e}")
raise HTTPException(status_code=500, detail=f"Generation error: {str(e)}") raise HTTPException(status_code=500, detail=f"Generation error: {str(e)}")
...@@ -6505,13 +6225,6 @@ def parse_args(): ...@@ -6505,13 +6225,6 @@ def parse_args():
default=None, default=None,
help="Path to store generated files (images, audio). If specified, files will be saved here and served over web.", help="Path to store generated files (images, audio). If specified, files will be saved here and served over web.",
) )
parser.add_argument(
"--parser",
type=str,
default="auto",
choices=["auto", "litellm"],
help="Tool call parser to use: 'auto' for internal parser, 'litellm' for LiteLLM's parser. Default: auto",
)
return parser.parse_args() return parser.parse_args()
def main(): def main():
"""Main entry point.""" """Main entry point."""
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment