Fix rotation edit

parent 3d690b74
...@@ -4,6 +4,7 @@ All ASGI middleware functions extracted from main.py. ...@@ -4,6 +4,7 @@ All ASGI middleware functions extracted from main.py.
import time import time
import logging import logging
import threading import threading
import hmac as _hmac
from typing import Optional from typing import Optional
from fastapi import Request from fastapi import Request
from fastapi.responses import JSONResponse, RedirectResponse from fastapi.responses import JSONResponse, RedirectResponse
...@@ -21,7 +22,8 @@ _client_rl_lock = threading.Lock() ...@@ -21,7 +22,8 @@ _client_rl_lock = threading.Lock()
def _get_real_client_ip(request: Request) -> str: def _get_real_client_ip(request: Request) -> str:
xff = request.headers.get('X-Forwarded-For', '') xff = request.headers.get('X-Forwarded-For', '')
if xff: if xff:
return xff.split(',')[0].strip() # Use rightmost IP (appended by the trusted upstream proxy) to prevent spoofing
return xff.split(',')[-1].strip()
client = request.scope.get('client') client = request.scope.get('client')
return client[0] if client else 'unknown' return client[0] if client else 'unknown'
...@@ -62,7 +64,8 @@ _BLOCK_MESSAGE = "We do not support the Israeli genocide of Palestinian people." ...@@ -62,7 +64,8 @@ _BLOCK_MESSAGE = "We do not support the Israeli genocide of Palestinian people."
def _get_client_ip(request: Request) -> Optional[str]: def _get_client_ip(request: Request) -> Optional[str]:
xff = request.headers.get("X-Forwarded-For") xff = request.headers.get("X-Forwarded-For")
if xff: if xff:
return xff.split(",")[0].strip() # Use rightmost IP (appended by the trusted upstream proxy) to prevent spoofing
return xff.split(",")[-1].strip()
client = request.scope.get("client") client = request.scope.get("client")
return client[0] if client else None return client[0] if client else None
...@@ -195,7 +198,10 @@ def make_auth_middleware(get_server_config, get_config, get_db, url_for_fn): ...@@ -195,7 +198,10 @@ def make_auth_middleware(get_server_config, get_config, get_db, url_for_fn):
token = auth_header.replace('Bearer ', '') token = auth_header.replace('Bearer ', '')
allowed_tokens = server_config.get('auth_tokens', []) allowed_tokens = server_config.get('auth_tokens', [])
if token in allowed_tokens: _token_valid = False
for _t in allowed_tokens:
_token_valid |= _hmac.compare_digest(token, _t)
if _token_valid:
request.state.user_id = None request.state.user_id = None
request.state.token_id = None request.state.token_id = None
request.state.is_global_token = True request.state.is_global_token = True
......
...@@ -42,39 +42,6 @@ try: ...@@ -42,39 +42,6 @@ try:
except ImportError: except ImportError:
HAS_CURL_CFFI = False HAS_CURL_CFFI = False
# Configuration matching the official Claude CLI
# Try to load client_id from credentials file first, fallback to generated UUID
import json
import os
from pathlib import Path
def _load_client_id_from_credentials():
"""Attempt to load client_id from existing Claude credentials file"""
try:
creds_path = Path.home() / ".claude" / ".credentials.json"
if creds_path.exists():
with open(creds_path, 'r') as f:
creds = json.load(f)
# Try to extract client_id from various possible locations
if 'client_id' in creds:
return creds['client_id']
elif 'oauth' in creds and 'client_id' in creds['oauth']:
return creds['oauth']['client_id']
elif 'claudeAiOauth' in creds and 'client_id' in creds['claudeAiOauth']:
return creds['claudeAiOauth']['client_id']
except Exception:
pass
return None
def _generate_client_id():
"""Generate a stable client_id UUID based on machine characteristics"""
# Use machine hostname and platform to generate a stable UUID
import uuid
import platform
machine_id = f"{platform.node()}-{platform.machine()}-claude-code"
# Generate UUID5 (name-based) from the machine ID
return str(uuid.uuid5(uuid.NAMESPACE_DNS, machine_id))
# Claude OAuth2 Configuration # Claude OAuth2 Configuration
# These values match the official claude-cli implementation # These values match the official claude-cli implementation
CLIENT_ID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e" # Official Claude Code client ID CLIENT_ID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e" # Official Claude Code client ID
......
...@@ -28,14 +28,11 @@ import logging ...@@ -28,14 +28,11 @@ import logging
# On read, detect format by attempting JSON first so legacy pickle data still works. # On read, detect format by attempting JSON first so legacy pickle data still works.
def _cache_encode(value: any) -> bytes: def _cache_encode(value: any) -> bytes:
"""Encode a cache value. Prefers JSON; falls back to pickle.""" """Encode a cache value using JSON only."""
try: return b'\x00' + json.dumps(value, ensure_ascii=False).encode('utf-8')
return b'\x00' + json.dumps(value, ensure_ascii=False).encode('utf-8')
except (TypeError, ValueError):
return b'\x01' + pickle.dumps(value)
def _cache_decode(data: bytes) -> any: def _cache_decode(data: bytes) -> any:
"""Decode a cache value encoded by _cache_encode, or legacy raw pickle bytes.""" """Decode a cache value encoded by _cache_encode. Legacy pickle entries are discarded."""
if isinstance(data, memoryview): if isinstance(data, memoryview):
data = bytes(data) data = bytes(data)
if not data: if not data:
...@@ -43,12 +40,15 @@ def _cache_decode(data: bytes) -> any: ...@@ -43,12 +40,15 @@ def _cache_decode(data: bytes) -> any:
if data[0:1] == b'\x00': if data[0:1] == b'\x00':
return json.loads(data[1:].decode('utf-8')) return json.loads(data[1:].decode('utf-8'))
if data[0:1] == b'\x01': if data[0:1] == b'\x01':
return pickle.loads(data[1:]) # Legacy pickle-encoded entry — discard; will be recalculated on next miss
# Legacy: no prefix — assume raw pickle logger.warning("Discarding legacy pickle-encoded cache entry (will be recalculated)")
return None
# Legacy: no prefix — try JSON, discard if unparseable
try: try:
return pickle.loads(data)
except Exception:
return json.loads(data.decode('utf-8')) return json.loads(data.decode('utf-8'))
except Exception:
logger.warning("Discarding unrecognised legacy cache entry (will be recalculated)")
return None
from typing import Any, Optional, Dict, List from typing import Any, Optional, Dict, List
from pathlib import Path from pathlib import Path
import time import time
......
...@@ -222,7 +222,7 @@ class DatabaseManager: ...@@ -222,7 +222,7 @@ class DatabaseManager:
async def __aenter__(self): async def __aenter__(self):
return self return self
async def __aexit__(self, exc_type, exc_val, exc_tb): async def __aexit__(self, exc_type, exc_val, exc_tb):
pass return False # never suppress exceptions
return TransactionContext() return TransactionContext()
...@@ -416,24 +416,11 @@ class DatabaseManager: ...@@ -416,24 +416,11 @@ class DatabaseManager:
completion_tokens: Optional number of output/completion tokens completion_tokens: Optional number of output/completion tokens
actual_cost: Optional actual cost returned by provider (in USD) actual_cost: Optional actual cost returned by provider (in USD)
""" """
logger.info(f"💾 DB.record_token_usage ENTERED: provider={provider_id}, tokens={tokens_used}, user_id={user_id}") logger.debug(f"DB.record_token_usage: provider={provider_id}, tokens={tokens_used}, user_id={user_id}")
try: try:
# Convert latency to int for storage # Convert latency to int for storage
latency_int = int(latency_ms) if latency_ms else 0 latency_int = int(latency_ms) if latency_ms else 0
logger.info(f"🔍 DB.record_token_usage FULL PARAMETERS:") logger.debug(f"DB.record_token_usage params: provider={provider_id}, model={model_name}, tokens={tokens_used}, user={user_id}, success={success}")
logger.info(f" provider_id: {provider_id}")
logger.info(f" model_name: {model_name}")
logger.info(f" tokens_used: {tokens_used}")
logger.info(f" user_id: {user_id}")
logger.info(f" success: {success}")
logger.info(f" latency_ms: {latency_ms} → latency_int: {latency_int}")
logger.info(f" error_type: {error_type}")
logger.info(f" token_id: {token_id}")
logger.info(f" prompt_tokens: {prompt_tokens}")
logger.info(f" completion_tokens: {completion_tokens}")
logger.info(f" actual_cost: {actual_cost}")
logger.info(f" db_type: {self.db_type}")
logger.info(f"DB.record_token_usage: About to execute SQL - provider={provider_id}, tokens={tokens_used}, success={success}")
with self._get_connection() as conn: with self._get_connection() as conn:
cursor = conn.cursor() cursor = conn.cursor()
...@@ -451,31 +438,26 @@ class DatabaseManager: ...@@ -451,31 +438,26 @@ class DatabaseManager:
VALUES ({placeholder}, {placeholder}, {placeholder}, {placeholder}, {placeholder}, {placeholder}, {placeholder}, {placeholder}, {placeholder}, {placeholder}, {placeholder}, {placeholder}, {placeholder}, CURRENT_TIMESTAMP) VALUES ({placeholder}, {placeholder}, {placeholder}, {placeholder}, {placeholder}, {placeholder}, {placeholder}, {placeholder}, {placeholder}, {placeholder}, {placeholder}, {placeholder}, {placeholder}, CURRENT_TIMESTAMP)
''' '''
params = (user_id, provider_id, model_name, tokens_used, prompt_tokens, completion_tokens, actual_cost, success, latency_int, error_type, token_id, rotation_id, autoselect_id) params = (user_id, provider_id, model_name, tokens_used, prompt_tokens, completion_tokens, actual_cost, success, latency_int, error_type, token_id, rotation_id, autoselect_id)
logger.info(f"🔍 Trying full INSERT with {len(params)} parameters") logger.debug(f"Trying full INSERT with {len(params)} parameters")
logger.debug(f"🔍 SQL: {sql}")
logger.debug(f"🔍 Params: {params}")
cursor.execute(sql, params) cursor.execute(sql, params)
logger.info(f"✅ Inserted with full column set, rows affected: {cursor.rowcount}") logger.debug(f"Inserted with full column set, rows affected: {cursor.rowcount}")
except Exception as full_insert_error: except Exception as full_insert_error:
logger.warning(f"⚠️ Full column insert failed: {full_insert_error}") logger.warning(f"⚠️ Full column insert failed: {full_insert_error}")
logger.warning(f"⚠️ Full insert error type: {type(full_insert_error).__name__}") logger.warning(f"⚠️ Full insert error type: {type(full_insert_error).__name__}")
import traceback import traceback
logger.warning(f"⚠️ Full insert traceback: {traceback.format_exc()}") logger.warning(f"⚠️ Full insert traceback: {traceback.format_exc()}")
logger.info(f"🔍 Falling back to basic insert") logger.debug("Falling back to basic insert")
# Fallback to basic columns only # Fallback to basic columns only
sql = f''' sql = f'''
INSERT INTO token_usage (user_id, provider_id, model_name, tokens_used, timestamp) INSERT INTO token_usage (user_id, provider_id, model_name, tokens_used, timestamp)
VALUES ({placeholder}, {placeholder}, {placeholder}, {placeholder}, CURRENT_TIMESTAMP) VALUES ({placeholder}, {placeholder}, {placeholder}, {placeholder}, CURRENT_TIMESTAMP)
''' '''
params = (user_id, provider_id, model_name, tokens_used) params = (user_id, provider_id, model_name, tokens_used)
logger.info(f"🔍 Trying basic INSERT with {len(params)} parameters")
logger.debug(f"🔍 SQL: {sql}")
logger.debug(f"🔍 Params: {params}")
cursor.execute(sql, params) cursor.execute(sql, params)
logger.info(f"✅ Inserted with basic column set, rows affected: {cursor.rowcount}") logger.debug(f"Inserted with basic column set, rows affected: {cursor.rowcount}")
conn.commit() conn.commit()
logger.info(f"✅ Successfully recorded token usage for {provider_id}/{model_name}: {tokens_used} tokens (user_id={user_id})") logger.info(f"Recorded token usage: {provider_id}/{model_name} {tokens_used} tokens (user_id={user_id})")
except Exception as e: except Exception as e:
logger.error(f"❌ Failed to record token usage for {provider_id}/{model_name}: {e}") logger.error(f"❌ Failed to record token usage for {provider_id}/{model_name}: {e}")
logger.error(f"Error details - user_id={user_id}, tokens={tokens_used}, success={success}") logger.error(f"Error details - user_id={user_id}, tokens={tokens_used}, success={success}")
...@@ -485,7 +467,7 @@ class DatabaseManager: ...@@ -485,7 +467,7 @@ class DatabaseManager:
test_cursor = test_conn.cursor() test_cursor = test_conn.cursor()
test_cursor.execute("INSERT INTO token_usage (provider_id, model_name, tokens_used, success) VALUES (?, 'test', 1, 1)" if self.db_type == 'sqlite' else "INSERT INTO token_usage (provider_id, model_name, tokens_used, success) VALUES (%s, 'test', 1, 1)", (f"test-{provider_id}",)) test_cursor.execute("INSERT INTO token_usage (provider_id, model_name, tokens_used, success) VALUES (?, 'test', 1, 1)" if self.db_type == 'sqlite' else "INSERT INTO token_usage (provider_id, model_name, tokens_used, success) VALUES (%s, 'test', 1, 1)", (f"test-{provider_id}",))
test_conn.commit() test_conn.commit()
logger.info("✅ Test database insert succeeded") logger.debug("Test database insert succeeded")
except Exception as test_e: except Exception as test_e:
logger.error(f"❌ Even test database insert failed: {test_e}") logger.error(f"❌ Even test database insert failed: {test_e}")
raise raise
......
...@@ -540,7 +540,6 @@ class RequestHandler: ...@@ -540,7 +540,6 @@ class RequestHandler:
# Apply rate limiting # Apply rate limiting
logger.info("Applying rate limiting...") logger.info("Applying rate limiting...")
await handler.apply_rate_limit() await handler.apply_rate_limit()
await handler.apply_rate_limit()
logger.info("Rate limiting applied") logger.info("Rate limiting applied")
logger.info(f"Sending request to provider handler...") logger.info(f"Sending request to provider handler...")
...@@ -729,7 +728,14 @@ class RequestHandler: ...@@ -729,7 +728,14 @@ class RequestHandler:
else: else:
provider_config = self.config.get_provider(provider_id) provider_config = self.config.get_provider(provider_id)
if provider_config.api_key_required: if isinstance(provider_config, dict):
api_key_required = provider_config.get('api_key_required', False)
_provider_type = provider_config.get('type', '')
else:
api_key_required = provider_config.api_key_required
_provider_type = provider_config.type
if api_key_required:
api_key = request_data.get('api_key') or request.headers.get('Authorization', '').replace('Bearer ', '') api_key = request_data.get('api_key') or request.headers.get('Authorization', '').replace('Bearer ', '')
if not api_key: if not api_key:
raise HTTPException(status_code=401, detail="API key required") raise HTTPException(status_code=401, detail="API key required")
...@@ -745,11 +751,11 @@ class RequestHandler: ...@@ -745,11 +751,11 @@ class RequestHandler:
# If seed is present in request, generate unique fingerprint per request # If seed is present in request, generate unique fingerprint per request
seed = request_data.get('seed') seed = request_data.get('seed')
system_fingerprint = generate_system_fingerprint(provider_id, seed) system_fingerprint = generate_system_fingerprint(provider_id, seed)
# Get context configuration and calculate effective context # Get context configuration and calculate effective context
model = request_data.get('model') model = request_data.get('model')
messages = request_data.get('messages', []) messages = request_data.get('messages', [])
context_config = get_context_config_for_model( context_config = get_context_config_for_model(
model_name=model, model_name=model,
provider_config=provider_config, provider_config=provider_config,
...@@ -807,12 +813,12 @@ class RequestHandler: ...@@ -807,12 +813,12 @@ class RequestHandler:
# Check if this is a Google streaming response by checking provider type from config # Check if this is a Google streaming response by checking provider type from config
# This is more reliable than checking response iterability which can cause false positives # This is more reliable than checking response iterability which can cause false positives
is_google_stream = provider_config.type == 'google' is_google_stream = _provider_type == 'google'
is_kiro_stream = provider_config.type == 'kiro' is_kiro_stream = _provider_type == 'kiro'
is_kilo_stream = provider_config.type in ('kilo', 'kilocode') is_kilo_stream = _provider_type in ('kilo', 'kilocode')
logger.info(f"Is Google streaming response: {is_google_stream} (provider type: {provider_config.type})") logger.info(f"Is Google streaming response: {is_google_stream} (provider type: {_provider_type})")
logger.info(f"Is Kiro streaming response: {is_kiro_stream} (provider type: {provider_config.type})") logger.info(f"Is Kiro streaming response: {is_kiro_stream} (provider type: {_provider_type})")
logger.info(f"Is Kilo streaming response: {is_kilo_stream} (provider type: {provider_config.type})") logger.info(f"Is Kilo streaming response: {is_kilo_stream} (provider type: {_provider_type})")
if is_kilo_stream: if is_kilo_stream:
# Handle Kilo/KiloCode streaming response # Handle Kilo/KiloCode streaming response
...@@ -1797,6 +1803,56 @@ class RequestHandler: ...@@ -1797,6 +1803,56 @@ class RequestHandler:
return capabilities return capabilities
async def handle_generic_proxy(self, request: Request, provider_id: str, endpoint_path: str, body: dict, method: str = "POST") -> JSONResponse:
"""Forward a request to the provider's native endpoint and return the response."""
import httpx
import logging
logger = logging.getLogger(__name__)
# Support user-defined providers (dict format) and global providers (object format)
if self.user_id and provider_id in self.user_providers:
provider_config = self.user_providers[provider_id]
base_url = (provider_config.get('endpoint') or '').rstrip('/')
api_key_required = provider_config.get('api_key_required', False)
config_api_key = provider_config.get('api_key')
else:
provider_config = self.config.get_provider(provider_id)
base_url = (getattr(provider_config, 'endpoint', '') or '').rstrip('/')
api_key_required = getattr(provider_config, 'api_key_required', False)
config_api_key = getattr(provider_config, 'api_key', None)
# Strip trailing /chat/completions or /completions to get the real base
for suffix in ['/chat/completions', '/completions']:
if base_url.endswith(suffix):
base_url = base_url[:-len(suffix)]
break
url = f"{base_url}/{endpoint_path.lstrip('/')}"
headers = {'Content-Type': 'application/json'}
if api_key_required:
api_key = request.headers.get('Authorization', '').replace('Bearer ', '') or config_api_key
if api_key:
headers['Authorization'] = f'Bearer {api_key}'
logger.info(f"Generic proxy [{method}]: {provider_id} -> {url}")
try:
async with httpx.AsyncClient(timeout=300) as client:
if method == "GET":
resp = await client.get(url, headers=headers)
elif method == "DELETE":
resp = await client.delete(url, headers=headers)
else:
resp = await client.post(url, json=body, headers=headers)
try:
content = resp.json()
except Exception:
content = {"detail": resp.text}
return JSONResponse(status_code=resp.status_code, content=content)
except Exception as e:
logger.error(f"Generic proxy error: {e}", exc_info=True)
raise HTTPException(status_code=502, detail=str(e))
async def handle_audio_transcription(self, request: Request, provider_id: str, form_data) -> Dict: async def handle_audio_transcription(self, request: Request, provider_id: str, form_data) -> Dict:
"""Handle audio transcription requests""" """Handle audio transcription requests"""
import logging import logging
......
...@@ -429,9 +429,27 @@ class PayPalPaymentHandler: ...@@ -429,9 +429,27 @@ class PayPalPaymentHandler:
'Wallet top up via PayPal') 'Wallet top up via PayPal')
async def _handle_order_approved(self, resource: dict): async def _handle_order_approved(self, resource: dict):
"""Handle approved order (capture pending).""" """Handle approved order — record pending capture state."""
order_id = resource.get('id') order_id = resource.get('id')
logger.info(f"PayPal order approved: {order_id}") logger.info(f"PayPal order approved (awaiting capture): {order_id}")
try:
placeholder = '?' if self.db.db_type == 'sqlite' else '%s'
with self.db._get_connection() as conn:
cursor = conn.cursor()
cursor.execute(f"""
INSERT OR IGNORE INTO payment_transactions
(gateway, gateway_transaction_id, status, created_at)
VALUES ({placeholder}, {placeholder}, 'pending_capture', CURRENT_TIMESTAMP)
ON CONFLICT(gateway_transaction_id) DO UPDATE SET status='pending_capture'
""", ('paypal', order_id)) if self.db.db_type == 'sqlite' else cursor.execute(f"""
INSERT INTO payment_transactions
(gateway, gateway_transaction_id, status, created_at)
VALUES ({placeholder}, {placeholder}, 'pending_capture', CURRENT_TIMESTAMP)
ON DUPLICATE KEY UPDATE status='pending_capture'
""", ('paypal', order_id))
conn.commit()
except Exception as e:
logger.warning(f"PayPal: could not record approved order {order_id}: {e}")
async def _handle_payment_capture_completed(self, resource: dict): async def _handle_payment_capture_completed(self, resource: dict):
"""Handle completed payment capture — credit wallet.""" """Handle completed payment capture — credit wallet."""
...@@ -501,9 +519,24 @@ class PayPalPaymentHandler: ...@@ -501,9 +519,24 @@ class PayPalPaymentHandler:
logger.warning(f"PayPal refund: cannot apply refund {refund_id} — missing user_id/amount") logger.warning(f"PayPal refund: cannot apply refund {refund_id} — missing user_id/amount")
async def _handle_vault_token_created(self, resource: dict): async def _handle_vault_token_created(self, resource: dict):
"""Handle vault token creation.""" """Handle vault token creation — store as a payment method."""
token_id = resource.get('id') token_id = resource.get('id')
logger.info(f"PayPal vault token created: {token_id}") logger.info(f"PayPal vault token created: {token_id}")
customer = resource.get('customer', {})
merchant_customer_id = customer.get('merchant_customer_id') or resource.get('metadata', {}).get('merchant_customer_id')
if not (token_id and merchant_customer_id):
logger.warning(f"PayPal vault token {token_id}: missing merchant_customer_id, skipping save")
return
try:
user_id = int(merchant_customer_id)
except (ValueError, TypeError):
logger.warning(f"PayPal vault token {token_id}: invalid merchant_customer_id {merchant_customer_id!r}")
return
try:
self.db.add_payment_method(user_id, 'paypal', token_id, is_default=False, metadata={'paypal_vault_token': token_id})
logger.info(f"Stored PayPal vault token {token_id} as payment method for user {user_id}")
except Exception as e:
logger.error(f"PayPal: failed to store vault token {token_id} for user {user_id}: {e}")
async def _handle_vault_token_deleted(self, resource: dict): async def _handle_vault_token_deleted(self, resource: dict):
"""Handle vault token deletion — deactivate matching payment method.""" """Handle vault token deletion — deactivate matching payment method."""
...@@ -590,41 +623,36 @@ class PayPalPaymentHandler: ...@@ -590,41 +623,36 @@ class PayPalPaymentHandler:
logger.error(f"Error creating PayPal top up order: {e}") logger.error(f"Error creating PayPal top up order: {e}")
return {'success': False, 'error': str(e)} return {'success': False, 'error': str(e)}
async def _handle_order_completed(self, resource: dict):
"""Handle completed order (Vault v3)"""
order_id = resource.get('id')
logger.info(f"PayPal order completed: {order_id}")
# Check if this is a top up order
purchase_units = resource.get('purchase_units', [])
if purchase_units and 'Wallet top up' in purchase_units[0].get('description', ''):
amount = Decimal(purchase_units[0]['amount']['value'])
user_id = int(resource.get('custom_id', 0))
if user_id > 0:
from aisbf.payments.wallet.manager import WalletManager
from sqlalchemy.ext.asyncio import AsyncSession
async with AsyncSession(self.db.engine) as session:
wallet_manager = WalletManager(session)
await wallet_manager.credit_wallet(
user_id=user_id,
amount=amount,
transaction_details={
'payment_gateway': 'paypal',
'gateway_transaction_id': order_id,
'description': 'Wallet top up via PayPal',
'metadata': {'order_id': order_id}
}
)
await session.commit()
logger.info(f"Wallet credited successfully for user {user_id}, amount {amount}")
async def _handle_payment_completed(self, resource: dict): async def _handle_payment_completed(self, resource: dict):
"""Handle completed payment (legacy)""" """Handle completed payment (legacy PAYMENT.SALE.COMPLETED) — credit wallet if applicable."""
logger.info(f"PayPal payment completed: {resource.get('id')}") payment_id = resource.get('id')
logger.info(f"PayPal payment completed: {payment_id}")
custom_id = resource.get('custom', '') or resource.get('custom_id', '')
amount_obj = resource.get('amount', {})
try:
amount = Decimal(amount_obj.get('total', amount_obj.get('value', '0')))
user_id = int(custom_id) if custom_id else 0
except (ValueError, TypeError):
user_id = 0
if user_id > 0 and amount > 0:
await self._credit_wallet_for_paypal(user_id, amount, payment_id, 'Payment via PayPal')
else:
logger.debug(f"PayPal PAYMENT.SALE.COMPLETED {payment_id}: no user_id/amount to credit")
async def _handle_payment_denied(self, resource: dict): async def _handle_payment_denied(self, resource: dict):
"""Handle denied payment (legacy)""" """Handle denied payment (legacy PAYMENT.SALE.DENIED) — queue for retry."""
logger.warning(f"PayPal payment denied: {resource.get('id')}") payment_id = resource.get('id')
logger.warning(f"PayPal payment denied: {payment_id}")
try:
placeholder = '?' if self.db.db_type == 'sqlite' else '%s'
with self.db._get_connection() as conn:
cursor = conn.cursor()
cursor.execute(f"""
INSERT INTO payment_retry_queue
(gateway, gateway_transaction_id, status, next_retry_at, created_at)
VALUES ({placeholder}, {placeholder}, 'pending',
CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
""", ('paypal', payment_id))
conn.commit()
except Exception as e:
logger.error(f"PayPal: failed to queue denied payment {payment_id} for retry: {e}")
...@@ -319,5 +319,21 @@ class StripePaymentHandler: ...@@ -319,5 +319,21 @@ class StripePaymentHandler:
return {"success": False, "error": str(e)} return {"success": False, "error": str(e)}
async def _handle_payment_failed(self, payment_intent: dict): async def _handle_payment_failed(self, payment_intent: dict):
"""Handle failed payment""" """Handle failed payment — queue for retry and log the failure reason."""
logger.warning(f"Payment failed: {payment_intent['id']}") intent_id = payment_intent.get('id', '')
error = payment_intent.get('last_payment_error', {}) or {}
reason = error.get('message', 'unknown')
logger.warning(f"Stripe payment failed: {intent_id} — {reason}")
try:
placeholder = '?' if self.db.db_type == 'sqlite' else '%s'
with self.db._get_connection() as conn:
cursor = conn.cursor()
cursor.execute(f"""
INSERT INTO payment_retry_queue
(gateway, gateway_transaction_id, status, next_retry_at, created_at)
VALUES ({placeholder}, {placeholder}, 'pending',
CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
""", ('stripe', intent_id))
conn.commit()
except Exception as e:
logger.error(f"Stripe: failed to queue failed payment {intent_id} for retry: {e}")
...@@ -88,19 +88,45 @@ class CodexProviderHandler(BaseProviderHandler): ...@@ -88,19 +88,45 @@ class CodexProviderHandler(BaseProviderHandler):
) )
# Determine mode: API key mode or OAuth2 mode # Determine mode: API key mode or OAuth2 mode
# Treat empty strings and placeholder values as "no key"
def _is_real_key(k):
return bool(k) and str(k).strip() not in ('', 'placeholder', 'YOUR_API_KEY', 'none', 'null')
_cfg_api_key = (provider_config.get('api_key') if isinstance(provider_config, dict) _cfg_api_key = (provider_config.get('api_key') if isinstance(provider_config, dict)
else getattr(provider_config, 'api_key', None)) if provider_config else None else getattr(provider_config, 'api_key', None)) if provider_config else None
self._use_api_key_mode = bool(api_key or _cfg_api_key) self._use_api_key_mode = _is_real_key(api_key) or _is_real_key(_cfg_api_key)
self._account_id = None # Will be extracted from ID token in OAuth2 mode self._account_id = None # Will be extracted from ID token in OAuth2 mode
# Base URL for API requests # Base URL for API requests
_endpoint = (provider_config.get('endpoint') if isinstance(provider_config, dict) _endpoint = (provider_config.get('endpoint') if isinstance(provider_config, dict)
else getattr(provider_config, 'endpoint', None)) if provider_config else None else getattr(provider_config, 'endpoint', None)) if provider_config else None
self.base_url = (_endpoint or 'https://chatgpt.com/backend-api').rstrip('/')
CHATGPT_BACKEND = 'https://chatgpt.com/backend-api'
OPENAI_API = 'https://api.openai.com/v1'
def _is_chatgpt_backend(url: str) -> bool:
return url.rstrip('/').startswith(CHATGPT_BACKEND.rstrip('/'))
if self._use_api_key_mode:
# In API key mode, use OpenAI API for any chatgpt.com/backend-api URL
# (including subpaths like /codex) — the ChatGPT backend does not support
# the standard OpenAI /chat/completions format.
if _endpoint and not _is_chatgpt_backend(_endpoint):
self.base_url = _endpoint.rstrip('/')
else:
self.base_url = OPENAI_API
else:
# In OAuth2 mode, always use the bare ChatGPT backend base URL.
# Any /codex or other suffix in the configured endpoint is stripped here;
# the specific API path (/codex/responses) is appended later at call time.
if _endpoint and not _is_chatgpt_backend(_endpoint):
self.base_url = _endpoint.rstrip('/')
else:
self.base_url = CHATGPT_BACKEND
# Initialize OpenAI client for API key mode # Initialize OpenAI client for API key mode
if self._use_api_key_mode: if self._use_api_key_mode:
effective_key = api_key or _cfg_api_key effective_key = (api_key if _is_real_key(api_key) else None) or (_cfg_api_key if _is_real_key(_cfg_api_key) else None)
self.client = OpenAI(api_key=effective_key, base_url=self.base_url) self.client = OpenAI(api_key=effective_key, base_url=self.base_url)
else: else:
self.client = None self.client = None
......
This diff is collapsed.
from fastapi import APIRouter, Request, Form, Query, UploadFile, File, HTTPException from fastapi import APIRouter, Request, Form, Query, UploadFile, File, HTTPException
from fastapi.responses import JSONResponse, RedirectResponse, HTMLResponse, Response from fastapi.responses import JSONResponse, RedirectResponse, HTMLResponse, Response
from typing import Optional from typing import Optional
import time, logging, secrets, hashlib, os, re import time, logging, secrets, hashlib, os, re, hmac
from pathlib import Path from pathlib import Path
from datetime import datetime, timedelta from datetime import datetime, timedelta
from aisbf.database import DatabaseRegistry from aisbf.database import DatabaseRegistry
...@@ -753,7 +753,7 @@ async def dashboard_change_password_save(request: Request, current_password: str ...@@ -753,7 +753,7 @@ async def dashboard_change_password_save(request: Request, current_password: str
try: try:
if not db.verify_user_password(user_id, current_password): if not db.verify_user_password(user_id, current_password):
return RedirectResponse(url=url_for(request, "/dashboard/change-password?error=Current password is incorrect"), status_code=303) return RedirectResponse(url=url_for(request, "/dashboard/change-password?error=Current password is incorrect"), status_code=303)
db.update_user_password(user_id, new_password) db.update_user_password(user_id, _db_hash_password(new_password))
return RedirectResponse(url=url_for(request, "/dashboard/change-password?success=Password changed successfully"), status_code=303) return RedirectResponse(url=url_for(request, "/dashboard/change-password?success=Password changed successfully"), status_code=303)
except Exception as e: except Exception as e:
return RedirectResponse(url=url_for(request, f"/dashboard/change-password?error=Failed to change password: {str(e)}"), status_code=303) return RedirectResponse(url=url_for(request, f"/dashboard/change-password?error=Failed to change password: {str(e)}"), status_code=303)
...@@ -961,7 +961,7 @@ async def oauth2_google_callback(request: Request, code: str = Query(...), state ...@@ -961,7 +961,7 @@ async def oauth2_google_callback(request: Request, code: str = Query(...), state
redirect_uri = f"{base_url}/auth/oauth2/google/callback" redirect_uri = f"{base_url}/auth/oauth2/google/callback"
session_state = request.session.get('oauth2_google', {}).get('state') session_state = request.session.get('oauth2_google', {}).get('state')
if state != session_state: if not hmac.compare_digest(state, session_state or ''):
return _templates.TemplateResponse(request=request, name="dashboard/login.html", return _templates.TemplateResponse(request=request, name="dashboard/login.html",
context={"request": request, "config": _config, "error": "Invalid authentication state"}) context={"request": request, "config": _config, "error": "Invalid authentication state"})
...@@ -1093,7 +1093,7 @@ async def oauth2_github_callback(request: Request, code: str = Query(...), state ...@@ -1093,7 +1093,7 @@ async def oauth2_github_callback(request: Request, code: str = Query(...), state
redirect_uri = f"{base_url}/auth/oauth2/github/callback" redirect_uri = f"{base_url}/auth/oauth2/github/callback"
session_state = request.session.get('oauth2_github', {}).get('state') session_state = request.session.get('oauth2_github', {}).get('state')
if state != session_state: if not hmac.compare_digest(state, session_state or ''):
return _templates.TemplateResponse(request=request, name="dashboard/login.html", return _templates.TemplateResponse(request=request, name="dashboard/login.html",
context={"request": request, "config": _config, "error": "Invalid authentication state"}) context={"request": request, "config": _config, "error": "Invalid authentication state"})
......
...@@ -20,13 +20,20 @@ except ImportError: ...@@ -20,13 +20,20 @@ except ImportError:
router = APIRouter() router = APIRouter()
_config = None _config = None
_templates = None _templates = None
_payment_service = None
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def init(config, templates): def init(config, templates, payment_service=None):
global _config, _templates global _config, _templates, _payment_service
_config = config _config = config
_templates = templates _templates = templates
_payment_service = payment_service
def set_payment_service(service):
global _payment_service
_payment_service = service
@router.get("/dashboard/billing/add-method", response_class=HTMLResponse) @router.get("/dashboard/billing/add-method", response_class=HTMLResponse)
...@@ -99,8 +106,8 @@ async def dashboard_add_payment_method_stripe(request: Request): ...@@ -99,8 +106,8 @@ async def dashboard_add_payment_method_stripe(request: Request):
try: try:
# Attach the PM to the Stripe customer so it can be charged later # Attach the PM to the Stripe customer so it can be charged later
if payment_service and payment_service.stripe_handler: if _payment_service and _payment_service.stripe_handler:
customer_id = await payment_service.stripe_handler._get_or_create_customer(user_id) customer_id = await _payment_service.stripe_handler._get_or_create_customer(user_id)
import stripe as _stripe import stripe as _stripe
import asyncio as _asyncio import asyncio as _asyncio
try: try:
......
This diff is collapsed.
...@@ -137,8 +137,8 @@ else ...@@ -137,8 +137,8 @@ else
fi fi
# Remove _share directory (PyPI packaging artifacts) # Remove _share directory (PyPI packaging artifacts)
if [ -d "_share" ]; then if [ -d "aisbf/_share" ]; then
echo "Removing _share/ directory..." echo "Removing aisbf/_share/ directory..."
rm -rf _share rm -rf _share
echo " ✓ _share/ removed" echo " ✓ _share/ removed"
else else
......
...@@ -254,15 +254,15 @@ function buildProviderSelectHtml(uid, currentValue, onChangeFn) { ...@@ -254,15 +254,15 @@ function buildProviderSelectHtml(uid, currentValue, onChangeFn) {
const opts = availableProviders.map(p => const opts = availableProviders.map(p =>
`<option value="${escHtmlAttr(p)}" ${currentValue === p ? 'selected' : ''}>${escHtmlAttr(p)}</option>` `<option value="${escHtmlAttr(p)}" ${currentValue === p ? 'selected' : ''}>${escHtmlAttr(p)}</option>`
).join(''); ).join('');
return `<select id="${uid}" onchange="${onChangeFn}(this.value)" required> return `<select id="${uid}" onchange="(${onChangeFn})(this.value)" required>
<option value="">${window.i18n.t('rotations.select_provider')}</option>${opts}</select>`; <option value="">${window.i18n.t('rotations.select_provider')}</option>${opts}</select>`;
} else { } else {
const dlOpts = availableProviders.map(p => `<option value="${escHtmlAttr(p)}">`).join(''); const dlOpts = availableProviders.map(p => `<option value="${escHtmlAttr(p)}">`).join('');
return `<div style="position:relative;"> return `<div style="position:relative;">
<input type="text" id="${uid}" value="${escHtmlAttr(currentValue)}" list="${uid}-dl" <input type="text" id="${uid}" value="${escHtmlAttr(currentValue)}" list="${uid}-dl"
placeholder="${window.i18n.t('rotations.type_search_provider')}" placeholder="${window.i18n.t('rotations.type_search_provider')}"
oninput="handleProviderInput('${uid}', this.value, ${onChangeFn})" oninput="handleProviderInput('${uid}', this.value, (${onChangeFn}))"
onchange="handleProviderInput('${uid}', this.value, ${onChangeFn})" onchange="handleProviderInput('${uid}', this.value, (${onChangeFn}))"
style="width:100%;padding:8px;border:1px solid var(--color-border);border-radius:3px;background:var(--bg-page);color:var(--color-text);font-size:14px;"> style="width:100%;padding:8px;border:1px solid var(--color-border);border-radius:3px;background:var(--bg-page);color:var(--color-text);font-size:14px;">
<datalist id="${uid}-dl">${dlOpts}</datalist> <datalist id="${uid}-dl">${dlOpts}</datalist>
</div>`; </div>`;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment