Commit faf1ac4b authored by Your Name's avatar Your Name

v0.9.9: User-based configuration routing for providers, rotations,...

v0.9.9: User-based configuration routing for providers, rotations, autoselects, and OAuth2 credentials

- Config admin (from aisbf.json, user_id=None) saves configurations to JSON files
- Database users save configurations to the database (user_providers, user_rotations, user_autoselects tables)
- Dashboard endpoints check user type and route accordingly
- File upload endpoint supports both config admin (files) and database users (database)
- MCP server tools accept user_id parameter and route to appropriate storage
- OAuth2 credential handling already implemented this pattern (Claude, Kilo, Codex)
- Updated CHANGELOG.md, setup.py, and pyproject.toml
parent 72d1dd85
Requirement already satisfied: curl_cffi in /home/nextime/aisbf/venv/lib/python3.13/site-packages (0.14.0)
Requirement already satisfied: cffi>=1.12.0 in /home/nextime/aisbf/venv/lib/python3.13/site-packages (from curl_cffi) (2.0.0)
Requirement already satisfied: certifi>=2024.2.2 in /home/nextime/aisbf/venv/lib/python3.13/site-packages (from curl_cffi) (2026.2.25)
Requirement already satisfied: pycparser in /home/nextime/aisbf/venv/lib/python3.13/site-packages (from cffi>=1.12.0->curl_cffi) (3.0)
......@@ -7,6 +7,20 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased]
## [0.9.9] - 2026-04-04
### Added
- **User-Based Configuration Routing**: All provider, rotation, and autoselect configurations are now saved and retrieved based on user type
- Config admin (defined in aisbf.json, `user_id=None`) saves configurations to JSON files (`~/.aisbf/providers.json`, `~/.aisbf/rotations.json`, `~/.aisbf/autoselect.json`)
- Database users (any user from the database, including admin role users) save configurations to the database via `user_providers`, `user_rotations`, and `user_autoselects` tables
- Dashboard endpoints (`/dashboard/providers`, `/dashboard/rotations`, `/dashboard/autoselect`) now check user type and route accordingly
- File upload endpoint (`/dashboard/providers/{provider_name}/upload`) supports both config admin (saves to files) and database users (saves to database with metadata)
- MCP server tools (`set_provider_config`, `set_rotation_config`, `set_autoselect_config`) now accept `user_id` parameter and route to appropriate storage
- OAuth2 credential handling already implemented this pattern (Claude, Kilo, Codex providers)
### Changed
- **Version Bump**: Updated version to 0.9.9 in setup.py and pyproject.toml
## [0.9.8] - 2026-04-04
### Added
......
......@@ -252,17 +252,25 @@ class CodexOAuth2:
"""
url = f"{self.issuer}/api/accounts/deviceauth/token"
# Include client_id to properly identify the application
payload = {
"client_id": CLIENT_ID,
"device_auth_id": device_auth_id,
"user_code": user_code,
}
logger.debug(f"CodexOAuth2: Polling token endpoint - URL: {url}, Payload: {payload}")
async with httpx.AsyncClient() as client:
response = await client.post(
url,
headers={"Content-Type": "application/json"},
json={
"device_auth_id": device_auth_id,
"user_code": user_code,
},
json=payload,
timeout=30.0
)
logger.debug(f"CodexOAuth2: Poll response - Status: {response.status_code}, Body: {response.text[:500]}")
if response.status_code == 200:
return response.json()
......@@ -311,6 +319,8 @@ class CodexOAuth2:
if not hasattr(self, '_device_auth_id') or not self._device_auth_id:
return {"status": "error", "error": "No device authorization in progress"}
logger.debug(f"CodexOAuth2: Polling for completion - device_auth_id: {self._device_auth_id}, user_code: {self._device_user_code}")
try:
token_resp = await self.poll_device_code_token(
device_auth_id=self._device_auth_id,
......@@ -318,6 +328,8 @@ class CodexOAuth2:
interval=1, # We control polling interval from outside
)
logger.info(f"CodexOAuth2: Token response received - keys: {list(token_resp.keys())}")
# Step 3: Exchange for tokens
redirect_uri = f"{self.issuer}/deviceauth/callback"
tokens = await self.exchange_code_for_tokens(
......@@ -326,10 +338,13 @@ class CodexOAuth2:
code_verifier=token_resp["code_verifier"],
)
logger.info(f"CodexOAuth2: Tokens exchanged successfully")
# Step 4: Optionally obtain API key
api_key = None
try:
api_key = await self.obtain_api_key(tokens["id_token"])
logger.info(f"CodexOAuth2: API key obtained")
except Exception as e:
logger.warning(f"CodexOAuth2: Failed to obtain API key: {e}")
......@@ -356,6 +371,7 @@ class CodexOAuth2:
except Exception as e:
error_msg = str(e)
logger.debug(f"CodexOAuth2: Poll exception - {type(e).__name__}: {error_msg}")
# 403/404 means still pending
if "403" in error_msg or "404" in error_msg or "pending" in error_msg.lower():
return {"status": "pending"}
......
......@@ -327,6 +327,29 @@ class DatabaseManager:
except:
pass # Index might already exist
# Create user_oauth2_credentials table for storing OAuth2 tokens per user/provider
cursor.execute(f'''
CREATE TABLE IF NOT EXISTS user_oauth2_credentials (
id INTEGER PRIMARY KEY {auto_increment},
user_id INTEGER NOT NULL,
provider_id VARCHAR(255) NOT NULL,
auth_type VARCHAR(50) NOT NULL,
credentials TEXT NOT NULL,
created_at TIMESTAMP DEFAULT {timestamp_default},
updated_at TIMESTAMP DEFAULT {timestamp_default},
FOREIGN KEY (user_id) REFERENCES users(id),
UNIQUE(user_id, provider_id, auth_type)
)
''')
try:
cursor.execute('''
CREATE INDEX idx_user_oauth2_user_provider
ON user_oauth2_credentials(user_id, provider_id)
''')
except:
pass
conn.commit()
logger.info("User auth files table initialized")
......@@ -1497,6 +1520,148 @@ class DatabaseManager:
conn.commit()
return cursor.rowcount
# User OAuth2 credential methods
def save_user_oauth2_credentials(self, user_id: int, provider_id: str, auth_type: str, credentials: Dict) -> int:
"""
Save OAuth2 credentials for a user/provider combination.
Args:
user_id: User ID
provider_id: Provider identifier (e.g., 'codex', 'kilo', 'claude')
auth_type: Auth type (e.g., 'codex_oauth2', 'kilo_oauth2', 'claude_oauth2')
credentials: Credentials dictionary
Returns:
Record ID
"""
with self._get_connection() as conn:
cursor = conn.cursor()
credentials_json = json.dumps(credentials)
placeholder = '?' if self.db_type == 'sqlite' else '%s'
if self.db_type == 'sqlite':
cursor.execute(f'''
INSERT OR REPLACE INTO user_oauth2_credentials
(user_id, provider_id, auth_type, credentials, updated_at)
VALUES ({placeholder}, {placeholder}, {placeholder}, {placeholder}, CURRENT_TIMESTAMP)
''', (user_id, provider_id, auth_type, credentials_json))
else: # mysql
cursor.execute(f'''
INSERT INTO user_oauth2_credentials
(user_id, provider_id, auth_type, credentials, updated_at)
VALUES ({placeholder}, {placeholder}, {placeholder}, {placeholder}, CURRENT_TIMESTAMP)
ON DUPLICATE KEY UPDATE
credentials=VALUES(credentials), updated_at=CURRENT_TIMESTAMP
''', (user_id, provider_id, auth_type, credentials_json))
conn.commit()
return cursor.lastrowid
def get_user_oauth2_credentials(self, user_id: int, provider_id: str, auth_type: str = None) -> Optional[Dict]:
"""
Get OAuth2 credentials for a user/provider combination.
Args:
user_id: User ID
provider_id: Provider identifier
auth_type: Optional auth type filter
Returns:
Credentials dictionary or None
"""
with self._get_connection() as conn:
cursor = conn.cursor()
placeholder = '?' if self.db_type == 'sqlite' else '%s'
if auth_type:
cursor.execute(f'''
SELECT id, auth_type, credentials, created_at, updated_at
FROM user_oauth2_credentials
WHERE user_id = {placeholder} AND provider_id = {placeholder} AND auth_type = {placeholder}
''', (user_id, provider_id, auth_type))
else:
cursor.execute(f'''
SELECT id, auth_type, credentials, created_at, updated_at
FROM user_oauth2_credentials
WHERE user_id = {placeholder} AND provider_id = {placeholder}
ORDER BY updated_at DESC
LIMIT 1
''', (user_id, provider_id))
row = cursor.fetchone()
if row:
return {
'id': row[0],
'auth_type': row[1],
'credentials': json.loads(row[2]),
'created_at': row[3],
'updated_at': row[4]
}
return None
def get_all_user_oauth2_credentials(self, user_id: int) -> List[Dict]:
"""
Get all OAuth2 credentials for a user.
Args:
user_id: User ID
Returns:
List of credential dictionaries
"""
with self._get_connection() as conn:
cursor = conn.cursor()
placeholder = '?' if self.db_type == 'sqlite' else '%s'
cursor.execute(f'''
SELECT id, provider_id, auth_type, credentials, created_at, updated_at
FROM user_oauth2_credentials
WHERE user_id = {placeholder}
ORDER BY provider_id, auth_type
''', (user_id,))
credentials = []
for row in cursor.fetchall():
credentials.append({
'id': row[0],
'provider_id': row[1],
'auth_type': row[2],
'credentials': json.loads(row[3]),
'created_at': row[4],
'updated_at': row[5]
})
return credentials
def delete_user_oauth2_credentials(self, user_id: int, provider_id: str, auth_type: str = None) -> int:
"""
Delete OAuth2 credentials for a user/provider combination.
Args:
user_id: User ID
provider_id: Provider identifier
auth_type: Optional auth type filter (if None, deletes all for provider)
Returns:
Number of records deleted
"""
with self._get_connection() as conn:
cursor = conn.cursor()
placeholder = '?' if self.db_type == 'sqlite' else '%s'
if auth_type:
cursor.execute(f'''
DELETE FROM user_oauth2_credentials
WHERE user_id = {placeholder} AND provider_id = {placeholder} AND auth_type = {placeholder}
''', (user_id, provider_id, auth_type))
else:
cursor.execute(f'''
DELETE FROM user_oauth2_credentials
WHERE user_id = {placeholder} AND provider_id = {placeholder}
''', (user_id, provider_id))
conn.commit()
return cursor.rowcount
# Global database manager instance
_db_manager: Optional[DatabaseManager] = None
......
......@@ -351,7 +351,7 @@ class RequestHandler:
logger.info("No API key required for this provider")
logger.info(f"Getting provider handler for {provider_id}")
handler = get_provider_handler(provider_id, api_key)
handler = get_provider_handler(provider_id, api_key, user_id=self.user_id)
logger.info(f"Provider handler obtained: {handler.__class__.__name__}")
if handler.is_rate_limited():
......@@ -518,7 +518,7 @@ class RequestHandler:
else:
api_key = None
handler = get_provider_handler(provider_id, api_key)
handler = get_provider_handler(provider_id, api_key, user_id=self.user_id)
if handler.is_rate_limited():
raise HTTPException(status_code=503, detail="Provider temporarily unavailable")
......@@ -1117,7 +1117,7 @@ class RequestHandler:
else:
api_key = None
handler = get_provider_handler(provider_id, api_key)
handler = get_provider_handler(provider_id, api_key, user_id=self.user_id)
try:
# Apply rate limiting
await handler.apply_rate_limit()
......@@ -1398,7 +1398,7 @@ class RequestHandler:
else:
api_key = None
handler = get_provider_handler(provider_id, api_key)
handler = get_provider_handler(provider_id, api_key, user_id=self.user_id)
if handler.is_rate_limited():
raise HTTPException(status_code=503, detail="Provider temporarily unavailable")
......@@ -1427,7 +1427,7 @@ class RequestHandler:
else:
api_key = None
handler = get_provider_handler(provider_id, api_key)
handler = get_provider_handler(provider_id, api_key, user_id=self.user_id)
if handler.is_rate_limited():
raise HTTPException(status_code=503, detail="Provider temporarily unavailable")
......@@ -1456,7 +1456,7 @@ class RequestHandler:
else:
api_key = None
handler = get_provider_handler(provider_id, api_key)
handler = get_provider_handler(provider_id, api_key, user_id=self.user_id)
if handler.is_rate_limited():
raise HTTPException(status_code=503, detail="Provider temporarily unavailable")
......@@ -1489,7 +1489,7 @@ class RequestHandler:
else:
api_key = None
handler = get_provider_handler(provider_id, api_key)
handler = get_provider_handler(provider_id, api_key, user_id=self.user_id)
if handler.is_rate_limited():
raise HTTPException(status_code=503, detail="Provider temporarily unavailable")
......@@ -2025,7 +2025,7 @@ class RotationHandler:
api_key = self._get_api_key(provider_id, provider.get('api_key'))
# Check if provider is rate limited/deactivated
provider_handler = get_provider_handler(provider_id, api_key)
provider_handler = get_provider_handler(provider_id, api_key, user_id=self.user_id)
if provider_handler.is_rate_limited():
logger.warning(f" [SKIPPED] Provider {provider_id} is rate limited/deactivated")
logger.warning(f" Reason: Provider has exceeded failure threshold or is in cooldown period")
......@@ -2367,7 +2367,7 @@ class RotationHandler:
model_name = current_model['name']
logger.info(f"Getting provider handler for {provider_id}")
handler = get_provider_handler(provider_id, api_key)
handler = get_provider_handler(provider_id, api_key, user_id=self.user_id)
logger.info(f"Provider handler obtained: {handler.__class__.__name__}")
if handler.is_rate_limited():
......@@ -2800,9 +2800,13 @@ class RotationHandler:
})
}
yield f"data: {json.dumps(final_chunk)}\n\n".encode('utf-8')
# Yield control to event loop to ensure chunk is flushed to client
await asyncio.sleep(0)
# Send [DONE] marker
yield b"data: [DONE]\n\n"
# Final flush to ensure all buffered data reaches the client
await asyncio.sleep(0)
return StreamingResponse(error_stream_generator(), media_type="text/event-stream", status_code=status_code)
......@@ -3106,6 +3110,8 @@ class RotationHandler:
}]
}
yield f"data: {json.dumps(final_chunk)}\n\n".encode('utf-8')
# Yield control to event loop to ensure final chunk is flushed to client
await asyncio.sleep(0)
elif is_kilo_provider:
# Handle Kilo/KiloCode streaming response
# Kilo returns an async generator that yields OpenAI-compatible SSE bytes
......
......@@ -676,13 +676,13 @@ class MCPServer:
'get_rotation_settings': self._get_rotation_settings,
})
# Add fullconfig-level tools
# Add fullconfig-level tools (now support user_id for routing)
if auth_level >= MCPAuthLevel.FULLCONFIG:
handlers.update({
'get_providers_config': self._get_providers_config,
'set_autoselect_config': self._set_autoselect_config,
'set_rotation_config': self._set_rotation_config,
'set_provider_config': self._set_provider_config,
'set_autoselect_config': lambda args: self._set_autoselect_config(args, user_id),
'set_rotation_config': lambda args: self._set_rotation_config(args, user_id),
'set_provider_config': lambda args: self._set_provider_config(args, user_id),
'get_server_config': self._get_server_config,
'set_server_config': self._set_server_config,
'get_tor_status': self._get_tor_status,
......@@ -719,7 +719,7 @@ class MCPServer:
raise HTTPException(status_code=404, detail=f"Tool '{tool_name}' not found")
handler = handlers[tool_name]
return await handler(arguments, user_id)
return await handler(arguments)
async def _list_models(self, args: Dict) -> Dict:
"""List all available models"""
......@@ -994,15 +994,19 @@ class MCPServer:
else:
return {"providers": providers_data}
async def _set_autoselect_config(self, args: Dict) -> Dict:
"""Set autoselect configuration"""
async def _set_autoselect_config(self, args: Dict, user_id: Optional[int] = None) -> Dict:
"""Set autoselect configuration. Config admin saves to files, other users save to database."""
autoselect_id = args.get('autoselect_id')
autoselect_data = args.get('autoselect_data')
if not autoselect_id or not autoselect_data:
raise HTTPException(status_code=400, detail="autoselect_id and autoselect_data are required")
# Load existing config
# Check if user is config admin (user_id is None means config admin)
is_config_admin = user_id is None
if is_config_admin:
# Config admin: save to JSON files
config_path = Path.home() / '.aisbf' / 'autoselect.json'
if not config_path.exists():
config_path = Path(__file__).parent.parent / 'config' / 'autoselect.json'
......@@ -1021,22 +1025,31 @@ class MCPServer:
with open(save_path, 'w') as f:
json.dump(full_config, f, indent=2)
return {"status": "success", "message": f"Autoselect '{autoselect_id}' saved. Restart server for changes to take effect."}
return {"status": "success", "message": f"Autoselect '{autoselect_id}' saved to file. Restart server for changes to take effect."}
else:
# Database user: save to database
from .database import get_database
db = get_database()
db.save_user_autoselect(user_id, autoselect_id, autoselect_data)
return {"status": "success", "message": f"Autoselect '{autoselect_id}' saved to database for user {user_id}."}
async def _set_rotation_config(self, args: Dict) -> Dict:
"""Set rotation configuration"""
async def _set_rotation_config(self, args: Dict, user_id: Optional[int] = None) -> Dict:
"""Set rotation configuration. Config admin saves to files, other users save to database."""
rotation_id = args.get('rotation_id')
rotation_data = args.get('rotation_data')
if not rotation_id or not rotation_data:
raise HTTPException(status_code=400, detail="rotation_id and rotation_data are required")
# Load existing config
# Check if user is config admin (user_id is None means config admin)
is_config_admin = user_id is None
if is_config_admin:
# Config admin: save to JSON files
config_path = Path.home() / '.aisbf' / 'rotations.json'
if not config_path.exists():
with open(config_path) as f:
full_config = json.load(f)
else:
config_path = Path(__file__).parent.parent / 'config' / 'rotations.json'
with open(config_path) as f:
full_config = json.load(f)
......@@ -1051,17 +1064,27 @@ class MCPServer:
with open(save_path, 'w') as f:
json.dump(full_config, f, indent=2)
return {"status": "success", "message": f"Rotation '{rotation_id}' saved. Restart server for changes to take effect."}
return {"status": "success", "message": f"Rotation '{rotation_id}' saved to file. Restart server for changes to take effect."}
else:
# Database user: save to database
from .database import get_database
db = get_database()
db.save_user_rotation(user_id, rotation_id, rotation_data)
return {"status": "success", "message": f"Rotation '{rotation_id}' saved to database for user {user_id}."}
async def _set_provider_config(self, args: Dict) -> Dict:
"""Set provider configuration"""
async def _set_provider_config(self, args: Dict, user_id: Optional[int] = None) -> Dict:
"""Set provider configuration. Config admin saves to files, other users save to database."""
provider_id = args.get('provider_id')
provider_data = args.get('provider_data')
if not provider_id or not provider_data:
raise HTTPException(status_code=400, detail="provider_id and provider_data are required")
# Load existing config
# Check if user is config admin (user_id is None means config admin)
is_config_admin = user_id is None
if is_config_admin:
# Config admin: save to JSON files
config_path = Path.home() / '.aisbf' / 'providers.json'
if not config_path.exists():
config_path = Path(__file__).parent.parent / 'config' / 'providers.json'
......@@ -1080,7 +1103,13 @@ class MCPServer:
with open(save_path, 'w') as f:
json.dump(full_config, f, indent=2)
return {"status": "success", "message": f"Provider '{provider_id}' saved. Restart server for changes to take effect."}
return {"status": "success", "message": f"Provider '{provider_id}' saved to file. Restart server for changes to take effect."}
else:
# Database user: save to database
from .database import get_database
db = get_database()
db.save_user_provider(user_id, provider_id, provider_data)
return {"status": "success", "message": f"Provider '{provider_id}' saved to database for user {user_id}."}
async def _get_server_config(self, args: Dict) -> Dict:
"""Get server configuration"""
......
......@@ -56,12 +56,13 @@ PROVIDER_HANDLERS = {
}
def get_provider_handler(provider_id: str, api_key: Optional[str] = None) -> BaseProviderHandler:
def get_provider_handler(provider_id: str, api_key: Optional[str] = None, user_id: Optional[int] = None) -> BaseProviderHandler:
import logging
logger = logging.getLogger(__name__)
logger.info(f"=== get_provider_handler START ===")
logger.info(f"Provider ID: {provider_id}")
logger.info(f"API key provided: {bool(api_key)}")
logger.info(f"User ID: {user_id}")
provider_config = config.get_provider(provider_id)
logger.info(f"Provider config: {provider_config}")
......@@ -76,7 +77,13 @@ def get_provider_handler(provider_id: str, api_key: Optional[str] = None) -> Bas
logger.error(f"Unsupported provider type: {provider_config.type}")
raise ValueError(f"Unsupported provider type: {provider_config.type}")
# All handlers now accept api_key as optional parameter
# Check if handler supports user_id parameter (CodexProviderHandler does)
import inspect
sig = inspect.signature(handler_class.__init__)
if 'user_id' in sig.parameters:
logger.info(f"Creating handler with provider_id, optional api_key, and user_id")
handler = handler_class(provider_id, api_key, user_id=user_id)
else:
logger.info(f"Creating handler with provider_id and optional api_key")
handler = handler_class(provider_id, api_key)
......
......@@ -39,12 +39,16 @@ class ClaudeProviderHandler(BaseProviderHandler):
the official Anthropic Python SDK. OAuth2 access tokens are passed as
the api_key parameter to the SDK, which handles proper message formatting,
retries, and streaming.
For admin users (user_id=None), credentials are loaded from file.
For non-admin users, credentials are loaded from the database.
"""
# NOTE: OAuth2 API uses its own model naming scheme that differs from standard Anthropic API
def __init__(self, provider_id: str, api_key: Optional[str] = None):
def __init__(self, provider_id: str, api_key: Optional[str] = None, user_id: Optional[int] = None):
super().__init__(provider_id, api_key)
self.user_id = user_id
self.provider_config = config.get_provider(provider_id)
# Get credentials file path from config
......@@ -53,7 +57,12 @@ class ClaudeProviderHandler(BaseProviderHandler):
if claude_config and isinstance(claude_config, dict):
credentials_file = claude_config.get('credentials_file')
# Initialize ClaudeAuth with credentials file (handles OAuth2 flow)
# Only the ONE config admin (user_id=None from aisbf.json) uses file-based credentials
# All other users (including database admins with user_id) use database credentials
if user_id is not None:
self.auth = self._load_auth_from_db(provider_id, credentials_file)
else:
# Config admin (from aisbf.json): use file-based credentials
from ..auth.claude import ClaudeAuth
self.auth = ClaudeAuth(credentials_file=credentials_file)
......@@ -92,6 +101,39 @@ class ClaudeProviderHandler(BaseProviderHandler):
# Initialize persistent identifiers for metadata
self._init_session_identifiers()
def _load_auth_from_db(self, provider_id: str, credentials_file: str):
"""
Load OAuth2 credentials from database for non-admin users.
Falls back to file-based credentials if not found in database.
"""
try:
from ..database import get_database
from ..auth.claude import ClaudeAuth
db = get_database()
if db:
db_creds = db.get_user_oauth2_credentials(
user_id=self.user_id,
provider_id=provider_id,
auth_type='claude_oauth2'
)
if db_creds and db_creds.get('credentials'):
# Create auth instance with database credentials
auth = ClaudeAuth(credentials_file=credentials_file)
# Override the loaded credentials with database credentials
auth.tokens = db_creds['credentials'].get('tokens', {})
import logging
logging.getLogger(__name__).info(f"ClaudeProviderHandler: Loaded credentials from database for user {self.user_id}")
return auth
except Exception as e:
import logging
logging.getLogger(__name__).warning(f"ClaudeProviderHandler: Failed to load credentials from database: {e}")
# Fall back to file-based credentials
from ..auth.claude import ClaudeAuth
import logging
logging.getLogger(__name__).info(f"ClaudeProviderHandler: Falling back to file-based credentials for user {self.user_id}")
return ClaudeAuth(credentials_file=credentials_file)
def _init_session_identifiers(self):
"""Initialize persistent session identifiers (device_id, account_uuid, session_id)."""
import uuid
......@@ -1190,7 +1232,11 @@ class ClaudeProviderHandler(BaseProviderHandler):
}
yield f"data: {json.dumps(final_chunk, ensure_ascii=False)}\n\n".encode('utf-8')
# Yield control to event loop to ensure chunk is flushed to client
await asyncio.sleep(0)
yield b"data: [DONE]\n\n"
# Final flush to ensure all buffered data reaches the client
await asyncio.sleep(0)
except json.JSONDecodeError as e:
logger.warning(f"Failed to parse streaming chunk: {e}")
......@@ -1553,7 +1599,11 @@ class ClaudeProviderHandler(BaseProviderHandler):
}
yield f"data: {json.dumps(final_chunk, ensure_ascii=False)}\n\n".encode('utf-8')
# Yield control to event loop to ensure chunk is flushed to client
await asyncio.sleep(0)
yield b"data: [DONE]\n\n"
# Final flush to ensure all buffered data reaches the client
await asyncio.sleep(0)
logger.info(f"ClaudeProviderHandler: SDK streaming completed successfully")
self.record_success()
......
......@@ -40,10 +40,14 @@ class CodexProviderHandler(BaseProviderHandler):
Uses the same OpenAI-compatible protocol but authenticates via OAuth2
using the Codex OAuth2 flow (device code or browser-based PKCE).
For admin users (user_id=None), credentials are loaded from file.
For non-admin users, credentials are loaded from the database.
"""
def __init__(self, provider_id: str, api_key: Optional[str] = None):
def __init__(self, provider_id: str, api_key: Optional[str] = None, user_id: Optional[int] = None):
super().__init__(provider_id, api_key)
self.user_id = user_id
# Get provider config
provider_config = config.providers.get(provider_id)
......@@ -54,6 +58,12 @@ class CodexProviderHandler(BaseProviderHandler):
credentials_file = codex_config.get('credentials_file', '~/.aisbf/codex_credentials.json')
issuer = codex_config.get('issuer', 'https://auth.openai.com')
# Only the ONE config admin (user_id=None from aisbf.json) uses file-based credentials
# All other users (including database admins with user_id) use database credentials
if user_id is not None:
self.oauth2 = self._load_oauth2_from_db(provider_id, credentials_file, issuer)
else:
# Config admin (from aisbf.json): use file-based credentials
self.oauth2 = CodexOAuth2(
credentials_file=credentials_file,
issuer=issuer,
......@@ -73,6 +83,40 @@ class CodexProviderHandler(BaseProviderHandler):
self.client = OpenAI(base_url=endpoint, api_key=resolved_api_key or "dummy")
self._oauth2_enabled = not api_key and provider_config and not provider_config.api_key_required
def _load_oauth2_from_db(self, provider_id: str, credentials_file: str, issuer: str) -> CodexOAuth2:
"""
Load OAuth2 credentials from database for non-admin users.
Falls back to file-based credentials if not found in database.
"""
try:
from ..database import get_database
db = get_database()
if db:
db_creds = db.get_user_oauth2_credentials(
user_id=self.user_id,
provider_id=provider_id,
auth_type='codex_oauth2'
)
if db_creds and db_creds.get('credentials'):
# Create OAuth2 instance with database credentials
oauth2 = CodexOAuth2(
credentials_file=credentials_file,
issuer=issuer,
)
# Override the loaded credentials with database credentials
oauth2.credentials = db_creds['credentials']
logger.info(f"CodexProviderHandler: Loaded credentials from database for user {self.user_id}")
return oauth2
except Exception as e:
logger.warning(f"CodexProviderHandler: Failed to load credentials from database: {e}")
# Fall back to file-based credentials
logger.info(f"CodexProviderHandler: Falling back to file-based credentials for user {self.user_id}")
return CodexOAuth2(
credentials_file=credentials_file,
issuer=issuer,
)
async def _get_valid_api_key(self) -> str:
"""Get a valid API key, refreshing OAuth2 if needed."""
# If we have an API key from config, use it
......
......@@ -32,10 +32,14 @@ from .base import BaseProviderHandler, AISBF_DEBUG
class KiloProviderHandler(BaseProviderHandler):
"""
Handler for Kilo Gateway (OpenAI-compatible with OAuth2 support).
For admin users (user_id=None), credentials are loaded from file.
For non-admin users, credentials are loaded from the database.
"""
def __init__(self, provider_id: str, api_key: Optional[str] = None):
def __init__(self, provider_id: str, api_key: Optional[str] = None, user_id: Optional[int] = None):
super().__init__(provider_id, api_key)
self.user_id = user_id
self.provider_config = config.get_provider(provider_id)
kilo_config = getattr(self.provider_config, 'kilo_config', None)
......@@ -47,6 +51,12 @@ class KiloProviderHandler(BaseProviderHandler):
credentials_file = kilo_config.get('credentials_file')
api_base = kilo_config.get('api_base')
# Only the ONE config admin (user_id=None from aisbf.json) uses file-based credentials
# All other users (including database admins with user_id) use database credentials
if user_id is not None:
self.oauth2 = self._load_oauth2_from_db(provider_id, credentials_file, api_base)
else:
# Config admin (from aisbf.json): use file-based credentials
from ..auth.kilo import KiloOAuth2
self.oauth2 = KiloOAuth2(credentials_file=credentials_file, api_base=api_base)
......@@ -62,6 +72,39 @@ class KiloProviderHandler(BaseProviderHandler):
self.client = OpenAI(base_url=endpoint, api_key=api_key or "placeholder")
def _load_oauth2_from_db(self, provider_id: str, credentials_file: str, api_base: str):
"""
Load OAuth2 credentials from database for non-admin users.
Falls back to file-based credentials if not found in database.
"""
try:
from ..database import get_database
from ..auth.kilo import KiloOAuth2
db = get_database()
if db:
db_creds = db.get_user_oauth2_credentials(
user_id=self.user_id,
provider_id=provider_id,
auth_type='kilo_oauth2'
)
if db_creds and db_creds.get('credentials'):
# Create OAuth2 instance with database credentials
oauth2 = KiloOAuth2(credentials_file=credentials_file, api_base=api_base)
# Override the loaded credentials with database credentials
oauth2.credentials = db_creds['credentials']
import logging
logging.getLogger(__name__).info(f"KiloProviderHandler: Loaded credentials from database for user {self.user_id}")
return oauth2
except Exception as e:
import logging
logging.getLogger(__name__).warning(f"KiloProviderHandler: Failed to load credentials from database: {e}")
# Fall back to file-based credentials
from ..auth.kilo import KiloOAuth2
import logging
logging.getLogger(__name__).info(f"KiloProviderHandler: Falling back to file-based credentials for user {self.user_id}")
return KiloOAuth2(credentials_file=credentials_file, api_base=api_base)
async def _ensure_authenticated(self) -> str:
"""Ensure user is authenticated and return valid token."""
import logging
......
......@@ -22,6 +22,7 @@ Why did the programmer quit his job? Because he didn't get arrays!
"""
import httpx
import asyncio
import time
import os
import json
......@@ -363,6 +364,8 @@ class KiroProviderHandler(BaseProviderHandler):
}]
}
yield f"data: {json.dumps(tool_calls_chunk, ensure_ascii=False)}\n\n".encode('utf-8')
# Yield control to event loop to ensure chunk is flushed to client
await asyncio.sleep(0)
final_chunk = {
"id": completion_id,
......@@ -382,7 +385,11 @@ class KiroProviderHandler(BaseProviderHandler):
}
yield f"data: {json.dumps(final_chunk, ensure_ascii=False)}\n\n".encode('utf-8')
# Yield control to event loop to ensure chunk is flushed to client
await asyncio.sleep(0)
yield b"data: [DONE]\n\n"
# Final flush to ensure all buffered data reaches the client
await asyncio.sleep(0)
def _get_models_cache_path(self) -> str:
"""Get the path to the models cache file."""
......
This diff is collapsed.
......@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "aisbf"
version = "0.9.8"
version = "0.9.9"
description = "AISBF - AI Service Broker Framework || AI Should Be Free - A modular proxy server for managing multiple AI provider integrations"
readme = "README.md"
license = "GPL-3.0-or-later"
......
......@@ -49,7 +49,7 @@ class InstallCommand(_install):
setup(
name="aisbf",
version="0.9.8",
version="0.9.9",
author="AISBF Contributors",
author_email="stefy@nexlab.net",
description="AISBF - AI Service Broker Framework || AI Should Be Free - A modular proxy server for managing multiple AI provider integrations",
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment