Commit 362b8452 authored by Your Name's avatar Your Name

Implement on-demand model swapping for multiple models

- Add model_backend_types dict to track backend for each model
- Update set_default_model to accept backend_type parameter
- Modify get_model_for_request to swap models on-demand when in ondemand mode
- Unload current model from VRAM and load new model when request arrives for different model
- Respect --backend flag when loading models on-demand
- Only activates when no --loadall or --loadswap flag is specified
parent ebfa6892
......@@ -2304,7 +2304,7 @@ class MultiModelManager:
Supports dynamic switching based on request model name.
Modes:
- default: Load models on-demand
- default: Load models on-demand (swap models in VRAM when request changes)
- loadall: Pre-load all models in VRAM at startup
- loadswap: Keep all models in memory (CPU RAM), swap active model to VRAM
"""
......@@ -2327,6 +2327,8 @@ class MultiModelManager:
self.model_aliases: Dict[str, str] = {}
# Whisper server manager
self.whisper_server: Optional[WhisperServerManager] = None
# Track backend type for each model (needed for on-demand loading)
self.model_backend_types: Dict[str, str] = {}
@property
def audio_model(self) -> Optional[str]:
......@@ -2347,10 +2349,11 @@ class MultiModelManager:
"""Set the load mode: 'ondemand', 'loadall', or 'loadswap'."""
self.load_mode = mode
def set_default_model(self, model_name: str, config: Dict = None):
def set_default_model(self, model_name: str, config: Dict = None, backend_type: str = "auto"):
"""Set the default/main text model."""
self.default_model = model_name
self.config[model_name] = config or {}
self.model_backend_types[model_name] = backend_type
def set_audio_model(self, model_name: str, config: Dict = None):
"""Add an audio transcription model."""
......@@ -2393,7 +2396,14 @@ class MultiModelManager:
- "tts:modelname" -> use specific TTS model
- Custom aliases -> resolve to actual model name
- Otherwise match by model ID in multi_model_manager.models
In ondemand mode with multiple text models:
- If requested model is different from currently loaded model,
unload current and load new model on-demand (respecting --backend)
"""
# Import global_args inside function to ensure it's available
global global_args
# Resolve custom aliases first
if requested_model in self.model_aliases:
requested_model = self.model_aliases[requested_model]
......@@ -2489,6 +2499,101 @@ class MultiModelManager:
self.current_model_key = key
return model
# === ON-DEMAND MODEL SWITCHING FOR TEXT MODELS ===
# If we're in ondemand mode and the requested model is in config but not loaded,
# we should try to load it on-demand (swap from current model)
# Only for text models (not audio/image/tts which have their own handling)
# Check if requested model is in our config (means it was registered but not loaded)
if self.load_mode == "ondemand" and requested_model in self.config:
# This is a text model that's registered but not loaded
# We need to swap: unload current model and load this one
# Check if we have a different model currently loaded
if self.current_model_key and self.current_model_key in self.models:
# Unload the current model from VRAM
current_model = self.models[self.current_model_key]
print(f"ON-DEMAND SWAP: Unloading model '{self.current_model_key}' from VRAM to load '{requested_model}'")
current_model.cleanup()
del self.models[self.current_model_key]
# Load the new model on-demand
print(f"ON-DEMAND SWAP: Loading model '{requested_model}' into VRAM")
# Get the backend type for this model
backend_type = self.model_backend_types.get(requested_model, "auto")
# Get config for this model
model_config = self.config.get(requested_model, {})
effective_backend = backend_type
if effective_backend == "auto" and global_args:
effective_backend = getattr(global_args, 'backend', 'auto')
try:
# Create new model manager and load the model
new_manager = ModelManager()
new_manager.load_model(
model_name=requested_model,
backend_type=effective_backend,
**model_config
)
self.models[requested_model] = new_manager
self.current_model_key = requested_model
self.active_in_vram = requested_model
print(f"ON-DEMAND SWAP: Successfully loaded model '{requested_model}' with backend '{effective_backend}'")
return new_manager
except Exception as e:
print(f"ON-DEMAND SWAP: Failed to load model '{requested_model}': {e}")
# Try to restore the previous model if we had one
return None
# Also check if the model matches by short name (e.g., "Phi-3" matches "microsoft/Phi-3-mini-4k-instruct")
if self.load_mode == "ondemand":
for model_name in self.config.keys():
# Only check text models (not audio:, image:, tts: prefixes)
if ":" not in model_name:
short_name = model_name.split("/")[-1] if "/" in model_name else model_name
if requested_model.lower() in short_name.lower() or short_name.lower() in requested_model.lower():
# Found a matching model in config, try to load it
if model_name not in self.models:
# Check if we have a different model currently loaded
if self.current_model_key and self.current_model_key in self.models:
# Unload the current model from VRAM
current_model = self.models[self.current_model_key]
print(f"ON-DEMAND SWAP: Unloading model '{self.current_model_key}' from VRAM to load '{model_name}'")
current_model.cleanup()
del self.models[self.current_model_key]
# Load the new model on-demand
print(f"ON-DEMAND SWAP: Loading model '{model_name}' into VRAM")
# Get the backend type for this model
backend_type = self.model_backend_types.get(model_name, "auto")
# Get config for this model
model_config = self.config.get(model_name, {})
effective_backend = backend_type
if effective_backend == "auto" and global_args:
effective_backend = getattr(global_args, 'backend', 'auto')
try:
new_manager = ModelManager()
new_manager.load_model(
model_name=model_name,
backend_type=effective_backend,
**model_config
)
self.models[model_name] = new_manager
self.current_model_key = model_name
self.active_in_vram = model_name
print(f"ON-DEMAND SWAP: Successfully loaded model '{model_name}' with backend '{effective_backend}'")
return new_manager
except Exception as e:
print(f"ON-DEMAND SWAP: Failed to load model '{model_name}': {e}")
return None
return None
def add_model(self, key: str, manager: ModelManager):
......@@ -3356,7 +3461,23 @@ def save_image_response(img, request_format="base64"):
file_path = os.path.join(global_file_path, filename)
img.save(file_path, format="PNG")
# Add URL to response
result["url"] = f"/v1/files/{filename}"
# Determine base URL based on --url argument
url_setting = getattr(global_args, 'url', 'auto') if global_args else 'auto'
if url_setting == 'auto':
# Use client IP from request
if http_request:
client_host = http_request.client.host if http_request.client else '127.0.0.1'
# Check if HTTPS is enabled
use_https = getattr(global_args, 'https', False) or getattr(global_args, 'pubkey', None)
protocol = "https" if use_https else "http"
port = getattr(global_args, 'port', 8000)
base_url = f"{protocol}://{client_host}:{port}"
else:
base_url = "http://127.0.0.1:8000"
else:
# Use explicitly provided URL (strip trailing slash if present)
base_url = url_setting.rstrip('/')
result["url"] = f"{base_url}/v1/files/{filename}"
# If client explicitly requested base64, include it
# Otherwise, only return URL when file-path is set
......@@ -3377,7 +3498,7 @@ def save_image_response(img, request_format="base64"):
return result
@app.post("/v1/images/generations")
async def create_image_generation(request: ImageGenerationRequest):
async def create_image_generation(request: ImageGenerationRequest, http_request: Request = None):
"""
Image generation endpoint (OpenAI-compatible).
......@@ -4520,6 +4641,12 @@ def parse_args():
default=8000,
help="Port to bind to (default: 8000)",
)
parser.add_argument(
"--url",
type=str,
default="auto",
help="Base URL for media downloads: 'auto' (use request IP) or explicit URL (e.g., http://myserver:8000)",
)
parser.add_argument(
"--https",
action="store_true",
......@@ -5029,7 +5156,7 @@ def main():
**load_kwargs
)
# Register with multi_model_manager
multi_model_manager.set_default_model(first_model_name, load_kwargs)
multi_model_manager.set_default_model(first_model_name, load_kwargs, args.backend)
multi_model_manager.add_model(first_model_name, model_manager)
print(f"\nMain text model loaded: {first_model_name}")
except Exception as e:
......@@ -5121,7 +5248,7 @@ def main():
# Register other models but don't load them
for model_name in model_names[1:]:
multi_model_manager.set_default_model(model_name, load_kwargs)
multi_model_manager.set_default_model(model_name, load_kwargs, args.backend)
print(f"Other models will load on-demand: {model_names[1:]}")
# Model is already loaded at lines 4274-4281
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment