Commit 6c2c0afc authored by Your Name's avatar Your Name

Fix LiteLLM

parent 8280060e
......@@ -16,8 +16,6 @@ from .models.parser import (
from .models.templates import AgenticTemplateManager
# OpenAI-compatible backends
__all__ = [
'ModelParserDispatcher',
'BaseParser',
......
# codai module - AI model parsing utilities
from .models.parser import (
ModelParserDispatcher,
BaseParser,
QwenParser,
DeepSeekParser,
LlamaParser,
MistralParser,
ClaudeParser,
CommandRParser,
GemmaParser,
GrokParser,
PhiParser,
ApexBig50Parser,
)
from .models.templates import AgenticTemplateManager
# OpenAI-compatible backends
from .openai.litellm import (
LiteLLMBackend,
get_litellm_backend,
set_litellm_backend,
LITELLM_AVAILABLE,
)
__all__ = [
'ModelParserDispatcher',
'BaseParser',
'QwenParser',
'DeepSeekParser',
'LlamaParser',
'MistralParser',
'ClaudeParser',
'CommandRParser',
'GemmaParser',
'GrokParser',
'PhiParser',
'ApexBig50Parser',
'AgenticTemplateManager',
'LiteLLMBackend',
'get_litellm_backend',
'set_litellm_backend',
'LITELLM_AVAILABLE',
]
import time
import uuid
# Try to import litellm for response formatting
# Fall back to plain dicts if litellm is not available or doesn't export these
try:
from litellm import ModelResponse, ChatCompletionChunk
LITELLM_AVAILABLE = True
except ImportError:
LITELLM_AVAILABLE = False
ModelResponse = None
ChatCompletionChunk = None
class OpenAIFormatter:
"""Formatter for standardizing chat completion responses in OpenAI format.
......@@ -123,3 +133,74 @@ class OpenAIFormatter:
chunk["usage"] = usage
return chunk
def format_litellm_full(self, text: str, prompt_tokens: int, completion_tokens: int, tool_calls=None) -> dict:
"""Format using litellm's ModelResponse if available.
Args:
text: The generated text content
prompt_tokens: Number of tokens in the prompt
completion_tokens: Number of tokens in the completion
tool_calls: Optional list of tool calls to include
Returns:
Dictionary representation of ModelResponse
"""
if not LITELLM_AVAILABLE or ModelResponse is None:
return self.format_full(text, prompt_tokens, completion_tokens, tool_calls)
try:
from litellm import Choices, Message, Usage
return ModelResponse(
id=self.id,
model=self.model_name,
object="chat.completion",
created=int(time.time()),
choices=[Choices(
finish_reason="tool_calls" if tool_calls else "stop",
index=0,
message=Message(content=text if not tool_calls else None, role="assistant", tool_calls=tool_calls)
)],
usage=Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens
)
).model_dump()
except Exception:
# Fall back to plain dict if litellm fails
return self.format_full(text, prompt_tokens, completion_tokens, tool_calls)
def format_litellm_chunk(self, delta_text: str, is_final: bool = False, usage: dict = None) -> dict:
"""Format streaming chunk using litellm's ChatCompletionChunk if available.
Args:
delta_text: The incremental text content for this chunk
is_final: Whether this is the final chunk
usage: Optional usage information (typically only sent on final chunk)
Returns:
Dictionary representation of ChatCompletionChunk
"""
if not LITELLM_AVAILABLE or ChatCompletionChunk is None:
return self.format_chunk(delta_text, is_final, usage)
try:
from litellm import StreamingChoices, Delta, Usage
return ChatCompletionChunk(
id=self.id,
model=self.model_name,
object="chat.completion.chunk",
created=int(time.time()),
choices=[StreamingChoices(
finish_reason="stop" if is_final else None,
index=0,
delta=Delta(content=delta_text, role="assistant")
)],
usage=Usage(**usage) if usage else None
).model_dump()
except Exception:
# Fall back to plain dict if litellm fails
return self.format_chunk(delta_text, is_final, usage)
import time
import uuid
# Try to import litellm for response formatting
# Fall back to plain dicts if litellm is not available or doesn't export these
try:
from litellm import ModelResponse, ChatCompletionChunk
LITELLM_AVAILABLE = True
except ImportError:
LITELLM_AVAILABLE = False
ModelResponse = None
ChatCompletionChunk = None
class OpenAIFormatter:
"""Formatter for standardizing chat completion responses in OpenAI format.
This class provides final sanitization of responses before sending them
to clients. It processes the output of the internal parser and formats
them into proper OpenAI-compatible responses.
"""
def __init__(self, model_name: str):
self.model_name = model_name
self.id = f"chatcmpl-{uuid.uuid4()}"
def format_full(self, text: str, prompt_tokens: int, completion_tokens: int, tool_calls=None) -> dict:
"""Format a standard (non-streaming) response.
Args:
text: The generated text content
prompt_tokens: Number of tokens in the prompt
completion_tokens: Number of tokens in the completion
tool_calls: Optional list of tool calls to include
Returns:
Dictionary representation of the response
"""
message = {
"role": "assistant",
"content": text if not tool_calls else None,
}
if tool_calls:
message["tool_calls"] = tool_calls
choice = {
"index": 0,
"message": message,
"finish_reason": "tool_calls" if tool_calls else "stop",
}
return {
"id": self.id,
"object": "chat.completion",
"created": int(time.time()),
"model": self.model_name,
"choices": [choice],
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens,
},
"provider": {
"provider_name": "coderai",
"provider_id": "coderai",
},
}
def format_chunk(self, delta_text: str, is_final: bool = False, usage: dict = None) -> dict:
"""Format a streaming chunk response.
Args:
delta_text: The incremental text content for this chunk
is_final: Whether this is the final chunk
usage: Optional usage information (typically only sent on final chunk)
Returns:
Dictionary representation of the chunk
"""
delta = {
"content": delta_text,
"role": "assistant",
}
choice = {
"index": 0,
"delta": delta,
"finish_reason": "stop" if is_final else None,
}
chunk = {
"id": self.id,
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": self.model_name,
"choices": [choice],
}
if usage and is_final:
chunk["usage"] = usage
return chunk
def format_final_chunk(self, usage: dict = None) -> dict:
"""Format the final streaming chunk with usage information.
Args:
usage: Usage statistics dictionary with prompt_tokens, completion_tokens, total_tokens
Returns:
Dictionary representation of the final chunk
"""
delta = {
"content": None,
"role": "assistant",
}
choice = {
"index": 0,
"delta": delta,
"finish_reason": "stop",
}
chunk = {
"id": self.id,
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": self.model_name,
"choices": [choice],
}
if usage:
chunk["usage"] = usage
return chunk
def format_litellm_full(self, text: str, prompt_tokens: int, completion_tokens: int, tool_calls=None) -> dict:
"""Format using litellm's ModelResponse if available.
Args:
text: The generated text content
prompt_tokens: Number of tokens in the prompt
completion_tokens: Number of tokens in the completion
tool_calls: Optional list of tool calls to include
Returns:
Dictionary representation of ModelResponse
"""
if not LITELLM_AVAILABLE or ModelResponse is None:
return self.format_full(text, prompt_tokens, completion_tokens, tool_calls)
try:
from litellm import Choices, Message, Usage
return ModelResponse(
id=self.id,
model=self.model_name,
object="chat.completion",
created=int(time.time()),
choices=[Choices(
finish_reason="tool_calls" if tool_calls else "stop",
index=0,
message=Message(content=text if not tool_calls else None, role="assistant", tool_calls=tool_calls)
)],
usage=Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens
)
).model_dump()
except Exception:
# Fall back to plain dict if litellm fails
return self.format_full(text, prompt_tokens, completion_tokens, tool_calls)
def format_litellm_chunk(self, delta_text: str, is_final: bool = False, usage: dict = None) -> dict:
"""Format streaming chunk using litellm's ChatCompletionChunk if available.
Args:
delta_text: The incremental text content for this chunk
is_final: Whether this is the final chunk
usage: Optional usage information (typically only sent on final chunk)
Returns:
Dictionary representation of ChatCompletionChunk
"""
if not LITELLM_AVAILABLE or ChatCompletionChunk is None:
return self.format_chunk(delta_text, is_final, usage)
try:
from litellm import StreamingChoices, Delta, Usage
return ChatCompletionChunk(
id=self.id,
model=self.model_name,
object="chat.completion.chunk",
created=int(time.time()),
choices=[StreamingChoices(
finish_reason="stop" if is_final else None,
index=0,
delta=Delta(content=delta_text, role="assistant")
)],
usage=Usage(**usage) if usage else None
).model_dump()
except Exception:
# Fall back to plain dict if litellm fails
return self.format_chunk(delta_text, is_final, usage)
# codai.openai - OpenAI-compatible API implementations
__all__ = []
......@@ -5171,7 +5171,223 @@ async def create_speech(request: TTSRequest):
async def chat_completions(request: ChatCompletionRequest, http_request: Request = None):
"""Chat completions endpoint with streaming and tool support."""
# Continue with original implementation
# Check if we should use litellm backend
parser_type = getattr(global_args, 'parser', 'auto') if global_args else 'auto'
if parser_type == 'litellm':
# Use LiteLLM backend
from codai.openai.litellm import get_litellm_backend, LITELLM_AVAILABLE
if not LITELLM_AVAILABLE:
raise HTTPException(
status_code=500,
detail="LiteLLM is not installed. Run: pip install litellm"
)
# Check for API key in request - litellm requires an API key
# If not provided, use a fake key to allow the request to proceed
api_key = None
# Try to get API key from request body
if hasattr(request, 'api_key') and request.api_key:
api_key = request.api_key
# If no API key in body, try to get from Authorization header
if not api_key:
auth_header = http_request.headers.get('Authorization', '') if http_request else ''
if auth_header.startswith('Bearer '):
api_key = auth_header[7:] # Extract token after 'Bearer '
# If still no API key, use a fake key to allow litellm to proceed
# litellm will then fail with the actual provider error if needed
if not api_key:
api_key = "fake-key-for-local-testing"
print("DEBUG: No API key provided, using fake key for litellm")
# Determine the base URL for litellm to connect to
# Use the server's host and port for local connections
api_base = None
# Check if model starts with 'ollama:' - use local Ollama
if request.model and request.model.startswith('ollama:'):
# Get the host from the request headers
client_host = "127.0.0.1"
if http_request:
host_header = http_request.headers.get('host', '')
if host_header:
# Strip port if present
if ':' in host_header:
client_host = host_header.split(':')[0]
if client_host.replace('.', '').isdigit():
# It's an IP, keep it
pass
else:
# It's a hostname, use localhost
client_host = "127.0.0.1"
else:
client_host = host_header
# Get port from global_args or use default
port = getattr(global_args, 'port', 11434) if global_args else 11434
api_base = f"http://{client_host}:{port}/v1"
print(f"DEBUG: Using api_base for Ollama: {api_base}")
else:
# For non-Ollama models, use the server's own URL as base
# This allows LiteLLM to make requests to the local server
if http_request:
# Get the host from the request headers
host_header = http_request.headers.get('host', '')
if host_header:
# Strip port if present to reconstruct clean URL
if ':' in host_header:
client_host = host_header.split(':')[0]
# Keep the port from the request for consistency
server_port = host_header.split(':')[1] if len(host_header.split(':')) > 1 else str(getattr(global_args, 'port', 6745))
else:
client_host = host_header
server_port = str(getattr(global_args, 'port', 6745))
else:
# Fallback to client host if no Host header
client_host = http_request.client.host if http_request.client else "127.0.0.1"
server_port = str(getattr(global_args, 'port', 6745))
else:
# Fallback if no http_request
client_host = "127.0.0.1"
server_port = str(getattr(global_args, 'port', 6745))
# Determine protocol (http or https)
use_https = getattr(global_args, 'https', False) or getattr(global_args, 'pubkey', None)
protocol = "https" if use_https else "http"
api_base = f"{protocol}://{client_host}:{server_port}/v1"
print(f"DEBUG: Using api_base for local server: {api_base}")
# Get or create litellm backend
litellm_backend = get_litellm_backend(
model=request.model,
api_key=api_key,
api_base=api_base,
context_window=8192, # Default, can be made configurable
model_manager=multi_model_manager # Pass for alias resolution
)
# Get the tool_parser from multi_model_manager for model-specific parsing
tool_parser = multi_model_manager.tool_parser if hasattr(multi_model_manager, 'tool_parser') else None
# Convert messages to dict format
messages_dict = []
for msg in request.messages:
msg_dict = {"role": msg.role, "content": msg.content or ""}
if hasattr(msg, 'tool_calls') and msg.tool_calls:
msg_dict["tool_calls"] = msg.tool_calls
if hasattr(msg, 'tool_call_id') and msg.tool_call_id:
msg_dict["tool_call_id"] = msg.tool_call_id
messages_dict.append(msg_dict)
# Prepare tools if provided
tools_dict = None
if request.tools:
tools_dict = request.tools
# Generate response
try:
if request.stream:
# Streaming response
from fastapi.responses import StreamingResponse
async def generate():
try:
async for chunk in await litellm_backend.chat_completion(
messages=messages_dict,
model=request.model,
temperature=request.temperature,
top_p=request.top_p,
max_tokens=request.max_tokens,
stop=request.stop,
tools=tools_dict,
tool_choice=request.tool_choice,
stream=True,
tool_parser=tool_parser,
):
# Add rate limit headers
headers = {}
if 'usage' in chunk:
headers = litellm_backend.get_rate_limit_headers(
prompt_tokens=chunk.get('usage', {}).get('prompt_tokens', 0),
completion_tokens=chunk.get('usage', {}).get('completion_tokens', 0)
)
# Handle Qwen tool calls if model is Qwen family
if 'qwen' in request.model.lower():
content = chunk.get('choices', [{}])[0].get('delta', {}).get('content', '')
tool_calls = chunk.get('choices', [{}])[0].get('delta', {}).get('tool_calls', [])
if not tool_calls and content:
# Try to parse tool calls from content
tool_calls = litellm_backend.parse_qwen_tool_calls(content)
if tool_calls:
# Strip tool tags from content
content = litellm_backend.strip_tool_tags(content)
chunk['choices'][0]['delta']['content'] = content
chunk['choices'][0]['delta']['tool_calls'] = tool_calls
yield f"data: {json.dumps(chunk)}\n\n"
yield "data: [DONE]\n\n"
except Exception as e:
yield f"data: {json.dumps({'error': {'message': str(e), 'type': 'internal_error'}})}\n\n"
return StreamingResponse(generate(), media_type="text/event-stream")
else:
# Non-streaming response
response = await litellm_backend.chat_completion(
messages=messages_dict,
model=request.model,
temperature=request.temperature,
top_p=request.top_p,
max_tokens=request.max_tokens,
stop=request.stop,
tools=tools_dict,
tool_choice=request.tool_choice,
stream=False,
tool_parser=tool_parser,
)
# Handle Qwen tool calls
if 'qwen' in request.model.lower() and 'choices' in response:
msg = response['choices'][0].get('message', {})
content = msg.get('content', '')
tool_calls = msg.get('tool_calls', [])
if not tool_calls and content:
tool_calls = litellm_backend.parse_qwen_tool_calls(content)
if tool_calls:
msg['content'] = litellm_backend.strip_tool_tags(content)
msg['tool_calls'] = tool_calls
response['choices'][0]['message'] = msg
# Add rate limit headers
headers = {}
if 'usage' in response:
headers = litellm_backend.get_rate_limit_headers(
prompt_tokens=response.get('usage', {}).get('prompt_tokens', 0),
completion_tokens=response.get('usage', {}).get('completion_tokens', 0)
)
from fastapi.responses import JSONResponse
return JSONResponse(content=response, headers=headers)
except Exception as e:
# Handle litellm errors
error_response = {
"error": {
"message": str(e),
"type": "internal_error",
"code": 500
}
}
return JSONResponse(content=error_response, status_code=500)
# Continue with original implementation for 'auto' parser
# Get the model for this request
requested_model = request.model
......@@ -5541,14 +5757,14 @@ async def stream_chat_response(
prompt_tokens = len(prompt_text.split())
completion_tokens = len(generated_text.split()) if generated_text else 0
# Build complete final chunk with OpenAIFormatter for sanitization
# Use OpenAIFormatter for final chunk sanitization
formatter = OpenAIFormatter(model_name)
usage_details = {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens,
}
final_chunk = formatter.format_final_chunk(usage=usage_details)
final_chunk = formatter.format_litellm_chunk("", is_final=True, usage=usage_details)
yield f"data: {json.dumps(final_chunk)}\n\n"
else:
# Calculate token counts for usage in final chunk
......@@ -5694,7 +5910,7 @@ async def generate_chat_response(
# Use OpenAIFormatter for final sanitization
formatter = OpenAIFormatter(model_name)
return formatter.format_full(
return formatter.format_litellm_full(
text=response_message.get("content", ""),
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
......@@ -6225,6 +6441,13 @@ def parse_args():
default=None,
help="Path to store generated files (images, audio). If specified, files will be saved here and served over web.",
)
parser.add_argument(
"--parser",
type=str,
default="auto",
choices=["auto", "litellm"],
help="Tool call parser to use: 'auto' for internal parser, 'litellm' for LiteLLM's parser. Default: auto",
)
return parser.parse_args()
def main():
"""Main entry point."""
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment