feat: Add tool call detection in streaming responses

- Detect tool calls in accumulated streaming text after all chunks received
- Parse nested 'assistant: [...]' format with tool calls inside
- Parse simple 'tool: {...}' format
- Convert detected tool calls to OpenAI-compatible format
- Send tool_calls in first chunk, then final assistant text
- Proper handling of finish_reason in final chunk
parent 847353e3
...@@ -1410,14 +1410,158 @@ class RotationHandler: ...@@ -1410,14 +1410,158 @@ class RotationHandler:
delta_text = chunk_text[len(accumulated_text):] if chunk_text.startswith(accumulated_text) else chunk_text delta_text = chunk_text[len(accumulated_text):] if chunk_text.startswith(accumulated_text) else chunk_text
accumulated_text = chunk_text # Update accumulated text for next iteration accumulated_text = chunk_text # Update accumulated text for next iteration
# Check if this is the last chunk # Track completion tokens for Google responses
is_last_chunk = (chunk_idx == total_chunks - 1) if delta_text:
chunk_finish_reason = finish_reason if is_last_chunk else None accumulated_response_text += delta_text
# Only send if there's new content or it's the last chunk with finish_reason chunk_idx += 1
if delta_text or is_last_chunk: except Exception as chunk_error:
# Create OpenAI-compatible chunk with additional fields error_msg = str(chunk_error)
openai_chunk = { logger.error(f"Error processing Google chunk: {error_msg}")
logger.error(f"Chunk type: {type(chunk)}")
logger.error(f"Chunk content: {chunk}")
chunk_idx += 1
continue
# After collecting all chunks, check if the accumulated text contains a tool call pattern
# This handles models that return tool calls as text instead of using function_call attributes
tool_calls = None
final_text = accumulated_response_text
# Check for tool call patterns in the accumulated text
if accumulated_response_text:
import re as re_module
# Pattern 0: "assistant: [...]" wrapping everything (nested format)
outer_assistant_pattern = r"^assistant:\s*(\[.*\])\s*$"
outer_assistant_match = re_module.match(outer_assistant_pattern, accumulated_response_text.strip(), re_module.DOTALL)
if outer_assistant_match:
try:
outer_content = json.loads(outer_assistant_match.group(1))
if isinstance(outer_content, list) and len(outer_content) > 0:
for item in outer_content:
if isinstance(item, dict) and item.get('type') == 'text':
inner_text = item.get('text', '')
# Now parse the inner text for tool calls
inner_tool_pattern = r'tool:\s*(\{.*?\})\s*(?:assistant:\s*(\[.*\]))?\s*$'
inner_tool_match = re_module.search(inner_tool_pattern, inner_text, re_module.DOTALL)
if inner_tool_match:
tool_json_str = inner_tool_match.group(1)
# Parse the tool JSON - handle multi-line content
try:
# Extract JSON using a more robust method
tool_start = inner_text.find('tool:')
if tool_start != -1:
json_start = inner_text.find('{', tool_start)
brace_count = 0
json_end = json_start
for i, c in enumerate(inner_text[json_start:], json_start):
if c == '{':
brace_count += 1
elif c == '}':
brace_count -= 1
if brace_count == 0:
json_end = i + 1
break
tool_json_str = inner_text[json_start:json_end]
parsed_tool = json.loads(tool_json_str)
# Convert to OpenAI tool_calls format
tool_calls = [{
"id": f"call_0",
"type": "function",
"function": {
"name": parsed_tool.get('action', parsed_tool.get('name', 'unknown')),
"arguments": json.dumps({k: v for k, v in parsed_tool.items() if k not in ['action', 'name']})
}
}]
logger.info(f"Converted streaming tool call to OpenAI format: {tool_calls}")
# Extract the final assistant text if present
if inner_tool_match.group(2):
try:
final_assistant = json.loads(inner_tool_match.group(2))
if isinstance(final_assistant, list) and len(final_assistant) > 0:
for final_item in final_assistant:
if isinstance(final_item, dict) and final_item.get('type') == 'text':
final_text = final_item.get('text', '')
break
else:
final_text = ""
else:
final_text = ""
except json.JSONDecodeError:
final_text = ""
else:
final_text = ""
except (json.JSONDecodeError, Exception) as e:
logger.debug(f"Failed to parse streaming tool JSON: {e}")
break
except (json.JSONDecodeError, Exception) as e:
logger.debug(f"Failed to parse outer assistant format in streaming: {e}")
# Pattern 1: Simple "tool: {...}" format (not nested)
elif not tool_calls:
tool_pattern = r'tool:\s*(\{.*?\})\s*(?:assistant:\s*(\[.*\]))?\s*$'
tool_match = re_module.search(tool_pattern, accumulated_response_text, re_module.DOTALL)
if tool_match:
try:
# Extract JSON using brace counting
tool_start = accumulated_response_text.find('tool:')
if tool_start != -1:
json_start = accumulated_response_text.find('{', tool_start)
brace_count = 0
json_end = json_start
for i, c in enumerate(accumulated_response_text[json_start:], json_start):
if c == '{':
brace_count += 1
elif c == '}':
brace_count -= 1
if brace_count == 0:
json_end = i + 1
break
tool_json_str = accumulated_response_text[json_start:json_end]
parsed_tool = json.loads(tool_json_str)
# Convert to OpenAI tool_calls format
tool_calls = [{
"id": f"call_0",
"type": "function",
"function": {
"name": parsed_tool.get('action', parsed_tool.get('name', 'unknown')),
"arguments": json.dumps({k: v for k, v in parsed_tool.items() if k not in ['action', 'name']})
}
}]
logger.info(f"Converted simple streaming tool call to OpenAI format: {tool_calls}")
# Extract the final assistant text if present
if tool_match.group(2):
try:
final_assistant = json.loads(tool_match.group(2))
if isinstance(final_assistant, list) and len(final_assistant) > 0:
for final_item in final_assistant:
if isinstance(final_item, dict) and final_item.get('type') == 'text':
final_text = final_item.get('text', '')
break
else:
final_text = ""
else:
final_text = ""
except json.JSONDecodeError:
final_text = ""
else:
final_text = ""
except (json.JSONDecodeError, Exception) as e:
logger.debug(f"Failed to parse simple streaming tool JSON: {e}")
# Now send the response chunks
# If we detected tool calls, send them in the first chunk with role
if tool_calls:
# First chunk with tool_calls
tool_chunk = {
"id": response_id, "id": response_id,
"object": "chat.completion.chunk", "object": "chat.completion.chunk",
"created": created_time, "created": created_time,
...@@ -1429,37 +1573,72 @@ class RotationHandler: ...@@ -1429,37 +1573,72 @@ class RotationHandler:
"choices": [{ "choices": [{
"index": 0, "index": 0,
"delta": { "delta": {
"content": delta_text if delta_text else "", "content": None,
"refusal": None, "refusal": None,
"role": "assistant", "role": "assistant",
"tool_calls": None "tool_calls": tool_calls
}, },
"finish_reason": chunk_finish_reason, "finish_reason": None,
"logprobs": None, "logprobs": None,
"native_finish_reason": chunk_finish_reason "native_finish_reason": None
}] }]
} }
yield f"data: {json.dumps(tool_chunk)}\n\n".encode('utf-8')
chunk_id += 1 # If there's final assistant text, send it
logger.debug(f"OpenAI chunk (delta length: {len(delta_text)}, finish: {chunk_finish_reason})") if final_text:
text_chunk = {
# Track completion tokens for Google responses "id": response_id,
if delta_text: "object": "chat.completion.chunk",
accumulated_response_text += delta_text "created": created_time,
"model": model_name,
yield f"data: {json.dumps(openai_chunk)}\n\n".encode('utf-8') "service_tier": None,
"system_fingerprint": system_fingerprint,
chunk_idx += 1 "usage": None,
except Exception as chunk_error: "provider": provider_id,
error_msg = str(chunk_error) "choices": [{
logger.error(f"Error processing Google chunk: {error_msg}") "index": 0,
logger.error(f"Chunk type: {type(chunk)}") "delta": {
logger.error(f"Chunk content: {chunk}") "content": final_text,
chunk_idx += 1 "refusal": None,
continue "role": None,
"tool_calls": None
},
"finish_reason": None,
"logprobs": None,
"native_finish_reason": None
}]
}
yield f"data: {json.dumps(text_chunk)}\n\n".encode('utf-8')
else:
# No tool calls detected, send text normally
# Send the accumulated text as a single chunk
if accumulated_response_text:
text_chunk = {
"id": response_id,
"object": "chat.completion.chunk",
"created": created_time,
"model": model_name,
"service_tier": None,
"system_fingerprint": system_fingerprint,
"usage": None,
"provider": provider_id,
"choices": [{
"index": 0,
"delta": {
"content": accumulated_response_text,
"refusal": None,
"role": "assistant",
"tool_calls": None
},
"finish_reason": None,
"logprobs": None,
"native_finish_reason": None
}]
}
yield f"data: {json.dumps(text_chunk)}\n\n".encode('utf-8')
# Send final chunk with usage statistics (empty content) # Send final chunk with finish reason and usage statistics
# Calculate completion tokens for Google responses (count tokens in full response)
if accumulated_response_text: if accumulated_response_text:
completion_tokens = count_messages_tokens([{"role": "assistant", "content": accumulated_response_text}], model_name) completion_tokens = count_messages_tokens([{"role": "assistant", "content": accumulated_response_text}], model_name)
total_tokens = effective_context + completion_tokens total_tokens = effective_context + completion_tokens
...@@ -1483,12 +1662,12 @@ class RotationHandler: ...@@ -1483,12 +1662,12 @@ class RotationHandler:
"content": "", "content": "",
"function_call": None, "function_call": None,
"refusal": None, "refusal": None,
"role": "assistant", "role": None,
"tool_calls": None "tool_calls": None
}, },
"finish_reason": None, "finish_reason": "stop",
"logprobs": None, "logprobs": None,
"native_finish_reason": None "native_finish_reason": "stop"
}] }]
} }
yield f"data: {json.dumps(final_chunk)}\n\n".encode('utf-8') yield f"data: {json.dumps(final_chunk)}\n\n".encode('utf-8')
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment