Commit 00775972 authored by Your Name's avatar Your Name

Fix: Centralize model unloading - properly handle all model types in ondemand mode

- Added unload_all_models() to MultiModelManager that handles ALL model types:
  ModelManager, diffusers pipelines, sd.cpp StableDiffusion, and any other objects
- Text endpoints now properly unload image models before loading text models
- Image endpoints now properly unload text models before loading image models
- The rule: in ondemand mode, if the model in VRAM differs from the requested
  model (regardless of type), fully unload before loading the new one
- Includes gc.collect(), torch.cuda.empty_cache(), and 1s settle delay
parent 7d838962
This diff is collapsed.
...@@ -299,63 +299,36 @@ async def chat_completions(request: ChatCompletionRequest, http_request: Request ...@@ -299,63 +299,36 @@ async def chat_completions(request: ChatCompletionRequest, http_request: Request
from codai.api.state import get_load_mode from codai.api.state import get_load_mode
load_mode = get_load_mode() load_mode = get_load_mode()
# Check if there's an image model already loaded in VRAM # In ondemand mode (no --load-all or --loadswap), if ANY model is loaded in VRAM
current_image_model = None # and it's different from what we need, fully unload it first to free VRAM
for key in multi_model_manager.models.keys(): if load_mode == "ondemand":
if key.startswith("image:"): has_any_model = len(multi_model_manager.models) > 0 or model_manager.backend is not None
current_image_model = key
break if has_any_model:
# Check if the requested model is already loaded (no need to unload)
# Check if legacy model_manager has a model loaded already_loaded = False
has_legacy_model = model_manager.backend is not None if requested_model and requested_model in multi_model_manager.models:
already_loaded = True
# In ondemand mode, if any model (text, image, etc.) is already loaded and we're requesting a different model, elif multi_model_manager.default_model and (
# we should unload the current model first to free VRAM not requested_model or requested_model == "default" or
needs_full_unload = (load_mode == "ondemand" and (current_image_model is not None or has_legacy_model)) requested_model == multi_model_manager.default_model
):
# If we're requesting a text model and there's an image model loaded, unload it first if multi_model_manager.default_model in multi_model_manager.models:
if needs_full_unload: already_loaded = True
print(f"In ondemand mode - fully unloading current model before loading text model...")
if not already_loaded:
# Full cleanup: remove all models from VRAM print(f"In ondemand mode - fully unloading current model(s) before loading text model '{requested_model}'...")
for key in list(multi_model_manager.models.keys()):
model_to_cleanup = multi_model_manager.models.get(key) # Use centralized unload method
if model_to_cleanup is not None: multi_model_manager.unload_all_models()
print(f"Unloading '{key}' from VRAM...")
try: # Also cleanup legacy model_manager
if hasattr(model_to_cleanup, 'cleanup') and callable(getattr(model_to_cleanup, 'cleanup')): if model_manager.backend is not None:
model_to_cleanup.cleanup() print("Unloading legacy model_manager from VRAM...")
except Exception as e: try:
print(f"Warning during cleanup of '{key}': {e}") model_manager.cleanup()
del multi_model_manager.models[key] except Exception as e:
print(f"Warning during legacy model cleanup: {e}")
# Also cleanup legacy model_manager
if model_manager.backend is not None:
print("Unloading legacy model_manager from VRAM...")
try:
if hasattr(model_manager.backend, 'unload'):
model_manager.backend.unload()
elif hasattr(model_manager.backend, 'cleanup'):
model_manager.backend.cleanup()
except Exception as e:
print(f"Warning during legacy model cleanup: {e}")
model_manager.backend = None
# Force garbage collection and clear CUDA cache
import gc
gc.collect()
try:
import torch
if torch.cuda.is_available():
torch.cuda.synchronize()
torch.cuda.empty_cache()
print("CUDA cache cleared")
except:
pass
# Add delay to let VRAM settle
import time
time.sleep(1)
# Try to get the appropriate model # Try to get the appropriate model
mm = multi_model_manager.get_model_for_request(requested_model) mm = multi_model_manager.get_model_for_request(requested_model)
...@@ -1754,53 +1727,29 @@ async def completions(request: CompletionRequest): ...@@ -1754,53 +1727,29 @@ async def completions(request: CompletionRequest):
from codai.api.state import get_load_mode from codai.api.state import get_load_mode
load_mode = get_load_mode() load_mode = get_load_mode()
# Check if there's an image model already loaded in VRAM # In ondemand mode, if ANY model is loaded and it's different from what we need, unload first
current_image_model = None if load_mode == "ondemand":
for key in multi_model_manager.models.keys(): has_any_model = len(multi_model_manager.models) > 0 or model_manager.backend is not None
if key.startswith("image:"):
current_image_model = key if has_any_model:
break already_loaded = False
if requested_model and requested_model in multi_model_manager.models:
# In ondemand mode, if any model is already loaded, unload it first already_loaded = True
needs_full_unload = (load_mode == "ondemand" and current_image_model is not None) elif multi_model_manager.default_model and (
not requested_model or requested_model == "default" or
if needs_full_unload: requested_model == multi_model_manager.default_model
print(f"In ondemand mode - fully unloading current model before loading text model...") ):
if multi_model_manager.default_model in multi_model_manager.models:
# Full cleanup already_loaded = True
for key in list(multi_model_manager.models.keys()):
model_to_cleanup = multi_model_manager.models.get(key) if not already_loaded:
if model_to_cleanup is not None: print(f"In ondemand mode - fully unloading current model(s) before loading text model '{requested_model}'...")
print(f"Unloading '{key}' from VRAM...") multi_model_manager.unload_all_models()
try: if model_manager.backend is not None:
if hasattr(model_to_cleanup, 'cleanup') and callable(getattr(model_to_cleanup, 'cleanup')): try:
model_to_cleanup.cleanup() model_manager.cleanup()
except Exception as e: except:
print(f"Warning during cleanup of '{key}': {e}") pass
del multi_model_manager.models[key]
# Also cleanup legacy model_manager
if model_manager.backend is not None:
print("Unloading legacy model_manager from VRAM...")
try:
if hasattr(model_manager.backend, 'unload'):
model_manager.backend.unload()
elif hasattr(model_manager.backend, 'cleanup'):
model_manager.backend.cleanup()
except:
pass
model_manager.backend = None
# Force garbage collection and clear CUDA cache
import gc
gc.collect()
try:
import torch
if torch.cuda.is_available():
torch.cuda.synchronize()
torch.cuda.empty_cache()
except:
pass
# Try to get the appropriate model # Try to get the appropriate model
mm = multi_model_manager.get_model_for_request(requested_model) mm = multi_model_manager.get_model_for_request(requested_model)
......
...@@ -711,6 +711,78 @@ class MultiModelManager: ...@@ -711,6 +711,78 @@ class MultiModelManager:
# Model not found - try to load it as a new model # Model not found - try to load it as a new model
return self._load_model_by_name(requested_model) return self._load_model_by_name(requested_model)
def unload_all_models(self):
"""
Fully unload ALL models from VRAM. Used in ondemand mode when switching
between different model types (e.g., text -> image or image -> text).
This handles all model types:
- ModelManager instances (have cleanup() method)
- Diffusers pipelines (need to be moved to CPU and deleted)
- stable-diffusion-cpp StableDiffusion instances
- Any other model objects
"""
print("=== FULL VRAM CLEANUP: Unloading all models ===")
for key in list(self.models.keys()):
model_obj = self.models.get(key)
if model_obj is None:
continue
print(f"Unloading '{key}' from VRAM...")
try:
# Method 1: ModelManager with cleanup()
if hasattr(model_obj, 'cleanup') and callable(getattr(model_obj, 'cleanup')):
model_obj.cleanup()
# Method 2: Diffusers pipeline (has 'to' method to move to CPU)
elif hasattr(model_obj, 'to') and callable(getattr(model_obj, 'to')):
try:
model_obj.to('cpu')
except:
pass
del model_obj
# Method 3: Object with 'model' attribute (e.g., wrapper)
elif hasattr(model_obj, 'model') and model_obj.model is not None:
if hasattr(model_obj.model, 'cleanup'):
model_obj.model.cleanup()
elif hasattr(model_obj.model, 'to'):
try:
model_obj.model.to('cpu')
except:
pass
del model_obj
# Method 4: Just delete it
else:
del model_obj
except Exception as e:
print(f"Warning during cleanup of '{key}': {e}")
# Remove from dict
if key in self.models:
del self.models[key]
# Reset tracking state
self.current_model_key = None
self.active_in_vram = None
# Force garbage collection
for _ in range(3):
gc.collect()
# Clear CUDA cache if available
try:
import torch
if torch.cuda.is_available():
torch.cuda.synchronize()
torch.cuda.empty_cache()
print("CUDA cache cleared")
except:
pass
# Small delay to let GPU memory settle
time.sleep(1)
print("=== FULL VRAM CLEANUP: Complete ===")
def add_model(self, key: str, manager: ModelManager): def add_model(self, key: str, manager: ModelManager):
"""Add a model manager for a specific key.""" """Add a model manager for a specific key."""
self.models[key] = manager self.models[key] = manager
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment