Fix first occurrence of balanced offload strategy (line 9384)

The balanced offload strategy existed in two places in the code.
The first occurrence at line 9384 was still using the old logic.
Now both occurrences have the improved VRAM estimation with:
- LoRA overhead accounting
- Inference overhead (30%)
- Conservative 70% threshold
- OOM fallback protection
parent 15e14a57
......@@ -9390,22 +9390,42 @@ def main(args):
# Get available VRAM
vram_total = torch.cuda.get_device_properties(0).total_memory / (1024**3)
vram_allocated = torch.cuda.memory_allocated() / (1024**3)
vram_reserved = torch.cuda.memory_reserved() / (1024**3)
vram_available = vram_total - vram_allocated
# Estimate model size from VRAM requirements
# Add overhead for LoRA, inference, and model components not in estimate
model_vram_est = parse_vram_estimate(m_info.get("vram", "~10 GB"))
# If model fits comfortably in available VRAM (with 15% buffer), load fully
if model_vram_est < vram_available * 0.85:
print(f" 📦 Balanced mode: Model (~{model_vram_est:.1f}GB) fits in VRAM ({vram_available:.1f}GB available)")
# Account for various overheads:
# - LoRA weights add ~2-4GB
# - Inference activation memory needs ~20-30% extra
# - Text encoder, VAE, scheduler not always in estimate
is_lora = m_info.get("is_lora", False)
lora_overhead = 4.0 if is_lora else 0.0 # LoRA adds significant overhead
inference_overhead = model_vram_est * 0.3 # 30% for activations during inference
total_vram_needed = model_vram_est + lora_overhead + inference_overhead
# Use conservative 70% threshold (30% safety buffer) for "balanced"
# This ensures we don't OOM during inference
vram_threshold = vram_available * 0.70
if total_vram_needed < vram_threshold:
print(f" 📦 Balanced mode: Model (~{total_vram_needed:.1f}GB needed) fits in VRAM ({vram_available:.1f}GB available)")
print(f" Loading fully to GPU (no offloading)")
pipe = pipe.to("cuda")
try:
pipe = pipe.to("cuda")
except torch.cuda.OutOfMemoryError:
# Fallback if moving to GPU fails
print(f" ⚠️ OOM when loading to GPU, falling back to model CPU offload")
torch.cuda.empty_cache()
gc.collect()
pipe.enable_model_cpu_offload()
else:
# Model too large, use sequential offloading but only for necessary layers
print(f" 📦 Balanced mode: Model (~{model_vram_est:.1f}GB) exceeds VRAM ({vram_available:.1f}GB available)")
print(f" Using selective offloading to maximize VRAM usage")
pipe.enable_sequential_cpu_offload()
# Model too large, use model CPU offload (better than sequential for most cases)
print(f" 📦 Balanced mode: Model (~{total_vram_needed:.1f}GB needed) exceeds safe VRAM ({vram_available:.1f}GB available)")
print(f" Using model CPU offload to prevent OOM")
pipe.enable_model_cpu_offload()
else:
pipe.to("cuda" if torch.cuda.is_available() else "cpu")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment