Commit eed5a3ff authored by Your Name's avatar Your Name

Add normalize_model_name() to litellm backend

- Add method to normalize model names for litellm
- Maps common model patterns to providers (gpt-* -> openai/, llama -> meta/, etc.)
- Falls back to openai/ for unknown models
parent 71e521ff
...@@ -84,6 +84,68 @@ class LiteLLMBackend: ...@@ -84,6 +84,68 @@ class LiteLLMBackend:
litellm.base_url = base_url litellm.base_url = base_url
if api_key: if api_key:
litellm.api_key = api_key litellm.api_key = api_key
def normalize_model_name(self, model: str) -> str:
"""
Normalize model name for litellm.
LiteLLM requires model names in format "provider/model" e.g., "openai/gpt-3.5-turbo".
If no provider is specified, add a default provider based on common patterns.
Args:
model: Original model name
Returns:
Normalized model name with provider prefix
"""
# If already has a provider prefix (contains /), return as-is
if '/' in model:
return model
# Common model name patterns and their providers
provider_map = {
# OpenAI models
'gpt-': 'openai/',
'gpt3': 'openai/',
'gpt4': 'openai/',
# Anthropic models
'claude': 'anthropic/',
# Google models
'gemini': 'gemini/',
'palm': 'gemini/',
# Meta/Llama models
'llama': 'meta/',
'llama2': 'meta/',
'llama3': 'meta/',
'mistral': 'mistral/',
# Mistral models
# AWS models
'amazon': 'bedrock/',
# Azure models
'azure': 'azure/',
# Cohere models
'cohere': 'cohere/',
# AI21 models
'ai21': 'ai21/',
# Local/Ollama models
'ollama': 'ollama/',
# HuggingFace models
'hf': 'huggingface/',
# DeepSeek models
'deepseek': 'deepseek/',
# Qwen models
'qwen': 'qwen/',
}
model_lower = model.lower()
# Check for known patterns
for pattern, provider in provider_map.items():
if model_lower.startswith(pattern):
return f"{provider}{model}"
# Default: assume OpenAI-compatible local model
return f"openai/{model}"
def _convert_messages(self, messages: List[Dict]) -> List[Dict]: def _convert_messages(self, messages: List[Dict]) -> List[Dict]:
"""Convert OpenAI message format to litellm format.""" """Convert OpenAI message format to litellm format."""
...@@ -243,8 +305,8 @@ class LiteLLMBackend: ...@@ -243,8 +305,8 @@ class LiteLLMBackend:
if not LITELLM_AVAILABLE: if not LITELLM_AVAILABLE:
raise RuntimeError("litellm is not installed. Run: pip install litellm") raise RuntimeError("litellm is not installed. Run: pip install litellm")
# Prepare the model # Prepare the model - normalize name for litellm
use_model = model or self.model use_model = self.normalize_model_name(model or self.model)
# Convert messages to litellm format # Convert messages to litellm format
litellm_messages = self._convert_messages(messages) litellm_messages = self._convert_messages(messages)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment