feat(studio): add shared capability inference service

parent 0cfd0f54
......@@ -37,6 +37,7 @@ from fastapi.responses import JSONResponse, StreamingResponse
from .models import ChatCompletionRequest, ChatCompletionResponse
from .providers import get_provider_handler, RateLimitError
from .config import config
from .studio import infer_model_capabilities
from .utils import (
count_messages_tokens,
split_messages_into_chunks,
......@@ -1669,139 +1670,8 @@ class RequestHandler:
return 4096
def _detect_capabilities(self, model_name: str, provider_type: str) -> List[str]:
"""Auto-detect model capabilities based on model name and provider type"""
model_lower = model_name.lower()
capabilities = []
# Text-to-text (default for most models)
if not any(keyword in model_lower for keyword in ['embedding', 'embed', 'whisper', 'tts', 'dall-e', 'stable-diffusion']):
capabilities.append('t2t')
# Text-to-image generation
if any(keyword in model_lower for keyword in ['dall-e', 'dalle', 'stable-diffusion', 'sd-', 'sdxl', 'midjourney', 'imagen', 'flux']):
capabilities.append('t2i')
# Image-to-image (editing, style transfer)
if any(keyword in model_lower for keyword in ['stable-diffusion', 'sd-', 'sdxl', 'controlnet', 'img2img']):
capabilities.append('i2i')
# Vision/Image understanding (image-to-text)
if any(keyword in model_lower for keyword in ['vision', 'gpt-4-turbo', 'gpt-4o', 'claude-3', 'gemini-1.5', 'gemini-2.0', 'gemini-pro-vision', 'llava', 'blip']):
capabilities.append('vision')
capabilities.append('i2t')
# Audio transcription (audio-to-text)
if any(keyword in model_lower for keyword in ['whisper', 'transcribe', 'speech-to-text', 'stt']):
capabilities.append('transcription')
capabilities.append('a2t')
# Text-to-speech
if any(keyword in model_lower for keyword in ['tts', 'text-to-speech', 'elevenlabs', 'bark', 'tortoise']):
capabilities.append('tts')
capabilities.append('t2a')
# Text-to-video generation
if any(keyword in model_lower for keyword in ['sora', 'runway', 'pika', 'text-to-video', 't2v']):
capabilities.append('t2v')
# Image-to-video generation
if any(keyword in model_lower for keyword in ['runway', 'pika', 'img2video', 'i2v']):
capabilities.append('i2v')
# Video-to-video (editing)
if any(keyword in model_lower for keyword in ['runway', 'video-edit', 'v2v']):
capabilities.append('v2v')
# Video understanding (video-to-text)
if any(keyword in model_lower for keyword in ['video-llama', 'video-chat', 'v2t']):
capabilities.append('v2t')
# Audio-to-audio (music generation, audio processing)
if any(keyword in model_lower for keyword in ['musicgen', 'audiogen', 'riffusion', 'a2a']):
capabilities.append('a2a')
# Text embeddings
if any(keyword in model_lower for keyword in ['embedding', 'embed', 'ada-002', 'bge', 'e5', 'instructor']):
capabilities.append('embeddings')
# Function calling / tool use
if any(keyword in model_lower for keyword in ['gpt-4', 'gpt-3.5-turbo', 'claude-3', 'gemini', 'function', 'tool']):
capabilities.append('function_calling')
# Code generation
if any(keyword in model_lower for keyword in ['codex', 'code-', 'starcoder', 'codellama', 'deepseek-coder', 'phind']):
capabilities.append('code_generation')
capabilities.append('code_completion')
# Translation
if any(keyword in model_lower for keyword in ['translate', 'translation', 'm2m', 'nllb']):
capabilities.append('translation')
# Summarization
if any(keyword in model_lower for keyword in ['summarize', 'summary', 'bart', 'pegasus']):
capabilities.append('summarization')
# Classification
if any(keyword in model_lower for keyword in ['classifier', 'classification', 'bert-', 'roberta-']):
capabilities.append('classification')
# Sentiment analysis
if any(keyword in model_lower for keyword in ['sentiment', 'emotion']):
capabilities.append('sentiment_analysis')
# Named Entity Recognition
if any(keyword in model_lower for keyword in ['ner', 'entity', 'spacy']):
capabilities.append('ner')
# Question answering
if any(keyword in model_lower for keyword in ['qa', 'question', 'squad']):
capabilities.append('question_answering')
# Reasoning (chain-of-thought)
if any(keyword in model_lower for keyword in ['reasoning', 'cot', 'o1', 'o3']):
capabilities.append('reasoning')
# Search / RAG
if any(keyword in model_lower for keyword in ['search', 'retrieval', 'rag']):
capabilities.append('search')
# Content moderation
if any(keyword in model_lower for keyword in ['moderation', 'safety', 'content-filter']):
capabilities.append('moderation')
# Fine-tuning support
if any(keyword in model_lower for keyword in ['fine-tune', 'finetune', 'ft-']):
capabilities.append('fine_tuning')
# Multimodal (multiple input/output types)
if any(keyword in model_lower for keyword in ['gpt-4o', 'gemini', 'claude-3', 'multimodal', 'mm-']):
capabilities.append('multimodal')
# OCR (Optical Character Recognition)
if any(keyword in model_lower for keyword in ['ocr', 'tesseract', 'paddleocr', 'easyocr']):
capabilities.append('ocr')
# Image captioning
if any(keyword in model_lower for keyword in ['caption', 'blip', 'git-']):
capabilities.append('image_captioning')
# Object detection
if any(keyword in model_lower for keyword in ['yolo', 'detection', 'rcnn', 'detr']):
capabilities.append('object_detection')
# Segmentation
if any(keyword in model_lower for keyword in ['segment', 'sam', 'mask']):
capabilities.append('segmentation')
# 3D generation
if any(keyword in model_lower for keyword in ['3d', 'nerf', 'gaussian', 'mesh']):
capabilities.append('3d_generation')
# Animation
if any(keyword in model_lower for keyword in ['animate', 'motion', 'pose']):
capabilities.append('animation')
return capabilities
"""Auto-detect model capabilities based on model name and provider type."""
return infer_model_capabilities(model_name=model_name, provider_type=provider_type).capabilities
async def handle_generic_proxy(self, request: Request, provider_id: str, endpoint_path: str, body: dict, method: str = "POST") -> JSONResponse:
"""Forward a request to the provider's native endpoint and return the response."""
......@@ -3988,42 +3858,8 @@ class RotationHandler:
return 4096
def _detect_capabilities(self, model_name: str, provider_type: str) -> List[str]:
"""Auto-detect model capabilities based on model name and provider type"""
model_lower = model_name.lower()
capabilities = []
# Text-to-text is the default capability for all models
capabilities.append('t2t')
# Image generation models
if any(keyword in model_lower for keyword in ['dall-e', 'dalle', 'stable-diffusion', 'sd-', 'midjourney', 'imagen']):
capabilities.append('t2i')
# Vision models (can process images)
if any(keyword in model_lower for keyword in ['vision', 'gpt-4-turbo', 'gpt-4o', 'claude-3', 'gemini-1.5', 'gemini-2.0']):
capabilities.append('vision')
# Audio transcription models
if any(keyword in model_lower for keyword in ['whisper', 'transcribe']):
capabilities.append('transcription')
# Text-to-speech models
if any(keyword in model_lower for keyword in ['tts', 'text-to-speech', 'elevenlabs']):
capabilities.append('tts')
# Video generation models
if any(keyword in model_lower for keyword in ['sora', 'runway', 'pika', 'video']):
capabilities.append('i2v')
# Embedding models
if any(keyword in model_lower for keyword in ['embedding', 'embed', 'ada-002']):
capabilities.append('embeddings')
# Function calling / tool use
if any(keyword in model_lower for keyword in ['gpt-4', 'gpt-3.5-turbo', 'claude-3', 'gemini']):
capabilities.append('function_calling')
return capabilities
"""Auto-detect model capabilities based on model name and provider type."""
return infer_model_capabilities(model_name=model_name, provider_type=provider_type).capabilities
class AutoselectHandler:
def __init__(self, user_id=None):
......
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Dict, Iterable, List, Optional
STUDIO_CAPABILITY_MAP = {
"t2t": "chat",
"vision": "vision",
"i2t": "vision",
"t2i": "image_generation",
"i2i": "image_edit",
"t2v": "video_generation",
"i2v": "video_generation",
"v2v": "video_generation",
"v2t": "video_understanding",
"a2t": "audio_input",
"transcription": "transcription",
"tts": "speech_generation",
"t2a": "audio_generation",
"a2a": "audio_generation",
"embeddings": "embeddings",
"function_calling": "tool_use",
"reasoning": "reasoning",
"code_generation": "code_generation",
"code_completion": "code_completion",
"translation": "translation",
"summarization": "summarization",
"classification": "classification",
"sentiment_analysis": "sentiment_analysis",
"ner": "ner",
"question_answering": "question_answering",
"search": "search",
"moderation": "moderation",
"fine_tuning": "fine_tuning",
"multimodal": "multimodal",
"ocr": "ocr",
"image_captioning": "image_captioning",
"object_detection": "object_detection",
"segmentation": "segmentation",
"3d_generation": "3d_generation",
"animation": "animation",
}
DEFAULT_CHAT_PROVIDER_TYPES = {"openai", "anthropic", "google", "kilo", "claude", "qwen", "codex"}
@dataclass
class StudioCapabilityResult:
capabilities: List[str]
source: str
unknown: bool
notes: List[str]
@dataclass
class StudioCapabilityMergeResult:
capabilities: List[str]
partial_capabilities: List[str]
def _dedupe(values: Iterable[str]) -> List[str]:
seen: List[str] = []
for value in values:
if value and value not in seen:
seen.append(value)
return seen
def normalize_capabilities(values: Optional[Iterable[str]]) -> List[str]:
normalized: List[str] = []
for value in values or []:
normalized.append(STUDIO_CAPABILITY_MAP.get(value, value))
return _dedupe(normalized)
def infer_model_capabilities(
model_name: str,
provider_type: str,
explicit_capabilities: Optional[Iterable[str]] = None,
architecture: Optional[Dict[str, Any]] = None,
provider_metadata: Optional[Dict[str, Any]] = None,
) -> StudioCapabilityResult:
explicit = normalize_capabilities(explicit_capabilities)
if explicit:
return StudioCapabilityResult(capabilities=explicit, source="explicit", unknown=False, notes=[])
provider_metadata = provider_metadata or {}
capabilities = normalize_capabilities(provider_metadata.get("capabilities"))
source = "provider_metadata" if capabilities else "heuristic"
notes: List[str] = []
name = (model_name or "").lower()
architecture = architecture or {}
input_modalities = architecture.get("input_modalities") or []
output_modalities = architecture.get("output_modalities") or []
if not capabilities:
if not any(token in name for token in ["embedding", "embed", "whisper", "tts", "dall-e", "stable-diffusion"]):
capabilities.append("chat")
if any(token in name for token in ["vision", "gpt-4-turbo", "gpt-4o", "claude-3", "gemini-1.5", "gemini-2.0", "gemini-pro-vision", "llava", "blip"]):
capabilities.append("vision")
if any(token in name for token in ["dall-e", "dalle", "stable-diffusion", "sd-", "sdxl", "midjourney", "imagen", "flux"]):
capabilities.append("image_generation")
if any(token in name for token in ["stable-diffusion", "sd-", "sdxl", "controlnet", "img2img"]):
capabilities.append("image_edit")
if any(token in name for token in ["sora", "runway", "pika", "text-to-video", "t2v", "video"]):
capabilities.append("video_generation")
if any(token in name for token in ["whisper", "transcribe", "speech-to-text", "stt"]):
capabilities.extend(["audio_input", "transcription"])
if any(token in name for token in ["tts", "text-to-speech", "elevenlabs", "bark", "tortoise", "speech"]):
capabilities.append("speech_generation")
if any(token in name for token in ["musicgen", "audiogen", "riffusion", "a2a"]):
capabilities.append("audio_generation")
if any(token in name for token in ["embedding", "embed", "ada-002", "bge", "e5", "instructor"]):
capabilities.append("embeddings")
if any(token in name for token in ["gpt-4", "gpt-3.5-turbo", "claude-3", "gemini", "function", "tool"]):
capabilities.append("tool_use")
if any(token in name for token in ["reasoning", "cot", "o1", "o3"]):
capabilities.append("reasoning")
if "image" in input_modalities:
capabilities.append("vision")
if "audio" in input_modalities:
capabilities.append("audio_input")
if "text" in output_modalities and "audio" in input_modalities:
capabilities.append("transcription")
if "audio" in output_modalities:
capabilities.append("speech_generation")
capabilities = _dedupe(capabilities)
unknown = not capabilities
if unknown and provider_type in DEFAULT_CHAT_PROVIDER_TYPES:
capabilities = ["chat"]
source = "fallback"
unknown = False
notes.append(f"No confident Studio capabilities inferred for {provider_type}:{model_name}")
elif unknown:
notes.append(f"No confident Studio capabilities inferred for {provider_type}:{model_name}")
return StudioCapabilityResult(
capabilities=capabilities,
source=source,
unknown=unknown,
notes=notes,
)
def merge_capabilities(
base_capabilities: Optional[Iterable[str]],
override_capabilities: Optional[Iterable[str]],
support_mode: str = "union",
) -> StudioCapabilityMergeResult:
base = normalize_capabilities(base_capabilities)
override = normalize_capabilities(override_capabilities)
if support_mode == "intersection" and override:
capabilities = [capability for capability in override if capability in base]
partial_capabilities = [capability for capability in base if capability not in capabilities]
return StudioCapabilityMergeResult(capabilities=capabilities, partial_capabilities=partial_capabilities)
capabilities = _dedupe([*base, *override])
return StudioCapabilityMergeResult(capabilities=capabilities, partial_capabilities=[])
def build_catalog_entry(
scope: str,
owner_id: Optional[int],
kind: str,
source_id: str,
target_id: str,
label: str,
description: Optional[str],
capabilities: Optional[Iterable[str]],
availability_state: str,
availability_reason: Optional[str],
metadata: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
return {
"id": f"provider/{source_id}/{target_id}",
"kind": kind,
"owner_scope": scope,
"owner_id": owner_id,
"source_id": source_id,
"target_id": target_id,
"label": label,
"description": description,
"capabilities": normalize_capabilities(capabilities),
"availability_state": availability_state,
"availability_reason": availability_reason,
"metadata": metadata or {},
}
import pytest
from aisbf.studio import (
StudioCapabilityResult,
build_catalog_entry,
infer_model_capabilities,
merge_capabilities,
)
def test_infer_model_capabilities_prefers_explicit_capabilities():
result = infer_model_capabilities(
model_name="gpt-4o",
provider_type="openai",
explicit_capabilities=["chat", "vision"],
architecture={"input_modalities": ["text", "image"], "output_modalities": ["text"]},
)
assert isinstance(result, StudioCapabilityResult)
assert result.capabilities == ["chat", "vision"]
assert result.source == "explicit"
assert result.unknown is False
def test_infer_model_capabilities_uses_name_and_architecture_heuristics_when_explicit_missing():
result = infer_model_capabilities(
model_name="whisper-large-v3",
provider_type="openai",
explicit_capabilities=None,
architecture={"input_modalities": ["audio"], "output_modalities": ["text"]},
)
assert "audio_input" in result.capabilities
assert "transcription" in result.capabilities
assert result.source in {"provider_metadata", "heuristic"}
def test_merge_capabilities_keeps_explicit_values_and_reports_partial_support():
merged = merge_capabilities(
base_capabilities=["chat", "vision", "image_generation"],
override_capabilities=["chat", "vision"],
support_mode="intersection",
)
assert merged.capabilities == ["chat", "vision"]
assert merged.partial_capabilities == ["image_generation"]
def test_build_catalog_entry_normalizes_provider_model_payload():
entry = build_catalog_entry(
scope="user",
owner_id=5,
kind="provider_model",
source_id="openai",
target_id="gpt-4o",
label="GPT-4o",
description="General multimodal model",
capabilities=["chat", "vision"],
availability_state="ready",
availability_reason=None,
metadata={"context_length": 128000},
)
assert entry["id"] == "provider/openai/gpt-4o"
assert entry["kind"] == "provider_model"
assert entry["owner_scope"] == "user"
assert entry["owner_id"] == 5
assert entry["capabilities"] == ["chat", "vision"]
assert entry["metadata"]["context_length"] == 128000
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment