fix: Handle assistant wrapper pattern in streaming responses

- Detect and unwrap responses wrapped in 'assistant: [{"type": "text", "text": "..."}]' format
- Use extracted text for response content instead of raw accumulated text
- Fix variable scoping issue with tool_match variable
- Update token counting to use final_text when available
parent b307c7fb
......@@ -1370,16 +1370,15 @@ class RotationHandler:
# Track completion tokens for Google responses (since Google doesn't provide them)
completion_tokens = 0
accumulated_response_text = "" # Track full response for token counting and tool detection
accumulated_response_text = "" # Track full response for token counting
# Collect all chunks first (needed for Google's accumulated text format)
# Collect all chunks first to know when we're at the last one
chunks_list = []
async for chunk in response:
chunks_list.append(chunk)
total_chunks = len(chunks_list)
chunk_idx = 0
sent_first_chunk = False # Track if we've sent the first chunk with role
for chunk in chunks_list:
try:
......@@ -1411,43 +1410,9 @@ class RotationHandler:
delta_text = chunk_text[len(accumulated_text):] if chunk_text.startswith(accumulated_text) else chunk_text
accumulated_text = chunk_text # Update accumulated text for next iteration
# Track full response for tool detection
if chunk_text:
accumulated_response_text = chunk_text
# Check if this is the last chunk
is_last_chunk = (chunk_idx == total_chunks - 1)
chunk_finish_reason = finish_reason if is_last_chunk else None
# Only send if there's new content or it's the last chunk
if delta_text or is_last_chunk:
# Send streaming chunks normally (with deltas)
openai_chunk = {
"id": response_id,
"object": "chat.completion.chunk",
"created": created_time,
"model": model_name,
"service_tier": None,
"system_fingerprint": system_fingerprint,
"usage": None,
"provider": provider_id,
"choices": [{
"index": 0,
"delta": {
"content": delta_text if delta_text else "",
"refusal": None,
"role": "assistant" if not sent_first_chunk else None,
"tool_calls": None
},
"finish_reason": chunk_finish_reason,
"logprobs": None,
"native_finish_reason": chunk_finish_reason
}]
}
sent_first_chunk = True
chunk_id += 1
logger.debug(f"OpenAI chunk (delta length: {len(delta_text)}, finish: {chunk_finish_reason})")
yield f"data: {json.dumps(openai_chunk)}\n\n".encode('utf-8')
# Track completion tokens for Google responses
if delta_text:
accumulated_response_text += delta_text
chunk_idx += 1
except Exception as chunk_error:
......@@ -1458,9 +1423,10 @@ class RotationHandler:
chunk_idx += 1
continue
# AFTER streaming all chunks, check if the accumulated text contains a tool call pattern
# If detected, send an additional chunk with the tool call
# After collecting all chunks, check if the accumulated text contains a tool call pattern
# This handles models that return tool calls as text instead of using function_call attributes
tool_calls = None
final_text = accumulated_response_text
logger.debug(f"=== ACCUMULATED RESPONSE TEXT ===")
logger.debug(f"Total length: {len(accumulated_response_text)}")
......@@ -1471,7 +1437,26 @@ class RotationHandler:
if accumulated_response_text:
import re as re_module
# Initialize tool_match to None
tool_match = None
# First, check if the response is wrapped in "assistant: [{'type': 'text', 'text': '...'}]"
# This is a common pattern where the model wraps its response instead of returning plain text
assistant_wrapper_pattern = r"^assistant:\s*\[\s*\{\s*'type':\s*'text',\s*'text':\s*['\"](.+?)['\"]\s*\}\s*\]\s*$"
assistant_wrapper_match = re_module.match(assistant_wrapper_pattern, accumulated_response_text.strip(), re_module.DOTALL)
if assistant_wrapper_match:
# Extract the plain text from the wrapper
extracted_text = assistant_wrapper_match.group(1)
# Unescape common escape sequences
extracted_text = extracted_text.replace("\\'", "'").replace('\\"', '"').replace("\\n", "\n")
logger.debug(f"Extracted text from assistant wrapper: {extracted_text[:200]}...")
final_text = extracted_text
tool_calls = None
else:
# Check for tool call pattern
# Simple approach: just look for "tool: {...}" pattern and extract the JSON
# This avoids complex nested parsing issues
tool_pattern = r'tool:\s*(\{[^{}]*\{[^{}]*\}[^{}]*\}|\{[^{}]+\})'
tool_match = re_module.search(tool_pattern, accumulated_response_text, re_module.DOTALL)
......@@ -1502,20 +1487,13 @@ class RotationHandler:
logger.debug(f"First 20 bytes (repr): {repr(tool_json_str[:20])}")
logger.debug(f"ASCII codes for first 20 chars: {[ord(c) for c in tool_json_str[:20]]}")
# Try parsing with unicode escape decoding
try:
parsed_tool = json.loads(tool_json_str)
logger.debug(f"Successfully parsed tool JSON as-is")
logger.debug(f"Successfully parsed tool JSON")
except json.JSONDecodeError as e:
logger.debug(f"JSON parse error (as-is): {e}")
try:
import codecs
decoded_json = codecs.decode(tool_json_str, 'unicode_escape')
logger.debug(f"Decoded JSON (first 200 chars): {decoded_json[:200]}")
parsed_tool = json.loads(decoded_json)
logger.debug(f"Successfully parsed decoded JSON")
except (json.JSONDecodeError, UnicodeDecodeError) as e2:
logger.debug(f"Decoded JSON also failed: {e2}")
logger.debug(f"JSON parse error: {e}")
logger.debug(f"Error at position {e.pos if hasattr(e, 'pos') else 'unknown'}")
# Try fixing common issues: single quotes, trailing commas
fixed_json = tool_json_str.replace("'", '"')
fixed_json = re_module.sub(r',\s*}', '}', fixed_json)
fixed_json = re_module.sub(r',\s*]', ']', fixed_json)
......@@ -1523,9 +1501,9 @@ class RotationHandler:
try:
parsed_tool = json.loads(fixed_json)
logger.debug(f"Successfully parsed fixed JSON")
except json.JSONDecodeError as e3:
logger.debug(f"Fixed JSON also failed: {e3}")
raise e
except json.JSONDecodeError as e2:
logger.debug(f"Fixed JSON also failed: {e2}")
raise e # Re-raise original error
# Convert to OpenAI tool_calls format
tool_calls = [{
......@@ -1537,11 +1515,26 @@ class RotationHandler:
}
}]
logger.info(f"Converted streaming tool call to OpenAI format: {tool_calls}")
# Extract final assistant text after the tool JSON
# Look for pattern: }\\nassistant: [{'type': 'text', 'text': "..."}]
# or just return empty since the tool call is the main content
after_tool = accumulated_response_text[json_end:]
assistant_pattern = r"assistant:\s*\[.*'text':\s*['\"](.+?)['\"].*\]\s*\]?\s*$"
assistant_match = re_module.search(assistant_pattern, after_tool, re_module.DOTALL)
if assistant_match:
final_text = assistant_match.group(1)
# Unescape common escape sequences
final_text = final_text.replace("\\n", "\n").replace("\\'", "'").replace('\\"', '"')
else:
final_text = ""
except (json.JSONDecodeError, ValueError, SyntaxError, Exception) as e:
logger.debug(f"Failed to parse tool JSON in streaming: {e}")
# If tool calls detected, send an additional chunk with the tool call
# Now send the response chunks
# If we detected tool calls, send them in the first chunk with role
if tool_calls:
# First chunk with tool_calls
tool_chunk = {
"id": response_id,
"object": "chat.completion.chunk",
......@@ -1556,7 +1549,7 @@ class RotationHandler:
"delta": {
"content": None,
"refusal": None,
"role": None,
"role": "assistant",
"tool_calls": tool_calls
},
"finish_reason": None,
......@@ -1566,9 +1559,64 @@ class RotationHandler:
}
yield f"data: {json.dumps(tool_chunk)}\n\n".encode('utf-8')
# Send final chunk with usage statistics
if accumulated_response_text:
completion_tokens = count_messages_tokens([{"role": "assistant", "content": accumulated_response_text}], model_name)
# If there's final assistant text, send it
if final_text:
text_chunk = {
"id": response_id,
"object": "chat.completion.chunk",
"created": created_time,
"model": model_name,
"service_tier": None,
"system_fingerprint": system_fingerprint,
"usage": None,
"provider": provider_id,
"choices": [{
"index": 0,
"delta": {
"content": final_text,
"refusal": None,
"role": None,
"tool_calls": None
},
"finish_reason": None,
"logprobs": None,
"native_finish_reason": None
}]
}
yield f"data: {json.dumps(text_chunk)}\n\n".encode('utf-8')
else:
# No tool calls detected, send text normally
# Send the final text (which may have been extracted from wrapper)
if final_text:
text_chunk = {
"id": response_id,
"object": "chat.completion.chunk",
"created": created_time,
"model": model_name,
"service_tier": None,
"system_fingerprint": system_fingerprint,
"usage": None,
"provider": provider_id,
"choices": [{
"index": 0,
"delta": {
"content": final_text,
"refusal": None,
"role": "assistant",
"tool_calls": None
},
"finish_reason": None,
"logprobs": None,
"native_finish_reason": None
}]
}
yield f"data: {json.dumps(text_chunk)}\n\n".encode('utf-8')
# Send final chunk with finish reason and usage statistics
# Use final_text for token counting (which may have been extracted from wrapper)
text_for_token_count = final_text if final_text else accumulated_response_text
if text_for_token_count:
completion_tokens = count_messages_tokens([{"role": "assistant", "content": text_for_token_count}], model_name)
total_tokens = effective_context + completion_tokens
final_chunk = {
"id": response_id,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment