feat: resolve market references at runtime

parent e6fd64d9
...@@ -234,6 +234,27 @@ def _apply_listing_derived_fields(listing: dict, db): ...@@ -234,6 +234,27 @@ def _apply_listing_derived_fields(listing: dict, db):
return _attach_analytics_snapshot(listing) return _attach_analytics_snapshot(listing)
async def resolve_market_reference(reference_id: int, user_id: int) -> dict:
db = DatabaseRegistry.get_config_database()
reference = db.get_market_import_reference(reference_id)
if not reference or reference.get('user_id') != user_id:
raise ValueError('market reference not found')
listing = db.get_market_listing(reference.get('listing_id'))
if not listing or not listing.get('is_active'):
raise ValueError('market reference unavailable')
return {
'reference': reference,
'listing': listing,
'listing_id': listing.get('id'),
'owner_user_id': listing.get('owner_user_id'),
'owner_username': listing.get('owner_username'),
'source_type': listing.get('source_type'),
'source_id': listing.get('source_id'),
}
def _listing_scope_priority(listing: dict) -> int: def _listing_scope_priority(listing: dict) -> int:
return 1 if listing.get('source_scope') == 'global' else 0 return 1 if listing.get('source_scope') == 'global' else 0
......
...@@ -6,6 +6,7 @@ from aisbf.models import ChatCompletionRequest ...@@ -6,6 +6,7 @@ from aisbf.models import ChatCompletionRequest
from aisbf.database import DatabaseRegistry from aisbf.database import DatabaseRegistry
from aisbf.app.model_cache import get_provider_models from aisbf.app.model_cache import get_provider_models
from aisbf.studio_services import studio_service from aisbf.studio_services import studio_service
from aisbf.routes.dashboard.market import resolve_market_reference
router = APIRouter() router = APIRouter()
_config = None _config = None
...@@ -50,6 +51,27 @@ def parse_provider_from_model(model: str) -> tuple[str, str]: ...@@ -50,6 +51,27 @@ def parse_provider_from_model(model: str) -> tuple[str, str]:
return None, model return None, model
async def _resolve_runtime_market_reference(handler, resource_id: str, user_id: int, expected_type: str) -> dict | None:
if not user_id or not resource_id.startswith('market-ref:'):
return None
if resource_id not in getattr(handler, 'user_providers', {}) and resource_id not in getattr(handler, 'rotations', {}) and resource_id not in getattr(handler, 'autoselects', {}):
return None
try:
reference_id = int(resource_id.split(':', 1)[1])
except (TypeError, ValueError):
raise HTTPException(status_code=400, detail='Invalid market reference id')
try:
resolved = await resolve_market_reference(reference_id, user_id)
except ValueError as exc:
message = str(exc)
if 'unavailable' in message:
raise HTTPException(status_code=400, detail='Market reference unavailable')
raise HTTPException(status_code=404, detail='Market reference not found')
if resolved.get('source_type') != expected_type:
raise HTTPException(status_code=400, detail='Market reference type mismatch')
return resolved
def _normalize_studio_proxy_body(endpoint_path: str, body: dict) -> dict: def _normalize_studio_proxy_body(endpoint_path: str, body: dict) -> dict:
normalized = dict(body or {}) normalized = dict(body or {})
...@@ -349,6 +371,12 @@ async def user_chat_completions(request: Request, username: str, body: ChatCompl ...@@ -349,6 +371,12 @@ async def user_chat_completions(request: Request, username: str, body: ChatCompl
body_dict = body.model_dump() body_dict = body.model_dump()
if provider_id == "user-autoselect": if provider_id == "user-autoselect":
handler = _get_user_handler('autoselect', user_id) handler = _get_user_handler('autoselect', user_id)
market_reference = await _resolve_runtime_market_reference(handler, actual_model, user_id, 'autoselect')
if market_reference:
owner_handler = _get_user_handler('autoselect', market_reference['owner_user_id'])
body_dict['model'] = market_reference['source_id']
token_id = getattr(request.state, 'token_id', None)
return await owner_handler.handle_autoselect_request(market_reference['source_id'], body_dict, user_id, token_id)
if actual_model not in handler.user_autoselects: if actual_model not in handler.user_autoselects:
raise HTTPException(status_code=400, detail=f"User autoselect '{actual_model}' not found. Available: {list(handler.user_autoselects.keys())}") raise HTTPException(status_code=400, detail=f"User autoselect '{actual_model}' not found. Available: {list(handler.user_autoselects.keys())}")
body_dict['model'] = actual_model body_dict['model'] = actual_model
...@@ -359,6 +387,12 @@ async def user_chat_completions(request: Request, username: str, body: ChatCompl ...@@ -359,6 +387,12 @@ async def user_chat_completions(request: Request, username: str, body: ChatCompl
return await handler.handle_autoselect_request(actual_model, body_dict, user_id, token_id) return await handler.handle_autoselect_request(actual_model, body_dict, user_id, token_id)
if provider_id == "user-rotation": if provider_id == "user-rotation":
handler = _get_user_handler('rotation', user_id) handler = _get_user_handler('rotation', user_id)
market_reference = await _resolve_runtime_market_reference(handler, actual_model, user_id, 'rotation')
if market_reference:
owner_handler = _get_user_handler('rotation', market_reference['owner_user_id'])
body_dict['model'] = market_reference['source_id']
token_id = getattr(request.state, 'token_id', None)
return await owner_handler.handle_rotation_request(market_reference['source_id'], body_dict, user_id, token_id)
if actual_model not in handler.rotations: if actual_model not in handler.rotations:
raise HTTPException(status_code=400, detail=f"User rotation '{actual_model}' not found. Available: {list(handler.rotations.keys())}") raise HTTPException(status_code=400, detail=f"User rotation '{actual_model}' not found. Available: {list(handler.rotations.keys())}")
body_dict['model'] = actual_model body_dict['model'] = actual_model
...@@ -366,6 +400,13 @@ async def user_chat_completions(request: Request, username: str, body: ChatCompl ...@@ -366,6 +400,13 @@ async def user_chat_completions(request: Request, username: str, body: ChatCompl
return await handler.handle_rotation_request(actual_model, body_dict, user_id, token_id) return await handler.handle_rotation_request(actual_model, body_dict, user_id, token_id)
if provider_id == "user-provider": if provider_id == "user-provider":
handler = _get_user_handler('request', user_id) handler = _get_user_handler('request', user_id)
market_reference = await _resolve_runtime_market_reference(handler, actual_model, user_id, 'provider')
if market_reference:
owner_handler = _get_user_handler('request', market_reference['owner_user_id'])
body_dict['model'] = market_reference['source_id']
if body.stream:
return await owner_handler.handle_streaming_chat_completion(request, market_reference['source_id'], body_dict)
return await owner_handler.handle_chat_completion(request, market_reference['source_id'], body_dict)
if actual_model not in handler.user_providers: if actual_model not in handler.user_providers:
raise HTTPException(status_code=400, detail=f"User provider '{actual_model}' not found. Available: {list(handler.user_providers.keys())}") raise HTTPException(status_code=400, detail=f"User provider '{actual_model}' not found. Available: {list(handler.user_providers.keys())}")
body_dict['model'] = actual_model body_dict['model'] = actual_model
......
...@@ -234,7 +234,7 @@ def test_ensure_market_enabled_respects_market_settings_enabled_contract(monkeyp ...@@ -234,7 +234,7 @@ def test_ensure_market_enabled_respects_market_settings_enabled_contract(monkeyp
assert settings["enabled"] is False assert settings["enabled"] is False
def test_admin_payment_settings_embeds_market_admin_controls_and_listings(monkeypatch): def test_admin_payment_settings_links_to_dedicated_market_admin_page(monkeypatch):
db = MarketSettingsDbStub() db = MarketSettingsDbStub()
templates = TemplateCapture() templates = TemplateCapture()
client = TestClient(app) client = TestClient(app)
...@@ -248,10 +248,10 @@ def test_admin_payment_settings_embeds_market_admin_controls_and_listings(monkey ...@@ -248,10 +248,10 @@ def test_admin_payment_settings_embeds_market_admin_controls_and_listings(monkey
assert response.status_code == 200 assert response.status_code == 200
assert templates.calls[-1]["name"] == "dashboard/admin_payment_settings.html" assert templates.calls[-1]["name"] == "dashboard/admin_payment_settings.html"
assert templates.calls[-1]["context"]["market_settings"]["enabled"] is True assert templates.calls[-1]["context"]["market_settings"]["enabled"] is True
assert templates.calls[-1]["context"]["market_listings"][0]["title"] == "Flux Dev Pack" assert "market_listings" not in templates.calls[-1]["context"]
assert "Enable market" in response.text assert "Enable market" in response.text
assert "Market Administration" in response.text assert "Market Administration" in response.text
assert "Flux Dev Pack" in response.text assert "Open Market Administration" in response.text
def test_base_nav_hides_market_admin_link(monkeypatch): def test_base_nav_hides_market_admin_link(monkeypatch):
......
...@@ -2,11 +2,13 @@ import json ...@@ -2,11 +2,13 @@ import json
import sys import sys
from base64 import b64encode from base64 import b64encode
from pathlib import Path from pathlib import Path
import pytest
from fastapi.responses import HTMLResponse from fastapi.responses import HTMLResponse
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from itsdangerous import TimestampSigner from itsdangerous import TimestampSigner
from jinja2 import Environment, FileSystemLoader, select_autoescape from jinja2 import Environment, FileSystemLoader, select_autoescape
from aisbf.models import ChatCompletionRequest
from aisbf.routes.dashboard import market as dashboard_market from aisbf.routes.dashboard import market as dashboard_market
...@@ -165,6 +167,13 @@ class MarketReferenceImportDbStub: ...@@ -165,6 +167,13 @@ class MarketReferenceImportDbStub:
def get_market_settings(self): def get_market_settings(self):
return dict(self.market_settings) return dict(self.market_settings)
def get_user_by_id(self, user_id):
if user_id == 11:
return {"id": 11, "username": "buyer"}
if user_id == 7:
return {"id": 7, "username": "seller"}
return None
def get_market_listing(self, listing_id): def get_market_listing(self, listing_id):
listings = { listings = {
self.listing["id"]: self.listing, self.listing["id"]: self.listing,
...@@ -243,6 +252,43 @@ class RegistryStub: ...@@ -243,6 +252,43 @@ class RegistryStub:
return self._db return self._db
class RuntimeUserHandlerStub:
def __init__(self):
self.calls = []
self.user_providers = {}
self.rotations = {}
self.autoselects = {}
self.user_autoselects = {}
async def handle_chat_completion(self, request, provider_id, body_dict):
self.calls.append(("provider", provider_id, body_dict))
return {"ok": True, "provider_id": provider_id, "model": body_dict.get("model")}
async def handle_rotation_request(self, rotation_id, body_dict, user_id, token_id):
self.calls.append(("rotation", rotation_id, body_dict, user_id, token_id))
return {"ok": True, "rotation_id": rotation_id}
async def handle_autoselect_request(self, autoselect_id, body_dict, user_id, token_id):
self.calls.append(("autoselect", autoselect_id, body_dict, user_id, token_id))
return {"ok": True, "autoselect_id": autoselect_id}
class MarketReferenceRuntimeDbStub(MarketReferenceImportDbStub):
def admin_set_market_listing_active(self, listing_id, is_active):
listing = self.get_market_listing(listing_id)
if not listing:
return False
if listing_id == self.listing["id"]:
self.listing["is_active"] = bool(is_active)
elif listing_id == self.rotation_listing["id"]:
self.rotation_listing["is_active"] = bool(is_active)
elif listing_id == self.model_listing["id"]:
self.model_listing["is_active"] = bool(is_active)
elif listing_id == self.autoselect_listing["id"]:
self.autoselect_listing["is_active"] = bool(is_active)
return True
def _find_session_secret() -> str: def _find_session_secret() -> str:
for middleware in app.user_middleware: for middleware in app.user_middleware:
kwargs = getattr(middleware, "kwargs", {}) kwargs = getattr(middleware, "kwargs", {})
...@@ -272,6 +318,155 @@ def _login_as_user(client: TestClient, user_id: int = 11) -> None: ...@@ -272,6 +318,155 @@ def _login_as_user(client: TestClient, user_id: int = 11) -> None:
) )
def _login_user_api_request(client: TestClient, username: str = "buyer", user_id: int = 11) -> None:
_set_session_cookie(
client,
{
"logged_in": True,
"username": username,
"role": "user",
"user_id": user_id,
"expires_at": 4102444800,
},
)
@pytest.fixture
def runtime_fixture(monkeypatch):
db = MarketReferenceRuntimeDbStub()
reference_id = db.create_market_import_reference(
user_id=11,
listing_id=db.listing["id"],
reference_type="provider",
display_name=db.listing["title"],
owner_username=db.listing["owner_username"],
source_type=db.listing["source_type"],
source_id=db.listing["source_id"],
)
monkeypatch.setattr(dashboard_market, "DatabaseRegistry", RegistryStub(db))
return {
"db": db,
"handler": dashboard_market,
"reference_id": reference_id,
"buyer_user_id": 11,
"listing_id": db.listing["id"],
"seller_user_id": db.listing["owner_user_id"],
}
@pytest.mark.asyncio
async def test_market_reference_resolves_to_source_listing_at_runtime(runtime_fixture):
handler = runtime_fixture["handler"]
resolved = await handler.resolve_market_reference(runtime_fixture["reference_id"], runtime_fixture["buyer_user_id"])
assert resolved["listing_id"] == runtime_fixture["listing_id"]
assert resolved["owner_user_id"] == runtime_fixture["seller_user_id"]
assert resolved["source_id"] == "seller-provider"
@pytest.mark.asyncio
async def test_market_reference_rejects_disabled_listing(runtime_fixture):
runtime_fixture["db"].admin_set_market_listing_active(runtime_fixture["listing_id"], False)
handler = runtime_fixture["handler"]
with pytest.raises(ValueError, match="unavailable"):
await handler.resolve_market_reference(runtime_fixture["reference_id"], runtime_fixture["buyer_user_id"])
class DirectRequestStub:
def __init__(self, user_id: int):
self.state = type("State", (), {})()
self.state.user_id = user_id
self.state.is_admin = False
self.state.is_global_token = False
self.state.token_id = None
@pytest.mark.asyncio
async def test_user_provider_market_reference_executes_via_seller_handler(monkeypatch):
db = MarketReferenceRuntimeDbStub()
buyer_handler = RuntimeUserHandlerStub()
seller_handler = RuntimeUserHandlerStub()
reference_id = db.create_market_import_reference(
user_id=11,
listing_id=db.listing["id"],
reference_type="provider",
display_name=db.listing["title"],
owner_username=db.listing["owner_username"],
source_type=db.listing["source_type"],
source_id=db.listing["source_id"],
)
monkeypatch.setattr(dashboard_market, "DatabaseRegistry", RegistryStub(db))
from aisbf.routes import user_api
monkeypatch.setattr(user_api, "DatabaseRegistry", RegistryStub(db))
def fake_get_user_handler(kind, user_id=None):
if kind == "request" and user_id in (None, 11):
buyer_handler.user_providers = {f"market-ref:{reference_id}": {"market_reference": True}}
return buyer_handler
if kind == "request" and user_id == 7:
return seller_handler
raise AssertionError((kind, user_id))
monkeypatch.setattr(user_api, "_get_user_handler", fake_get_user_handler)
result = await user_api.user_chat_completions(
DirectRequestStub(11),
"buyer",
ChatCompletionRequest(model=f"user-provider/market-ref:{reference_id}", messages=[{"role": "user", "content": "hi"}]),
)
assert result["provider_id"] == "seller-provider"
assert len(seller_handler.calls) == 1
call_kind, provider_id, payload = seller_handler.calls[0]
assert call_kind == "provider"
assert provider_id == "seller-provider"
assert payload["model"] == "seller-provider"
assert payload["messages"][0]["content"] == "hi"
@pytest.mark.asyncio
async def test_user_provider_market_reference_rejects_unavailable_listing_at_runtime(monkeypatch):
db = MarketReferenceRuntimeDbStub()
buyer_handler = RuntimeUserHandlerStub()
reference_id = db.create_market_import_reference(
user_id=11,
listing_id=db.listing["id"],
reference_type="provider",
display_name=db.listing["title"],
owner_username=db.listing["owner_username"],
source_type=db.listing["source_type"],
source_id=db.listing["source_id"],
)
db.admin_set_market_listing_active(db.listing["id"], False)
monkeypatch.setattr(dashboard_market, "DatabaseRegistry", RegistryStub(db))
from aisbf.routes import user_api
monkeypatch.setattr(user_api, "DatabaseRegistry", RegistryStub(db))
def fake_get_user_handler(kind, user_id=None):
if kind == "request" and user_id in (None, 11):
buyer_handler.user_providers = {f"market-ref:{reference_id}": {"market_reference": True}}
return buyer_handler
raise AssertionError((kind, user_id))
monkeypatch.setattr(user_api, "_get_user_handler", fake_get_user_handler)
with pytest.raises(Exception) as exc_info:
await user_api.user_chat_completions(
DirectRequestStub(11),
"buyer",
ChatCompletionRequest(model=f"user-provider/market-ref:{reference_id}", messages=[{"role": "user", "content": "hi"}]),
)
assert exc_info.value.status_code == 400
assert exc_info.value.detail == "Market reference unavailable"
def _seed_dashboard_market_reference_mix(db: MarketReferenceImportDbStub) -> None: def _seed_dashboard_market_reference_mix(db: MarketReferenceImportDbStub) -> None:
provider_reference = { provider_reference = {
"id": 1, "id": 1,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment