Commit c1e71237 authored by Your Name's avatar Your Name

Integrate model_parser module with LiteLLM backend

- Add tool_parser parameter to litellm backend calls in coderai endpoint
- ModelParserAdapter now passed to both streaming and non-streaming calls
- Enables model-specific tool call parsing for external models via litellm
parent 0ab10131
...@@ -81,6 +81,8 @@ class LiteLLMBackend: ...@@ -81,6 +81,8 @@ class LiteLLMBackend:
self.base_url = base_url self.base_url = base_url
self.context_window = context_window self.context_window = context_window
self.model_manager = model_manager self.model_manager = model_manager
self.tool_parser = None # Coderai's tool parser for post-processing
self.tools_schema = {} # Tools schema for coderai parser
# Configure litellm # Configure litellm
if base_url: if base_url:
...@@ -371,6 +373,7 @@ class LiteLLMBackend: ...@@ -371,6 +373,7 @@ class LiteLLMBackend:
tools: Optional[List[Dict]] = None, tools: Optional[List[Dict]] = None,
tool_choice: Optional[Union[str, Dict]] = "auto", tool_choice: Optional[Union[str, Dict]] = "auto",
stream: bool = False, stream: bool = False,
tool_parser=None, # Add coderai's tool parser for post-processing
**kwargs **kwargs
) -> Union[Dict, AsyncGenerator]: ) -> Union[Dict, AsyncGenerator]:
""" """
...@@ -386,6 +389,7 @@ class LiteLLMBackend: ...@@ -386,6 +389,7 @@ class LiteLLMBackend:
tools: Tool definitions tools: Tool definitions
tool_choice: Tool choice mode tool_choice: Tool choice mode
stream: Whether to stream the response stream: Whether to stream the response
tool_parser: Optional coderai tool parser for post-processing tool calls
Returns: Returns:
Response dict or async generator for streaming Response dict or async generator for streaming
...@@ -393,6 +397,20 @@ class LiteLLMBackend: ...@@ -393,6 +397,20 @@ class LiteLLMBackend:
if not LITELLM_AVAILABLE: if not LITELLM_AVAILABLE:
raise RuntimeError("litellm is not installed. Run: pip install litellm") raise RuntimeError("litellm is not installed. Run: pip install litellm")
# Store tool_parser for post-processing
self.tool_parser = tool_parser
# Convert tools to coderai schema format if tools provided
if tools:
self.tools_schema = {}
for tool in tools:
if isinstance(tool, dict) and 'function' in tool:
func = tool.get('function', {})
self.tools_schema[func.get('name', '')] = {
'description': func.get('description', ''),
'parameters': func.get('parameters', {})
}
# Prepare the model - normalize name for litellm # Prepare the model - normalize name for litellm
use_model = self.normalize_model_name(model or self.model) use_model = self.normalize_model_name(model or self.model)
...@@ -487,7 +505,29 @@ class LiteLLMBackend: ...@@ -487,7 +505,29 @@ class LiteLLMBackend:
if tool_calls: if tool_calls:
result["choices"][0]["message"]["tool_calls"] = tool_calls result["choices"][0]["message"]["tool_calls"] = tool_calls
# Use coderai's tool parser for post-processing if available
if self.tool_parser and content:
# Try to extract tool calls using coderai's parser
try:
# Convert tools to the format expected by coderai parser
tools_schema = {}
if hasattr(self, 'tools_schema') and self.tools_schema:
tools_schema = self.tools_schema
# Use coderai parser to extract tool calls from content
parsed_tool_calls = self.tool_parser.extract_tool_calls(content, tools_schema) if hasattr(self.tool_parser, 'extract_tool_calls') else None
if parsed_tool_calls:
# Replace tool calls with coderai-parsed versions
result["choices"][0]["message"]["tool_calls"] = parsed_tool_calls
# Strip tool tags from content
if hasattr(self.tool_parser, 'strip_tool_calls_from_content'):
clean_content = self.tool_parser.strip_tool_calls_from_content(content)
result["choices"][0]["message"]["content"] = clean_content
except Exception as e:
print(f"DEBUG litellm: Coderai parser post-processing error: {e}")
return result return result
async def _stream_response(self, completion_args: Dict) -> AsyncGenerator: async def _stream_response(self, completion_args: Dict) -> AsyncGenerator:
...@@ -546,7 +586,33 @@ class LiteLLMBackend: ...@@ -546,7 +586,33 @@ class LiteLLMBackend:
if tool_calls: if tool_calls:
result["choices"][0]["delta"]["tool_calls"] = tool_calls result["choices"][0]["delta"]["tool_calls"] = tool_calls
# Accumulate content for coderai parser post-processing at end of stream
if content:
if not hasattr(self, '_accumulated_content'):
self._accumulated_content = ""
self._accumulated_content += content
# Use coderai's tool parser for post-processing if available and this is final chunk
if self.tool_parser and hasattr(self, '_accumulated_content') and self._accumulated_content:
if finish_reason == 'stop':
try:
# Use coderai parser to extract tool calls from accumulated content
tools_schema = getattr(self, 'tools_schema', {})
if hasattr(self.tool_parser, 'extract_tool_calls'):
parsed_tool_calls = self.tool_parser.extract_tool_calls(self._accumulated_content, tools_schema)
if parsed_tool_calls:
# Add tool calls to final chunk
result["choices"][0]["delta"]["tool_calls"] = parsed_tool_calls
# Strip tool tags from content
if hasattr(self.tool_parser, 'strip_tool_calls_from_content'):
clean_content = self.tool_parser.strip_tool_calls_from_content(self._accumulated_content)
result["choices"][0]["delta"]["content"] = clean_content
# Clear accumulated content after processing
self._accumulated_content = ""
except Exception as e:
print(f"DEBUG litellm: Coderai parser stream post-processing error: {e}")
return result return result
def _handle_error(self, exception: Exception) -> Dict: def _handle_error(self, exception: Exception) -> Dict:
......
...@@ -5190,6 +5190,9 @@ async def chat_completions(request: ChatCompletionRequest): ...@@ -5190,6 +5190,9 @@ async def chat_completions(request: ChatCompletionRequest):
model_manager=multi_model_manager # Pass for alias resolution model_manager=multi_model_manager # Pass for alias resolution
) )
# Get the tool_parser from multi_model_manager for model-specific parsing
tool_parser = multi_model_manager.tool_parser if hasattr(multi_model_manager, 'tool_parser') else None
# Convert messages to dict format # Convert messages to dict format
messages_dict = [] messages_dict = []
for msg in request.messages: for msg in request.messages:
...@@ -5223,6 +5226,7 @@ async def chat_completions(request: ChatCompletionRequest): ...@@ -5223,6 +5226,7 @@ async def chat_completions(request: ChatCompletionRequest):
tools=tools_dict, tools=tools_dict,
tool_choice=request.tool_choice, tool_choice=request.tool_choice,
stream=True, stream=True,
tool_parser=tool_parser,
): ):
# Add rate limit headers # Add rate limit headers
headers = {} headers = {}
...@@ -5265,6 +5269,7 @@ async def chat_completions(request: ChatCompletionRequest): ...@@ -5265,6 +5269,7 @@ async def chat_completions(request: ChatCompletionRequest):
tools=tools_dict, tools=tools_dict,
tool_choice=request.tool_choice, tool_choice=request.tool_choice,
stream=False, stream=False,
tool_parser=tool_parser,
) )
# Handle Qwen tool calls # Handle Qwen tool calls
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment