Fix effective_context variable scope and calculate total tokens for streaming responses

- Pass effective_context as parameter to stream_generator functions
- Update _create_streaming_response signature to accept effective_context
- Update all calls to _create_streaming_response to pass effective_context
- Track accumulated response text for token counting in streaming
- Calculate completion tokens for Google responses (since Google doesn't provide them)
- Calculate completion tokens for non-Google providers when they don't provide token counts
- Include prompt_tokens, completion_tokens, total_tokens, and effective_context in final chunk
- Fixes 'name effective_context is not defined' error in streaming responses
- Fixes issue where streaming responses had null token counts
parent d9cdab8b
......@@ -381,7 +381,7 @@ class RequestHandler:
# Update request_data with condensed messages
request_data['messages'] = messages
async def stream_generator():
async def stream_generator(effective_context):
import logging
import time
import json
......@@ -415,6 +415,10 @@ class RequestHandler:
created_time = int(time.time())
response_id = f"google-{request_data['model']}-{created_time}"
# Track completion tokens for Google responses (since Google doesn't provide them)
completion_tokens = 0
accumulated_response_text = "" # Track full response for token counting
# Collect all chunks first to know when we're at the last one
chunks_list = []
async for chunk in response:
......@@ -486,6 +490,10 @@ class RequestHandler:
chunk_id += 1
logger.debug(f"OpenAI chunk (delta length: {len(delta_text)}, finish: {chunk_finish_reason})")
# Track completion tokens for Google responses
if delta_text:
accumulated_response_text += delta_text
# Serialize as JSON
yield f"data: {json.dumps(openai_chunk)}\n\n".encode('utf-8')
......@@ -500,6 +508,10 @@ class RequestHandler:
continue
# Send final chunk with usage statistics (empty content)
# Calculate completion tokens for Google responses (count tokens in full response)
if accumulated_response_text:
completion_tokens = count_messages_tokens([{"role": "assistant", "content": accumulated_response_text}], request_data['model'])
total_tokens = effective_context + completion_tokens
final_chunk = {
"id": response_id,
"object": "chat.completion.chunk",
......@@ -508,9 +520,9 @@ class RequestHandler:
"service_tier": None,
"system_fingerprint": system_fingerprint,
"usage": {
"prompt_tokens": None,
"completion_tokens": None,
"total_tokens": None,
"prompt_tokens": effective_context,
"completion_tokens": completion_tokens,
"total_tokens": total_tokens,
"effective_context": effective_context
},
"provider": provider_id,
......@@ -533,6 +545,7 @@ class RequestHandler:
# Handle OpenAI/Anthropic streaming responses
# OpenAI SDK returns a sync Stream object, not an async iterator
# So we use a regular for loop, not async for
accumulated_response_text = "" # Track full response for token counting
for chunk in response:
try:
# Debug: Log chunk type and content before serialization
......@@ -543,6 +556,15 @@ class RequestHandler:
# Convert chunk to dict and serialize as JSON
chunk_dict = chunk.model_dump() if hasattr(chunk, 'model_dump') else chunk
# Track response content for token calculation
if isinstance(chunk_dict, dict):
choices = chunk_dict.get('choices', [])
if choices:
delta = choices[0].get('delta', {})
delta_content = delta.get('content', '')
if delta_content:
accumulated_response_text += delta_content
# Add effective_context to the last chunk (when finish_reason is present)
if isinstance(chunk_dict, dict):
choices = chunk_dict.get('choices', [])
......@@ -552,6 +574,18 @@ class RequestHandler:
chunk_dict['usage'] = {}
chunk_dict['usage']['effective_context'] = effective_context
# If provider doesn't provide token counts, calculate them
if chunk_dict['usage'].get('total_tokens') is None:
# Calculate completion tokens from accumulated response
if accumulated_response_text:
completion_tokens = count_messages_tokens([{"role": "assistant", "content": accumulated_response_text}], request_data['model'])
else:
completion_tokens = 0
total_tokens = effective_context + completion_tokens
chunk_dict['usage']['prompt_tokens'] = effective_context
chunk_dict['usage']['completion_tokens'] = completion_tokens
chunk_dict['usage']['total_tokens'] = total_tokens
yield f"data: {json.dumps(chunk_dict)}\n\n".encode('utf-8')
except Exception as chunk_error:
# Handle errors during chunk serialization
......@@ -567,7 +601,7 @@ class RequestHandler:
error_dict = {"error": str(e)}
yield f"data: {json.dumps(error_dict)}\n\n".encode('utf-8')
return StreamingResponse(stream_generator(), media_type="text/event-stream")
return StreamingResponse(stream_generator(effective_context), media_type="text/event-stream")
async def handle_model_list(self, request: Request, provider_id: str) -> List[Dict]:
provider_config = self.config.get_provider(provider_id)
......@@ -1039,7 +1073,8 @@ class RotationHandler:
provider_id=provider_id,
model_name=model_name,
handler=handler,
request_data=request_data
request_data=request_data,
effective_context=effective_context
)
else:
logger.info("Returning non-streaming response")
......@@ -1100,7 +1135,8 @@ class RotationHandler:
provider_id=provider_id,
model_name=model_name,
handler=handler,
request_data=request_data
request_data=request_data,
effective_context=effective_context
)
else:
logger.info("Returning non-streaming response")
......@@ -1136,7 +1172,7 @@ class RotationHandler:
detail=f"All providers in rotation failed after {max_retries} attempts. Last error: {last_error}"
)
def _create_streaming_response(self, response, provider_type: str, provider_id: str, model_name: str, handler, request_data: Dict):
def _create_streaming_response(self, response, provider_type: str, provider_id: str, model_name: str, handler, request_data: Dict, effective_context: int):
"""
Create a StreamingResponse with proper handling based on provider type.
......@@ -1147,6 +1183,7 @@ class RotationHandler:
model_name: The model name being used
handler: The provider handler (for recording success/failure)
request_data: The original request data
effective_context: The effective context (total tokens used)
Returns:
StreamingResponse with appropriate generator for the provider type
......@@ -1165,7 +1202,7 @@ class RotationHandler:
seed = request_data.get('seed')
system_fingerprint = generate_system_fingerprint(provider_id, seed)
async def stream_generator():
async def stream_generator(effective_context):
try:
if is_google_provider:
# Handle Google's streaming response
......@@ -1176,6 +1213,10 @@ class RotationHandler:
created_time = int(time.time())
response_id = f"google-{model_name}-{created_time}"
# Track completion tokens for Google responses (since Google doesn't provide them)
completion_tokens = 0
accumulated_response_text = "" # Track full response for token counting
# Collect all chunks first to know when we're at the last one
chunks_list = []
async for chunk in response:
......@@ -1247,6 +1288,10 @@ class RotationHandler:
chunk_id += 1
logger.debug(f"OpenAI chunk (delta length: {len(delta_text)}, finish: {chunk_finish_reason})")
# Track completion tokens for Google responses
if delta_text:
accumulated_response_text += delta_text
yield f"data: {json.dumps(openai_chunk)}\n\n".encode('utf-8')
chunk_idx += 1
......@@ -1259,6 +1304,10 @@ class RotationHandler:
continue
# Send final chunk with usage statistics (empty content)
# Calculate completion tokens for Google responses (count tokens in full response)
if accumulated_response_text:
completion_tokens = count_messages_tokens([{"role": "assistant", "content": accumulated_response_text}], model_name)
total_tokens = effective_context + completion_tokens
final_chunk = {
"id": response_id,
"object": "chat.completion.chunk",
......@@ -1267,9 +1316,9 @@ class RotationHandler:
"service_tier": None,
"system_fingerprint": system_fingerprint,
"usage": {
"prompt_tokens": None,
"completion_tokens": None,
"total_tokens": None,
"prompt_tokens": effective_context,
"completion_tokens": completion_tokens,
"total_tokens": total_tokens,
"effective_context": effective_context
},
"provider": provider_id,
......@@ -1292,6 +1341,7 @@ class RotationHandler:
# Handle OpenAI/Anthropic streaming responses
# OpenAI SDK returns a sync Stream object, not an async iterator
# So we use a regular for loop, not async for
accumulated_response_text = "" # Track full response for token counting
for chunk in response:
try:
logger.debug(f"Chunk type: {type(chunk)}")
......@@ -1300,6 +1350,15 @@ class RotationHandler:
# For OpenAI-compatible providers, just pass through the raw chunk
chunk_dict = chunk.model_dump() if hasattr(chunk, 'model_dump') else chunk
# Track response content for token calculation
if isinstance(chunk_dict, dict):
choices = chunk_dict.get('choices', [])
if choices:
delta = choices[0].get('delta', {})
delta_content = delta.get('content', '')
if delta_content:
accumulated_response_text += delta_content
# Add effective_context to the last chunk (when finish_reason is present)
if isinstance(chunk_dict, dict):
choices = chunk_dict.get('choices', [])
......@@ -1309,6 +1368,18 @@ class RotationHandler:
chunk_dict['usage'] = {}
chunk_dict['usage']['effective_context'] = effective_context
# If provider doesn't provide token counts, calculate them
if chunk_dict['usage'].get('total_tokens') is None:
# Calculate completion tokens from accumulated response
if accumulated_response_text:
completion_tokens = count_messages_tokens([{"role": "assistant", "content": accumulated_response_text}], model_name)
else:
completion_tokens = 0
total_tokens = effective_context + completion_tokens
chunk_dict['usage']['prompt_tokens'] = effective_context
chunk_dict['usage']['completion_tokens'] = completion_tokens
chunk_dict['usage']['total_tokens'] = total_tokens
yield f"data: {json.dumps(chunk_dict)}\n\n".encode('utf-8')
except Exception as chunk_error:
error_msg = str(chunk_error)
......@@ -1323,7 +1394,7 @@ class RotationHandler:
error_dict = {"error": str(e)}
yield f"data: {json.dumps(error_dict)}\n\n".encode('utf-8')
return StreamingResponse(stream_generator(), media_type="text/event-stream")
return StreamingResponse(stream_generator(effective_context), media_type="text/event-stream")
async def handle_rotation_model_list(self, rotation_id: str) -> List[Dict]:
rotation_config = self.config.get_rotation(rotation_id)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment