Commit 2b8e8b37 authored by Your Name's avatar Your Name

Add URL download and caching for --clip-l-path and --vae-path

When a URL is passed to --clip-l-path or --vae-path, the model is now automatically downloaded and cached.
parent d784e7e9
...@@ -4653,10 +4653,32 @@ def main(): ...@@ -4653,10 +4653,32 @@ def main():
t5xxl_path = None t5xxl_path = None
vae_path = None vae_path = None
# Use CLI arguments if provided # Use CLI arguments if provided, download and cache if URL
if args.clip_l_path: if args.clip_l_path:
# Check if it's a URL and download if needed
if args.clip_l_path.startswith('http://') or args.clip_l_path.startswith('https://'):
cached = get_cached_model_path(args.clip_l_path)
if cached:
clip_l_path = cached
print(f"Using cached CLIP model: {clip_l_path}")
else:
cache_dir = get_model_cache_dir()
clip_l_path = download_model(args.clip_l_path, cache_dir)
print(f"Downloaded CLIP model to: {clip_l_path}")
else:
clip_l_path = args.clip_l_path clip_l_path = args.clip_l_path
if args.vae_path: if args.vae_path:
# Check if it's a URL and download if needed
if args.vae_path.startswith('http://') or args.vae_path.startswith('https://'):
cached = get_cached_model_path(args.vae_path)
if cached:
vae_path = cached
print(f"Using cached VAE model: {vae_path}")
else:
cache_dir = get_model_cache_dir()
vae_path = download_model(args.vae_path, cache_dir)
print(f"Downloaded VAE model to: {vae_path}")
else:
vae_path = args.vae_path vae_path = args.vae_path
# Look for common file patterns only if CLI args not provided # Look for common file patterns only if CLI args not provided
...@@ -5113,10 +5135,32 @@ def main(): ...@@ -5113,10 +5135,32 @@ def main():
t5xxl_path = None t5xxl_path = None
vae_path = None vae_path = None
# Use CLI arguments if provided # Use CLI arguments if provided, download and cache if URL
if args.clip_l_path: if args.clip_l_path:
# Check if it's a URL and download if needed
if args.clip_l_path.startswith('http://') or args.clip_l_path.startswith('https://'):
cached = get_cached_model_path(args.clip_l_path)
if cached:
clip_l_path = cached
print(f"Using cached CLIP model: {clip_l_path}")
else:
cache_dir = get_model_cache_dir()
clip_l_path = download_model(args.clip_l_path, cache_dir)
print(f"Downloaded CLIP model to: {clip_l_path}")
else:
clip_l_path = args.clip_l_path clip_l_path = args.clip_l_path
if args.vae_path: if args.vae_path:
# Check if it's a URL and download if needed
if args.vae_path.startswith('http://') or args.vae_path.startswith('https://'):
cached = get_cached_model_path(args.vae_path)
if cached:
vae_path = cached
print(f"Using cached VAE model: {vae_path}")
else:
cache_dir = get_model_cache_dir()
vae_path = download_model(args.vae_path, cache_dir)
print(f"Downloaded VAE model to: {vae_path}")
else:
vae_path = args.vae_path vae_path = args.vae_path
# Look for common file patterns only if CLI args not provided # Look for common file patterns only if CLI args not provided
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment