Fix coderai support

parent ce6a2513
......@@ -23,7 +23,9 @@ def _is_broker_only_coderai(provider_config) -> bool:
coderai_config = getattr(provider_config, 'coderai_config', None) or {}
if not isinstance(coderai_config, dict):
return False
return bool(coderai_config.get('broker_mode', False))
# broker_preferred also defers to the incoming WSS connection; direct fallback
# is blocked, so there is no point attempting a prefetch before a session exists.
return bool(coderai_config.get('broker_mode', False) or coderai_config.get('broker_preferred', False))
def _is_runpod_public_provider(provider_config) -> bool:
......
......@@ -403,9 +403,12 @@ class CoderAIBroker:
candidates.sort(key=lambda session: session.last_seen, reverse=True)
return candidates[0]
# Session exists on another broker node: we have no local WebSocket handle,
# so we cannot deliver requests directly. Return None; send_request will
# route through the Redis queue to the owning node.
snapshot = await self.get_session_snapshot(provider_id, client_id)
if not snapshot or not snapshot.get("connected"):
return None
if snapshot and snapshot.get("connected"):
logger.debug(f"CoderAI session for provider={provider_id} is connected on a remote broker node")
return None
async def list_sessions(self) -> list[Dict[str, Any]]:
......
......@@ -463,6 +463,8 @@ class CoderAIProviderHandler(BaseProviderHandler):
raise Exception(message.get("error") or "CoderAI WebSocket request failed")
self.record_success()
return message.get("payload") or {}
if self._broker_preferred:
raise RuntimeError(f"[{self.provider_id}] No active CoderAI broker session; direct fallback not allowed with broker_preferred=True")
response = self.client.chat.completions.create(**payload)
self.record_success()
......@@ -484,6 +486,8 @@ class CoderAIProviderHandler(BaseProviderHandler):
if (message.get("status") or "ok") == "error":
raise Exception(message.get("error") or "CoderAI model discovery failed")
return self._extract_models(message.get("payload") or {})
if self._broker_preferred:
raise RuntimeError(f"[{self.provider_id}] No active CoderAI broker session; direct fallback not allowed with broker_preferred=True")
models = self.client.models.list()
payload = {"data": [m.model_dump() if hasattr(m, "model_dump") else m for m in models]}
return self._extract_models(payload)
......@@ -504,8 +508,8 @@ class CoderAIProviderHandler(BaseProviderHandler):
if (message.get("status") or "ok") == "error":
raise Exception(message.get("error") or "CoderAI capability discovery failed")
return message.get("payload") or {}
if self._broker_mode:
raise Exception("CoderAI broker mode requires an active broker session")
if self._broker_mode or self._broker_preferred:
raise Exception(f"[{self.provider_id}] No active CoderAI broker session; direct fallback not allowed")
return await self._http_json("GET", "/coderai/capabilities", timeout=self._model_timeout)
async def register_client(self) -> Dict[str, Any]:
......@@ -525,8 +529,8 @@ class CoderAIProviderHandler(BaseProviderHandler):
if (message.get("status") or "ok") == "error":
raise Exception(message.get("error") or "CoderAI registration failed")
return message.get("payload") or {}
if self._broker_mode:
raise Exception("CoderAI broker mode does not support outbound registration")
if self._broker_mode or self._broker_preferred:
raise Exception(f"[{self.provider_id}] No active CoderAI broker session; direct fallback not allowed")
return await self._http_json("POST", self._registration_path, payload, timeout=self._model_timeout)
async def proxy_native_request(
......@@ -570,7 +574,7 @@ class CoderAIProviderHandler(BaseProviderHandler):
raise Exception(message.get("error") or "CoderAI proxy request failed")
envelope = message.get("payload") or {}
return int(envelope.get("status_code") or 200), envelope
if self._broker_mode:
raise Exception("CoderAI broker mode requires an active broker session")
if self._broker_mode or self._broker_preferred:
raise Exception(f"[{self.provider_id}] No active CoderAI broker session; direct fallback not allowed")
response = await self._http_json(method.upper(), endpoint_path, body or {}, timeout=self._request_timeout)
return 200, response
......@@ -18,6 +18,23 @@ router = APIRouter()
logger = logging.getLogger(__name__)
async def _broker_refresh_models(provider_id: str, user_id: Optional[int]) -> None:
"""Fetch and cache the model list for a provider that just connected via broker."""
try:
from aisbf.app.model_cache import fetch_provider_models
from aisbf.config import config as aisbf_config
models = await fetch_provider_models(provider_id, aisbf_config, user_id=user_id)
logger.info(
"CoderAI broker model refresh: provider=%s user=%s models=%d",
provider_id, user_id, len(models),
)
except Exception:
logger.warning(
"CoderAI broker model refresh failed for provider=%s user=%s",
provider_id, user_id, exc_info=True,
)
def _coderai_register_payload(
provider_id: str,
client_id: str,
......@@ -38,6 +55,30 @@ def _coderai_register_payload(
}
def _extract_register_metadata(
payload: dict,
owner_user_id: Optional[int],
username: str,
scope_name: str,
proxy_scheme: str,
) -> dict:
hardware = payload.get("hardware") or {}
return {
"endpoint": payload.get("endpoint"),
"transport": payload.get("transport"),
"studio_endpoints": payload.get("studio_endpoints") or [],
"hardware": hardware,
"gpus": payload.get("gpus") or hardware.get("gpus") or [],
"gpu_count": payload.get("gpu_count") or hardware.get("gpu_count"),
"total_vram_mb": payload.get("total_vram_mb") or hardware.get("total_vram_mb"),
"available_vram_mb": payload.get("available_vram_mb") or hardware.get("available_vram_mb"),
"owner_user_id": owner_user_id,
"username": username,
"scope_name": scope_name,
"proxy_scheme": proxy_scheme,
}
async def _coderai_broker_websocket_impl(websocket: WebSocket, scope_name: str):
provider_id = websocket.query_params.get("provider_id") or websocket.headers.get("x-coderai-provider-id") or "coderai"
client_id = websocket.query_params.get("client_id") or websocket.headers.get("x-coderai-client-id") or f"anon-{int(time.time())}"
......@@ -49,25 +90,59 @@ async def _coderai_broker_websocket_impl(websocket: WebSocket, scope_name: str):
return
await websocket.accept()
expected_scope = scope_name
session = await broker.register(websocket, provider_id, client_id, metadata={"source": "websocket", "owner_user_id": owner_user_id, "username": username, "scope_name": expected_scope, "proxy_scheme": websocket.url.scheme})
# Client speaks first: wait for op=register before creating the session.
# This ensures the session is stored with full capability and hardware metadata
# from the very beginning rather than via a follow-up touch().
try:
raw = await asyncio.wait_for(websocket.receive_text(), timeout=30.0)
except asyncio.TimeoutError:
await websocket.close(code=1008, reason="registration timeout: expected op=register")
return
first_msg = json.loads(raw)
if first_msg.get("op") != "register":
await websocket.close(code=1008, reason=f"expected op=register, got {first_msg.get('op')!r}")
return
first_payload = first_msg.get("payload") or {}
payload_token = first_payload.get("registration_token") or first_msg.get("registration_token")
if payload_token and payload_token != presented_token:
await websocket.close(code=1008, reason="registration token mismatch")
return
capabilities = first_payload.get("capabilities") or first_msg.get("capabilities") or {}
metadata = _extract_register_metadata(first_payload, owner_user_id, username, expected_scope, websocket.url.scheme)
session = await broker.register(
websocket, provider_id, client_id,
metadata=metadata, capabilities=capabilities,
)
# Respond to op=register with event=registered.
registered_payload = _coderai_register_payload(session.provider_id, session.client_id, username, owner_user_id, expected_scope)
registered_payload.update({"session_id": session.session_id, "status": "ok", "request_id": first_msg.get("request_id")})
await websocket.send_text(json.dumps(registered_payload))
logger.info(
"CoderAI broker registered provider=%s client=%s session_id=%s scope=%s",
provider_id, client_id, session.session_id, expected_scope,
)
# Populate the model cache in the background now that the session is live.
asyncio.create_task(_broker_refresh_models(session.provider_id, owner_user_id))
async def _drain_broker_queue() -> None:
while True:
queued = await broker.consume_request(session.session_id, timeout=1)
if queued is not None:
await websocket.send_text(json.dumps(queued))
await broker.touch(session.session_id, metadata={"proxy_scheme": websocket.url.scheme, "username": username, "scope_name": expected_scope})
try:
payload = _coderai_register_payload(session.provider_id, session.client_id, username, owner_user_id, expected_scope)
payload["session_id"] = session.session_id
await websocket.send_text(json.dumps(payload))
queue_task = asyncio.create_task(_drain_broker_queue())
try:
while True:
raw = await websocket.receive_text()
message = json.loads(raw)
op = message.get("op")
if op == "register":
# Re-registration during an active session (e.g. after model reload).
payload = message.get("payload") or {}
payload_token = payload.get("registration_token") or message.get("registration_token")
if payload_token and payload_token != presented_token:
......@@ -78,31 +153,13 @@ async def _coderai_broker_websocket_impl(websocket: WebSocket, scope_name: str):
"error": "Registration token mismatch",
}))
continue
capabilities = payload.get("capabilities") or message.get("capabilities") or {}
metadata = {
"endpoint": payload.get("endpoint"),
"transport": payload.get("transport"),
"studio_endpoints": payload.get("studio_endpoints") or [],
"hardware": payload.get("hardware") or {},
"gpus": payload.get("gpus") or ((payload.get("hardware") or {}).get("gpus")) or [],
"gpu_count": payload.get("gpu_count") or ((payload.get("hardware") or {}).get("gpu_count")),
"total_vram_mb": payload.get("total_vram_mb") or ((payload.get("hardware") or {}).get("total_vram_mb")),
"available_vram_mb": payload.get("available_vram_mb") or ((payload.get("hardware") or {}).get("available_vram_mb")),
"owner_user_id": owner_user_id,
"username": username,
"scope_name": expected_scope,
"proxy_scheme": websocket.url.scheme,
}
await broker.touch(session.session_id, metadata=metadata, capabilities=capabilities)
await websocket.send_text(json.dumps({
"v": 1,
"request_id": message.get("request_id"),
"status": "ok",
"payload": {
**_coderai_register_payload(session.provider_id, session.client_id, username, owner_user_id, expected_scope),
"session_id": session.session_id,
},
}))
new_caps = payload.get("capabilities") or message.get("capabilities") or {}
new_meta = _extract_register_metadata(payload, owner_user_id, username, expected_scope, websocket.url.scheme)
await broker.touch(session.session_id, metadata=new_meta, capabilities=new_caps)
re_payload = _coderai_register_payload(session.provider_id, session.client_id, username, owner_user_id, expected_scope)
re_payload.update({"session_id": session.session_id, "status": "ok", "request_id": message.get("request_id")})
await websocket.send_text(json.dumps(re_payload))
asyncio.create_task(_broker_refresh_models(session.provider_id, owner_user_id))
continue
if op == "heartbeat":
await broker.touch(session.session_id, metadata=message.get("payload") or {})
......@@ -121,7 +178,6 @@ async def _coderai_broker_websocket_impl(websocket: WebSocket, scope_name: str):
except Exception as e:
logger.error(f"CoderAI broker websocket error: {e}", exc_info=True)
finally:
if 'queue_task' in locals():
queue_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await queue_task
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment