Commit 2b07302d authored by Your Name's avatar Your Name

Providers for users ok

parent d39326be
...@@ -101,26 +101,31 @@ class ClaudeAuth: ...@@ -101,26 +101,31 @@ class ClaudeAuth:
REDIRECT_URI = DEFAULT_REDIRECT_URI REDIRECT_URI = DEFAULT_REDIRECT_URI
CLI_USER_AGENT = CLI_USER_AGENT CLI_USER_AGENT = CLI_USER_AGENT
def __init__(self, credentials_file: Optional[str] = None, redirect_uri: Optional[str] = None): def __init__(self, credentials_file: Optional[str] = None, redirect_uri: Optional[str] = None, skip_initial_load: bool = False, save_callback: Optional[callable] = None):
""" """
Initialize Claude authentication. Initialize Claude authentication.
Args: Args:
credentials_file: Path to credentials file (default: ~/.aisbf/claude_credentials.json) credentials_file: Path to credentials file (default: ~/.aisbf/claude_credentials.json)
skip_initial_load: If True, do not load credentials from file on initialization
save_callback: Optional callback to save credentials instead of writing to file
""" """
if credentials_file: if credentials_file:
self.credentials_file = Path(credentials_file).expanduser() self.credentials_file = Path(credentials_file).expanduser()
else: else:
# Store credentials in ~/.aisbf/ directory (AISBF config directory) # Store credentials in ~/.aisbf/ directory (AISBF config directory)
self.credentials_file = Path.home() / ".aisbf" / "claude_credentials.json" self.credentials_file = Path.home() / ".aisbf" / "claude_credentials.json"
# Allow overriding redirect URI for reverse proxy deployments # Allow overriding redirect URI for reverse proxy deployments
self.redirect_uri = redirect_uri if redirect_uri is not None else DEFAULT_REDIRECT_URI self.redirect_uri = redirect_uri if redirect_uri is not None else DEFAULT_REDIRECT_URI
self.tokens = self._load_credentials() self.tokens = None
self._save_callback = save_callback
if not skip_initial_load:
self.tokens = self._load_credentials()
self._oauth_state = None # Store state for OAuth flow self._oauth_state = None # Store state for OAuth flow
self._code_verifier = None # Store verifier for OAuth flow self._code_verifier = None # Store verifier for OAuth flow
# Log TLS fingerprinting capability # Log TLS fingerprinting capability
if HAS_CURL_CFFI: if HAS_CURL_CFFI:
logger.info(f"ClaudeAuth initialized with TLS fingerprinting (curl_cffi) - credentials: {self.credentials_file}") logger.info(f"ClaudeAuth initialized with TLS fingerprinting (curl_cffi) - credentials: {self.credentials_file}")
...@@ -142,7 +147,25 @@ class ClaudeAuth: ...@@ -142,7 +147,25 @@ class ClaudeAuth:
return None return None
def _save_credentials(self, data: Dict): def _save_credentials(self, data: Dict):
"""Save credentials to file with file locking to prevent race conditions.""" """
Save credentials:
- If save_callback is provided, use it (database save for user providers)
- Otherwise, save to file with file locking to prevent race conditions
"""
self.tokens = data
if self._save_callback:
# User provider: ONLY use callback, NO file fallback EVER
try:
self._save_callback({'tokens': data})
logger.info("ClaudeAuth: Saved credentials via callback")
return
except Exception as e:
logger.error(f"ClaudeAuth: Failed to save credentials to database: {e}")
# DO NOT FALLBACK TO FILE SAVE FOR REGULAR USERS
raise
# Admin/global provider ONLY: save to file
try: try:
self.tokens = data self.tokens = data
# Store id_token if received (contains account info) # Store id_token if received (contains account info)
......
...@@ -54,13 +54,15 @@ class CodexOAuth2: ...@@ -54,13 +54,15 @@ class CodexOAuth2:
Supports authentication with OpenAI's Codex OAuth2 endpoints. Supports authentication with OpenAI's Codex OAuth2 endpoints.
""" """
def __init__(self, credentials_file: Optional[str] = None, issuer: Optional[str] = None): def __init__(self, credentials_file: Optional[str] = None, issuer: Optional[str] = None, skip_initial_load: bool = False, save_callback: Optional[callable] = None):
""" """
Initialize Codex OAuth2 client. Initialize Codex OAuth2 client.
Args: Args:
credentials_file: Path to credentials JSON file (default: ~/.aisbf/codex_credentials.json) credentials_file: Path to credentials JSON file (default: ~/.aisbf/codex_credentials.json)
issuer: OAuth2 issuer URL (default: https://auth.openai.com) issuer: OAuth2 issuer URL (default: https://auth.openai.com)
skip_initial_load: If True, do not load credentials from file on initialization
save_callback: Optional callback to save credentials instead of writing to file
""" """
# Expand and resolve path immediately to absolute path # Expand and resolve path immediately to absolute path
default_path = os.path.expanduser("~/.aisbf/codex_credentials.json") default_path = os.path.expanduser("~/.aisbf/codex_credentials.json")
...@@ -71,10 +73,12 @@ class CodexOAuth2: ...@@ -71,10 +73,12 @@ class CodexOAuth2:
self.credentials_file = os.path.abspath(expanded) self.credentials_file = os.path.abspath(expanded)
else: else:
self.credentials_file = default_path self.credentials_file = default_path
self.issuer = (issuer or DEFAULT_ISSUER).rstrip("/") self.issuer = (issuer or DEFAULT_ISSUER).rstrip("/")
self.credentials = None self.credentials = None
self._load_credentials() self._save_callback = save_callback
if not skip_initial_load:
self._load_credentials()
def _load_credentials(self) -> None: def _load_credentials(self) -> None:
"""Load credentials from file if it exists.""" """Load credentials from file if it exists."""
...@@ -89,11 +93,27 @@ class CodexOAuth2: ...@@ -89,11 +93,27 @@ class CodexOAuth2:
def _save_credentials(self, credentials: Dict[str, Any]) -> None: def _save_credentials(self, credentials: Dict[str, Any]) -> None:
""" """
Save credentials to file with secure permissions. Save credentials:
- If save_callback is provided, use it (database save for user providers)
- Otherwise, save to file with secure permissions (admin/global providers)
Args: Args:
credentials: Credentials dict to save credentials: Credentials dict to save
""" """
self.credentials = credentials
if self._save_callback:
# User provider: ONLY use callback, NO file fallback EVER
try:
self._save_callback(credentials)
logger.info(f"CodexOAuth2: Saved credentials via callback")
return
except Exception as e:
logger.error(f"CodexOAuth2: Failed to save credentials to database: {e}")
# DO NOT FALLBACK TO FILE SAVE FOR REGULAR USERS
raise
# Admin/global provider ONLY: save to file
try: try:
# Path is already expanded and absolute from __init__ # Path is already expanded and absolute from __init__
resolved_path = self.credentials_file resolved_path = self.credentials_file
......
...@@ -40,18 +40,21 @@ class KiloOAuth2: ...@@ -40,18 +40,21 @@ class KiloOAuth2:
Supports authentication with Kilo Gateway at https://api.kilo.ai. Supports authentication with Kilo Gateway at https://api.kilo.ai.
""" """
def __init__(self, credentials_file: Optional[str] = None, api_base: Optional[str] = None): def __init__(self, credentials_file: Optional[str] = None, api_base: Optional[str] = None, skip_initial_load: bool = False, save_callback: Optional[callable] = None):
""" """
Initialize Kilo OAuth2 client. Initialize Kilo OAuth2 client.
Args: Args:
credentials_file: Path to credentials JSON file (default: ~/.kilo_credentials.json) credentials_file: Path to credentials JSON file (default: ~/.kilo_credentials.json)
api_base: Base URL for Kilo API (default: https://api.kilo.ai) api_base: Base URL for Kilo API (default: https://api.kilo.ai)
skip_initial_load: If True, do not load credentials from file on initialization
""" """
self.credentials_file = os.path.expanduser(credentials_file) if credentials_file else os.path.expanduser("~/.kilo_credentials.json") self.credentials_file = os.path.expanduser(credentials_file) if credentials_file else os.path.expanduser("~/.kilo_credentials.json")
self.api_base = api_base or os.environ.get("KILO_API_URL", "https://api.kilo.ai") self.api_base = api_base or os.environ.get("KILO_API_URL", "https://api.kilo.ai")
self.credentials = None self.credentials = None
self._load_credentials() self._save_callback = save_callback
if not skip_initial_load:
self._load_credentials()
def _load_credentials(self) -> None: def _load_credentials(self) -> None:
"""Load credentials from file if it exists.""" """Load credentials from file if it exists."""
...@@ -66,17 +69,33 @@ class KiloOAuth2: ...@@ -66,17 +69,33 @@ class KiloOAuth2:
def _save_credentials(self, credentials: Dict[str, Any]) -> None: def _save_credentials(self, credentials: Dict[str, Any]) -> None:
""" """
Save credentials to file with secure permissions. Save credentials:
- If save_callback is provided, use it (database save for user providers)
- Otherwise, save to file with secure permissions (admin/global providers)
Args: Args:
credentials: Credentials dict to save credentials: Credentials dict to save
""" """
self.credentials = credentials
if self._save_callback:
# User provider: ONLY use callback, NO file fallback EVER
try:
self._save_callback(credentials)
logger.info(f"KiloOAuth2: Saved credentials via callback")
return
except Exception as e:
logger.error(f"KiloOAuth2: Failed to save credentials to database: {e}")
# DO NOT FALLBACK TO FILE SAVE FOR REGULAR USERS
raise
# Admin/global provider ONLY: save to file
try: try:
# Ensure directory exists # Ensure directory exists
cred_dir = os.path.dirname(self.credentials_file) cred_dir = os.path.dirname(self.credentials_file)
if cred_dir: # Only create if there's a directory component if cred_dir: # Only create if there's a directory component
os.makedirs(cred_dir, exist_ok=True) os.makedirs(cred_dir, exist_ok=True)
# Write credentials # Write credentials
with open(self.credentials_file, 'w') as f: with open(self.credentials_file, 'w') as f:
json.dump(credentials, f, indent=2) json.dump(credentials, f, indent=2)
......
...@@ -96,15 +96,17 @@ class QwenOAuth2: ...@@ -96,15 +96,17 @@ class QwenOAuth2:
Supports authentication with Qwen's OAuth2 endpoints and automatic token refresh. Supports authentication with Qwen's OAuth2 endpoints and automatic token refresh.
""" """
def __init__(self, credentials_file: Optional[str] = None): def __init__(self, credentials_file: Optional[str] = None, skip_initial_load: bool = False, save_callback: Optional[callable] = None):
""" """
Initialize Qwen OAuth2 client. Initialize Qwen OAuth2 client.
⚠️ WARNING: OAuth2 authentication for Qwen has been discontinued. ⚠️ WARNING: OAuth2 authentication for Qwen has been discontinued.
This client will not work with DashScope API. Use API key authentication instead. This client will not work with DashScope API. Use API key authentication instead.
Args: Args:
credentials_file: Path to credentials JSON file (default: ~/.aisbf/qwen_credentials.json) credentials_file: Path to credentials JSON file (default: ~/.aisbf/qwen_credentials.json)
skip_initial_load: If True, do not load credentials from file on initialization
save_callback: Optional callback to save credentials instead of writing to file
""" """
logger.warning( logger.warning(
"⚠️ Qwen OAuth2 service has been discontinued by Qwen. " "⚠️ Qwen OAuth2 service has been discontinued by Qwen. "
...@@ -116,7 +118,9 @@ class QwenOAuth2: ...@@ -116,7 +118,9 @@ class QwenOAuth2:
self.credentials = None self.credentials = None
self._file_mod_time = 0 self._file_mod_time = 0
self._last_check = 0 self._last_check = 0
self._load_credentials() self._save_callback = save_callback
if not skip_initial_load:
self._load_credentials()
def _load_credentials(self) -> None: def _load_credentials(self) -> None:
"""Load credentials from file if it exists.""" """Load credentials from file if it exists."""
...@@ -136,11 +140,27 @@ class QwenOAuth2: ...@@ -136,11 +140,27 @@ class QwenOAuth2:
def _save_credentials(self, credentials: Dict[str, Any]) -> None: def _save_credentials(self, credentials: Dict[str, Any]) -> None:
""" """
Save credentials to file with secure permissions and file locking. Save credentials:
- If save_callback is provided, use it (database save for user providers)
- Otherwise, save to file with secure permissions and file locking (admin/global providers)
Args: Args:
credentials: Credentials dict to save credentials: Credentials dict to save
""" """
self.credentials = credentials
if self._save_callback:
# User provider: ONLY use callback, NO file fallback EVER
try:
self._save_callback(credentials)
logger.info(f"QwenOAuth2: Saved credentials via callback")
return
except Exception as e:
logger.error(f"QwenOAuth2: Failed to save credentials to database: {e}")
# DO NOT FALLBACK TO FILE SAVE FOR REGULAR USERS
raise
# Admin/global provider ONLY: save to file
try: try:
# Ensure directory exists # Ensure directory exists
os.makedirs(os.path.dirname(self.credentials_file), exist_ok=True) os.makedirs(os.path.dirname(self.credentials_file), exist_ok=True)
......
...@@ -102,12 +102,21 @@ class ClaudeProviderHandler(BaseProviderHandler): ...@@ -102,12 +102,21 @@ class ClaudeProviderHandler(BaseProviderHandler):
def _load_auth_from_db(self, provider_id: str, credentials_file: str): def _load_auth_from_db(self, provider_id: str, credentials_file: str):
""" """
Load OAuth2 credentials from database for non-admin users. Load OAuth2 credentials:
Falls back to file-based credentials if not found in database. - Admin users (user_id=None): ONLY load from file
- Regular users: ONLY load from database, NO file fallback
""" """
from ..auth.claude import ClaudeAuth
import logging
if self.user_id is None:
# Admin user: ONLY use file-based credentials
logging.getLogger(__name__).info(f"ClaudeProviderHandler: Admin user, loading credentials from file: {credentials_file}")
return ClaudeAuth(credentials_file=credentials_file)
# Regular user: ONLY use database credentials, NO file fallback
try: try:
from ..database import get_database from ..database import get_database
from ..auth.claude import ClaudeAuth
db = get_database() db = get_database()
if db: if db:
db_creds = db.get_user_oauth2_credentials( db_creds = db.get_user_oauth2_credentials(
...@@ -116,22 +125,24 @@ class ClaudeProviderHandler(BaseProviderHandler): ...@@ -116,22 +125,24 @@ class ClaudeProviderHandler(BaseProviderHandler):
auth_type='claude_oauth2' auth_type='claude_oauth2'
) )
if db_creds and db_creds.get('credentials'): if db_creds and db_creds.get('credentials'):
# Create auth instance with database credentials # Create auth instance with skip_initial_load=True to avoid file read
auth = ClaudeAuth(credentials_file=credentials_file) # Pass save callback to save credentials back to database
# Override the loaded credentials with database credentials auth = ClaudeAuth(
credentials_file=credentials_file,
skip_initial_load=True,
save_callback=lambda creds: self._save_auth_to_db(creds)
)
# Set tokens directly from database
auth.tokens = db_creds['credentials'].get('tokens', {}) auth.tokens = db_creds['credentials'].get('tokens', {})
import logging import logging
logging.getLogger(__name__).info(f"ClaudeProviderHandler: Loaded credentials from database for user {self.user_id}") logging.getLogger(__name__).info(f"ClaudeProviderHandler: Loaded credentials from database for user {self.user_id}")
return auth return auth
except Exception as e: except Exception as e:
import logging
logging.getLogger(__name__).warning(f"ClaudeProviderHandler: Failed to load credentials from database: {e}") logging.getLogger(__name__).warning(f"ClaudeProviderHandler: Failed to load credentials from database: {e}")
# Fall back to file-based credentials # For regular users, NO file fallback - return empty auth instance
from ..auth.claude import ClaudeAuth logging.getLogger(__name__).info(f"ClaudeProviderHandler: No database credentials found for user {self.user_id}, returning unauthenticated instance")
import logging return ClaudeAuth(credentials_file=credentials_file, skip_initial_load=True)
logging.getLogger(__name__).info(f"ClaudeProviderHandler: Falling back to file-based credentials for user {self.user_id}")
return ClaudeAuth(credentials_file=credentials_file)
def _init_session_identifiers(self): def _init_session_identifiers(self):
"""Initialize persistent session identifiers (device_id, account_uuid, session_id).""" """Initialize persistent session identifiers (device_id, account_uuid, session_id)."""
......
...@@ -112,9 +112,22 @@ class CodexProviderHandler(BaseProviderHandler): ...@@ -112,9 +112,22 @@ class CodexProviderHandler(BaseProviderHandler):
def _load_oauth2_from_db(self, provider_id: str, credentials_file: str, issuer: str) -> CodexOAuth2: def _load_oauth2_from_db(self, provider_id: str, credentials_file: str, issuer: str) -> CodexOAuth2:
""" """
Load OAuth2 credentials from database for non-admin users. Load OAuth2 credentials:
Falls back to file-based credentials if not found in database. - Admin users (user_id=None): ONLY load from file
- Regular users: ONLY load from database, NO file fallback
""" """
from ..auth.codex import CodexOAuth2
import logging
if self.user_id is None:
# Admin user: ONLY use file-based credentials
logging.getLogger(__name__).info(f"CodexProviderHandler: Admin user, loading credentials from file: {credentials_file}")
return CodexOAuth2(
credentials_file=credentials_file,
issuer=issuer,
)
# Regular user: ONLY use database credentials, NO file fallback
try: try:
from ..database import get_database from ..database import get_database
db = get_database() db = get_database()
...@@ -125,23 +138,28 @@ class CodexProviderHandler(BaseProviderHandler): ...@@ -125,23 +138,28 @@ class CodexProviderHandler(BaseProviderHandler):
auth_type='codex_oauth2' auth_type='codex_oauth2'
) )
if db_creds and db_creds.get('credentials'): if db_creds and db_creds.get('credentials'):
# Create OAuth2 instance with database credentials # Create OAuth2 instance with skip_initial_load=True to avoid file read
# Pass save callback to save credentials back to database
oauth2 = CodexOAuth2( oauth2 = CodexOAuth2(
credentials_file=credentials_file, credentials_file=credentials_file,
issuer=issuer, issuer=issuer,
skip_initial_load=True,
save_callback=lambda creds: self._save_oauth2_to_db(creds)
) )
# Override the loaded credentials with database credentials # Set credentials directly from database
oauth2.credentials = db_creds['credentials'] oauth2.credentials = db_creds['credentials']
logger.info(f"CodexProviderHandler: Loaded credentials from database for user {self.user_id}") logging.getLogger(__name__).info(f"CodexProviderHandler: Loaded credentials from database for user {self.user_id}")
return oauth2 return oauth2
except Exception as e: except Exception as e:
logger.warning(f"CodexProviderHandler: Failed to load credentials from database: {e}") logging.getLogger(__name__).warning(f"CodexProviderHandler: Failed to load credentials from database: {e}")
# Fall back to file-based credentials # For regular users, NO file fallback - return empty auth instance
logger.info(f"CodexProviderHandler: Falling back to file-based credentials for user {self.user_id}") logging.getLogger(__name__).info(f"CodexProviderHandler: No database credentials found for user {self.user_id}, returning unauthenticated instance")
return CodexOAuth2( return CodexOAuth2(
credentials_file=credentials_file, credentials_file=credentials_file,
issuer=issuer, issuer=issuer,
skip_initial_load=True,
save_callback=lambda creds: self._save_oauth2_to_db(creds)
) )
async def _get_valid_api_key(self) -> str: async def _get_valid_api_key(self) -> str:
......
...@@ -110,12 +110,21 @@ class KiloProviderHandler(BaseProviderHandler): ...@@ -110,12 +110,21 @@ class KiloProviderHandler(BaseProviderHandler):
def _load_oauth2_from_db(self, provider_id: str, credentials_file: str, api_base: str): def _load_oauth2_from_db(self, provider_id: str, credentials_file: str, api_base: str):
""" """
Load OAuth2 credentials from database for non-admin users. Load OAuth2 credentials:
Falls back to file-based credentials if not found in database. - Admin users (user_id=None): ONLY load from file
- Regular users: ONLY load from database, NO file fallback
""" """
from ..auth.kilo import KiloOAuth2
import logging
if self.user_id is None:
# Admin user: ONLY use file-based credentials
logging.getLogger(__name__).info(f"KiloProviderHandler: Admin user, loading credentials from file: {credentials_file}")
return KiloOAuth2(credentials_file=credentials_file, api_base=api_base)
# Regular user: ONLY use database credentials, NO file fallback
try: try:
from ..database import get_database from ..database import get_database
from ..auth.kilo import KiloOAuth2
db = get_database() db = get_database()
if db: if db:
db_creds = db.get_user_oauth2_credentials( db_creds = db.get_user_oauth2_credentials(
...@@ -124,22 +133,29 @@ class KiloProviderHandler(BaseProviderHandler): ...@@ -124,22 +133,29 @@ class KiloProviderHandler(BaseProviderHandler):
auth_type='kilo_oauth2' auth_type='kilo_oauth2'
) )
if db_creds and db_creds.get('credentials'): if db_creds and db_creds.get('credentials'):
# Create OAuth2 instance with database credentials # Create OAuth2 instance with skip_initial_load=True to avoid file read
oauth2 = KiloOAuth2(credentials_file=credentials_file, api_base=api_base) # Pass save callback to save credentials back to database
# Override the loaded credentials with database credentials oauth2 = KiloOAuth2(
credentials_file=credentials_file,
api_base=api_base,
skip_initial_load=True,
save_callback=lambda creds: self._save_oauth2_to_db(creds)
)
# Set credentials directly from database
oauth2.credentials = db_creds['credentials'] oauth2.credentials = db_creds['credentials']
import logging
logging.getLogger(__name__).info(f"KiloProviderHandler: Loaded credentials from database for user {self.user_id}") logging.getLogger(__name__).info(f"KiloProviderHandler: Loaded credentials from database for user {self.user_id}")
return oauth2 return oauth2
except Exception as e: except Exception as e:
import logging
logging.getLogger(__name__).warning(f"KiloProviderHandler: Failed to load credentials from database: {e}") logging.getLogger(__name__).warning(f"KiloProviderHandler: Failed to load credentials from database: {e}")
# Fall back to file-based credentials # For regular users, NO file fallback - return empty auth instance
from ..auth.kilo import KiloOAuth2 logging.getLogger(__name__).info(f"KiloProviderHandler: No database credentials found for user {self.user_id}, returning unauthenticated instance")
import logging return KiloOAuth2(
logging.getLogger(__name__).info(f"KiloProviderHandler: Falling back to file-based credentials for user {self.user_id}") credentials_file=credentials_file,
return KiloOAuth2(credentials_file=credentials_file, api_base=api_base) api_base=api_base,
skip_initial_load=True,
save_callback=lambda creds: self._save_oauth2_to_db(creds)
)
def _save_oauth2_to_db(self, credentials: Dict) -> None: def _save_oauth2_to_db(self, credentials: Dict) -> None:
""" """
......
...@@ -98,12 +98,21 @@ class QwenProviderHandler(BaseProviderHandler): ...@@ -98,12 +98,21 @@ class QwenProviderHandler(BaseProviderHandler):
def _load_auth_from_db(self, provider_id: str, credentials_file: str): def _load_auth_from_db(self, provider_id: str, credentials_file: str):
""" """
Load OAuth2 credentials from database for non-admin users. Load OAuth2 credentials:
Falls back to file-based credentials if not found in database. - Admin users (user_id=None): ONLY load from file
- Regular users: ONLY load from database, NO file fallback
""" """
from ..auth.qwen import QwenOAuth2
import logging
if self.user_id is None:
# Admin user: ONLY use file-based credentials
logging.getLogger(__name__).info(f"QwenProviderHandler: Admin user, loading credentials from file: {credentials_file}")
return QwenOAuth2(credentials_file=credentials_file)
# Regular user: ONLY use database credentials, NO file fallback
try: try:
from ..database import get_database from ..database import get_database
from ..auth.qwen import QwenOAuth2
db = get_database() db = get_database()
if db: if db:
db_creds = db.get_user_oauth2_credentials( db_creds = db.get_user_oauth2_credentials(
...@@ -112,22 +121,52 @@ class QwenProviderHandler(BaseProviderHandler): ...@@ -112,22 +121,52 @@ class QwenProviderHandler(BaseProviderHandler):
auth_type='qwen_oauth2' auth_type='qwen_oauth2'
) )
if db_creds and db_creds.get('credentials'): if db_creds and db_creds.get('credentials'):
# Create auth instance with database credentials # Create auth instance with skip_initial_load=True to avoid file read
auth = QwenOAuth2(credentials_file=credentials_file) # Pass save callback to save credentials back to database
# Override the loaded credentials with database credentials auth = QwenOAuth2(
credentials_file=credentials_file,
skip_initial_load=True,
save_callback=lambda creds: self._save_auth_to_db(creds)
)
# Set credentials directly from database
auth.credentials = db_creds['credentials'] auth.credentials = db_creds['credentials']
import logging
logging.getLogger(__name__).info(f"QwenProviderHandler: Loaded credentials from database for user {self.user_id}") logging.getLogger(__name__).info(f"QwenProviderHandler: Loaded credentials from database for user {self.user_id}")
return auth return auth
except Exception as e: except Exception as e:
import logging
logging.getLogger(__name__).warning(f"QwenProviderHandler: Failed to load credentials from database: {e}") logging.getLogger(__name__).warning(f"QwenProviderHandler: Failed to load credentials from database: {e}")
# Fall back to file-based credentials # For regular users, NO file fallback - return empty auth instance
from ..auth.qwen import QwenOAuth2 logging.getLogger(__name__).info(f"QwenProviderHandler: No database credentials found for user {self.user_id}, returning unauthenticated instance")
import logging return QwenOAuth2(
logging.getLogger(__name__).info(f"QwenProviderHandler: Falling back to file-based credentials for user {self.user_id}") credentials_file=credentials_file,
return QwenOAuth2(credentials_file=credentials_file) skip_initial_load=True,
save_callback=lambda creds: self._save_auth_to_db(creds)
)
def _save_auth_to_db(self, credentials: Dict) -> None:
"""
Save OAuth2 credentials to database for non-admin users.
This is called after successful device flow authentication.
"""
if self.user_id is None:
# Admin user uses file-based credentials, nothing to save to DB
return
try:
from ..database import get_database
db = get_database()
if db:
db.save_user_oauth2_credentials(
user_id=self.user_id,
provider_id=self.provider_id,
auth_type='qwen_oauth2',
credentials=credentials
)
import logging
logging.getLogger(__name__).info(f"QwenProviderHandler: Saved credentials to database for user {self.user_id}")
except Exception as e:
import logging
logging.getLogger(__name__).warning(f"QwenProviderHandler: Failed to save credentials to database: {e}")
async def _get_sdk_client(self): async def _get_sdk_client(self):
"""Get or create an OpenAI SDK client configured with authentication (OAuth2 or API key).""" """Get or create an OpenAI SDK client configured with authentication (OAuth2 or API key)."""
......
...@@ -3786,10 +3786,8 @@ async def dashboard_providers(request: Request): ...@@ -3786,10 +3786,8 @@ async def dashboard_providers(request: Request):
db = DatabaseRegistry.get_config_database() db = DatabaseRegistry.get_config_database()
user_providers = db.get_user_providers(current_user_id) user_providers = db.get_user_providers(current_user_id)
# Convert to the format expected by the frontend # Always pass raw user providers format to the template (array)
providers_data = {} providers_data = user_providers
for provider in user_providers:
providers_data[provider['provider_id']] = provider['config']
# Check for success parameter # Check for success parameter
success = request.query_params.get('success') success = request.query_params.get('success')
...@@ -4020,9 +4018,6 @@ async def dashboard_providers_save(request: Request, config: str = Form(...)): ...@@ -4020,9 +4018,6 @@ async def dashboard_providers_save(request: Request, config: str = Form(...)):
) )
else: else:
success_msg = "Configuration saved successfully!" success_msg = "Configuration saved successfully!"
from aisbf.database import get_database
db = DatabaseRegistry.get_config_database()
user_providers = db.get_user_providers(current_user_id)
return templates.TemplateResponse( return templates.TemplateResponse(
request=request, request=request,
...@@ -4031,7 +4026,7 @@ async def dashboard_providers_save(request: Request, config: str = Form(...)): ...@@ -4031,7 +4026,7 @@ async def dashboard_providers_save(request: Request, config: str = Form(...)):
"request": request, "request": request,
"session": request.session, "session": request.session,
"__version__": __version__, "__version__": __version__,
"user_providers_json": json.dumps(user_providers), "user_providers_json": json.dumps(providers_data),
"user_id": current_user_id, "user_id": current_user_id,
"success": success_msg "success": success_msg
} }
...@@ -4101,20 +4096,19 @@ async def dashboard_providers_get_models(request: Request): ...@@ -4101,20 +4096,19 @@ async def dashboard_providers_get_models(request: Request):
"error": "provider_key is required" "error": "provider_key is required"
}, status_code=400) }, status_code=400)
# Check if provider exists in config # Get user ID from session
if not config or provider_key not in config.providers: current_user_id = request.session.get('user_id')
return JSONResponse({
"success": False,
"error": f"Provider '{provider_key}' not found in configuration"
}, status_code=404)
# Get provider handler # Get provider handler - pass user_id to automatically handle user-specific providers
from aisbf.providers import get_provider_handler from aisbf.providers import get_provider_handler
provider_config = config.providers[provider_key] try:
api_key = provider_config.api_key if hasattr(provider_config, 'api_key') else None handler = get_provider_handler(provider_key, user_id=current_user_id)
except ValueError as e:
handler = get_provider_handler(provider_key, api_key) return JSONResponse({
"success": False,
"error": str(e)
}, status_code=404)
# Fetch models from provider # Fetch models from provider
models_result = await handler.get_models() models_result = await handler.get_models()
...@@ -6131,14 +6125,28 @@ async def dashboard_provider_file_delete( ...@@ -6131,14 +6125,28 @@ async def dashboard_provider_file_delete(
@app.get("/dashboard/providers/{provider_name}/auth/check") @app.get("/dashboard/providers/{provider_name}/auth/check")
async def dashboard_provider_auth_check(request: Request, provider_name: str): async def dashboard_provider_auth_check(request: Request, provider_name: str):
"""Check OAuth authentication status for a provider""" """Check OAuth authentication status for a provider"""
auth_check = require_admin(request) auth_check = require_dashboard_auth(request)
if auth_check: if auth_check:
return auth_check return auth_check
try: try:
# Load current provider configuration # Get user ID from session
config = Config() current_user_id = request.session.get('user_id')
provider_config = config.providers.get(provider_name)
# Load provider configuration
provider_config = None
if current_user_id is None:
# Admin: check global config
global_config = Config()
provider_config = global_config.providers.get(provider_name)
else:
# Regular user: get from user providers
from aisbf.database import DatabaseRegistry
db = DatabaseRegistry.get_config_database()
user_provider = db.get_user_provider(current_user_id, provider_name)
if user_provider:
provider_config = user_provider['config']
if not provider_config: if not provider_config:
return JSONResponse( return JSONResponse(
...@@ -6146,11 +6154,19 @@ async def dashboard_provider_auth_check(request: Request, provider_name: str): ...@@ -6146,11 +6154,19 @@ async def dashboard_provider_auth_check(request: Request, provider_name: str):
content={"authenticated": False, "error": f"Provider '{provider_name}' not found"} content={"authenticated": False, "error": f"Provider '{provider_name}' not found"}
) )
provider_type = provider_config.type # Handle both dict (user providers) and object (global providers)
if isinstance(provider_config, dict):
provider_type = provider_config.get('type')
else:
provider_type = provider_config.type
if provider_type == 'claude': if provider_type == 'claude':
from aisbf.auth.claude import ClaudeAuth from aisbf.auth.claude import ClaudeAuth
claude_config = provider_config.claude_config or {} # Handle dict vs object
if isinstance(provider_config, dict):
claude_config = provider_config.get('claude_config', {})
else:
claude_config = provider_config.claude_config or {}
auth = ClaudeAuth(credentials_file=claude_config.get('credentials_file', '~/.claude_credentials.json')) auth = ClaudeAuth(credentials_file=claude_config.get('credentials_file', '~/.claude_credentials.json'))
is_auth = auth.is_authenticated() is_auth = auth.is_authenticated()
result = {"authenticated": is_auth} result = {"authenticated": is_auth}
...@@ -6160,7 +6176,11 @@ async def dashboard_provider_auth_check(request: Request, provider_name: str): ...@@ -6160,7 +6176,11 @@ async def dashboard_provider_auth_check(request: Request, provider_name: str):
elif provider_type == 'kilocode': elif provider_type == 'kilocode':
from aisbf.auth.kilo import KiloOAuth2 from aisbf.auth.kilo import KiloOAuth2
kilo_config = provider_config.kilo_config or {} # Handle dict vs object
if isinstance(provider_config, dict):
kilo_config = provider_config.get('kilo_config', {})
else:
kilo_config = provider_config.kilo_config or {}
auth = KiloOAuth2(credentials_file=kilo_config.get('credentials_file', '~/.kilo_credentials.json')) auth = KiloOAuth2(credentials_file=kilo_config.get('credentials_file', '~/.kilo_credentials.json'))
is_auth = auth.is_authenticated() is_auth = auth.is_authenticated()
result = {"authenticated": is_auth} result = {"authenticated": is_auth}
...@@ -6172,7 +6192,11 @@ async def dashboard_provider_auth_check(request: Request, provider_name: str): ...@@ -6172,7 +6192,11 @@ async def dashboard_provider_auth_check(request: Request, provider_name: str):
elif provider_type == 'qwen': elif provider_type == 'qwen':
from aisbf.auth.qwen import QwenOAuth2 from aisbf.auth.qwen import QwenOAuth2
qwen_config = provider_config.qwen_config or {} # Handle dict vs object
if isinstance(provider_config, dict):
qwen_config = provider_config.get('qwen_config', {})
else:
qwen_config = provider_config.qwen_config or {}
auth = QwenOAuth2(credentials_file=qwen_config.get('credentials_file', '~/.aisbf/qwen_credentials.json')) auth = QwenOAuth2(credentials_file=qwen_config.get('credentials_file', '~/.aisbf/qwen_credentials.json'))
is_auth = auth.is_authenticated() is_auth = auth.is_authenticated()
result = {"authenticated": is_auth} result = {"authenticated": is_auth}
...@@ -6185,7 +6209,11 @@ async def dashboard_provider_auth_check(request: Request, provider_name: str): ...@@ -6185,7 +6209,11 @@ async def dashboard_provider_auth_check(request: Request, provider_name: str):
elif provider_type == 'codex': elif provider_type == 'codex':
from aisbf.auth.codex import CodexOAuth2 from aisbf.auth.codex import CodexOAuth2
codex_config = provider_config.codex_config or {} # Handle dict vs object
if isinstance(provider_config, dict):
codex_config = provider_config.get('codex_config', {})
else:
codex_config = provider_config.codex_config or {}
auth = CodexOAuth2(credentials_file=codex_config.get('credentials_file', '~/.aisbf/codex_credentials.json')) auth = CodexOAuth2(credentials_file=codex_config.get('credentials_file', '~/.aisbf/codex_credentials.json'))
is_auth = auth.is_authenticated() is_auth = auth.is_authenticated()
result = {"authenticated": is_auth} result = {"authenticated": is_auth}
...@@ -6388,8 +6416,9 @@ async def dashboard_user_tokens_delete(request: Request, token_id: int): ...@@ -6388,8 +6416,9 @@ async def dashboard_user_tokens_delete(request: Request, token_id: int):
try: try:
db.delete_user_api_token(user_id, token_id) db.delete_user_api_token(user_id, token_id)
} return JSONResponse(content={"success": True})
) except Exception as e:
return JSONResponse(status_code=500, content={"error": str(e)})
@app.get("/dashboard/response-cache/stats") @app.get("/dashboard/response-cache/stats")
async def dashboard_response_cache_stats(request: Request): async def dashboard_response_cache_stats(request: Request):
...@@ -12047,9 +12076,7 @@ async def dashboard_claude_auth_start(request: Request): ...@@ -12047,9 +12076,7 @@ async def dashboard_claude_auth_start(request: Request):
from aisbf.auth.claude import ClaudeAuth from aisbf.auth.claude import ClaudeAuth
# Create auth instance # Create auth instance
auth = ClaudeAuth() auth = ClaudeAuth(credentials_file=credentials_file, skip_initial_load=True)
# Override credentials file if specified
auth.credentials_file = Path(credentials_file).expanduser()
# Generate PKCE challenge # Generate PKCE challenge
verifier, challenge = auth._generate_pkce() verifier, challenge = auth._generate_pkce()
...@@ -12366,7 +12393,7 @@ async def dashboard_kilo_auth_start(request: Request): ...@@ -12366,7 +12393,7 @@ async def dashboard_kilo_auth_start(request: Request):
from aisbf.auth.kilo import KiloOAuth2 from aisbf.auth.kilo import KiloOAuth2
# Create auth instance # Create auth instance
auth = KiloOAuth2(credentials_file=credentials_file) auth = KiloOAuth2(credentials_file=credentials_file, skip_initial_load=True)
# Initiate device authorization (async method) # Initiate device authorization (async method)
device_auth = await auth.initiate_device_auth() device_auth = await auth.initiate_device_auth()
...@@ -12437,7 +12464,7 @@ async def dashboard_kilo_auth_poll(request: Request): ...@@ -12437,7 +12464,7 @@ async def dashboard_kilo_auth_poll(request: Request):
from aisbf.auth.kilo import KiloOAuth2 from aisbf.auth.kilo import KiloOAuth2
# Create auth instance # Create auth instance
auth = KiloOAuth2(credentials_file=credentials_file) auth = KiloOAuth2(credentials_file=credentials_file, skip_initial_load=True)
# Poll device authorization status (async method) # Poll device authorization status (async method)
result = await auth.poll_device_auth(device_code) result = await auth.poll_device_auth(device_code)
...@@ -12645,7 +12672,7 @@ async def dashboard_kilo_auth_logout(request: Request): ...@@ -12645,7 +12672,7 @@ async def dashboard_kilo_auth_logout(request: Request):
from aisbf.auth.kilo import KiloOAuth2 from aisbf.auth.kilo import KiloOAuth2
# Create auth instance # Create auth instance
auth = KiloOAuth2(credentials_file=credentials_file) auth = KiloOAuth2(credentials_file=credentials_file, skip_initial_load=True)
# Logout (clear credentials) # Logout (clear credentials)
auth.logout() auth.logout()
...@@ -12687,8 +12714,8 @@ async def dashboard_codex_auth_start(request: Request): ...@@ -12687,8 +12714,8 @@ async def dashboard_codex_auth_start(request: Request):
from aisbf.auth.codex import CodexOAuth2 from aisbf.auth.codex import CodexOAuth2
# Create auth instance # Create auth instance
auth = CodexOAuth2(credentials_file=credentials_file, issuer=issuer) auth = CodexOAuth2(credentials_file=credentials_file, issuer=issuer, skip_initial_load=True)
# Request device code (returns immediately) # Request device code (returns immediately)
device_info = await auth.request_device_code_flow() device_info = await auth.request_device_code_flow()
...@@ -12759,7 +12786,7 @@ async def dashboard_codex_auth_poll(request: Request): ...@@ -12759,7 +12786,7 @@ async def dashboard_codex_auth_poll(request: Request):
from aisbf.auth.codex import CodexOAuth2 from aisbf.auth.codex import CodexOAuth2
# Create auth instance # Create auth instance
auth = CodexOAuth2(credentials_file=credentials_file, issuer=issuer) auth = CodexOAuth2(credentials_file=credentials_file, issuer=issuer, skip_initial_load=True)
# Set device auth info on the instance (required for poll_device_code_completion) # Set device auth info on the instance (required for poll_device_code_completion)
auth._device_auth_id = device_auth_id auth._device_auth_id = device_auth_id
...@@ -13046,10 +13073,10 @@ async def dashboard_qwen_auth_start(request: Request): ...@@ -13046,10 +13073,10 @@ async def dashboard_qwen_auth_start(request: Request):
# Import QwenOAuth2 # Import QwenOAuth2
from aisbf.auth.qwen import QwenOAuth2 from aisbf.auth.qwen import QwenOAuth2
# Create auth instance # Create auth instance
auth = QwenOAuth2(credentials_file=credentials_file) auth = QwenOAuth2(credentials_file=credentials_file, skip_initial_load=True)
logger.info(f"QwenOAuth2: Requesting device code for provider: {provider_key}") logger.info(f"QwenOAuth2: Requesting device code for provider: {provider_key}")
# Request device code # Request device code
...@@ -13125,7 +13152,7 @@ async def dashboard_qwen_auth_poll(request: Request): ...@@ -13125,7 +13152,7 @@ async def dashboard_qwen_auth_poll(request: Request):
from aisbf.auth.qwen import QwenOAuth2 from aisbf.auth.qwen import QwenOAuth2
# Create auth instance # Create auth instance
auth = QwenOAuth2(credentials_file=credentials_file) auth = QwenOAuth2(credentials_file=credentials_file, skip_initial_load=True)
# Poll for token - returns token dict if approved, None if still pending # Poll for token - returns token dict if approved, None if still pending
result = await auth.poll_device_token(device_code, code_verifier) result = await auth.poll_device_token(device_code, code_verifier)
......
...@@ -134,6 +134,7 @@ along with this program. If not, see <https://www.gnu.org/licenses/>. ...@@ -134,6 +134,7 @@ along with this program. If not, see <https://www.gnu.org/licenses/>.
{% block extra_js %} {% block extra_js %}
<script> <script>
let providersData = {}; let providersData = {};
let expandedProviders = new Set();
let rawProviders = {{ user_providers_json | safe }}; let rawProviders = {{ user_providers_json | safe }};
// Convert user providers format to the format expected by the UI // Convert user providers format to the format expected by the UI
...@@ -141,8 +142,6 @@ rawProviders.forEach(provider => { ...@@ -141,8 +142,6 @@ rawProviders.forEach(provider => {
providersData[provider.provider_id] = provider.config; providersData[provider.provider_id] = provider.config;
}); });
let expandedProviders = new Set();
// Chunk size: 512KB chunks for maximum compatibility with restrictive proxies // Chunk size: 512KB chunks for maximum compatibility with restrictive proxies
const CHUNK_SIZE = 512 * 1024; const CHUNK_SIZE = 512 * 1024;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment