Commit 2b07302d authored by Your Name's avatar Your Name

Providers for users ok

parent d39326be
...@@ -101,12 +101,14 @@ class ClaudeAuth: ...@@ -101,12 +101,14 @@ class ClaudeAuth:
REDIRECT_URI = DEFAULT_REDIRECT_URI REDIRECT_URI = DEFAULT_REDIRECT_URI
CLI_USER_AGENT = CLI_USER_AGENT CLI_USER_AGENT = CLI_USER_AGENT
def __init__(self, credentials_file: Optional[str] = None, redirect_uri: Optional[str] = None): def __init__(self, credentials_file: Optional[str] = None, redirect_uri: Optional[str] = None, skip_initial_load: bool = False, save_callback: Optional[callable] = None):
""" """
Initialize Claude authentication. Initialize Claude authentication.
Args: Args:
credentials_file: Path to credentials file (default: ~/.aisbf/claude_credentials.json) credentials_file: Path to credentials file (default: ~/.aisbf/claude_credentials.json)
skip_initial_load: If True, do not load credentials from file on initialization
save_callback: Optional callback to save credentials instead of writing to file
""" """
if credentials_file: if credentials_file:
self.credentials_file = Path(credentials_file).expanduser() self.credentials_file = Path(credentials_file).expanduser()
...@@ -117,6 +119,9 @@ class ClaudeAuth: ...@@ -117,6 +119,9 @@ class ClaudeAuth:
# Allow overriding redirect URI for reverse proxy deployments # Allow overriding redirect URI for reverse proxy deployments
self.redirect_uri = redirect_uri if redirect_uri is not None else DEFAULT_REDIRECT_URI self.redirect_uri = redirect_uri if redirect_uri is not None else DEFAULT_REDIRECT_URI
self.tokens = None
self._save_callback = save_callback
if not skip_initial_load:
self.tokens = self._load_credentials() self.tokens = self._load_credentials()
self._oauth_state = None # Store state for OAuth flow self._oauth_state = None # Store state for OAuth flow
self._code_verifier = None # Store verifier for OAuth flow self._code_verifier = None # Store verifier for OAuth flow
...@@ -142,7 +147,25 @@ class ClaudeAuth: ...@@ -142,7 +147,25 @@ class ClaudeAuth:
return None return None
def _save_credentials(self, data: Dict): def _save_credentials(self, data: Dict):
"""Save credentials to file with file locking to prevent race conditions.""" """
Save credentials:
- If save_callback is provided, use it (database save for user providers)
- Otherwise, save to file with file locking to prevent race conditions
"""
self.tokens = data
if self._save_callback:
# User provider: ONLY use callback, NO file fallback EVER
try:
self._save_callback({'tokens': data})
logger.info("ClaudeAuth: Saved credentials via callback")
return
except Exception as e:
logger.error(f"ClaudeAuth: Failed to save credentials to database: {e}")
# DO NOT FALLBACK TO FILE SAVE FOR REGULAR USERS
raise
# Admin/global provider ONLY: save to file
try: try:
self.tokens = data self.tokens = data
# Store id_token if received (contains account info) # Store id_token if received (contains account info)
......
...@@ -54,13 +54,15 @@ class CodexOAuth2: ...@@ -54,13 +54,15 @@ class CodexOAuth2:
Supports authentication with OpenAI's Codex OAuth2 endpoints. Supports authentication with OpenAI's Codex OAuth2 endpoints.
""" """
def __init__(self, credentials_file: Optional[str] = None, issuer: Optional[str] = None): def __init__(self, credentials_file: Optional[str] = None, issuer: Optional[str] = None, skip_initial_load: bool = False, save_callback: Optional[callable] = None):
""" """
Initialize Codex OAuth2 client. Initialize Codex OAuth2 client.
Args: Args:
credentials_file: Path to credentials JSON file (default: ~/.aisbf/codex_credentials.json) credentials_file: Path to credentials JSON file (default: ~/.aisbf/codex_credentials.json)
issuer: OAuth2 issuer URL (default: https://auth.openai.com) issuer: OAuth2 issuer URL (default: https://auth.openai.com)
skip_initial_load: If True, do not load credentials from file on initialization
save_callback: Optional callback to save credentials instead of writing to file
""" """
# Expand and resolve path immediately to absolute path # Expand and resolve path immediately to absolute path
default_path = os.path.expanduser("~/.aisbf/codex_credentials.json") default_path = os.path.expanduser("~/.aisbf/codex_credentials.json")
...@@ -74,6 +76,8 @@ class CodexOAuth2: ...@@ -74,6 +76,8 @@ class CodexOAuth2:
self.issuer = (issuer or DEFAULT_ISSUER).rstrip("/") self.issuer = (issuer or DEFAULT_ISSUER).rstrip("/")
self.credentials = None self.credentials = None
self._save_callback = save_callback
if not skip_initial_load:
self._load_credentials() self._load_credentials()
def _load_credentials(self) -> None: def _load_credentials(self) -> None:
...@@ -89,11 +93,27 @@ class CodexOAuth2: ...@@ -89,11 +93,27 @@ class CodexOAuth2:
def _save_credentials(self, credentials: Dict[str, Any]) -> None: def _save_credentials(self, credentials: Dict[str, Any]) -> None:
""" """
Save credentials to file with secure permissions. Save credentials:
- If save_callback is provided, use it (database save for user providers)
- Otherwise, save to file with secure permissions (admin/global providers)
Args: Args:
credentials: Credentials dict to save credentials: Credentials dict to save
""" """
self.credentials = credentials
if self._save_callback:
# User provider: ONLY use callback, NO file fallback EVER
try:
self._save_callback(credentials)
logger.info(f"CodexOAuth2: Saved credentials via callback")
return
except Exception as e:
logger.error(f"CodexOAuth2: Failed to save credentials to database: {e}")
# DO NOT FALLBACK TO FILE SAVE FOR REGULAR USERS
raise
# Admin/global provider ONLY: save to file
try: try:
# Path is already expanded and absolute from __init__ # Path is already expanded and absolute from __init__
resolved_path = self.credentials_file resolved_path = self.credentials_file
......
...@@ -40,17 +40,20 @@ class KiloOAuth2: ...@@ -40,17 +40,20 @@ class KiloOAuth2:
Supports authentication with Kilo Gateway at https://api.kilo.ai. Supports authentication with Kilo Gateway at https://api.kilo.ai.
""" """
def __init__(self, credentials_file: Optional[str] = None, api_base: Optional[str] = None): def __init__(self, credentials_file: Optional[str] = None, api_base: Optional[str] = None, skip_initial_load: bool = False, save_callback: Optional[callable] = None):
""" """
Initialize Kilo OAuth2 client. Initialize Kilo OAuth2 client.
Args: Args:
credentials_file: Path to credentials JSON file (default: ~/.kilo_credentials.json) credentials_file: Path to credentials JSON file (default: ~/.kilo_credentials.json)
api_base: Base URL for Kilo API (default: https://api.kilo.ai) api_base: Base URL for Kilo API (default: https://api.kilo.ai)
skip_initial_load: If True, do not load credentials from file on initialization
""" """
self.credentials_file = os.path.expanduser(credentials_file) if credentials_file else os.path.expanduser("~/.kilo_credentials.json") self.credentials_file = os.path.expanduser(credentials_file) if credentials_file else os.path.expanduser("~/.kilo_credentials.json")
self.api_base = api_base or os.environ.get("KILO_API_URL", "https://api.kilo.ai") self.api_base = api_base or os.environ.get("KILO_API_URL", "https://api.kilo.ai")
self.credentials = None self.credentials = None
self._save_callback = save_callback
if not skip_initial_load:
self._load_credentials() self._load_credentials()
def _load_credentials(self) -> None: def _load_credentials(self) -> None:
...@@ -66,11 +69,27 @@ class KiloOAuth2: ...@@ -66,11 +69,27 @@ class KiloOAuth2:
def _save_credentials(self, credentials: Dict[str, Any]) -> None: def _save_credentials(self, credentials: Dict[str, Any]) -> None:
""" """
Save credentials to file with secure permissions. Save credentials:
- If save_callback is provided, use it (database save for user providers)
- Otherwise, save to file with secure permissions (admin/global providers)
Args: Args:
credentials: Credentials dict to save credentials: Credentials dict to save
""" """
self.credentials = credentials
if self._save_callback:
# User provider: ONLY use callback, NO file fallback EVER
try:
self._save_callback(credentials)
logger.info(f"KiloOAuth2: Saved credentials via callback")
return
except Exception as e:
logger.error(f"KiloOAuth2: Failed to save credentials to database: {e}")
# DO NOT FALLBACK TO FILE SAVE FOR REGULAR USERS
raise
# Admin/global provider ONLY: save to file
try: try:
# Ensure directory exists # Ensure directory exists
cred_dir = os.path.dirname(self.credentials_file) cred_dir = os.path.dirname(self.credentials_file)
......
...@@ -96,7 +96,7 @@ class QwenOAuth2: ...@@ -96,7 +96,7 @@ class QwenOAuth2:
Supports authentication with Qwen's OAuth2 endpoints and automatic token refresh. Supports authentication with Qwen's OAuth2 endpoints and automatic token refresh.
""" """
def __init__(self, credentials_file: Optional[str] = None): def __init__(self, credentials_file: Optional[str] = None, skip_initial_load: bool = False, save_callback: Optional[callable] = None):
""" """
Initialize Qwen OAuth2 client. Initialize Qwen OAuth2 client.
...@@ -105,6 +105,8 @@ class QwenOAuth2: ...@@ -105,6 +105,8 @@ class QwenOAuth2:
Args: Args:
credentials_file: Path to credentials JSON file (default: ~/.aisbf/qwen_credentials.json) credentials_file: Path to credentials JSON file (default: ~/.aisbf/qwen_credentials.json)
skip_initial_load: If True, do not load credentials from file on initialization
save_callback: Optional callback to save credentials instead of writing to file
""" """
logger.warning( logger.warning(
"⚠️ Qwen OAuth2 service has been discontinued by Qwen. " "⚠️ Qwen OAuth2 service has been discontinued by Qwen. "
...@@ -116,6 +118,8 @@ class QwenOAuth2: ...@@ -116,6 +118,8 @@ class QwenOAuth2:
self.credentials = None self.credentials = None
self._file_mod_time = 0 self._file_mod_time = 0
self._last_check = 0 self._last_check = 0
self._save_callback = save_callback
if not skip_initial_load:
self._load_credentials() self._load_credentials()
def _load_credentials(self) -> None: def _load_credentials(self) -> None:
...@@ -136,11 +140,27 @@ class QwenOAuth2: ...@@ -136,11 +140,27 @@ class QwenOAuth2:
def _save_credentials(self, credentials: Dict[str, Any]) -> None: def _save_credentials(self, credentials: Dict[str, Any]) -> None:
""" """
Save credentials to file with secure permissions and file locking. Save credentials:
- If save_callback is provided, use it (database save for user providers)
- Otherwise, save to file with secure permissions and file locking (admin/global providers)
Args: Args:
credentials: Credentials dict to save credentials: Credentials dict to save
""" """
self.credentials = credentials
if self._save_callback:
# User provider: ONLY use callback, NO file fallback EVER
try:
self._save_callback(credentials)
logger.info(f"QwenOAuth2: Saved credentials via callback")
return
except Exception as e:
logger.error(f"QwenOAuth2: Failed to save credentials to database: {e}")
# DO NOT FALLBACK TO FILE SAVE FOR REGULAR USERS
raise
# Admin/global provider ONLY: save to file
try: try:
# Ensure directory exists # Ensure directory exists
os.makedirs(os.path.dirname(self.credentials_file), exist_ok=True) os.makedirs(os.path.dirname(self.credentials_file), exist_ok=True)
......
...@@ -102,12 +102,21 @@ class ClaudeProviderHandler(BaseProviderHandler): ...@@ -102,12 +102,21 @@ class ClaudeProviderHandler(BaseProviderHandler):
def _load_auth_from_db(self, provider_id: str, credentials_file: str): def _load_auth_from_db(self, provider_id: str, credentials_file: str):
""" """
Load OAuth2 credentials from database for non-admin users. Load OAuth2 credentials:
Falls back to file-based credentials if not found in database. - Admin users (user_id=None): ONLY load from file
- Regular users: ONLY load from database, NO file fallback
""" """
from ..auth.claude import ClaudeAuth
import logging
if self.user_id is None:
# Admin user: ONLY use file-based credentials
logging.getLogger(__name__).info(f"ClaudeProviderHandler: Admin user, loading credentials from file: {credentials_file}")
return ClaudeAuth(credentials_file=credentials_file)
# Regular user: ONLY use database credentials, NO file fallback
try: try:
from ..database import get_database from ..database import get_database
from ..auth.claude import ClaudeAuth
db = get_database() db = get_database()
if db: if db:
db_creds = db.get_user_oauth2_credentials( db_creds = db.get_user_oauth2_credentials(
...@@ -116,22 +125,24 @@ class ClaudeProviderHandler(BaseProviderHandler): ...@@ -116,22 +125,24 @@ class ClaudeProviderHandler(BaseProviderHandler):
auth_type='claude_oauth2' auth_type='claude_oauth2'
) )
if db_creds and db_creds.get('credentials'): if db_creds and db_creds.get('credentials'):
# Create auth instance with database credentials # Create auth instance with skip_initial_load=True to avoid file read
auth = ClaudeAuth(credentials_file=credentials_file) # Pass save callback to save credentials back to database
# Override the loaded credentials with database credentials auth = ClaudeAuth(
credentials_file=credentials_file,
skip_initial_load=True,
save_callback=lambda creds: self._save_auth_to_db(creds)
)
# Set tokens directly from database
auth.tokens = db_creds['credentials'].get('tokens', {}) auth.tokens = db_creds['credentials'].get('tokens', {})
import logging import logging
logging.getLogger(__name__).info(f"ClaudeProviderHandler: Loaded credentials from database for user {self.user_id}") logging.getLogger(__name__).info(f"ClaudeProviderHandler: Loaded credentials from database for user {self.user_id}")
return auth return auth
except Exception as e: except Exception as e:
import logging
logging.getLogger(__name__).warning(f"ClaudeProviderHandler: Failed to load credentials from database: {e}") logging.getLogger(__name__).warning(f"ClaudeProviderHandler: Failed to load credentials from database: {e}")
# Fall back to file-based credentials # For regular users, NO file fallback - return empty auth instance
from ..auth.claude import ClaudeAuth logging.getLogger(__name__).info(f"ClaudeProviderHandler: No database credentials found for user {self.user_id}, returning unauthenticated instance")
import logging return ClaudeAuth(credentials_file=credentials_file, skip_initial_load=True)
logging.getLogger(__name__).info(f"ClaudeProviderHandler: Falling back to file-based credentials for user {self.user_id}")
return ClaudeAuth(credentials_file=credentials_file)
def _init_session_identifiers(self): def _init_session_identifiers(self):
"""Initialize persistent session identifiers (device_id, account_uuid, session_id).""" """Initialize persistent session identifiers (device_id, account_uuid, session_id)."""
......
...@@ -112,9 +112,22 @@ class CodexProviderHandler(BaseProviderHandler): ...@@ -112,9 +112,22 @@ class CodexProviderHandler(BaseProviderHandler):
def _load_oauth2_from_db(self, provider_id: str, credentials_file: str, issuer: str) -> CodexOAuth2: def _load_oauth2_from_db(self, provider_id: str, credentials_file: str, issuer: str) -> CodexOAuth2:
""" """
Load OAuth2 credentials from database for non-admin users. Load OAuth2 credentials:
Falls back to file-based credentials if not found in database. - Admin users (user_id=None): ONLY load from file
- Regular users: ONLY load from database, NO file fallback
""" """
from ..auth.codex import CodexOAuth2
import logging
if self.user_id is None:
# Admin user: ONLY use file-based credentials
logging.getLogger(__name__).info(f"CodexProviderHandler: Admin user, loading credentials from file: {credentials_file}")
return CodexOAuth2(
credentials_file=credentials_file,
issuer=issuer,
)
# Regular user: ONLY use database credentials, NO file fallback
try: try:
from ..database import get_database from ..database import get_database
db = get_database() db = get_database()
...@@ -125,23 +138,28 @@ class CodexProviderHandler(BaseProviderHandler): ...@@ -125,23 +138,28 @@ class CodexProviderHandler(BaseProviderHandler):
auth_type='codex_oauth2' auth_type='codex_oauth2'
) )
if db_creds and db_creds.get('credentials'): if db_creds and db_creds.get('credentials'):
# Create OAuth2 instance with database credentials # Create OAuth2 instance with skip_initial_load=True to avoid file read
# Pass save callback to save credentials back to database
oauth2 = CodexOAuth2( oauth2 = CodexOAuth2(
credentials_file=credentials_file, credentials_file=credentials_file,
issuer=issuer, issuer=issuer,
skip_initial_load=True,
save_callback=lambda creds: self._save_oauth2_to_db(creds)
) )
# Override the loaded credentials with database credentials # Set credentials directly from database
oauth2.credentials = db_creds['credentials'] oauth2.credentials = db_creds['credentials']
logger.info(f"CodexProviderHandler: Loaded credentials from database for user {self.user_id}") logging.getLogger(__name__).info(f"CodexProviderHandler: Loaded credentials from database for user {self.user_id}")
return oauth2 return oauth2
except Exception as e: except Exception as e:
logger.warning(f"CodexProviderHandler: Failed to load credentials from database: {e}") logging.getLogger(__name__).warning(f"CodexProviderHandler: Failed to load credentials from database: {e}")
# Fall back to file-based credentials # For regular users, NO file fallback - return empty auth instance
logger.info(f"CodexProviderHandler: Falling back to file-based credentials for user {self.user_id}") logging.getLogger(__name__).info(f"CodexProviderHandler: No database credentials found for user {self.user_id}, returning unauthenticated instance")
return CodexOAuth2( return CodexOAuth2(
credentials_file=credentials_file, credentials_file=credentials_file,
issuer=issuer, issuer=issuer,
skip_initial_load=True,
save_callback=lambda creds: self._save_oauth2_to_db(creds)
) )
async def _get_valid_api_key(self) -> str: async def _get_valid_api_key(self) -> str:
......
...@@ -110,12 +110,21 @@ class KiloProviderHandler(BaseProviderHandler): ...@@ -110,12 +110,21 @@ class KiloProviderHandler(BaseProviderHandler):
def _load_oauth2_from_db(self, provider_id: str, credentials_file: str, api_base: str): def _load_oauth2_from_db(self, provider_id: str, credentials_file: str, api_base: str):
""" """
Load OAuth2 credentials from database for non-admin users. Load OAuth2 credentials:
Falls back to file-based credentials if not found in database. - Admin users (user_id=None): ONLY load from file
- Regular users: ONLY load from database, NO file fallback
""" """
from ..auth.kilo import KiloOAuth2
import logging
if self.user_id is None:
# Admin user: ONLY use file-based credentials
logging.getLogger(__name__).info(f"KiloProviderHandler: Admin user, loading credentials from file: {credentials_file}")
return KiloOAuth2(credentials_file=credentials_file, api_base=api_base)
# Regular user: ONLY use database credentials, NO file fallback
try: try:
from ..database import get_database from ..database import get_database
from ..auth.kilo import KiloOAuth2
db = get_database() db = get_database()
if db: if db:
db_creds = db.get_user_oauth2_credentials( db_creds = db.get_user_oauth2_credentials(
...@@ -124,22 +133,29 @@ class KiloProviderHandler(BaseProviderHandler): ...@@ -124,22 +133,29 @@ class KiloProviderHandler(BaseProviderHandler):
auth_type='kilo_oauth2' auth_type='kilo_oauth2'
) )
if db_creds and db_creds.get('credentials'): if db_creds and db_creds.get('credentials'):
# Create OAuth2 instance with database credentials # Create OAuth2 instance with skip_initial_load=True to avoid file read
oauth2 = KiloOAuth2(credentials_file=credentials_file, api_base=api_base) # Pass save callback to save credentials back to database
# Override the loaded credentials with database credentials oauth2 = KiloOAuth2(
credentials_file=credentials_file,
api_base=api_base,
skip_initial_load=True,
save_callback=lambda creds: self._save_oauth2_to_db(creds)
)
# Set credentials directly from database
oauth2.credentials = db_creds['credentials'] oauth2.credentials = db_creds['credentials']
import logging
logging.getLogger(__name__).info(f"KiloProviderHandler: Loaded credentials from database for user {self.user_id}") logging.getLogger(__name__).info(f"KiloProviderHandler: Loaded credentials from database for user {self.user_id}")
return oauth2 return oauth2
except Exception as e: except Exception as e:
import logging
logging.getLogger(__name__).warning(f"KiloProviderHandler: Failed to load credentials from database: {e}") logging.getLogger(__name__).warning(f"KiloProviderHandler: Failed to load credentials from database: {e}")
# Fall back to file-based credentials # For regular users, NO file fallback - return empty auth instance
from ..auth.kilo import KiloOAuth2 logging.getLogger(__name__).info(f"KiloProviderHandler: No database credentials found for user {self.user_id}, returning unauthenticated instance")
import logging return KiloOAuth2(
logging.getLogger(__name__).info(f"KiloProviderHandler: Falling back to file-based credentials for user {self.user_id}") credentials_file=credentials_file,
return KiloOAuth2(credentials_file=credentials_file, api_base=api_base) api_base=api_base,
skip_initial_load=True,
save_callback=lambda creds: self._save_oauth2_to_db(creds)
)
def _save_oauth2_to_db(self, credentials: Dict) -> None: def _save_oauth2_to_db(self, credentials: Dict) -> None:
""" """
......
...@@ -98,12 +98,21 @@ class QwenProviderHandler(BaseProviderHandler): ...@@ -98,12 +98,21 @@ class QwenProviderHandler(BaseProviderHandler):
def _load_auth_from_db(self, provider_id: str, credentials_file: str): def _load_auth_from_db(self, provider_id: str, credentials_file: str):
""" """
Load OAuth2 credentials from database for non-admin users. Load OAuth2 credentials:
Falls back to file-based credentials if not found in database. - Admin users (user_id=None): ONLY load from file
- Regular users: ONLY load from database, NO file fallback
""" """
from ..auth.qwen import QwenOAuth2
import logging
if self.user_id is None:
# Admin user: ONLY use file-based credentials
logging.getLogger(__name__).info(f"QwenProviderHandler: Admin user, loading credentials from file: {credentials_file}")
return QwenOAuth2(credentials_file=credentials_file)
# Regular user: ONLY use database credentials, NO file fallback
try: try:
from ..database import get_database from ..database import get_database
from ..auth.qwen import QwenOAuth2
db = get_database() db = get_database()
if db: if db:
db_creds = db.get_user_oauth2_credentials( db_creds = db.get_user_oauth2_credentials(
...@@ -112,22 +121,52 @@ class QwenProviderHandler(BaseProviderHandler): ...@@ -112,22 +121,52 @@ class QwenProviderHandler(BaseProviderHandler):
auth_type='qwen_oauth2' auth_type='qwen_oauth2'
) )
if db_creds and db_creds.get('credentials'): if db_creds and db_creds.get('credentials'):
# Create auth instance with database credentials # Create auth instance with skip_initial_load=True to avoid file read
auth = QwenOAuth2(credentials_file=credentials_file) # Pass save callback to save credentials back to database
# Override the loaded credentials with database credentials auth = QwenOAuth2(
credentials_file=credentials_file,
skip_initial_load=True,
save_callback=lambda creds: self._save_auth_to_db(creds)
)
# Set credentials directly from database
auth.credentials = db_creds['credentials'] auth.credentials = db_creds['credentials']
import logging
logging.getLogger(__name__).info(f"QwenProviderHandler: Loaded credentials from database for user {self.user_id}") logging.getLogger(__name__).info(f"QwenProviderHandler: Loaded credentials from database for user {self.user_id}")
return auth return auth
except Exception as e: except Exception as e:
import logging
logging.getLogger(__name__).warning(f"QwenProviderHandler: Failed to load credentials from database: {e}") logging.getLogger(__name__).warning(f"QwenProviderHandler: Failed to load credentials from database: {e}")
# Fall back to file-based credentials # For regular users, NO file fallback - return empty auth instance
from ..auth.qwen import QwenOAuth2 logging.getLogger(__name__).info(f"QwenProviderHandler: No database credentials found for user {self.user_id}, returning unauthenticated instance")
return QwenOAuth2(
credentials_file=credentials_file,
skip_initial_load=True,
save_callback=lambda creds: self._save_auth_to_db(creds)
)
def _save_auth_to_db(self, credentials: Dict) -> None:
"""
Save OAuth2 credentials to database for non-admin users.
This is called after successful device flow authentication.
"""
if self.user_id is None:
# Admin user uses file-based credentials, nothing to save to DB
return
try:
from ..database import get_database
db = get_database()
if db:
db.save_user_oauth2_credentials(
user_id=self.user_id,
provider_id=self.provider_id,
auth_type='qwen_oauth2',
credentials=credentials
)
import logging import logging
logging.getLogger(__name__).info(f"QwenProviderHandler: Falling back to file-based credentials for user {self.user_id}") logging.getLogger(__name__).info(f"QwenProviderHandler: Saved credentials to database for user {self.user_id}")
return QwenOAuth2(credentials_file=credentials_file) except Exception as e:
import logging
logging.getLogger(__name__).warning(f"QwenProviderHandler: Failed to save credentials to database: {e}")
async def _get_sdk_client(self): async def _get_sdk_client(self):
"""Get or create an OpenAI SDK client configured with authentication (OAuth2 or API key).""" """Get or create an OpenAI SDK client configured with authentication (OAuth2 or API key)."""
......
...@@ -3786,10 +3786,8 @@ async def dashboard_providers(request: Request): ...@@ -3786,10 +3786,8 @@ async def dashboard_providers(request: Request):
db = DatabaseRegistry.get_config_database() db = DatabaseRegistry.get_config_database()
user_providers = db.get_user_providers(current_user_id) user_providers = db.get_user_providers(current_user_id)
# Convert to the format expected by the frontend # Always pass raw user providers format to the template (array)
providers_data = {} providers_data = user_providers
for provider in user_providers:
providers_data[provider['provider_id']] = provider['config']
# Check for success parameter # Check for success parameter
success = request.query_params.get('success') success = request.query_params.get('success')
...@@ -4020,9 +4018,6 @@ async def dashboard_providers_save(request: Request, config: str = Form(...)): ...@@ -4020,9 +4018,6 @@ async def dashboard_providers_save(request: Request, config: str = Form(...)):
) )
else: else:
success_msg = "Configuration saved successfully!" success_msg = "Configuration saved successfully!"
from aisbf.database import get_database
db = DatabaseRegistry.get_config_database()
user_providers = db.get_user_providers(current_user_id)
return templates.TemplateResponse( return templates.TemplateResponse(
request=request, request=request,
...@@ -4031,7 +4026,7 @@ async def dashboard_providers_save(request: Request, config: str = Form(...)): ...@@ -4031,7 +4026,7 @@ async def dashboard_providers_save(request: Request, config: str = Form(...)):
"request": request, "request": request,
"session": request.session, "session": request.session,
"__version__": __version__, "__version__": __version__,
"user_providers_json": json.dumps(user_providers), "user_providers_json": json.dumps(providers_data),
"user_id": current_user_id, "user_id": current_user_id,
"success": success_msg "success": success_msg
} }
...@@ -4101,20 +4096,19 @@ async def dashboard_providers_get_models(request: Request): ...@@ -4101,20 +4096,19 @@ async def dashboard_providers_get_models(request: Request):
"error": "provider_key is required" "error": "provider_key is required"
}, status_code=400) }, status_code=400)
# Check if provider exists in config # Get user ID from session
if not config or provider_key not in config.providers: current_user_id = request.session.get('user_id')
return JSONResponse({
"success": False,
"error": f"Provider '{provider_key}' not found in configuration"
}, status_code=404)
# Get provider handler # Get provider handler - pass user_id to automatically handle user-specific providers
from aisbf.providers import get_provider_handler from aisbf.providers import get_provider_handler
provider_config = config.providers[provider_key] try:
api_key = provider_config.api_key if hasattr(provider_config, 'api_key') else None handler = get_provider_handler(provider_key, user_id=current_user_id)
except ValueError as e:
handler = get_provider_handler(provider_key, api_key) return JSONResponse({
"success": False,
"error": str(e)
}, status_code=404)
# Fetch models from provider # Fetch models from provider
models_result = await handler.get_models() models_result = await handler.get_models()
...@@ -6131,14 +6125,28 @@ async def dashboard_provider_file_delete( ...@@ -6131,14 +6125,28 @@ async def dashboard_provider_file_delete(
@app.get("/dashboard/providers/{provider_name}/auth/check") @app.get("/dashboard/providers/{provider_name}/auth/check")
async def dashboard_provider_auth_check(request: Request, provider_name: str): async def dashboard_provider_auth_check(request: Request, provider_name: str):
"""Check OAuth authentication status for a provider""" """Check OAuth authentication status for a provider"""
auth_check = require_admin(request) auth_check = require_dashboard_auth(request)
if auth_check: if auth_check:
return auth_check return auth_check
try: try:
# Load current provider configuration # Get user ID from session
config = Config() current_user_id = request.session.get('user_id')
provider_config = config.providers.get(provider_name)
# Load provider configuration
provider_config = None
if current_user_id is None:
# Admin: check global config
global_config = Config()
provider_config = global_config.providers.get(provider_name)
else:
# Regular user: get from user providers
from aisbf.database import DatabaseRegistry
db = DatabaseRegistry.get_config_database()
user_provider = db.get_user_provider(current_user_id, provider_name)
if user_provider:
provider_config = user_provider['config']
if not provider_config: if not provider_config:
return JSONResponse( return JSONResponse(
...@@ -6146,10 +6154,18 @@ async def dashboard_provider_auth_check(request: Request, provider_name: str): ...@@ -6146,10 +6154,18 @@ async def dashboard_provider_auth_check(request: Request, provider_name: str):
content={"authenticated": False, "error": f"Provider '{provider_name}' not found"} content={"authenticated": False, "error": f"Provider '{provider_name}' not found"}
) )
# Handle both dict (user providers) and object (global providers)
if isinstance(provider_config, dict):
provider_type = provider_config.get('type')
else:
provider_type = provider_config.type provider_type = provider_config.type
if provider_type == 'claude': if provider_type == 'claude':
from aisbf.auth.claude import ClaudeAuth from aisbf.auth.claude import ClaudeAuth
# Handle dict vs object
if isinstance(provider_config, dict):
claude_config = provider_config.get('claude_config', {})
else:
claude_config = provider_config.claude_config or {} claude_config = provider_config.claude_config or {}
auth = ClaudeAuth(credentials_file=claude_config.get('credentials_file', '~/.claude_credentials.json')) auth = ClaudeAuth(credentials_file=claude_config.get('credentials_file', '~/.claude_credentials.json'))
is_auth = auth.is_authenticated() is_auth = auth.is_authenticated()
...@@ -6160,6 +6176,10 @@ async def dashboard_provider_auth_check(request: Request, provider_name: str): ...@@ -6160,6 +6176,10 @@ async def dashboard_provider_auth_check(request: Request, provider_name: str):
elif provider_type == 'kilocode': elif provider_type == 'kilocode':
from aisbf.auth.kilo import KiloOAuth2 from aisbf.auth.kilo import KiloOAuth2
# Handle dict vs object
if isinstance(provider_config, dict):
kilo_config = provider_config.get('kilo_config', {})
else:
kilo_config = provider_config.kilo_config or {} kilo_config = provider_config.kilo_config or {}
auth = KiloOAuth2(credentials_file=kilo_config.get('credentials_file', '~/.kilo_credentials.json')) auth = KiloOAuth2(credentials_file=kilo_config.get('credentials_file', '~/.kilo_credentials.json'))
is_auth = auth.is_authenticated() is_auth = auth.is_authenticated()
...@@ -6172,6 +6192,10 @@ async def dashboard_provider_auth_check(request: Request, provider_name: str): ...@@ -6172,6 +6192,10 @@ async def dashboard_provider_auth_check(request: Request, provider_name: str):
elif provider_type == 'qwen': elif provider_type == 'qwen':
from aisbf.auth.qwen import QwenOAuth2 from aisbf.auth.qwen import QwenOAuth2
# Handle dict vs object
if isinstance(provider_config, dict):
qwen_config = provider_config.get('qwen_config', {})
else:
qwen_config = provider_config.qwen_config or {} qwen_config = provider_config.qwen_config or {}
auth = QwenOAuth2(credentials_file=qwen_config.get('credentials_file', '~/.aisbf/qwen_credentials.json')) auth = QwenOAuth2(credentials_file=qwen_config.get('credentials_file', '~/.aisbf/qwen_credentials.json'))
is_auth = auth.is_authenticated() is_auth = auth.is_authenticated()
...@@ -6185,6 +6209,10 @@ async def dashboard_provider_auth_check(request: Request, provider_name: str): ...@@ -6185,6 +6209,10 @@ async def dashboard_provider_auth_check(request: Request, provider_name: str):
elif provider_type == 'codex': elif provider_type == 'codex':
from aisbf.auth.codex import CodexOAuth2 from aisbf.auth.codex import CodexOAuth2
# Handle dict vs object
if isinstance(provider_config, dict):
codex_config = provider_config.get('codex_config', {})
else:
codex_config = provider_config.codex_config or {} codex_config = provider_config.codex_config or {}
auth = CodexOAuth2(credentials_file=codex_config.get('credentials_file', '~/.aisbf/codex_credentials.json')) auth = CodexOAuth2(credentials_file=codex_config.get('credentials_file', '~/.aisbf/codex_credentials.json'))
is_auth = auth.is_authenticated() is_auth = auth.is_authenticated()
...@@ -6388,8 +6416,9 @@ async def dashboard_user_tokens_delete(request: Request, token_id: int): ...@@ -6388,8 +6416,9 @@ async def dashboard_user_tokens_delete(request: Request, token_id: int):
try: try:
db.delete_user_api_token(user_id, token_id) db.delete_user_api_token(user_id, token_id)
} return JSONResponse(content={"success": True})
) except Exception as e:
return JSONResponse(status_code=500, content={"error": str(e)})
@app.get("/dashboard/response-cache/stats") @app.get("/dashboard/response-cache/stats")
async def dashboard_response_cache_stats(request: Request): async def dashboard_response_cache_stats(request: Request):
...@@ -12047,9 +12076,7 @@ async def dashboard_claude_auth_start(request: Request): ...@@ -12047,9 +12076,7 @@ async def dashboard_claude_auth_start(request: Request):
from aisbf.auth.claude import ClaudeAuth from aisbf.auth.claude import ClaudeAuth
# Create auth instance # Create auth instance
auth = ClaudeAuth() auth = ClaudeAuth(credentials_file=credentials_file, skip_initial_load=True)
# Override credentials file if specified
auth.credentials_file = Path(credentials_file).expanduser()
# Generate PKCE challenge # Generate PKCE challenge
verifier, challenge = auth._generate_pkce() verifier, challenge = auth._generate_pkce()
...@@ -12366,7 +12393,7 @@ async def dashboard_kilo_auth_start(request: Request): ...@@ -12366,7 +12393,7 @@ async def dashboard_kilo_auth_start(request: Request):
from aisbf.auth.kilo import KiloOAuth2 from aisbf.auth.kilo import KiloOAuth2
# Create auth instance # Create auth instance
auth = KiloOAuth2(credentials_file=credentials_file) auth = KiloOAuth2(credentials_file=credentials_file, skip_initial_load=True)
# Initiate device authorization (async method) # Initiate device authorization (async method)
device_auth = await auth.initiate_device_auth() device_auth = await auth.initiate_device_auth()
...@@ -12437,7 +12464,7 @@ async def dashboard_kilo_auth_poll(request: Request): ...@@ -12437,7 +12464,7 @@ async def dashboard_kilo_auth_poll(request: Request):
from aisbf.auth.kilo import KiloOAuth2 from aisbf.auth.kilo import KiloOAuth2
# Create auth instance # Create auth instance
auth = KiloOAuth2(credentials_file=credentials_file) auth = KiloOAuth2(credentials_file=credentials_file, skip_initial_load=True)
# Poll device authorization status (async method) # Poll device authorization status (async method)
result = await auth.poll_device_auth(device_code) result = await auth.poll_device_auth(device_code)
...@@ -12645,7 +12672,7 @@ async def dashboard_kilo_auth_logout(request: Request): ...@@ -12645,7 +12672,7 @@ async def dashboard_kilo_auth_logout(request: Request):
from aisbf.auth.kilo import KiloOAuth2 from aisbf.auth.kilo import KiloOAuth2
# Create auth instance # Create auth instance
auth = KiloOAuth2(credentials_file=credentials_file) auth = KiloOAuth2(credentials_file=credentials_file, skip_initial_load=True)
# Logout (clear credentials) # Logout (clear credentials)
auth.logout() auth.logout()
...@@ -12687,7 +12714,7 @@ async def dashboard_codex_auth_start(request: Request): ...@@ -12687,7 +12714,7 @@ async def dashboard_codex_auth_start(request: Request):
from aisbf.auth.codex import CodexOAuth2 from aisbf.auth.codex import CodexOAuth2
# Create auth instance # Create auth instance
auth = CodexOAuth2(credentials_file=credentials_file, issuer=issuer) auth = CodexOAuth2(credentials_file=credentials_file, issuer=issuer, skip_initial_load=True)
# Request device code (returns immediately) # Request device code (returns immediately)
device_info = await auth.request_device_code_flow() device_info = await auth.request_device_code_flow()
...@@ -12759,7 +12786,7 @@ async def dashboard_codex_auth_poll(request: Request): ...@@ -12759,7 +12786,7 @@ async def dashboard_codex_auth_poll(request: Request):
from aisbf.auth.codex import CodexOAuth2 from aisbf.auth.codex import CodexOAuth2
# Create auth instance # Create auth instance
auth = CodexOAuth2(credentials_file=credentials_file, issuer=issuer) auth = CodexOAuth2(credentials_file=credentials_file, issuer=issuer, skip_initial_load=True)
# Set device auth info on the instance (required for poll_device_code_completion) # Set device auth info on the instance (required for poll_device_code_completion)
auth._device_auth_id = device_auth_id auth._device_auth_id = device_auth_id
...@@ -13048,7 +13075,7 @@ async def dashboard_qwen_auth_start(request: Request): ...@@ -13048,7 +13075,7 @@ async def dashboard_qwen_auth_start(request: Request):
from aisbf.auth.qwen import QwenOAuth2 from aisbf.auth.qwen import QwenOAuth2
# Create auth instance # Create auth instance
auth = QwenOAuth2(credentials_file=credentials_file) auth = QwenOAuth2(credentials_file=credentials_file, skip_initial_load=True)
logger.info(f"QwenOAuth2: Requesting device code for provider: {provider_key}") logger.info(f"QwenOAuth2: Requesting device code for provider: {provider_key}")
...@@ -13125,7 +13152,7 @@ async def dashboard_qwen_auth_poll(request: Request): ...@@ -13125,7 +13152,7 @@ async def dashboard_qwen_auth_poll(request: Request):
from aisbf.auth.qwen import QwenOAuth2 from aisbf.auth.qwen import QwenOAuth2
# Create auth instance # Create auth instance
auth = QwenOAuth2(credentials_file=credentials_file) auth = QwenOAuth2(credentials_file=credentials_file, skip_initial_load=True)
# Poll for token - returns token dict if approved, None if still pending # Poll for token - returns token dict if approved, None if still pending
result = await auth.poll_device_token(device_code, code_verifier) result = await auth.poll_device_token(device_code, code_verifier)
......
...@@ -134,6 +134,7 @@ along with this program. If not, see <https://www.gnu.org/licenses/>. ...@@ -134,6 +134,7 @@ along with this program. If not, see <https://www.gnu.org/licenses/>.
{% block extra_js %} {% block extra_js %}
<script> <script>
let providersData = {}; let providersData = {};
let expandedProviders = new Set();
let rawProviders = {{ user_providers_json | safe }}; let rawProviders = {{ user_providers_json | safe }};
// Convert user providers format to the format expected by the UI // Convert user providers format to the format expected by the UI
...@@ -141,8 +142,6 @@ rawProviders.forEach(provider => { ...@@ -141,8 +142,6 @@ rawProviders.forEach(provider => {
providersData[provider.provider_id] = provider.config; providersData[provider.provider_id] = provider.config;
}); });
let expandedProviders = new Set();
// Chunk size: 512KB chunks for maximum compatibility with restrictive proxies // Chunk size: 512KB chunks for maximum compatibility with restrictive proxies
const CHUNK_SIZE = 512 * 1024; const CHUNK_SIZE = 512 * 1024;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment