feat: Add multi-model support for audio transcription and image generation

- Add --audio-model and --image-model CLI arguments
- Add --loadall, --audio-ctx, --audio-offload, --vision-ctx, --vision-offload args
- Implement MultiModelManager class for dynamic model switching
- Add POST /v1/audio/transcriptions endpoint (OpenAI-compatible)
- Add POST /v1/images/generations endpoint (OpenAI-compatible)
- Update endpoints to use multi_model_manager for model selection
- Audio uses faster-whisper for local transcription
- Images use Stable Diffusion via diffusers
parent eb6b8d85
......@@ -171,6 +171,58 @@ class ModelList(BaseModel):
data: List[ModelInfo]
# =============================================================================
# Audio Transcription Models
# =============================================================================
class TranscriptionRequest(BaseModel):
model: str
file: Optional[bytes] = None
file_path: Optional[str] = None
language: Optional[str] = None
prompt: Optional[str] = None
response_format: Optional[str] = "json"
temperature: Optional[float] = 0.0
timestamp_granularities: Optional[List[str]] = None
class Config:
extra = "allow" # Allow extra fields to prevent 422 errors
class TranscriptionResponse(BaseModel):
text: str
class Config:
extra = "allow"
# =============================================================================
# Image Generation Models
# =============================================================================
class ImageGenerationRequest(BaseModel):
model: str
prompt: str
n: int = 1
size: Optional[str] = "1024x1024"
quality: Optional[str] = "standard"
style: Optional[str] = None
response_format: Optional[str] = "url"
seed: Optional[int] = None
user: Optional[str] = None
class Config:
extra = "allow" # Allow extra fields to prevent 422 errors
class ImageGenerationResponse(BaseModel):
created: int
data: List[Dict]
class Config:
extra = "allow"
# =============================================================================
# Content Filtering Utility
# =============================================================================
......@@ -1704,7 +1756,155 @@ class ModelManager:
self.backend = None
# Global model manager
# =============================================================================
# Multi-Model Manager (supports audio transcription and image generation)
# =============================================================================
class MultiModelManager:
"""
Manages multiple models: main text model, audio transcription, and image generation.
Supports dynamic switching based on request model name.
"""
def __init__(self):
self.models: Dict[str, ModelManager] = {}
self.default_model: Optional[str] = None
self.audio_model: Optional[str] = None
self.image_model: Optional[str] = None
self.tool_parser = ToolCallParser()
self.current_model_key: Optional[str] = None
# Configuration for each model type
self.config: Dict[str, Dict] = {}
def set_default_model(self, model_name: str, config: Dict = None):
"""Set the default/main text model."""
self.default_model = model_name
self.config[model_name] = config or {}
def set_audio_model(self, model_name: str, config: Dict = None):
"""Set the audio transcription model."""
self.audio_model = model_name
self.config[f"audio:{model_name}"] = config or {}
def set_image_model(self, model_name: str, config: Dict = None):
"""Set the image generation model."""
self.image_model = model_name
self.config[f"image:{model_name}"] = config or {}
def get_model_for_request(self, requested_model: str) -> Optional[ModelManager]:
"""
Get the appropriate model manager for a request based on model name.
Model name conventions:
- "default", empty, or matches default model -> use main model
- starts with "audio:" -> use audio model
- starts with "image:" -> use image model
"""
# Handle empty or "default" model names
if not requested_model or requested_model == "default":
if self.default_model and self.default_model in self.models:
self.current_model_key = self.default_model
return self.models[self.default_model]
return None
# Check for specialized models
if requested_model.startswith("audio:"):
audio_name = requested_model[6:] # Remove "audio:" prefix
key = f"audio:{audio_name}"
if key in self.models:
self.current_model_key = key
return self.models[key]
elif self.audio_model:
# Try loading audio model on demand
key = f"audio:{self.audio_model}"
return None # Signal that we need to load
return None
if requested_model.startswith("image:"):
image_name = requested_model[6:] # Remove "image:" prefix
key = f"image:{image_name}"
if key in self.models:
self.current_model_key = key
return self.models[key]
elif self.image_model:
# Try loading image model on demand
key = f"image:{self.image_model}"
return None # Signal that we need to load
return None
# Check if it's the default model
if self.default_model and (requested_model == self.default_model or
requested_model.endswith(self.default_model.split("/")[-1])):
self.current_model_key = self.default_model
return self.models.get(self.default_model)
# Check if any loaded model matches
for key, model in self.models.items():
if requested_model in key or key.endswith(requested_model.split("/")[-1]):
self.current_model_key = key
return model
return None
def add_model(self, key: str, manager: ModelManager):
"""Add a model manager for a specific key."""
self.models[key] = manager
def get_model(self, key: str) -> Optional[ModelManager]:
"""Get a model manager by key."""
return self.models.get(key)
def get_current_model(self) -> Optional[ModelManager]:
"""Get the currently active model."""
if self.current_model_key:
return self.models.get(self.current_model_key)
if self.default_model:
return self.models.get(self.default_model)
return None
def list_models(self) -> List[ModelInfo]:
"""List all available models."""
models = []
# Add default model
if self.default_model:
model_id = self.default_model
# Also add short name
short_name = self.default_model.split("/")[-1] if "/" in self.default_model else self.default_model
if short_name != self.default_model:
models.append(ModelInfo(id=short_name))
models.append(ModelInfo(id=model_id))
models.append(ModelInfo(id="default"))
# Add audio models
if self.audio_model:
audio_id = f"audio:{self.audio_model}"
models.append(ModelInfo(id=audio_id))
# Add image models
if self.image_model:
image_id = f"image:{self.image_model}"
models.append(ModelInfo(id=image_id))
# Add loaded models that aren't in the above categories
for key in self.models:
if key not in [self.default_model, f"audio:{self.audio_model}", f"image:{self.image_model}"]:
models.append(ModelInfo(id=key))
return models if models else [ModelInfo(id="default")]
def cleanup(self):
"""Cleanup all models."""
for model in self.models.values():
model.cleanup()
self.models.clear()
# Global multi-model manager
multi_model_manager = MultiModelManager()
# Global model manager (for backward compatibility)
model_manager = ModelManager()
# Global system prompt (set via --system-prompt flag)
......@@ -1722,6 +1922,7 @@ async def lifespan(app: FastAPI):
# Startup
yield
# Shutdown
multi_model_manager.cleanup()
model_manager.cleanup()
......@@ -1853,19 +2054,267 @@ async def log_requests(request: Request, call_next):
@app.get("/v1/models", response_model=ModelList)
async def list_models():
"""List available models."""
models = []
if model_manager.model_name:
models.append(ModelInfo(id=model_manager.model_name))
else:
models.append(ModelInfo(id="unknown"))
models = multi_model_manager.list_models()
return ModelList(data=models)
# =============================================================================
# Audio Transcription Endpoint
# =============================================================================
from fastapi import UploadFile, File, Form
@app.post("/v1/audio/transcriptions")
async def create_transcription(
model: str = Form(...),
file: UploadFile = File(...),
language: Optional[str] = Form(None),
prompt: Optional[str] = Form(None),
response_format: Optional[str] = Form("json"),
temperature: Optional[float] = Form(0.0),
):
"""
Audio transcription endpoint (OpenAI-compatible).
Supports:
- OpenAI's whisper-1 model (via OpenAI API)
- Local faster-whisper models (when --audio-model is specified)
"""
audio_model = multi_model_manager.audio_model
# If no audio model configured, return an error
if not audio_model:
raise HTTPException(
status_code=400,
detail="Audio transcription not configured. Use --audio-model to specify a model."
)
# Determine model to use
model_to_use = model
if model == "whisper-1" or model.startswith("audio:"):
# Use configured audio model
model_to_use = audio_model
# Read file content
file_content = await file.read()
# Try to use faster-whisper if available
try:
from faster_whisper import WhisperModel
# Determine compute type based on GPU availability
import torch
if torch.cuda.is_available():
compute_type = "float16"
else:
compute_type = "int8"
# Try to load the model (lazy loading)
model_key = f"audio:{model_to_use}"
whisper_model = multi_model_manager.get_model(model_key)
if whisper_model is None:
print(f"Loading faster-whisper model: {model_to_use}")
whisper_model = WhisperModel(
model_to_use,
device="cuda" if torch.cuda.is_available() else "cpu",
compute_type=compute_type
)
# Store in multi_model_manager
multi_model_manager.add_model(model_key, whisper_model)
# Write to temp file
import tempfile
import os
with tempfile.NamedTemporaryFile(delete=False, suffix=f"_{file.filename}") as tmp:
tmp.write(file_content)
tmp_path = tmp.name
try:
# Run transcription
segments, info = whisper_model.transcribe(
tmp_path,
language=language,
initial_prompt=prompt,
temperature=temperature or 0.0,
)
# Collect all segments
text_parts = []
for segment in segments:
text_parts.append(segment.text.strip())
full_text = " ".join(text_parts)
return {"text": full_text}
finally:
# Cleanup temp file
os.unlink(tmp_path)
except ImportError:
# faster-whisper not installed
raise HTTPException(
status_code=501,
detail="Audio transcription not available. Install faster-whisper: pip install faster-whisper"
)
except Exception as e:
print(f"Transcription error: {e}")
raise HTTPException(status_code=500, detail=f"Transcription error: {str(e)}")
# =============================================================================
# Image Generation Endpoint
# =============================================================================
@app.post("/v1/images/generations")
async def create_image_generation(request: ImageGenerationRequest):
"""
Image generation endpoint (OpenAI-compatible).
Supports:
- Stable Diffusion XL (via local inference with diffusers)
- Other diffusers models
"""
image_model = multi_model_manager.image_model
# If no image model configured, return an error
if not image_model:
raise HTTPException(
status_code=400,
detail="Image generation not configured. Use --image-model to specify a model."
)
# Determine model to use
model_to_use = request.model
if model_to_use.startswith("image:"):
model_to_use = image_model
# Parse size (e.g., "1024x1024")
width, height = 1024, 1024
if request.size:
parts = request.size.split("x")
if len(parts) == 2:
try:
width = int(parts[0])
height = int(parts[1])
except ValueError:
pass
# Try to use diffusers if available
try:
import torch
from diffusers import StableDiffusionXLPipeline, DiffusionPipeline
# Determine model key
model_key = f"image:{model_to_use}"
pipeline = multi_model_manager.get_model(model_key)
if pipeline is None:
print(f"Loading Stable Diffusion model: {model_to_use}")
# Try to load as Stable Diffusion XL first
try:
pipeline = StableDiffusionXLPipeline.from_pretrained(
model_to_use,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
use_safetensors=True,
)
except Exception:
# Try generic diffusion pipeline
pipeline = DiffusionPipeline.from_pretrained(
model_to_use,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
use_safetensors=True,
)
# Move to GPU if available
if torch.cuda.is_available():
pipeline = pipeline.to("cuda")
else:
pipeline = pipeline.to("cpu")
# Enable attention slicing for lower memory usage
if torch.cuda.is_available():
pipeline.enable_attention_slicing()
multi_model_manager.add_model(model_key, pipeline)
# Generate images
generator = None
if request.seed is not None:
generator = torch.Generator(device=pipeline.device).manual_seed(request.seed)
# Quality: "standard" or "hd"
quality = request.quality or "standard"
# Generate
result = pipeline(
prompt=request.prompt,
negative_prompt=None,
num_images_per_prompt=request.n,
height=height,
width=width,
generator=generator,
guidance_scale=7.5 if quality == "standard" else 9.0,
num_inference_steps=30 if quality == "standard" else 50,
)
# Extract images
images = []
for img in result.images:
# Convert to base64
import base64
import io
buffered = io.BytesIO()
img.save(buffered, format="PNG")
img_bytes = buffered.getvalue()
img_base64 = base64.b64encode(img_bytes).decode('utf-8')
if request.response_format == "base64":
images.append({"b64_json": img_base64})
else:
# For URL format, we'd need to save somewhere
# For now, return base64
images.append({"b64_json": img_base64})
return {
"created": int(time.time()),
"data": images
}
except ImportError as e:
# diffusers not installed
raise HTTPException(
status_code=501,
detail=f"Image generation not available. Install diffusers: pip install diffusers torch accelerate safetensors. Error: {str(e)}"
)
except Exception as e:
print(f"Image generation error: {e}")
import traceback
traceback.print_exc()
raise HTTPException(status_code=500, detail=f"Image generation error: {str(e)}")
@app.post("/v1/chat/completions")
async def chat_completions(request: ChatCompletionRequest):
"""Chat completions endpoint with streaming and tool support."""
if model_manager.backend is None:
raise HTTPException(status_code=503, detail="Model not loaded")
# Get the model for this request
requested_model = request.model
# Try to get the appropriate model
mm = multi_model_manager.get_model_for_request(requested_model)
if mm is None:
# Model not loaded - try to use default
if model_manager.backend is not None:
# Fallback to legacy model_manager
current_manager = model_manager
else:
raise HTTPException(status_code=503, detail="Model not loaded")
else:
current_manager = mm
# Inject system prompt if --system-prompt flag was provided
messages = request.messages
......@@ -1887,6 +2336,9 @@ async def chat_completions(request: ChatCompletionRequest):
if request.tools:
messages = format_tools_for_prompt(request.tools, messages)
# Get the tool_parser from the current manager
tool_parser = current_manager.tool_parser if hasattr(current_manager, 'tool_parser') else ToolCallParser()
# Prepare stop sequences
stop_sequences = []
if request.stop:
......@@ -1962,6 +2414,8 @@ async def chat_completions(request: ChatCompletionRequest):
request.top_p,
stop_sequences,
tools_dict,
current_manager,
tool_parser,
),
media_type="text/event-stream",
)
......@@ -1974,6 +2428,8 @@ async def chat_completions(request: ChatCompletionRequest):
request.top_p,
stop_sequences,
tools_dict,
current_manager,
tool_parser,
)
async def stream_chat_response(
......@@ -1984,6 +2440,8 @@ async def stream_chat_response(
top_p: float,
stop: List[str],
tools: Optional[List[Dict]],
current_manager: ModelManager,
tool_parser: ToolCallParser,
) -> AsyncGenerator[str, None]:
"""Stream chat completion response."""
completion_id = f"chatcmpl-{uuid.uuid4().hex}"
......@@ -1996,7 +2454,7 @@ async def stream_chat_response(
try:
chunk_count = 0
# Use generate_chat_stream for proper chat template handling
async for chunk in model_manager.generate_chat_stream(
async for chunk in current_manager.generate_chat_stream(
messages=messages,
max_tokens=max_tokens,
temperature=temperature,
......@@ -2009,7 +2467,7 @@ async def stream_chat_response(
filtered_chunk = filter_malformed_content(chunk)
# Always filter out tool call format - some may slip through even without tools
filtered_chunk = model_manager.tool_parser.strip_tool_calls_from_content(filtered_chunk)
filtered_chunk = tool_parser.strip_tool_calls_from_content(filtered_chunk)
if not filtered_chunk:
print(f"DEBUG: filtered_chunk was empty (original chunk: {repr(chunk[:50])})")
......@@ -2048,7 +2506,7 @@ async def stream_chat_response(
parameters=t["function"].get("parameters")
)
tool_objects.append(Tool(type=t.get("type", "function"), function=tool_func))
tool_calls = model_manager.tool_parser.extract_tool_calls(generated_text, tool_objects)
tool_calls = tool_parser.extract_tool_calls(generated_text, tool_objects)
if tool_calls:
# Tool calls were extracted and stripped from content during streaming
# Just send the tool_calls chunk
......@@ -2095,6 +2553,8 @@ async def generate_chat_response(
top_p: float,
stop: List[str],
tools: Optional[List[Dict]],
current_manager: ModelManager,
tool_parser: ToolCallParser,
) -> Dict:
"""Generate non-streaming chat completion response."""
completion_id = f"chatcmpl-{uuid.uuid4().hex}"
......@@ -2102,7 +2562,7 @@ async def generate_chat_response(
try:
# Use generate_chat for proper chat template handling
generated_text = model_manager.generate_chat(
generated_text = current_manager.generate_chat(
messages=messages,
max_tokens=max_tokens,
temperature=temperature,
......@@ -2132,10 +2592,10 @@ async def generate_chat_response(
parameters=t["function"].get("parameters")
)
tool_objects.append(Tool(type=t.get("type", "function"), function=tool_func))
tool_calls = model_manager.tool_parser.extract_tool_calls(generated_text, tool_objects)
tool_calls = tool_parser.extract_tool_calls(generated_text, tool_objects)
if tool_calls:
# Strip tool call format from content so user doesn't see raw tags
clean_content = model_manager.tool_parser.strip_tool_calls_from_content(generated_text)
clean_content = tool_parser.strip_tool_calls_from_content(generated_text)
response_message["content"] = clean_content if clean_content.strip() else None
response_message["tool_calls"] = tool_calls
finish_reason = "tool_calls"
......@@ -2169,8 +2629,21 @@ async def generate_chat_response(
@app.post("/v1/completions")
async def completions(request: CompletionRequest):
"""Text completions endpoint."""
if model_manager.backend is None:
raise HTTPException(status_code=503, detail="Model not loaded")
# Get the model for this request
requested_model = request.model
# Try to get the appropriate model
mm = multi_model_manager.get_model_for_request(requested_model)
if mm is None:
# Model not loaded - try to use default
if model_manager.backend is not None:
# Fallback to legacy model_manager
current_manager = model_manager
else:
raise HTTPException(status_code=503, detail="Model not loaded")
else:
current_manager = mm
prompts = request.prompt if isinstance(request.prompt, list) else [request.prompt]
stop_sequences = []
......@@ -2186,6 +2659,7 @@ async def completions(request: CompletionRequest):
request.temperature,
request.top_p,
stop_sequences,
current_manager,
),
media_type="text/event-stream",
)
......@@ -2197,6 +2671,7 @@ async def completions(request: CompletionRequest):
request.temperature,
request.top_p,
stop_sequences,
current_manager,
)
......@@ -2207,13 +2682,14 @@ async def stream_completion_response(
temperature: float,
top_p: float,
stop: List[str],
current_manager: ModelManager,
) -> AsyncGenerator[str, None]:
"""Stream completion response."""
completion_id = f"cmpl-{uuid.uuid4().hex}"
created = int(time.time())
try:
async for chunk in model_manager.generate_stream(
async for chunk in current_manager.generate_stream(
prompt=prompt,
max_tokens=max_tokens,
temperature=temperature,
......@@ -2249,13 +2725,14 @@ async def generate_completion_response(
temperature: float,
top_p: float,
stop: List[str],
current_manager: ModelManager,
) -> Dict:
"""Generate non-streaming completion response."""
completion_id = f"cmpl-{uuid.uuid4().hex}"
created = int(time.time())
try:
generated_text = model_manager.generate(
generated_text = current_manager.generate(
prompt=prompt,
max_tokens=max_tokens,
temperature=temperature,
......@@ -2264,9 +2741,9 @@ async def generate_completion_response(
)
# Calculate token counts if tokenizer available
if model_manager.tokenizer:
prompt_tokens = len(model_manager.tokenizer.encode(prompt))
completion_tokens = len(model_manager.tokenizer.encode(generated_text))
if current_manager.tokenizer:
prompt_tokens = len(current_manager.tokenizer.encode(prompt))
completion_tokens = len(current_manager.tokenizer.encode(generated_text))
else:
prompt_tokens = len(prompt.split())
completion_tokens = len(generated_text.split())
......@@ -2402,12 +2879,54 @@ def parse_args():
default=None,
help="Inject a system prompt at the beginning of conversations. Use without a value for a default prompt, or provide custom text.",
)
# Multi-model arguments
parser.add_argument(
"--audio-model",
type=str,
default=None,
help="Model for audio transcription (e.g., whisper-1, or path to faster-whisper model)",
)
parser.add_argument(
"--image-model",
type=str,
default=None,
help="Model for image generation (e.g., stable-diffusion-xl-base-1.0)",
)
parser.add_argument(
"--loadall",
action="store_true",
help="Pre-load all models (main, audio, image) at startup instead of on-demand",
)
parser.add_argument(
"--audio-ctx",
type=int,
default=480000,
help="Audio model context size in milliseconds (default: 480000 = 30 seconds for Whisper)",
)
parser.add_argument(
"--audio-offload",
type=float,
default=None,
help="Audio model GPU offload percentage (0-100). If not set, uses CPU",
)
parser.add_argument(
"--vision-ctx",
type=int,
default=2048,
help="Vision model context size (default: 2048)",
)
parser.add_argument(
"--vision-offload",
type=float,
default=None,
help="Vision model GPU offload percentage (0-100). If not set, loads fully on GPU",
)
return parser.parse_args()
def main():
"""Main entry point."""
global global_system_prompt
global global_system_prompt, model_manager, multi_model_manager
# Optional: set process name if procname is available
try:
......@@ -2462,7 +2981,7 @@ def main():
print(f" [{status}] {name}")
print("")
# Load the model
# Load the main model
load_kwargs = {
'offload_dir': args.offload_dir,
'load_in_4bit': args.load_in_4bit,
......@@ -2483,6 +3002,9 @@ def main():
backend_type=args.backend,
**load_kwargs
)
# Register with multi_model_manager
multi_model_manager.set_default_model(model_name, load_kwargs)
multi_model_manager.add_model(model_name, model_manager)
except Exception as e:
print(f"\nError loading model: {e}")
error_str = str(e).lower()
......@@ -2506,11 +3028,40 @@ def main():
print(f" coderai --backend vulkan --model {model_name}")
sys.exit(1)
# Set up audio model if specified
if args.audio_model:
print(f"\nAudio transcription model: {args.audio_model}")
multi_model_manager.set_audio_model(args.audio_model, {
'ctx': args.audio_ctx,
'offload': args.audio_offload,
})
# Set up image model if specified
if args.image_model:
print(f"\nImage generation model: {args.image_model}")
multi_model_manager.set_image_model(args.image_model, {
'ctx': args.vision_ctx,
'offload': args.vision_offload,
})
# If --loadall, pre-load all models
if args.loadall:
print("\nPre-loading all models...")
# Audio model will be loaded on first request (lazy loading)
# Image model will be loaded on first request (lazy loading)
print(" - Audio model: will load on first request")
print(" - Image model: will load on first request")
# Start the server
import uvicorn
print(f"\nStarting server on http://{args.host}:{args.port}")
print(f"API documentation available at http://{args.host}:{args.port}/docs")
print(f"Using backend: {model_manager.backend_type}")
# Print available models
models = multi_model_manager.list_models()
print(f"Available models: {[m.id for m in models]}")
uvicorn.run(app, host=args.host, port=args.port)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment