Try to fix...

parent 38a58f51
......@@ -31,7 +31,10 @@ class Message(BaseModel):
content: Union[str, List[Dict], List, None] = None
tool_calls: Optional[List[Dict]] = None
tool_call_id: Optional[str] = None
name: Optional[str] = None
class Config:
extra = "allow" # Allow extra fields not defined in the model
class ChatCompletionRequest(BaseModel):
model: str
messages: List[Message]
......
......@@ -121,7 +121,8 @@ class GoogleProviderHandler(BaseProviderHandler):
self.client = genai.Client(api_key=api_key)
async def handle_request(self, model: str, messages: List[Dict], max_tokens: Optional[int] = None,
temperature: Optional[float] = 1.0, stream: Optional[bool] = False) -> Dict:
temperature: Optional[float] = 1.0, stream: Optional[bool] = False,
tools: Optional[List[Dict]] = None, tool_choice: Optional[Union[str, Dict]] = None) -> Dict:
if self.is_rate_limited():
raise Exception("Provider rate limited")
......@@ -222,12 +223,25 @@ class OpenAIProviderHandler(BaseProviderHandler):
# Build request parameters
request_params = {
"model": model,
"messages": [{"role": msg["role"], "content": msg["content"]} for msg in messages],
"messages": [],
"max_tokens": max_tokens,
"temperature": temperature,
"stream": stream
}
# Build messages with all fields (including tool_calls and tool_call_id)
for msg in messages:
message = {"role": msg["role"]}
if "content" in msg and msg["content"] is not None:
message["content"] = msg["content"]
if "tool_calls" in msg and msg["tool_calls"] is not None:
message["tool_calls"] = msg["tool_calls"]
if "tool_call_id" in msg and msg["tool_call_id"] is not None:
message["tool_call_id"] = msg["tool_call_id"]
if "name" in msg and msg["name"] is not None:
message["name"] = msg["name"]
request_params["messages"].append(message)
# Add tools and tool_choice if provided
if tools is not None:
request_params["tools"] = tools
......@@ -271,7 +285,8 @@ class AnthropicProviderHandler(BaseProviderHandler):
self.client = Anthropic(api_key=api_key)
async def handle_request(self, model: str, messages: List[Dict], max_tokens: Optional[int] = None,
temperature: Optional[float] = 1.0, stream: Optional[bool] = False) -> Dict:
temperature: Optional[float] = 1.0, stream: Optional[bool] = False,
tools: Optional[List[Dict]] = None, tool_choice: Optional[Union[str, Dict]] = None) -> Dict:
if self.is_rate_limited():
raise Exception("Provider rate limited")
......@@ -320,7 +335,8 @@ class OllamaProviderHandler(BaseProviderHandler):
self.client = httpx.AsyncClient(base_url=config.providers[provider_id].endpoint, timeout=timeout)
async def handle_request(self, model: str, messages: List[Dict], max_tokens: Optional[int] = None,
temperature: Optional[float] = 1.0, stream: Optional[bool] = False) -> Dict:
temperature: Optional[float] = 1.0, stream: Optional[bool] = False,
tools: Optional[List[Dict]] = None, tool_choice: Optional[Union[str, Dict]] = None) -> Dict:
import logging
import json
logger = logging.getLogger(__name__)
......
......@@ -212,7 +212,27 @@ async def rotation_chat_completions(request: Request, body: ChatCompletionReques
try:
if body.stream:
logger.debug("Handling streaming rotation request")
return await rotation_handler.handle_rotation_request(body.model, body_dict)
rotation_config = config.get_rotation(body.model)
if not rotation_config:
raise HTTPException(status_code=400, detail=f"Rotation {body.model} not found")
async def stream_generator():
try:
response = await rotation_handler.handle_rotation_request(body.model, body_dict)
for chunk in response:
try:
chunk_dict = chunk.model_dump() if hasattr(chunk, 'model_dump') else chunk
import json
yield f"data: {json.dumps(chunk_dict)}\n\n".encode('utf-8')
except Exception as chunk_error:
logger.warning(f"Error serializing chunk: {str(chunk_error)}")
continue
except Exception as e:
logger.error(f"Error in streaming response: {str(e)}")
import json
yield f"data: {json.dumps({'error': str(e)})}\n\n".encode('utf-8')
return StreamingResponse(stream_generator(), media_type="text/event-stream")
else:
logger.debug("Handling non-streaming rotation request")
result = await rotation_handler.handle_rotation_request(body.model, body_dict)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment