fix: reuse cached provider models before auto-detect refresh

parent 5ce4da23
...@@ -26,25 +26,77 @@ def _is_broker_only_coderai(provider_config) -> bool: ...@@ -26,25 +26,77 @@ def _is_broker_only_coderai(provider_config) -> bool:
return bool(coderai_config.get('broker_mode', False)) return bool(coderai_config.get('broker_mode', False))
def _cache_key_for_provider(provider_id: str, user_id: Optional[int] = None) -> str:
return f"{provider_id}:{user_id}" if user_id is not None else provider_id
def _endpoint_cache_key(provider_config) -> Optional[str]:
if provider_config is None:
return None
prov_type = getattr(provider_config, 'type', '') or ''
endpoint = getattr(provider_config, 'endpoint', '') or ''
if not prov_type and not endpoint:
return None
return f"{prov_type}:{endpoint}"
def _get_cached_provider_models(cache_key: str) -> Optional[list]:
cached_at = _model_cache_timestamps.get(cache_key)
if cached_at is None:
return None
if time.time() - cached_at >= _cache_refresh_interval:
return None
return _model_cache.get(cache_key)
def _store_provider_models_in_cache(cache_key: str, models: list, cached_at: Optional[float] = None) -> float:
now = cached_at if cached_at is not None else time.time()
_model_cache[cache_key] = models
_model_cache_timestamps[cache_key] = now
return now
def _get_cached_endpoint_models(endpoint_key: Optional[str]) -> Optional[tuple[list, float]]:
if not endpoint_key:
return None
cached = _endpoint_model_cache.get(endpoint_key)
if not cached:
return None
cached_models, cached_at = cached
if time.time() - cached_at >= _cache_refresh_interval:
return None
return cached_models, cached_at
def _store_endpoint_models_in_cache(endpoint_key: Optional[str], models: list, cached_at: float) -> None:
if not endpoint_key:
return
_endpoint_model_cache[endpoint_key] = (models, cached_at)
async def fetch_provider_models(provider_id: str, config, user_id: Optional[int] = None) -> list: async def fetch_provider_models(provider_id: str, config, user_id: Optional[int] = None) -> list:
global _model_cache, _model_cache_timestamps, _endpoint_model_cache global _model_cache, _model_cache_timestamps, _endpoint_model_cache
cache_key = f"{provider_id}:{user_id}" if user_id else provider_id cache_key = _cache_key_for_provider(provider_id, user_id)
try: try:
if not user_id and config is not None: cached_models = _get_cached_provider_models(cache_key)
if cached_models is not None:
return cached_models
provider_config = None
endpoint_key = None
if config is not None:
try: try:
prov_cfg = config.get_provider(provider_id) provider_config = config.get_provider(provider_id)
prov_type = getattr(prov_cfg, 'type', '') endpoint_key = _endpoint_cache_key(provider_config) if user_id is None else None
endpoint = getattr(prov_cfg, 'endpoint', '') or ''
endpoint_key = f"{prov_type}:{endpoint}"
if endpoint_key and endpoint_key in _endpoint_model_cache:
cached_models, cached_at = _endpoint_model_cache[endpoint_key]
if time.time() - cached_at < _cache_refresh_interval:
_model_cache[cache_key] = cached_models
_model_cache_timestamps[cache_key] = cached_at
return cached_models
except Exception: except Exception:
pass provider_config = None
endpoint_cached = _get_cached_endpoint_models(endpoint_key)
if endpoint_cached is not None:
cached_models, cached_at = endpoint_cached
_store_provider_models_in_cache(cache_key, cached_models, cached_at)
return cached_models
from aisbf.handlers import RequestHandler from aisbf.handlers import RequestHandler
from starlette.requests import Request as StarletteRequest from starlette.requests import Request as StarletteRequest
...@@ -56,20 +108,8 @@ async def fetch_provider_models(provider_id: str, config, user_id: Optional[int] ...@@ -56,20 +108,8 @@ async def fetch_provider_models(provider_id: str, config, user_id: Optional[int]
models = await request_handler.handle_model_list(dummy_request, provider_id) models = await request_handler.handle_model_list(dummy_request, provider_id)
now = time.time() now = _store_provider_models_in_cache(cache_key, models)
_model_cache[cache_key] = models _store_endpoint_models_in_cache(endpoint_key, models, now)
_model_cache_timestamps[cache_key] = now
if not user_id and config is not None:
try:
prov_cfg = config.get_provider(provider_id)
prov_type = getattr(prov_cfg, 'type', '')
endpoint = getattr(prov_cfg, 'endpoint', '') or ''
endpoint_key = f"{prov_type}:{endpoint}"
if endpoint_key and endpoint_key not in _endpoint_model_cache:
_endpoint_model_cache[endpoint_key] = (models, now)
except Exception:
pass
logger.info(f"Cached {len(models)} models from provider: {provider_id}") logger.info(f"Cached {len(models)} models from provider: {provider_id}")
return models return models
...@@ -127,24 +167,21 @@ async def get_provider_models(provider_id: str, provider_config, config, user_id ...@@ -127,24 +167,21 @@ async def get_provider_models(provider_id: str, provider_config, config, user_id
for model in provider_config.models for model in provider_config.models
] ]
cache_key = f"{provider_id}:{user_id}" if user_id else provider_id cache_key = _cache_key_for_provider(provider_id, user_id)
if cache_key in _model_cache: cached_models = _get_cached_provider_models(cache_key)
cache_age = time.time() - _model_cache_timestamps.get(cache_key, 0) if cached_models:
if cache_age < _cache_refresh_interval: models = []
cached_models = _model_cache[cache_key] for model in cached_models:
if cached_models: mc = model.copy()
models = [] mc['id'] = f"{provider_id}/{model.get('id', model.get('name', ''))}"
for model in cached_models: mc.setdefault('object', 'model')
mc = model.copy() mc.setdefault('created', current_time)
mc['id'] = f"{provider_id}/{model.get('id', model.get('name', ''))}" mc.setdefault('owned_by', provider_config.name)
mc.setdefault('object', 'model') mc['provider'] = provider_id
mc.setdefault('created', current_time) mc['type'] = 'provider'
mc.setdefault('owned_by', provider_config.name) mc['source'] = 'api_cache'
mc['provider'] = provider_id models.append(mc)
mc['type'] = 'provider' return models
mc['source'] = 'api_cache'
models.append(mc)
return models
api_key = getattr(provider_config, 'api_key', None) api_key = getattr(provider_config, 'api_key', None)
api_key_required = getattr(provider_config, 'api_key_required', True) api_key_required = getattr(provider_config, 'api_key_required', True)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment