Commit 08496f1f authored by Your Name's avatar Your Name

Add VRAM cleanup before loading image models

- Cleanup any existing models (text, audio, etc.) from VRAM before loading
  image models to prevent out of memory errors when switching between model types
- Applied to both diffusers and stable-diffusion-cpp loading paths
parent 7cbe5355
...@@ -3629,6 +3629,25 @@ async def create_image_generation(request: ImageGenerationRequest, http_request: ...@@ -3629,6 +3629,25 @@ async def create_image_generation(request: ImageGenerationRequest, http_request:
# Try diffusers first (torch-based, best quality for NVIDIA) # Try diffusers first (torch-based, best quality for NVIDIA)
# Skip if it's a GGUF model (those need stable-diffusion-cpp) # Skip if it's a GGUF model (those need stable-diffusion-cpp)
# First, cleanup any other models to free VRAM
for key in list(multi_model_manager.models.keys()):
# Skip image models
if key.startswith("image:"):
continue
# Unload any other model (text, audio, etc.) to free VRAM
model_to_cleanup = multi_model_manager.models.get(key)
if model_to_cleanup is not None:
print(f"Unloading '{key}' from VRAM to make room for diffusers image model")
try:
if hasattr(model_to_cleanup, 'cleanup') and callable(getattr(model_to_cleanup, 'cleanup')):
model_to_cleanup.cleanup()
elif hasattr(model_to_cleanup, 'model') and model_to_cleanup.model is not None:
if hasattr(model_to_cleanup.model, 'cleanup'):
model_to_cleanup.model.cleanup()
except Exception as e:
print(f"Warning during cleanup of '{key}': {e}")
del multi_model_manager.models[key]
try: try:
import torch import torch
from diffusers import StableDiffusionXLPipeline, DiffusionPipeline from diffusers import StableDiffusionXLPipeline, DiffusionPipeline
...@@ -3740,6 +3759,27 @@ async def create_image_generation(request: ImageGenerationRequest, http_request: ...@@ -3740,6 +3759,27 @@ async def create_image_generation(request: ImageGenerationRequest, http_request:
except ImportError: except ImportError:
pass pass
# If no cached image model found, need to load one - first cleanup any existing models
if sd_model is None:
# Check if there's a text model loaded and unload it to free VRAM
for key in list(multi_model_manager.models.keys()):
# Skip image models
if key.startswith("image:"):
continue
# Unload any other model (text, audio, etc.) to free VRAM
model_to_cleanup = multi_model_manager.models.get(key)
if model_to_cleanup is not None:
print(f"Unloading '{key}' from VRAM to make room for image model")
try:
if hasattr(model_to_cleanup, 'cleanup') and callable(getattr(model_to_cleanup, 'cleanup')):
model_to_cleanup.cleanup()
elif hasattr(model_to_cleanup, 'model') and model_to_cleanup.model is not None:
if hasattr(model_to_cleanup.model, 'cleanup'):
model_to_cleanup.model.cleanup()
except Exception as e:
print(f"Warning during cleanup of '{key}': {e}")
del multi_model_manager.models[key]
if sd_model is not None: if sd_model is not None:
# Check if it's a stable-diffusion-cpp model (has generate method from sd.cpp) # Check if it's a stable-diffusion-cpp model (has generate method from sd.cpp)
try: try:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment