Commit 096b75d2 authored by Your Name's avatar Your Name

Add OOM handling and sequential offload for diffusers

- Enable sequential CPU offload if --offload-strategy or --offload-dir is specified
- Add retry logic: on OOM, retry with attention_slicing, then with sequential_offload
- Clear CUDA cache between retry attempts
parent 782612ea
......@@ -3752,35 +3752,74 @@ async def create_image_generation(request: ImageGenerationRequest, http_request:
torch_dtype = precision_map.get(precision, torch.float32)
print(f"Using precision: {precision} ({torch_dtype})")
# Try to load as Stable Diffusion XL first
try:
pipeline = StableDiffusionXLPipeline.from_pretrained(
model_to_use,
torch_dtype=torch_dtype,
use_safetensors=True,
)
except Exception:
# Try generic diffusion pipeline
pipeline = DiffusionPipeline.from_pretrained(
model_to_use,
torch_dtype=torch_dtype,
use_safetensors=True,
)
# Check if offload strategy is specified (for auto-OOM handling)
offload_strategy = getattr(global_args, 'offload_strategy', None)
offload_dir = getattr(global_args, 'offload_dir', None)
use_sequential_offload = offload_strategy is not None or offload_dir is not None
# Move to GPU if available
if torch.cuda.is_available():
pipeline = pipeline.to("cuda")
else:
pipeline = pipeline.to("cpu")
# Track loading attempts for OOM handling
load_attempt = 0
max_attempts = 3
pipeline = None
# Enable attention slicing for lower memory usage
if torch.cuda.is_available():
pipeline.enable_attention_slicing()
while pipeline is None and load_attempt < max_attempts:
try:
load_attempt += 1
print(f"Loading attempt {load_attempt}/{max_attempts}...")
# Try to load as Stable Diffusion XL first
try:
pipeline = StableDiffusionXLPipeline.from_pretrained(
model_to_use,
torch_dtype=torch_dtype,
use_safetensors=True,
)
except Exception:
# Try generic diffusion pipeline
pipeline = DiffusionPipeline.from_pretrained(
model_to_use,
torch_dtype=torch_dtype,
use_safetensors=True,
)
# Apply memory optimizations based on attempt
if torch.cuda.is_available():
if load_attempt >= 2:
# Second attempt: enable attention slicing
print("Enabling attention slicing for lower VRAM usage...")
pipeline.enable_attention_slicing()
if load_attempt >= 3 or use_sequential_offload:
# Third attempt or offload requested: enable sequential CPU offload
print("Enabling sequential CPU offload for lower VRAM usage...")
pipeline.enable_sequential_cpu_offload()
else:
# First attempt: try regular GPU
pipeline = pipeline.to("cuda")
else:
pipeline = pipeline.to("cpu")
except Exception as load_error:
error_msg = str(load_error).lower()
is_oom = any(x in error_msg for x in ['out of memory', 'oom', 'cuda error', 'cudamalloc'])
if is_oom and load_attempt < max_attempts:
print(f"OOM during model loading: {load_error}")
print(f"Retrying with more aggressive memory optimization...")
pipeline = None # Reset for retry
# Clear CUDA cache
if torch.cuda.is_available():
torch.cuda.empty_cache()
else:
raise load_error
# Enable VAE tiling if requested (for lower VRAM usage)
if getattr(global_args, 'vae_tiling', False):
print("Enabling VAE tiling for lower VRAM usage...")
pipeline.enable_vae_tiling()
try:
pipeline.enable_vae_tiling()
except Exception as e:
print(f"Warning: Could not enable VAE tiling: {e}")
multi_model_manager.add_model(model_key, pipeline)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment