Add automatic OOM handling with progressive VRAM reduction fallback for NVIDIA backend

parent 320ca0e7
......@@ -571,8 +571,24 @@ class NvidiaBackend(ModelBackend):
return max_memory
def _try_load_model(self, model_name: str, load_kwargs: dict, device: str) -> Optional[any]:
"""Try to load model with given settings, return None on OOM."""
import torch
from transformers import AutoModelForCausalLM
try:
model = AutoModelForCausalLM.from_pretrained(model_name, **load_kwargs)
if device == "cpu" and load_kwargs.get('device_map') is None:
model = model.to(device)
return model
except (RuntimeError, torch.cuda.OutOfMemoryError) as e:
error_msg = str(e).lower()
if "out of memory" in error_msg or "cuda" in error_msg or "oom" in error_msg:
return None
raise
def load_model(self, model_name: str, **kwargs) -> None:
"""Load the model using HuggingFace Transformers."""
"""Load the model using HuggingFace Transformers with automatic OOM handling."""
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
......@@ -606,16 +622,6 @@ class NvidiaBackend(ModelBackend):
# Prepare model loading arguments
load_kwargs = {'trust_remote_code': True}
# Setup memory management: GPU (93%) → CPU (limit) → Disk
if self.device == "cuda":
max_memory = self._get_gpu_memory_map()
load_kwargs['max_memory'] = max_memory
load_kwargs['device_map'] = 'auto'
print(f" Memory strategy: GPU (93% VRAM) → CPU → Disk")
else:
# CPU-only mode
load_kwargs['device_map'] = None
if load_in_4bit or load_in_8bit:
try:
import bitsandbytes as bnb
......@@ -635,25 +641,90 @@ class NvidiaBackend(ModelBackend):
if offload_dir:
os.makedirs(offload_dir, exist_ok=True)
load_kwargs['offload_folder'] = offload_dir
print(f"Disk offload directory: {offload_dir} (used only when GPU+CPU full)")
# Add Flash Attention 2 if enabled
if self.use_flash_attn and self.flash_attn_available:
load_kwargs['attn_implementation'] = "flash_attention_2"
print("Using Flash Attention 2")
# Load model
self.model = AutoModelForCausalLM.from_pretrained(model_name, **load_kwargs)
if self.device == "cpu" and load_kwargs.get('device_map') is None:
self.model = self.model.to(self.device)
# Try loading with automatic fallback on OOM
model = None
vram_percentages = [0.93, 0.85, 0.75, 0.65, 0.50, 0.35, 0.20, 0.0]
for vram_pct in vram_percentages:
if self.device != "cuda":
# CPU-only mode
load_kwargs['device_map'] = None
print("Loading model in CPU-only mode...")
model = self._try_load_model(model_name, load_kwargs, self.device)
if model is not None:
break
# GPU mode with varying VRAM limits
if vram_pct > 0:
max_memory = self._get_gpu_memory_map_with_limit(vram_pct)
load_kwargs['max_memory'] = max_memory
load_kwargs['device_map'] = 'auto'
print(f"\nTrying with GPU limit: {vram_pct*100:.0f}% VRAM")
if offload_dir:
print(f" Disk offload directory: {offload_dir}")
model = self._try_load_model(model_name, load_kwargs, self.device)
if model is not None:
print(f" ✓ Model loaded successfully with {vram_pct*100:.0f}% GPU VRAM limit")
if vram_pct < 0.93:
print(f" (Reduced from 93% due to memory constraints)")
break
else:
print(f" ✗ Out of memory with {vram_pct*100:.0f}% GPU VRAM, trying lower limit...")
# Clear CUDA cache before retry
if torch.cuda.is_available():
torch.cuda.empty_cache()
else:
# Last resort: CPU-only mode with offloading
print("\nFalling back to CPU-only mode (no GPU layers)...")
load_kwargs['max_memory'] = {0: 0, 'cpu': int((manual_ram_gb or 48) * 1e9)}
load_kwargs['device_map'] = 'auto'
model = self._try_load_model(model_name, load_kwargs, "cpu")
if model is not None:
print(" ✓ Model loaded successfully on CPU")
break
if model is None:
raise RuntimeError("Failed to load model: Out of memory even with minimum GPU usage")
self.model = model
self.model.eval()
self.model_name = model_name
print(f"\nModel loaded successfully")
print(f"Model device: {next(self.model.parameters()).device}")
def _get_gpu_memory_map_with_limit(self, vram_fraction: float) -> Dict:
"""Get max_memory dict with specified VRAM fraction limit."""
import torch
max_memory = {}
if torch.cuda.is_available():
for i in range(torch.cuda.device_count()):
props = torch.cuda.get_device_properties(i)
total_vram = props.total_memory
usable_vram = int(total_vram * vram_fraction)
max_memory[i] = usable_vram
# CPU memory
manual_ram_gb = getattr(self, '_pending_ram_gb', None)
if manual_ram_gb:
max_memory['cpu'] = int(manual_ram_gb * 1e9)
else:
import psutil
available_ram = psutil.virtual_memory().available
usable_ram = max(0, available_ram - int(4e9))
max_memory['cpu'] = usable_ram
return max_memory
def format_messages(self, messages: List[ChatMessage]) -> str:
"""Format messages into a prompt string."""
formatted = []
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment