Commit faf1ac4b authored by Your Name's avatar Your Name

v0.9.9: User-based configuration routing for providers, rotations,...

v0.9.9: User-based configuration routing for providers, rotations, autoselects, and OAuth2 credentials

- Config admin (from aisbf.json, user_id=None) saves configurations to JSON files
- Database users save configurations to the database (user_providers, user_rotations, user_autoselects tables)
- Dashboard endpoints check user type and route accordingly
- File upload endpoint supports both config admin (files) and database users (database)
- MCP server tools accept user_id parameter and route to appropriate storage
- OAuth2 credential handling already implemented this pattern (Claude, Kilo, Codex)
- Updated CHANGELOG.md, setup.py, and pyproject.toml
parent 72d1dd85
Requirement already satisfied: curl_cffi in /home/nextime/aisbf/venv/lib/python3.13/site-packages (0.14.0)
Requirement already satisfied: cffi>=1.12.0 in /home/nextime/aisbf/venv/lib/python3.13/site-packages (from curl_cffi) (2.0.0)
Requirement already satisfied: certifi>=2024.2.2 in /home/nextime/aisbf/venv/lib/python3.13/site-packages (from curl_cffi) (2026.2.25)
Requirement already satisfied: pycparser in /home/nextime/aisbf/venv/lib/python3.13/site-packages (from cffi>=1.12.0->curl_cffi) (3.0)
...@@ -7,6 +7,20 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ...@@ -7,6 +7,20 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased] ## [Unreleased]
## [0.9.9] - 2026-04-04
### Added
- **User-Based Configuration Routing**: All provider, rotation, and autoselect configurations are now saved and retrieved based on user type
- Config admin (defined in aisbf.json, `user_id=None`) saves configurations to JSON files (`~/.aisbf/providers.json`, `~/.aisbf/rotations.json`, `~/.aisbf/autoselect.json`)
- Database users (any user from the database, including admin role users) save configurations to the database via `user_providers`, `user_rotations`, and `user_autoselects` tables
- Dashboard endpoints (`/dashboard/providers`, `/dashboard/rotations`, `/dashboard/autoselect`) now check user type and route accordingly
- File upload endpoint (`/dashboard/providers/{provider_name}/upload`) supports both config admin (saves to files) and database users (saves to database with metadata)
- MCP server tools (`set_provider_config`, `set_rotation_config`, `set_autoselect_config`) now accept `user_id` parameter and route to appropriate storage
- OAuth2 credential handling already implemented this pattern (Claude, Kilo, Codex providers)
### Changed
- **Version Bump**: Updated version to 0.9.9 in setup.py and pyproject.toml
## [0.9.8] - 2026-04-04 ## [0.9.8] - 2026-04-04
### Added ### Added
......
...@@ -252,17 +252,25 @@ class CodexOAuth2: ...@@ -252,17 +252,25 @@ class CodexOAuth2:
""" """
url = f"{self.issuer}/api/accounts/deviceauth/token" url = f"{self.issuer}/api/accounts/deviceauth/token"
# Include client_id to properly identify the application
payload = {
"client_id": CLIENT_ID,
"device_auth_id": device_auth_id,
"user_code": user_code,
}
logger.debug(f"CodexOAuth2: Polling token endpoint - URL: {url}, Payload: {payload}")
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await client.post( response = await client.post(
url, url,
headers={"Content-Type": "application/json"}, headers={"Content-Type": "application/json"},
json={ json=payload,
"device_auth_id": device_auth_id,
"user_code": user_code,
},
timeout=30.0 timeout=30.0
) )
logger.debug(f"CodexOAuth2: Poll response - Status: {response.status_code}, Body: {response.text[:500]}")
if response.status_code == 200: if response.status_code == 200:
return response.json() return response.json()
...@@ -311,6 +319,8 @@ class CodexOAuth2: ...@@ -311,6 +319,8 @@ class CodexOAuth2:
if not hasattr(self, '_device_auth_id') or not self._device_auth_id: if not hasattr(self, '_device_auth_id') or not self._device_auth_id:
return {"status": "error", "error": "No device authorization in progress"} return {"status": "error", "error": "No device authorization in progress"}
logger.debug(f"CodexOAuth2: Polling for completion - device_auth_id: {self._device_auth_id}, user_code: {self._device_user_code}")
try: try:
token_resp = await self.poll_device_code_token( token_resp = await self.poll_device_code_token(
device_auth_id=self._device_auth_id, device_auth_id=self._device_auth_id,
...@@ -318,6 +328,8 @@ class CodexOAuth2: ...@@ -318,6 +328,8 @@ class CodexOAuth2:
interval=1, # We control polling interval from outside interval=1, # We control polling interval from outside
) )
logger.info(f"CodexOAuth2: Token response received - keys: {list(token_resp.keys())}")
# Step 3: Exchange for tokens # Step 3: Exchange for tokens
redirect_uri = f"{self.issuer}/deviceauth/callback" redirect_uri = f"{self.issuer}/deviceauth/callback"
tokens = await self.exchange_code_for_tokens( tokens = await self.exchange_code_for_tokens(
...@@ -326,10 +338,13 @@ class CodexOAuth2: ...@@ -326,10 +338,13 @@ class CodexOAuth2:
code_verifier=token_resp["code_verifier"], code_verifier=token_resp["code_verifier"],
) )
logger.info(f"CodexOAuth2: Tokens exchanged successfully")
# Step 4: Optionally obtain API key # Step 4: Optionally obtain API key
api_key = None api_key = None
try: try:
api_key = await self.obtain_api_key(tokens["id_token"]) api_key = await self.obtain_api_key(tokens["id_token"])
logger.info(f"CodexOAuth2: API key obtained")
except Exception as e: except Exception as e:
logger.warning(f"CodexOAuth2: Failed to obtain API key: {e}") logger.warning(f"CodexOAuth2: Failed to obtain API key: {e}")
...@@ -356,6 +371,7 @@ class CodexOAuth2: ...@@ -356,6 +371,7 @@ class CodexOAuth2:
except Exception as e: except Exception as e:
error_msg = str(e) error_msg = str(e)
logger.debug(f"CodexOAuth2: Poll exception - {type(e).__name__}: {error_msg}")
# 403/404 means still pending # 403/404 means still pending
if "403" in error_msg or "404" in error_msg or "pending" in error_msg.lower(): if "403" in error_msg or "404" in error_msg or "pending" in error_msg.lower():
return {"status": "pending"} return {"status": "pending"}
......
...@@ -327,6 +327,29 @@ class DatabaseManager: ...@@ -327,6 +327,29 @@ class DatabaseManager:
except: except:
pass # Index might already exist pass # Index might already exist
# Create user_oauth2_credentials table for storing OAuth2 tokens per user/provider
cursor.execute(f'''
CREATE TABLE IF NOT EXISTS user_oauth2_credentials (
id INTEGER PRIMARY KEY {auto_increment},
user_id INTEGER NOT NULL,
provider_id VARCHAR(255) NOT NULL,
auth_type VARCHAR(50) NOT NULL,
credentials TEXT NOT NULL,
created_at TIMESTAMP DEFAULT {timestamp_default},
updated_at TIMESTAMP DEFAULT {timestamp_default},
FOREIGN KEY (user_id) REFERENCES users(id),
UNIQUE(user_id, provider_id, auth_type)
)
''')
try:
cursor.execute('''
CREATE INDEX idx_user_oauth2_user_provider
ON user_oauth2_credentials(user_id, provider_id)
''')
except:
pass
conn.commit() conn.commit()
logger.info("User auth files table initialized") logger.info("User auth files table initialized")
...@@ -1497,6 +1520,148 @@ class DatabaseManager: ...@@ -1497,6 +1520,148 @@ class DatabaseManager:
conn.commit() conn.commit()
return cursor.rowcount return cursor.rowcount
# User OAuth2 credential methods
def save_user_oauth2_credentials(self, user_id: int, provider_id: str, auth_type: str, credentials: Dict) -> int:
"""
Save OAuth2 credentials for a user/provider combination.
Args:
user_id: User ID
provider_id: Provider identifier (e.g., 'codex', 'kilo', 'claude')
auth_type: Auth type (e.g., 'codex_oauth2', 'kilo_oauth2', 'claude_oauth2')
credentials: Credentials dictionary
Returns:
Record ID
"""
with self._get_connection() as conn:
cursor = conn.cursor()
credentials_json = json.dumps(credentials)
placeholder = '?' if self.db_type == 'sqlite' else '%s'
if self.db_type == 'sqlite':
cursor.execute(f'''
INSERT OR REPLACE INTO user_oauth2_credentials
(user_id, provider_id, auth_type, credentials, updated_at)
VALUES ({placeholder}, {placeholder}, {placeholder}, {placeholder}, CURRENT_TIMESTAMP)
''', (user_id, provider_id, auth_type, credentials_json))
else: # mysql
cursor.execute(f'''
INSERT INTO user_oauth2_credentials
(user_id, provider_id, auth_type, credentials, updated_at)
VALUES ({placeholder}, {placeholder}, {placeholder}, {placeholder}, CURRENT_TIMESTAMP)
ON DUPLICATE KEY UPDATE
credentials=VALUES(credentials), updated_at=CURRENT_TIMESTAMP
''', (user_id, provider_id, auth_type, credentials_json))
conn.commit()
return cursor.lastrowid
def get_user_oauth2_credentials(self, user_id: int, provider_id: str, auth_type: str = None) -> Optional[Dict]:
"""
Get OAuth2 credentials for a user/provider combination.
Args:
user_id: User ID
provider_id: Provider identifier
auth_type: Optional auth type filter
Returns:
Credentials dictionary or None
"""
with self._get_connection() as conn:
cursor = conn.cursor()
placeholder = '?' if self.db_type == 'sqlite' else '%s'
if auth_type:
cursor.execute(f'''
SELECT id, auth_type, credentials, created_at, updated_at
FROM user_oauth2_credentials
WHERE user_id = {placeholder} AND provider_id = {placeholder} AND auth_type = {placeholder}
''', (user_id, provider_id, auth_type))
else:
cursor.execute(f'''
SELECT id, auth_type, credentials, created_at, updated_at
FROM user_oauth2_credentials
WHERE user_id = {placeholder} AND provider_id = {placeholder}
ORDER BY updated_at DESC
LIMIT 1
''', (user_id, provider_id))
row = cursor.fetchone()
if row:
return {
'id': row[0],
'auth_type': row[1],
'credentials': json.loads(row[2]),
'created_at': row[3],
'updated_at': row[4]
}
return None
def get_all_user_oauth2_credentials(self, user_id: int) -> List[Dict]:
"""
Get all OAuth2 credentials for a user.
Args:
user_id: User ID
Returns:
List of credential dictionaries
"""
with self._get_connection() as conn:
cursor = conn.cursor()
placeholder = '?' if self.db_type == 'sqlite' else '%s'
cursor.execute(f'''
SELECT id, provider_id, auth_type, credentials, created_at, updated_at
FROM user_oauth2_credentials
WHERE user_id = {placeholder}
ORDER BY provider_id, auth_type
''', (user_id,))
credentials = []
for row in cursor.fetchall():
credentials.append({
'id': row[0],
'provider_id': row[1],
'auth_type': row[2],
'credentials': json.loads(row[3]),
'created_at': row[4],
'updated_at': row[5]
})
return credentials
def delete_user_oauth2_credentials(self, user_id: int, provider_id: str, auth_type: str = None) -> int:
"""
Delete OAuth2 credentials for a user/provider combination.
Args:
user_id: User ID
provider_id: Provider identifier
auth_type: Optional auth type filter (if None, deletes all for provider)
Returns:
Number of records deleted
"""
with self._get_connection() as conn:
cursor = conn.cursor()
placeholder = '?' if self.db_type == 'sqlite' else '%s'
if auth_type:
cursor.execute(f'''
DELETE FROM user_oauth2_credentials
WHERE user_id = {placeholder} AND provider_id = {placeholder} AND auth_type = {placeholder}
''', (user_id, provider_id, auth_type))
else:
cursor.execute(f'''
DELETE FROM user_oauth2_credentials
WHERE user_id = {placeholder} AND provider_id = {placeholder}
''', (user_id, provider_id))
conn.commit()
return cursor.rowcount
# Global database manager instance # Global database manager instance
_db_manager: Optional[DatabaseManager] = None _db_manager: Optional[DatabaseManager] = None
......
...@@ -351,7 +351,7 @@ class RequestHandler: ...@@ -351,7 +351,7 @@ class RequestHandler:
logger.info("No API key required for this provider") logger.info("No API key required for this provider")
logger.info(f"Getting provider handler for {provider_id}") logger.info(f"Getting provider handler for {provider_id}")
handler = get_provider_handler(provider_id, api_key) handler = get_provider_handler(provider_id, api_key, user_id=self.user_id)
logger.info(f"Provider handler obtained: {handler.__class__.__name__}") logger.info(f"Provider handler obtained: {handler.__class__.__name__}")
if handler.is_rate_limited(): if handler.is_rate_limited():
...@@ -518,7 +518,7 @@ class RequestHandler: ...@@ -518,7 +518,7 @@ class RequestHandler:
else: else:
api_key = None api_key = None
handler = get_provider_handler(provider_id, api_key) handler = get_provider_handler(provider_id, api_key, user_id=self.user_id)
if handler.is_rate_limited(): if handler.is_rate_limited():
raise HTTPException(status_code=503, detail="Provider temporarily unavailable") raise HTTPException(status_code=503, detail="Provider temporarily unavailable")
...@@ -1117,7 +1117,7 @@ class RequestHandler: ...@@ -1117,7 +1117,7 @@ class RequestHandler:
else: else:
api_key = None api_key = None
handler = get_provider_handler(provider_id, api_key) handler = get_provider_handler(provider_id, api_key, user_id=self.user_id)
try: try:
# Apply rate limiting # Apply rate limiting
await handler.apply_rate_limit() await handler.apply_rate_limit()
...@@ -1398,7 +1398,7 @@ class RequestHandler: ...@@ -1398,7 +1398,7 @@ class RequestHandler:
else: else:
api_key = None api_key = None
handler = get_provider_handler(provider_id, api_key) handler = get_provider_handler(provider_id, api_key, user_id=self.user_id)
if handler.is_rate_limited(): if handler.is_rate_limited():
raise HTTPException(status_code=503, detail="Provider temporarily unavailable") raise HTTPException(status_code=503, detail="Provider temporarily unavailable")
...@@ -1427,7 +1427,7 @@ class RequestHandler: ...@@ -1427,7 +1427,7 @@ class RequestHandler:
else: else:
api_key = None api_key = None
handler = get_provider_handler(provider_id, api_key) handler = get_provider_handler(provider_id, api_key, user_id=self.user_id)
if handler.is_rate_limited(): if handler.is_rate_limited():
raise HTTPException(status_code=503, detail="Provider temporarily unavailable") raise HTTPException(status_code=503, detail="Provider temporarily unavailable")
...@@ -1456,7 +1456,7 @@ class RequestHandler: ...@@ -1456,7 +1456,7 @@ class RequestHandler:
else: else:
api_key = None api_key = None
handler = get_provider_handler(provider_id, api_key) handler = get_provider_handler(provider_id, api_key, user_id=self.user_id)
if handler.is_rate_limited(): if handler.is_rate_limited():
raise HTTPException(status_code=503, detail="Provider temporarily unavailable") raise HTTPException(status_code=503, detail="Provider temporarily unavailable")
...@@ -1489,7 +1489,7 @@ class RequestHandler: ...@@ -1489,7 +1489,7 @@ class RequestHandler:
else: else:
api_key = None api_key = None
handler = get_provider_handler(provider_id, api_key) handler = get_provider_handler(provider_id, api_key, user_id=self.user_id)
if handler.is_rate_limited(): if handler.is_rate_limited():
raise HTTPException(status_code=503, detail="Provider temporarily unavailable") raise HTTPException(status_code=503, detail="Provider temporarily unavailable")
...@@ -2025,7 +2025,7 @@ class RotationHandler: ...@@ -2025,7 +2025,7 @@ class RotationHandler:
api_key = self._get_api_key(provider_id, provider.get('api_key')) api_key = self._get_api_key(provider_id, provider.get('api_key'))
# Check if provider is rate limited/deactivated # Check if provider is rate limited/deactivated
provider_handler = get_provider_handler(provider_id, api_key) provider_handler = get_provider_handler(provider_id, api_key, user_id=self.user_id)
if provider_handler.is_rate_limited(): if provider_handler.is_rate_limited():
logger.warning(f" [SKIPPED] Provider {provider_id} is rate limited/deactivated") logger.warning(f" [SKIPPED] Provider {provider_id} is rate limited/deactivated")
logger.warning(f" Reason: Provider has exceeded failure threshold or is in cooldown period") logger.warning(f" Reason: Provider has exceeded failure threshold or is in cooldown period")
...@@ -2367,7 +2367,7 @@ class RotationHandler: ...@@ -2367,7 +2367,7 @@ class RotationHandler:
model_name = current_model['name'] model_name = current_model['name']
logger.info(f"Getting provider handler for {provider_id}") logger.info(f"Getting provider handler for {provider_id}")
handler = get_provider_handler(provider_id, api_key) handler = get_provider_handler(provider_id, api_key, user_id=self.user_id)
logger.info(f"Provider handler obtained: {handler.__class__.__name__}") logger.info(f"Provider handler obtained: {handler.__class__.__name__}")
if handler.is_rate_limited(): if handler.is_rate_limited():
...@@ -2800,9 +2800,13 @@ class RotationHandler: ...@@ -2800,9 +2800,13 @@ class RotationHandler:
}) })
} }
yield f"data: {json.dumps(final_chunk)}\n\n".encode('utf-8') yield f"data: {json.dumps(final_chunk)}\n\n".encode('utf-8')
# Yield control to event loop to ensure chunk is flushed to client
await asyncio.sleep(0)
# Send [DONE] marker # Send [DONE] marker
yield b"data: [DONE]\n\n" yield b"data: [DONE]\n\n"
# Final flush to ensure all buffered data reaches the client
await asyncio.sleep(0)
return StreamingResponse(error_stream_generator(), media_type="text/event-stream", status_code=status_code) return StreamingResponse(error_stream_generator(), media_type="text/event-stream", status_code=status_code)
...@@ -3106,6 +3110,8 @@ class RotationHandler: ...@@ -3106,6 +3110,8 @@ class RotationHandler:
}] }]
} }
yield f"data: {json.dumps(final_chunk)}\n\n".encode('utf-8') yield f"data: {json.dumps(final_chunk)}\n\n".encode('utf-8')
# Yield control to event loop to ensure final chunk is flushed to client
await asyncio.sleep(0)
elif is_kilo_provider: elif is_kilo_provider:
# Handle Kilo/KiloCode streaming response # Handle Kilo/KiloCode streaming response
# Kilo returns an async generator that yields OpenAI-compatible SSE bytes # Kilo returns an async generator that yields OpenAI-compatible SSE bytes
......
...@@ -676,13 +676,13 @@ class MCPServer: ...@@ -676,13 +676,13 @@ class MCPServer:
'get_rotation_settings': self._get_rotation_settings, 'get_rotation_settings': self._get_rotation_settings,
}) })
# Add fullconfig-level tools # Add fullconfig-level tools (now support user_id for routing)
if auth_level >= MCPAuthLevel.FULLCONFIG: if auth_level >= MCPAuthLevel.FULLCONFIG:
handlers.update({ handlers.update({
'get_providers_config': self._get_providers_config, 'get_providers_config': self._get_providers_config,
'set_autoselect_config': self._set_autoselect_config, 'set_autoselect_config': lambda args: self._set_autoselect_config(args, user_id),
'set_rotation_config': self._set_rotation_config, 'set_rotation_config': lambda args: self._set_rotation_config(args, user_id),
'set_provider_config': self._set_provider_config, 'set_provider_config': lambda args: self._set_provider_config(args, user_id),
'get_server_config': self._get_server_config, 'get_server_config': self._get_server_config,
'set_server_config': self._set_server_config, 'set_server_config': self._set_server_config,
'get_tor_status': self._get_tor_status, 'get_tor_status': self._get_tor_status,
...@@ -719,7 +719,7 @@ class MCPServer: ...@@ -719,7 +719,7 @@ class MCPServer:
raise HTTPException(status_code=404, detail=f"Tool '{tool_name}' not found") raise HTTPException(status_code=404, detail=f"Tool '{tool_name}' not found")
handler = handlers[tool_name] handler = handlers[tool_name]
return await handler(arguments, user_id) return await handler(arguments)
async def _list_models(self, args: Dict) -> Dict: async def _list_models(self, args: Dict) -> Dict:
"""List all available models""" """List all available models"""
...@@ -994,15 +994,19 @@ class MCPServer: ...@@ -994,15 +994,19 @@ class MCPServer:
else: else:
return {"providers": providers_data} return {"providers": providers_data}
async def _set_autoselect_config(self, args: Dict) -> Dict: async def _set_autoselect_config(self, args: Dict, user_id: Optional[int] = None) -> Dict:
"""Set autoselect configuration""" """Set autoselect configuration. Config admin saves to files, other users save to database."""
autoselect_id = args.get('autoselect_id') autoselect_id = args.get('autoselect_id')
autoselect_data = args.get('autoselect_data') autoselect_data = args.get('autoselect_data')
if not autoselect_id or not autoselect_data: if not autoselect_id or not autoselect_data:
raise HTTPException(status_code=400, detail="autoselect_id and autoselect_data are required") raise HTTPException(status_code=400, detail="autoselect_id and autoselect_data are required")
# Load existing config # Check if user is config admin (user_id is None means config admin)
is_config_admin = user_id is None
if is_config_admin:
# Config admin: save to JSON files
config_path = Path.home() / '.aisbf' / 'autoselect.json' config_path = Path.home() / '.aisbf' / 'autoselect.json'
if not config_path.exists(): if not config_path.exists():
config_path = Path(__file__).parent.parent / 'config' / 'autoselect.json' config_path = Path(__file__).parent.parent / 'config' / 'autoselect.json'
...@@ -1021,22 +1025,31 @@ class MCPServer: ...@@ -1021,22 +1025,31 @@ class MCPServer:
with open(save_path, 'w') as f: with open(save_path, 'w') as f:
json.dump(full_config, f, indent=2) json.dump(full_config, f, indent=2)
return {"status": "success", "message": f"Autoselect '{autoselect_id}' saved. Restart server for changes to take effect."} return {"status": "success", "message": f"Autoselect '{autoselect_id}' saved to file. Restart server for changes to take effect."}
else:
# Database user: save to database
from .database import get_database
db = get_database()
db.save_user_autoselect(user_id, autoselect_id, autoselect_data)
return {"status": "success", "message": f"Autoselect '{autoselect_id}' saved to database for user {user_id}."}
async def _set_rotation_config(self, args: Dict) -> Dict: async def _set_rotation_config(self, args: Dict, user_id: Optional[int] = None) -> Dict:
"""Set rotation configuration""" """Set rotation configuration. Config admin saves to files, other users save to database."""
rotation_id = args.get('rotation_id') rotation_id = args.get('rotation_id')
rotation_data = args.get('rotation_data') rotation_data = args.get('rotation_data')
if not rotation_id or not rotation_data: if not rotation_id or not rotation_data:
raise HTTPException(status_code=400, detail="rotation_id and rotation_data are required") raise HTTPException(status_code=400, detail="rotation_id and rotation_data are required")
# Load existing config # Check if user is config admin (user_id is None means config admin)
is_config_admin = user_id is None
if is_config_admin:
# Config admin: save to JSON files
config_path = Path.home() / '.aisbf' / 'rotations.json' config_path = Path.home() / '.aisbf' / 'rotations.json'
if not config_path.exists(): if not config_path.exists():
with open(config_path) as f: config_path = Path(__file__).parent.parent / 'config' / 'rotations.json'
full_config = json.load(f)
else:
with open(config_path) as f: with open(config_path) as f:
full_config = json.load(f) full_config = json.load(f)
...@@ -1051,17 +1064,27 @@ class MCPServer: ...@@ -1051,17 +1064,27 @@ class MCPServer:
with open(save_path, 'w') as f: with open(save_path, 'w') as f:
json.dump(full_config, f, indent=2) json.dump(full_config, f, indent=2)
return {"status": "success", "message": f"Rotation '{rotation_id}' saved. Restart server for changes to take effect."} return {"status": "success", "message": f"Rotation '{rotation_id}' saved to file. Restart server for changes to take effect."}
else:
# Database user: save to database
from .database import get_database
db = get_database()
db.save_user_rotation(user_id, rotation_id, rotation_data)
return {"status": "success", "message": f"Rotation '{rotation_id}' saved to database for user {user_id}."}
async def _set_provider_config(self, args: Dict) -> Dict: async def _set_provider_config(self, args: Dict, user_id: Optional[int] = None) -> Dict:
"""Set provider configuration""" """Set provider configuration. Config admin saves to files, other users save to database."""
provider_id = args.get('provider_id') provider_id = args.get('provider_id')
provider_data = args.get('provider_data') provider_data = args.get('provider_data')
if not provider_id or not provider_data: if not provider_id or not provider_data:
raise HTTPException(status_code=400, detail="provider_id and provider_data are required") raise HTTPException(status_code=400, detail="provider_id and provider_data are required")
# Load existing config # Check if user is config admin (user_id is None means config admin)
is_config_admin = user_id is None
if is_config_admin:
# Config admin: save to JSON files
config_path = Path.home() / '.aisbf' / 'providers.json' config_path = Path.home() / '.aisbf' / 'providers.json'
if not config_path.exists(): if not config_path.exists():
config_path = Path(__file__).parent.parent / 'config' / 'providers.json' config_path = Path(__file__).parent.parent / 'config' / 'providers.json'
...@@ -1080,7 +1103,13 @@ class MCPServer: ...@@ -1080,7 +1103,13 @@ class MCPServer:
with open(save_path, 'w') as f: with open(save_path, 'w') as f:
json.dump(full_config, f, indent=2) json.dump(full_config, f, indent=2)
return {"status": "success", "message": f"Provider '{provider_id}' saved. Restart server for changes to take effect."} return {"status": "success", "message": f"Provider '{provider_id}' saved to file. Restart server for changes to take effect."}
else:
# Database user: save to database
from .database import get_database
db = get_database()
db.save_user_provider(user_id, provider_id, provider_data)
return {"status": "success", "message": f"Provider '{provider_id}' saved to database for user {user_id}."}
async def _get_server_config(self, args: Dict) -> Dict: async def _get_server_config(self, args: Dict) -> Dict:
"""Get server configuration""" """Get server configuration"""
......
...@@ -56,12 +56,13 @@ PROVIDER_HANDLERS = { ...@@ -56,12 +56,13 @@ PROVIDER_HANDLERS = {
} }
def get_provider_handler(provider_id: str, api_key: Optional[str] = None) -> BaseProviderHandler: def get_provider_handler(provider_id: str, api_key: Optional[str] = None, user_id: Optional[int] = None) -> BaseProviderHandler:
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
logger.info(f"=== get_provider_handler START ===") logger.info(f"=== get_provider_handler START ===")
logger.info(f"Provider ID: {provider_id}") logger.info(f"Provider ID: {provider_id}")
logger.info(f"API key provided: {bool(api_key)}") logger.info(f"API key provided: {bool(api_key)}")
logger.info(f"User ID: {user_id}")
provider_config = config.get_provider(provider_id) provider_config = config.get_provider(provider_id)
logger.info(f"Provider config: {provider_config}") logger.info(f"Provider config: {provider_config}")
...@@ -76,7 +77,13 @@ def get_provider_handler(provider_id: str, api_key: Optional[str] = None) -> Bas ...@@ -76,7 +77,13 @@ def get_provider_handler(provider_id: str, api_key: Optional[str] = None) -> Bas
logger.error(f"Unsupported provider type: {provider_config.type}") logger.error(f"Unsupported provider type: {provider_config.type}")
raise ValueError(f"Unsupported provider type: {provider_config.type}") raise ValueError(f"Unsupported provider type: {provider_config.type}")
# All handlers now accept api_key as optional parameter # Check if handler supports user_id parameter (CodexProviderHandler does)
import inspect
sig = inspect.signature(handler_class.__init__)
if 'user_id' in sig.parameters:
logger.info(f"Creating handler with provider_id, optional api_key, and user_id")
handler = handler_class(provider_id, api_key, user_id=user_id)
else:
logger.info(f"Creating handler with provider_id and optional api_key") logger.info(f"Creating handler with provider_id and optional api_key")
handler = handler_class(provider_id, api_key) handler = handler_class(provider_id, api_key)
......
...@@ -39,12 +39,16 @@ class ClaudeProviderHandler(BaseProviderHandler): ...@@ -39,12 +39,16 @@ class ClaudeProviderHandler(BaseProviderHandler):
the official Anthropic Python SDK. OAuth2 access tokens are passed as the official Anthropic Python SDK. OAuth2 access tokens are passed as
the api_key parameter to the SDK, which handles proper message formatting, the api_key parameter to the SDK, which handles proper message formatting,
retries, and streaming. retries, and streaming.
For admin users (user_id=None), credentials are loaded from file.
For non-admin users, credentials are loaded from the database.
""" """
# NOTE: OAuth2 API uses its own model naming scheme that differs from standard Anthropic API # NOTE: OAuth2 API uses its own model naming scheme that differs from standard Anthropic API
def __init__(self, provider_id: str, api_key: Optional[str] = None): def __init__(self, provider_id: str, api_key: Optional[str] = None, user_id: Optional[int] = None):
super().__init__(provider_id, api_key) super().__init__(provider_id, api_key)
self.user_id = user_id
self.provider_config = config.get_provider(provider_id) self.provider_config = config.get_provider(provider_id)
# Get credentials file path from config # Get credentials file path from config
...@@ -53,7 +57,12 @@ class ClaudeProviderHandler(BaseProviderHandler): ...@@ -53,7 +57,12 @@ class ClaudeProviderHandler(BaseProviderHandler):
if claude_config and isinstance(claude_config, dict): if claude_config and isinstance(claude_config, dict):
credentials_file = claude_config.get('credentials_file') credentials_file = claude_config.get('credentials_file')
# Initialize ClaudeAuth with credentials file (handles OAuth2 flow) # Only the ONE config admin (user_id=None from aisbf.json) uses file-based credentials
# All other users (including database admins with user_id) use database credentials
if user_id is not None:
self.auth = self._load_auth_from_db(provider_id, credentials_file)
else:
# Config admin (from aisbf.json): use file-based credentials
from ..auth.claude import ClaudeAuth from ..auth.claude import ClaudeAuth
self.auth = ClaudeAuth(credentials_file=credentials_file) self.auth = ClaudeAuth(credentials_file=credentials_file)
...@@ -92,6 +101,39 @@ class ClaudeProviderHandler(BaseProviderHandler): ...@@ -92,6 +101,39 @@ class ClaudeProviderHandler(BaseProviderHandler):
# Initialize persistent identifiers for metadata # Initialize persistent identifiers for metadata
self._init_session_identifiers() self._init_session_identifiers()
def _load_auth_from_db(self, provider_id: str, credentials_file: str):
"""
Load OAuth2 credentials from database for non-admin users.
Falls back to file-based credentials if not found in database.
"""
try:
from ..database import get_database
from ..auth.claude import ClaudeAuth
db = get_database()
if db:
db_creds = db.get_user_oauth2_credentials(
user_id=self.user_id,
provider_id=provider_id,
auth_type='claude_oauth2'
)
if db_creds and db_creds.get('credentials'):
# Create auth instance with database credentials
auth = ClaudeAuth(credentials_file=credentials_file)
# Override the loaded credentials with database credentials
auth.tokens = db_creds['credentials'].get('tokens', {})
import logging
logging.getLogger(__name__).info(f"ClaudeProviderHandler: Loaded credentials from database for user {self.user_id}")
return auth
except Exception as e:
import logging
logging.getLogger(__name__).warning(f"ClaudeProviderHandler: Failed to load credentials from database: {e}")
# Fall back to file-based credentials
from ..auth.claude import ClaudeAuth
import logging
logging.getLogger(__name__).info(f"ClaudeProviderHandler: Falling back to file-based credentials for user {self.user_id}")
return ClaudeAuth(credentials_file=credentials_file)
def _init_session_identifiers(self): def _init_session_identifiers(self):
"""Initialize persistent session identifiers (device_id, account_uuid, session_id).""" """Initialize persistent session identifiers (device_id, account_uuid, session_id)."""
import uuid import uuid
...@@ -1190,7 +1232,11 @@ class ClaudeProviderHandler(BaseProviderHandler): ...@@ -1190,7 +1232,11 @@ class ClaudeProviderHandler(BaseProviderHandler):
} }
yield f"data: {json.dumps(final_chunk, ensure_ascii=False)}\n\n".encode('utf-8') yield f"data: {json.dumps(final_chunk, ensure_ascii=False)}\n\n".encode('utf-8')
# Yield control to event loop to ensure chunk is flushed to client
await asyncio.sleep(0)
yield b"data: [DONE]\n\n" yield b"data: [DONE]\n\n"
# Final flush to ensure all buffered data reaches the client
await asyncio.sleep(0)
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
logger.warning(f"Failed to parse streaming chunk: {e}") logger.warning(f"Failed to parse streaming chunk: {e}")
...@@ -1553,7 +1599,11 @@ class ClaudeProviderHandler(BaseProviderHandler): ...@@ -1553,7 +1599,11 @@ class ClaudeProviderHandler(BaseProviderHandler):
} }
yield f"data: {json.dumps(final_chunk, ensure_ascii=False)}\n\n".encode('utf-8') yield f"data: {json.dumps(final_chunk, ensure_ascii=False)}\n\n".encode('utf-8')
# Yield control to event loop to ensure chunk is flushed to client
await asyncio.sleep(0)
yield b"data: [DONE]\n\n" yield b"data: [DONE]\n\n"
# Final flush to ensure all buffered data reaches the client
await asyncio.sleep(0)
logger.info(f"ClaudeProviderHandler: SDK streaming completed successfully") logger.info(f"ClaudeProviderHandler: SDK streaming completed successfully")
self.record_success() self.record_success()
......
...@@ -40,10 +40,14 @@ class CodexProviderHandler(BaseProviderHandler): ...@@ -40,10 +40,14 @@ class CodexProviderHandler(BaseProviderHandler):
Uses the same OpenAI-compatible protocol but authenticates via OAuth2 Uses the same OpenAI-compatible protocol but authenticates via OAuth2
using the Codex OAuth2 flow (device code or browser-based PKCE). using the Codex OAuth2 flow (device code or browser-based PKCE).
For admin users (user_id=None), credentials are loaded from file.
For non-admin users, credentials are loaded from the database.
""" """
def __init__(self, provider_id: str, api_key: Optional[str] = None): def __init__(self, provider_id: str, api_key: Optional[str] = None, user_id: Optional[int] = None):
super().__init__(provider_id, api_key) super().__init__(provider_id, api_key)
self.user_id = user_id
# Get provider config # Get provider config
provider_config = config.providers.get(provider_id) provider_config = config.providers.get(provider_id)
...@@ -54,6 +58,12 @@ class CodexProviderHandler(BaseProviderHandler): ...@@ -54,6 +58,12 @@ class CodexProviderHandler(BaseProviderHandler):
credentials_file = codex_config.get('credentials_file', '~/.aisbf/codex_credentials.json') credentials_file = codex_config.get('credentials_file', '~/.aisbf/codex_credentials.json')
issuer = codex_config.get('issuer', 'https://auth.openai.com') issuer = codex_config.get('issuer', 'https://auth.openai.com')
# Only the ONE config admin (user_id=None from aisbf.json) uses file-based credentials
# All other users (including database admins with user_id) use database credentials
if user_id is not None:
self.oauth2 = self._load_oauth2_from_db(provider_id, credentials_file, issuer)
else:
# Config admin (from aisbf.json): use file-based credentials
self.oauth2 = CodexOAuth2( self.oauth2 = CodexOAuth2(
credentials_file=credentials_file, credentials_file=credentials_file,
issuer=issuer, issuer=issuer,
...@@ -73,6 +83,40 @@ class CodexProviderHandler(BaseProviderHandler): ...@@ -73,6 +83,40 @@ class CodexProviderHandler(BaseProviderHandler):
self.client = OpenAI(base_url=endpoint, api_key=resolved_api_key or "dummy") self.client = OpenAI(base_url=endpoint, api_key=resolved_api_key or "dummy")
self._oauth2_enabled = not api_key and provider_config and not provider_config.api_key_required self._oauth2_enabled = not api_key and provider_config and not provider_config.api_key_required
def _load_oauth2_from_db(self, provider_id: str, credentials_file: str, issuer: str) -> CodexOAuth2:
"""
Load OAuth2 credentials from database for non-admin users.
Falls back to file-based credentials if not found in database.
"""
try:
from ..database import get_database
db = get_database()
if db:
db_creds = db.get_user_oauth2_credentials(
user_id=self.user_id,
provider_id=provider_id,
auth_type='codex_oauth2'
)
if db_creds and db_creds.get('credentials'):
# Create OAuth2 instance with database credentials
oauth2 = CodexOAuth2(
credentials_file=credentials_file,
issuer=issuer,
)
# Override the loaded credentials with database credentials
oauth2.credentials = db_creds['credentials']
logger.info(f"CodexProviderHandler: Loaded credentials from database for user {self.user_id}")
return oauth2
except Exception as e:
logger.warning(f"CodexProviderHandler: Failed to load credentials from database: {e}")
# Fall back to file-based credentials
logger.info(f"CodexProviderHandler: Falling back to file-based credentials for user {self.user_id}")
return CodexOAuth2(
credentials_file=credentials_file,
issuer=issuer,
)
async def _get_valid_api_key(self) -> str: async def _get_valid_api_key(self) -> str:
"""Get a valid API key, refreshing OAuth2 if needed.""" """Get a valid API key, refreshing OAuth2 if needed."""
# If we have an API key from config, use it # If we have an API key from config, use it
......
...@@ -32,10 +32,14 @@ from .base import BaseProviderHandler, AISBF_DEBUG ...@@ -32,10 +32,14 @@ from .base import BaseProviderHandler, AISBF_DEBUG
class KiloProviderHandler(BaseProviderHandler): class KiloProviderHandler(BaseProviderHandler):
""" """
Handler for Kilo Gateway (OpenAI-compatible with OAuth2 support). Handler for Kilo Gateway (OpenAI-compatible with OAuth2 support).
For admin users (user_id=None), credentials are loaded from file.
For non-admin users, credentials are loaded from the database.
""" """
def __init__(self, provider_id: str, api_key: Optional[str] = None): def __init__(self, provider_id: str, api_key: Optional[str] = None, user_id: Optional[int] = None):
super().__init__(provider_id, api_key) super().__init__(provider_id, api_key)
self.user_id = user_id
self.provider_config = config.get_provider(provider_id) self.provider_config = config.get_provider(provider_id)
kilo_config = getattr(self.provider_config, 'kilo_config', None) kilo_config = getattr(self.provider_config, 'kilo_config', None)
...@@ -47,6 +51,12 @@ class KiloProviderHandler(BaseProviderHandler): ...@@ -47,6 +51,12 @@ class KiloProviderHandler(BaseProviderHandler):
credentials_file = kilo_config.get('credentials_file') credentials_file = kilo_config.get('credentials_file')
api_base = kilo_config.get('api_base') api_base = kilo_config.get('api_base')
# Only the ONE config admin (user_id=None from aisbf.json) uses file-based credentials
# All other users (including database admins with user_id) use database credentials
if user_id is not None:
self.oauth2 = self._load_oauth2_from_db(provider_id, credentials_file, api_base)
else:
# Config admin (from aisbf.json): use file-based credentials
from ..auth.kilo import KiloOAuth2 from ..auth.kilo import KiloOAuth2
self.oauth2 = KiloOAuth2(credentials_file=credentials_file, api_base=api_base) self.oauth2 = KiloOAuth2(credentials_file=credentials_file, api_base=api_base)
...@@ -62,6 +72,39 @@ class KiloProviderHandler(BaseProviderHandler): ...@@ -62,6 +72,39 @@ class KiloProviderHandler(BaseProviderHandler):
self.client = OpenAI(base_url=endpoint, api_key=api_key or "placeholder") self.client = OpenAI(base_url=endpoint, api_key=api_key or "placeholder")
def _load_oauth2_from_db(self, provider_id: str, credentials_file: str, api_base: str):
"""
Load OAuth2 credentials from database for non-admin users.
Falls back to file-based credentials if not found in database.
"""
try:
from ..database import get_database
from ..auth.kilo import KiloOAuth2
db = get_database()
if db:
db_creds = db.get_user_oauth2_credentials(
user_id=self.user_id,
provider_id=provider_id,
auth_type='kilo_oauth2'
)
if db_creds and db_creds.get('credentials'):
# Create OAuth2 instance with database credentials
oauth2 = KiloOAuth2(credentials_file=credentials_file, api_base=api_base)
# Override the loaded credentials with database credentials
oauth2.credentials = db_creds['credentials']
import logging
logging.getLogger(__name__).info(f"KiloProviderHandler: Loaded credentials from database for user {self.user_id}")
return oauth2
except Exception as e:
import logging
logging.getLogger(__name__).warning(f"KiloProviderHandler: Failed to load credentials from database: {e}")
# Fall back to file-based credentials
from ..auth.kilo import KiloOAuth2
import logging
logging.getLogger(__name__).info(f"KiloProviderHandler: Falling back to file-based credentials for user {self.user_id}")
return KiloOAuth2(credentials_file=credentials_file, api_base=api_base)
async def _ensure_authenticated(self) -> str: async def _ensure_authenticated(self) -> str:
"""Ensure user is authenticated and return valid token.""" """Ensure user is authenticated and return valid token."""
import logging import logging
......
...@@ -22,6 +22,7 @@ Why did the programmer quit his job? Because he didn't get arrays! ...@@ -22,6 +22,7 @@ Why did the programmer quit his job? Because he didn't get arrays!
""" """
import httpx import httpx
import asyncio
import time import time
import os import os
import json import json
...@@ -363,6 +364,8 @@ class KiroProviderHandler(BaseProviderHandler): ...@@ -363,6 +364,8 @@ class KiroProviderHandler(BaseProviderHandler):
}] }]
} }
yield f"data: {json.dumps(tool_calls_chunk, ensure_ascii=False)}\n\n".encode('utf-8') yield f"data: {json.dumps(tool_calls_chunk, ensure_ascii=False)}\n\n".encode('utf-8')
# Yield control to event loop to ensure chunk is flushed to client
await asyncio.sleep(0)
final_chunk = { final_chunk = {
"id": completion_id, "id": completion_id,
...@@ -382,7 +385,11 @@ class KiroProviderHandler(BaseProviderHandler): ...@@ -382,7 +385,11 @@ class KiroProviderHandler(BaseProviderHandler):
} }
yield f"data: {json.dumps(final_chunk, ensure_ascii=False)}\n\n".encode('utf-8') yield f"data: {json.dumps(final_chunk, ensure_ascii=False)}\n\n".encode('utf-8')
# Yield control to event loop to ensure chunk is flushed to client
await asyncio.sleep(0)
yield b"data: [DONE]\n\n" yield b"data: [DONE]\n\n"
# Final flush to ensure all buffered data reaches the client
await asyncio.sleep(0)
def _get_models_cache_path(self) -> str: def _get_models_cache_path(self) -> str:
"""Get the path to the models cache file.""" """Get the path to the models cache file."""
......
This diff is collapsed.
...@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" ...@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project] [project]
name = "aisbf" name = "aisbf"
version = "0.9.8" version = "0.9.9"
description = "AISBF - AI Service Broker Framework || AI Should Be Free - A modular proxy server for managing multiple AI provider integrations" description = "AISBF - AI Service Broker Framework || AI Should Be Free - A modular proxy server for managing multiple AI provider integrations"
readme = "README.md" readme = "README.md"
license = "GPL-3.0-or-later" license = "GPL-3.0-or-later"
......
...@@ -49,7 +49,7 @@ class InstallCommand(_install): ...@@ -49,7 +49,7 @@ class InstallCommand(_install):
setup( setup(
name="aisbf", name="aisbf",
version="0.9.8", version="0.9.9",
author="AISBF Contributors", author="AISBF Contributors",
author_email="stefy@nexlab.net", author_email="stefy@nexlab.net",
description="AISBF - AI Service Broker Framework || AI Should Be Free - A modular proxy server for managing multiple AI provider integrations", description="AISBF - AI Service Broker Framework || AI Should Be Free - A modular proxy server for managing multiple AI provider integrations",
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment