Commit 63460a13 authored by Your Name's avatar Your Name

Complete fix: Add ondemand mode model switching to audio and TTS endpoints

- Added model resolution and unload logic to /v1/audio/transcriptions
- Added model resolution and unload logic to /v1/audio/speech (TTS)
- Now ALL endpoints (text, image, audio, TTS) properly handle model switching
- In ondemand mode, ANY model type switch triggers unload first (e.g., text->audio, TTS->image, etc.)
parent a37085b4
...@@ -61,6 +61,40 @@ async def create_transcription( ...@@ -61,6 +61,40 @@ async def create_transcription(
detail="Audio transcription not configured. Use --audio-model or --whisper-server." detail="Audio transcription not configured. Use --audio-model or --whisper-server."
) )
# Get load mode to determine if we need to unload other models first
from codai.api.state import get_load_mode
from codai.models.manager import model_manager
load_mode = get_load_mode()
# In ondemand mode, if ANY model is loaded and it's different from what we need, unload first
if load_mode == "ondemand":
has_any_model = len(multi_model_manager.models) > 0 or model_manager.backend is not None
if has_any_model:
# Resolve both the requested audio model and currently loaded model to their canonical names
requested_canonical = multi_model_manager.resolve_model_name(f"audio:{audio_model}")
loaded_canonical = multi_model_manager.get_currently_loaded_model_name()
# Also check legacy model_manager
if not loaded_canonical and model_manager.backend is not None:
loaded_canonical = "legacy_model_manager"
# Compare: if they're different models, unload first
already_loaded = (requested_canonical and loaded_canonical and
requested_canonical == loaded_canonical)
if not already_loaded:
print(f"In ondemand mode - model switch detected:")
print(f" Requested: 'audio:{audio_model}' (resolved to: '{requested_canonical}')")
print(f" Loaded: '{loaded_canonical}'")
print(f" -> Fully unloading current model(s) before loading audio model...")
multi_model_manager.unload_all_models()
if model_manager.backend is not None:
try:
model_manager.cleanup()
except:
pass
# Determine model to use # Determine model to use
model_to_use = model model_to_use = model
if model_to_use.startswith("whisper:") or model_to_use.startswith("audio:"): if model_to_use.startswith("whisper:") or model_to_use.startswith("audio:"):
......
...@@ -74,6 +74,40 @@ async def create_speech(request: TTSRequest): ...@@ -74,6 +74,40 @@ async def create_speech(request: TTSRequest):
detail="TTS not configured. Use --tts-model to specify a model." detail="TTS not configured. Use --tts-model to specify a model."
) )
# Get load mode to determine if we need to unload other models first
from codai.api.state import get_load_mode
from codai.models.manager import model_manager
load_mode = get_load_mode()
# In ondemand mode, if ANY model is loaded and it's different from what we need, unload first
if load_mode == "ondemand":
has_any_model = len(multi_model_manager.models) > 0 or model_manager.backend is not None
if has_any_model:
# Resolve both the requested TTS model and currently loaded model to their canonical names
requested_canonical = multi_model_manager.resolve_model_name(f"tts:{tts_model}")
loaded_canonical = multi_model_manager.get_currently_loaded_model_name()
# Also check legacy model_manager
if not loaded_canonical and model_manager.backend is not None:
loaded_canonical = "legacy_model_manager"
# Compare: if they're different models, unload first
already_loaded = (requested_canonical and loaded_canonical and
requested_canonical == loaded_canonical)
if not already_loaded:
print(f"In ondemand mode - model switch detected:")
print(f" Requested: 'tts:{tts_model}' (resolved to: '{requested_canonical}')")
print(f" Loaded: '{loaded_canonical}'")
print(f" -> Fully unloading current model(s) before loading TTS model...")
multi_model_manager.unload_all_models()
if model_manager.backend is not None:
try:
model_manager.cleanup()
except:
pass
# Determine model to use # Determine model to use
model_to_use = request.model model_to_use = request.model
if model_to_use.startswith("tts:"): if model_to_use.startswith("tts:"):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment