Add fallback for models that don't support load_in_4bit quantization

Modify _try_load_model() to catch TypeError when quantization arguments
are not supported by the model class. When this happens, the method now:
1. Warns the user about unsupported quantization
2. Retries loading the model without quantization arguments
3. Returns the model successfully if loading works

This fixes issues with models like Qwen3.5 that don't support
bitsandbytes quantization.
parent 33a7e421
......@@ -572,7 +572,7 @@ class NvidiaBackend(ModelBackend):
return max_memory
def _try_load_model(self, model_name: str, load_kwargs: dict, device: str) -> Optional[any]:
"""Try to load model with given settings, return None on OOM."""
"""Try to load model with given settings, return None on OOM or retry without quantization if unsupported."""
import torch
from transformers import AutoModelForCausalLM
......@@ -586,6 +586,35 @@ class NvidiaBackend(ModelBackend):
if "out of memory" in error_msg or "cuda" in error_msg or "oom" in error_msg:
return None
raise
except TypeError as e:
error_msg = str(e).lower()
# Check if the error is about unsupported quantization arguments
if "load_in_4bit" in error_msg or "load_in_8bit" in error_msg or "unexpected keyword argument" in error_msg:
# Check if we have quantization args that need to be removed
if 'load_in_4bit' in load_kwargs or 'load_in_8bit' in load_kwargs:
print(f"Warning: Model does not support bitsandbytes quantization (load_in_4bit/load_in_8bit)")
print("Retrying without quantization...")
# Create a copy of load_kwargs without quantization args
retry_kwargs = load_kwargs.copy()
retry_kwargs.pop('load_in_4bit', None)
retry_kwargs.pop('load_in_8bit', None)
# Retry loading without quantization
try:
model = AutoModelForCausalLM.from_pretrained(model_name, **retry_kwargs)
if device == "cpu" and retry_kwargs.get('device_map') is None:
model = model.to(device)
print("Model loaded successfully without quantization")
return model
except (RuntimeError, torch.cuda.OutOfMemoryError) as e2:
error_msg2 = str(e2).lower()
if "out of memory" in error_msg2 or "cuda" in error_msg2 or "oom" in error_msg2:
return None
raise
except TypeError:
# If it still fails with TypeError, re-raise the original error
raise e
# Re-raise if not a quantization-related error
raise
def _is_moe_model(self, model_name: str) -> bool:
"""Check if model is a MoE (Mixture of Experts) model which needs more VRAM headroom."""
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment