feat: Add tool call detection in streaming responses

- Detect tool calls in accumulated streaming text after all chunks received
- Parse nested 'assistant: [...]' format with tool calls inside
- Parse simple 'tool: {...}' format
- Convert detected tool calls to OpenAI-compatible format
- Send tool_calls in first chunk, then final assistant text
- Proper handling of finish_reason in final chunk
parent 847353e3
...@@ -1410,44 +1410,9 @@ class RotationHandler: ...@@ -1410,44 +1410,9 @@ class RotationHandler:
delta_text = chunk_text[len(accumulated_text):] if chunk_text.startswith(accumulated_text) else chunk_text delta_text = chunk_text[len(accumulated_text):] if chunk_text.startswith(accumulated_text) else chunk_text
accumulated_text = chunk_text # Update accumulated text for next iteration accumulated_text = chunk_text # Update accumulated text for next iteration
# Check if this is the last chunk # Track completion tokens for Google responses
is_last_chunk = (chunk_idx == total_chunks - 1) if delta_text:
chunk_finish_reason = finish_reason if is_last_chunk else None accumulated_response_text += delta_text
# Only send if there's new content or it's the last chunk with finish_reason
if delta_text or is_last_chunk:
# Create OpenAI-compatible chunk with additional fields
openai_chunk = {
"id": response_id,
"object": "chat.completion.chunk",
"created": created_time,
"model": model_name,
"service_tier": None,
"system_fingerprint": system_fingerprint,
"usage": None,
"provider": provider_id,
"choices": [{
"index": 0,
"delta": {
"content": delta_text if delta_text else "",
"refusal": None,
"role": "assistant",
"tool_calls": None
},
"finish_reason": chunk_finish_reason,
"logprobs": None,
"native_finish_reason": chunk_finish_reason
}]
}
chunk_id += 1
logger.debug(f"OpenAI chunk (delta length: {len(delta_text)}, finish: {chunk_finish_reason})")
# Track completion tokens for Google responses
if delta_text:
accumulated_response_text += delta_text
yield f"data: {json.dumps(openai_chunk)}\n\n".encode('utf-8')
chunk_idx += 1 chunk_idx += 1
except Exception as chunk_error: except Exception as chunk_error:
...@@ -1458,8 +1423,222 @@ class RotationHandler: ...@@ -1458,8 +1423,222 @@ class RotationHandler:
chunk_idx += 1 chunk_idx += 1
continue continue
# Send final chunk with usage statistics (empty content) # After collecting all chunks, check if the accumulated text contains a tool call pattern
# Calculate completion tokens for Google responses (count tokens in full response) # This handles models that return tool calls as text instead of using function_call attributes
tool_calls = None
final_text = accumulated_response_text
# Check for tool call patterns in the accumulated text
if accumulated_response_text:
import re as re_module
# Pattern 0: "assistant: [...]" wrapping everything (nested format)
outer_assistant_pattern = r"^assistant:\s*(\[.*\])\s*$"
outer_assistant_match = re_module.match(outer_assistant_pattern, accumulated_response_text.strip(), re_module.DOTALL)
if outer_assistant_match:
try:
outer_content = json.loads(outer_assistant_match.group(1))
if isinstance(outer_content, list) and len(outer_content) > 0:
for item in outer_content:
if isinstance(item, dict) and item.get('type') == 'text':
inner_text = item.get('text', '')
# Now parse the inner text for tool calls
inner_tool_pattern = r'tool:\s*(\{.*?\})\s*(?:assistant:\s*(\[.*\]))?\s*$'
inner_tool_match = re_module.search(inner_tool_pattern, inner_text, re_module.DOTALL)
if inner_tool_match:
tool_json_str = inner_tool_match.group(1)
# Parse the tool JSON - handle multi-line content
try:
# Extract JSON using a more robust method
tool_start = inner_text.find('tool:')
if tool_start != -1:
json_start = inner_text.find('{', tool_start)
brace_count = 0
json_end = json_start
for i, c in enumerate(inner_text[json_start:], json_start):
if c == '{':
brace_count += 1
elif c == '}':
brace_count -= 1
if brace_count == 0:
json_end = i + 1
break
tool_json_str = inner_text[json_start:json_end]
parsed_tool = json.loads(tool_json_str)
# Convert to OpenAI tool_calls format
tool_calls = [{
"id": f"call_0",
"type": "function",
"function": {
"name": parsed_tool.get('action', parsed_tool.get('name', 'unknown')),
"arguments": json.dumps({k: v for k, v in parsed_tool.items() if k not in ['action', 'name']})
}
}]
logger.info(f"Converted streaming tool call to OpenAI format: {tool_calls}")
# Extract the final assistant text if present
if inner_tool_match.group(2):
try:
final_assistant = json.loads(inner_tool_match.group(2))
if isinstance(final_assistant, list) and len(final_assistant) > 0:
for final_item in final_assistant:
if isinstance(final_item, dict) and final_item.get('type') == 'text':
final_text = final_item.get('text', '')
break
else:
final_text = ""
else:
final_text = ""
except json.JSONDecodeError:
final_text = ""
else:
final_text = ""
except (json.JSONDecodeError, Exception) as e:
logger.debug(f"Failed to parse streaming tool JSON: {e}")
break
except (json.JSONDecodeError, Exception) as e:
logger.debug(f"Failed to parse outer assistant format in streaming: {e}")
# Pattern 1: Simple "tool: {...}" format (not nested)
elif not tool_calls:
tool_pattern = r'tool:\s*(\{.*?\})\s*(?:assistant:\s*(\[.*\]))?\s*$'
tool_match = re_module.search(tool_pattern, accumulated_response_text, re_module.DOTALL)
if tool_match:
try:
# Extract JSON using brace counting
tool_start = accumulated_response_text.find('tool:')
if tool_start != -1:
json_start = accumulated_response_text.find('{', tool_start)
brace_count = 0
json_end = json_start
for i, c in enumerate(accumulated_response_text[json_start:], json_start):
if c == '{':
brace_count += 1
elif c == '}':
brace_count -= 1
if brace_count == 0:
json_end = i + 1
break
tool_json_str = accumulated_response_text[json_start:json_end]
parsed_tool = json.loads(tool_json_str)
# Convert to OpenAI tool_calls format
tool_calls = [{
"id": f"call_0",
"type": "function",
"function": {
"name": parsed_tool.get('action', parsed_tool.get('name', 'unknown')),
"arguments": json.dumps({k: v for k, v in parsed_tool.items() if k not in ['action', 'name']})
}
}]
logger.info(f"Converted simple streaming tool call to OpenAI format: {tool_calls}")
# Extract the final assistant text if present
if tool_match.group(2):
try:
final_assistant = json.loads(tool_match.group(2))
if isinstance(final_assistant, list) and len(final_assistant) > 0:
for final_item in final_assistant:
if isinstance(final_item, dict) and final_item.get('type') == 'text':
final_text = final_item.get('text', '')
break
else:
final_text = ""
else:
final_text = ""
except json.JSONDecodeError:
final_text = ""
else:
final_text = ""
except (json.JSONDecodeError, Exception) as e:
logger.debug(f"Failed to parse simple streaming tool JSON: {e}")
# Now send the response chunks
# If we detected tool calls, send them in the first chunk with role
if tool_calls:
# First chunk with tool_calls
tool_chunk = {
"id": response_id,
"object": "chat.completion.chunk",
"created": created_time,
"model": model_name,
"service_tier": None,
"system_fingerprint": system_fingerprint,
"usage": None,
"provider": provider_id,
"choices": [{
"index": 0,
"delta": {
"content": None,
"refusal": None,
"role": "assistant",
"tool_calls": tool_calls
},
"finish_reason": None,
"logprobs": None,
"native_finish_reason": None
}]
}
yield f"data: {json.dumps(tool_chunk)}\n\n".encode('utf-8')
# If there's final assistant text, send it
if final_text:
text_chunk = {
"id": response_id,
"object": "chat.completion.chunk",
"created": created_time,
"model": model_name,
"service_tier": None,
"system_fingerprint": system_fingerprint,
"usage": None,
"provider": provider_id,
"choices": [{
"index": 0,
"delta": {
"content": final_text,
"refusal": None,
"role": None,
"tool_calls": None
},
"finish_reason": None,
"logprobs": None,
"native_finish_reason": None
}]
}
yield f"data: {json.dumps(text_chunk)}\n\n".encode('utf-8')
else:
# No tool calls detected, send text normally
# Send the accumulated text as a single chunk
if accumulated_response_text:
text_chunk = {
"id": response_id,
"object": "chat.completion.chunk",
"created": created_time,
"model": model_name,
"service_tier": None,
"system_fingerprint": system_fingerprint,
"usage": None,
"provider": provider_id,
"choices": [{
"index": 0,
"delta": {
"content": accumulated_response_text,
"refusal": None,
"role": "assistant",
"tool_calls": None
},
"finish_reason": None,
"logprobs": None,
"native_finish_reason": None
}]
}
yield f"data: {json.dumps(text_chunk)}\n\n".encode('utf-8')
# Send final chunk with finish reason and usage statistics
if accumulated_response_text: if accumulated_response_text:
completion_tokens = count_messages_tokens([{"role": "assistant", "content": accumulated_response_text}], model_name) completion_tokens = count_messages_tokens([{"role": "assistant", "content": accumulated_response_text}], model_name)
total_tokens = effective_context + completion_tokens total_tokens = effective_context + completion_tokens
...@@ -1483,12 +1662,12 @@ class RotationHandler: ...@@ -1483,12 +1662,12 @@ class RotationHandler:
"content": "", "content": "",
"function_call": None, "function_call": None,
"refusal": None, "refusal": None,
"role": "assistant", "role": None,
"tool_calls": None "tool_calls": None
}, },
"finish_reason": None, "finish_reason": "stop",
"logprobs": None, "logprobs": None,
"native_finish_reason": None "native_finish_reason": "stop"
}] }]
} }
yield f"data: {json.dumps(final_chunk)}\n\n".encode('utf-8') yield f"data: {json.dumps(final_chunk)}\n\n".encode('utf-8')
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment