Commit e004541a authored by Your Name's avatar Your Name

Centralize model resolution and VRAM management in MultiModelManager.request_model()

- Added request_model() method to MultiModelManager that handles:
  1. Alias resolution (image, audio, tts, vision, default, custom aliases)
  2. VRAM management (unloading previous models in ondemand mode)
  3. Checking if model is already loaded

- Simplified codai/api/images.py:
  - Uses request_model() for model resolution and VRAM management
  - Extracted helper functions: _is_gguf_model(), _load_diffusers_pipeline(),
    _generate_with_diffusers(), _generate_with_sdcpp(), _load_sdcpp_model()
  - Removed duplicated sd.cpp generation code
  - Fixed semaphore scope (all generation now inside semaphore block)

- Simplified codai/api/tts.py:
  - Uses request_model() instead of duplicated VRAM management code
  - Removed duplicate get_cached_model_path() and get_model_cache_dir() wrappers

- Simplified codai/api/transcriptions.py:
  - Uses request_model() instead of duplicated VRAM management code

- Simplified codai/api/text.py:
  - Both /v1/chat/completions and /v1/completions use request_model()
  - Removed duplicated VRAM management blocks
parent a5b64c4c
This diff is collapsed.
......@@ -295,46 +295,13 @@ async def chat_completions(request: ChatCompletionRequest, http_request: Request
# Get the model for this request
requested_model = request.model
# Get load mode to determine if we need to unload other models first
from codai.api.state import get_load_mode
load_mode = get_load_mode()
# In ondemand mode (no --load-all or --loadswap), if ANY model is loaded in VRAM
# and it's different from what we need, fully unload it first to free VRAM
if load_mode == "ondemand":
has_any_model = len(multi_model_manager.models) > 0 or model_manager.backend is not None
if has_any_model:
# Resolve both the requested model and currently loaded model to their canonical names
requested_canonical = multi_model_manager.resolve_model_name(requested_model)
loaded_canonical = multi_model_manager.get_currently_loaded_model_name()
# Also check legacy model_manager
if not loaded_canonical and model_manager.backend is not None:
loaded_canonical = "legacy_model_manager"
# Compare: if they're different models (even if same type), unload first
already_loaded = (requested_canonical and loaded_canonical and
requested_canonical == loaded_canonical)
if not already_loaded:
print(f"In ondemand mode - model switch detected:")
print(f" Requested: '{requested_model}' (resolved to: '{requested_canonical}')")
print(f" Loaded: '{loaded_canonical}'")
print(f" -> Fully unloading current model(s) before loading new model...")
# Use centralized unload method
multi_model_manager.unload_all_models()
# Also cleanup legacy model_manager
if model_manager.backend is not None:
print("Unloading legacy model_manager from VRAM...")
try:
model_manager.cleanup()
except Exception as e:
print(f"Warning during legacy model cleanup: {e}")
# Use the manager to resolve the model and manage VRAM (handles ondemand unloading)
model_info = multi_model_manager.request_model(
requested_model=requested_model,
model_type="text"
)
# Try to get the appropriate model
# Try to get the appropriate model (request_model handles VRAM cleanup)
mm = multi_model_manager.get_model_for_request(requested_model)
if mm is None:
......@@ -1727,40 +1694,13 @@ async def completions(request: CompletionRequest):
# Get the model for this request
requested_model = request.model
# Get load mode to determine if we need to unload other models first
from codai.api.state import get_load_mode
load_mode = get_load_mode()
# In ondemand mode, if ANY model is loaded and it's different from what we need, unload first
if load_mode == "ondemand":
has_any_model = len(multi_model_manager.models) > 0 or model_manager.backend is not None
if has_any_model:
# Resolve both the requested model and currently loaded model to their canonical names
requested_canonical = multi_model_manager.resolve_model_name(requested_model)
loaded_canonical = multi_model_manager.get_currently_loaded_model_name()
# Also check legacy model_manager
if not loaded_canonical and model_manager.backend is not None:
loaded_canonical = "legacy_model_manager"
# Compare: if they're different models (even if same type), unload first
already_loaded = (requested_canonical and loaded_canonical and
requested_canonical == loaded_canonical)
if not already_loaded:
print(f"In ondemand mode - model switch detected:")
print(f" Requested: '{requested_model}' (resolved to: '{requested_canonical}')")
print(f" Loaded: '{loaded_canonical}'")
print(f" -> Fully unloading current model(s) before loading new model...")
multi_model_manager.unload_all_models()
if model_manager.backend is not None:
try:
model_manager.cleanup()
except:
pass
# Use the manager to resolve the model and manage VRAM (handles ondemand unloading)
model_info = multi_model_manager.request_model(
requested_model=requested_model,
model_type="text"
)
# Try to get the appropriate model
# Try to get the appropriate model (request_model handles VRAM cleanup)
mm = multi_model_manager.get_model_for_request(requested_model)
if mm is None:
......
......@@ -54,52 +54,22 @@ async def create_transcription(
raise HTTPException(status_code=500, detail=result["error"])
return {"text": result.get("text", "")}
audio_model = multi_model_manager.audio_models[0] if multi_model_manager.audio_models else None
if not audio_model:
# Use the manager to resolve the model and manage VRAM
model_info = multi_model_manager.request_model(
requested_model=model,
model_type="audio"
)
model_name = model_info['model_name']
model_key = model_info['model_key']
whisper_model = model_info['model_object']
if not model_name:
raise HTTPException(
status_code=400,
detail="Audio transcription not configured. Use --audio-model or --whisper-server."
)
# Get load mode to determine if we need to unload other models first
from codai.api.state import get_load_mode
from codai.models.manager import model_manager
load_mode = get_load_mode()
# In ondemand mode, if ANY model is loaded and it's different from what we need, unload first
if load_mode == "ondemand":
has_any_model = len(multi_model_manager.models) > 0 or model_manager.backend is not None
if has_any_model:
# Resolve both the requested audio model and currently loaded model to their canonical names
requested_canonical = multi_model_manager.resolve_model_name(f"audio:{audio_model}")
loaded_canonical = multi_model_manager.get_currently_loaded_model_name()
# Also check legacy model_manager
if not loaded_canonical and model_manager.backend is not None:
loaded_canonical = "legacy_model_manager"
# Compare: if they're different models, unload first
already_loaded = (requested_canonical and loaded_canonical and
requested_canonical == loaded_canonical)
if not already_loaded:
print(f"In ondemand mode - model switch detected:")
print(f" Requested: 'audio:{audio_model}' (resolved to: '{requested_canonical}')")
print(f" Loaded: '{loaded_canonical}'")
print(f" -> Fully unloading current model(s) before loading audio model...")
multi_model_manager.unload_all_models()
if model_manager.backend is not None:
try:
model_manager.cleanup()
except:
pass
# Determine model to use
model_to_use = model
if model_to_use.startswith("whisper:") or model_to_use.startswith("audio:"):
model_to_use = audio_model
# Read the uploaded file
file_content = await file.read()
......@@ -113,26 +83,23 @@ async def create_transcription(
try:
from faster_whisper import WhisperModel
# Determine model key
model_key = f"audio:{model_to_use}"
whisper_model = multi_model_manager.get_model(model_key)
if whisper_model is None:
print(f"Loading faster-whisper model: {model_to_use}")
print(f"Loading faster-whisper model: {model_name}")
# Determine compute type - always use int8 for CPU
compute_type = "int8"
# Load the model
whisper_model = WhisperModel(
model_to_use,
model_name,
device="cpu", # Always use CPU - faster-whisper CUDA doesn't work with AMD
compute_type=compute_type,
)
# Cache the model
multi_model_manager.add_model(model_key, whisper_model)
print(f"Loaded faster-whisper model: {model_to_use}")
multi_model_manager.current_model_key = model_key
print(f"Loaded faster-whisper model: {model_name}")
# Run transcription
segments, info = whisper_model.transcribe(
......@@ -160,24 +127,21 @@ async def create_transcription(
try:
import whispercpp
# Determine model key
model_key = f"audio:{model_to_use}"
whisper_model = multi_model_manager.get_model(model_key)
if whisper_model is None:
print(f"Loading whispercpp model: {model_to_use}")
print(f"Loading whispercpp model: {model_name}")
# Check if it's a built-in model name
if model_to_use in ['tiny.en', 'tiny', 'base.en', 'base', 'small.en', 'small', 'medium.en', 'medium', 'large-v1', 'large']:
if model_name in ['tiny.en', 'tiny', 'base.en', 'base', 'small.en', 'small', 'medium.en', 'medium', 'large-v1', 'large']:
# It's a built-in model name
whisper_model = whispercpp.Whisper.from_pretrained(model_to_use)
whisper_model = whispercpp.Whisper.from_pretrained(model_name)
else:
# It's a path to a GGUF file
whisper_model = whispercpp.Whisper.from_pretrained(model_to_use)
whisper_model = whispercpp.Whisper.from_pretrained(model_name)
# Cache the model
multi_model_manager.add_model(model_key, whisper_model)
print(f"Loaded whispercpp model: {model_to_use}")
multi_model_manager.current_model_key = model_key
print(f"Loaded whispercpp model: {model_name}")
# Run transcription
result = whisper_model.transcribe(tmp_path)
......
......@@ -16,18 +16,6 @@ from codai.models.manager import multi_model_manager
global_args = None
def get_cached_model_path(url: str) -> str:
"""Get cached model path if available."""
from codai.models.cache import get_cached_model_path as cache_get_cached_model_path
return cache_get_cached_model_path(url)
def get_model_cache_dir() -> str:
"""Get model cache directory."""
from codai.models.cache import get_model_cache_dir
return get_model_cache_dir()
def set_global_args(args):
"""Set global args from coderai."""
global global_args
......@@ -65,80 +53,46 @@ async def create_speech(request: TTSRequest):
Supports:
- Kokoro TTS models (when --tts-model is specified)
"""
tts_model = multi_model_manager.tts_model
# Use the manager to resolve the model and manage VRAM
model_info = multi_model_manager.request_model(
requested_model=request.model,
model_type="tts"
)
model_name = model_info['model_name']
model_key = model_info['model_key']
kokoro_model = model_info['model_object']
# If no TTS model configured, return an error
if not tts_model:
if not model_name:
raise HTTPException(
status_code=400,
detail="TTS not configured. Use --tts-model to specify a model."
)
# Get load mode to determine if we need to unload other models first
from codai.api.state import get_load_mode
from codai.models.manager import model_manager
load_mode = get_load_mode()
# In ondemand mode, if ANY model is loaded and it's different from what we need, unload first
if load_mode == "ondemand":
has_any_model = len(multi_model_manager.models) > 0 or model_manager.backend is not None
if has_any_model:
# Resolve both the requested TTS model and currently loaded model to their canonical names
requested_canonical = multi_model_manager.resolve_model_name(f"tts:{tts_model}")
loaded_canonical = multi_model_manager.get_currently_loaded_model_name()
# Also check legacy model_manager
if not loaded_canonical and model_manager.backend is not None:
loaded_canonical = "legacy_model_manager"
# Compare: if they're different models, unload first
already_loaded = (requested_canonical and loaded_canonical and
requested_canonical == loaded_canonical)
if not already_loaded:
print(f"In ondemand mode - model switch detected:")
print(f" Requested: 'tts:{tts_model}' (resolved to: '{requested_canonical}')")
print(f" Loaded: '{loaded_canonical}'")
print(f" -> Fully unloading current model(s) before loading TTS model...")
multi_model_manager.unload_all_models()
if model_manager.backend is not None:
try:
model_manager.cleanup()
except:
pass
# Determine model to use
model_to_use = request.model
if model_to_use.startswith("tts:"):
model_to_use = tts_model
# Try to use kokoro if available
try:
from kokoro import Kokoro
# Determine model key
model_key = f"tts:{model_to_use}"
kokoro_model = multi_model_manager.get_model(model_key)
if kokoro_model is None:
print(f"Loading Kokoro TTS model: {model_to_use}")
print(f"Loading Kokoro TTS model: {model_name}")
# Check if model_to_use is a URL - download it (with caching)
# Check if model_name is a URL - download it (with caching)
model_path = None
if model_to_use.startswith('http://') or model_to_use.startswith('https://'):
print(f"Loading model from URL: {model_to_use}")
if model_name.startswith('http://') or model_name.startswith('https://'):
print(f"Loading model from URL: {model_name}")
from codai.models.cache import load_model
model_path = load_model(model_to_use)
model_path = load_model(model_name)
if not model_path:
raise Exception(f"Failed to load model from {model_to_use}")
raise Exception(f"Failed to load model from {model_name}")
else:
# Use local path or model name
model_path = model_to_use
model_path = model_name
# Load the Kokoro model
kokoro_model = Kokoro(model_path if model_path else model_to_use)
kokoro_model = Kokoro(model_path if model_path else model_name)
multi_model_manager.add_model(model_key, kokoro_model)
multi_model_manager.current_model_key = model_key
# Generate speech
voice = request.voice or "af_sarah"
......
......@@ -883,6 +883,133 @@ class MultiModelManager:
return load_model(model_path, cache_dir, file_pattern)
def request_model(self, requested_model: str, model_type: str = None) -> Dict[str, Any]:
"""
Central method for API modules to request a model.
Handles:
1. Alias resolution (e.g., "image" -> "Tongyi-MAI/Z-Image-Turbo")
2. VRAM management (unloading previous models in ondemand mode)
3. Checking if model is already loaded
Args:
requested_model: The model name/alias from the API request
model_type: The type of model being requested ("image", "text", "audio", "tts", "vision")
Used to resolve empty/None model names to the appropriate default.
Returns:
Dict with:
- 'model_key': The key used to store/retrieve the model in self.models
- 'model_name': The resolved model name/path/HF ID
- 'model_object': The loaded model object if already loaded, None otherwise
- 'config': The stored configuration for this model
- 'already_loaded': True if the model is already loaded in VRAM
"""
from codai.api.state import get_load_mode
mode = get_load_mode()
# Step 1: Resolve the model name from aliases
resolved_name = None
model_key = None
# If no model specified, use the default for the given type
if not requested_model or requested_model == model_type:
if model_type == "image":
resolved_name = self.image_models[0] if self.image_models else None
elif model_type == "audio":
resolved_name = self.audio_models[0] if self.audio_models else None
elif model_type == "tts":
resolved_name = self.tts_model
elif model_type == "vision":
resolved_name = self.vision_models[0] if self.vision_models else None
else:
resolved_name = self.default_model
else:
# Resolve custom aliases
if requested_model in self.model_aliases:
requested_model = self.model_aliases[requested_model]
# Handle "default" alias
if requested_model == "default":
resolved_name = self.default_model
# Handle type-specific aliases
elif requested_model == "image":
resolved_name = self.image_models[0] if self.image_models else None
elif requested_model == "audio":
resolved_name = self.audio_models[0] if self.audio_models else None
elif requested_model == "tts":
resolved_name = self.tts_model
elif requested_model == "vision":
resolved_name = self.vision_models[0] if self.vision_models else None
# Handle prefixed models (e.g., "image:model_name")
elif requested_model.startswith("image:"):
resolved_name = requested_model[6:]
elif requested_model.startswith("audio:"):
resolved_name = requested_model[6:]
elif requested_model.startswith("tts:"):
resolved_name = requested_model[4:]
elif requested_model.startswith("vision:"):
resolved_name = requested_model[7:]
else:
resolved_name = requested_model
if not resolved_name:
return {
'model_key': None,
'model_name': None,
'model_object': None,
'config': {},
'already_loaded': False,
}
# Step 2: Build the model key (prefixed with type)
if model_type and model_type != "text":
model_key = f"{model_type}:{resolved_name}"
else:
model_key = resolved_name
# Step 3: Check if already loaded
existing_model = self.models.get(model_key)
if existing_model is not None:
self.current_model_key = model_key
return {
'model_key': model_key,
'model_name': resolved_name,
'model_object': existing_model,
'config': self.config.get(model_key, {}),
'already_loaded': True,
}
# Step 4: In ondemand mode, unload any currently loaded model
if mode == "ondemand":
has_any_model = len(self.models) > 0 or model_manager.backend is not None
if has_any_model:
loaded_canonical = self.get_currently_loaded_model_name()
if not loaded_canonical and model_manager.backend is not None:
loaded_canonical = "legacy_model_manager"
if loaded_canonical and loaded_canonical != model_key:
print(f"Ondemand mode - model switch detected:")
print(f" Requested: '{model_key}' (resolved: '{resolved_name}')")
print(f" Currently loaded: '{loaded_canonical}'")
print(f" -> Unloading current model(s) before loading new model...")
self.unload_all_models()
if model_manager.backend is not None:
try:
model_manager.cleanup()
except:
pass
# Step 5: Return info for the caller to load the model
return {
'model_key': model_key,
'model_name': resolved_name,
'model_object': None,
'config': self.config.get(model_key, {}),
'already_loaded': False,
}
def unload_all_models(self):
"""
Fully unload ALL models from VRAM. Used in ondemand mode when switching
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment