Commit d9cba7ec authored by Your Name's avatar Your Name

Fix tool extraction and repetition detection

- Add repetition filtering for model output (n-gram detection)
- Improve reasoning extraction to exclude tool call content
- Add JSON validation for extracted tool calls
- Ensure fixes work in both streaming and non-streaming modes
parent 72917a8a
......@@ -15,6 +15,8 @@ from .parser import (
OpenAIFormatter,
ToolCallParser,
ModelParserAdapter,
filter_repetition,
validate_json_complete,
)
from .templates import AgenticTemplateManager
......@@ -36,4 +38,6 @@ __all__ = [
'ToolCallParser',
'ModelParserAdapter',
'AgenticTemplateManager',
'filter_repetition',
'validate_json_complete',
]
......@@ -26,6 +26,7 @@ def extract_reasoning_content(text: str, model_family: str = None) -> Tuple[str,
"""Extract reasoning/thinking content from model output.
Returns tuple of (reasoning_content, clean_text).
The reasoning_content will have any tool call tags stripped out.
"""
reasoning_content = ""
clean_text = text
......@@ -56,6 +57,26 @@ def extract_reasoning_content(text: str, model_family: str = None) -> Tuple[str,
for p in [r'<thought>.*?</thought>', r'<think>.*?</think>']:
clean_text = re.sub(p, '', clean_text, flags=re.DOTALL | re.IGNORECASE)
# FIX: If reasoning contains tool call tags, split at the first tool tag
# The tool call part should NOT be in reasoning - it should be left in clean_text for tool extraction
if reasoning_content:
tool_tag_patterns = ["<tool_call>", "<tool>", "<|tool_call|", "<function="]
earliest_tool_idx = len(reasoning_content)
earliest_tool_tag = None
for tag in tool_tag_patterns:
idx = reasoning_content.find(tag)
if idx != -1 and idx < earliest_tool_idx:
earliest_tool_idx = idx
earliest_tool_tag = tag
if earliest_tool_tag:
# Split: everything before the tool tag is reasoning, tool part goes back to clean_text
tool_part = reasoning_content[earliest_tool_idx:]
reasoning_content = reasoning_content[:earliest_tool_idx].strip()
# Prepend the tool part to clean_text so it can be extracted as a tool call
clean_text = tool_part + " " + clean_text
clean_text = clean_text.strip()
return reasoning_content, clean_text
......@@ -161,6 +182,24 @@ class QwenParser(BaseParser):
@validate_tool_output
def parse(self, text: str) -> List[Dict]:
# 0. PRE-VALIDATION: Check if text looks like reasoning output
# If text contains thinking/reasoning tags, extract only the content after them
# This prevents parsing partial tool calls from reasoning blocks
thinking_pattern = r'<\|.*?\|>|<(?:thought|think)>.*?((?:</(?:thought|think)>)|$)|<\|begin.*?\|><\|end.*?\|>'
has_thinking = re.search(thinking_pattern, text, flags=re.IGNORECASE)
# If text has thinking tags, check if there's actual content after them
if has_thinking:
# Find the last thinking tag position
thinking_matches = list(re.finditer(thinking_pattern, text, flags=re.DOTALL | re.IGNORECASE))
if thinking_matches:
last_think_end = thinking_matches[-1].end()
content_after_thinking = text[last_think_end:].strip()
# If there's no meaningful content after thinking, return empty
if not content_after_thinking or len(content_after_thinking) < 5:
print(f"DEBUG QwenParser: Text appears to be reasoning only, no content after thinking tags")
return []
# 1. IMMEDIATE REPETITION GUARD
# If the model is looping the same tag, we only care about the first one.
if text.count('<tool') > 1:
......@@ -218,6 +257,11 @@ class QwenParser(BaseParser):
if json_str.startswith('{') and not json_str.endswith('}'):
json_str += '}'
# Validate JSON is complete before accepting
if not validate_json_complete(json_str):
print(f"DEBUG QwenParser: JSON appears incomplete, skipping: {json_str[:50]}...")
continue
try:
data = json.loads(json_str)
if 'name' in data:
......@@ -773,6 +817,9 @@ def filter_malformed_content(text: str) -> str:
if not text:
return text
# Apply repetition filtering first
text = filter_repetition(text)
# Remove diff-like blocks that shouldn't be in the output
filtered = text
......@@ -794,6 +841,182 @@ def filter_malformed_content(text: str) -> str:
return filtered
def filter_repetition(text: str, min_repeat_count: int = 3, ngram_sizes: tuple = (2, 3)) -> str:
"""
Detect and remove n-gram repetition from text.
This function looks for sequences of 2-3 words that are repeated 3 or more times
consecutively (like "does does does" or "the the the the") and removes the duplicates.
Args:
text: The input text to filter
min_repeat_count: Minimum number of repetitions to trigger removal (default: 3)
ngram_sizes: Tuple of n-gram sizes to check (default: (2, 3))
Returns:
Text with repetition removed
"""
if not text or len(text) < 10:
return text
import re
# Split into words while preserving whitespace for reconstruction
# Use a regex that captures words and the whitespace between them
parts = re.split(r'(\s+)', text)
words = []
for i, part in enumerate(parts):
if i % 2 == 0:
# Even indices are text content
words.append(part)
else:
# Odd indices are whitespace - attach to previous word
if words:
words[-1] = words[-1] + part
if not words:
return text
# Convert to list of (word, is_word) tuples to track what to keep
result = []
i = 0
while i < len(words):
word = words[i]
# Check if this is a word (contains non-whitespace)
is_word = bool(word.strip())
if not is_word:
# Keep whitespace as-is
result.append(word)
i += 1
continue
# Try each n-gram size
found_repetition = False
for ngram_size in ngram_sizes:
if i + ngram_size * min_repeat_count > len(words):
continue
# Build the n-gram sequence to check
ngram_parts = []
valid = True
for j in range(ngram_size):
idx = i + j
if idx >= len(words):
valid = False
break
# Get the word part only (strip whitespace)
w = words[idx].strip()
if not w:
valid = False
break
ngram_parts.append(w)
if not valid or len(ngram_parts) != ngram_size:
continue
# Check if this n-gram repeats
ngram_str = ' '.join(ngram_parts)
repeat_count = 1
# Count consecutive repetitions
check_idx = i
while check_idx + ngram_size * (repeat_count + 1) <= len(words):
# Check if next n-gram matches
next_ngram = []
for j in range(ngram_size):
idx = check_idx + ngram_size * (repeat_count + 1) + j
if idx >= len(words):
break
w = words[idx].strip()
if not w:
break
next_ngram.append(w)
if next_ngram == ngram_parts:
repeat_count += 1
check_idx = check_idx + ngram_size
else:
break
# If we found enough repetitions, remove duplicates
if repeat_count >= min_repeat_count:
# Keep only the first occurrence
for j in range(ngram_size):
result.append(words[i + j])
# Skip all the repeated n-grams
i += ngram_size * repeat_count
found_repetition = True
break
if not found_repetition:
result.append(word)
i += 1
return ''.join(result)
def validate_json_complete(json_str: str) -> bool:
"""
Validate that a JSON string is complete (not truncated).
Checks for:
- Balanced braces and brackets
- No unclosed strings
- Valid structure
Args:
json_str: The JSON string to validate
Returns:
True if JSON appears complete, False if it appears truncated
"""
if not json_str:
return False
json_str = json_str.strip()
# Check if it starts with { or [
if not (json_str.startswith('{') or json_str.startswith('[')):
return False
# Try to parse it
try:
json.loads(json_str)
return True
except json.JSONDecodeError as e:
# Check if the error is due to truncation vs. syntax error
error_msg = str(e)
# Common truncation errors
if 'Expecting' in error_msg and ('property name' in error_msg or 'value' in error_msg or 'string' in error_msg):
# This is likely truncated - we got cut off in the middle
return False
# If we have a valid start but missing end, it's truncated
if json_str.endswith(',') or json_str.endswith(':'):
return False
# Check for unclosed braces/brackets
open_braces = json_str.count('{')
close_braces = json_str.count('}')
open_brackets = json_str.count('[')
close_brackets = json_str.count(']')
if open_braces > close_braces or open_brackets > close_brackets:
return False
# Try again - if it still fails, it's a syntax error
try:
json.loads(json_str)
return True
except:
return False
# =============================================================================
# Tool Formatting
# =============================================================================
......
......@@ -30,6 +30,7 @@ from threading import Thread
# Import codai module for enhanced tool call parsing
from codai.models import ModelParserDispatcher, OpenAIFormatter, ToolCallParser, ModelParserAdapter
from codai.models.parser import filter_repetition
# Import from codai modules for use in this file
from codai.models.manager import ModelManager, WhisperServerManager, MultiModelManager
......@@ -2380,6 +2381,30 @@ async def chat_completions(request: ChatCompletionRequest, http_request: Request
**extra_params,
)
# FIX: Apply repetition filtering to both reasoning and final text
reasoning_text = filter_repetition(reasoning_text)
second_pass_result = filter_repetition(second_pass_result)
# FIX: If reasoning contains tool call tags, split at the first tool tag
# The tool call part should NOT be in reasoning - it should be left for tool extraction
tool_tag_patterns = ["<tool_call>", "<tool>", "<|tool_call|", "<function="]
earliest_tool_idx = len(reasoning_text)
earliest_tool_tag = None
for tag in tool_tag_patterns:
idx = reasoning_text.find(tag)
if idx != -1 and idx < earliest_tool_idx:
earliest_tool_idx = idx
earliest_tool_tag = tag
if earliest_tool_tag:
# Split: everything before the tool tag is reasoning, everything from the tag onwards goes to second_pass_result
tool_part = reasoning_text[earliest_tool_idx:]
reasoning_text = reasoning_text[:earliest_tool_idx].strip()
# Prepend the tool part to second_pass_result so it can be extracted as a tool call
second_pass_result = tool_part + second_pass_result
if global_debug:
print(f"DEBUG: Moved tool call from reasoning to second_pass_result: {tool_part[:100]}...")
# In debug mode, dump the full generated text (second pass result)
if global_debug:
print(f"\n{'='*80}")
......@@ -2397,21 +2422,17 @@ async def chat_completions(request: ChatCompletionRequest, http_request: Request
print(reasoning_text)
print(f"{'='*80}\n")
# Try to extract tool calls from the second pass result
# If second pass is empty, try the reasoning text as fallback
# Try to extract tool calls from the second pass result ONLY
# FIX: Do NOT fall back to reasoning text - tool calls should only come from final response
extracted_tool_calls = None
text_for_tool_extraction = second_pass_result
# If second pass is empty or just whitespace, try reasoning text
if not text_for_tool_extraction or not text_for_tool_extraction.strip():
if global_debug:
print(f"DEBUG: Second pass result is empty, trying reasoning text")
print(f"DEBUG: Reasoning text length: {len(reasoning_text)}")
print(f"DEBUG: Reasoning text preview: {reasoning_text[:200] if reasoning_text else 'empty'}")
text_for_tool_extraction = reasoning_text
# CRITICAL: Only extract from second pass, never from reasoning
# Reasoning may contain partial/incomplete tool calls that confuse the parser
if global_debug:
print(f"DEBUG: Final text for tool extraction: {text_for_tool_extraction[:200] if text_for_tool_extraction else 'empty'}")
print(f"DEBUG: Tool extraction - using second_pass_result only")
print(f"DEBUG: Second pass result length: {len(second_pass_result) if second_pass_result else 0}")
print(f"DEBUG: Reasoning text length: {len(reasoning_text) if reasoning_text else 0}")
if request.tools and text_for_tool_extraction:
# Convert tools for ModelParserAdapter
......@@ -2443,6 +2464,23 @@ async def chat_completions(request: ChatCompletionRequest, http_request: Request
adapter = ModelParserAdapter(model_name=response_model_name)
extracted_tool_calls = adapter.extract_tool_calls(text_for_tool_extraction, tools_list)
# FIX: Validate extracted tool calls have valid JSON
if extracted_tool_calls:
from codai.models.parser import validate_json_complete
validated_calls = []
for tc in extracted_tool_calls:
args = tc.get('function', {}).get('arguments', '{}')
if isinstance(args, str) and validate_json_complete(args):
validated_calls.append(tc)
elif isinstance(args, dict):
# Dict is already valid
validated_calls.append(tc)
if len(validated_calls) != len(extracted_tool_calls):
if global_debug:
print(f"DEBUG: Filtered out {len(extracted_tool_calls) - len(validated_calls)} invalid tool calls")
extracted_tool_calls = validated_calls if validated_calls else None
if global_debug and extracted_tool_calls:
print(f"\n{'='*80}")
print(f"=== RAW STREAM: EXTRACTED TOOL CALLS (DEBUG) ===")
......@@ -2534,6 +2572,30 @@ async def chat_completions(request: ChatCompletionRequest, http_request: Request
# Clean up control tokens from final text
final_text = cleanup_control_tokens(final_text)
# FIX: Apply repetition filtering to reasoning and final text
reasoning_text = filter_repetition(reasoning_text)
final_text = filter_repetition(final_text)
# FIX: If reasoning contains tool call tags, split at the first tool tag
# The tool call part should NOT be in reasoning - it should be left for tool extraction in final_text
tool_tag_patterns = ["<tool_call>", "<tool>", "<|tool_call|", "<function="]
earliest_tool_idx = len(reasoning_text)
earliest_tool_tag = None
for tag in tool_tag_patterns:
idx = reasoning_text.find(tag)
if idx != -1 and idx < earliest_tool_idx:
earliest_tool_idx = idx
earliest_tool_tag = tag
if earliest_tool_tag:
# Split: everything before the tool tag is reasoning, everything from the tag onwards goes to final_text
tool_part = reasoning_text[earliest_tool_idx:]
reasoning_text = reasoning_text[:earliest_tool_idx].strip()
# Prepend the tool part to final_text so it can be extracted as a tool call
final_text = tool_part + final_text
if global_debug:
print(f"RAW: Moved tool call from reasoning to final_text: {tool_part[:100]}...")
if global_debug:
print(f"RAW: Final text after cleanup: {final_text[:100]}...")
......@@ -2607,6 +2669,22 @@ async def chat_completions(request: ChatCompletionRequest, http_request: Request
# Extract tool calls from final_text only (after reasoning is done)
extracted_tool_calls = adapter.extract_tool_calls(final_text, tools_list)
# FIX: Validate extracted tool calls have valid JSON
if extracted_tool_calls:
from codai.models.parser import validate_json_complete
validated_calls = []
for tc in extracted_tool_calls:
args = tc.get('function', {}).get('arguments', '{}')
if isinstance(args, str) and validate_json_complete(args):
validated_calls.append(tc)
elif isinstance(args, dict):
# Dict is already valid
validated_calls.append(tc)
if len(validated_calls) != len(extracted_tool_calls):
print(f"DEBUG: Filtered out {len(extracted_tool_calls) - len(validated_calls)} invalid tool calls in non-streaming")
extracted_tool_calls = validated_calls if validated_calls else None
if extracted_tool_calls:
# Strip tool calls from the text
clean_text = adapter.strip_tool_calls_from_content(final_text)
......@@ -2898,6 +2976,9 @@ async def stream_chat_response(
# Always filter malformed content
filtered_chunk = filter_malformed_content(chunk)
# Apply repetition filtering to prevent infinite loops
filtered_chunk = filter_repetition(filtered_chunk)
# Always filter out tool call format
filtered_chunk = tool_parser.strip_tool_calls_from_content(filtered_chunk)
......@@ -2965,6 +3046,20 @@ async def stream_chat_response(
continue
try:
tool_calls = tool_parser.extract_tool_calls(generated_text, tool_objects)
# FIX: Validate extracted tool calls have valid JSON (stream_chat_response)
if tool_calls:
from codai.models.parser import validate_json_complete
validated_calls = []
for tc in tool_calls:
args = tc.get('function', {}).get('arguments', '{}')
if isinstance(args, str) and validate_json_complete(args):
validated_calls.append(tc)
elif isinstance(args, dict):
validated_calls.append(tc)
if len(validated_calls) != len(tool_calls):
print(f"DEBUG: Filtered out {len(tool_calls) - len(validated_calls)} invalid tool calls in stream_chat_response")
tool_calls = validated_calls if validated_calls else None
except Exception as e:
print(f"DEBUG: Error extracting tool calls: {e}")
tool_calls = None
......@@ -3118,6 +3213,9 @@ async def generate_chat_response(
# Always filter out malformed content
generated_text = filter_malformed_content(generated_text)
# Apply repetition filtering to prevent infinite loops
generated_text = filter_repetition(generated_text)
# Dump raw output if enabled
if global_dump:
print(f"\n{'='*80}")
......@@ -3160,6 +3258,20 @@ async def generate_chat_response(
continue
try:
tool_calls = tool_parser.extract_tool_calls(generated_text, tool_objects)
# FIX: Validate extracted tool calls have valid JSON (generate_chat_response)
if tool_calls:
from codai.models.parser import validate_json_complete
validated_calls = []
for tc in tool_calls:
args = tc.get('function', {}).get('arguments', '{}')
if isinstance(args, str) and validate_json_complete(args):
validated_calls.append(tc)
elif isinstance(args, dict):
validated_calls.append(tc)
if len(validated_calls) != len(tool_calls):
print(f"DEBUG: Filtered out {len(tool_calls) - len(validated_calls)} invalid tool calls in generate_chat_response")
tool_calls = validated_calls if validated_calls else None
except Exception as e:
print(f"DEBUG: Error extracting tool calls: {e}")
tool_calls = None
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment