refactor(auth): separate ClaudeAuth refresh/no-refresh methods

parent 8bc9a2df
......@@ -311,26 +311,43 @@ class ClaudeAuth:
logger.error(f"Token refresh failed after {max_retries} attempts")
return False
async def get_valid_token(self, auto_login: bool = False) -> str:
def get_valid_token(self) -> Optional[str]:
"""
Get a valid access token without attempting refresh.
This is a quick check method that returns the token if valid,
or None if expired. It does NOT attempt to refresh the token.
Returns:
Valid access token or None if expired/not authenticated
"""
if not self.tokens:
return None
# Check if token is expired (with 5 minute buffer)
if time.time() > (self.tokens.get('expires_at', 0) - 300):
return None
return self.tokens.get('access_token')
async def get_valid_token_with_refresh(self, auto_login: bool = False) -> Optional[str]:
"""
Get a valid access token, refreshing it if necessary.
Args:
auto_login: If True, automatically trigger login flow when no credentials exist.
If False, raise an exception instead (default: False for security).
If False, return None instead (default: False for security).
Returns:
Valid access token
Raises:
Exception: If no credentials exist and auto_login is False
Valid access token or None if refresh fails
"""
if not self.tokens:
if not auto_login:
logger.error("No Claude credentials available. Please authenticate via dashboard or MCP.")
raise Exception("Claude authentication required. Please authenticate via /dashboard/claude/auth/start or MCP tool.")
return None
logger.info("No tokens available, starting login flow")
self.login()
return self.tokens.get('access_token') if self.tokens else None
# Refresh if less than 5 minutes remain
if time.time() > (self.tokens.get('expires_at', 0) - 300):
......@@ -338,11 +355,12 @@ class ClaudeAuth:
if not await self.refresh_token():
if not auto_login:
logger.error("Token refresh failed and auto_login is disabled")
raise Exception("Claude token refresh failed. Please re-authenticate via /dashboard/claude/auth/start or MCP tool.")
return None
logger.warning("Refresh failed, re-authenticating...")
self.login()
return self.tokens.get('access_token') if self.tokens else None
return self.tokens['access_token']
return self.tokens.get('access_token')
def get_account_id(self) -> Optional[str]:
"""
......
import pytest
import time
import json
import tempfile
import os
from aisbf.auth.claude import ClaudeAuth
@pytest.fixture
def temp_credentials_file():
"""Create a temporary credentials file."""
fd, path = tempfile.mkstemp(suffix='.json')
os.close(fd)
yield path
if os.path.exists(path):
os.remove(path)
@pytest.fixture
def valid_tokens():
"""Valid tokens with future expiry."""
return {
"access_token": "valid_access_token",
"refresh_token": "valid_refresh_token",
"expires_in": 3600,
"expires_at": time.time() + 3600, # Expires in 1 hour
"token_type": "Bearer"
}
@pytest.fixture
def expired_tokens():
"""Expired tokens."""
return {
"access_token": "expired_access_token",
"refresh_token": "expired_refresh_token",
"expires_in": 3600,
"expires_at": time.time() - 3600, # Expired 1 hour ago
"token_type": "Bearer"
}
def test_get_valid_token_returns_token_when_valid(temp_credentials_file, valid_tokens):
"""Test that get_valid_token returns token when valid (no refresh)."""
# Save valid tokens
with open(temp_credentials_file, 'w') as f:
json.dump(valid_tokens, f)
auth = ClaudeAuth(credentials_file=temp_credentials_file)
token = auth.get_valid_token()
assert token == "valid_access_token"
def test_get_valid_token_returns_none_when_expired(temp_credentials_file, expired_tokens):
"""Test that get_valid_token returns None when expired (no refresh attempt)."""
# Save expired tokens
with open(temp_credentials_file, 'w') as f:
json.dump(expired_tokens, f)
auth = ClaudeAuth(credentials_file=temp_credentials_file)
token = auth.get_valid_token()
assert token is None
@pytest.mark.asyncio
async def test_get_valid_token_with_refresh_returns_valid_token(temp_credentials_file, valid_tokens):
"""Test that get_valid_token_with_refresh returns token when valid."""
# Save valid tokens
with open(temp_credentials_file, 'w') as f:
json.dump(valid_tokens, f)
auth = ClaudeAuth(credentials_file=temp_credentials_file)
token = await auth.get_valid_token_with_refresh()
assert token == "valid_access_token"
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment