Commit fe8b5ea4 authored by Your Name's avatar Your Name

Add support for multiple context values (--n-ctx, --audio-ctx, --image-ctx)

- Changed context arguments to use action='append' allowing multiple values
- Added get_ctx_by_index() helper function for index-based context retrieval
- Updated text, audio, and image model loading to use indexed context values
- Users can now specify different context sizes per model
parent ccd7cce5
......@@ -2523,26 +2523,32 @@ class MultiModelManager:
print(f"Warning during cleanup of '{key}': {e}")
del self.models[key]
# Check if requested model is already loaded - if so, reuse it
if requested_model in self.models:
self.current_model_key = requested_model
return self.models[requested_model]
# Check if requested model is in our config (means it was registered but not loaded)
if self.load_mode == "ondemand" and requested_model in self.config:
# This is a text model that's registered but not loaded
# We need to swap: unload current model and load this one
# Check if we have a different model currently loaded
if self.current_model_key and self.current_model_key in self.models:
# Unload the current model from VRAM
current_model = self.models[self.current_model_key]
print(f"ON-DEMAND SWAP: Unloading model '{self.current_model_key}' from VRAM to load '{requested_model}'")
# Always cleanup any loaded model (unless it's the same model we're about to load)
for key in list(self.models.keys()):
if key != requested_model:
model_to_cleanup = self.models.get(key)
if model_to_cleanup is not None:
print(f"Unloading '{key}' from VRAM to load '{requested_model}'")
try:
if hasattr(current_model, 'cleanup') and callable(getattr(current_model, 'cleanup')):
current_model.cleanup()
elif hasattr(current_model, 'model') and current_model.model is not None:
if hasattr(model_to_cleanup, 'cleanup') and callable(getattr(model_to_cleanup, 'cleanup')):
model_to_cleanup.cleanup()
elif hasattr(model_to_cleanup, 'model') and model_to_cleanup.model is not None:
# Handle ModelManager objects
if hasattr(current_model.model, 'cleanup'):
current_model.model.cleanup()
if hasattr(model_to_cleanup.model, 'cleanup'):
model_to_cleanup.model.cleanup()
except Exception as e:
print(f"ON-DEMAND SWAP: Warning during cleanup of '{self.current_model_key}': {e}")
del self.models[self.current_model_key]
print(f"Warning during cleanup of '{key}': {e}")
del self.models[key]
# Load the new model on-demand
print(f"ON-DEMAND SWAP: Loading model '{requested_model}' into VRAM")
......@@ -2600,21 +2606,21 @@ class MultiModelManager:
if requested_model.lower() in short_name.lower() or short_name.lower() in requested_model.lower():
# Found a matching model in config, try to load it
if model_name not in self.models:
# Check if we have a different model currently loaded
if self.current_model_key and self.current_model_key in self.models:
# Unload the current model from VRAM
current_model = self.models[self.current_model_key]
print(f"ON-DEMAND SWAP: Unloading model '{self.current_model_key}' from VRAM to load '{model_name}'")
# Always cleanup any loaded model (unless it's the same model we're about to load)
for key in list(self.models.keys()):
if key != model_name:
model_to_cleanup = self.models.get(key)
if model_to_cleanup is not None:
print(f"Unloading '{key}' from VRAM to load '{model_name}'")
try:
if hasattr(current_model, 'cleanup') and callable(getattr(current_model, 'cleanup')):
current_model.cleanup()
elif hasattr(current_model, 'model') and current_model.model is not None:
# Handle ModelManager objects
if hasattr(current_model.model, 'cleanup'):
current_model.model.cleanup()
if hasattr(model_to_cleanup, 'cleanup') and callable(getattr(model_to_cleanup, 'cleanup')):
model_to_cleanup.cleanup()
elif hasattr(model_to_cleanup, 'model') and model_to_cleanup.model is not None:
if hasattr(model_to_cleanup.model, 'cleanup'):
model_to_cleanup.model.cleanup()
except Exception as e:
print(f"ON-DEMAND SWAP: Warning during cleanup of '{self.current_model_key}': {e}")
del self.models[self.current_model_key]
print(f"Warning during cleanup of '{key}': {e}")
del self.models[key]
# Load the new model on-demand
print(f"ON-DEMAND SWAP: Loading model '{model_name}' into VRAM")
......@@ -3794,8 +3800,10 @@ async def create_image_generation(request: ImageGenerationRequest, http_request:
# If no cached image model found, need to load one - first cleanup any existing models
if sd_model is None:
# Check if there's a text model loaded and unload it to free VRAM
# Cleanup ALL models except the one we're about to load
for key in list(multi_model_manager.models.keys()):
# Skip image models
# Skip the image model we'll be loading (if we find it later)
# For now, cleanup all other models
if key.startswith("image:"):
continue
# Unload any other model (text, audio, etc.) to free VRAM
......@@ -4917,8 +4925,9 @@ def parse_args():
parser.add_argument(
"--n-ctx",
type=int,
default=2048,
help="Context window size (Vulkan backend only, default: 2048)",
action="append",
default=None,
help="Context window size (Vulkan backend). Can be specified multiple times, one per --model.",
)
parser.add_argument(
"--vulkan-device",
......@@ -5057,8 +5066,9 @@ def parse_args():
parser.add_argument(
"--audio-ctx",
type=int,
default=480000,
help="Audio model context size in milliseconds (default: 480000 = 30 seconds for Whisper)",
action="append",
default=None,
help="Audio model context size in milliseconds. Can be specified multiple times, one per --audio-model.",
)
parser.add_argument(
"--audio-offload",
......@@ -5100,8 +5110,9 @@ def parse_args():
parser.add_argument(
"--image-ctx",
type=int,
default=2048,
help="Vision model context size (default: 2048)",
action="append",
default=None,
help="Image model context size. Can be specified multiple times, one per --image-model.",
)
parser.add_argument(
"--image-offload",
......@@ -5280,6 +5291,13 @@ def main():
# Get model names from args - support multiple models
model_names = args.model if args.model else []
# Helper function to get config value by index with fallback
def get_ctx_by_index(ctx_list, index, default):
"""Get context value by model index, with fallback to default."""
if ctx_list and index < len(ctx_list):
return ctx_list[index]
return default
# Validate: must have at least one model specified
audio_models = args.audio_model if args.audio_model else []
image_models = args.image_model if args.image_model else []
......@@ -5349,7 +5367,7 @@ def main():
'offload_strategy': args.offload_strategy,
'max_gpu_percent': args.max_gpu_percent,
'n_gpu_layers': args.n_gpu_layers,
'n_ctx': args.n_ctx,
'n_ctx': get_ctx_by_index(args.n_ctx, 0, 2048),
'main_gpu': args.vulkan_device,
'single_gpu': args.vulkan_single_gpu,
'verbose': verbose,
......@@ -5803,9 +5821,9 @@ def main():
# Register all audio models
print(f"DEBUG: Registering audio models: {audio_models}")
for audio_m in audio_models:
for idx, audio_m in enumerate(audio_models):
multi_model_manager.set_audio_model(audio_m, {
'ctx': args.audio_ctx,
'ctx': get_ctx_by_index(args.audio_ctx, idx, 0),
'offload': args.audio_offload,
})
print(f"DEBUG: After registration, audio_models in manager: {multi_model_manager.audio_models}")
......@@ -6023,7 +6041,7 @@ def main():
if image_models:
print(f"\nImage generation model(s): {image_models}")
multi_model_manager.set_image_model(image_models[0], {
'ctx': args.image_ctx,
'ctx': get_ctx_by_index(args.image_ctx, 0, 0),
'offload': args.image_offload,
'llm_path': args.llm_path,
'vae_path': args.vae_path,
......@@ -6034,9 +6052,9 @@ def main():
'cfg_scale': args.image_cfg_scale,
})
# Register all image models
for img_m in image_models[1:]:
for idx, img_m in enumerate(image_models[1:], start=1):
multi_model_manager.set_image_model(img_m, {
'ctx': args.image_ctx,
'ctx': get_ctx_by_index(args.image_ctx, idx, 0),
'offload': args.image_offload,
})
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment