Fix effective_context variable scope and calculate total tokens for streaming responses

- Pass effective_context as parameter to stream_generator functions
- Update _create_streaming_response signature to accept effective_context
- Update all calls to _create_streaming_response to pass effective_context
- Track accumulated response text for token counting in streaming
- Calculate completion tokens for Google responses (since Google doesn't provide them)
- Calculate completion tokens for non-Google providers when they don't provide token counts
- Include prompt_tokens, completion_tokens, total_tokens, and effective_context in final chunk
- Fixes 'name effective_context is not defined' error in streaming responses
- Fixes issue where streaming responses had null token counts
parent d9cdab8b
...@@ -381,7 +381,7 @@ class RequestHandler: ...@@ -381,7 +381,7 @@ class RequestHandler:
# Update request_data with condensed messages # Update request_data with condensed messages
request_data['messages'] = messages request_data['messages'] = messages
async def stream_generator(): async def stream_generator(effective_context):
import logging import logging
import time import time
import json import json
...@@ -415,6 +415,10 @@ class RequestHandler: ...@@ -415,6 +415,10 @@ class RequestHandler:
created_time = int(time.time()) created_time = int(time.time())
response_id = f"google-{request_data['model']}-{created_time}" response_id = f"google-{request_data['model']}-{created_time}"
# Track completion tokens for Google responses (since Google doesn't provide them)
completion_tokens = 0
accumulated_response_text = "" # Track full response for token counting
# Collect all chunks first to know when we're at the last one # Collect all chunks first to know when we're at the last one
chunks_list = [] chunks_list = []
async for chunk in response: async for chunk in response:
...@@ -486,6 +490,10 @@ class RequestHandler: ...@@ -486,6 +490,10 @@ class RequestHandler:
chunk_id += 1 chunk_id += 1
logger.debug(f"OpenAI chunk (delta length: {len(delta_text)}, finish: {chunk_finish_reason})") logger.debug(f"OpenAI chunk (delta length: {len(delta_text)}, finish: {chunk_finish_reason})")
# Track completion tokens for Google responses
if delta_text:
accumulated_response_text += delta_text
# Serialize as JSON # Serialize as JSON
yield f"data: {json.dumps(openai_chunk)}\n\n".encode('utf-8') yield f"data: {json.dumps(openai_chunk)}\n\n".encode('utf-8')
...@@ -500,6 +508,10 @@ class RequestHandler: ...@@ -500,6 +508,10 @@ class RequestHandler:
continue continue
# Send final chunk with usage statistics (empty content) # Send final chunk with usage statistics (empty content)
# Calculate completion tokens for Google responses (count tokens in full response)
if accumulated_response_text:
completion_tokens = count_messages_tokens([{"role": "assistant", "content": accumulated_response_text}], request_data['model'])
total_tokens = effective_context + completion_tokens
final_chunk = { final_chunk = {
"id": response_id, "id": response_id,
"object": "chat.completion.chunk", "object": "chat.completion.chunk",
...@@ -508,9 +520,9 @@ class RequestHandler: ...@@ -508,9 +520,9 @@ class RequestHandler:
"service_tier": None, "service_tier": None,
"system_fingerprint": system_fingerprint, "system_fingerprint": system_fingerprint,
"usage": { "usage": {
"prompt_tokens": None, "prompt_tokens": effective_context,
"completion_tokens": None, "completion_tokens": completion_tokens,
"total_tokens": None, "total_tokens": total_tokens,
"effective_context": effective_context "effective_context": effective_context
}, },
"provider": provider_id, "provider": provider_id,
...@@ -533,6 +545,7 @@ class RequestHandler: ...@@ -533,6 +545,7 @@ class RequestHandler:
# Handle OpenAI/Anthropic streaming responses # Handle OpenAI/Anthropic streaming responses
# OpenAI SDK returns a sync Stream object, not an async iterator # OpenAI SDK returns a sync Stream object, not an async iterator
# So we use a regular for loop, not async for # So we use a regular for loop, not async for
accumulated_response_text = "" # Track full response for token counting
for chunk in response: for chunk in response:
try: try:
# Debug: Log chunk type and content before serialization # Debug: Log chunk type and content before serialization
...@@ -543,6 +556,15 @@ class RequestHandler: ...@@ -543,6 +556,15 @@ class RequestHandler:
# Convert chunk to dict and serialize as JSON # Convert chunk to dict and serialize as JSON
chunk_dict = chunk.model_dump() if hasattr(chunk, 'model_dump') else chunk chunk_dict = chunk.model_dump() if hasattr(chunk, 'model_dump') else chunk
# Track response content for token calculation
if isinstance(chunk_dict, dict):
choices = chunk_dict.get('choices', [])
if choices:
delta = choices[0].get('delta', {})
delta_content = delta.get('content', '')
if delta_content:
accumulated_response_text += delta_content
# Add effective_context to the last chunk (when finish_reason is present) # Add effective_context to the last chunk (when finish_reason is present)
if isinstance(chunk_dict, dict): if isinstance(chunk_dict, dict):
choices = chunk_dict.get('choices', []) choices = chunk_dict.get('choices', [])
...@@ -551,6 +573,18 @@ class RequestHandler: ...@@ -551,6 +573,18 @@ class RequestHandler:
if 'usage' not in chunk_dict: if 'usage' not in chunk_dict:
chunk_dict['usage'] = {} chunk_dict['usage'] = {}
chunk_dict['usage']['effective_context'] = effective_context chunk_dict['usage']['effective_context'] = effective_context
# If provider doesn't provide token counts, calculate them
if chunk_dict['usage'].get('total_tokens') is None:
# Calculate completion tokens from accumulated response
if accumulated_response_text:
completion_tokens = count_messages_tokens([{"role": "assistant", "content": accumulated_response_text}], request_data['model'])
else:
completion_tokens = 0
total_tokens = effective_context + completion_tokens
chunk_dict['usage']['prompt_tokens'] = effective_context
chunk_dict['usage']['completion_tokens'] = completion_tokens
chunk_dict['usage']['total_tokens'] = total_tokens
yield f"data: {json.dumps(chunk_dict)}\n\n".encode('utf-8') yield f"data: {json.dumps(chunk_dict)}\n\n".encode('utf-8')
except Exception as chunk_error: except Exception as chunk_error:
...@@ -567,7 +601,7 @@ class RequestHandler: ...@@ -567,7 +601,7 @@ class RequestHandler:
error_dict = {"error": str(e)} error_dict = {"error": str(e)}
yield f"data: {json.dumps(error_dict)}\n\n".encode('utf-8') yield f"data: {json.dumps(error_dict)}\n\n".encode('utf-8')
return StreamingResponse(stream_generator(), media_type="text/event-stream") return StreamingResponse(stream_generator(effective_context), media_type="text/event-stream")
async def handle_model_list(self, request: Request, provider_id: str) -> List[Dict]: async def handle_model_list(self, request: Request, provider_id: str) -> List[Dict]:
provider_config = self.config.get_provider(provider_id) provider_config = self.config.get_provider(provider_id)
...@@ -1039,7 +1073,8 @@ class RotationHandler: ...@@ -1039,7 +1073,8 @@ class RotationHandler:
provider_id=provider_id, provider_id=provider_id,
model_name=model_name, model_name=model_name,
handler=handler, handler=handler,
request_data=request_data request_data=request_data,
effective_context=effective_context
) )
else: else:
logger.info("Returning non-streaming response") logger.info("Returning non-streaming response")
...@@ -1100,7 +1135,8 @@ class RotationHandler: ...@@ -1100,7 +1135,8 @@ class RotationHandler:
provider_id=provider_id, provider_id=provider_id,
model_name=model_name, model_name=model_name,
handler=handler, handler=handler,
request_data=request_data request_data=request_data,
effective_context=effective_context
) )
else: else:
logger.info("Returning non-streaming response") logger.info("Returning non-streaming response")
...@@ -1136,7 +1172,7 @@ class RotationHandler: ...@@ -1136,7 +1172,7 @@ class RotationHandler:
detail=f"All providers in rotation failed after {max_retries} attempts. Last error: {last_error}" detail=f"All providers in rotation failed after {max_retries} attempts. Last error: {last_error}"
) )
def _create_streaming_response(self, response, provider_type: str, provider_id: str, model_name: str, handler, request_data: Dict): def _create_streaming_response(self, response, provider_type: str, provider_id: str, model_name: str, handler, request_data: Dict, effective_context: int):
""" """
Create a StreamingResponse with proper handling based on provider type. Create a StreamingResponse with proper handling based on provider type.
...@@ -1147,6 +1183,7 @@ class RotationHandler: ...@@ -1147,6 +1183,7 @@ class RotationHandler:
model_name: The model name being used model_name: The model name being used
handler: The provider handler (for recording success/failure) handler: The provider handler (for recording success/failure)
request_data: The original request data request_data: The original request data
effective_context: The effective context (total tokens used)
Returns: Returns:
StreamingResponse with appropriate generator for the provider type StreamingResponse with appropriate generator for the provider type
...@@ -1165,7 +1202,7 @@ class RotationHandler: ...@@ -1165,7 +1202,7 @@ class RotationHandler:
seed = request_data.get('seed') seed = request_data.get('seed')
system_fingerprint = generate_system_fingerprint(provider_id, seed) system_fingerprint = generate_system_fingerprint(provider_id, seed)
async def stream_generator(): async def stream_generator(effective_context):
try: try:
if is_google_provider: if is_google_provider:
# Handle Google's streaming response # Handle Google's streaming response
...@@ -1176,6 +1213,10 @@ class RotationHandler: ...@@ -1176,6 +1213,10 @@ class RotationHandler:
created_time = int(time.time()) created_time = int(time.time())
response_id = f"google-{model_name}-{created_time}" response_id = f"google-{model_name}-{created_time}"
# Track completion tokens for Google responses (since Google doesn't provide them)
completion_tokens = 0
accumulated_response_text = "" # Track full response for token counting
# Collect all chunks first to know when we're at the last one # Collect all chunks first to know when we're at the last one
chunks_list = [] chunks_list = []
async for chunk in response: async for chunk in response:
...@@ -1247,6 +1288,10 @@ class RotationHandler: ...@@ -1247,6 +1288,10 @@ class RotationHandler:
chunk_id += 1 chunk_id += 1
logger.debug(f"OpenAI chunk (delta length: {len(delta_text)}, finish: {chunk_finish_reason})") logger.debug(f"OpenAI chunk (delta length: {len(delta_text)}, finish: {chunk_finish_reason})")
# Track completion tokens for Google responses
if delta_text:
accumulated_response_text += delta_text
yield f"data: {json.dumps(openai_chunk)}\n\n".encode('utf-8') yield f"data: {json.dumps(openai_chunk)}\n\n".encode('utf-8')
chunk_idx += 1 chunk_idx += 1
...@@ -1259,6 +1304,10 @@ class RotationHandler: ...@@ -1259,6 +1304,10 @@ class RotationHandler:
continue continue
# Send final chunk with usage statistics (empty content) # Send final chunk with usage statistics (empty content)
# Calculate completion tokens for Google responses (count tokens in full response)
if accumulated_response_text:
completion_tokens = count_messages_tokens([{"role": "assistant", "content": accumulated_response_text}], model_name)
total_tokens = effective_context + completion_tokens
final_chunk = { final_chunk = {
"id": response_id, "id": response_id,
"object": "chat.completion.chunk", "object": "chat.completion.chunk",
...@@ -1267,9 +1316,9 @@ class RotationHandler: ...@@ -1267,9 +1316,9 @@ class RotationHandler:
"service_tier": None, "service_tier": None,
"system_fingerprint": system_fingerprint, "system_fingerprint": system_fingerprint,
"usage": { "usage": {
"prompt_tokens": None, "prompt_tokens": effective_context,
"completion_tokens": None, "completion_tokens": completion_tokens,
"total_tokens": None, "total_tokens": total_tokens,
"effective_context": effective_context "effective_context": effective_context
}, },
"provider": provider_id, "provider": provider_id,
...@@ -1292,6 +1341,7 @@ class RotationHandler: ...@@ -1292,6 +1341,7 @@ class RotationHandler:
# Handle OpenAI/Anthropic streaming responses # Handle OpenAI/Anthropic streaming responses
# OpenAI SDK returns a sync Stream object, not an async iterator # OpenAI SDK returns a sync Stream object, not an async iterator
# So we use a regular for loop, not async for # So we use a regular for loop, not async for
accumulated_response_text = "" # Track full response for token counting
for chunk in response: for chunk in response:
try: try:
logger.debug(f"Chunk type: {type(chunk)}") logger.debug(f"Chunk type: {type(chunk)}")
...@@ -1300,6 +1350,15 @@ class RotationHandler: ...@@ -1300,6 +1350,15 @@ class RotationHandler:
# For OpenAI-compatible providers, just pass through the raw chunk # For OpenAI-compatible providers, just pass through the raw chunk
chunk_dict = chunk.model_dump() if hasattr(chunk, 'model_dump') else chunk chunk_dict = chunk.model_dump() if hasattr(chunk, 'model_dump') else chunk
# Track response content for token calculation
if isinstance(chunk_dict, dict):
choices = chunk_dict.get('choices', [])
if choices:
delta = choices[0].get('delta', {})
delta_content = delta.get('content', '')
if delta_content:
accumulated_response_text += delta_content
# Add effective_context to the last chunk (when finish_reason is present) # Add effective_context to the last chunk (when finish_reason is present)
if isinstance(chunk_dict, dict): if isinstance(chunk_dict, dict):
choices = chunk_dict.get('choices', []) choices = chunk_dict.get('choices', [])
...@@ -1308,6 +1367,18 @@ class RotationHandler: ...@@ -1308,6 +1367,18 @@ class RotationHandler:
if 'usage' not in chunk_dict: if 'usage' not in chunk_dict:
chunk_dict['usage'] = {} chunk_dict['usage'] = {}
chunk_dict['usage']['effective_context'] = effective_context chunk_dict['usage']['effective_context'] = effective_context
# If provider doesn't provide token counts, calculate them
if chunk_dict['usage'].get('total_tokens') is None:
# Calculate completion tokens from accumulated response
if accumulated_response_text:
completion_tokens = count_messages_tokens([{"role": "assistant", "content": accumulated_response_text}], model_name)
else:
completion_tokens = 0
total_tokens = effective_context + completion_tokens
chunk_dict['usage']['prompt_tokens'] = effective_context
chunk_dict['usage']['completion_tokens'] = completion_tokens
chunk_dict['usage']['total_tokens'] = total_tokens
yield f"data: {json.dumps(chunk_dict)}\n\n".encode('utf-8') yield f"data: {json.dumps(chunk_dict)}\n\n".encode('utf-8')
except Exception as chunk_error: except Exception as chunk_error:
...@@ -1323,7 +1394,7 @@ class RotationHandler: ...@@ -1323,7 +1394,7 @@ class RotationHandler:
error_dict = {"error": str(e)} error_dict = {"error": str(e)}
yield f"data: {json.dumps(error_dict)}\n\n".encode('utf-8') yield f"data: {json.dumps(error_dict)}\n\n".encode('utf-8')
return StreamingResponse(stream_generator(), media_type="text/event-stream") return StreamingResponse(stream_generator(effective_context), media_type="text/event-stream")
async def handle_rotation_model_list(self, rotation_id: str) -> List[Dict]: async def handle_rotation_model_list(self, rotation_id: str) -> List[Dict]:
rotation_config = self.config.get_rotation(rotation_id) rotation_config = self.config.get_rotation(rotation_id)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment