Add rate limiting functionality to providers and models

parent 3c2fe27c
Pipeline #221 failed with stages
...@@ -50,6 +50,9 @@ class RequestHandler: ...@@ -50,6 +50,9 @@ class RequestHandler:
raise HTTPException(status_code=503, detail="Provider temporarily unavailable") raise HTTPException(status_code=503, detail="Provider temporarily unavailable")
try: try:
# Apply rate limiting
await handler.apply_rate_limit()
response = await handler.handle_request( response = await handler.handle_request(
model=request_data['model'], model=request_data['model'],
messages=request_data['messages'], messages=request_data['messages'],
...@@ -80,6 +83,9 @@ class RequestHandler: ...@@ -80,6 +83,9 @@ class RequestHandler:
async def stream_generator(): async def stream_generator():
try: try:
# Apply rate limiting
await handler.apply_rate_limit()
response = await handler.handle_request( response = await handler.handle_request(
model=request_data['model'], model=request_data['model'],
messages=request_data['messages'], messages=request_data['messages'],
...@@ -108,6 +114,9 @@ class RequestHandler: ...@@ -108,6 +114,9 @@ class RequestHandler:
handler = get_provider_handler(provider_id, api_key) handler = get_provider_handler(provider_id, api_key)
try: try:
# Apply rate limiting
await handler.apply_rate_limit()
models = await handler.get_models() models = await handler.get_models()
return [model.dict() for model in models] return [model.dict() for model in models]
except Exception as e: except Exception as e:
...@@ -145,6 +154,10 @@ class RotationHandler: ...@@ -145,6 +154,10 @@ class RotationHandler:
raise HTTPException(status_code=503, detail="All providers temporarily unavailable") raise HTTPException(status_code=503, detail="All providers temporarily unavailable")
try: try:
# Apply rate limiting with model-specific rate limit if available
rate_limit = selected_model.get('rate_limit')
await handler.apply_rate_limit(rate_limit)
response = await handler.handle_request( response = await handler.handle_request(
model=model_name, model=model_name,
messages=request_data['messages'], messages=request_data['messages'],
...@@ -170,7 +183,8 @@ class RotationHandler: ...@@ -170,7 +183,8 @@ class RotationHandler:
"id": f"{provider['provider_id']}/{model['name']}", "id": f"{provider['provider_id']}/{model['name']}",
"name": model['name'], "name": model['name'],
"provider_id": provider['provider_id'], "provider_id": provider['provider_id'],
"weight": model['weight'] "weight": model['weight'],
"rate_limit": model.get('rate_limit')
}) })
return all_models return all_models
...@@ -51,6 +51,7 @@ class Model(BaseModel): ...@@ -51,6 +51,7 @@ class Model(BaseModel):
name: str name: str
provider_id: str provider_id: str
weight: int = 1 weight: int = 1
rate_limit: Optional[float] = None
class Provider(BaseModel): class Provider(BaseModel):
id: str id: str
......
...@@ -38,12 +38,29 @@ class BaseProviderHandler: ...@@ -38,12 +38,29 @@ class BaseProviderHandler:
self.provider_id = provider_id self.provider_id = provider_id
self.api_key = api_key self.api_key = api_key
self.error_tracking = config.error_tracking[provider_id] self.error_tracking = config.error_tracking[provider_id]
self.last_request_time = 0
self.rate_limit = config.providers[provider_id].rate_limit
def is_rate_limited(self) -> bool: def is_rate_limited(self) -> bool:
if self.error_tracking['disabled_until'] and self.error_tracking['disabled_until'] > time.time(): if self.error_tracking['disabled_until'] and self.error_tracking['disabled_until'] > time.time():
return True return True
return False return False
async def apply_rate_limit(self, rate_limit: Optional[float] = None):
"""Apply rate limiting by waiting if necessary"""
if rate_limit is None:
rate_limit = self.rate_limit
if rate_limit and rate_limit > 0:
current_time = time.time()
time_since_last_request = current_time - self.last_request_time
required_wait = rate_limit - time_since_last_request
if required_wait > 0:
await asyncio.sleep(required_wait)
self.last_request_time = time.time()
def record_failure(self): def record_failure(self):
self.error_tracking['failures'] += 1 self.error_tracking['failures'] += 1
self.error_tracking['last_failure'] = time.time() self.error_tracking['last_failure'] = time.time()
...@@ -70,10 +87,13 @@ class GoogleProviderHandler(BaseProviderHandler): ...@@ -70,10 +87,13 @@ class GoogleProviderHandler(BaseProviderHandler):
import logging import logging
logging.info(f"GoogleProviderHandler: Handling request for model {model}") logging.info(f"GoogleProviderHandler: Handling request for model {model}")
logging.info(f"GoogleProviderHandler: Messages: {messages}") logging.info(f"GoogleProviderHandler: Messages: {messages}")
# Apply rate limiting
await self.apply_rate_limit()
# Build content from messages # Build content from messages
content = "\n\n".join([f"{msg['role']}: {msg['content']}" for msg in messages]) content = "\n\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
# Generate content using the google-genai client # Generate content using the google-genai client
response = self.client.models.generate_content( response = self.client.models.generate_content(
model=model, model=model,
...@@ -83,10 +103,10 @@ class GoogleProviderHandler(BaseProviderHandler): ...@@ -83,10 +103,10 @@ class GoogleProviderHandler(BaseProviderHandler):
"max_output_tokens": max_tokens "max_output_tokens": max_tokens
} }
) )
logging.info(f"GoogleProviderHandler: Response received: {response}") logging.info(f"GoogleProviderHandler: Response received: {response}")
self.record_success() self.record_success()
# Return the response as a dictionary # Return the response as a dictionary
return { return {
"candidates": [{ "candidates": [{
...@@ -113,11 +133,14 @@ class GoogleProviderHandler(BaseProviderHandler): ...@@ -113,11 +133,14 @@ class GoogleProviderHandler(BaseProviderHandler):
try: try:
import logging import logging
logging.info("GoogleProviderHandler: Getting models list") logging.info("GoogleProviderHandler: Getting models list")
# Apply rate limiting
await self.apply_rate_limit()
# List models using the google-genai client # List models using the google-genai client
models = self.client.models.list() models = self.client.models.list()
logging.info(f"GoogleProviderHandler: Models received: {models}") logging.info(f"GoogleProviderHandler: Models received: {models}")
# Convert to our Model format # Convert to our Model format
result = [] result = []
for model in models: for model in models:
...@@ -126,7 +149,7 @@ class GoogleProviderHandler(BaseProviderHandler): ...@@ -126,7 +149,7 @@ class GoogleProviderHandler(BaseProviderHandler):
name=model.display_name or model.name, name=model.display_name or model.name,
provider_id=self.provider_id provider_id=self.provider_id
)) ))
return result return result
except Exception as e: except Exception as e:
import logging import logging
...@@ -147,7 +170,10 @@ class OpenAIProviderHandler(BaseProviderHandler): ...@@ -147,7 +170,10 @@ class OpenAIProviderHandler(BaseProviderHandler):
import logging import logging
logging.info(f"OpenAIProviderHandler: Handling request for model {model}") logging.info(f"OpenAIProviderHandler: Handling request for model {model}")
logging.info(f"OpenAIProviderHandler: Messages: {messages}") logging.info(f"OpenAIProviderHandler: Messages: {messages}")
# Apply rate limiting
await self.apply_rate_limit()
response = self.client.chat.completions.create( response = self.client.chat.completions.create(
model=model, model=model,
messages=[{"role": msg["role"], "content": msg["content"]} for msg in messages], messages=[{"role": msg["role"], "content": msg["content"]} for msg in messages],
...@@ -168,10 +194,13 @@ class OpenAIProviderHandler(BaseProviderHandler): ...@@ -168,10 +194,13 @@ class OpenAIProviderHandler(BaseProviderHandler):
try: try:
import logging import logging
logging.info("OpenAIProviderHandler: Getting models list") logging.info("OpenAIProviderHandler: Getting models list")
# Apply rate limiting
await self.apply_rate_limit()
models = self.client.models.list() models = self.client.models.list()
logging.info(f"OpenAIProviderHandler: Models received: {models}") logging.info(f"OpenAIProviderHandler: Models received: {models}")
return [Model(id=model.id, name=model.id, provider_id=self.provider_id) for model in models] return [Model(id=model.id, name=model.id, provider_id=self.provider_id) for model in models]
except Exception as e: except Exception as e:
import logging import logging
...@@ -192,7 +221,10 @@ class AnthropicProviderHandler(BaseProviderHandler): ...@@ -192,7 +221,10 @@ class AnthropicProviderHandler(BaseProviderHandler):
import logging import logging
logging.info(f"AnthropicProviderHandler: Handling request for model {model}") logging.info(f"AnthropicProviderHandler: Handling request for model {model}")
logging.info(f"AnthropicProviderHandler: Messages: {messages}") logging.info(f"AnthropicProviderHandler: Messages: {messages}")
# Apply rate limiting
await self.apply_rate_limit()
response = self.client.messages.create( response = self.client.messages.create(
model=model, model=model,
messages=[{"role": msg["role"], "content": msg["content"]} for msg in messages], messages=[{"role": msg["role"], "content": msg["content"]} for msg in messages],
...@@ -227,6 +259,9 @@ class OllamaProviderHandler(BaseProviderHandler): ...@@ -227,6 +259,9 @@ class OllamaProviderHandler(BaseProviderHandler):
raise Exception("Provider rate limited") raise Exception("Provider rate limited")
try: try:
# Apply rate limiting
await self.apply_rate_limit()
response = await self.client.post("/api/generate", json={ response = await self.client.post("/api/generate", json={
"model": model, "model": model,
"prompt": "\n\n".join([f"{msg['role']}: {msg['content']}" for msg in messages]), "prompt": "\n\n".join([f"{msg['role']}: {msg['content']}" for msg in messages]),
...@@ -243,6 +278,9 @@ class OllamaProviderHandler(BaseProviderHandler): ...@@ -243,6 +278,9 @@ class OllamaProviderHandler(BaseProviderHandler):
raise e raise e
async def get_models(self) -> List[Model]: async def get_models(self) -> List[Model]:
# Apply rate limiting
await self.apply_rate_limit()
response = await self.client.get("/api/tags") response = await self.client.get("/api/tags")
response.raise_for_status() response.raise_for_status()
models = response.json().get('models', []) models = response.json().get('models', [])
......
...@@ -5,134 +5,152 @@ ...@@ -5,134 +5,152 @@
"name": "Google AI Studio", "name": "Google AI Studio",
"endpoint": "https://generativelanguage.googleapis.com/v1beta", "endpoint": "https://generativelanguage.googleapis.com/v1beta",
"type": "google", "type": "google",
"api_key_required": true "api_key_required": true,
"rate_limit": 0
}, },
"openai": { "openai": {
"id": "openai", "id": "openai",
"name": "OpenAI", "name": "OpenAI",
"endpoint": "https://api.openai.com/v1", "endpoint": "https://api.openai.com/v1",
"type": "openai", "type": "openai",
"api_key_required": true "api_key_required": true,
"rate_limit": 0
}, },
"anthropic": { "anthropic": {
"id": "anthropic", "id": "anthropic",
"name": "Anthropic", "name": "Anthropic",
"endpoint": "https://api.anthropic.com/v1", "endpoint": "https://api.anthropic.com/v1",
"type": "anthropic", "type": "anthropic",
"api_key_required": true "api_key_required": true,
"rate_limit": 0
}, },
"ollama": { "ollama": {
"id": "ollama", "id": "ollama",
"name": "Ollama", "name": "Ollama",
"endpoint": "http://localhost:11434", "endpoint": "http://localhost:11434",
"type": "ollama", "type": "ollama",
"api_key_required": false "api_key_required": false,
"rate_limit": 0
}, },
"azure_openai": { "azure_openai": {
"id": "azure_openai", "id": "azure_openai",
"name": "Azure OpenAI", "name": "Azure OpenAI",
"endpoint": "https://your-azure-endpoint.openai.azure.com", "endpoint": "https://your-azure-endpoint.openai.azure.com",
"type": "openai", "type": "openai",
"api_key_required": true "api_key_required": true,
"rate_limit": 0
}, },
"cohere": { "cohere": {
"id": "cohere", "id": "cohere",
"name": "Cohere", "name": "Cohere",
"endpoint": "https://api.cohere.com/v1", "endpoint": "https://api.cohere.com/v1",
"type": "cohere", "type": "cohere",
"api_key_required": true "api_key_required": true,
"rate_limit": 0
}, },
"huggingface": { "huggingface": {
"id": "huggingface", "id": "huggingface",
"name": "Hugging Face", "name": "Hugging Face",
"endpoint": "https://api-inference.huggingface.co", "endpoint": "https://api-inference.huggingface.co",
"type": "huggingface", "type": "huggingface",
"api_key_required": true "api_key_required": true,
"rate_limit": 0
}, },
"replicate": { "replicate": {
"id": "replicate", "id": "replicate",
"name": "Replicate", "name": "Replicate",
"endpoint": "https://api.replicate.com/v1", "endpoint": "https://api.replicate.com/v1",
"type": "replicate", "type": "replicate",
"api_key_required": true "api_key_required": true,
"rate_limit": 0
}, },
"togetherai": { "togetherai": {
"id": "togetherai", "id": "togetherai",
"name": "Together AI", "name": "Together AI",
"endpoint": "https://api.together.xyz/v1", "endpoint": "https://api.together.xyz/v1",
"type": "openai", "type": "openai",
"api_key_required": true "api_key_required": true,
"rate_limit": 0
}, },
"groq": { "groq": {
"id": "groq", "id": "groq",
"name": "Groq", "name": "Groq",
"endpoint": "https://api.groq.com/openai/v1", "endpoint": "https://api.groq.com/openai/v1",
"type": "openai", "type": "openai",
"api_key_required": true "api_key_required": true,
"rate_limit": 0
}, },
"mistralai": { "mistralai": {
"id": "mistralai", "id": "mistralai",
"name": "Mistral AI", "name": "Mistral AI",
"endpoint": "https://api.mistral.ai/v1", "endpoint": "https://api.mistral.ai/v1",
"type": "openai", "type": "openai",
"api_key_required": true "api_key_required": true,
"rate_limit": 0
}, },
"stabilityai": { "stabilityai": {
"id": "stabilityai", "id": "stabilityai",
"name": "Stability AI", "name": "Stability AI",
"endpoint": "https://api.stability.ai/v2beta", "endpoint": "https://api.stability.ai/v2beta",
"type": "stabilityai", "type": "stabilityai",
"api_key_required": true "api_key_required": true,
"rate_limit": 0
}, },
"kilo": { "kilo": {
"id": "kilo", "id": "kilo",
"name": "KiloCode", "name": "KiloCode",
"endpoint": "https://kilocode.ai/api/openrouter", "endpoint": "https://kilocode.ai/api/openrouter",
"type": "openai", "type": "openai",
"api_key_required": true "api_key_required": true,
"rate_limit": 0
}, },
"perplexity": { "perplexity": {
"id": "perplexity", "id": "perplexity",
"name": "Perplexity AI", "name": "Perplexity AI",
"endpoint": "https://api.perplexity.ai", "endpoint": "https://api.perplexity.ai",
"type": "openai", "type": "openai",
"api_key_required": true "api_key_required": true,
"rate_limit": 0
}, },
"poe": { "poe": {
"id": "poe", "id": "poe",
"name": "Poe", "name": "Poe",
"endpoint": "https://api.poe.com/v1", "endpoint": "https://api.poe.com/v1",
"type": "poe", "type": "poe",
"api_key_required": true "api_key_required": true,
"rate_limit": 0
}, },
"lanai": { "lanai": {
"id": "lanai", "id": "lanai",
"name": "Llama AI", "name": "Llama AI",
"endpoint": "https://api.lanai.ai/v1", "endpoint": "https://api.lanai.ai/v1",
"type": "lanai", "type": "lanai",
"api_key_required": true "api_key_required": true,
"rate_limit": 0
}, },
"amazon": { "amazon": {
"id": "amazon", "id": "amazon",
"name": "Amazon Bedrock", "name": "Amazon Bedrock",
"endpoint": "https://api.amazon.com/bedrock/v1", "endpoint": "https://api.amazon.com/bedrock/v1",
"type": "amazon", "type": "amazon",
"api_key_required": true "api_key_required": true,
"rate_limit": 0
}, },
"ibm": { "ibm": {
"id": "ibm", "id": "ibm",
"name": "IBM Watson", "name": "IBM Watson",
"endpoint": "https://api.ibm.com/watson/v1", "endpoint": "https://api.ibm.com/watson/v1",
"type": "ibm", "type": "ibm",
"api_key_required": true "api_key_required": true,
"rate_limit": 0
}, },
"microsoft": { "microsoft": {
"id": "microsoft", "id": "microsoft",
"name": "Microsoft Azure AI", "name": "Microsoft Azure AI",
"endpoint": "https://api.microsoft.com/v1", "endpoint": "https://api.microsoft.com/v1",
"type": "microsoft", "type": "microsoft",
"api_key_required": true "api_key_required": true,
"rate_limit": 0
} }
} }
} }
...@@ -9,11 +9,13 @@ ...@@ -9,11 +9,13 @@
"models": [ "models": [
{ {
"name": "gemini-2.0-flash", "name": "gemini-2.0-flash",
"weight": 3 "weight": 3,
"rate_limit": 0
}, },
{ {
"name": "gemini-1.5-pro", "name": "gemini-1.5-pro",
"weight": 1 "weight": 1,
"rate_limit": 0
} }
] ]
}, },
...@@ -23,11 +25,13 @@ ...@@ -23,11 +25,13 @@
"models": [ "models": [
{ {
"name": "gpt-4", "name": "gpt-4",
"weight": 2 "weight": 2,
"rate_limit": 0
}, },
{ {
"name": "gpt-3.5-turbo", "name": "gpt-3.5-turbo",
"weight": 1 "weight": 1,
"rate_limit": 0
} }
] ]
}, },
...@@ -37,11 +41,13 @@ ...@@ -37,11 +41,13 @@
"models": [ "models": [
{ {
"name": "claude-3-5-sonnet-20241022", "name": "claude-3-5-sonnet-20241022",
"weight": 2 "weight": 2,
"rate_limit": 0
}, },
{ {
"name": "claude-3-haiku-20240307", "name": "claude-3-haiku-20240307",
"weight": 1 "weight": 1,
"rate_limit": 0
} }
] ]
} }
...@@ -56,11 +62,13 @@ ...@@ -56,11 +62,13 @@
"models": [ "models": [
{ {
"name": "gemini-1.5-pro", "name": "gemini-1.5-pro",
"weight": 2 "weight": 2,
"rate_limit": 0
}, },
{ {
"name": "gemini-2.0-flash", "name": "gemini-2.0-flash",
"weight": 1 "weight": 1,
"rate_limit": 0
} }
] ]
}, },
...@@ -70,11 +78,13 @@ ...@@ -70,11 +78,13 @@
"models": [ "models": [
{ {
"name": "gpt-4", "name": "gpt-4",
"weight": 2 "weight": 2,
"rate_limit": 0
}, },
{ {
"name": "gpt-3.5-turbo", "name": "gpt-3.5-turbo",
"weight": 1 "weight": 1,
"rate_limit": 0
} }
] ]
} }
......
...@@ -61,14 +61,14 @@ async def chat_completions(provider_id: str, request: Request, body: ChatComplet ...@@ -61,14 +61,14 @@ async def chat_completions(provider_id: str, request: Request, body: ChatComplet
logger.debug(f"Received chat_completions request for provider: {provider_id}") logger.debug(f"Received chat_completions request for provider: {provider_id}")
logger.debug(f"Request headers: {dict(request.headers)}") logger.debug(f"Request headers: {dict(request.headers)}")
logger.debug(f"Request body: {body}") logger.debug(f"Request body: {body}")
body_dict = body.model_dump() body_dict = body.model_dump()
# Check if it's a rotation # Check if it's a rotation
if provider_id in config.rotations: if provider_id in config.rotations:
logger.debug("Handling rotation request") logger.debug("Handling rotation request")
return await rotation_handler.handle_rotation_request(provider_id, body_dict) return await rotation_handler.handle_rotation_request(provider_id, body_dict)
# Check if it's a provider # Check if it's a provider
if provider_id not in config.providers: if provider_id not in config.providers:
logger.error(f"Provider {provider_id} not found") logger.error(f"Provider {provider_id} not found")
...@@ -93,12 +93,12 @@ async def chat_completions(provider_id: str, request: Request, body: ChatComplet ...@@ -93,12 +93,12 @@ async def chat_completions(provider_id: str, request: Request, body: ChatComplet
@app.get("/api/{provider_id}/models") @app.get("/api/{provider_id}/models")
async def list_models(request: Request, provider_id: str): async def list_models(request: Request, provider_id: str):
logger.debug(f"Received list_models request for provider: {provider_id}") logger.debug(f"Received list_models request for provider: {provider_id}")
# Check if it's a rotation # Check if it's a rotation
if provider_id in config.rotations: if provider_id in config.rotations:
logger.debug("Handling rotation model list request") logger.debug("Handling rotation model list request")
return await rotation_handler.handle_rotation_model_list(provider_id) return await rotation_handler.handle_rotation_model_list(provider_id)
# Check if it's a provider # Check if it's a provider
if provider_id not in config.providers: if provider_id not in config.providers:
logger.error(f"Provider {provider_id} not found") logger.error(f"Provider {provider_id} not found")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment