Commit f498093e authored by Your Name's avatar Your Name

Add GGUF model support and extended stable-diffusion-cpp options

- Detect GGUF models and skip diffusers, use stable-diffusion-cpp instead
- Add HuggingFace model ID resolution for GGUF files
- Add support for VAE, LLM, T5XXL paths from CLI args
- Add clip_on_cpu support for VRAM savings
- Use all available CPU cores instead of hardcoded 4 threads
parent 2dfbce90
......@@ -3618,10 +3618,19 @@ async def create_image_generation(request: ImageGenerationRequest, http_request:
pass
# Try diffusers first (torch-based, best quality for NVIDIA)
# Skip if it's a GGUF model (those need stable-diffusion-cpp)
try:
import torch
from diffusers import StableDiffusionXLPipeline, DiffusionPipeline
# Check if this is a GGUF model - skip diffusers for those
is_gguf_model = (model_to_use.endswith('.gguf') or 'gguf' in model_to_use.lower() or
(model_to_use.startswith('http') and '.gguf' in model_to_use))
if is_gguf_model:
print(f"GGUF model detected ({model_to_use}), skipping diffusers, using stable-diffusion-cpp...")
raise Exception("GGUF model - use stable-diffusion-cpp instead")
# Determine model key
model_key = f"image:{model_to_use}"
pipeline = multi_model_manager.get_model(model_key)
......@@ -3796,6 +3805,7 @@ async def create_image_generation(request: ImageGenerationRequest, http_request:
from stable_diffusion_cpp import StableDiffusion
# Check if model_to_use is a URL and get cached path
# Also handle HuggingFace model IDs that need to be resolved
model_path = None
if model_to_use.startswith('http://') or model_to_use.startswith('https://'):
cached_path = get_cached_model_path(model_to_use)
......@@ -3810,6 +3820,40 @@ async def create_image_generation(request: ImageGenerationRequest, http_request:
print(f"Downloaded to: {model_path}")
elif os.path.isfile(model_to_use):
model_path = model_to_use
else:
# Try to resolve as HuggingFace model ID
print(f"Trying to resolve as HuggingFace model ID: {model_to_use}")
try:
from huggingface_hub import hf_hub_download, list_repo_files
# Parse model name (format: "org/model" or "org/model/filename.gguf")
parts = model_to_use.split('/')
if len(parts) >= 2:
repo_id = f"{parts[0]}/{parts[1]}"
# First check if there's a cached GGUF file for this model
# Try common GGUF file patterns
files = list_repo_files(repo_id)
gguf_files = [f for f in files if f.endswith('.gguf')]
if gguf_files:
# Try to find a cached version first
for gguf_file in gguf_files:
# Construct potential URL and check cache
potential_url = f"https://huggingface.co/{repo_id}/resolve/main/{gguf_file}"
cached = get_cached_model_path(potential_url)
if cached:
model_path = cached
print(f"Using cached GGUF model: {model_path}")
break
# If not cached, download the first GGUF file
if not model_path:
print(f"Downloading GGUF model from HF: {gguf_files[0]}")
model_path = hf_hub_download(repo_id=repo_id, filename=gguf_files[0])
print(f"Downloaded to: {model_path}")
except Exception as e:
print(f"Could not resolve as HuggingFace model: {e}")
if model_path is None:
print("Warning: Could not resolve sd.cpp model path")
......@@ -3829,11 +3873,62 @@ async def create_image_generation(request: ImageGenerationRequest, http_request:
else:
print(f"Using Vulkan backend for sd.cpp image generation")
sd_model = StableDiffusion(
model_path=model_path,
vae_path=None,
n_threads=4,
)
# Build kwargs for stable-diffusion-cpp with CLI args
sd_kwargs = {'diffusion_model_path': model_path}
# Add VAE path from CLI args if provided
vae_path = getattr(global_args, 'vae_path', None)
if vae_path:
# Check if it's a URL and download if needed
if vae_path.startswith('http://') or vae_path.startswith('https://'):
cached = get_cached_model_path(vae_path)
if cached:
sd_kwargs['vae_path'] = cached
print(f"Using cached VAE model: {cached}")
else:
cache_dir = get_model_cache_dir()
sd_kwargs['vae_path'] = download_model(vae_path, cache_dir)
else:
sd_kwargs['vae_path'] = vae_path
# Add LLM/CLIP path from CLI args if provided
llm_path = getattr(global_args, 'llm_path', None)
if llm_path:
if llm_path.startswith('http://') or llm_path.startswith('https://'):
cached = get_cached_model_path(llm_path)
if cached:
sd_kwargs['llm_path'] = cached
print(f"Using cached LLM model: {cached}")
else:
cache_dir = get_model_cache_dir()
sd_kwargs['llm_path'] = download_model(llm_path, cache_dir)
else:
sd_kwargs['llm_path'] = llm_path
# Add T5XXL path from CLI args if provided
t5xxl_path = getattr(global_args, 't5xxl_path', None)
if t5xxl_path:
if t5xxl_path.startswith('http://') or t5xxl_path.startswith('https://'):
cached = get_cached_model_path(t5xxl_path)
if cached:
sd_kwargs['t5xxl_path'] = cached
print(f"Using cached T5XXL model: {cached}")
else:
cache_dir = get_model_cache_dir()
sd_kwargs['t5xxl_path'] = download_model(t5xxl_path, cache_dir)
else:
sd_kwargs['t5xxl_path'] = t5xxl_path
# Add clip_on_cpu if specified
if getattr(global_args, 'clip_on_cpu', False):
sd_kwargs['keep_clip_on_cpu'] = True
print(f"DEBUG: Running CLIP on CPU to save VRAM (keep_clip_on_cpu=True)")
# Use all available CPU cores
import psutil
sd_kwargs['n_threads'] = psutil.cpu_count()
sd_model = StableDiffusion(**sd_kwargs)
print(f"Using stable-diffusion-cpp-python for image generation")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment