Commit 616a3a42 authored by Your Name's avatar Your Name

Fix cross-domain OAuth2 callback handling - use state as unique key

parent 2765eea2
......@@ -12009,7 +12009,15 @@ def _start_localhost_callback_server():
logger.info(f"Localhost callback server received - Code: {code[:10] if code else 'None'}...")
# Store the callback data
# Store the callback data using state as key AND also store latest for backward compatibility
if state:
_pending_oauth2_callbacks[state] = {
'code': code,
'state': state,
'error': error,
'timestamp': time.time()
}
# Also store latest for backward compatibility
_pending_oauth2_callbacks['latest'] = {
'code': code,
'state': state,
......@@ -12132,15 +12140,27 @@ async def dashboard_oauth2_callback(
status_code=400
)
# Store the code in session for the auth completion
request.session['oauth2_code'] = code
request.session['oauth2_state'] = state
# Store callback in global storage using STATE as unique key
# This works across domains even when session cookie is not available
_pending_oauth2_callbacks[state] = {
'code': code,
'state': state,
'error': error,
'timestamp': time.time()
}
# Also try to store in session if cookie is available (same domain)
try:
request.session['oauth2_code'] = code
request.session['oauth2_state'] = state
except:
pass
# Detect if this is a direct localhost callback (no extension involved)
referer = request.headers.get('referer', '')
is_direct_callback = 'localhost:54545' in referer or '127.0.0.1:54545' in referer
logger.info(f"OAuth2 callback received - Direct: {is_direct_callback}, Code: {code[:10]}...")
logger.info(f"OAuth2 callback received - Direct: {is_direct_callback}, State: {state[:10]}..., Code: {code[:10]}...")
# Return success page with auto-close script
return HTMLResponse(
......@@ -12306,25 +12326,26 @@ async def dashboard_claude_auth_complete(request: Request):
return auth_check
try:
# Get code from session (stored by callback endpoint) or from localhost callback server
code = request.session.get('oauth2_code')
verifier = request.session.get('oauth2_verifier')
# Get state from user's session (always available since user was the one who started auth)
state = request.session.get('oauth2_state')
verifier = request.session.get('oauth2_verifier')
credentials_file = request.session.get('oauth2_credentials_file', '~/.claude_credentials.json')
# Check for callback data from localhost server if not in session
if not code and 'latest' in _pending_oauth2_callbacks:
callback_data = _pending_oauth2_callbacks['latest']
# Get code from session OR from global storage using state
code = request.session.get('oauth2_code')
# Check global storage for THIS user's state
if not code and state and state in _pending_oauth2_callbacks:
callback_data = _pending_oauth2_callbacks[state]
# Only use if received within the last 5 minutes
if time.time() - callback_data.get('timestamp', 0) < 300:
code = callback_data.get('code')
state = callback_data.get('state') or state # Use callback state if available
if callback_data.get('error'):
return JSONResponse(
status_code=400,
content={"success": False, "error": f"OAuth2 error: {callback_data['error']}"}
)
logger.info(f"Using code from localhost callback server: {code[:10] if code else 'None'}...")
logger.info(f"Using code from global callback storage for state {state[:10]}...: {code[:10] if code else 'None'}...")
if not code or not verifier:
return JSONResponse(
......@@ -12385,8 +12406,9 @@ async def dashboard_claude_auth_complete(request: Request):
request.session.pop('oauth2_provider', None)
request.session.pop('oauth2_credentials_file', None)
# Clear pending callback data
_pending_oauth2_callbacks.pop('latest', None)
# Clear pending callback data for THIS user's state
if state:
_pending_oauth2_callbacks.pop(state, None)
return JSONResponse({
"success": True,
......@@ -12414,9 +12436,12 @@ async def dashboard_claude_auth_callback_status(request: Request):
if auth_check:
return auth_check
# Check if we have callback data from the localhost server
if 'latest' in _pending_oauth2_callbacks:
callback_data = _pending_oauth2_callbacks['latest']
# Get expected state from user's session (this is what we generated when auth started)
expected_state = request.session.get('oauth2_state')
# Check if we have callback data matching THIS user's state
if expected_state and expected_state in _pending_oauth2_callbacks:
callback_data = _pending_oauth2_callbacks[expected_state]
# Only valid if received within the last 5 minutes
if time.time() - callback_data.get('timestamp', 0) < 300:
if callback_data.get('error'):
......@@ -12430,13 +12455,20 @@ async def dashboard_claude_auth_callback_status(request: Request):
"has_code": True
})
# Also check session (for extension flow)
# Also check session (for same domain flow)
if request.session.get('oauth2_code'):
return JSONResponse({
"received": True,
"has_code": True
})
# Garbage collect stale entries older than 10 minutes
now = time.time()
stale_states = [k for k, v in _pending_oauth2_callbacks.items()
if k != 'latest' and now - v.get('timestamp', 0) > 600]
for stale in stale_states:
_pending_oauth2_callbacks.pop(stale, None)
return JSONResponse({
"received": False
})
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment