Try to fix...

parent 38a58f51
...@@ -31,7 +31,10 @@ class Message(BaseModel): ...@@ -31,7 +31,10 @@ class Message(BaseModel):
content: Union[str, List[Dict], List, None] = None content: Union[str, List[Dict], List, None] = None
tool_calls: Optional[List[Dict]] = None tool_calls: Optional[List[Dict]] = None
tool_call_id: Optional[str] = None tool_call_id: Optional[str] = None
name: Optional[str] = None
class Config:
extra = "allow" # Allow extra fields not defined in the model
class ChatCompletionRequest(BaseModel): class ChatCompletionRequest(BaseModel):
model: str model: str
messages: List[Message] messages: List[Message]
......
...@@ -121,7 +121,8 @@ class GoogleProviderHandler(BaseProviderHandler): ...@@ -121,7 +121,8 @@ class GoogleProviderHandler(BaseProviderHandler):
self.client = genai.Client(api_key=api_key) self.client = genai.Client(api_key=api_key)
async def handle_request(self, model: str, messages: List[Dict], max_tokens: Optional[int] = None, async def handle_request(self, model: str, messages: List[Dict], max_tokens: Optional[int] = None,
temperature: Optional[float] = 1.0, stream: Optional[bool] = False) -> Dict: temperature: Optional[float] = 1.0, stream: Optional[bool] = False,
tools: Optional[List[Dict]] = None, tool_choice: Optional[Union[str, Dict]] = None) -> Dict:
if self.is_rate_limited(): if self.is_rate_limited():
raise Exception("Provider rate limited") raise Exception("Provider rate limited")
...@@ -222,12 +223,25 @@ class OpenAIProviderHandler(BaseProviderHandler): ...@@ -222,12 +223,25 @@ class OpenAIProviderHandler(BaseProviderHandler):
# Build request parameters # Build request parameters
request_params = { request_params = {
"model": model, "model": model,
"messages": [{"role": msg["role"], "content": msg["content"]} for msg in messages], "messages": [],
"max_tokens": max_tokens, "max_tokens": max_tokens,
"temperature": temperature, "temperature": temperature,
"stream": stream "stream": stream
} }
# Build messages with all fields (including tool_calls and tool_call_id)
for msg in messages:
message = {"role": msg["role"]}
if "content" in msg and msg["content"] is not None:
message["content"] = msg["content"]
if "tool_calls" in msg and msg["tool_calls"] is not None:
message["tool_calls"] = msg["tool_calls"]
if "tool_call_id" in msg and msg["tool_call_id"] is not None:
message["tool_call_id"] = msg["tool_call_id"]
if "name" in msg and msg["name"] is not None:
message["name"] = msg["name"]
request_params["messages"].append(message)
# Add tools and tool_choice if provided # Add tools and tool_choice if provided
if tools is not None: if tools is not None:
request_params["tools"] = tools request_params["tools"] = tools
...@@ -271,7 +285,8 @@ class AnthropicProviderHandler(BaseProviderHandler): ...@@ -271,7 +285,8 @@ class AnthropicProviderHandler(BaseProviderHandler):
self.client = Anthropic(api_key=api_key) self.client = Anthropic(api_key=api_key)
async def handle_request(self, model: str, messages: List[Dict], max_tokens: Optional[int] = None, async def handle_request(self, model: str, messages: List[Dict], max_tokens: Optional[int] = None,
temperature: Optional[float] = 1.0, stream: Optional[bool] = False) -> Dict: temperature: Optional[float] = 1.0, stream: Optional[bool] = False,
tools: Optional[List[Dict]] = None, tool_choice: Optional[Union[str, Dict]] = None) -> Dict:
if self.is_rate_limited(): if self.is_rate_limited():
raise Exception("Provider rate limited") raise Exception("Provider rate limited")
...@@ -320,7 +335,8 @@ class OllamaProviderHandler(BaseProviderHandler): ...@@ -320,7 +335,8 @@ class OllamaProviderHandler(BaseProviderHandler):
self.client = httpx.AsyncClient(base_url=config.providers[provider_id].endpoint, timeout=timeout) self.client = httpx.AsyncClient(base_url=config.providers[provider_id].endpoint, timeout=timeout)
async def handle_request(self, model: str, messages: List[Dict], max_tokens: Optional[int] = None, async def handle_request(self, model: str, messages: List[Dict], max_tokens: Optional[int] = None,
temperature: Optional[float] = 1.0, stream: Optional[bool] = False) -> Dict: temperature: Optional[float] = 1.0, stream: Optional[bool] = False,
tools: Optional[List[Dict]] = None, tool_choice: Optional[Union[str, Dict]] = None) -> Dict:
import logging import logging
import json import json
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
...@@ -212,7 +212,27 @@ async def rotation_chat_completions(request: Request, body: ChatCompletionReques ...@@ -212,7 +212,27 @@ async def rotation_chat_completions(request: Request, body: ChatCompletionReques
try: try:
if body.stream: if body.stream:
logger.debug("Handling streaming rotation request") logger.debug("Handling streaming rotation request")
return await rotation_handler.handle_rotation_request(body.model, body_dict) rotation_config = config.get_rotation(body.model)
if not rotation_config:
raise HTTPException(status_code=400, detail=f"Rotation {body.model} not found")
async def stream_generator():
try:
response = await rotation_handler.handle_rotation_request(body.model, body_dict)
for chunk in response:
try:
chunk_dict = chunk.model_dump() if hasattr(chunk, 'model_dump') else chunk
import json
yield f"data: {json.dumps(chunk_dict)}\n\n".encode('utf-8')
except Exception as chunk_error:
logger.warning(f"Error serializing chunk: {str(chunk_error)}")
continue
except Exception as e:
logger.error(f"Error in streaming response: {str(e)}")
import json
yield f"data: {json.dumps({'error': str(e)})}\n\n".encode('utf-8')
return StreamingResponse(stream_generator(), media_type="text/event-stream")
else: else:
logger.debug("Handling non-streaming rotation request") logger.debug("Handling non-streaming rotation request")
result = await rotation_handler.handle_rotation_request(body.model, body_dict) result = await rotation_handler.handle_rotation_request(body.model, body_dict)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment