Commit a93b69cf authored by Your Name's avatar Your Name

Pass api_base to litellm for local model connections

- When model starts with 'ollama:', construct api_base from request host and port
- api_base is now passed to LiteLLMBackend for local connections
parent c723cf43
......@@ -62,6 +62,7 @@ class LiteLLMBackend:
model: str = "gpt-3.5-turbo",
api_key: Optional[str] = None,
base_url: Optional[str] = None,
api_base: Optional[str] = None, # Add api_base parameter
context_window: int = 4096,
model_manager: Optional[Any] = None,
**kwargs
......@@ -73,6 +74,7 @@ class LiteLLMBackend:
model: Model name to use (e.g., "gpt-3.5-turbo", "ollama/llama2")
api_key: API key for the model provider
base_url: Custom base URL for OpenAI-compatible APIs
api_base: API base URL (alternative to base_url, e.g., "http://localhost:11434/v1")
context_window: Maximum context window size for rate limit headers
model_manager: Reference to MultiModelManager for resolving aliases
"""
......@@ -80,17 +82,17 @@ class LiteLLMBackend:
# Use provided API key, or generate a fake one if not provided
# This allows litellm to proceed without requiring an API key
self.api_key = api_key if api_key else "fake-key-for-local-testing"
self.base_url = base_url
self.base_url = base_url or api_base # Use either base_url or api_base
self.context_window = context_window
self.model_manager = model_manager
self.tool_parser = None # Coderai's tool parser for post-processing
self.tools_schema = {} # Tools schema for coderai parser
# Configure litellm
if base_url:
litellm.base_url = base_url
if api_key:
litellm.api_key = api_key
if self.base_url:
litellm.base_url = self.base_url
if self.api_key:
litellm.api_key = self.api_key
def normalize_model_name(self, model: str) -> str:
"""
......@@ -741,6 +743,7 @@ def get_litellm_backend(
model: str = "gpt-3.5-turbo",
api_key: Optional[str] = None,
base_url: Optional[str] = None,
api_base: Optional[str] = None, # Add api_base parameter
context_window: int = 4096,
model_manager: Optional[Any] = None,
**kwargs
......@@ -754,6 +757,7 @@ def get_litellm_backend(
model=model,
api_key=api_key,
base_url=base_url,
api_base=api_base,
context_window=context_window,
model_manager=model_manager,
**kwargs
......
......@@ -5203,10 +5203,39 @@ async def chat_completions(request: ChatCompletionRequest, http_request: Request
api_key = "fake-key-for-local-testing"
print("DEBUG: No API key provided, using fake key for litellm")
# Determine the base URL for litellm to connect to
# Use the server's host and port for local connections
api_base = None
# Check if model starts with 'ollama:' - use local Ollama
if request.model and request.model.startswith('ollama:'):
# Get the host from the request headers
client_host = "127.0.0.1"
if http_request:
host_header = http_request.headers.get('host', '')
if host_header:
# Strip port if present
if ':' in host_header:
client_host = host_header.split(':')[0]
if client_host.replace('.', '').isdigit():
# It's an IP, keep it
pass
else:
# It's a hostname, use localhost
client_host = "127.0.0.1"
else:
client_host = host_header
# Get port from global_args or use default
port = getattr(global_args, 'port', 11434) if global_args else 11434
api_base = f"http://{client_host}:{port}/v1"
print(f"DEBUG: Using api_base for Ollama: {api_base}")
# Get or create litellm backend
litellm_backend = get_litellm_backend(
model=request.model,
api_key=api_key,
api_base=api_base,
context_window=8192, # Default, can be made configurable
model_manager=multi_model_manager # Pass for alias resolution
)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment