Commit 7c150a4d authored by Your Name's avatar Your Name

Add aggressive VRAM cleanup for model switching

- Added _aggressive_vram_cleanup method to properly clear VRAM
- Moves model to CPU before deletion
- Deletes pipeline, vae, text_encoder, tokenizer explicitly
- Multiple rounds of gc.collect()
- Uses torch.cuda.synchronize() before clearing cache
- Increased delay to 5 seconds after cleanup
parent 804dac03
......@@ -2164,6 +2164,78 @@ class ModelManager:
self.backend: Optional[ModelBackend] = None
self.backend_type: Optional[str] = None
self.tool_parser = ToolCallParser()
def _aggressive_vram_cleanup(self, model_manager):
"""
Aggressively cleanup VRAM when switching between different model types.
This is more thorough than a simple cleanup() call.
"""
import gc
import time
try:
import torch
# First, try to move model to CPU if it has a model attribute
if hasattr(model_manager, 'model') and model_manager.model is not None:
model = model_manager.model
# If it's a diffusers pipeline, try to move to CPU first
if hasattr(model, 'to'):
try:
model.to('cpu')
except:
pass
# Delete the model
del model
# Also handle backend directly if it's different
if hasattr(model_manager, 'backend') and model_manager.backend is not None:
backend = model_manager.backend
if hasattr(backend, 'model') and backend.model is not None:
model = backend.model
if hasattr(model, 'to'):
try:
model.to('cpu')
except:
pass
del model
if hasattr(backend, 'pipeline') and backend.pipeline is not None:
del backend.pipeline
if hasattr(backend, 'vae') and backend.vae is not None:
del backend.vae
if hasattr(backend, 'text_encoder') and backend.text_encoder is not None:
del backend.text_encoder
if hasattr(backend, 'tokenizer') and backend.tokenizer is not None:
del backend.tokenizer
# Force multiple rounds of garbage collection
for _ in range(3):
gc.collect()
# Clear PyTorch cache
if torch.cuda.is_available():
torch.cuda.synchronize()
torch.cuda.empty_cache()
# Add delay to allow Vulkan to release memory
time.sleep(2)
except Exception as e:
print(f"Warning during aggressive VRAM cleanup: {e}")
finally:
# Try to cleanup the model manager itself
try:
if hasattr(model_manager, 'cleanup'):
model_manager.cleanup()
except:
pass
def load_model(self, model_name: str, backend_type: str = "auto", **kwargs):
"""
......@@ -2607,29 +2679,12 @@ class MultiModelManager:
model_to_cleanup = self.models.get(key)
if model_to_cleanup is not None:
print(f"Unloading image model '{key}' from VRAM to reload text model")
try:
if hasattr(model_to_cleanup, 'cleanup') and callable(getattr(model_to_cleanup, 'cleanup')):
model_to_cleanup.cleanup()
elif hasattr(model_to_cleanup, 'model') and model_to_cleanup.model is not None:
if hasattr(model_to_cleanup.model, 'cleanup'):
model_to_cleanup.model.cleanup()
except Exception as e:
print(f"Warning during cleanup of '{key}': {e}")
self._aggressive_vram_cleanup(model_to_cleanup)
del self.models[key]
# Force garbage collection and clear GPU cache
import gc
gc.collect()
try:
import torch
if torch.cuda.is_available():
torch.cuda.empty_cache()
except:
pass
# Add a longer delay to allow VRAM to be freed (Vulkan needs more time)
import time
time.sleep(3)
time.sleep(5)
# Now try to reload the default model
try:
......@@ -2751,14 +2806,7 @@ class MultiModelManager:
model_to_cleanup = self.models.get(key)
if model_to_cleanup is not None:
print(f"Unloading image model '{key}' from VRAM to make room for text model")
try:
if hasattr(model_to_cleanup, 'cleanup') and callable(getattr(model_to_cleanup, 'cleanup')):
model_to_cleanup.cleanup()
elif hasattr(model_to_cleanup, 'model') and model_to_cleanup.model is not None:
if hasattr(model_to_cleanup.model, 'cleanup'):
model_to_cleanup.model.cleanup()
except Exception as e:
print(f"Warning during cleanup of '{key}': {e}")
self._aggressive_vram_cleanup(model_to_cleanup)
del self.models[key]
# Force garbage collection and clear GPU cache
......@@ -2771,9 +2819,9 @@ class MultiModelManager:
except:
pass
# Add a small delay to allow VRAM to be freed
# Add a longer delay to allow VRAM to be freed (Vulkan needs more time)
import time
time.sleep(1)
time.sleep(5)
# Check if requested model is already loaded - if so, reuse it
if requested_model in self.models:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment