feat: add RunPod provider runtime management

parent 3156c83c
RunPod implementation recovery plan for next session.
Goal
- Add a new provider type `runpod`
- Support multiple RunPod accounts by allowing multiple AISBF providers of type `runpod`
- Support two modes:
- pod-backed/serverless-backed wrapper provider with one wrapper mode per provider: `openai`, `coderai`, or `ollama`
- `runpod_public` provider represented as one AISBF provider with many discovered models/endpoints
- Auto-start stopped pods on request and wait until ready
- Cache pod/endpoint status in DB/cache so behavior is consistent across multiple AISBF instances
- Stop idle pods after configurable inactivity
- Allow serverless endpoint template usage as an alternative to pod-backed mode
Product decisions already made
- Scope: full lifecycle now
- Wrapper mode:
- pod-backed `runpod` providers store one wrapper mode per provider
- `runpod_public` auto-detects protocol per discovered model, with optional manual override per model
- Cold start behavior: auto start + wait
- `runpod_public` shape: one provider, many discovered models
- Management API preference: use the most recent/current supported RunPod management API surface between GraphQL and REST/OpenAPI
- Do not hardcode GraphQL if REST/OpenAPI is newer
Critical first step next session
- Verify which RunPod management API is the current supported one:
- inspect current REST/OpenAPI docs/spec
- inspect current GraphQL docs/spec
- use whichever is the newer/current supported API surface
- Then map exact operations for:
- pod status/start/stop
- template lookup/use
- endpoint discovery
- serverless endpoint creation/use
- public endpoint metadata and request format
Docs already identified
- `https://docs.runpod.io/api-reference/overview`
- `https://docs.runpod.io/llms.txt`
- `https://docs.runpod.io/public-endpoints/requests`
- `https://rest.runpod.io/v1/openapi.json`
Implementation map in AISBF
- `aisbf/config.py`
- extend `ProviderConfig` with `runpod_config: Optional[Dict] = None`
- `aisbf/providers/__init__.py`
- register new provider type `runpod`
- new file `aisbf/providers/runpod.py`
- main handler/orchestrator
- `templates/dashboard/providers.html`
- add `runpod` provider type option and config UI
- `aisbf/routes/dashboard/providers.py`
- add any RunPod-specific dashboard actions/status endpoints if needed
- `aisbf/app/model_cache.py`
- integrate caching/refresh for `runpod_public` discovered models
- `aisbf/database.py`
- add persistent lifecycle/runtime state for runpod providers
Planned `runpod_config` structure
Example target shape:
```json
{
"mode": "pod",
"wrapper_mode": "openai",
"account_name": "personal-runpod",
"management_api": "auto",
"idle_shutdown_ms": 900000,
"startup_poll_interval_ms": 3000,
"startup_timeout_ms": 300000,
"pod_id": "abc123",
"template_id": "tmpl_xyz",
"endpoint_id": "",
"serverless_template_id": "",
"public_endpoint_protocol_default": "auto",
"public_models": {
"model-slug": {
"protocol": "openai",
"capabilities": ["chat", "vision"]
}
}
}
```
Modes
- `pod`
- `serverless_template`
- `public`
Wrapper modes for non-public
- `openai`
- `ollama`
- `coderai`
Representation rules
- Non-public runpod providers:
- one wrapper mode per provider
- lifecycle managed by AISBF
- `runpod_public`:
- one provider with many discovered models/endpoints
- protocol auto-detected per model
- optional manual override per model in config
Architecture to implement
1. `RunpodProviderHandler` as orchestrator
- It should handle lifecycle and dispatch, not just protocol forwarding
- Responsibilities:
- load `runpod_config`
- ensure pod/endpoint is ready before forwarding requests
- cache status/discovery
- delegate to existing protocol behavior
2. Delegation model
- For pod/serverless-backed providers:
- once ready, speak protocol based on provider-level `wrapper_mode`
- delegate internally to existing handlers:
- `OpenAIProviderHandler`
- `OllamaProviderHandler`
- `CoderAIProviderHandler`
- For `runpod_public`:
- discover public models/endpoints
- resolve protocol per model
- dispatch request using model-specific protocol behavior
3. Readiness lifecycle
- On request for pod-backed provider:
- read cached status from DB/cache
- if running and endpoint known, reuse
- if stopped, start pod
- poll until ready or timeout
- persist status/ready endpoint back to DB/cache
- On request for serverless-template mode:
- resolve or create usable endpoint from template as configured
- cache endpoint metadata
4. Idle shutdown
- Store persistent last-used timestamps and runtime state in DB
- Add background loop that:
- scans runpod provider state
- if `now - last_used_at > idle_shutdown_ms` and provider is pod-backed and running
- stop the pod
- persist updated status
Database work needed
Add a new table in `aisbf/database.py`, e.g. `runpod_provider_state` with fields like:
- `provider_scope` (`global` / `user`)
- `owner_user_id`
- `provider_id`
- `mode`
- `wrapper_mode`
- `resource_id`
- `resource_kind` (`pod`, `endpoint`, `public`)
- `status`
- `endpoint_url`
- `public_catalog_json`
- `metadata`
- `last_used_at`
- `last_status_sync_at`
- `updated_at`
- unique on `(owner_user_id, provider_id)`
Add helpers:
- `get_runpod_provider_state(...)`
- `save_runpod_provider_state(...)`
- `touch_runpod_provider_state(...)`
- `list_runpod_provider_states(...)`
This DB-backed state is required for:
- round-robin multi-instance consistency
- idle shutdown scanning
- readiness caching
- public endpoint discovery caching
Cache/model discovery work
For `runpod_public` in `aisbf/app/model_cache.py`:
- cache discovered public models
- refresh periodically or on-demand
- store enough metadata per model:
- model id/slug
- protocol
- capabilities
- route base
- request mode (`runsync`, `run`, `status`)
- parameter/schema hints if available
Dashboard work
In `templates/dashboard/providers.html`:
- add provider type option: `runpod`
- add description text for `runpod`
- add UI section for `runpod_config`
- likely fields:
- account label
- mode (`pod`, `serverless_template`, `public`)
- wrapper mode (`openai`, `ollama`, `coderai`) for non-public
- API key field if not top-level
- pod id
- template id
- endpoint id
- serverless template id
- idle shutdown ms
- startup timeout ms
- poll interval ms
- auto-discovery toggle
- per-model protocol override editor for public models
Potential server-side additions in `aisbf/routes/dashboard/providers.py`
- refresh RunPod public discovery
- show RunPod lifecycle status
- optional manual start/stop actions later if useful
Protocol behavior plan
1. Pod-backed `openai`
- after pod ready, delegate to OpenAI-compatible request/model list flow
- endpoint likely `/v1/...`
2. Pod-backed `ollama`
- after pod ready, delegate to Ollama flow
- endpoint likely `/api/...`
3. Pod-backed `coderai`
- after pod ready, delegate to CoderAI flow
- endpoint/path depends on service running in the pod
4. `runpod_public`
- public endpoints are not one uniform protocol
- implement model-level protocol metadata
- auto-detect protocol from endpoint metadata/docs/naming where possible
- allow manual override per model
- request path likely uses `https://api.runpod.ai/v2/<endpoint>/...`
- do not fake this part; implement from verified docs only
Suggested next-session execution order
1. Verify RunPod API contract and choose the current supported management API surface
2. Add `runpod_config` to `aisbf/config.py`
3. Add DB-backed `runpod_provider_state` table and helpers in `aisbf/database.py`
4. Create `aisbf/providers/runpod.py`
5. Register `runpod` in `aisbf/providers/__init__.py`
6. Add idle shutdown background task in startup/background task area
7. Add dashboard UI/config save support in `templates/dashboard/providers.html`
8. Hook `runpod_public` discovery into `aisbf/app/model_cache.py`
9. Validate with compile/tests
Recommended tests to add
- config validation for `runpod_config`
- DB CRUD for `runpod_provider_state`
- lifecycle tests:
- stopped pod -> start called
- running pod -> no start
- idle timeout -> stop called
- public model discovery parsing
- protocol selection:
- public model auto-detect
- public model manual override
- delegation tests:
- `wrapper_mode=openai`
- `wrapper_mode=ollama`
- `wrapper_mode=coderai`
Files already reviewed for this work
- `aisbf/config.py`
- `aisbf/providers/__init__.py`
- `aisbf/providers/openai.py`
- `aisbf/providers/ollama.py`
- `aisbf/providers/coderai.py`
- `aisbf/providers/base.py`
- `aisbf/app/model_cache.py`
- `aisbf/routes/dashboard/providers.py`
- `templates/dashboard/providers.html`
Suggested next-session prompt
"Implement full RunPod provider support for AISBF. First determine whether RunPod REST/OpenAPI or GraphQL is the newer/current supported management API, then use that API for pod lifecycle, endpoint discovery, and template/serverless management. Add a new `runpod` provider type with `runpod_config`, DB-backed lifecycle state, auto-start/wait, idle shutdown, wrapper-mode delegation (`openai`, `ollama`, `coderai`), and `runpod_public` as one provider with many discovered models and per-model protocol auto-detect/manual override. Preserve multi-instance consistency by storing lifecycle state in the database."
...@@ -2069,6 +2069,189 @@ class DatabaseManager: ...@@ -2069,6 +2069,189 @@ class DatabaseManager:
''', (user_id, provider_name)) ''', (user_id, provider_name))
conn.commit() conn.commit()
def get_runpod_provider_state(self, provider_scope: str, owner_user_id: Optional[int], provider_id: str) -> Optional[Dict[str, Any]]:
"""Get stored RunPod runtime state for a provider."""
with self._get_connection() as conn:
cursor = conn.cursor()
placeholder = '?' if self.db_type == 'sqlite' else '%s'
owner_clause = 'owner_user_id IS NULL' if owner_user_id is None else f'owner_user_id = {placeholder}'
params = [provider_scope]
if owner_user_id is not None:
params.append(owner_user_id)
params.append(provider_id)
cursor.execute(f'''
SELECT provider_scope, owner_user_id, provider_id, mode, wrapper_mode, resource_id,
resource_kind, status, endpoint_url, public_catalog_json, metadata,
last_used_at, last_status_sync_at, updated_at
FROM runpod_provider_state
WHERE provider_scope = {placeholder} AND {owner_clause} AND provider_id = {placeholder}
LIMIT 1
''', tuple(params))
row = cursor.fetchone()
return self._row_to_runpod_provider_state(row)
def save_runpod_provider_state(self, provider_scope: str, owner_user_id: Optional[int], provider_id: str, mode: str,
wrapper_mode: Optional[str], resource_id: Optional[str], resource_kind: str,
status: str, endpoint_url: Optional[str] = None, public_catalog_json: Optional[Any] = None,
metadata: Optional[Dict[str, Any]] = None, last_used_at: Optional[Any] = None,
last_status_sync_at: Optional[Any] = None) -> Dict[str, Any]:
"""Persist RunPod runtime state for a provider."""
with self._get_connection() as conn:
cursor = conn.cursor()
placeholder = '?' if self.db_type == 'sqlite' else '%s'
catalog_payload = json.dumps(public_catalog_json or [])
metadata_payload = json.dumps(metadata or {})
sync_value = last_status_sync_at if last_status_sync_at is not None else time.time()
used_value = last_used_at
if owner_user_id is None:
cursor.execute(f'''
UPDATE runpod_provider_state
SET provider_scope = {placeholder},
mode = {placeholder},
wrapper_mode = {placeholder},
resource_id = {placeholder},
resource_kind = {placeholder},
status = {placeholder},
endpoint_url = {placeholder},
public_catalog_json = {placeholder},
metadata = {placeholder},
last_used_at = {placeholder},
last_status_sync_at = {placeholder},
updated_at = CURRENT_TIMESTAMP
WHERE provider_scope = {placeholder} AND owner_user_id IS NULL AND provider_id = {placeholder}
''', (
provider_scope, mode, wrapper_mode, resource_id, resource_kind, status,
endpoint_url, catalog_payload, metadata_payload, used_value, sync_value,
provider_scope, provider_id,
))
if cursor.rowcount:
conn.commit()
return self.get_runpod_provider_state(provider_scope, owner_user_id, provider_id)
if self.db_type == 'sqlite':
cursor.execute(f'''
INSERT INTO runpod_provider_state (
provider_scope, owner_user_id, provider_id, mode, wrapper_mode, resource_id,
resource_kind, status, endpoint_url, public_catalog_json, metadata,
last_used_at, last_status_sync_at, updated_at
) VALUES ({placeholder}, {placeholder}, {placeholder}, {placeholder}, {placeholder}, {placeholder},
{placeholder}, {placeholder}, {placeholder}, {placeholder}, {placeholder},
{placeholder}, {placeholder}, CURRENT_TIMESTAMP)
ON CONFLICT(owner_user_id, provider_id) DO UPDATE SET
provider_scope = excluded.provider_scope,
mode = excluded.mode,
wrapper_mode = excluded.wrapper_mode,
resource_id = excluded.resource_id,
resource_kind = excluded.resource_kind,
status = excluded.status,
endpoint_url = excluded.endpoint_url,
public_catalog_json = excluded.public_catalog_json,
metadata = excluded.metadata,
last_used_at = excluded.last_used_at,
last_status_sync_at = excluded.last_status_sync_at,
updated_at = CURRENT_TIMESTAMP
''', (provider_scope, owner_user_id, provider_id, mode, wrapper_mode, resource_id, resource_kind,
status, endpoint_url, catalog_payload, metadata_payload, used_value, sync_value))
else:
cursor.execute(f'''
INSERT INTO runpod_provider_state (
provider_scope, owner_user_id, provider_id, mode, wrapper_mode, resource_id,
resource_kind, status, endpoint_url, public_catalog_json, metadata,
last_used_at, last_status_sync_at, updated_at
) VALUES ({placeholder}, {placeholder}, {placeholder}, {placeholder}, {placeholder}, {placeholder},
{placeholder}, {placeholder}, {placeholder}, {placeholder}, {placeholder},
{placeholder}, {placeholder}, CURRENT_TIMESTAMP)
ON DUPLICATE KEY UPDATE
provider_scope = VALUES(provider_scope),
mode = VALUES(mode),
wrapper_mode = VALUES(wrapper_mode),
resource_id = VALUES(resource_id),
resource_kind = VALUES(resource_kind),
status = VALUES(status),
endpoint_url = VALUES(endpoint_url),
public_catalog_json = VALUES(public_catalog_json),
metadata = VALUES(metadata),
last_used_at = VALUES(last_used_at),
last_status_sync_at = VALUES(last_status_sync_at),
updated_at = CURRENT_TIMESTAMP
''', (provider_scope, owner_user_id, provider_id, mode, wrapper_mode, resource_id, resource_kind,
status, endpoint_url, catalog_payload, metadata_payload, used_value, sync_value))
conn.commit()
return self.get_runpod_provider_state(provider_scope, owner_user_id, provider_id)
def touch_runpod_provider_state(self, provider_scope: str, owner_user_id: Optional[int], provider_id: str,
last_used_at: Optional[Any] = None) -> None:
"""Update last-used timestamp for a RunPod state row."""
with self._get_connection() as conn:
cursor = conn.cursor()
placeholder = '?' if self.db_type == 'sqlite' else '%s'
owner_clause = 'owner_user_id IS NULL' if owner_user_id is None else f'owner_user_id = {placeholder}'
params = [last_used_at if last_used_at is not None else time.time(), provider_scope]
if owner_user_id is not None:
params.append(owner_user_id)
params.append(provider_id)
cursor.execute(f'''
UPDATE runpod_provider_state
SET last_used_at = {placeholder}, updated_at = CURRENT_TIMESTAMP
WHERE provider_scope = {placeholder} AND {owner_clause} AND provider_id = {placeholder}
''', tuple(params))
conn.commit()
def list_runpod_provider_states(self, provider_scope: Optional[str] = None) -> List[Dict[str, Any]]:
"""List stored RunPod runtime state rows."""
with self._get_connection() as conn:
cursor = conn.cursor()
placeholder = '?' if self.db_type == 'sqlite' else '%s'
if provider_scope is None:
cursor.execute('''
SELECT provider_scope, owner_user_id, provider_id, mode, wrapper_mode, resource_id,
resource_kind, status, endpoint_url, public_catalog_json, metadata,
last_used_at, last_status_sync_at, updated_at
FROM runpod_provider_state
ORDER BY updated_at DESC
''')
else:
cursor.execute(f'''
SELECT provider_scope, owner_user_id, provider_id, mode, wrapper_mode, resource_id,
resource_kind, status, endpoint_url, public_catalog_json, metadata,
last_used_at, last_status_sync_at, updated_at
FROM runpod_provider_state
WHERE provider_scope = {placeholder}
ORDER BY updated_at DESC
''', (provider_scope,))
return [state for state in (self._row_to_runpod_provider_state(row) for row in cursor.fetchall()) if state]
def _row_to_runpod_provider_state(self, row) -> Optional[Dict[str, Any]]:
if not row:
return None
public_catalog_payload = row[9]
metadata_payload = row[10]
try:
public_catalog = json.loads(public_catalog_payload) if public_catalog_payload else []
except Exception:
public_catalog = []
try:
metadata = json.loads(metadata_payload) if metadata_payload else {}
except Exception:
metadata = {}
return {
'provider_scope': row[0],
'owner_user_id': row[1],
'provider_id': row[2],
'mode': row[3],
'wrapper_mode': row[4],
'resource_id': row[5],
'resource_kind': row[6],
'status': row[7],
'endpoint_url': row[8],
'public_catalog_json': public_catalog,
'metadata': metadata,
'last_used_at': row[11],
'last_status_sync_at': row[12],
'updated_at': row[13],
}
# User-specific rotation methods # User-specific rotation methods
def save_user_rotation(self, user_id: int, rotation_name: str, config: Dict): def save_user_rotation(self, user_id: int, rotation_name: str, config: Dict):
""" """
...@@ -5602,6 +5785,62 @@ def DatabaseManager__run_config_migrations(self, cursor, auto_increment, timesta ...@@ -5602,6 +5785,62 @@ def DatabaseManager__run_config_migrations(self, cursor, auto_increment, timesta
except Exception as e: except Exception as e:
logger.warning(f"Migration check for provider_disabled_state table: {e}") logger.warning(f"Migration check for provider_disabled_state table: {e}")
# Migration: Create runpod_provider_state table if missing
try:
if self.db_type == 'sqlite':
cursor.execute("PRAGMA table_info(runpod_provider_state)")
if not cursor.fetchall():
cursor.execute(f'''
CREATE TABLE runpod_provider_state (
id INTEGER PRIMARY KEY {auto_increment},
provider_scope VARCHAR(16) NOT NULL,
owner_user_id INTEGER,
provider_id VARCHAR(255) NOT NULL,
mode VARCHAR(64) NOT NULL,
wrapper_mode VARCHAR(64),
resource_id VARCHAR(255),
resource_kind VARCHAR(64) NOT NULL,
status VARCHAR(64) NOT NULL,
endpoint_url TEXT,
public_catalog_json TEXT,
metadata TEXT,
last_used_at REAL,
last_status_sync_at REAL,
updated_at TIMESTAMP DEFAULT {timestamp_default},
UNIQUE(owner_user_id, provider_id)
)
''')
logger.info("✅ Migration: Created runpod_provider_state table")
else:
cursor.execute("""
SELECT TABLE_NAME FROM INFORMATION_SCHEMA.TABLES
WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = 'runpod_provider_state'
""")
if not cursor.fetchone():
cursor.execute(f'''
CREATE TABLE runpod_provider_state (
id INTEGER PRIMARY KEY {auto_increment},
provider_scope VARCHAR(16) NOT NULL,
owner_user_id INTEGER NULL,
provider_id VARCHAR(255) NOT NULL,
mode VARCHAR(64) NOT NULL,
wrapper_mode VARCHAR(64) NULL,
resource_id VARCHAR(255) NULL,
resource_kind VARCHAR(64) NOT NULL,
status VARCHAR(64) NOT NULL,
endpoint_url TEXT NULL,
public_catalog_json LONGTEXT NULL,
metadata LONGTEXT NULL,
last_used_at DOUBLE NULL,
last_status_sync_at DOUBLE NULL,
updated_at TIMESTAMP DEFAULT {timestamp_default},
UNIQUE KEY uniq_runpod_provider_state (owner_user_id, provider_id)
)
''')
logger.info("✅ Migration: Created runpod_provider_state table")
except Exception as e:
logger.warning(f"Migration check for runpod_provider_state table: {e}")
# Migration: Create user_sort_order table if missing # Migration: Create user_sort_order table if missing
try: try:
if self.db_type == 'sqlite': if self.db_type == 'sqlite':
......
...@@ -43,6 +43,7 @@ from .ollama import OllamaProviderHandler ...@@ -43,6 +43,7 @@ from .ollama import OllamaProviderHandler
from .codex import CodexProviderHandler from .codex import CodexProviderHandler
from .coderai import CoderAIProviderHandler from .coderai import CoderAIProviderHandler
from .qwen import QwenProviderHandler from .qwen import QwenProviderHandler
from .runpod import RunpodProviderHandler
from ..config import config from ..config import config
...@@ -57,7 +58,8 @@ PROVIDER_HANDLERS = { ...@@ -57,7 +58,8 @@ PROVIDER_HANDLERS = {
'kilocode': KiloProviderHandler, # Kilocode provider with OAuth2 support 'kilocode': KiloProviderHandler, # Kilocode provider with OAuth2 support
'codex': CodexProviderHandler, # Codex provider with OAuth2 support (OpenAI protocol) 'codex': CodexProviderHandler, # Codex provider with OAuth2 support (OpenAI protocol)
'coderai': CoderAIProviderHandler, # CoderAI provider with HTTP/WebSocket bridge support 'coderai': CoderAIProviderHandler, # CoderAI provider with HTTP/WebSocket bridge support
'qwen': QwenProviderHandler # Qwen provider with OAuth2 support (OpenAI-compatible) 'qwen': QwenProviderHandler, # Qwen provider with OAuth2 support (OpenAI-compatible)
'runpod': RunpodProviderHandler,
} }
......
...@@ -29,15 +29,17 @@ from .base import BaseProviderHandler, AISBF_DEBUG ...@@ -29,15 +29,17 @@ from .base import BaseProviderHandler, AISBF_DEBUG
class OllamaProviderHandler(BaseProviderHandler): class OllamaProviderHandler(BaseProviderHandler):
def __init__(self, provider_id: str, api_key: Optional[str] = None): def __init__(self, provider_id: str, api_key: Optional[str] = None, user_id: Optional[int] = None, provider_config=None):
super().__init__(provider_id, api_key) self.provider_config = provider_config if provider_config is not None else config.providers[provider_id]
super().__init__(provider_id, api_key, user_id=user_id)
timeout = httpx.Timeout( timeout = httpx.Timeout(
connect=60.0, connect=60.0,
read=300.0, read=300.0,
write=60.0, write=60.0,
pool=60.0 pool=60.0
) )
self.client = httpx.AsyncClient(base_url=config.providers[provider_id].endpoint, timeout=timeout) endpoint = self.provider_config.get("endpoint") if isinstance(self.provider_config, dict) else self.provider_config.endpoint
self.client = httpx.AsyncClient(base_url=endpoint, timeout=timeout)
def validate_credentials(self) -> bool: def validate_credentials(self) -> bool:
""" """
......
...@@ -30,9 +30,11 @@ from .base import BaseProviderHandler, AISBF_DEBUG ...@@ -30,9 +30,11 @@ from .base import BaseProviderHandler, AISBF_DEBUG
class OpenAIProviderHandler(BaseProviderHandler): class OpenAIProviderHandler(BaseProviderHandler):
def __init__(self, provider_id: str, api_key: str): def __init__(self, provider_id: str, api_key: str, user_id: Optional[int] = None, provider_config=None):
super().__init__(provider_id, api_key) self.provider_config = provider_config if provider_config is not None else config.providers[provider_id]
self.client = OpenAI(base_url=config.providers[provider_id].endpoint, api_key=api_key) super().__init__(provider_id, api_key, user_id=user_id)
endpoint = self.provider_config.get("endpoint") if isinstance(self.provider_config, dict) else self.provider_config.endpoint
self.client = OpenAI(base_url=endpoint, api_key=api_key)
def validate_credentials(self) -> bool: def validate_credentials(self) -> bool:
"""Validate OpenAI API key presence.""" """Validate OpenAI API key presence."""
......
"""
Copyright (C) 2026 Stefy Lanza <stefy@nexlab.net>
AISBF - AI Service Broker Framework || AI Should Be Free
RunPod provider handler.
"""
from __future__ import annotations
import asyncio
import json
import logging
import time
from typing import Any, Dict, List, Optional, Tuple
from urllib.parse import urljoin
import httpx
from ..config import config
from ..database import DatabaseRegistry
from ..models import Model
from .base import BaseProviderHandler
from .coderai import CoderAIProviderHandler
from .ollama import OllamaProviderHandler
from .openai import OpenAIProviderHandler
logger = logging.getLogger(__name__)
RUNPOD_MANAGEMENT_BASE = "https://rest.runpod.io/v1"
RUNPOD_PUBLIC_BASE = "https://api.runpod.ai/v2"
RUNPOD_PUBLIC_PROTOCOLS = {"runpod_public", "openai", "ollama", "coderai"}
RUNPOD_WRAPPER_MODES = {"openai", "ollama", "coderai"}
RUNPOD_RUNTIME_RUNNING = {"running", "ready", "completed", "healthy", "active"}
RUNPOD_RUNTIME_STOPPED = {"stopped", "exited", "terminated", "idle"}
class RunpodProviderHandler(BaseProviderHandler):
def __init__(self, provider_id: str, api_key: Optional[str] = None, user_id: Optional[int] = None, provider_config: Optional[Any] = None):
self.provider_config = provider_config if provider_config is not None else config.get_provider(provider_id)
super().__init__(provider_id, api_key, user_id=user_id)
self.user_provider_config = provider_config if isinstance(provider_config, dict) else self.user_provider_config
self._runpod_config = self._resolve_runpod_config()
self._mode = str(self._runpod_config.get("mode") or "pod").strip().lower()
self._wrapper_mode = str(self._runpod_config.get("wrapper_mode") or "openai").strip().lower()
self._management_api = str(self._runpod_config.get("management_api") or "auto").strip().lower()
self._startup_poll_interval_ms = int(self._runpod_config.get("startup_poll_interval_ms") or 3000)
self._startup_timeout_ms = int(self._runpod_config.get("startup_timeout_ms") or 300000)
self._idle_shutdown_ms = int(self._runpod_config.get("idle_shutdown_ms") or 900000)
self._public_endpoint_protocol_default = str(self._runpod_config.get("public_endpoint_protocol_default") or "auto").strip().lower()
self._management_base = (self._get_provider_value("endpoint") or RUNPOD_MANAGEMENT_BASE).rstrip("/")
def _get_provider_value(self, key: str, default: Any = None) -> Any:
if isinstance(self.provider_config, dict):
return self.provider_config.get(key, default)
return getattr(self.provider_config, key, default)
def _resolve_runpod_config(self) -> Dict[str, Any]:
if isinstance(self.provider_config, dict):
raw = self.provider_config.get("runpod_config") or {}
else:
raw = getattr(self.provider_config, "runpod_config", None) or {}
return raw if isinstance(raw, dict) else {}
def validate_credentials(self) -> bool:
key = self.api_key or self._get_provider_value("api_key")
if not isinstance(key, str) or not key.strip() or key.strip().startswith("YOUR_"):
logger.error("[%s] RunPod API key required but not configured", self.provider_id)
return False
if self._mode not in {"pod", "serverless_template", "public"}:
logger.error("[%s] Unsupported RunPod mode: %s", self.provider_id, self._mode)
return False
if self._mode != "public" and self._wrapper_mode not in RUNPOD_WRAPPER_MODES:
logger.error("[%s] Unsupported RunPod wrapper mode: %s", self.provider_id, self._wrapper_mode)
return False
return True
async def handle_request(self, model: str, messages: List[Dict], max_tokens: Optional[int] = None, temperature: Optional[float] = 1.0, stream: Optional[bool] = False, tools: Optional[List[Dict]] = None, tool_choice: Optional[Any] = None):
await self.apply_rate_limit()
if self._mode == "public":
response = await self._handle_public_request(model, messages, max_tokens=max_tokens, temperature=temperature)
self.record_success()
return response
await self._ensure_non_public_resource_ready()
delegate = self._build_delegate_handler(self._wrapper_mode)
response = await delegate.handle_request(model, messages, max_tokens=max_tokens, temperature=temperature, stream=stream, tools=tools, tool_choice=tool_choice)
self.record_success()
return response
async def get_models(self) -> List[Model]:
await self.apply_rate_limit()
if self._mode == "public":
catalog = await self._get_public_catalog()
return [self._catalog_entry_to_model(entry) for entry in catalog]
await self._ensure_non_public_resource_ready()
delegate = self._build_delegate_handler(self._wrapper_mode)
return await delegate.get_models()
async def refresh_public_catalog(self) -> List[Dict[str, Any]]:
state = self._db().get_runpod_provider_state(self._provider_scope(), self.user_id, self.provider_id)
metadata = dict((state or {}).get("metadata") or {})
try:
catalog = self._apply_public_model_overrides(
self._normalize_public_catalog(await self._fetch_live_public_catalog_entries())
)
except Exception as exc:
metadata["catalog_refresh_error"] = str(exc)
if state:
self._db().save_runpod_provider_state(
provider_scope=self._provider_scope(),
owner_user_id=self.user_id,
provider_id=self.provider_id,
mode=state.get("mode") or "public",
wrapper_mode=state.get("wrapper_mode"),
resource_id=state.get("resource_id") or "public-catalog",
resource_kind=state.get("resource_kind") or "public",
status=state.get("status") or "ready",
endpoint_url=state.get("endpoint_url"),
public_catalog_json=state.get("public_catalog_json") or [],
metadata=metadata,
last_used_at=state.get("last_used_at"),
last_status_sync_at=state.get("last_status_sync_at"),
)
raise
metadata["catalog_refreshed_at"] = int(time.time())
metadata["catalog_item_count"] = len(catalog)
metadata["catalog_source"] = "live"
metadata.pop("catalog_refresh_error", None)
self._db().save_runpod_provider_state(
provider_scope=self._provider_scope(),
owner_user_id=self.user_id,
provider_id=self.provider_id,
mode="public",
wrapper_mode=None,
resource_id="public-catalog",
resource_kind="public",
status="ready",
endpoint_url=RUNPOD_PUBLIC_BASE,
public_catalog_json=catalog,
metadata=metadata,
)
return catalog
async def _fetch_live_public_catalog_entries(self) -> List[Dict[str, Any]]:
seed_entries = self._runpod_config.get("public_catalog_seed") or []
return [entry for entry in seed_entries if isinstance(entry, dict)]
def _apply_public_model_overrides(self, catalog: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
overrides = self._runpod_config.get("public_models") or {}
if not isinstance(overrides, dict):
return catalog
normalized = []
for entry in catalog:
patched = dict(entry)
override = overrides.get(entry.get("id")) or {}
if not isinstance(override, dict):
normalized.append(patched)
continue
protocol = override.get("protocol")
if protocol:
patched["protocol"] = str(protocol).strip().lower()
if override.get("capabilities"):
patched["capabilities"] = list(override.get("capabilities") or [])
normalized.append(patched)
return normalized
async def poll_idle_shutdown(self) -> bool:
if self._mode != "pod" or self._idle_shutdown_ms <= 0:
return False
state = self._db().get_runpod_provider_state(self._provider_scope(), self.user_id, self.provider_id)
if not state or state.get("last_used_at") is None:
return False
now = time.time()
last_used_at = self._to_timestamp(state.get("last_used_at"))
if last_used_at is None or (now - last_used_at) * 1000 < self._idle_shutdown_ms:
return False
if not self._is_running_status(state.get("status")):
return False
pod_id = self._runpod_config.get("pod_id") or state.get("resource_id")
if not pod_id:
return False
await self._management_request("POST", f"/pods/{pod_id}/stop")
self._db().save_runpod_provider_state(
provider_scope=self._provider_scope(),
owner_user_id=self.user_id,
provider_id=self.provider_id,
mode=self._mode,
wrapper_mode=self._wrapper_mode,
resource_id=pod_id,
resource_kind="pod",
status="stopped",
endpoint_url=state.get("endpoint_url"),
public_catalog_json=state.get("public_catalog_json"),
metadata=state.get("metadata") or {},
last_used_at=state.get("last_used_at"),
)
return True
def current_runtime_state(self) -> Optional[Dict[str, Any]]:
return self._db().get_runpod_provider_state(self._provider_scope(), self.user_id, self.provider_id)
def build_runtime_status(self) -> Dict[str, Any]:
state = self.current_runtime_state() or {}
metadata = dict(state.get("metadata") or {})
catalog = self._normalize_public_catalog(state.get("public_catalog_json"))
return {
"provider_id": self.provider_id,
"provider_scope": self._provider_scope(),
"mode": state.get("mode") or self._mode,
"wrapper_mode": state.get("wrapper_mode") or (None if (state.get("mode") or self._mode) == "public" else self._wrapper_mode),
"resource_id": state.get("resource_id"),
"resource_kind": state.get("resource_kind"),
"status": state.get("status") or "unknown",
"endpoint_url": state.get("endpoint_url"),
"last_used_at": state.get("last_used_at"),
"last_status_sync_at": state.get("last_status_sync_at"),
"updated_at": state.get("updated_at"),
"catalog": {
"item_count": metadata.get("catalog_item_count", len(catalog)),
"refreshed_at": metadata.get("catalog_refreshed_at"),
"source": metadata.get("catalog_source"),
"refresh_error": metadata.get("catalog_refresh_error"),
"models": catalog,
},
"metadata": metadata,
}
async def _handle_public_request(self, model: str, messages: List[Dict], max_tokens: Optional[int], temperature: Optional[float]) -> Dict[str, Any]:
entry = await self._resolve_public_model(model)
protocol = self._resolve_public_protocol(entry)
if protocol == "openai":
delegate = self._build_delegate_handler("openai", endpoint_override=entry.get("route_base"), api_key_override=self.api_key)
return await delegate.handle_request(model, messages, max_tokens=max_tokens, temperature=temperature, stream=False)
if protocol == "ollama":
delegate = self._build_delegate_handler("ollama", endpoint_override=entry.get("route_base"), api_key_override=self.api_key)
return await delegate.handle_request(model, messages, max_tokens=max_tokens, temperature=temperature, stream=False)
payload = {
"input": {
"messages": messages,
"model": model,
"max_tokens": max_tokens,
"temperature": temperature,
}
}
route_base = (entry.get("route_base") or f"{RUNPOD_PUBLIC_BASE}/{entry.get('id')}").rstrip("/")
request_mode = str(entry.get("request_mode") or "runsync").strip().lower()
if request_mode not in {"run", "runsync"}:
request_mode = "runsync"
response = await self._public_request("POST", f"{route_base}/{request_mode}", json_body=payload)
return self._wrap_public_response(model, entry, response)
async def _ensure_non_public_resource_ready(self) -> Dict[str, Any]:
if self._mode == "serverless_template":
return await self._ensure_serverless_endpoint_ready()
return await self._ensure_pod_ready()
async def _ensure_pod_ready(self) -> Dict[str, Any]:
pod_id = self._runpod_config.get("pod_id")
if not pod_id:
raise ValueError(f"RunPod provider '{self.provider_id}' requires runpod_config.pod_id")
state = self._db().get_runpod_provider_state(self._provider_scope(), self.user_id, self.provider_id)
pod = await self._management_request("GET", f"/pods/{pod_id}", params={"includeTemplate": True})
status = self._extract_pod_status(pod)
if self._is_stopped_status(status):
await self._management_request("POST", f"/pods/{pod_id}/start")
pod = await self._wait_for_pod_ready(pod_id)
status = self._extract_pod_status(pod)
elif not self._is_running_status(status):
pod = await self._wait_for_pod_ready(pod_id)
status = self._extract_pod_status(pod)
endpoint_url = self._derive_pod_endpoint_url(pod)
metadata = dict((state or {}).get("metadata") or {})
metadata["pod"] = pod
now = time.time()
self._db().save_runpod_provider_state(
provider_scope=self._provider_scope(),
owner_user_id=self.user_id,
provider_id=self.provider_id,
mode="pod",
wrapper_mode=self._wrapper_mode,
resource_id=pod_id,
resource_kind="pod",
status=status,
endpoint_url=endpoint_url,
public_catalog_json=(state or {}).get("public_catalog_json"),
metadata=metadata,
last_used_at=now,
)
return {"status": status, "endpoint_url": endpoint_url, "pod": pod}
async def _wait_for_pod_ready(self, pod_id: str) -> Dict[str, Any]:
deadline = time.time() + (self._startup_timeout_ms / 1000.0)
interval = max(self._startup_poll_interval_ms / 1000.0, 0.1)
last_pod = None
while time.time() < deadline:
last_pod = await self._management_request("GET", f"/pods/{pod_id}", params={"includeTemplate": True})
if self._pod_is_ready(last_pod):
return last_pod
await asyncio.sleep(interval)
raise TimeoutError(f"RunPod pod '{pod_id}' did not become ready within {self._startup_timeout_ms}ms")
async def _ensure_serverless_endpoint_ready(self) -> Dict[str, Any]:
endpoint_id = self._runpod_config.get("endpoint_id")
if not endpoint_id:
endpoint_id = await self._resolve_endpoint_from_serverless_template()
endpoint = await self._management_request("GET", f"/endpoints/{endpoint_id}", params={"includeTemplate": True, "includeWorkers": True})
endpoint_url = self._derive_serverless_endpoint_url(endpoint)
state = self._db().get_runpod_provider_state(self._provider_scope(), self.user_id, self.provider_id)
metadata = dict((state or {}).get("metadata") or {})
metadata["endpoint"] = endpoint
self._db().save_runpod_provider_state(
provider_scope=self._provider_scope(),
owner_user_id=self.user_id,
provider_id=self.provider_id,
mode="serverless_template",
wrapper_mode=self._wrapper_mode,
resource_id=endpoint_id,
resource_kind="endpoint",
status="ready",
endpoint_url=endpoint_url,
public_catalog_json=(state or {}).get("public_catalog_json"),
metadata=metadata,
last_used_at=time.time(),
)
return {"status": "ready", "endpoint_url": endpoint_url, "endpoint": endpoint}
async def _resolve_endpoint_from_serverless_template(self) -> str:
configured_endpoint = self._runpod_config.get("endpoint_id")
if configured_endpoint:
return configured_endpoint
template_id = self._runpod_config.get("serverless_template_id") or self._runpod_config.get("template_id")
if not template_id:
raise ValueError(f"RunPod provider '{self.provider_id}' requires endpoint_id or serverless_template_id in serverless_template mode")
endpoints = await self._management_request("GET", "/endpoints", params={"includeTemplate": True})
items = self._extract_items(endpoints)
for endpoint in items:
if str(endpoint.get("templateId") or "") == str(template_id):
return endpoint.get("id")
payload = {
"name": f"AISBF {self.provider_id}",
"templateId": template_id,
}
created = await self._management_request("POST", "/endpoints", json_body=payload)
endpoint_id = created.get("id")
if not endpoint_id:
raise ValueError(f"RunPod endpoint creation for provider '{self.provider_id}' did not return an endpoint id")
return endpoint_id
async def _get_public_catalog(self) -> List[Dict[str, Any]]:
state = self._db().get_runpod_provider_state(self._provider_scope(), self.user_id, self.provider_id)
cached = self._normalize_public_catalog((state or {}).get("public_catalog_json"))
if cached:
return cached
return await self.refresh_public_catalog()
async def _resolve_public_model(self, model: str) -> Dict[str, Any]:
requested = model.split("/", 1)[-1]
catalog = await self._get_public_catalog()
for entry in catalog:
if entry.get("id") == requested or entry.get("name") == requested:
self._touch_runtime_state()
return entry
raise ValueError(f"RunPod public model '{requested}' not found for provider '{self.provider_id}'")
def _resolve_public_protocol(self, entry: Dict[str, Any]) -> str:
manual = ((self._runpod_config.get("public_models") or {}).get(entry.get("id") or "") or {}).get("protocol")
candidate = str(manual or entry.get("protocol") or self._public_endpoint_protocol_default or "runpod_public").strip().lower()
if candidate == "auto":
candidate = self._infer_public_protocol(entry)
if candidate not in RUNPOD_PUBLIC_PROTOCOLS:
candidate = "runpod_public"
return candidate
def _infer_public_protocol(self, entry: Dict[str, Any]) -> str:
route_base = str(entry.get("route_base") or "").lower()
schema = entry.get("schema") or {}
if any(token in route_base for token in ("/v1", "openai")):
return "openai"
if str(schema).lower().find("messages") >= 0:
return "openai"
return "runpod_public"
def _catalog_entry_to_model(self, entry: Dict[str, Any]) -> Model:
identifier = str(entry.get("id") or entry.get("name") or "unknown")
return Model(
id=identifier,
name=entry.get("name") or identifier,
provider_id=self.provider_id,
description=entry.get("description"),
architecture=entry.get("architecture"),
pricing=entry.get("pricing"),
context_length=entry.get("context_length"),
context_size=entry.get("context_length"),
supported_parameters=entry.get("supported_parameters"),
)
def _wrap_public_response(self, model: str, entry: Dict[str, Any], response: Dict[str, Any]) -> Dict[str, Any]:
output = response.get("output")
if isinstance(output, dict) and "choices" in output:
return output
text = None
if isinstance(output, dict):
text = output.get("text") or output.get("response") or output.get("content")
elif isinstance(output, str):
text = output
payload = {
"id": response.get("id") or f"runpod-{entry.get('id')}-{int(time.time())}",
"object": "chat.completion",
"created": int(time.time()),
"model": f"{self.provider_id}/{model.split('/', 1)[-1]}",
"choices": [{"index": 0, "message": {"role": "assistant", "content": text or json.dumps(output or response)}, "finish_reason": "stop"}],
"usage": response.get("usage") or {},
"runpod": response,
}
return payload
def _build_delegate_handler(self, wrapper_mode: str, endpoint_override: Optional[str] = None, api_key_override: Optional[str] = None):
provider_dict = self._provider_dict_with_endpoint(endpoint_override)
api_key = api_key_override or self.api_key or self._get_provider_value("api_key")
if wrapper_mode == "openai":
return OpenAIProviderHandler(self.provider_id, api_key, provider_config=provider_dict)
if wrapper_mode == "ollama":
return OllamaProviderHandler(self.provider_id, api_key, provider_config=provider_dict)
if wrapper_mode == "coderai":
return CoderAIProviderHandler(self.provider_id, api_key, user_id=self.user_id, provider_config=provider_dict)
raise ValueError(f"Unsupported RunPod wrapper mode '{wrapper_mode}'")
def _provider_dict_with_endpoint(self, endpoint_override: Optional[str]) -> Dict[str, Any]:
if isinstance(self.provider_config, dict):
data = dict(self.provider_config)
else:
data = self.provider_config.model_dump() if hasattr(self.provider_config, "model_dump") else dict(self.provider_config.dict())
runtime_state = self.current_runtime_state() or {}
data["endpoint"] = endpoint_override or runtime_state.get("endpoint_url") or data.get("endpoint")
return data
def _provider_scope(self) -> str:
return "user" if self.user_id is not None else "global"
def _db(self):
return DatabaseRegistry.get_config_database()
def _touch_runtime_state(self) -> None:
self._db().touch_runpod_provider_state(self._provider_scope(), self.user_id, self.provider_id)
async def _management_request(self, method: str, path: str, params: Optional[Dict[str, Any]] = None, json_body: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
headers = {"Authorization": f"Bearer {self.api_key or self._get_provider_value('api_key')}", "Content-Type": "application/json"}
url = f"{self._management_base}{path}"
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.request(method.upper(), url, headers=headers, params=params, json=json_body)
response.raise_for_status()
if not response.content:
return {}
return response.json()
async def _public_request(self, method: str, url: str, json_body: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
headers = {"Authorization": f"Bearer {self.api_key or self._get_provider_value('api_key')}", "Content-Type": "application/json"}
async with httpx.AsyncClient(timeout=60.0) as client:
response = await client.request(method.upper(), url, headers=headers, json=json_body)
response.raise_for_status()
return response.json()
def _extract_pod_status(self, pod: Dict[str, Any]) -> str:
return str(pod.get("desiredStatus") or pod.get("status") or "unknown").strip().lower()
def _extract_items(self, payload: Any) -> List[Dict[str, Any]]:
if isinstance(payload, list):
return [item for item in payload if isinstance(item, dict)]
if isinstance(payload, dict):
for key in ("data", "items", "pods", "endpoints", "templates"):
value = payload.get(key)
if isinstance(value, list):
return [item for item in value if isinstance(item, dict)]
return []
def _derive_pod_endpoint_url(self, pod: Dict[str, Any]) -> str:
port_mappings = pod.get("portMappings") or []
public_ip = pod.get("publicIp")
if public_ip and isinstance(port_mappings, list) and port_mappings:
first_mapping = port_mappings[0]
if isinstance(first_mapping, dict):
public_port = first_mapping.get("publicPort") or first_mapping.get("port") or first_mapping.get("containerPort")
else:
public_port = None
if public_port:
scheme = "http"
if self._wrapper_mode == "openai":
return f"{scheme}://{public_ip}:{public_port}/v1"
if self._wrapper_mode == "ollama":
return f"{scheme}://{public_ip}:{public_port}"
return f"{scheme}://{public_ip}:{public_port}"
return ""
def _derive_serverless_endpoint_url(self, endpoint: Dict[str, Any]) -> str:
endpoint_id = endpoint.get("id") or self._runpod_config.get("endpoint_id")
if not endpoint_id:
raise ValueError(f"RunPod provider '{self.provider_id}' serverless endpoint is missing an id")
if self._wrapper_mode == "openai":
return f"{RUNPOD_PUBLIC_BASE}/{endpoint_id}/openai/v1"
if self._wrapper_mode == "ollama":
return f"{RUNPOD_PUBLIC_BASE}/{endpoint_id}"
return f"{RUNPOD_PUBLIC_BASE}/{endpoint_id}"
def _pod_is_ready(self, pod: Dict[str, Any]) -> bool:
status = self._extract_pod_status(pod)
if status not in RUNPOD_RUNTIME_RUNNING:
return False
return bool(self._derive_pod_endpoint_url(pod))
def _is_running_status(self, status: Optional[str]) -> bool:
return str(status or "").strip().lower() in RUNPOD_RUNTIME_RUNNING
def _is_stopped_status(self, status: Optional[str]) -> bool:
return str(status or "").strip().lower() in RUNPOD_RUNTIME_STOPPED
def _normalize_public_catalog(self, entries: Any) -> List[Dict[str, Any]]:
normalized = []
for entry in entries or []:
if not isinstance(entry, dict):
continue
identifier = str(entry.get("id") or entry.get("name") or "").strip()
if not identifier:
continue
route_base = str(entry.get("route_base") or f"{RUNPOD_PUBLIC_BASE}/{identifier}").rstrip("/")
protocol = str(entry.get("protocol") or "auto").lower()
if protocol == "auto":
protocol = self._infer_public_protocol({**entry, "id": identifier, "route_base": route_base})
normalized.append({
"id": identifier,
"name": entry.get("name") or identifier,
"protocol": protocol,
"capabilities": list(entry.get("capabilities") or []),
"route_base": route_base,
"request_mode": str(entry.get("request_mode") or "runsync").lower(),
"description": entry.get("description"),
"pricing": entry.get("pricing"),
"architecture": entry.get("architecture"),
"context_length": entry.get("context_length"),
"supported_parameters": entry.get("supported_parameters"),
"schema": entry.get("schema"),
})
return normalized
@staticmethod
def _to_timestamp(value: Any) -> Optional[float]:
if value is None:
return None
if isinstance(value, (int, float)):
return float(value)
if isinstance(value, str):
try:
return float(value)
except ValueError:
return None
return None
async def runpod_idle_shutdown_loop(poll_interval_seconds: float = 30.0) -> None:
while True:
try:
db = DatabaseRegistry.get_config_database()
for state in db.list_runpod_provider_states():
if (state or {}).get("mode") != "pod":
continue
provider_id = state.get("provider_id")
owner_user_id = state.get("owner_user_id")
api_key = None
provider_config = None
if owner_user_id is None:
provider_config = config.providers.get(provider_id)
api_key = getattr(provider_config, "api_key", None) if provider_config else None
else:
provider = db.get_user_provider(owner_user_id, provider_id)
if provider:
provider_config = provider.get("config")
api_key = (provider_config or {}).get("api_key")
if not provider_config:
continue
handler = RunpodProviderHandler(provider_id, api_key, user_id=owner_user_id, provider_config=provider_config)
await handler.poll_idle_shutdown()
except Exception:
logger.exception("RunPod idle shutdown loop iteration failed")
await asyncio.sleep(max(poll_interval_seconds, 1.0))
...@@ -16,6 +16,7 @@ from aisbf.app.startup import _reload_global_config, _apply_condense_defaults_pr ...@@ -16,6 +16,7 @@ from aisbf.app.startup import _reload_global_config, _apply_condense_defaults_pr
from aisbf.app.middleware import _is_local_client from aisbf.app.middleware import _is_local_client
from aisbf.app.model_cache import fetch_provider_models from aisbf.app.model_cache import fetch_provider_models
from aisbf.routes.auth import require_dashboard_auth, require_api_auth, require_api_admin, require_admin from aisbf.routes.auth import require_dashboard_auth, require_api_auth, require_api_admin, require_admin
from aisbf.providers.runpod import RunpodProviderHandler
import httpx import httpx
router = APIRouter() router = APIRouter()
...@@ -116,6 +117,55 @@ def _ensure_coderai_token(provider_config: dict) -> dict: ...@@ -116,6 +117,55 @@ def _ensure_coderai_token(provider_config: dict) -> dict:
return stamped return stamped
def _normalize_runpod_provider_config(provider_id: str, provider_config: dict) -> dict:
stamped = dict(provider_config or {})
if stamped.get('type') != 'runpod':
return stamped
runpod_config = stamped.get('runpod_config')
if not isinstance(runpod_config, dict):
runpod_config = {}
mode = str(runpod_config.get('mode') or 'pod').strip().lower()
wrapper_mode = str(runpod_config.get('wrapper_mode') or 'openai').strip().lower()
runpod_config['mode'] = mode
runpod_config['management_api'] = str(runpod_config.get('management_api') or 'auto').strip().lower() or 'auto'
runpod_config['account_name'] = str(runpod_config.get('account_name') or provider_id).strip() or provider_id
runpod_config['startup_poll_interval_ms'] = int(runpod_config.get('startup_poll_interval_ms') or 3000)
runpod_config['startup_timeout_ms'] = int(runpod_config.get('startup_timeout_ms') or 300000)
runpod_config['idle_shutdown_ms'] = int(runpod_config.get('idle_shutdown_ms') or 900000)
runpod_config['public_endpoint_protocol_default'] = str(runpod_config.get('public_endpoint_protocol_default') or 'auto').strip().lower() or 'auto'
if mode == 'public':
public_models = runpod_config.get('public_models')
if not isinstance(public_models, dict):
runpod_config['public_models'] = {}
else:
runpod_config['wrapper_mode'] = wrapper_mode
stamped['runpod_config'] = runpod_config
if not stamped.get('endpoint'):
stamped['endpoint'] = 'https://rest.runpod.io/v1'
return stamped
def _validate_runpod_provider_config(provider_id: str, provider_config: dict) -> None:
if not isinstance(provider_config, dict) or provider_config.get('type') != 'runpod':
return
runpod_config = provider_config.get('runpod_config') or {}
mode = str(runpod_config.get('mode') or 'pod').strip().lower()
if mode not in {'pod', 'serverless_template', 'public'}:
raise ValueError(f"RunPod provider '{provider_id}' has unsupported mode '{mode}'")
if mode != 'public':
wrapper_mode = str(runpod_config.get('wrapper_mode') or 'openai').strip().lower()
if wrapper_mode not in {'openai', 'ollama', 'coderai'}:
raise ValueError(f"RunPod provider '{provider_id}' has unsupported wrapper_mode '{wrapper_mode}'")
if mode == 'pod' and not str(runpod_config.get('pod_id') or '').strip():
raise ValueError(f"RunPod provider '{provider_id}' requires runpod_config.pod_id in pod mode")
if mode == 'serverless_template' and not (str(runpod_config.get('endpoint_id') or '').strip() or str(runpod_config.get('serverless_template_id') or '').strip() or str(runpod_config.get('template_id') or '').strip()):
raise ValueError(f"RunPod provider '{provider_id}' requires endpoint_id or template_id in serverless_template mode")
def _validate_coderai_provider_config(provider_id: str, provider_config: dict) -> None: def _validate_coderai_provider_config(provider_id: str, provider_config: dict) -> None:
if not isinstance(provider_config, dict) or provider_config.get('type') != 'coderai': if not isinstance(provider_config, dict) or provider_config.get('type') != 'coderai':
return return
...@@ -189,6 +239,34 @@ def _apply_usage_disable(db, user_id, provider_id: str, usage_data: dict): ...@@ -189,6 +239,34 @@ def _apply_usage_disable(db, user_id, provider_id: str, usage_data: dict):
pass pass
def _resolve_dashboard_provider_config(request: Request, provider_id: str) -> tuple[dict, Optional[int]]:
current_user_id = request.session.get('user_id')
db = DatabaseRegistry.get_config_database()
if current_user_id is None:
provider = _config.providers.get(provider_id) if _config else None
if provider is None:
raise HTTPException(status_code=404, detail="Provider not found")
if hasattr(provider, "model_dump"):
return provider.model_dump(), None
if hasattr(provider, "dict"):
return provider.dict(), None
return dict(provider), None
provider_row = db.get_user_provider(current_user_id, provider_id)
if not provider_row:
raise HTTPException(status_code=404, detail="Provider not found")
return dict(provider_row.get("config") or {}), current_user_id
def _build_dashboard_runpod_handler(request: Request, provider_id: str) -> RunpodProviderHandler:
provider_config, owner_user_id = _resolve_dashboard_provider_config(request, provider_id)
if provider_config.get("type") != "runpod":
raise HTTPException(status_code=404, detail="RunPod provider not found")
api_key = provider_config.get("api_key")
return RunpodProviderHandler(provider_id, api_key=api_key, user_id=owner_user_id, provider_config=provider_config)
@router.get("/dashboard", response_class=HTMLResponse) @router.get("/dashboard", response_class=HTMLResponse)
async def dashboard_index(request: Request): async def dashboard_index(request: Request):
"""Dashboard overview page""" """Dashboard overview page"""
...@@ -628,7 +706,9 @@ async def dashboard_providers_save(request: Request, config: str = Form(...)): ...@@ -628,7 +706,9 @@ async def dashboard_providers_save(request: Request, config: str = Form(...)):
# Apply defaults: if condense_method is set but condense_context is not, default to 80 # Apply defaults: if condense_method is set but condense_context is not, default to 80
for provider_key, provider in providers_data.items(): for provider_key, provider in providers_data.items():
provider = _ensure_coderai_token(provider) provider = _ensure_coderai_token(provider)
provider = _normalize_runpod_provider_config(provider_key, provider)
_validate_coderai_provider_config(provider_key, provider) _validate_coderai_provider_config(provider_key, provider)
_validate_runpod_provider_config(provider_key, provider)
if 'models' in provider and isinstance(provider['models'], list): if 'models' in provider and isinstance(provider['models'], list):
for model in provider['models']: for model in provider['models']:
if 'condense_method' in model and model.get('condense_method'): if 'condense_method' in model and model.get('condense_method'):
...@@ -961,6 +1041,41 @@ async def search_provider_models_api(request: Request, provider_id: str, query: ...@@ -961,6 +1041,41 @@ async def search_provider_models_api(request: Request, provider_id: str, query:
return JSONResponse({"models": models[:200], "fetched_live": fetched_live}) return JSONResponse({"models": models[:200], "fetched_live": fetched_live})
@router.get("/dashboard/providers/{provider_id}/runpod-status")
async def api_runpod_provider_status(provider_id: str, request: Request):
auth_check = require_dashboard_auth(request)
if auth_check:
return JSONResponse({"success": False, "error": "Not authenticated"}, status_code=401)
try:
handler = _build_dashboard_runpod_handler(request, provider_id)
return JSONResponse({"success": True, "status": handler.build_runtime_status()})
except HTTPException as exc:
return JSONResponse({"success": False, "error": exc.detail}, status_code=exc.status_code)
except Exception as exc:
return JSONResponse({"success": False, "error": str(exc)}, status_code=500)
@router.post("/dashboard/providers/{provider_id}/runpod-refresh")
async def api_runpod_provider_refresh(provider_id: str, request: Request):
auth_check = require_dashboard_auth(request)
if auth_check:
return JSONResponse({"success": False, "error": "Not authenticated"}, status_code=401)
try:
handler = _build_dashboard_runpod_handler(request, provider_id)
catalog = await handler.refresh_public_catalog()
return JSONResponse({
"success": True,
"catalog_count": len(catalog),
"status": handler.build_runtime_status(),
})
except HTTPException as exc:
return JSONResponse({"success": False, "error": exc.detail}, status_code=exc.status_code)
except Exception as exc:
return JSONResponse({"success": False, "error": str(exc)}, status_code=500)
@router.get("/dashboard/search-all-models") @router.get("/dashboard/search-all-models")
async def search_all_models_api(request: Request, query: str = "", refresh: bool = False): async def search_all_models_api(request: Request, query: str = "", refresh: bool = False):
"""Return all available models (rotations + provider models) for autoselect, with optional live refresh.""" """Return all available models (rotations + provider models) for autoselect, with optional live refresh."""
......
...@@ -52,6 +52,7 @@ along with this program. If not, see <https://www.gnu.org/licenses/>. ...@@ -52,6 +52,7 @@ along with this program. If not, see <https://www.gnu.org/licenses/>.
<option value="qwen">Qwen (OAuth2)</option> <option value="qwen">Qwen (OAuth2)</option>
<option value="codex">Codex (OpenAI OAuth2)</option> <option value="codex">Codex (OpenAI OAuth2)</option>
<option value="coderai">CoderAI</option> <option value="coderai">CoderAI</option>
<option value="runpod">RunPod</option>
</select> </select>
<small style="color: var(--color-muted); display: block; margin-top: 5px;">Select the type of provider to configure appropriate settings</small> <small style="color: var(--color-muted); display: block; margin-top: 5px;">Select the type of provider to configure appropriate settings</small>
</div> </div>
...@@ -190,6 +191,98 @@ function formatPerfNumber(value, digits = 1, suffix = '') { ...@@ -190,6 +191,98 @@ function formatPerfNumber(value, digits = 1, suffix = '') {
return `${num.toFixed(digits)}${suffix}`; return `${num.toFixed(digits)}${suffix}`;
} }
function formatRunpodTimestamp(value) {
if (value == null || value === '') return 'Never';
const date = new Date(Number(value) * 1000);
if (Number.isNaN(date.getTime())) return 'Unknown';
return date.toLocaleString();
}
function runpodStatusBadgeColor(status) {
const normalized = String(status || '').toLowerCase();
if (['ready', 'running', 'healthy', 'active'].includes(normalized)) return '#22c55e';
if (['starting', 'pending', 'provisioning'].includes(normalized)) return '#f59e0b';
if (['stopped', 'exited', 'terminated', 'idle'].includes(normalized)) return '#94a3b8';
return '#f87171';
}
function renderRunpodRuntimeStatus(providerKey, payload, errorMessage) {
const container = document.getElementById(`runpod-runtime-status-${providerKey}`);
if (!container) return;
if (errorMessage) {
container.innerHTML = `<div style="padding:12px; border-radius:6px; background:rgba(248,113,113,0.08); border:1px solid rgba(248,113,113,0.35); color:#fca5a5;">${escHtmlAttr(errorMessage)}</div>`;
return;
}
const status = payload || {};
const catalog = status.catalog || {};
const badgeColor = runpodStatusBadgeColor(status.status);
const models = Array.isArray(catalog.models) ? catalog.models : [];
const modelPreview = models.slice(0, 6).map(model => `<code style="background:var(--bg-page);padding:2px 6px;border-radius:4px;">${escHtmlAttr(model.id || model.name || 'unknown')}</code>`).join(' ');
container.innerHTML = `
<div style="padding:14px; border-radius:6px; background: var(--bg-page); border:1px solid var(--color-border); display:flex; flex-direction:column; gap:10px;">
<div style="display:flex; justify-content:space-between; gap:10px; align-items:center; flex-wrap:wrap;">
<strong style="color:var(--color-text);">RunPod Runtime Status</strong>
<span style="display:inline-flex; align-items:center; gap:8px; background:rgba(15,23,42,0.25); border:1px solid var(--color-border); border-radius:999px; padding:4px 10px; color:var(--color-text);">
<span style="width:10px; height:10px; border-radius:50%; background:${badgeColor}; display:inline-block;"></span>
${escHtmlAttr(status.status || 'unknown')}
</span>
</div>
<div style="display:grid; grid-template-columns:repeat(auto-fit, minmax(220px, 1fr)); gap:10px; font-size:13px; color:var(--color-text); line-height:1.5;">
<div><strong>Mode:</strong> ${escHtmlAttr(status.mode || 'unknown')}</div>
<div><strong>Wrapper:</strong> ${escHtmlAttr(status.wrapper_mode || 'n/a')}</div>
<div><strong>Resource:</strong> ${escHtmlAttr(status.resource_kind || 'n/a')} ${status.resource_id ? `(${escHtmlAttr(status.resource_id)})` : ''}</div>
<div><strong>Endpoint:</strong> ${escHtmlAttr(status.endpoint_url || 'Unknown')}</div>
<div><strong>Catalog Count:</strong> ${escHtmlAttr(String(catalog.item_count ?? models.length ?? 0))}</div>
<div><strong>Catalog Source:</strong> ${escHtmlAttr(catalog.source || 'unknown')}</div>
<div><strong>Catalog Refreshed:</strong> ${escHtmlAttr(formatRunpodTimestamp(catalog.refreshed_at))}</div>
<div><strong>Last Used:</strong> ${escHtmlAttr(formatRunpodTimestamp(status.last_used_at))}</div>
</div>
${catalog.refresh_error ? `<div style="padding:10px; border-radius:6px; background:rgba(245,158,11,0.1); border:1px solid rgba(245,158,11,0.35); color:#fcd34d;"><strong>Last Refresh Error:</strong> ${escHtmlAttr(catalog.refresh_error)}</div>` : ''}
${modelPreview ? `<div style="display:flex; flex-wrap:wrap; gap:6px; align-items:center;"><strong style="color:var(--color-text);">Catalog Models:</strong> ${modelPreview}</div>` : ''}
</div>
`;
}
async function loadRunpodRuntimeStatus(providerKey) {
const container = document.getElementById(`runpod-runtime-status-${providerKey}`);
if (!container) return;
container.innerHTML = `<div style="color: var(--color-muted);">Loading RunPod runtime status...</div>`;
try {
const result = await apiCall('GET', `${BASE_PATH}/dashboard/providers/${encodeURIComponent(providerKey)}/runpod-status`);
if (!result.success) throw new Error(result.error || 'Failed to load RunPod status');
renderRunpodRuntimeStatus(providerKey, result.status, null);
} catch (error) {
renderRunpodRuntimeStatus(providerKey, null, `Unable to load RunPod status: ${error.message}`);
}
}
async function refreshRunpodCatalog(providerKey) {
const actionStatus = document.getElementById(`runpod-runtime-action-${providerKey}`);
if (actionStatus) {
actionStatus.style.color = 'var(--color-muted)';
actionStatus.textContent = 'Refreshing public catalog...';
}
try {
const result = await apiCall('POST', `${BASE_PATH}/dashboard/providers/${encodeURIComponent(providerKey)}/runpod-refresh`, {});
if (!result.success) throw new Error(result.error || 'Failed to refresh catalog');
renderRunpodRuntimeStatus(providerKey, result.status, null);
if (actionStatus) {
actionStatus.style.color = '#4ade80';
actionStatus.textContent = `Public catalog refreshed (${result.catalog_count} models).`;
}
} catch (error) {
if (actionStatus) {
actionStatus.style.color = '#f87171';
actionStatus.textContent = `Refresh failed: ${error.message}`;
}
}
}
function renderGpuSummary(metadata) { function renderGpuSummary(metadata) {
const gpus = Array.isArray(metadata?.gpus) ? metadata.gpus : []; const gpus = Array.isArray(metadata?.gpus) ? metadata.gpus : [];
if (!gpus.length) return 'None reported'; if (!gpus.length) return 'None reported';
...@@ -717,6 +810,7 @@ function renderProviderDetails(key) { ...@@ -717,6 +810,7 @@ function renderProviderDetails(key) {
const isQwenProvider = provider.type === 'qwen'; const isQwenProvider = provider.type === 'qwen';
const isCodexProvider = provider.type === 'codex'; const isCodexProvider = provider.type === 'codex';
const isCoderAIProvider = provider.type === 'coderai'; const isCoderAIProvider = provider.type === 'coderai';
const isRunpodProvider = provider.type === 'runpod';
// Initialize kiro_config if this is a kiro provider and doesn't have it // Initialize kiro_config if this is a kiro provider and doesn't have it
if (isKiroProvider && !provider.kiro_config) { if (isKiroProvider && !provider.kiro_config) {
...@@ -776,6 +870,23 @@ function renderProviderDetails(key) { ...@@ -776,6 +870,23 @@ function renderProviderDetails(key) {
registration_token: '' registration_token: ''
}; };
} }
if (isRunpodProvider && !provider.runpod_config) {
provider.runpod_config = {
mode: 'pod',
wrapper_mode: 'openai',
account_name: key,
management_api: 'auto',
idle_shutdown_ms: 900000,
startup_poll_interval_ms: 3000,
startup_timeout_ms: 300000,
pod_id: '',
template_id: '',
endpoint_id: '',
serverless_template_id: '',
public_endpoint_protocol_default: 'auto',
public_models: {}
};
}
const kiroConfig = provider.kiro_config || {}; const kiroConfig = provider.kiro_config || {};
const claudeConfig = provider.claude_config || {}; const claudeConfig = provider.claude_config || {};
...@@ -783,6 +894,7 @@ function renderProviderDetails(key) { ...@@ -783,6 +894,7 @@ function renderProviderDetails(key) {
const qwenConfig = provider.qwen_config || {}; const qwenConfig = provider.qwen_config || {};
const codexConfig = provider.codex_config || {}; const codexConfig = provider.codex_config || {};
const coderaiConfig = provider.coderai_config || {}; const coderaiConfig = provider.coderai_config || {};
const runpodConfig = provider.runpod_config || {};
const brokerSession = coderaiConfig.broker_session || {}; const brokerSession = coderaiConfig.broker_session || {};
const brokerConnected = !!brokerSession.connected; const brokerConnected = !!brokerSession.connected;
const ownerLabel = brokerSession.owner_user_id == null ? 'Global admin' : `User #${brokerSession.owner_user_id}`; const ownerLabel = brokerSession.owner_user_id == null ? 'Global admin' : `User #${brokerSession.owner_user_id}`;
...@@ -1121,6 +1233,128 @@ function renderProviderDetails(key) { ...@@ -1121,6 +1233,128 @@ function renderProviderDetails(key) {
</div> </div>
</div> </div>
`; `;
} else if (isRunpodProvider) {
const publicModels = runpodConfig.public_models || {};
const overrideRows = Object.entries(publicModels).map(([modelId, modelConfig]) => `
<div style="display:grid; grid-template-columns:minmax(220px, 1fr) minmax(180px, 220px); gap:10px; margin-bottom:10px; align-items:center;">
<input type="text" value="${escHtmlAttr(modelId)}" onchange="renameRunpodPublicModelOverride('${key}', '${escHtmlAttr(modelId)}', this.value)" placeholder="model-slug">
<select onchange="updateRunpodPublicModelProtocol('${key}', '${escHtmlAttr(modelId)}', this.value)">
<option value="auto" ${(modelConfig.protocol || 'auto') === 'auto' ? 'selected' : ''}>Auto detect</option>
<option value="runpod_public" ${modelConfig.protocol === 'runpod_public' ? 'selected' : ''}>RunPod public</option>
<option value="openai" ${modelConfig.protocol === 'openai' ? 'selected' : ''}>OpenAI</option>
<option value="ollama" ${modelConfig.protocol === 'ollama' ? 'selected' : ''}>Ollama</option>
<option value="coderai" ${modelConfig.protocol === 'coderai' ? 'selected' : ''}>CoderAI</option>
</select>
</div>
`).join('');
authFieldsHtml = `
<div style="background: var(--bg-panel); padding: 15px; border-radius: 5px; margin-bottom: 15px; border-left: 3px solid #4a9eff;">
<h4 style="margin: 0 0 15px 0; color: var(--color-link);">RunPod Runtime Configuration</h4>
<small style="color: var(--color-muted); display: block; margin-bottom: 15px;">
RunPod providers use the REST management API to start pods, resolve serverless endpoints, and cache discovered public models across AISBF instances.
</small>
<div class="form-group">
<label>Account Label</label>
<input type="text" value="${runpodConfig.account_name || key}" onchange="updateRunpodConfig('${key}', 'account_name', this.value)" placeholder="personal-runpod">
</div>
<div class="form-group">
<label>Management API Key</label>
<input type="password" value="${provider.api_key || ''}" onchange="updateProvider('${key}', 'api_key', this.value)" placeholder="RunPod API key">
<small style="color: var(--color-muted); display: block; margin-top: 5px;">Used for RunPod REST management calls and public endpoint requests.</small>
</div>
<div class="form-group">
<label>Mode</label>
<select onchange="updateRunpodMode('${key}', this.value)">
<option value="pod" ${runpodConfig.mode === 'pod' ? 'selected' : ''}>Pod-backed</option>
<option value="serverless_template" ${runpodConfig.mode === 'serverless_template' ? 'selected' : ''}>Serverless template</option>
<option value="public" ${runpodConfig.mode === 'public' ? 'selected' : ''}>Public catalog</option>
</select>
</div>
<div id="runpod-non-public-${key}" style="display:${runpodConfig.mode === 'public' ? 'none' : 'block'};">
<div class="form-group">
<label>Wrapper Mode</label>
<select onchange="updateRunpodConfig('${key}', 'wrapper_mode', this.value)">
<option value="openai" ${(runpodConfig.wrapper_mode || 'openai') === 'openai' ? 'selected' : ''}>OpenAI</option>
<option value="ollama" ${runpodConfig.wrapper_mode === 'ollama' ? 'selected' : ''}>Ollama</option>
<option value="coderai" ${runpodConfig.wrapper_mode === 'coderai' ? 'selected' : ''}>CoderAI</option>
</select>
</div>
<div class="form-group" style="display:${runpodConfig.mode === 'pod' ? 'block' : 'none'};" id="runpod-pod-id-group-${key}">
<label>Pod ID</label>
<input type="text" value="${runpodConfig.pod_id || ''}" onchange="updateRunpodConfig('${key}', 'pod_id', this.value)" placeholder="abc123">
</div>
<div class="form-group" style="display:${runpodConfig.mode === 'serverless_template' ? 'block' : 'none'};" id="runpod-endpoint-id-group-${key}">
<label>Endpoint ID</label>
<input type="text" value="${runpodConfig.endpoint_id || ''}" onchange="updateRunpodConfig('${key}', 'endpoint_id', this.value)" placeholder="endpoint id (optional if template set)">
</div>
<div class="form-group">
<label>Template ID</label>
<input type="text" value="${runpodConfig.template_id || ''}" onchange="updateRunpodConfig('${key}', 'template_id', this.value)" placeholder="tmpl_xyz">
</div>
<div class="form-group" style="display:${runpodConfig.mode === 'serverless_template' ? 'block' : 'none'};" id="runpod-serverless-template-group-${key}">
<label>Serverless Template ID</label>
<input type="text" value="${runpodConfig.serverless_template_id || ''}" onchange="updateRunpodConfig('${key}', 'serverless_template_id', this.value)" placeholder="serverless template id">
</div>
<div class="form-group">
<label>Idle Shutdown (ms)</label>
<input type="number" value="${runpodConfig.idle_shutdown_ms || 900000}" onchange="updateRunpodNumberConfig('${key}', 'idle_shutdown_ms', this.value)">
</div>
<div class="form-group">
<label>Startup Poll Interval (ms)</label>
<input type="number" value="${runpodConfig.startup_poll_interval_ms || 3000}" onchange="updateRunpodNumberConfig('${key}', 'startup_poll_interval_ms', this.value)">
</div>
<div class="form-group">
<label>Startup Timeout (ms)</label>
<input type="number" value="${runpodConfig.startup_timeout_ms || 300000}" onchange="updateRunpodNumberConfig('${key}', 'startup_timeout_ms', this.value)">
</div>
</div>
<div id="runpod-public-${key}" style="display:${runpodConfig.mode === 'public' ? 'block' : 'none'};">
<div class="form-group">
<label>Default Public Protocol</label>
<select onchange="updateRunpodConfig('${key}', 'public_endpoint_protocol_default', this.value)">
<option value="auto" ${(runpodConfig.public_endpoint_protocol_default || 'auto') === 'auto' ? 'selected' : ''}>Auto detect</option>
<option value="runpod_public" ${runpodConfig.public_endpoint_protocol_default === 'runpod_public' ? 'selected' : ''}>RunPod public</option>
<option value="openai" ${runpodConfig.public_endpoint_protocol_default === 'openai' ? 'selected' : ''}>OpenAI</option>
<option value="ollama" ${runpodConfig.public_endpoint_protocol_default === 'ollama' ? 'selected' : ''}>Ollama</option>
<option value="coderai" ${runpodConfig.public_endpoint_protocol_default === 'coderai' ? 'selected' : ''}>CoderAI</option>
</select>
</div>
<div class="form-group">
<label>Per-model Protocol Overrides</label>
<div style="background: var(--bg-page); border: 1px solid var(--color-border); border-radius: 6px; padding: 12px;">
${overrideRows || '<small style="color: var(--color-muted); display:block;">No overrides configured.</small>'}
<button type="button" class="btn btn-secondary" style="margin-top:10px;" onclick="addRunpodPublicModelOverride('${key}')">Add Override</button>
</div>
</div>
</div>
<div class="form-group" style="margin-top: 20px;">
<div style="display:flex; justify-content:space-between; gap:10px; align-items:center; flex-wrap:wrap; margin-bottom:10px;">
<label style="margin:0;">RunPod Runtime Status</label>
<div style="display:flex; gap:8px; flex-wrap:wrap;">
<button type="button" class="btn btn-secondary" onclick="loadRunpodRuntimeStatus('${safeKey}')">Refresh Status</button>
${runpodConfig.mode === 'public' ? `<button type="button" class="btn" onclick="refreshRunpodCatalog('${safeKey}')">Refresh Public Catalog</button>` : ''}
</div>
</div>
<div id="runpod-runtime-status-${key}" style="margin-top:10px;"></div>
<div id="runpod-runtime-action-${key}" style="margin-top:8px; font-size:13px; color: var(--color-muted);"></div>
</div>
</div>
`;
} else if (isQwenProvider) { } else if (isQwenProvider) {
// Qwen authentication fields - supports both API key and OAuth2 // Qwen authentication fields - supports both API key and OAuth2
authFieldsHtml = ` authFieldsHtml = `
...@@ -1313,6 +1547,7 @@ function renderProviderDetails(key) { ...@@ -1313,6 +1547,7 @@ function renderProviderDetails(key) {
<option value="qwen" ${provider.type === 'qwen' ? 'selected' : ''}>Qwen (OAuth2)</option> <option value="qwen" ${provider.type === 'qwen' ? 'selected' : ''}>Qwen (OAuth2)</option>
<option value="codex" ${provider.type === 'codex' ? 'selected' : ''}>Codex (OpenAI OAuth2)</option> <option value="codex" ${provider.type === 'codex' ? 'selected' : ''}>Codex (OpenAI OAuth2)</option>
<option value="coderai" ${provider.type === 'coderai' ? 'selected' : ''}>CoderAI</option> <option value="coderai" ${provider.type === 'coderai' ? 'selected' : ''}>CoderAI</option>
<option value="runpod" ${provider.type === 'runpod' ? 'selected' : ''}>RunPod</option>
</select> </select>
</div> </div>
...@@ -1470,6 +1705,10 @@ function renderProviderDetails(key) { ...@@ -1470,6 +1705,10 @@ function renderProviderDetails(key) {
renderModels(key); renderModels(key);
if (isRunpodProvider) {
setTimeout(() => loadRunpodRuntimeStatus(key), 0);
}
// Fetch full usage data for codex providers when panel is expanded // Fetch full usage data for codex providers when panel is expanded
if (isCodexProvider) { if (isCodexProvider) {
const cached = _usageCache[key]; const cached = _usageCache[key];
...@@ -1643,7 +1882,8 @@ function updateNewProviderDefaults() { ...@@ -1643,7 +1882,8 @@ function updateNewProviderDefaults() {
'kilocode': 'Kilocode provider. Uses OAuth2 Device Authorization Grant. Endpoint: https://api.kilo.ai/api/gateway', 'kilocode': 'Kilocode provider. Uses OAuth2 Device Authorization Grant. Endpoint: https://api.kilo.ai/api/gateway',
'qwen': 'Qwen provider. Uses OAuth2 Device Authorization Grant or API key. Endpoint: https://dashscope.aliyuncs.com/compatible-mode/v1', 'qwen': 'Qwen provider. Uses OAuth2 Device Authorization Grant or API key. Endpoint: https://dashscope.aliyuncs.com/compatible-mode/v1',
'codex': 'Codex provider. Uses OAuth2 Device Authorization Grant (same protocol as OpenAI). Endpoint: https://api.openai.com/v1', 'codex': 'Codex provider. Uses OAuth2 Device Authorization Grant (same protocol as OpenAI). Endpoint: https://api.openai.com/v1',
'coderai': 'CoderAI provider. In broker mode, CoderAI connects inbound to AISBF using a provider-scoped registration token. In direct mode, AISBF calls it as an OpenAI-compatible endpoint. Default endpoint: http://127.0.0.1:11437' 'coderai': 'CoderAI provider. In broker mode, CoderAI connects inbound to AISBF using a provider-scoped registration token. In direct mode, AISBF calls it as an OpenAI-compatible endpoint. Default endpoint: http://127.0.0.1:11437',
'runpod': 'RunPod provider. Uses the RunPod REST management API for pod lifecycle, serverless endpoints, and public endpoint catalogs. Default endpoint: https://rest.runpod.io/v1'
}; };
descriptionEl.textContent = descriptions[providerType] || window.i18n.t('providers.standard_config'); descriptionEl.textContent = descriptions[providerType] || window.i18n.t('providers.standard_config');
...@@ -2369,6 +2609,61 @@ function updateCoderAIConfig(key, field, value) { ...@@ -2369,6 +2609,61 @@ function updateCoderAIConfig(key, field, value) {
providersData[key].coderai_config[field] = value; providersData[key].coderai_config[field] = value;
} }
function updateRunpodConfig(key, field, value) {
if (!providersData[key].runpod_config) {
providersData[key].runpod_config = {};
}
providersData[key].runpod_config[field] = value;
}
function updateRunpodNumberConfig(key, field, value) {
if (!providersData[key].runpod_config) {
providersData[key].runpod_config = {};
}
providersData[key].runpod_config[field] = value === '' ? null : parseInt(value, 10);
}
function updateRunpodMode(key, value) {
if (!providersData[key].runpod_config) {
providersData[key].runpod_config = {};
}
providersData[key].runpod_config.mode = value;
renderProviderDetails(key);
}
function addRunpodPublicModelOverride(key) {
if (!providersData[key].runpod_config) providersData[key].runpod_config = {};
if (!providersData[key].runpod_config.public_models) providersData[key].runpod_config.public_models = {};
let candidate = 'model-slug';
let counter = 1;
while (providersData[key].runpod_config.public_models[candidate]) {
counter += 1;
candidate = `model-slug-${counter}`;
}
providersData[key].runpod_config.public_models[candidate] = { protocol: 'auto', capabilities: [] };
renderProviderDetails(key);
}
function renameRunpodPublicModelOverride(key, previousModelId, nextModelId) {
if (!providersData[key].runpod_config || !providersData[key].runpod_config.public_models) return;
const current = providersData[key].runpod_config.public_models[previousModelId];
if (!current) return;
const trimmed = String(nextModelId || '').trim();
if (!trimmed || trimmed === previousModelId) return;
delete providersData[key].runpod_config.public_models[previousModelId];
providersData[key].runpod_config.public_models[trimmed] = current;
renderProviderDetails(key);
}
function updateRunpodPublicModelProtocol(key, modelId, protocol) {
if (!providersData[key].runpod_config) providersData[key].runpod_config = {};
if (!providersData[key].runpod_config.public_models) providersData[key].runpod_config.public_models = {};
if (!providersData[key].runpod_config.public_models[modelId]) {
providersData[key].runpod_config.public_models[modelId] = { capabilities: [] };
}
providersData[key].runpod_config.public_models[modelId].protocol = protocol;
}
function updateCoderAIBrokerMode(key, enabled) { function updateCoderAIBrokerMode(key, enabled) {
if (!providersData[key].coderai_config) { if (!providersData[key].coderai_config) {
providersData[key].coderai_config = {}; providersData[key].coderai_config = {};
...@@ -2443,6 +2738,23 @@ function updateProviderType(key, value) { ...@@ -2443,6 +2738,23 @@ function updateProviderType(key, value) {
if (defaultEndpoint) { if (defaultEndpoint) {
providersData[key].endpoint = defaultEndpoint; providersData[key].endpoint = defaultEndpoint;
} }
if (value === 'runpod' && !providersData[key].runpod_config) {
providersData[key].runpod_config = {
mode: 'pod',
wrapper_mode: 'openai',
account_name: key,
management_api: 'auto',
idle_shutdown_ms: 900000,
startup_poll_interval_ms: 3000,
startup_timeout_ms: 300000,
pod_id: '',
template_id: '',
endpoint_id: '',
serverless_template_id: '',
public_endpoint_protocol_default: 'auto',
public_models: {}
};
}
// Re-render to update the configuration fields // Re-render to update the configuration fields
renderProvidersList(); renderProvidersList();
} }
...@@ -2551,6 +2863,22 @@ async function confirmAddProvider() { ...@@ -2551,6 +2863,22 @@ async function confirmAddProvider() {
registration_path: '/coderai/register', registration_path: '/coderai/register',
registration_token: '' registration_token: ''
}; };
} else if (type === 'runpod') {
providersData[key].runpod_config = {
mode: 'pod',
wrapper_mode: 'openai',
account_name: key,
management_api: 'auto',
idle_shutdown_ms: 900000,
startup_poll_interval_ms: 3000,
startup_timeout_ms: 300000,
pod_id: '',
template_id: '',
endpoint_id: '',
serverless_template_id: '',
public_endpoint_protocol_default: 'auto',
public_models: {}
};
} }
// Immediately persist the new provider // Immediately persist the new provider
...@@ -2587,7 +2915,8 @@ function getDefaultEndpoint(type) { ...@@ -2587,7 +2915,8 @@ function getDefaultEndpoint(type) {
'kilocode': 'https://api.kilo.ai/api/gateway', 'kilocode': 'https://api.kilo.ai/api/gateway',
'qwen': 'https://dashscope.aliyuncs.com/compatible-mode/v1', 'qwen': 'https://dashscope.aliyuncs.com/compatible-mode/v1',
'codex': 'https://api.openai.com/v1', 'codex': 'https://api.openai.com/v1',
'coderai': 'http://127.0.0.1:11437' 'coderai': 'http://127.0.0.1:11437',
'runpod': 'https://rest.runpod.io/v1'
}; };
return defaults[type] || ''; return defaults[type] || '';
} }
......
import json
import sqlite3
from pathlib import Path
import pytest
from fastapi.testclient import TestClient
from aisbf.config import ProviderConfig, config
from aisbf.database import DatabaseManager, DatabaseRegistry
from aisbf.providers import PROVIDER_HANDLERS, get_provider_handler
from aisbf.providers.base import BaseProviderHandler
from aisbf.providers.runpod import RunpodProviderHandler
from aisbf.app.model_cache import get_provider_models, _model_cache, _model_cache_timestamps, _endpoint_model_cache
from fastapi.responses import HTMLResponse
import sys
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
from main import app
class StubRunpodHandler(BaseProviderHandler):
def __init__(self, provider_id: str, api_key: str | None = None, user_id: int | None = None, provider_config=None):
self.provider_config = provider_config
self.user_provider_config = provider_config if isinstance(provider_config, dict) else None
super().__init__(provider_id, api_key, user_id=user_id)
def validate_credentials(self) -> bool:
return True
async def get_models(self):
return []
@pytest.fixture(autouse=True)
def reset_runpod_state(monkeypatch, tmp_path):
original_handlers = dict(PROVIDER_HANDLERS)
original_provider = config.providers.get("runpod-test")
original_error = config.error_tracking.get("runpod-test")
original_instances = dict(DatabaseRegistry._instances)
PROVIDER_HANDLERS["runpod"] = StubRunpodHandler
config.providers["runpod-test"] = ProviderConfig(
id="runpod-test",
name="RunPod Test",
endpoint="https://rest.runpod.io/v1",
type="runpod",
api_key_required=True,
api_key="test-key",
rate_limit=0,
runpod_config={
"mode": "public",
"account_name": "test-account",
"public_endpoint_protocol_default": "auto",
},
)
config.error_tracking["runpod-test"] = {
"enabled": True,
"max_errors": 5,
"cooldown_seconds": 60,
"failures": 0,
"last_failure": 0,
"disabled_until": None,
}
_model_cache.clear()
_model_cache_timestamps.clear()
_endpoint_model_cache.clear()
db_path = tmp_path / "runpod-test.db"
DatabaseRegistry._instances = {}
db = DatabaseRegistry.get_config_database({"type": "sqlite", "sqlite_path": str(db_path)})
yield db
PROVIDER_HANDLERS.clear()
PROVIDER_HANDLERS.update(original_handlers)
if original_provider is None:
config.providers.pop("runpod-test", None)
else:
config.providers["runpod-test"] = original_provider
if original_error is None:
config.error_tracking.pop("runpod-test", None)
else:
config.error_tracking["runpod-test"] = original_error
DatabaseRegistry._instances = original_instances
def test_provider_config_accepts_runpod_config():
provider = ProviderConfig(
id="runpod-test",
name="RunPod",
endpoint="https://rest.runpod.io/v1",
type="runpod",
api_key_required=True,
api_key="key",
rate_limit=0,
runpod_config={
"mode": "pod",
"wrapper_mode": "openai",
"pod_id": "pod-123",
},
)
assert provider.runpod_config["wrapper_mode"] == "openai"
assert provider.runpod_config["pod_id"] == "pod-123"
def test_get_provider_handler_supports_runpod(reset_runpod_state):
handler = get_provider_handler("runpod-test")
assert isinstance(handler, StubRunpodHandler)
assert handler.provider_id == "runpod-test"
def test_database_migration_creates_runpod_provider_state_table(reset_runpod_state):
conn = sqlite3.connect(reset_runpod_state.db_config["sqlite_path"])
try:
rows = conn.execute("PRAGMA table_info(runpod_provider_state)").fetchall()
finally:
conn.close()
column_names = {row[1] for row in rows}
assert rows
assert {"provider_id", "resource_kind", "status", "metadata", "updated_at"}.issubset(column_names)
@pytest.mark.asyncio
async def test_runpod_public_models_use_cached_catalog(reset_runpod_state):
reset_runpod_state.save_runpod_provider_state(
provider_scope="global",
owner_user_id=None,
provider_id="runpod-test",
mode="public",
wrapper_mode=None,
resource_id="public-catalog",
resource_kind="public",
status="ready",
endpoint_url="https://api.runpod.ai/v2",
public_catalog_json=[
{
"id": "black-forest-labs-flux-1-dev",
"name": "Flux Dev",
"protocol": "runpod_public",
"capabilities": ["image"],
"route_base": "https://api.runpod.ai/v2/black-forest-labs-flux-1-dev",
"request_mode": "runsync",
}
],
metadata={"source": "test"},
)
models = await get_provider_models("runpod-test", config.providers["runpod-test"], config)
assert len(models) == 1
assert models[0]["id"] == "runpod-test/black-forest-labs-flux-1-dev"
assert models[0]["source"] == "api_cache"
assert models[0]["capabilities"] == ["image"]
@pytest.mark.asyncio
async def test_runpod_refresh_public_catalog_normalizes_live_entries(reset_runpod_state, monkeypatch):
handler = RunpodProviderHandler("runpod-test", api_key="test-key", provider_config=config.providers["runpod-test"])
async def fake_public_catalog_source():
return [
{
"id": "black-forest-labs-flux-1-dev",
"name": "Flux Dev",
"route_base": "https://api.runpod.ai/v2/black-forest-labs-flux-1-dev",
"description": "Image model",
"schema": {"input": {"prompt": "string"}},
}
]
monkeypatch.setattr(handler, "_fetch_live_public_catalog_entries", fake_public_catalog_source)
catalog = await handler.refresh_public_catalog()
assert len(catalog) == 1
assert catalog[0]["id"] == "black-forest-labs-flux-1-dev"
assert catalog[0]["route_base"] == "https://api.runpod.ai/v2/black-forest-labs-flux-1-dev"
assert catalog[0]["protocol"] in {"runpod_public", "openai"}
state = reset_runpod_state.get_runpod_provider_state("global", None, "runpod-test")
assert state["metadata"]["catalog_item_count"] == 1
@pytest.mark.asyncio
async def test_runpod_public_refresh_applies_manual_protocol_override(reset_runpod_state, monkeypatch):
provider = config.providers["runpod-test"]
provider.runpod_config["public_models"] = {
"black-forest-labs-flux-1-dev": {"protocol": "openai"}
}
handler = RunpodProviderHandler("runpod-test", api_key="test-key", provider_config=provider)
async def fake_public_catalog_source():
return [
{
"id": "black-forest-labs-flux-1-dev",
"route_base": "https://api.runpod.ai/v2/black-forest-labs-flux-1-dev",
"protocol": "runpod_public",
}
]
monkeypatch.setattr(handler, "_fetch_live_public_catalog_entries", fake_public_catalog_source)
catalog = await handler.refresh_public_catalog()
assert catalog[0]["protocol"] == "openai"
@pytest.mark.asyncio
async def test_runpod_public_refresh_preserves_cached_catalog_on_failure(reset_runpod_state, monkeypatch):
reset_runpod_state.save_runpod_provider_state(
provider_scope="global",
owner_user_id=None,
provider_id="runpod-test",
mode="public",
wrapper_mode=None,
resource_id="public-catalog",
resource_kind="public",
status="ready",
endpoint_url="https://api.runpod.ai/v2",
public_catalog_json=[
{
"id": "cached-model",
"name": "Cached Model",
"protocol": "runpod_public",
"route_base": "https://api.runpod.ai/v2/cached-model",
"request_mode": "runsync",
"capabilities": [],
}
],
metadata={"catalog_source": "cached"},
)
handler = RunpodProviderHandler("runpod-test", api_key="test-key", provider_config=config.providers["runpod-test"])
async def fake_public_catalog_source():
raise RuntimeError("upstream unavailable")
monkeypatch.setattr(handler, "_fetch_live_public_catalog_entries", fake_public_catalog_source)
with pytest.raises(RuntimeError):
await handler.refresh_public_catalog()
states = [
state
for state in reset_runpod_state.list_runpod_provider_states()
if state["provider_scope"] == "global" and state["provider_id"] == "runpod-test"
]
state = next(
state for state in states if "catalog_refresh_error" in (state.get("metadata") or {})
)
assert state["public_catalog_json"][0]["id"] == "cached-model"
assert "upstream unavailable" in state["metadata"]["catalog_refresh_error"]
def test_runpod_global_state_save_updates_existing_row(reset_runpod_state):
reset_runpod_state.save_runpod_provider_state(
provider_scope="global",
owner_user_id=None,
provider_id="runpod-test",
mode="public",
wrapper_mode=None,
resource_id="public-catalog",
resource_kind="public",
status="ready",
endpoint_url="https://api.runpod.ai/v2",
public_catalog_json=[{"id": "cached-model"}],
metadata={"catalog_source": "cached"},
)
reset_runpod_state.save_runpod_provider_state(
provider_scope="global",
owner_user_id=None,
provider_id="runpod-test",
mode="public",
wrapper_mode=None,
resource_id="public-catalog",
resource_kind="public",
status="ready",
endpoint_url="https://api.runpod.ai/v2",
public_catalog_json=[{"id": "live-model"}],
metadata={"catalog_source": "live", "catalog_refresh_error": "upstream unavailable"},
)
state = reset_runpod_state.get_runpod_provider_state("global", None, "runpod-test")
matching_states = [
item
for item in reset_runpod_state.list_runpod_provider_states()
if item["provider_scope"] == "global" and item["provider_id"] == "runpod-test"
]
assert len(matching_states) == 1
assert state["public_catalog_json"][0]["id"] == "live-model"
assert state["metadata"]["catalog_source"] == "live"
assert state["metadata"]["catalog_refresh_error"] == "upstream unavailable"
@pytest.mark.asyncio
async def test_runpod_pod_mode_starts_stopped_pod_and_waits_until_ready(reset_runpod_state, monkeypatch):
provider = config.providers["runpod-test"]
provider.runpod_config.update({"mode": "pod", "wrapper_mode": "openai", "pod_id": "pod-123"})
handler = RunpodProviderHandler("runpod-test", api_key="test-key", provider_config=provider)
calls = []
responses = iter([
{"id": "pod-123", "desiredStatus": "EXITED", "publicIp": None, "portMappings": []},
{"id": "pod-123", "desiredStatus": "RUNNING", "publicIp": None, "portMappings": []},
{"id": "pod-123", "desiredStatus": "RUNNING", "publicIp": "1.2.3.4", "portMappings": [{"publicPort": 8000}]},
])
async def fake_management_request(method, path, params=None, json_body=None):
calls.append((method, path))
if method == "POST" and path == "/pods/pod-123/start":
return {"ok": True}
return next(responses)
monkeypatch.setattr(handler, "_management_request", fake_management_request)
runtime = await handler._ensure_pod_ready()
assert ("POST", "/pods/pod-123/start") in calls
assert runtime["endpoint_url"] == "http://1.2.3.4:8000/v1"
@pytest.mark.asyncio
async def test_runpod_idle_shutdown_stops_running_pod_after_threshold(reset_runpod_state, monkeypatch):
provider = config.providers["runpod-test"]
provider.runpod_config.update({"mode": "pod", "wrapper_mode": "openai", "pod_id": "pod-123", "idle_shutdown_ms": 1000})
reset_runpod_state.save_runpod_provider_state(
provider_scope="global",
owner_user_id=None,
provider_id="runpod-test",
mode="pod",
wrapper_mode="openai",
resource_id="pod-123",
resource_kind="pod",
status="running",
endpoint_url="http://1.2.3.4:8000/v1",
public_catalog_json=[],
metadata={},
last_used_at=0,
)
handler = RunpodProviderHandler("runpod-test", api_key="test-key", provider_config=provider)
calls = []
async def fake_management_request(method, path, params=None, json_body=None):
calls.append((method, path))
return {"ok": True}
monkeypatch.setattr(handler, "_management_request", fake_management_request)
monkeypatch.setattr("aisbf.providers.runpod.time.time", lambda: 5)
stopped = await handler.poll_idle_shutdown()
assert stopped is True
assert calls == [("POST", "/pods/pod-123/stop")]
states = [
state
for state in reset_runpod_state.list_runpod_provider_states()
if state["provider_scope"] == "global" and state["provider_id"] == "runpod-test"
]
state = next(state for state in states if state["status"] == "stopped")
assert state["resource_id"] == "pod-123"
def test_runpod_build_runtime_status_serializes_cached_state(reset_runpod_state):
provider = config.providers["runpod-test"]
provider.runpod_config.update({"mode": "public"})
reset_runpod_state.save_runpod_provider_state(
provider_scope="global",
owner_user_id=None,
provider_id="runpod-test",
mode="public",
wrapper_mode=None,
resource_id="public-catalog",
resource_kind="public",
status="ready",
endpoint_url="https://api.runpod.ai/v2",
public_catalog_json=[
{
"id": "black-forest-labs-flux-1-dev",
"name": "Flux Dev",
"protocol": "runpod_public",
"route_base": "https://api.runpod.ai/v2/black-forest-labs-flux-1-dev",
"request_mode": "runsync",
"capabilities": ["image"],
}
],
metadata={
"catalog_source": "live",
"catalog_refreshed_at": 123,
"catalog_item_count": 1,
"catalog_refresh_error": "stale failure",
},
last_used_at=111,
last_status_sync_at=222,
)
handler = RunpodProviderHandler("runpod-test", api_key="test-key", provider_config=provider)
status = handler.build_runtime_status()
assert status["provider_id"] == "runpod-test"
assert status["mode"] == "public"
assert status["status"] == "ready"
assert status["resource_kind"] == "public"
assert status["endpoint_url"] == "https://api.runpod.ai/v2"
assert status["catalog"]["item_count"] == 1
assert status["catalog"]["refreshed_at"] == 123
assert status["catalog"]["source"] == "live"
assert status["catalog"]["refresh_error"] == "stale failure"
assert status["catalog"]["models"][0]["id"] == "black-forest-labs-flux-1-dev"
assert status["last_used_at"] == 111
assert status["last_status_sync_at"] == 222
def test_runpod_build_delegate_handler_uses_runtime_endpoint_for_openai(reset_runpod_state):
provider = config.providers["runpod-test"]
provider.runpod_config.update({"mode": "pod", "wrapper_mode": "openai", "pod_id": "pod-123"})
handler = RunpodProviderHandler("runpod-test", api_key="test-key", provider_config=provider)
reset_runpod_state.save_runpod_provider_state(
provider_scope="global",
owner_user_id=None,
provider_id="runpod-test",
mode="pod",
wrapper_mode="openai",
resource_id="pod-123",
resource_kind="pod",
status="running",
endpoint_url="http://1.2.3.4:8000/v1",
public_catalog_json=[],
metadata={},
)
delegate = handler._build_delegate_handler("openai")
assert str(delegate.client.base_url).rstrip("/") == "http://1.2.3.4:8000/v1"
def test_runpod_build_delegate_handler_uses_runtime_endpoint_for_ollama(reset_runpod_state):
provider = config.providers["runpod-test"]
provider.runpod_config.update({"mode": "pod", "wrapper_mode": "ollama", "pod_id": "pod-123"})
handler = RunpodProviderHandler("runpod-test", api_key="test-key", provider_config=provider)
reset_runpod_state.save_runpod_provider_state(
provider_scope="global",
owner_user_id=None,
provider_id="runpod-test",
mode="pod",
wrapper_mode="ollama",
resource_id="pod-123",
resource_kind="pod",
status="running",
endpoint_url="http://1.2.3.4:11434",
public_catalog_json=[],
metadata={},
)
delegate = handler._build_delegate_handler("ollama")
assert str(delegate.client.base_url).rstrip("/") == "http://1.2.3.4:11434"
def test_dashboard_save_preserves_runpod_config(reset_runpod_state, monkeypatch):
client = TestClient(app)
client.cookies.set("session", "stub")
def fake_require_dashboard_auth(request):
request.session.update({"logged_in": True, "user_id": 1, "username": "alice", "role": "user"})
return None
from aisbf.routes.dashboard import providers as dashboard_providers
monkeypatch.setattr(dashboard_providers, "require_dashboard_auth", fake_require_dashboard_auth)
monkeypatch.setattr(
dashboard_providers,
"_templates",
type("TemplatesStub", (), {"TemplateResponse": staticmethod(lambda *args, **kwargs: HTMLResponse("ok"))})(),
)
saved = {}
def fake_save_user_provider(user_id, provider_key, provider_config):
saved[provider_key] = provider_config
monkeypatch.setattr(reset_runpod_state, "get_user_providers", lambda user_id: [])
monkeypatch.setattr(reset_runpod_state, "save_user_provider", fake_save_user_provider)
response = client.post(
"/dashboard/providers",
data={
"config": json.dumps({
"runpod-test": {
"type": "runpod",
"name": "RunPod",
"endpoint": "https://rest.runpod.io/v1",
"api_key_required": True,
"api_key": "test-key",
"runpod_config": {
"mode": "pod",
"wrapper_mode": "ollama",
"pod_id": "pod-123",
},
}
})
},
)
assert response.status_code == 200
assert saved["runpod-test"]["runpod_config"]["wrapper_mode"] == "ollama"
def test_dashboard_runpod_status_returns_runtime_state(reset_runpod_state, monkeypatch):
client = TestClient(app)
client.cookies.set("session", "stub")
def fake_require_dashboard_auth(request):
request.session.update({"logged_in": True, "user_id": 1, "username": "alice", "role": "user"})
return None
from aisbf.routes.dashboard import providers as dashboard_providers
monkeypatch.setattr(dashboard_providers, "require_dashboard_auth", fake_require_dashboard_auth)
saved_provider = {
"provider_id": "runpod-test",
"config": {
"type": "runpod",
"name": "RunPod",
"endpoint": "https://rest.runpod.io/v1",
"api_key": "test-key",
"runpod_config": {"mode": "public"},
},
}
monkeypatch.setattr(reset_runpod_state, "get_user_provider", lambda user_id, provider_id: saved_provider if provider_id == "runpod-test" else None)
class StubRunpodStatusHandler:
def __init__(self, provider_id, api_key=None, user_id=None, provider_config=None):
self.provider_id = provider_id
def build_runtime_status(self):
return {
"provider_id": self.provider_id,
"mode": "public",
"status": "ready",
"catalog": {"item_count": 1},
}
monkeypatch.setattr(dashboard_providers, "RunpodProviderHandler", StubRunpodStatusHandler, raising=False)
response = client.get("/dashboard/providers/runpod-test/runpod-status")
assert response.status_code == 200
assert response.json()["success"] is True
assert response.json()["status"]["provider_id"] == "runpod-test"
assert response.json()["status"]["catalog"]["item_count"] == 1
def test_dashboard_runpod_refresh_returns_refreshed_status(reset_runpod_state, monkeypatch):
client = TestClient(app)
client.cookies.set("session", "stub")
def fake_require_dashboard_auth(request):
request.session.update({"logged_in": True, "user_id": 1, "username": "alice", "role": "user"})
return None
from aisbf.routes.dashboard import providers as dashboard_providers
monkeypatch.setattr(dashboard_providers, "require_dashboard_auth", fake_require_dashboard_auth)
saved_provider = {
"provider_id": "runpod-test",
"config": {
"type": "runpod",
"name": "RunPod",
"endpoint": "https://rest.runpod.io/v1",
"api_key": "test-key",
"runpod_config": {"mode": "public"},
},
}
monkeypatch.setattr(reset_runpod_state, "get_user_provider", lambda user_id, provider_id: saved_provider if provider_id == "runpod-test" else None)
class StubRunpodRefreshHandler:
def __init__(self, provider_id, api_key=None, user_id=None, provider_config=None):
self.provider_id = provider_id
async def refresh_public_catalog(self):
return [{"id": "black-forest-labs-flux-1-dev"}]
def build_runtime_status(self):
return {
"provider_id": self.provider_id,
"mode": "public",
"status": "ready",
"catalog": {"item_count": 1, "models": [{"id": "black-forest-labs-flux-1-dev"}]},
}
monkeypatch.setattr(dashboard_providers, "RunpodProviderHandler", StubRunpodRefreshHandler, raising=False)
response = client.post("/dashboard/providers/runpod-test/runpod-refresh")
assert response.status_code == 200
body = response.json()
assert body["success"] is True
assert body["catalog_count"] == 1
assert body["status"]["provider_id"] == "runpod-test"
assert body["status"]["catalog"]["models"][0]["id"] == "black-forest-labs-flux-1-dev"
def test_dashboard_providers_page_includes_runpod_runtime_ui(reset_runpod_state, monkeypatch):
client = TestClient(app)
client.cookies.set("session", "stub")
from aisbf.routes.dashboard import providers as dashboard_providers
def fake_require_dashboard_auth(request):
request.session.update({"logged_in": True, "user_id": None, "username": "admin", "role": "admin"})
return None
class TemplatesStub:
@staticmethod
def TemplateResponse(*args, **kwargs):
context = kwargs.get("context") or {}
html = "\n".join([
str(context.get("providers_json", "")),
"RunPod Runtime Status",
"refreshRunpodCatalog",
"loadRunpodRuntimeStatus",
])
return HTMLResponse(html)
env = type("EnvStub", (), {"cache": {}})()
monkeypatch.setattr(dashboard_providers, "require_dashboard_auth", fake_require_dashboard_auth)
monkeypatch.setattr(dashboard_providers, "_templates", TemplatesStub())
response = client.get("/dashboard/providers")
assert response.status_code == 200
assert "RunPod Runtime Status" in response.text
assert "refreshRunpodCatalog" in response.text
assert "loadRunpodRuntimeStatus" in response.text
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment