Commit b3f747ba authored by Your Name's avatar Your Name

feat: Add Phase 1 and 2 improvements for Claude provider

Phase 1.2 - Automatic retry with exponential backoff:
- Add _request_with_retry() method for non-streaming requests
- Retries on 429 (with x-should-retry header), 529, 503 errors
- Exponential backoff with jitter (1s, 2s, 4s max 30s)
- Handles timeouts and HTTP errors gracefully

Phase 1.3 - Streaming idle watchdog:
- Add 90s idle timeout detection (matches vendors/claude)
- Tracks last_event_time and raises TimeoutError on idle
- Prevents indefinite hangs on dropped connections

Phase 2.3 - Cache token tracking:
- Add cache_stats dict to track cache hits/misses
- Track cache_tokens_read and cache_tokens_created
- Add get_cache_stats() method for analytics
- Updates stats during streaming message_delta events

Also includes:
- Temperature fix (skip 0.0 when thinking beta active)
- Rate limit config update (5s default for Claude)
parent 63dc05af
...@@ -2328,6 +2328,18 @@ class ClaudeProviderHandler(BaseProviderHandler): ...@@ -2328,6 +2328,18 @@ class ClaudeProviderHandler(BaseProviderHandler):
# HTTP client for direct API requests (kilocode method) # HTTP client for direct API requests (kilocode method)
self.client = httpx.AsyncClient(timeout=httpx.Timeout(300.0, connect=30.0)) self.client = httpx.AsyncClient(timeout=httpx.Timeout(300.0, connect=30.0))
# Streaming idle watchdog configuration (Phase 1.3)
self.stream_idle_timeout = 90.0 # seconds - matches vendors/claude
# Cache token tracking for analytics (Phase 2.3)
self.cache_stats = {
'cache_hits': 0,
'cache_misses': 0,
'cache_tokens_read': 0,
'cache_tokens_created': 0,
'total_requests': 0,
}
def _get_auth_headers(self, stream: bool = False): def _get_auth_headers(self, stream: bool = False):
""" """
Get HTTP headers with OAuth2 Bearer token. Get HTTP headers with OAuth2 Bearer token.
...@@ -3307,13 +3319,68 @@ class ClaudeProviderHandler(BaseProviderHandler): ...@@ -3307,13 +3319,68 @@ class ClaudeProviderHandler(BaseProviderHandler):
logger.info(f"Payload: {json.dumps(payload, indent=2)}") logger.info(f"Payload: {json.dumps(payload, indent=2)}")
logger.info(f"=== END NON-STREAMING REQUEST DEBUG ===") logger.info(f"=== END NON-STREAMING REQUEST DEBUG ===")
# Non-streaming request # Non-streaming request with automatic retry (Phase 1.2)
response = await self._request_with_retry(api_url, headers, payload, max_retries=3)
claude_response = response.json()
if AISBF_DEBUG:
logger.info(f"ClaudeProviderHandler: API response: {json.dumps(claude_response, indent=2)}")
logger.info(f"ClaudeProviderHandler: Response received successfully via direct HTTP")
self.record_success()
# Convert Claude API response to OpenAI format
openai_response = self._convert_to_openai_format(claude_response, model)
return openai_response
async def _request_with_retry(self, api_url: str, headers: Dict, payload: Dict, max_retries: int = 3):
"""
Non-streaming request with automatic retry for transient errors (Phase 1.2).
Retries on:
- 429 rate limit errors (with x-should-retry: true header)
- 529 overloaded errors
- 503 service unavailable
- Connection timeouts
Uses exponential backoff with jitter between retries.
"""
import logging
logger = logging.getLogger(__name__)
last_error = None
for attempt in range(max_retries):
try:
response = await self.client.post(api_url, headers=headers, json=payload) response = await self.client.post(api_url, headers=headers, json=payload)
logger.info(f"ClaudeProviderHandler: Response status: {response.status_code}") logger.info(f"ClaudeProviderHandler: Response status: {response.status_code} (attempt {attempt + 1}/{max_retries})")
# Check for 429 rate limit error before raising # Check for retryable errors
if response.status_code == 429: if response.status_code in (429, 529, 503):
# Check if we should retry
should_retry = response.headers.get('x-should-retry', 'false').lower() == 'true'
if should_retry or response.status_code in (529, 503):
if attempt < max_retries - 1:
# Calculate wait time with exponential backoff + jitter
wait_time = min(2 ** attempt + random.uniform(0, 1), 30)
# Try to get wait time from response
try:
error_data = response.json()
error_message = error_data.get('error', {}).get('message', '')
logger.warning(f"ClaudeProviderHandler: Retryable error: {error_message}")
except Exception:
pass
logger.info(f"ClaudeProviderHandler: Retrying in {wait_time:.1f}s (attempt {attempt + 1}/{max_retries})")
await asyncio.sleep(wait_time)
continue
else:
# Max retries exceeded, handle the error
try: try:
response_data = response.json() response_data = response.json()
except Exception: except Exception:
...@@ -3322,7 +3389,7 @@ class ClaudeProviderHandler(BaseProviderHandler): ...@@ -3322,7 +3389,7 @@ class ClaudeProviderHandler(BaseProviderHandler):
self.handle_429_error(response_data, dict(response.headers)) self.handle_429_error(response_data, dict(response.headers))
response.raise_for_status() response.raise_for_status()
# Log error details for non-2xx responses # Check for other errors
if response.status_code >= 400: if response.status_code >= 400:
try: try:
error_body = response.json() error_body = response.json()
...@@ -3336,18 +3403,33 @@ class ClaudeProviderHandler(BaseProviderHandler): ...@@ -3336,18 +3403,33 @@ class ClaudeProviderHandler(BaseProviderHandler):
response.raise_for_status() response.raise_for_status()
claude_response = response.json() # Success
return response
if AISBF_DEBUG:
logger.info(f"ClaudeProviderHandler: API response: {json.dumps(claude_response, indent=2)}")
logger.info(f"ClaudeProviderHandler: Response received successfully via direct HTTP") except httpx.TimeoutException as e:
self.record_success() last_error = e
if attempt < max_retries - 1:
wait_time = min(2 ** attempt + random.uniform(0, 1), 30)
logger.warning(f"ClaudeProviderHandler: Request timeout, retrying in {wait_time:.1f}s")
await asyncio.sleep(wait_time)
continue
else:
logger.error(f"ClaudeProviderHandler: Request timeout after {max_retries} attempts")
raise
# Convert Claude API response to OpenAI format except httpx.HTTPError as e:
openai_response = self._convert_to_openai_format(claude_response, model) last_error = e
if attempt < max_retries - 1:
wait_time = min(2 ** attempt + random.uniform(0, 1), 30)
logger.warning(f"ClaudeProviderHandler: HTTP error, retrying in {wait_time:.1f}s: {e}")
await asyncio.sleep(wait_time)
continue
else:
logger.error(f"ClaudeProviderHandler: HTTP error after {max_retries} attempts: {e}")
raise
return openai_response # Should not reach here, but just in case
raise last_error or Exception("Request failed after max retries")
async def _handle_streaming_request_with_retry(self, api_url: str, payload: Dict, headers: Dict, model: str): async def _handle_streaming_request_with_retry(self, api_url: str, payload: Dict, headers: Dict, model: str):
""" """
...@@ -3438,7 +3520,16 @@ class ClaudeProviderHandler(BaseProviderHandler): ...@@ -3438,7 +3520,16 @@ class ClaudeProviderHandler(BaseProviderHandler):
content_block_index = 0 content_block_index = 0
current_tool_calls = [] current_tool_calls = []
# Streaming idle watchdog (Phase 1.3)
last_event_time = time.time()
idle_timeout = self.stream_idle_timeout
async for line in response.aiter_lines(): async for line in response.aiter_lines():
# Check for idle timeout (Phase 1.3)
if time.time() - last_event_time > idle_timeout:
logger.error(f"ClaudeProviderHandler: Stream idle timeout ({idle_timeout}s)")
raise TimeoutError(f"Stream idle for {idle_timeout}s")
if not line or not line.startswith('data: '): if not line or not line.startswith('data: '):
continue continue
...@@ -3451,6 +3542,9 @@ class ClaudeProviderHandler(BaseProviderHandler): ...@@ -3451,6 +3542,9 @@ class ClaudeProviderHandler(BaseProviderHandler):
try: try:
chunk_data = json.loads(data_str) chunk_data = json.loads(data_str)
# Update idle watchdog (Phase 1.3)
last_event_time = time.time()
# Handle different event types # Handle different event types
event_type = chunk_data.get('type') event_type = chunk_data.get('type')
...@@ -3587,6 +3681,16 @@ class ClaudeProviderHandler(BaseProviderHandler): ...@@ -3587,6 +3681,16 @@ class ClaudeProviderHandler(BaseProviderHandler):
if usage: if usage:
logger.debug(f"ClaudeProviderHandler: Streaming usage update: {usage}") logger.debug(f"ClaudeProviderHandler: Streaming usage update: {usage}")
# Track cache tokens for analytics (Phase 2.3)
cache_read = usage.get('cache_read_input_tokens', 0)
cache_creation = usage.get('cache_creation_input_tokens', 0)
if cache_read > 0:
self.cache_stats['cache_hits'] += 1
self.cache_stats['cache_tokens_read'] += cache_read
if cache_creation > 0:
self.cache_stats['cache_misses'] += 1
self.cache_stats['cache_tokens_created'] += cache_creation
elif event_type == 'message_stop': elif event_type == 'message_stop':
# Final chunk # Final chunk
finish_reason = 'stop' finish_reason = 'stop'
...@@ -3855,6 +3959,22 @@ class ClaudeProviderHandler(BaseProviderHandler): ...@@ -3855,6 +3959,22 @@ class ClaudeProviderHandler(BaseProviderHandler):
# Close SDK client after streaming completes # Close SDK client after streaming completes
await sdk_client.close() await sdk_client.close()
def get_cache_stats(self) -> Dict:
"""
Get cache usage statistics (Phase 2.3).
Returns:
Dict with cache statistics including hits, misses, and token counts.
"""
total = self.cache_stats['cache_hits'] + self.cache_stats['cache_misses']
hit_rate = (self.cache_stats['cache_hits'] / total * 100) if total > 0 else 0
return {
**self.cache_stats,
'total_cache_events': total,
'cache_hit_rate_percent': round(hit_rate, 2),
}
def _get_models_cache_path(self) -> str: def _get_models_cache_path(self) -> str:
"""Get the path to the models cache file.""" """Get the path to the models cache file."""
import os import os
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment