Commit b3f747ba authored by Your Name's avatar Your Name

feat: Add Phase 1 and 2 improvements for Claude provider

Phase 1.2 - Automatic retry with exponential backoff:
- Add _request_with_retry() method for non-streaming requests
- Retries on 429 (with x-should-retry header), 529, 503 errors
- Exponential backoff with jitter (1s, 2s, 4s max 30s)
- Handles timeouts and HTTP errors gracefully

Phase 1.3 - Streaming idle watchdog:
- Add 90s idle timeout detection (matches vendors/claude)
- Tracks last_event_time and raises TimeoutError on idle
- Prevents indefinite hangs on dropped connections

Phase 2.3 - Cache token tracking:
- Add cache_stats dict to track cache hits/misses
- Track cache_tokens_read and cache_tokens_created
- Add get_cache_stats() method for analytics
- Updates stats during streaming message_delta events

Also includes:
- Temperature fix (skip 0.0 when thinking beta active)
- Rate limit config update (5s default for Claude)
parent 63dc05af
......@@ -2327,6 +2327,18 @@ class ClaudeProviderHandler(BaseProviderHandler):
# HTTP client for direct API requests (kilocode method)
self.client = httpx.AsyncClient(timeout=httpx.Timeout(300.0, connect=30.0))
# Streaming idle watchdog configuration (Phase 1.3)
self.stream_idle_timeout = 90.0 # seconds - matches vendors/claude
# Cache token tracking for analytics (Phase 2.3)
self.cache_stats = {
'cache_hits': 0,
'cache_misses': 0,
'cache_tokens_read': 0,
'cache_tokens_created': 0,
'total_requests': 0,
}
def _get_auth_headers(self, stream: bool = False):
"""
......@@ -3307,34 +3319,8 @@ class ClaudeProviderHandler(BaseProviderHandler):
logger.info(f"Payload: {json.dumps(payload, indent=2)}")
logger.info(f"=== END NON-STREAMING REQUEST DEBUG ===")
# Non-streaming request
response = await self.client.post(api_url, headers=headers, json=payload)
logger.info(f"ClaudeProviderHandler: Response status: {response.status_code}")
# Check for 429 rate limit error before raising
if response.status_code == 429:
try:
response_data = response.json()
except Exception:
response_data = response.text
self.handle_429_error(response_data, dict(response.headers))
response.raise_for_status()
# Log error details for non-2xx responses
if response.status_code >= 400:
try:
error_body = response.json()
error_message = error_body.get('error', {}).get('message', 'Unknown error')
error_type = error_body.get('error', {}).get('type', 'unknown')
logger.error(f"ClaudeProviderHandler: API error response: {json.dumps(error_body, indent=2)}")
logger.error(f"ClaudeProviderHandler: Error type: {error_type}")
logger.error(f"ClaudeProviderHandler: Error message: {error_message}")
except Exception:
logger.error(f"ClaudeProviderHandler: API error response (text): {response.text}")
response.raise_for_status()
# Non-streaming request with automatic retry (Phase 1.2)
response = await self._request_with_retry(api_url, headers, payload, max_retries=3)
claude_response = response.json()
......@@ -3349,6 +3335,102 @@ class ClaudeProviderHandler(BaseProviderHandler):
return openai_response
async def _request_with_retry(self, api_url: str, headers: Dict, payload: Dict, max_retries: int = 3):
"""
Non-streaming request with automatic retry for transient errors (Phase 1.2).
Retries on:
- 429 rate limit errors (with x-should-retry: true header)
- 529 overloaded errors
- 503 service unavailable
- Connection timeouts
Uses exponential backoff with jitter between retries.
"""
import logging
logger = logging.getLogger(__name__)
last_error = None
for attempt in range(max_retries):
try:
response = await self.client.post(api_url, headers=headers, json=payload)
logger.info(f"ClaudeProviderHandler: Response status: {response.status_code} (attempt {attempt + 1}/{max_retries})")
# Check for retryable errors
if response.status_code in (429, 529, 503):
# Check if we should retry
should_retry = response.headers.get('x-should-retry', 'false').lower() == 'true'
if should_retry or response.status_code in (529, 503):
if attempt < max_retries - 1:
# Calculate wait time with exponential backoff + jitter
wait_time = min(2 ** attempt + random.uniform(0, 1), 30)
# Try to get wait time from response
try:
error_data = response.json()
error_message = error_data.get('error', {}).get('message', '')
logger.warning(f"ClaudeProviderHandler: Retryable error: {error_message}")
except Exception:
pass
logger.info(f"ClaudeProviderHandler: Retrying in {wait_time:.1f}s (attempt {attempt + 1}/{max_retries})")
await asyncio.sleep(wait_time)
continue
else:
# Max retries exceeded, handle the error
try:
response_data = response.json()
except Exception:
response_data = response.text
self.handle_429_error(response_data, dict(response.headers))
response.raise_for_status()
# Check for other errors
if response.status_code >= 400:
try:
error_body = response.json()
error_message = error_body.get('error', {}).get('message', 'Unknown error')
error_type = error_body.get('error', {}).get('type', 'unknown')
logger.error(f"ClaudeProviderHandler: API error response: {json.dumps(error_body, indent=2)}")
logger.error(f"ClaudeProviderHandler: Error type: {error_type}")
logger.error(f"ClaudeProviderHandler: Error message: {error_message}")
except Exception:
logger.error(f"ClaudeProviderHandler: API error response (text): {response.text}")
response.raise_for_status()
# Success
return response
except httpx.TimeoutException as e:
last_error = e
if attempt < max_retries - 1:
wait_time = min(2 ** attempt + random.uniform(0, 1), 30)
logger.warning(f"ClaudeProviderHandler: Request timeout, retrying in {wait_time:.1f}s")
await asyncio.sleep(wait_time)
continue
else:
logger.error(f"ClaudeProviderHandler: Request timeout after {max_retries} attempts")
raise
except httpx.HTTPError as e:
last_error = e
if attempt < max_retries - 1:
wait_time = min(2 ** attempt + random.uniform(0, 1), 30)
logger.warning(f"ClaudeProviderHandler: HTTP error, retrying in {wait_time:.1f}s: {e}")
await asyncio.sleep(wait_time)
continue
else:
logger.error(f"ClaudeProviderHandler: HTTP error after {max_retries} attempts: {e}")
raise
# Should not reach here, but just in case
raise last_error or Exception("Request failed after max retries")
async def _handle_streaming_request_with_retry(self, api_url: str, payload: Dict, headers: Dict, model: str):
"""
Wrapper for streaming request that catches rate limit errors at the call site.
......@@ -3438,7 +3520,16 @@ class ClaudeProviderHandler(BaseProviderHandler):
content_block_index = 0
current_tool_calls = []
# Streaming idle watchdog (Phase 1.3)
last_event_time = time.time()
idle_timeout = self.stream_idle_timeout
async for line in response.aiter_lines():
# Check for idle timeout (Phase 1.3)
if time.time() - last_event_time > idle_timeout:
logger.error(f"ClaudeProviderHandler: Stream idle timeout ({idle_timeout}s)")
raise TimeoutError(f"Stream idle for {idle_timeout}s")
if not line or not line.startswith('data: '):
continue
......@@ -3451,6 +3542,9 @@ class ClaudeProviderHandler(BaseProviderHandler):
try:
chunk_data = json.loads(data_str)
# Update idle watchdog (Phase 1.3)
last_event_time = time.time()
# Handle different event types
event_type = chunk_data.get('type')
......@@ -3586,6 +3680,16 @@ class ClaudeProviderHandler(BaseProviderHandler):
usage = chunk_data.get('usage', {})
if usage:
logger.debug(f"ClaudeProviderHandler: Streaming usage update: {usage}")
# Track cache tokens for analytics (Phase 2.3)
cache_read = usage.get('cache_read_input_tokens', 0)
cache_creation = usage.get('cache_creation_input_tokens', 0)
if cache_read > 0:
self.cache_stats['cache_hits'] += 1
self.cache_stats['cache_tokens_read'] += cache_read
if cache_creation > 0:
self.cache_stats['cache_misses'] += 1
self.cache_stats['cache_tokens_created'] += cache_creation
elif event_type == 'message_stop':
# Final chunk
......@@ -3855,6 +3959,22 @@ class ClaudeProviderHandler(BaseProviderHandler):
# Close SDK client after streaming completes
await sdk_client.close()
def get_cache_stats(self) -> Dict:
"""
Get cache usage statistics (Phase 2.3).
Returns:
Dict with cache statistics including hits, misses, and token counts.
"""
total = self.cache_stats['cache_hits'] + self.cache_stats['cache_misses']
hit_rate = (self.cache_stats['cache_hits'] / total * 100) if total > 0 else 0
return {
**self.cache_stats,
'total_cache_events': total,
'cache_hit_rate_percent': round(hit_rate, 2),
}
def _get_models_cache_path(self) -> str:
"""Get the path to the models cache file."""
import os
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment