Code cleanup and fixes

parent bf4e59fe
......@@ -50,6 +50,40 @@ def get_db_executor():
return _db_executor
class _MySQLConnectionWrapper:
"""Wrapper that gives mysql.connector connections a reliable context manager protocol.
mysql-connector-python's C extension (__enter__/__exit__) has version-dependent
behaviour (some versions return a cursor from __enter__, others close the connection
in __exit__ unexpectedly). This wrapper always yields the raw connection and
handles commit/rollback/close explicitly.
"""
def __init__(self, conn):
self._conn = conn
def __enter__(self):
return self._conn
def __exit__(self, exc_type, exc_val, exc_tb):
if exc_type is None:
self._conn.commit()
else:
try:
self._conn.rollback()
except Exception:
pass
# Connection is intentionally left open: cursor and conn variables in the
# calling function remain valid after the with-block exits (matching SQLite's
# context-manager behaviour). The connection is closed by GC when the caller
# function returns and conn goes out of scope.
return False
# Forward attribute access so the wrapper can be used directly as well
def __getattr__(self, name):
return getattr(self._conn, name)
class DatabaseManager:
"""
Manages database for persistent tracking of context dimensions and rate limiting.
......@@ -103,9 +137,9 @@ class DatabaseManager:
port=self.db_config['mysql_port'],
user=self.db_config['mysql_user'],
password=self.db_config['mysql_password'],
database=self.db_config['mysql_database']
database=self.db_config['mysql_database'],
)
return conn
return _MySQLConnectionWrapper(conn)
except Exception as e:
logger.error(f"MySQL connection failed: {e}")
raise
......@@ -119,27 +153,20 @@ class DatabaseManager:
async def execute(self, sql: str, params: dict = None):
"""Execute SQL query and return result with mappings (compatible with AsyncSession interface)"""
_params = params or {}
def _sync_execute():
with self._get_connection() as conn:
cursor = conn.cursor()
cursor.row_factory = sqlite3.Row
params = params or {}
# Safe parameter handling - use native database parameter binding
if self.db_type == 'sqlite':
# SQLite natively supports :named parameters directly
cursor.execute(sql, params)
cursor = conn.cursor()
cursor.row_factory = sqlite3.Row
cursor.execute(sql, _params)
else:
# For MySQL, safely convert named parameters to %s placeholders
param_names = []
def replace_param(match):
param_names.append(match.group(1))
return '%s'
cursor = conn.cursor(dictionary=True)
import re
processed_sql = re.sub(r':(\w+)', replace_param, sql)
params_list = [params[name] for name in param_names]
cursor.execute(processed_sql, params_list)
param_names = []
processed_sql = re.sub(r':(\w+)', lambda m: (param_names.append(m.group(1)), '%s')[1], sql)
cursor.execute(processed_sql, [_params[n] for n in param_names])
if cursor.description:
rows = [dict(row) for row in cursor.fetchall()]
# Simulate SQLAlchemy Result object with mappings() method
......@@ -3328,6 +3355,7 @@ def DatabaseManager__init__(self, db_config: Optional[Dict[str, Any]] = None, da
self.db_config = db_config
self.db_type = self.db_config.get('type', 'sqlite').lower()
self.executor = get_db_executor()
if self.db_type == 'mysql':
# Import the module-level MYSQL_AVAILABLE flag
......@@ -3542,13 +3570,18 @@ def DatabaseManager__initialize_database(self):
# ''')
#
# try:
# cursor.execute('''
# CREATE INDEX IF NOT EXISTS idx_model_embeddings_provider_model
# ON model_embeddings(provider_id, model_name)
# ''')
# except:
# pass
# # Index creation moved to separate migration
#
# Create admin settings table for system configuration
cursor.execute(f'''
CREATE TABLE IF NOT EXISTS admin_settings (
id INTEGER PRIMARY KEY {auto_increment},
setting_key VARCHAR(255) UNIQUE NOT NULL,
setting_value TEXT,
updated_at TIMESTAMP DEFAULT {timestamp_default}
)
''')
# Create users table for multi-user management
cursor.execute(f'''
CREATE TABLE IF NOT EXISTS users (
......@@ -3997,9 +4030,8 @@ def DatabaseManager__initialize_database(self):
# Run configuration database migrations if this is a CONFIG database
if self.database_type == DatabaseRegistry.TYPE_CONFIG:
self._run_config_migrations(cursor, auto_increment, timestamp_default, boolean_type)
conn.commit()
logger.info(f"Database tables initialized successfully for {self.database_type} database")
conn.commit()
logger.info(f"Database tables initialized successfully for {self.database_type} database")
def DatabaseManager__create_config_tables(self, cursor, auto_increment, timestamp_default, boolean_type):
......@@ -4176,6 +4208,7 @@ def DatabaseManager__run_config_migrations(self, cursor, auto_increment, timesta
max_autoselections INTEGER DEFAULT -1,
max_rotation_models INTEGER DEFAULT -1,
max_autoselection_models INTEGER DEFAULT -1,
is_visible {boolean_type} DEFAULT 1,
created_at TIMESTAMP DEFAULT {timestamp_default},
updated_at TIMESTAMP DEFAULT {timestamp_default}
)
......@@ -4204,6 +4237,7 @@ def DatabaseManager__run_config_migrations(self, cursor, auto_increment, timesta
max_autoselections INTEGER DEFAULT -1,
max_rotation_models INTEGER DEFAULT -1,
max_autoselection_models INTEGER DEFAULT -1,
is_visible {boolean_type} DEFAULT 1,
created_at TIMESTAMP DEFAULT {timestamp_default},
updated_at TIMESTAMP DEFAULT {timestamp_default}
)
......@@ -4227,7 +4261,8 @@ def DatabaseManager__run_config_migrations(self, cursor, auto_increment, timesta
('max_rotation_models', 'INTEGER DEFAULT -1'),
('max_autoselection_models', 'INTEGER DEFAULT -1'),
('is_default', f'{boolean_type} DEFAULT 0'),
('is_active', f'{boolean_type} DEFAULT 1')
('is_active', f'{boolean_type} DEFAULT 1'),
('is_visible', f'{boolean_type} DEFAULT 1')
]
col_count = 0
for col_name, col_def in tier_columns:
......
......@@ -494,6 +494,8 @@ app = FastAPI(
# Initialize Jinja2 templates with custom globals for proxy-aware URLs
templates = Jinja2Templates(directory="templates")
# Add root templates directory to search path for parent template resolution
templates.env.loader.searchpath.insert(0, "templates")
# Monkey patch TemplateResponse to automatically add dashboard context variables
original_template_response = templates.TemplateResponse
......@@ -8141,9 +8143,13 @@ async def dashboard_wallet(request: Request):
wallet_manager = WalletManager(db)
wallet = await wallet_manager.get_wallet(user_id)
all_gateways = db.get_payment_gateway_settings()
enabled_gateways = {k: v for k, v in all_gateways.items() if v.get('enabled', False)}
return templates.TemplateResponse("dashboard/wallet.html", {
"request": request,
"wallet": wallet
"wallet": wallet,
"enabled_gateways": enabled_gateways,
})
except ImportError:
return HTMLResponse("Wallet functionality not available", status_code=503)
......@@ -8154,33 +8160,165 @@ async def dashboard_wallet(request: Request):
"error": "Failed to load wallet. Please try again later."
}, status_code=500)
@app.post("/dashboard/wallet/topup")
async def dashboard_wallet_topup(request: Request):
"""Session-authenticated wallet top-up — supports all admin-enabled gateways."""
from fastapi.responses import JSONResponse
auth_check = require_dashboard_auth(request)
if auth_check:
return JSONResponse({"error": "Unauthorized"}, status_code=401)
user_id = request.session.get('user_id')
try:
body = await request.json()
except Exception:
return JSONResponse({"error": "Invalid request body"}, status_code=400)
method = (body.get('payment_method') or '').lower()
amount = body.get('amount')
try:
amount = float(amount)
except (TypeError, ValueError):
return JSONResponse({"error": "Invalid amount"}, status_code=400)
if amount < 5 or amount > 500:
return JSONResponse({"error": "Amount must be between $5 and $500"}, status_code=400)
db = DatabaseRegistry.get_config_database()
gateways = db.get_payment_gateway_settings()
gw = gateways.get(method, {})
if not gw.get('enabled', False):
return JSONResponse({"error": f"Payment method '{method}' is not enabled"}, status_code=400)
# Crypto: return deposit address (manual transfer)
crypto_methods = {'bitcoin', 'ethereum', 'usdt', 'usdc'}
if method in crypto_methods:
address = gw.get('address', '')
if not address:
return JSONResponse({"error": "Crypto address not configured"}, status_code=503)
return JSONResponse({
"type": "crypto",
"method": method,
"address": address,
"amount": amount,
"network": gw.get('network', ''),
"confirmations": gw.get('confirmations', 3),
})
# Stripe: create checkout session
if method == 'stripe':
try:
payment_service = getattr(request.app.state, 'payment_service', None)
if payment_service and hasattr(payment_service, 'stripe_handler'):
from decimal import Decimal
intent = await payment_service.stripe_handler.create_payment_intent(
user_id, Decimal(str(amount)), metadata={"type": "wallet_topup"}
)
return JSONResponse({"type": "stripe", "client_secret": intent.client_secret})
# Fallback: redirect to Stripe-hosted checkout via publishable key
import stripe
stripe.api_key = gw.get('secret_key', '')
session = stripe.checkout.Session.create(
payment_method_types=['card'],
line_items=[{
'price_data': {
'currency': 'usd',
'product_data': {'name': 'Wallet Top-Up'},
'unit_amount': int(amount * 100),
},
'quantity': 1,
}],
mode='payment',
success_url=str(request.base_url) + 'dashboard/wallet?topup=success',
cancel_url=str(request.base_url) + 'dashboard/wallet?topup=cancelled',
metadata={'type': 'wallet_topup', 'user_id': str(user_id)},
)
return JSONResponse({"type": "stripe", "checkout_url": session.url})
except Exception as e:
logger.error(f"Stripe top-up error: {e}")
return JSONResponse({"error": "Stripe checkout failed. Please try again."}, status_code=502)
# PayPal: create order
if method == 'paypal':
try:
payment_service = getattr(request.app.state, 'payment_service', None)
if payment_service and hasattr(payment_service, 'paypal_handler'):
from decimal import Decimal
order = await payment_service.paypal_handler.create_order(
user_id, Decimal(str(amount)), metadata={"type": "wallet_topup"}
)
return JSONResponse({"type": "paypal", "order_id": order.id})
# Fallback: direct PayPal redirect
client_id = gw.get('client_id', '')
sandbox = gw.get('sandbox', True)
paypal_base = "https://www.sandbox.paypal.com" if sandbox else "https://www.paypal.com"
return JSONResponse({
"type": "paypal",
"paypal_base": paypal_base,
"client_id": client_id,
"amount": amount,
})
except Exception as e:
logger.error(f"PayPal top-up error: {e}")
return JSONResponse({"error": "PayPal checkout failed. Please try again."}, status_code=502)
return JSONResponse({"error": f"Unsupported payment method: {method}"}, status_code=400)
@app.get("/dashboard/wallet/transactions")
async def dashboard_wallet_transactions(request: Request, limit: int = 50, offset: int = 0):
"""Session-authenticated wallet transaction history (used by the wallet dashboard page)."""
auth_check = require_dashboard_auth(request)
if auth_check:
from fastapi.responses import JSONResponse
return JSONResponse({"error": "Unauthorized"}, status_code=401)
user_id = request.session.get('user_id')
try:
from aisbf.payments.wallet.manager import WalletManager
db = DatabaseRegistry.get_config_database()
wallet_manager = WalletManager(db)
transactions = await wallet_manager.get_transactions(user_id, limit=limit, offset=offset)
return transactions
except Exception as e:
logger.error(f"Failed to load wallet transactions: {e}")
from fastapi.responses import JSONResponse
return JSONResponse({"error": "Failed to load transactions"}, status_code=500)
@app.get("/dashboard/billing")
async def dashboard_billing(request: Request):
"""User payment transaction history page"""
auth_check = require_dashboard_auth(request)
if auth_check:
return auth_check
db = DatabaseRegistry.get_config_database()
user_id = request.session.get('user_id')
# Get user payment methods
payment_methods = db.get_user_payment_methods(user_id)
# Get payment transactions
transactions = db.get_user_payment_transactions(user_id)
# Get enabled payment gateways
enabled_gateways = []
gateways = db.get_payment_gateway_settings()
for gateway, settings in gateways.items():
if settings.get('enabled', False):
enabled_gateways.append(gateway)
# Get user wallet
currency_settings = db.get_currency_settings()
currency_code = currency_settings.get('currency_code', 'EUR')
wallet = db.get_user_wallet(user_id) or {'balance': '0.00', 'currency_code': currency_code, 'auto_topup_enabled': False}
try:
from aisbf.payments.wallet.manager import WalletManager
wallet_manager = WalletManager(db)
wallet = await wallet_manager.get_wallet(user_id)
except Exception:
wallet = {'balance': '0.00', 'currency_code': currency_code, 'auto_topup_enabled': False}
return templates.TemplateResponse(
request=request,
......
......@@ -230,6 +230,8 @@ setup(
'templates/dashboard/add_payment_method.html',
'templates/dashboard/paypal_connect.html',
'templates/dashboard/cache_settings.html',
'templates/dashboard/wallet.html',
'templates/dashboard/error.html',
]),
# Install static files (extension and favicon)
('share/aisbf/static', [
......
{% extends "base.html" %}
{% block title %}Error{% endblock %}
{% block content %}
<div class="container mt-5">
<div class="row justify-content-center">
<div class="col-md-6">
<div class="card">
<div class="card-header bg-danger text-white">
<h5 class="mb-0">Error</h5>
</div>
<div class="card-body">
<div class="alert alert-danger">
{{ error }}
</div>
<div class="d-flex justify-content-between">
<a href="{{ request.url_for('dashboard_index') }}" class="btn btn-primary">
<i class="fas fa-home"></i> Go to Dashboard
</a>
<button onclick="history.back()" class="btn btn-outline-secondary">
<i class="fas fa-arrow-left"></i> Go Back
</button>
</div>
</div>
</div>
</div>
</div>
</div>
{% endblock %}
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment