Commit 6cf44e43 authored by Your Name's avatar Your Name

feat: update Google OAuth to use display_name for username generation

parent b6e42536
......@@ -24,7 +24,7 @@ Main application for AISBF.
"""
from typing import Optional
from fastapi import FastAPI, HTTPException, Request, status, Form, Query, UploadFile, File
from fastapi.responses import JSONResponse, StreamingResponse, HTMLResponse, RedirectResponse
from fastapi.responses import JSONResponse, StreamingResponse, HTMLResponse, RedirectResponse, Response
from fastapi.middleware.cors import CORSMiddleware
from fastapi.exceptions import RequestValidationError
from fastapi.templating import Jinja2Templates
......@@ -1780,6 +1780,49 @@ async def dashboard_login_page(request: Request):
logger.error(f"Error rendering login page: {e}", exc_info=True)
raise
@app.get("/auth/logincheck")
async def auth_logincheck(request: Request):
"""Serve JavaScript that redirects to dashboard if user is logged in"""
# Check if user is logged in
is_logged_in = request.session.get('logged_in', False)
# Check if session has expired
if is_logged_in:
expires_at = request.session.get('expires_at')
if expires_at and int(time.time()) > expires_at:
is_logged_in = False
# Generate JavaScript response
if is_logged_in:
# Get the dashboard path (not full URL)
root_path = request.scope.get("root_path", "")
dashboard_path = f"{root_path}/dashboard"
js_content = f"""
(function() {{
// Redirect to dashboard if logged in
if (window.location.pathname !== '{dashboard_path}') {{
window.location.href = '{dashboard_path}';
}}
}})();
"""
else:
js_content = """
(function() {
// User not logged in, do nothing
})();
"""
return Response(
content=js_content,
media_type="application/javascript",
headers={
"Cache-Control": "no-cache, no-store, must-revalidate",
"Pragma": "no-cache",
"Expires": "0"
}
)
@app.post("/dashboard/login")
async def dashboard_login(request: Request, username: str = Form(...), password: str = Form(...), remember_me: bool = Form(False)):
"""Handle dashboard login"""
......@@ -2878,7 +2921,7 @@ async def oauth2_google_callback(request: Request, code: str = Query(...), state
return templates.TemplateResponse(
request=request,
name="dashboard/login.html",
context={"request": request, "error": "Invalid authentication state"}
context={"request": request, "config": config, "error": "Invalid authentication state"}
)
# Restore oauth instance
......@@ -2892,7 +2935,7 @@ async def oauth2_google_callback(request: Request, code: str = Query(...), state
return templates.TemplateResponse(
request=request,
name="dashboard/login.html",
context={"request": request, "error": "Failed to authenticate with Google"}
context={"request": request, "config": config, "error": "Failed to authenticate with Google"}
)
# Get user profile
......@@ -2901,11 +2944,13 @@ async def oauth2_google_callback(request: Request, code: str = Query(...), state
return templates.TemplateResponse(
request=request,
name="dashboard/login.html",
context={"request": request, "error": "Could not retrieve your profile from Google"}
context={"request": request, "config": config, "error": "Could not retrieve your profile from Google"}
)
email = user_info.get('email')
email_verified = user_info.get('email_verified', False)
# Get display name from Google
display_name = user_info.get('name', '')
db = DatabaseRegistry.get_config_database()
......@@ -2919,6 +2964,7 @@ async def oauth2_google_callback(request: Request, code: str = Query(...), state
request.session['email'] = existing_user.get('email', '')
request.session['role'] = existing_user['role']
request.session['user_id'] = existing_user['id']
request.session['email_verified'] = True # OAuth2 users have verified emails
request.session['expires_at'] = int(time.time()) + 14 * 24 * 60 * 60
else:
# New user - create account automatically (no password required)
......@@ -2926,23 +2972,27 @@ async def oauth2_google_callback(request: Request, code: str = Query(...), state
return templates.TemplateResponse(
request=request,
name="dashboard/login.html",
context={"request": request, "error": "Google email must be verified to create an account"}
context={"request": request, "config": config, "error": "Google email must be verified to create an account"}
)
# Generate secure random password for OAuth users (never used for login)
random_password = secrets.token_urlsafe(32)
password_hash = hashlib.sha256(random_password.encode()).hexdigest()
# Generate clean username from display_name with email fallback
google_username = db.generate_username_from_display_name(display_name, email)
final_username = db.find_unique_username(google_username)
# Create user with verified email (no verification required)
user_id = db.create_user(email, password_hash, None)
db.verify_email(email) # Mark email as verified automatically
user_id = db.create_user(final_username, password_hash, 'user', None, email, True, display_name)
# Login the new user
request.session['logged_in'] = True
request.session['username'] = email
request.session['username'] = final_username
request.session['email'] = email
request.session['role'] = 'user'
request.session['user_id'] = user_id
request.session['email_verified'] = True # OAuth2 users have verified emails
request.session['expires_at'] = int(time.time()) + 14 * 24 * 60 * 60
# Cleanup session data
......@@ -2956,7 +3006,7 @@ async def oauth2_google_callback(request: Request, code: str = Query(...), state
return templates.TemplateResponse(
request=request,
name="dashboard/login.html",
context={"request": request, "error": "An error occurred during email verification", "config": config.aisbf if config and config.aisbf else {}}
context={"request": request, "config": config, "error": "An error occurred during email verification"}
)
......@@ -3020,7 +3070,7 @@ async def oauth2_github_callback(request: Request, code: str = Query(...), state
return templates.TemplateResponse(
request=request,
name="dashboard/login.html",
context={"request": request, "error": "Invalid authentication state"}
context={"request": request, "config": config, "error": "Invalid authentication state"}
)
# Restore oauth instance
......@@ -3033,7 +3083,7 @@ async def oauth2_github_callback(request: Request, code: str = Query(...), state
return templates.TemplateResponse(
request=request,
name="dashboard/login.html",
context={"request": request, "error": "Failed to authenticate with GitHub"}
context={"request": request, "config": config, "error": "Failed to authenticate with GitHub"}
)
# Get user profile
......@@ -3042,10 +3092,12 @@ async def oauth2_github_callback(request: Request, code: str = Query(...), state
return templates.TemplateResponse(
request=request,
name="dashboard/login.html",
context={"request": request, "error": "Could not retrieve your profile from GitHub. Please ensure your email is public."}
context={"request": request, "config": config, "error": "Could not retrieve your profile from GitHub. Please ensure your email is public."}
)
email = user_info.get('email')
# Use GitHub username or name for the username field
github_username = user_info.get('login') or user_info.get('name') or email.split('@')[0]
db = DatabaseRegistry.get_config_database()
......@@ -3059,6 +3111,7 @@ async def oauth2_github_callback(request: Request, code: str = Query(...), state
request.session['email'] = existing_user.get('email', '')
request.session['role'] = existing_user['role']
request.session['user_id'] = existing_user['id']
request.session['email_verified'] = True # OAuth2 users have verified emails
request.session['expires_at'] = int(time.time()) + 14 * 24 * 60 * 60
else:
# New user - create account automatically (no password required)
......@@ -3066,16 +3119,23 @@ async def oauth2_github_callback(request: Request, code: str = Query(...), state
random_password = secrets.token_urlsafe(32)
password_hash = hashlib.sha256(random_password.encode()).hexdigest()
# Handle duplicate username by appending a number
final_username = github_username
counter = 1
while db.get_user_by_username(final_username):
final_username = f"{github_username}{counter}"
counter += 1
# Create user with verified email (no verification required)
user_id = db.create_user(email, password_hash, None)
db.verify_email(email) # Mark email as verified automatically
user_id = db.create_user(final_username, password_hash, 'user', None, email, True)
# Login the new user
request.session['logged_in'] = True
request.session['username'] = email
request.session['username'] = final_username
request.session['email'] = email
request.session['role'] = 'user'
request.session['user_id'] = user_id
request.session['email_verified'] = True # OAuth2 users have verified emails
request.session['expires_at'] = int(time.time()) + 14 * 24 * 60 * 60
# Cleanup session data
......@@ -3089,7 +3149,7 @@ async def oauth2_github_callback(request: Request, code: str = Query(...), state
return templates.TemplateResponse(
request=request,
name="dashboard/login.html",
context={"request": request, "error": "Authentication failed. Please try again."}
context={"request": request, "config": config, "error": "Authentication failed. Please try again."}
)
def require_dashboard_auth(request: Request):
......@@ -4312,6 +4372,7 @@ async def dashboard_settings_save(
aisbf_config['oauth2']['github']['client_secret'] = oauth2_github_client_secret
elif 'client_secret' not in aisbf_config['oauth2']['github']:
aisbf_config['oauth2']['github']['client_secret'] = ""
aisbf_config['oauth2']['github']['scopes'] = ["user:email", "read:user"]
# Save config
config_path = Path.home() / '.aisbf' / 'aisbf.json'
......@@ -4380,13 +4441,17 @@ async def dashboard_users(request: Request):
# Get all users
users = db.get_users()
# Get all tiers for assignment dropdown
tiers = db.get_all_tiers()
return templates.TemplateResponse(
request=request,
name="dashboard/users.html",
context={
"request": request,
"session": request.session,
"users": users
"users": users,
"tiers": tiers
}
)
......@@ -4489,6 +4554,38 @@ async def dashboard_users_delete(request: Request, user_id: int):
except Exception as e:
return JSONResponse({"success": False, "error": str(e)}, status_code=500)
@app.post("/dashboard/users/{user_id}/tier")
async def dashboard_users_update_tier(request: Request, user_id: int):
"""Update user tier assignment"""
auth_check = require_admin(request)
if auth_check:
return auth_check
from aisbf.database import get_database
db = DatabaseRegistry.get_config_database()
try:
body = await request.json()
tier_id = body.get('tier_id')
if not tier_id:
return JSONResponse({"success": False, "error": "tier_id is required"}, status_code=400)
# Verify tier exists
tier = db.get_tier_by_id(tier_id)
if not tier:
return JSONResponse({"success": False, "error": "Tier not found"}, status_code=404)
# Update user tier
success = db.set_user_tier(user_id, tier_id)
if success:
return JSONResponse({"success": True})
else:
return JSONResponse({"success": False, "error": "Failed to update user tier"}, status_code=500)
except Exception as e:
return JSONResponse({"success": False, "error": str(e)}, status_code=500)
@app.post("/dashboard/restart")
async def dashboard_restart(request: Request):
"""Reload configuration from disk"""
......@@ -5591,7 +5688,8 @@ async def api_create_tier(request: Request):
max_autoselections=body.get('max_autoselections', -1),
max_rotation_models=body.get('max_rotation_models', -1),
max_autoselection_models=body.get('max_autoselection_models', -1),
is_active=body.get('is_active', True)
is_active=body.get('is_active', True),
is_visible=body.get('is_visible', True)
)
return JSONResponse({"success": True, "tier_id": tier_id})
......@@ -5638,6 +5736,8 @@ async def api_update_tier(request: Request, tier_id: int):
update_kwargs['max_autoselection_models'] = body['max_autoselection_models']
if 'is_active' in body:
update_kwargs['is_active'] = body['is_active']
if 'is_visible' in body:
update_kwargs['is_visible'] = body['is_visible']
success = db.update_tier(tier_id, **update_kwargs)
......@@ -5738,7 +5838,8 @@ async def dashboard_admin_tier_save(request: Request):
'max_autoselections': int(form.get('max_autoselections', -1)),
'max_rotation_models': int(form.get('max_rotation_models', -1)),
'max_autoselection_models': int(form.get('max_autoselection_models', -1)),
'is_active': form.get('is_active') == '1'
'is_active': form.get('is_active') == '1',
'is_visible': form.get('is_visible') == '1'
}
if tier_id:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment