Commit 65caf41f authored by Your Name's avatar Your Name

Implement multiple audio/image model support with aliases

- Add support for multiple --audio-model arguments (action='append')
- Add support for multiple --image-model arguments (action='append')
- Add 'audio' alias pointing to first audio model
- Add 'vision'/'image' aliases pointing to first image model
- Update MultiModelManager to store audio_models and image_models as lists
- Add audio_model and image_model properties for accessing first model
- Update get_model_for_request to handle aliases
- Update list_models to show all models and aliases
- Fix remaining references in main function to use list-based variables
parent c2bd5ffa
......@@ -1959,9 +1959,9 @@ class MultiModelManager:
def __init__(self):
self.models: Dict[str, ModelManager] = {}
self.default_model: Optional[str] = None
self.audio_model: Optional[str] = None
self.audio_models: List[str] = [] # List of audio model names
self.tts_model: Optional[str] = None
self.image_model: Optional[str] = None
self.image_models: List[str] = [] # List of image model names
self.tool_parser = ToolCallParser()
self.current_model_key: Optional[str] = None
# Configuration for each model type
......@@ -1970,6 +1970,16 @@ class MultiModelManager:
self.load_mode: str = "ondemand" # "ondemand", "loadall", "loadswap"
self.active_in_vram: Optional[str] = None # Which model is currently in VRAM
@property
def audio_model(self) -> Optional[str]:
"""Get the first/default audio model."""
return self.audio_models[0] if self.audio_models else None
@property
def image_model(self) -> Optional[str]:
"""Get the first/default image model."""
return self.image_models[0] if self.image_models else None
def set_load_mode(self, mode: str):
"""Set the load mode: 'ondemand', 'loadall', or 'loadswap'."""
self.load_mode = mode
......@@ -1980,8 +1990,9 @@ class MultiModelManager:
self.config[model_name] = config or {}
def set_audio_model(self, model_name: str, config: Dict = None):
"""Set the audio transcription model."""
self.audio_model = model_name
"""Add an audio transcription model."""
if model_name not in self.audio_models:
self.audio_models.append(model_name)
self.config[f"audio:{model_name}"] = config or {}
def set_tts_model(self, model_name: str, config: Dict = None):
......@@ -1990,8 +2001,9 @@ class MultiModelManager:
self.config[f"tts:{model_name}"] = config or {}
def set_image_model(self, model_name: str, config: Dict = None):
"""Set the image generation model."""
self.image_model = model_name
"""Add an image generation model."""
if model_name not in self.image_models:
self.image_models.append(model_name)
self.config[f"image:{model_name}"] = config or {}
def get_model_for_request(self, requested_model: str) -> Optional[ModelManager]:
......@@ -2000,8 +2012,13 @@ class MultiModelManager:
Model name conventions:
- "default", empty, or matches default model -> use main model
- starts with "audio:" -> use audio model
- starts with "image:" -> use image model
- "audio" -> use first/default audio model
- "audio:modelname" -> use specific audio model
- "vision" or "image" -> use first/default image model
- "vision:modelname" or "image:modelname" -> use specific image model
- "tts" -> use TTS model
- "tts:modelname" -> use specific TTS model
- Otherwise match by model ID in multi_model_manager.models
"""
# Handle empty or "default" model names
if not requested_model or requested_model == "default":
......@@ -2010,17 +2027,50 @@ class MultiModelManager:
return self.models[self.default_model]
return None
# Check for specialized models
# Handle "audio" alias - use first/default audio model
if requested_model == "audio":
if self.audio_models:
first_audio = self.audio_models[0]
key = f"audio:{first_audio}"
if key in self.models:
self.current_model_key = key
return self.models[key]
# Try to load on demand
return None
return None
# Handle "vision" or "image" alias - use first/default image model
if requested_model in ("vision", "image"):
if self.image_models:
first_image = self.image_models[0]
key = f"image:{first_image}"
if key in self.models:
self.current_model_key = key
return self.models[key]
# Try to load on demand
return None
return None
# Handle "tts" alias
if requested_model == "tts":
if self.tts_model:
key = f"tts:{self.tts_model}"
if key in self.models:
self.current_model_key = key
return self.models[key]
return None
return None
# Check for specialized models with prefix
if requested_model.startswith("audio:"):
audio_name = requested_model[6:] # Remove "audio:" prefix
key = f"audio:{audio_name}"
if key in self.models:
self.current_model_key = key
return self.models[key]
elif self.audio_model:
elif audio_name in self.audio_models:
# Try loading audio model on demand
key = f"audio:{self.audio_model}"
return None # Signal that we need to load
return None
return None
if requested_model.startswith("tts:"):
......@@ -2029,22 +2079,24 @@ class MultiModelManager:
if key in self.models:
self.current_model_key = key
return self.models[key]
elif self.tts_model:
# Try loading TTS model on demand
key = f"tts:{self.tts_model}"
return None # Signal that we need to load
elif self.tts_model and tts_name == self.tts_model:
return None
return None
if requested_model.startswith("image:"):
image_name = requested_model[6:] # Remove "image:" prefix
# Handle both "vision:" and "image:" prefixes
if requested_model.startswith("vision:") or requested_model.startswith("image:"):
# Extract the model name (remove either prefix)
if requested_model.startswith("vision:"):
image_name = requested_model[7:] # Remove "vision:" prefix
else:
image_name = requested_model[6:] # Remove "image:" prefix
key = f"image:{image_name}"
if key in self.models:
self.current_model_key = key
return self.models[key]
elif self.image_model:
elif image_name in self.image_models:
# Try loading image model on demand
key = f"image:{self.image_model}"
return None # Signal that we need to load
return None
return None
# Check if it's the default model
......@@ -2081,7 +2133,7 @@ class MultiModelManager:
"""List all available models."""
models = []
# Add default model
# Add default model(s)
if self.default_model:
model_id = self.default_model
# Also add short name
......@@ -2091,25 +2143,37 @@ class MultiModelManager:
models.append(ModelInfo(id=model_id))
models.append(ModelInfo(id="default"))
# Add audio models
if self.audio_model:
audio_id = f"audio:{self.audio_model}"
models.append(ModelInfo(id=audio_id))
# Add aliases for first/default models
if self.audio_models:
models.append(ModelInfo(id="audio")) # Alias for first audio model
# Add all audio models
for audio_id in self.audio_models:
models.append(ModelInfo(id=f"audio:{audio_id}"))
# Add TTS models
if self.tts_model:
models.append(ModelInfo(id="tts")) # Alias for TTS
tts_id = f"tts:{self.tts_model}"
models.append(ModelInfo(id=tts_id))
# Add image models
if self.image_model:
image_id = f"image:{self.image_model}"
models.append(ModelInfo(id=image_id))
# Add vision/image models
if self.image_models:
models.append(ModelInfo(id="vision")) # Alias for first image model
models.append(ModelInfo(id="image")) # Alias for first image model
# Add all image models
for image_id in self.image_models:
models.append(ModelInfo(id=f"image:{image_id}"))
models.append(ModelInfo(id=f"vision:{image_id}"))
# Add loaded models that aren't in the above categories
for key in self.models:
if key not in [self.default_model, f"audio:{self.audio_model}", f"image:{self.image_model}"]:
models.append(ModelInfo(id=key))
# Skip if already added
if key == self.default_model or key.startswith("audio:") or key.startswith("image:") or key.startswith("tts:"):
continue
# Skip short names (already added)
if self.default_model and key == self.default_model.split("/")[-1]:
continue
models.append(ModelInfo(id=key))
return models if models else [ModelInfo(id="default")]
......@@ -3512,8 +3576,9 @@ def parse_args():
parser.add_argument(
"--model",
type=str,
action="append",
default=None,
help="Model name, path, or URL for text-to-text LLM. Optional if only using --audio-model or --image-model",
help="Model name, path, or URL for text-to-text LLM. Can be specified multiple times for multiple models.",
)
parser.add_argument(
"--backend",
......@@ -3614,19 +3679,21 @@ def parse_args():
"--tts-model",
type=str,
default=None,
help="Model for text-to-speech (e.g., kokoro, or path/URL to Kokoro model)",
help="Model for text-to-speech (e.g., kokoro, or path/URL to Kokoro model). Can be specified multiple times.",
)
parser.add_argument(
"--audio-model",
type=str,
action="append",
default=None,
help="Model for audio transcription (e.g., whisper-1, or path to faster-whisper model)",
help="Model for audio transcription (e.g., whisper-1, base, or path to faster-whisper model). Can be specified multiple times for multiple models.",
)
parser.add_argument(
"--image-model",
type=str,
action="append",
default=None,
help="Model for image generation (e.g., stable-diffusion-xl-base-1.0)",
help="Model for image generation (e.g., stable-diffusion-xl-base-1.0). Can be specified multiple times for multiple models.",
)
parser.add_argument(
"--loadall",
......@@ -3728,17 +3795,21 @@ def main():
print(f"Error listing devices: {e}")
sys.exit(0)
# Get model name from args or prompt interactively
model_name = args.model
# Get model names from args - support multiple models
model_names = args.model if args.model else []
# Validate: must have at least one model specified
if model_name is None and args.audio_model is None and args.image_model is None and args.tts_model is None:
audio_models = args.audio_model if args.audio_model else []
image_models = args.image_model if args.image_model else []
if not model_names and not audio_models and not image_models and args.tts_model is None:
print("Error: At least one of --model, --audio-model, --image-model, or --tts-model must be specified.")
print("")
print("For NVIDIA backend (HuggingFace models):")
print(" - microsoft/DialoGPT-medium")
print(" - meta-llama/Llama-2-7b-chat-hf (requires auth)")
print(" - TinyLlama/TinyLlama-1.1B-Chat-v1.0")
print(" - Use multiple --model flags for multiple models")
print("")
print("For Vulkan backend (GGUF models):")
print(" - Local path: ./phi-3-mini-4k-instruct-q4_k_m.gguf")
......@@ -3755,6 +3826,13 @@ def main():
print(" - --image-model stabilityai/stable-diffusion-xl-base-1.0")
sys.exit(1)
# Print loaded models info
if model_names:
print(f"\nText model(s): {model_names}")
if len(model_names) > 1:
# Load mode will be determined below
print(f"Multiple models configured - load mode will be set based on --loadall/--loadswap flags")
# Detect available backends
available = detect_available_backends()
print("\nAvailable backends:")
......@@ -3824,6 +3902,89 @@ def main():
# Set load mode in multi_model_manager
multi_model_manager.set_load_mode(load_mode)
# Load models based on mode and count
if len(model_names) > 1:
# Multiple models - handle based on load mode
print(f"\n=== Multiple Models Mode: {load_mode} ===")
if load_mode == "loadall":
# Load all models into VRAM
for i, model_name in enumerate(model_names):
print(f"\nLoading model {i+1}/{len(model_names)}: {model_name}")
try:
manager = ModelManager()
manager.load_model(
model_name=model_name,
backend_type=args.backend,
**load_kwargs
)
multi_model_manager.add_model(model_name, manager)
print(f"Loaded: {model_name}")
except Exception as e:
print(f"Error loading {model_name}: {e}")
elif load_mode == "loadswap":
# First model in VRAM, others in RAM
for i, model_name in enumerate(model_names):
print(f"\nLoading model {i+1}/{len(model_names)}: {model_name} ({'VRAM' if i == 0 else 'RAM'})")
try:
manager = ModelManager()
# For non-first models, we'll load them in CPU mode initially
if i > 0:
# Modify kwargs for CPU-only loading
swap_kwargs = load_kwargs.copy()
swap_kwargs['n_gpu_layers'] = 0 # Force CPU only
manager.load_model(
model_name=model_name,
backend_type=args.backend,
**swap_kwargs
)
else:
manager.load_model(
model_name=model_name,
backend_type=args.backend,
**load_kwargs
)
multi_model_manager.add_model(model_name, manager)
print(f"Loaded: {model_name} ({'VRAM' if i == 0 else 'RAM'})")
except Exception as e:
print(f"Error loading {model_name}: {e}")
else: # ondemand
# Only load first model, others on-demand
print(f"\nLoading first model (VRAM): {model_names[0]}")
try:
model_manager.load_model(
model_name=model_names[0],
backend_type=args.backend,
**load_kwargs
)
multi_model_manager.set_default_model(model_names[0], load_kwargs)
multi_model_manager.add_model(model_names[0], model_manager)
print(f"Loaded: {model_names[0]}")
# Register other models but don't load them
for model_name in model_names[1:]:
multi_model_manager.set_default_model(model_name, load_kwargs)
print(f"\nOther models will load on-demand: {model_names[1:]}")
except Exception as e:
print(f"Error loading model: {e}")
sys.exit(1)
elif len(model_names) == 1:
# Single model - load it
model_name = model_names[0]
# Determine load mode BEFORE setting up other models
load_mode = "ondemand"
if args.loadall:
load_mode = "loadall"
elif args.loadswap:
load_mode = "loadswap"
# Set load mode in multi_model_manager
multi_model_manager.set_load_mode(load_mode)
# Pre-load models based on mode
if load_mode == "loadall":
# Load all models into VRAM up to full capacity, then offload to CPU RAM
......@@ -3834,13 +3995,13 @@ def main():
print(f"Pre-loading main text model: {model_name}")
# Load image model
if args.image_model:
print(f"Pre-loading image model: {args.image_model}")
if image_models:
print(f"Pre-loading image model: {image_models[0]}")
print(f" Image model will load on first request")
# Load audio model
if args.audio_model:
print(f"Pre-loading audio model: {args.audio_model}")
if audio_models:
print(f"Pre-loading audio model: {audio_models[0]}")
# Load TTS model
if args.tts_model:
......@@ -3851,10 +4012,10 @@ def main():
print("\n=== Load Swap Mode ===")
if model_name:
print(f"Main text model will be in VRAM: {model_name}")
if args.image_model:
print(f"Image model in RAM: {args.image_model}")
if args.audio_model:
print(f"Audio model in RAM: {args.audio_model}")
if image_models:
print(f"Image model in RAM: {image_models[0]}")
if audio_models:
print(f"Audio model in RAM: {audio_models[0]}")
if args.tts_model:
print(f"TTS model in RAM: {args.tts_model}")
......@@ -3864,26 +4025,30 @@ def main():
print("Models will load on first request")
# Set up audio model if specified (with pre-loading if in loadall/loadswap mode)
if args.audio_model:
print(f"\nAudio transcription model: {args.audio_model}")
if audio_models:
print(f"\nAudio transcription model(s): {audio_models}")
# Set up Vulkan device for Whisper if using Vulkan backend
if hasattr(args, 'audio_vulkan_device') and args.audio_vulkan_device is not None:
os.environ['GGML_VULKAN_DEVICE'] = str(args.audio_vulkan_device)
print(f" Using Vulkan device: {args.audio_vulkan_device}")
multi_model_manager.set_audio_model(args.audio_model, {
'ctx': args.audio_ctx,
'offload': args.audio_offload,
})
# Pre-load audio model at startup if:
# Register all audio models
for audio_m in audio_models:
multi_model_manager.set_audio_model(audio_m, {
'ctx': args.audio_ctx,
'offload': args.audio_offload,
})
# Pre-load first audio model at startup if:
# - Using loadall or loadswap mode, OR
# - No main model is specified (only audio model configured)
should_preload = load_mode in ("loadall", "loadswap") or (model_name is None and args.audio_model)
should_preload = load_mode in ("loadall", "loadswap") or (model_name is None and audio_models)
if should_preload:
print(f"Pre-loading audio model...")
print(f"Pre-loading audio model... {audio_models[0]}")
# Check if model is a GGUF file - faster-whisper doesn't support GGUF format
model_to_use = args.audio_model
# Use first audio model for pre-loading
model_to_use = audio_models[0]
is_gguf_model = model_to_use.endswith('.gguf') or 'gguf' in model_to_use.lower()
if is_gguf_model:
......@@ -3899,7 +4064,7 @@ def main():
from faster_whisper import WhisperModel
import torch
model_to_use = args.audio_model
model_to_use = audio_models[0]
model_path = None
# Check if model is a URL - handle caching
......@@ -3926,7 +4091,7 @@ def main():
)
# Store in multi_model_manager
model_key = f"audio:{args.audio_model}"
model_key = f"audio:{audio_models[0]}"
multi_model_manager.add_model(model_key, whisper_model)
print(f"Audio model loaded successfully (faster-whisper)")
......@@ -3969,7 +4134,7 @@ def main():
try:
import whispercpp
model_to_use = args.audio_model
model_to_use = audio_models[0]
model_path = None
# Check if model is a URL - handle caching
......@@ -4012,7 +4177,7 @@ def main():
whisper_model = whispercpp.Whisper.from_pretrained(model_path)
# Store in multi_model_manager
model_key = f"audio:{args.audio_model}"
model_key = f"audio:{audio_models[0]}"
multi_model_manager.add_model(model_key, whisper_model)
print(f"Audio model loaded successfully (whispercpp)")
if whisper_vulkan_available:
......@@ -4046,20 +4211,26 @@ def main():
multi_model_manager.set_tts_model(args.tts_model, {})
# Pre-load TTS model if it's the only model configured
if model_name is None and not args.audio_model and not args.image_model:
if model_name is None and not audio_models and not image_models:
print(f"Pre-loading TTS model...")
# TTS models load on-demand, but we can pre-download if needed
# Set up image model if specified
if args.image_model:
print(f"\nImage generation model: {args.image_model}")
multi_model_manager.set_image_model(args.image_model, {
if image_models:
print(f"\nImage generation model(s): {image_models}")
multi_model_manager.set_image_model(image_models[0], {
'ctx': args.vision_ctx,
'offload': args.vision_offload,
})
# Register all image models
for img_m in image_models[1:]:
multi_model_manager.set_image_model(img_m, {
'ctx': args.vision_ctx,
'offload': args.vision_offload,
})
# Pre-load image model if it's the only model configured
if model_name is None and not args.audio_model and not args.tts_model:
if model_name is None and not audio_models and not args.tts_model:
print(f"Pre-loading image model...")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment