Commit 69fe4af0 authored by Your Name's avatar Your Name

Add model capability detection

- Add ModelCapabilities dataclass to represent model capabilities
- Add detect_model_capabilities() function to detect:
  - text_generation (LLM)
  - vision (image understanding)
  - image_generation (Stable Diffusion)
  - speech_to_text (whisper)
  - text_to_speech (TTS)
- Use capability detection for better error messages in image generation endpoint
parent ad4ec2a5
......@@ -35,6 +35,79 @@ queue_flags = {"model_1": False, "image_1": False, "audio_1": False, "tts_1": Fa
# Model Cache Directory
# =============================================================================
# =============================================================================
# Model Capability Detection
# =============================================================================
from typing import Set, Optional
from dataclasses import dataclass, field
@dataclass
class ModelCapabilities:
"""Represents what a model can do."""
text_generation: bool = False # LLM/chat completion
vision: bool = False # Image understanding
image_generation: bool = False # Text-to-image (Stable Diffusion)
speech_to_text: bool = False # Audio transcription
text_to_speech: bool = False # Speech synthesis
def __str__(self):
caps = []
if self.text_generation:
caps.append("text")
if self.vision:
caps.append("vision")
if self.image_generation:
caps.append("image")
if self.speech_to_text:
caps.append("speech-to-text")
if self.text_to_speech:
caps.append("text-to-speech")
return ", ".join(caps) if caps else "none"
def detect_model_capabilities(model_name: str) -> ModelCapabilities:
"""
Detect model capabilities based on model name/type.
This is a heuristic detection - actual capabilities may vary.
"""
caps = ModelCapabilities()
if not model_name:
return caps
name_lower = model_name.lower()
# Check for image generation models (Stable Diffusion, SDXL, etc.)
if any(x in name_lower for x in ['stable-diffusion', 'sd15', 'sdxl', 'sd-xl', 'turbo', 'playground']):
caps.image_generation = True
return caps # Usually SD models are dedicated
# Check for vision models
if any(x in name_lower for x in ['vision', 'vl-', '-vl', 'llava', 'qwen2-vl', 'qwen-vl', 'phi-4-mini', 'pixtral', 'clip']):
caps.vision = True
caps.text_generation = True # Vision models are also LLMs
return caps
# Check for TTS models
if any(x in name_lower for x in ['kokoro', 'tts', 'speech', 'voice']):
caps.text_to_speech = True
return caps
# Check for whisper models (speech-to-text)
if any(x in name_lower for x in ['whisper', 'faster-whisper', 'distil-whisper']):
caps.speech_to_text = True
return caps
# Check for GGUF models (typically text models)
if '.gguf' in name_lower or 'gguf' in name_lower:
caps.text_generation = True
return caps
# Default: assume text generation (most HF models are LLMs)
caps.text_generation = True
return caps
def get_model_cache_dir() -> str:
"""Get or create the model cache directory."""
# Use XDG_CACHE_HOME if set, otherwise use ~/.cache/coderai
......@@ -3352,6 +3425,14 @@ async def create_image_generation(request: ImageGenerationRequest):
# If still no image model configured, return an error
if not image_model:
# Try to get capabilities of requested model for better error message
requested = request.model if request.model else "default"
caps = detect_model_capabilities(requested)
if caps.text_generation and not caps.image_generation:
raise HTTPException(
status_code=400,
detail=f"Model '{requested}' is a text generation model (capabilities: {caps}). Use --image-model to specify an image generation model like 'stable-diffusion-xl'."
)
raise HTTPException(
status_code=400,
detail="Image generation not configured. Use --image-model to specify a model."
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment