fix(studio): restore capability inference coverage

parent 8628bdc3
......@@ -43,6 +43,23 @@ STUDIO_CAPABILITY_MAP = {
}
DEFAULT_CHAT_PROVIDER_TYPES = {"openai", "anthropic", "google", "kilo", "claude", "qwen", "codex"}
NON_CHAT_MEDIA_TOKENS = {
"dall-e",
"dalle",
"stable-diffusion",
"sd-",
"sdxl",
"midjourney",
"imagen",
"flux",
"sora",
"runway",
"pika",
"text-to-video",
"t2v",
"video-llama",
"video-chat",
}
@dataclass
......@@ -96,7 +113,7 @@ def infer_model_capabilities(
output_modalities = architecture.get("output_modalities") or []
if not capabilities:
if not any(token in name for token in ["embedding", "embed", "whisper", "tts", "dall-e", "stable-diffusion"]):
if not any(token in name for token in ["embedding", "embed", "whisper", "tts", *NON_CHAT_MEDIA_TOKENS]):
capabilities.append("chat")
if any(token in name for token in ["vision", "gpt-4-turbo", "gpt-4o", "claude-3", "gemini-1.5", "gemini-2.0", "gemini-pro-vision", "llava", "blip"]):
capabilities.append("vision")
......@@ -104,8 +121,10 @@ def infer_model_capabilities(
capabilities.append("image_generation")
if any(token in name for token in ["stable-diffusion", "sd-", "sdxl", "controlnet", "img2img"]):
capabilities.append("image_edit")
if any(token in name for token in ["sora", "runway", "pika", "text-to-video", "t2v", "video"]):
if any(token in name for token in ["sora", "runway", "pika", "text-to-video", "t2v"]):
capabilities.append("video_generation")
if any(token in name for token in ["video-llama", "video-chat", "v2t"]):
capabilities.append("video_understanding")
if any(token in name for token in ["whisper", "transcribe", "speech-to-text", "stt"]):
capabilities.extend(["audio_input", "transcription"])
if any(token in name for token in ["tts", "text-to-speech", "elevenlabs", "bark", "tortoise", "speech"]):
......@@ -116,6 +135,8 @@ def infer_model_capabilities(
capabilities.append("embeddings")
if any(token in name for token in ["gpt-4", "gpt-3.5-turbo", "claude-3", "gemini", "function", "tool"]):
capabilities.append("tool_use")
if any(token in name for token in ["codex", "code-", "starcoder", "codellama", "deepseek-coder", "phind"]):
capabilities.extend(["code_generation", "code_completion"])
if any(token in name for token in ["reasoning", "cot", "o1", "o3"]):
capabilities.append("reasoning")
......
......@@ -67,3 +67,33 @@ def test_build_catalog_entry_normalizes_provider_model_payload():
assert entry["owner_id"] == 5
assert entry["capabilities"] == ["chat", "vision"]
assert entry["metadata"]["context_length"] == 128000
def test_infer_model_capabilities_restores_code_model_families():
result = infer_model_capabilities(
model_name="deepseek-coder-33b-instruct",
provider_type="openai",
)
assert "code_generation" in result.capabilities
assert "code_completion" in result.capabilities
def test_infer_model_capabilities_does_not_treat_video_understanding_models_as_generation():
result = infer_model_capabilities(
model_name="video-llama-2",
provider_type="openai",
)
assert "video_understanding" in result.capabilities
assert "video_generation" not in result.capabilities
@pytest.mark.parametrize("model_name", ["dalle-3", "runway-gen3"])
def test_infer_model_capabilities_does_not_fallback_to_chat_for_media_models(model_name):
result = infer_model_capabilities(
model_name=model_name,
provider_type="openai",
)
assert "chat" not in result.capabilities
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment