fix: Stream chunks normally, only add tool call chunk at end

Instead of collecting all chunks and sending a modified response:
- Stream chunks normally as they come (with deltas like before)
- Only at the END, if tool call pattern detected, send additional chunk with tool_calls
- Then send final chunk with usage statistics

This preserves the original streaming behavior while adding tool call detection.
parent 72c50449
......@@ -1370,15 +1370,16 @@ class RotationHandler:
# Track completion tokens for Google responses (since Google doesn't provide them)
completion_tokens = 0
accumulated_response_text = "" # Track full response for token counting
accumulated_response_text = "" # Track full response for token counting and tool detection
# Collect all chunks first to know when we're at the last one
# Collect all chunks first (needed for Google's accumulated text format)
chunks_list = []
async for chunk in response:
chunks_list.append(chunk)
total_chunks = len(chunks_list)
chunk_idx = 0
sent_first_chunk = False # Track if we've sent the first chunk with role
for chunk in chunks_list:
try:
......@@ -1410,9 +1411,43 @@ class RotationHandler:
delta_text = chunk_text[len(accumulated_text):] if chunk_text.startswith(accumulated_text) else chunk_text
accumulated_text = chunk_text # Update accumulated text for next iteration
# Track completion tokens for Google responses
if delta_text:
accumulated_response_text += delta_text
# Track full response for tool detection
if chunk_text:
accumulated_response_text = chunk_text
# Check if this is the last chunk
is_last_chunk = (chunk_idx == total_chunks - 1)
chunk_finish_reason = finish_reason if is_last_chunk else None
# Only send if there's new content or it's the last chunk
if delta_text or is_last_chunk:
# Send streaming chunks normally (with deltas)
openai_chunk = {
"id": response_id,
"object": "chat.completion.chunk",
"created": created_time,
"model": model_name,
"service_tier": None,
"system_fingerprint": system_fingerprint,
"usage": None,
"provider": provider_id,
"choices": [{
"index": 0,
"delta": {
"content": delta_text if delta_text else "",
"refusal": None,
"role": "assistant" if not sent_first_chunk else None,
"tool_calls": None
},
"finish_reason": chunk_finish_reason,
"logprobs": None,
"native_finish_reason": chunk_finish_reason
}]
}
sent_first_chunk = True
chunk_id += 1
logger.debug(f"OpenAI chunk (delta length: {len(delta_text)}, finish: {chunk_finish_reason})")
yield f"data: {json.dumps(openai_chunk)}\n\n".encode('utf-8')
chunk_idx += 1
except Exception as chunk_error:
......@@ -1423,10 +1458,9 @@ class RotationHandler:
chunk_idx += 1
continue
# After collecting all chunks, check if the accumulated text contains a tool call pattern
# This handles models that return tool calls as text instead of using function_call attributes
# AFTER streaming all chunks, check if the accumulated text contains a tool call pattern
# If detected, send an additional chunk with the tool call
tool_calls = None
final_text = accumulated_response_text
logger.debug(f"=== ACCUMULATED RESPONSE TEXT ===")
logger.debug(f"Total length: {len(accumulated_response_text)}")
......@@ -1438,7 +1472,6 @@ class RotationHandler:
import re as re_module
# Simple approach: just look for "tool: {...}" pattern and extract the JSON
# This avoids complex nested parsing issues
tool_pattern = r'tool:\s*(\{[^{}]*\{[^{}]*\}[^{}]*\}|\{[^{}]+\})'
tool_match = re_module.search(tool_pattern, accumulated_response_text, re_module.DOTALL)
......@@ -1469,23 +1502,13 @@ class RotationHandler:
logger.debug(f"First 20 bytes (repr): {repr(tool_json_str[:20])}")
logger.debug(f"ASCII codes for first 20 chars: {[ord(c) for c in tool_json_str[:20]]}")
# The model may return literal \n (backslash-n) instead of actual newlines
# JSON requires actual newlines between tokens, not escape sequences
# We need to decode escape sequences, but carefully to preserve
# escaped quotes and backslashes inside string values
# Try parsing with unicode escape decoding
try:
# First try parsing as-is
parsed_tool = json.loads(tool_json_str)
logger.debug(f"Successfully parsed tool JSON as-is")
except json.JSONDecodeError as e:
logger.debug(f"JSON parse error (as-is): {e}")
logger.debug(f"Error at position {e.pos if hasattr(e, 'pos') else 'unknown'}")
# Try decoding escape sequences
# Replace literal \n (outside strings) with actual newlines
# This is tricky - we need to handle \n between tokens but preserve \\n in strings
try:
# Use codecs to decode unicode escape sequences
import codecs
decoded_json = codecs.decode(tool_json_str, 'unicode_escape')
logger.debug(f"Decoded JSON (first 200 chars): {decoded_json[:200]}")
......@@ -1493,7 +1516,6 @@ class RotationHandler:
logger.debug(f"Successfully parsed decoded JSON")
except (json.JSONDecodeError, UnicodeDecodeError) as e2:
logger.debug(f"Decoded JSON also failed: {e2}")
# Last resort: try fixing common issues
fixed_json = tool_json_str.replace("'", '"')
fixed_json = re_module.sub(r',\s*}', '}', fixed_json)
fixed_json = re_module.sub(r',\s*]', ']', fixed_json)
......@@ -1503,17 +1525,7 @@ class RotationHandler:
logger.debug(f"Successfully parsed fixed JSON")
except json.JSONDecodeError as e3:
logger.debug(f"Fixed JSON also failed: {e3}")
raise e # Re-raise original error
else:
# No tool call found, but check if the response is wrapped in assistant: [...] format
# and extract just the text content
assistant_text_pattern = r"^assistant:\s*\[.*'type':\s*'text',\s*'text':\s*['\"](.+?)['\"].*\]\s*$"
assistant_text_match = re_module.match(assistant_text_pattern, accumulated_response_text.strip(), re_module.DOTALL)
if assistant_text_match:
final_text = assistant_text_match.group(1)
# Unescape common escape sequences
final_text = final_text.replace("\\n", "\n").replace("\\'", "'").replace('\\"', '"')
logger.debug(f"Extracted text from assistant wrapper: {final_text[:100]}...")
raise e
# Convert to OpenAI tool_calls format
tool_calls = [{
......@@ -1525,26 +1537,11 @@ class RotationHandler:
}
}]
logger.info(f"Converted streaming tool call to OpenAI format: {tool_calls}")
# Extract final assistant text after the tool JSON
# Look for pattern: }\\nassistant: [{'type': 'text', 'text': "..."}]
# or just return empty since the tool call is the main content
after_tool = accumulated_response_text[json_end:]
assistant_pattern = r"assistant:\s*\[.*'text':\s*['\"](.+?)['\"].*\]\s*\]?\s*$"
assistant_match = re_module.search(assistant_pattern, after_tool, re_module.DOTALL)
if assistant_match:
final_text = assistant_match.group(1)
# Unescape common escape sequences
final_text = final_text.replace("\\n", "\n").replace("\\'", "'").replace('\\"', '"')
else:
final_text = ""
except (json.JSONDecodeError, ValueError, SyntaxError, Exception) as e:
logger.debug(f"Failed to parse tool JSON in streaming: {e}")
# Now send the response chunks
# If we detected tool calls, send them in the first chunk with role
# If tool calls detected, send an additional chunk with the tool call
if tool_calls:
# First chunk with tool_calls
tool_chunk = {
"id": response_id,
"object": "chat.completion.chunk",
......@@ -1559,7 +1556,7 @@ class RotationHandler:
"delta": {
"content": None,
"refusal": None,
"role": "assistant",
"role": None,
"tool_calls": tool_calls
},
"finish_reason": None,
......@@ -1568,61 +1565,8 @@ class RotationHandler:
}]
}
yield f"data: {json.dumps(tool_chunk)}\n\n".encode('utf-8')
# If there's final assistant text, send it
if final_text:
text_chunk = {
"id": response_id,
"object": "chat.completion.chunk",
"created": created_time,
"model": model_name,
"service_tier": None,
"system_fingerprint": system_fingerprint,
"usage": None,
"provider": provider_id,
"choices": [{
"index": 0,
"delta": {
"content": final_text,
"refusal": None,
"role": None,
"tool_calls": None
},
"finish_reason": None,
"logprobs": None,
"native_finish_reason": None
}]
}
yield f"data: {json.dumps(text_chunk)}\n\n".encode('utf-8')
else:
# No tool calls detected, send text normally
# Send the accumulated text as a single chunk
if accumulated_response_text:
text_chunk = {
"id": response_id,
"object": "chat.completion.chunk",
"created": created_time,
"model": model_name,
"service_tier": None,
"system_fingerprint": system_fingerprint,
"usage": None,
"provider": provider_id,
"choices": [{
"index": 0,
"delta": {
"content": accumulated_response_text,
"refusal": None,
"role": "assistant",
"tool_calls": None
},
"finish_reason": None,
"logprobs": None,
"native_finish_reason": None
}]
}
yield f"data: {json.dumps(text_chunk)}\n\n".encode('utf-8')
# Send final chunk with finish reason and usage statistics
# Send final chunk with usage statistics
if accumulated_response_text:
completion_tokens = count_messages_tokens([{"role": "assistant", "content": accumulated_response_text}], model_name)
total_tokens = effective_context + completion_tokens
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment