Commit cd1040bb authored by Your Name's avatar Your Name

Resolve model aliases in LiteLLM backend

- Add model_manager parameter to LiteLLMBackend for alias resolution
- Add _resolve_model_alias() method to handle default, image, audio, tts aliases
- Update get_litellm_backend() to pass model_manager
- Update coderai call site to pass multi_model_manager

Now --parser litellm will resolve aliases like 'default', 'image' to actual model names before normalizing for litellm.
parent eed5a3ff
...@@ -63,6 +63,7 @@ class LiteLLMBackend: ...@@ -63,6 +63,7 @@ class LiteLLMBackend:
api_key: Optional[str] = None, api_key: Optional[str] = None,
base_url: Optional[str] = None, base_url: Optional[str] = None,
context_window: int = 4096, context_window: int = 4096,
model_manager: Optional[Any] = None,
**kwargs **kwargs
): ):
""" """
...@@ -73,11 +74,13 @@ class LiteLLMBackend: ...@@ -73,11 +74,13 @@ class LiteLLMBackend:
api_key: API key for the model provider api_key: API key for the model provider
base_url: Custom base URL for OpenAI-compatible APIs base_url: Custom base URL for OpenAI-compatible APIs
context_window: Maximum context window size for rate limit headers context_window: Maximum context window size for rate limit headers
model_manager: Reference to MultiModelManager for resolving aliases
""" """
self.model = model self.model = model
self.api_key = api_key or os.environ.get("OPENAI_API_KEY") self.api_key = api_key or os.environ.get("OPENAI_API_KEY")
self.base_url = base_url self.base_url = base_url
self.context_window = context_window self.context_window = context_window
self.model_manager = model_manager
# Configure litellm # Configure litellm
if base_url: if base_url:
...@@ -92,15 +95,21 @@ class LiteLLMBackend: ...@@ -92,15 +95,21 @@ class LiteLLMBackend:
LiteLLM requires model names in format "provider/model" e.g., "openai/gpt-3.5-turbo". LiteLLM requires model names in format "provider/model" e.g., "openai/gpt-3.5-turbo".
If no provider is specified, add a default provider based on common patterns. If no provider is specified, add a default provider based on common patterns.
Also handles model aliases (default, image, etc.) by resolving them to
the actual model name from the model manager.
Args: Args:
model: Original model name model: Original model name (may be an alias)
Returns: Returns:
Normalized model name with provider prefix Normalized model name with provider prefix
""" """
# First, resolve alias to actual model name if we have a model manager
resolved_model = self._resolve_model_alias(model)
# If already has a provider prefix (contains /), return as-is # If already has a provider prefix (contains /), return as-is
if '/' in model: if '/' in resolved_model:
return model return resolved_model
# Common model name patterns and their providers # Common model name patterns and their providers
provider_map = { provider_map = {
...@@ -137,15 +146,66 @@ class LiteLLMBackend: ...@@ -137,15 +146,66 @@ class LiteLLMBackend:
'qwen': 'qwen/', 'qwen': 'qwen/',
} }
model_lower = model.lower() model_lower = resolved_model.lower()
# Check for known patterns # Check for known patterns
for pattern, provider in provider_map.items(): for pattern, provider in provider_map.items():
if model_lower.startswith(pattern): if model_lower.startswith(pattern):
return f"{provider}{model}" return f"{provider}{resolved_model}"
# Default: assume OpenAI-compatible local model # Default: assume OpenAI-compatible local model
return f"openai/{model}" return f"openai/{resolved_model}"
def _resolve_model_alias(self, model: str) -> str:
"""
Resolve model alias to actual model name.
Handles aliases like "default", "image", "audio", "tts", or custom aliases
registered via --model-alias.
Args:
model: Model name or alias
Returns:
Resolved actual model name
"""
if not self.model_manager:
return model
# Check if model is "default" or empty - use default_model
if not model or model == "default":
default_model = getattr(self.model_manager, 'default_model', None)
if default_model:
return default_model
return model
# Check if model is "image" - get first image model
if model == "image":
image_models = getattr(self.model_manager, 'image_models', [])
if image_models:
return image_models[0]
return model
# Check if model is "audio" - get first audio model
if model == "audio":
audio_models = getattr(self.model_manager, 'audio_models', [])
if audio_models:
return audio_models[0]
return model
# Check if model is "tts" - get tts model
if model == "tts":
tts_model = getattr(self.model_manager, 'tts_model', None)
if tts_model:
return tts_model
return model
# Check custom aliases registered via --model-alias
model_aliases = getattr(self.model_manager, 'model_aliases', {})
if model in model_aliases:
return model_aliases[model]
return model
def _convert_messages(self, messages: List[Dict]) -> List[Dict]: def _convert_messages(self, messages: List[Dict]) -> List[Dict]:
"""Convert OpenAI message format to litellm format.""" """Convert OpenAI message format to litellm format."""
...@@ -586,19 +646,22 @@ def get_litellm_backend( ...@@ -586,19 +646,22 @@ def get_litellm_backend(
api_key: Optional[str] = None, api_key: Optional[str] = None,
base_url: Optional[str] = None, base_url: Optional[str] = None,
context_window: int = 4096, context_window: int = 4096,
model_manager: Optional[Any] = None,
**kwargs **kwargs
) -> LiteLLMBackend: ) -> LiteLLMBackend:
"""Get or create the default LiteLLM backend instance.""" """Get or create the default LiteLLM backend instance."""
global default_litellm_backend global default_litellm_backend
if default_litellm_backend is None: # Always create a new instance with the provided model_manager
default_litellm_backend = LiteLLMBackend( # This ensures aliases are resolved correctly on each call
model=model, default_litellm_backend = LiteLLMBackend(
api_key=api_key, model=model,
base_url=base_url, api_key=api_key,
context_window=context_window, base_url=base_url,
**kwargs context_window=context_window,
) model_manager=model_manager,
**kwargs
)
return default_litellm_backend return default_litellm_backend
......
...@@ -5186,7 +5186,8 @@ async def chat_completions(request: ChatCompletionRequest): ...@@ -5186,7 +5186,8 @@ async def chat_completions(request: ChatCompletionRequest):
# Get or create litellm backend # Get or create litellm backend
litellm_backend = get_litellm_backend( litellm_backend = get_litellm_backend(
model=request.model, model=request.model,
context_window=8192 # Default, can be made configurable context_window=8192, # Default, can be made configurable
model_manager=multi_model_manager # Pass for alias resolution
) )
# Convert messages to dict format # Convert messages to dict format
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment