Commit bf1d3f52 authored by Your Name's avatar Your Name

Fix offload-strategy parameter passing to CUDA backend

- Add offload_strategy to kwargs in _load_default_model and _load_model_by_name
- Fix parameter name: ram -> manual_ram_gb to match backend expectation
- Also pass load_in_4bit, load_in_8bit, and max_gpu_percent
parent beded066
......@@ -527,11 +527,19 @@ class MultiModelManager:
if hasattr(global_args, 'offload_dir'):
kwargs['offload_dir'] = global_args.offload_dir
if hasattr(global_args, 'ram'):
kwargs['ram'] = global_args.ram
kwargs['manual_ram_gb'] = global_args.ram
if hasattr(global_args, 'flash_attn'):
kwargs['flash_attn'] = global_args.flash_attn
if hasattr(global_args, 'no_ram'):
kwargs['no_ram'] = global_args.no_ram
if hasattr(global_args, 'offload_strategy'):
kwargs['offload_strategy'] = global_args.offload_strategy
if hasattr(global_args, 'load_in_4bit'):
kwargs['load_in_4bit'] = global_args.load_in_4bit
if hasattr(global_args, 'load_in_8bit'):
kwargs['load_in_8bit'] = global_args.load_in_8bit
if hasattr(global_args, 'max_gpu_percent'):
kwargs['max_gpu_percent'] = global_args.max_gpu_percent
print(f"Loading default model on demand: {self.default_model}")
model_manager.load_model(self.default_model, backend_type=backend_type, **kwargs)
......@@ -578,11 +586,19 @@ class MultiModelManager:
if hasattr(global_args, 'offload_dir'):
kwargs['offload_dir'] = global_args.offload_dir
if hasattr(global_args, 'ram'):
kwargs['ram'] = global_args.ram
kwargs['manual_ram_gb'] = global_args.ram
if hasattr(global_args, 'flash_attn'):
kwargs['flash_attn'] = global_args.flash_attn
if hasattr(global_args, 'no_ram'):
kwargs['no_ram'] = global_args.no_ram
if hasattr(global_args, 'offload_strategy'):
kwargs['offload_strategy'] = global_args.offload_strategy
if hasattr(global_args, 'load_in_4bit'):
kwargs['load_in_4bit'] = global_args.load_in_4bit
if hasattr(global_args, 'load_in_8bit'):
kwargs['load_in_8bit'] = global_args.load_in_8bit
if hasattr(global_args, 'max_gpu_percent'):
kwargs['max_gpu_percent'] = global_args.max_gpu_percent
print(f"Loading model on demand: {model_name}")
model_manager.load_model(model_name, backend_type=backend_type, **kwargs)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment