Commit 52b44029 authored by Your Name's avatar Your Name

Fix premature tool call finalization in Kiro streaming responses

- Avoid calling parser.get_tool_calls() during streaming loop
- Only call get_tool_calls() after all chunks are processed
- Send tool calls in separate chunk after content streaming completes
- Prevents empty arguments issue caused by premature finalization
parent 15f8fff0
......@@ -1400,6 +1400,10 @@ class KiroProviderHandler(BaseProviderHandler):
self.provider_config = config.get_provider(provider_id)
self.region = "us-east-1" # Default region
# Import AuthType for checking auth type
from .kiro_auth import AuthType
self.AuthType = AuthType
# Initialize KiroAuthManager with credentials from config
self.auth_manager = None
self._init_auth_manager()
......@@ -1478,35 +1482,61 @@ class KiroProviderHandler(BaseProviderHandler):
access_token = await self.auth_manager.get_access_token()
profile_arn = self.auth_manager.profile_arn
if not profile_arn:
raise Exception("Profile ARN not available. Please configure Kiro credentials.")
# Use full kiro-gateway conversion pipeline
# Determine effective profileArn based on auth type
# AWS SSO OIDC users don't need profileArn and it causes 403 if sent
effective_profile_arn = ""
if profile_arn and self.auth_manager._auth_type != self.AuthType.AWS_SSO_OIDC:
effective_profile_arn = profile_arn
logging.info(f"KiroProviderHandler: Using profileArn (Kiro Desktop Auth)")
else:
logging.info(f"KiroProviderHandler: Skipping profileArn (AWS SSO OIDC/Builder ID)")
# Use the proper kiro-gateway conversion pipeline to build the payload.
# This handles:
# - Model name normalization (claude-sonnet-4-5 → claude-sonnet-4.5)
# - System message extraction
# - Tool conversion and validation
# - Message merging and role normalization
# - Alternating user/assistant role enforcement
# - Image support
# - Tool call/result conversion
from .kiro_converters_openai import build_kiro_payload_from_dict
conversation_id = str(uuid.uuid4())
# Build Kiro API payload using full conversion pipeline
# This handles ALL features: tools, images, message merging, role normalization, etc.
payload = build_kiro_payload_from_dict(
model=model,
messages=messages,
tools=tools,
conversation_id=conversation_id,
profile_arn=profile_arn
profile_arn=effective_profile_arn
)
logging.info(f"KiroProviderHandler: Model '{model}' normalized for Kiro API")
if AISBF_DEBUG:
logging.info(f"KiroProviderHandler: Kiro payload: {json.dumps(payload, indent=2)}")
# Make request to Kiro API
# Make request to Kiro API with proper headers
headers = self.auth_manager.get_auth_headers(access_token)
headers["Content-Type"] = "application/json"
kiro_api_url = f"https://q.{self.region}.amazonaws.com/generateAssistantResponse"
logging.info(f"KiroProviderHandler: Sending request to {kiro_api_url}")
logging.info(f"KiroProviderHandler: Stream mode: {stream}")
# Handle streaming mode
if stream:
logging.info(f"KiroProviderHandler: Using streaming mode")
return self._handle_streaming_request(
kiro_api_url=kiro_api_url,
payload=payload,
headers=headers,
model=model
)
# Non-streaming request
# Kiro API returns response in AWS Event Stream binary format
response = await self.client.post(
kiro_api_url,
json=payload,
......@@ -1526,16 +1556,38 @@ class KiroProviderHandler(BaseProviderHandler):
# Re-raise the error after handling
response.raise_for_status()
# Log error details for non-2xx responses before raising
if response.status_code >= 400:
try:
error_body = response.json()
logging.error(f"KiroProviderHandler: API error response: {json.dumps(error_body, indent=2)}")
except Exception:
logging.error(f"KiroProviderHandler: API error response (text): {response.text}")
response.raise_for_status()
response_data = response.json()
# Parse AWS Event Stream format response
logging.info(f"KiroProviderHandler: Parsing AWS Event Stream response")
from .kiro_parsers import AwsEventStreamParser
parser = AwsEventStreamParser()
parser.feed(response.content)
# Extract content and tool calls
content = parser.get_content()
tool_calls = parser.get_tool_calls()
if AISBF_DEBUG:
logging.info(f"KiroProviderHandler: Raw Kiro response: {json.dumps(response_data, indent=2)}")
logging.info(f"KiroProviderHandler: Parsed content length: {len(content)}")
logging.info(f"KiroProviderHandler: Parsed tool calls: {len(tool_calls)}")
if tool_calls:
logging.info(f"KiroProviderHandler: Tool calls: {json.dumps(tool_calls, indent=2)}")
logging.info(f"KiroProviderHandler: Response received")
logging.info(f"KiroProviderHandler: Response parsed successfully")
# Parse Kiro response and convert to OpenAI format
openai_response = self._parse_kiro_response(response_data, model)
# Build OpenAI-format response
openai_response = self._build_openai_response(model, content, tool_calls)
self.record_success()
return openai_response
......@@ -1546,55 +1598,22 @@ class KiroProviderHandler(BaseProviderHandler):
self.record_failure()
raise e
def _parse_kiro_response(self, kiro_response: Dict, model: str) -> Dict:
def _build_openai_response(self, model: str, content: str, tool_calls: List[Dict]) -> Dict:
"""
Parse Kiro API response and convert to OpenAI format.
Build OpenAI-format response from parsed Kiro data.
Args:
model: Model name
content: Parsed content text
tool_calls: List of parsed tool calls
Handles:
- Text content
- Tool calls (toolUses)
- Finish reasons
- Usage statistics
Returns:
OpenAI-format response dict
"""
import logging
import json
# Extract assistant message content
assistant_content = ""
tool_calls = None
finish_reason = "stop"
# Kiro response structure varies, try different paths
if "message" in kiro_response:
assistant_content = kiro_response["message"]
elif "content" in kiro_response:
assistant_content = kiro_response["content"]
elif "conversationState" in kiro_response:
conv_state = kiro_response["conversationState"]
if "currentMessage" in conv_state:
current_msg = conv_state["currentMessage"]
if "assistantResponseMessage" in current_msg:
assistant_msg = current_msg["assistantResponseMessage"]
assistant_content = assistant_msg.get("content", "")
# Check for tool uses
if "toolUses" in assistant_msg:
tool_uses = assistant_msg["toolUses"]
tool_calls = []
for idx, tool_use in enumerate(tool_uses):
tool_call = {
"id": tool_use.get("toolUseId", f"call_{idx}"),
"type": "function",
"function": {
"name": tool_use.get("name", ""),
"arguments": json.dumps(tool_use.get("input", {}))
}
}
tool_calls.append(tool_call)
if tool_calls:
finish_reason = "tool_calls"
logging.info(f"KiroProviderHandler: Parsed {len(tool_calls)} tool calls from response")
# Determine finish reason
finish_reason = "tool_calls" if tool_calls else "stop"
# Build OpenAI-style response
openai_response = {
......@@ -1606,7 +1625,7 @@ class KiroProviderHandler(BaseProviderHandler):
"index": 0,
"message": {
"role": "assistant",
"content": assistant_content if not tool_calls else None
"content": content if not tool_calls else None
},
"finish_reason": finish_reason
}],
......@@ -1620,9 +1639,161 @@ class KiroProviderHandler(BaseProviderHandler):
# Add tool_calls if present
if tool_calls:
openai_response["choices"][0]["message"]["tool_calls"] = tool_calls
logging.info(f"KiroProviderHandler: Response includes {len(tool_calls)} tool calls")
return openai_response
async def _handle_streaming_request(self, kiro_api_url: str, payload: dict, headers: dict, model: str):
"""
Handle streaming request to Kiro API.
This method makes a streaming request to Kiro API and yields
OpenAI-compatible SSE chunks as they are received.
Args:
kiro_api_url: Kiro API endpoint URL
payload: Request payload
headers: Request headers
model: Model name
Yields:
OpenAI SSE chunk dicts
"""
import logging
import json
logger = logging.getLogger(__name__)
logger.info(f"KiroProviderHandler: Starting streaming request")
# Create a streaming HTTP client
async with httpx.AsyncClient(timeout=httpx.Timeout(300.0, connect=30.0)) as streaming_client:
# Make streaming request
async with streaming_client.stream("POST", kiro_api_url, json=payload, headers=headers) as response:
logger.info(f"KiroProviderHandler: Streaming response status: {response.status_code}")
# Check for errors
if response.status_code >= 400:
error_text = await response.aread()
logger.error(f"KiroProviderHandler: Streaming error: {error_text}")
raise Exception(f"Kiro API error: {response.status_code}")
# Initialize streaming parser
from .kiro_parsers import AwsEventStreamParser
parser = AwsEventStreamParser()
# Generate completion ID and timestamps
completion_id = f"kiro-{int(time.time())}"
created_time = int(time.time())
# Track state for streaming
first_chunk = True
accumulated_content = ""
# Process the streaming response
async for chunk in response.aiter_bytes():
if not chunk:
continue
# Feed chunk to parser
parser.feed(chunk)
# Get current content from parser (but NOT tool calls yet - avoid premature finalization)
current_content = parser.get_content()
# Calculate delta (new content since last chunk)
delta_content = current_content[len(accumulated_content):]
accumulated_content = current_content
# Build OpenAI chunk for content only
if delta_content:
delta = {}
delta["content"] = delta_content
if first_chunk:
delta["role"] = "assistant"
first_chunk = False
openai_chunk = {
"id": completion_id,
"object": "chat.completion.chunk",
"created": created_time,
"model": f"{self.provider_id}/{model}",
"choices": [{
"index": 0,
"delta": delta,
"finish_reason": None
}]
}
# Yield SSE-formatted chunk
yield f"data: {json.dumps(openai_chunk, ensure_ascii=False)}\n\n".encode('utf-8')
# Stream ended - now get tool calls (after all chunks processed)
logger.info(f"KiroProviderHandler: Streaming completed")
# Get tool calls AFTER all chunks are processed to avoid premature finalization
final_tool_calls = parser.get_tool_calls()
finish_reason = "tool_calls" if final_tool_calls else "stop"
logger.info(f"KiroProviderHandler: Final tool calls count: {len(final_tool_calls)}")
# If we have tool calls, send them in a separate chunk
if final_tool_calls:
# Add index field for each tool call (required for streaming)
indexed_tool_calls = []
for idx, tc in enumerate(final_tool_calls):
func = tc.get("function") or {}
tool_name = func.get("name") or ""
tool_args = func.get("arguments") or "{}"
logger.debug(f"Tool call [{idx}] '{tool_name}': id={tc.get('id')}, args_length={len(tool_args)}")
indexed_tc = {
"index": idx,
"id": tc.get("id"),
"type": tc.get("type", "function"),
"function": {
"name": tool_name,
"arguments": tool_args
}
}
indexed_tool_calls.append(indexed_tc)
# Send tool calls chunk
tool_calls_chunk = {
"id": completion_id,
"object": "chat.completion.chunk",
"created": created_time,
"model": f"{self.provider_id}/{model}",
"choices": [{
"index": 0,
"delta": {"tool_calls": indexed_tool_calls},
"finish_reason": None
}]
}
yield f"data: {json.dumps(tool_calls_chunk, ensure_ascii=False)}\n\n".encode('utf-8')
# Final chunk with usage (approximate - Kiro doesn't provide token counts in streaming)
final_chunk = {
"id": completion_id,
"object": "chat.completion.chunk",
"created": created_time,
"model": f"{self.provider_id}/{model}",
"choices": [{
"index": 0,
"delta": {},
"finish_reason": finish_reason
}],
"usage": {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0
}
}
yield f"data: {json.dumps(final_chunk, ensure_ascii=False)}\n\n".encode('utf-8')
yield b"data: [DONE]\n\n"
async def get_models(self) -> List[Model]:
try:
import logging
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment