Add rate limiting functionality to providers and models

parent 3c2fe27c
Pipeline #221 failed with stages
......@@ -50,6 +50,9 @@ class RequestHandler:
raise HTTPException(status_code=503, detail="Provider temporarily unavailable")
try:
# Apply rate limiting
await handler.apply_rate_limit()
response = await handler.handle_request(
model=request_data['model'],
messages=request_data['messages'],
......@@ -80,6 +83,9 @@ class RequestHandler:
async def stream_generator():
try:
# Apply rate limiting
await handler.apply_rate_limit()
response = await handler.handle_request(
model=request_data['model'],
messages=request_data['messages'],
......@@ -108,6 +114,9 @@ class RequestHandler:
handler = get_provider_handler(provider_id, api_key)
try:
# Apply rate limiting
await handler.apply_rate_limit()
models = await handler.get_models()
return [model.dict() for model in models]
except Exception as e:
......@@ -145,6 +154,10 @@ class RotationHandler:
raise HTTPException(status_code=503, detail="All providers temporarily unavailable")
try:
# Apply rate limiting with model-specific rate limit if available
rate_limit = selected_model.get('rate_limit')
await handler.apply_rate_limit(rate_limit)
response = await handler.handle_request(
model=model_name,
messages=request_data['messages'],
......@@ -170,7 +183,8 @@ class RotationHandler:
"id": f"{provider['provider_id']}/{model['name']}",
"name": model['name'],
"provider_id": provider['provider_id'],
"weight": model['weight']
"weight": model['weight'],
"rate_limit": model.get('rate_limit')
})
return all_models
......@@ -51,6 +51,7 @@ class Model(BaseModel):
name: str
provider_id: str
weight: int = 1
rate_limit: Optional[float] = None
class Provider(BaseModel):
id: str
......
......@@ -38,12 +38,29 @@ class BaseProviderHandler:
self.provider_id = provider_id
self.api_key = api_key
self.error_tracking = config.error_tracking[provider_id]
self.last_request_time = 0
self.rate_limit = config.providers[provider_id].rate_limit
def is_rate_limited(self) -> bool:
if self.error_tracking['disabled_until'] and self.error_tracking['disabled_until'] > time.time():
return True
return False
async def apply_rate_limit(self, rate_limit: Optional[float] = None):
"""Apply rate limiting by waiting if necessary"""
if rate_limit is None:
rate_limit = self.rate_limit
if rate_limit and rate_limit > 0:
current_time = time.time()
time_since_last_request = current_time - self.last_request_time
required_wait = rate_limit - time_since_last_request
if required_wait > 0:
await asyncio.sleep(required_wait)
self.last_request_time = time.time()
def record_failure(self):
self.error_tracking['failures'] += 1
self.error_tracking['last_failure'] = time.time()
......@@ -70,10 +87,13 @@ class GoogleProviderHandler(BaseProviderHandler):
import logging
logging.info(f"GoogleProviderHandler: Handling request for model {model}")
logging.info(f"GoogleProviderHandler: Messages: {messages}")
# Apply rate limiting
await self.apply_rate_limit()
# Build content from messages
content = "\n\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
# Generate content using the google-genai client
response = self.client.models.generate_content(
model=model,
......@@ -83,10 +103,10 @@ class GoogleProviderHandler(BaseProviderHandler):
"max_output_tokens": max_tokens
}
)
logging.info(f"GoogleProviderHandler: Response received: {response}")
self.record_success()
# Return the response as a dictionary
return {
"candidates": [{
......@@ -113,11 +133,14 @@ class GoogleProviderHandler(BaseProviderHandler):
try:
import logging
logging.info("GoogleProviderHandler: Getting models list")
# Apply rate limiting
await self.apply_rate_limit()
# List models using the google-genai client
models = self.client.models.list()
logging.info(f"GoogleProviderHandler: Models received: {models}")
# Convert to our Model format
result = []
for model in models:
......@@ -126,7 +149,7 @@ class GoogleProviderHandler(BaseProviderHandler):
name=model.display_name or model.name,
provider_id=self.provider_id
))
return result
except Exception as e:
import logging
......@@ -147,7 +170,10 @@ class OpenAIProviderHandler(BaseProviderHandler):
import logging
logging.info(f"OpenAIProviderHandler: Handling request for model {model}")
logging.info(f"OpenAIProviderHandler: Messages: {messages}")
# Apply rate limiting
await self.apply_rate_limit()
response = self.client.chat.completions.create(
model=model,
messages=[{"role": msg["role"], "content": msg["content"]} for msg in messages],
......@@ -168,10 +194,13 @@ class OpenAIProviderHandler(BaseProviderHandler):
try:
import logging
logging.info("OpenAIProviderHandler: Getting models list")
# Apply rate limiting
await self.apply_rate_limit()
models = self.client.models.list()
logging.info(f"OpenAIProviderHandler: Models received: {models}")
return [Model(id=model.id, name=model.id, provider_id=self.provider_id) for model in models]
except Exception as e:
import logging
......@@ -192,7 +221,10 @@ class AnthropicProviderHandler(BaseProviderHandler):
import logging
logging.info(f"AnthropicProviderHandler: Handling request for model {model}")
logging.info(f"AnthropicProviderHandler: Messages: {messages}")
# Apply rate limiting
await self.apply_rate_limit()
response = self.client.messages.create(
model=model,
messages=[{"role": msg["role"], "content": msg["content"]} for msg in messages],
......@@ -227,6 +259,9 @@ class OllamaProviderHandler(BaseProviderHandler):
raise Exception("Provider rate limited")
try:
# Apply rate limiting
await self.apply_rate_limit()
response = await self.client.post("/api/generate", json={
"model": model,
"prompt": "\n\n".join([f"{msg['role']}: {msg['content']}" for msg in messages]),
......@@ -243,6 +278,9 @@ class OllamaProviderHandler(BaseProviderHandler):
raise e
async def get_models(self) -> List[Model]:
# Apply rate limiting
await self.apply_rate_limit()
response = await self.client.get("/api/tags")
response.raise_for_status()
models = response.json().get('models', [])
......
......@@ -5,134 +5,152 @@
"name": "Google AI Studio",
"endpoint": "https://generativelanguage.googleapis.com/v1beta",
"type": "google",
"api_key_required": true
"api_key_required": true,
"rate_limit": 0
},
"openai": {
"id": "openai",
"name": "OpenAI",
"endpoint": "https://api.openai.com/v1",
"type": "openai",
"api_key_required": true
"api_key_required": true,
"rate_limit": 0
},
"anthropic": {
"id": "anthropic",
"name": "Anthropic",
"endpoint": "https://api.anthropic.com/v1",
"type": "anthropic",
"api_key_required": true
"api_key_required": true,
"rate_limit": 0
},
"ollama": {
"id": "ollama",
"name": "Ollama",
"endpoint": "http://localhost:11434",
"type": "ollama",
"api_key_required": false
"api_key_required": false,
"rate_limit": 0
},
"azure_openai": {
"id": "azure_openai",
"name": "Azure OpenAI",
"endpoint": "https://your-azure-endpoint.openai.azure.com",
"type": "openai",
"api_key_required": true
"api_key_required": true,
"rate_limit": 0
},
"cohere": {
"id": "cohere",
"name": "Cohere",
"endpoint": "https://api.cohere.com/v1",
"type": "cohere",
"api_key_required": true
"api_key_required": true,
"rate_limit": 0
},
"huggingface": {
"id": "huggingface",
"name": "Hugging Face",
"endpoint": "https://api-inference.huggingface.co",
"type": "huggingface",
"api_key_required": true
"api_key_required": true,
"rate_limit": 0
},
"replicate": {
"id": "replicate",
"name": "Replicate",
"endpoint": "https://api.replicate.com/v1",
"type": "replicate",
"api_key_required": true
"api_key_required": true,
"rate_limit": 0
},
"togetherai": {
"id": "togetherai",
"name": "Together AI",
"endpoint": "https://api.together.xyz/v1",
"type": "openai",
"api_key_required": true
"api_key_required": true,
"rate_limit": 0
},
"groq": {
"id": "groq",
"name": "Groq",
"endpoint": "https://api.groq.com/openai/v1",
"type": "openai",
"api_key_required": true
"api_key_required": true,
"rate_limit": 0
},
"mistralai": {
"id": "mistralai",
"name": "Mistral AI",
"endpoint": "https://api.mistral.ai/v1",
"type": "openai",
"api_key_required": true
"api_key_required": true,
"rate_limit": 0
},
"stabilityai": {
"id": "stabilityai",
"name": "Stability AI",
"endpoint": "https://api.stability.ai/v2beta",
"type": "stabilityai",
"api_key_required": true
"api_key_required": true,
"rate_limit": 0
},
"kilo": {
"id": "kilo",
"name": "KiloCode",
"endpoint": "https://kilocode.ai/api/openrouter",
"type": "openai",
"api_key_required": true
"api_key_required": true,
"rate_limit": 0
},
"perplexity": {
"id": "perplexity",
"name": "Perplexity AI",
"endpoint": "https://api.perplexity.ai",
"type": "openai",
"api_key_required": true
"api_key_required": true,
"rate_limit": 0
},
"poe": {
"id": "poe",
"name": "Poe",
"endpoint": "https://api.poe.com/v1",
"type": "poe",
"api_key_required": true
"api_key_required": true,
"rate_limit": 0
},
"lanai": {
"id": "lanai",
"name": "Llama AI",
"endpoint": "https://api.lanai.ai/v1",
"type": "lanai",
"api_key_required": true
"api_key_required": true,
"rate_limit": 0
},
"amazon": {
"id": "amazon",
"name": "Amazon Bedrock",
"endpoint": "https://api.amazon.com/bedrock/v1",
"type": "amazon",
"api_key_required": true
"api_key_required": true,
"rate_limit": 0
},
"ibm": {
"id": "ibm",
"name": "IBM Watson",
"endpoint": "https://api.ibm.com/watson/v1",
"type": "ibm",
"api_key_required": true
"api_key_required": true,
"rate_limit": 0
},
"microsoft": {
"id": "microsoft",
"name": "Microsoft Azure AI",
"endpoint": "https://api.microsoft.com/v1",
"type": "microsoft",
"api_key_required": true
"api_key_required": true,
"rate_limit": 0
}
}
}
......@@ -9,11 +9,13 @@
"models": [
{
"name": "gemini-2.0-flash",
"weight": 3
"weight": 3,
"rate_limit": 0
},
{
"name": "gemini-1.5-pro",
"weight": 1
"weight": 1,
"rate_limit": 0
}
]
},
......@@ -23,11 +25,13 @@
"models": [
{
"name": "gpt-4",
"weight": 2
"weight": 2,
"rate_limit": 0
},
{
"name": "gpt-3.5-turbo",
"weight": 1
"weight": 1,
"rate_limit": 0
}
]
},
......@@ -37,11 +41,13 @@
"models": [
{
"name": "claude-3-5-sonnet-20241022",
"weight": 2
"weight": 2,
"rate_limit": 0
},
{
"name": "claude-3-haiku-20240307",
"weight": 1
"weight": 1,
"rate_limit": 0
}
]
}
......@@ -56,11 +62,13 @@
"models": [
{
"name": "gemini-1.5-pro",
"weight": 2
"weight": 2,
"rate_limit": 0
},
{
"name": "gemini-2.0-flash",
"weight": 1
"weight": 1,
"rate_limit": 0
}
]
},
......@@ -70,11 +78,13 @@
"models": [
{
"name": "gpt-4",
"weight": 2
"weight": 2,
"rate_limit": 0
},
{
"name": "gpt-3.5-turbo",
"weight": 1
"weight": 1,
"rate_limit": 0
}
]
}
......
......@@ -61,14 +61,14 @@ async def chat_completions(provider_id: str, request: Request, body: ChatComplet
logger.debug(f"Received chat_completions request for provider: {provider_id}")
logger.debug(f"Request headers: {dict(request.headers)}")
logger.debug(f"Request body: {body}")
body_dict = body.model_dump()
# Check if it's a rotation
if provider_id in config.rotations:
logger.debug("Handling rotation request")
return await rotation_handler.handle_rotation_request(provider_id, body_dict)
# Check if it's a provider
if provider_id not in config.providers:
logger.error(f"Provider {provider_id} not found")
......@@ -93,12 +93,12 @@ async def chat_completions(provider_id: str, request: Request, body: ChatComplet
@app.get("/api/{provider_id}/models")
async def list_models(request: Request, provider_id: str):
logger.debug(f"Received list_models request for provider: {provider_id}")
# Check if it's a rotation
if provider_id in config.rotations:
logger.debug("Handling rotation model list request")
return await rotation_handler.handle_rotation_model_list(provider_id)
# Check if it's a provider
if provider_id not in config.providers:
logger.error(f"Provider {provider_id} not found")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment