Add HF_TOKEN authentication support for gated/private models

- Add HF_TOKEN support to main pipeline loading (pipe_kwargs)
- Add HF_TOKEN support to VAE loading for Wan models
- Add HF_TOKEN support to image model loading for I2V mode
- Enhanced pipeline detection with multiple strategies
- Improved error messages for authentication errors (401, gated models)
- Added debug output for HF token status
parent bcbae548
Pipeline #225 canceled with stages
......@@ -146,6 +146,7 @@ PIPELINE_CLASS_MAP = {
"TextToVideoZeroPipeline": {"type": "t2v", "default_vram": "~6-8 GB"},
"MochiPipeline": {"type": "t2v", "default_vram": "~18-22 GB"},
"StableDiffusionXLPipeline": {"type": "image", "default_vram": "~10-16 GB"},
"StableDiffusion3Pipeline": {"type": "image", "default_vram": "~15-20 GB"},
"FluxPipeline": {"type": "image", "default_vram": "~20-25 GB"},
"AllegroPipeline": {"type": "t2v", "default_vram": "~35-45 GB"},
"HunyuanDiTPipeline": {"type": "t2v", "default_vram": "~40-55 GB"},
......@@ -154,6 +155,7 @@ PIPELINE_CLASS_MAP = {
"StepVideoPipeline": {"type": "t2v", "default_vram": "~90-140 GB"},
"CogVideoXPipeline": {"type": "t2v", "default_vram": "~20-30 GB"},
"HotshotXLPipeline": {"type": "video", "default_vram": "~8-12 GB"},
"LattePipeline": {"type": "t2v", "default_vram": "~20-30 GB"},
# Generic pipeline - auto-detects model type from loaded model
"DiffusionPipeline": {"type": "auto", "default_vram": "~10-30 GB"},
}
......@@ -192,7 +194,15 @@ def save_models_config(models):
def validate_hf_model(model_id, hf_token=None, debug=False):
"""Validate if a HuggingFace model exists and get its info"""
"""Validate if a HuggingFace model exists and get its info
Fetches comprehensive model information including:
- Basic metadata (tags, downloads, likes)
- Pipeline tag (text-to-video, image-to-video, etc.)
- Library name (diffusers, transformers, etc.)
- Model config (if available)
- siblings (files in the repo)
"""
headers = {}
if hf_token:
headers["Authorization"] = f"Bearer {hf_token}"
......@@ -211,12 +221,32 @@ def validate_hf_model(model_id, hf_token=None, debug=False):
if debug:
print(f" [DEBUG] Sending request...")
with urllib.request.urlopen(req, timeout=10) as response:
with urllib.request.urlopen(req, timeout=15) as response:
if debug:
print(f" [DEBUG] Response status: {response.status}")
data = json.loads(response.read().decode())
if debug:
print(f" [DEBUG] Model found! Tags: {data.get('tags', [])[:5]}")
print(f" [DEBUG] Model found!")
print(f" [DEBUG] Tags: {data.get('tags', [])[:5]}")
print(f" [DEBUG] Pipeline tag: {data.get('pipeline_tag', 'N/A')}")
print(f" [DEBUG] Library: {data.get('library_name', 'N/A')}")
# Try to fetch model_index.json for diffusers models
# This contains the actual pipeline class name
if data.get("library_name") == "diffusers" or "diffusers" in data.get("tags", []):
try:
config_url = f"https://huggingface.co/{model_id}/raw/main/model_index.json"
config_req = urllib.request.Request(config_url, headers=headers)
with urllib.request.urlopen(config_req, timeout=10) as config_response:
config_data = json.loads(config_response.read().decode())
data["model_index"] = config_data
if debug:
print(f" [DEBUG] model_index.json found: {config_data.get('_class_name', 'N/A')}")
except:
if debug:
print(f" [DEBUG] No model_index.json found (may not be a diffusers model)")
return data
except urllib.error.HTTPError as e:
if debug:
......@@ -251,20 +281,70 @@ def validate_hf_model(model_id, hf_token=None, debug=False):
def detect_pipeline_class(model_info):
"""Try to detect the pipeline class from model info"""
"""Try to detect the pipeline class from model info
Uses multiple strategies in order of reliability:
1. Check model_index.json config from HuggingFace API
2. Check pipeline_tag from HuggingFace API
3. Check model ID patterns
4. Check tags
5. Check library_name
"""
tags = model_info.get("tags", [])
library_name = model_info.get("library_name", "")
model_id = model_info.get("id", "").lower()
# Check model ID patterns
if "svd" in model_id or "stable-video-diffusion" in model_id:
return "StableVideoDiffusionPipeline"
if "wan" in model_id and "i2v" in model_id:
return "WanPipeline"
if "wan" in model_id and "t2v" in model_id:
pipeline_tag = model_info.get("pipeline_tag", "")
# 1. Check for explicit pipeline class in model config (most reliable)
# Some models have this in their config
config = model_info.get("config", {})
if config:
# Check for diffusers pipeline info
if "diffusers" in config:
diffusers_config = config["diffusers"]
if isinstance(diffusers_config, dict) and "pipeline_class" in diffusers_config:
return diffusers_config["pipeline_class"]
# Check for model_index.json info
if "model_index" in config:
model_index = config["model_index"]
if isinstance(model_index, dict) and "_class_name" in model_index:
class_name = model_index["_class_name"]
if class_name and class_name in PIPELINE_CLASS_MAP:
return class_name
# 2. Check pipeline_tag from HuggingFace API (very reliable)
if pipeline_tag:
pipeline_tag_lower = pipeline_tag.lower()
if pipeline_tag_lower == "text-to-video":
# Check if it's I2V or T2V
if "image-to-video" in tags or "i2v" in model_id:
return "StableVideoDiffusionPipeline"
return "WanPipeline"
elif pipeline_tag_lower == "image-to-video":
return "StableVideoDiffusionPipeline"
elif pipeline_tag_lower == "text-to-image":
# Check for specific image models
if "flux" in model_id:
return "FluxPipeline"
return "StableDiffusionXLPipeline"
elif pipeline_tag_lower == "image-to-image":
return "StableDiffusionXLPipeline"
# 3. Check model ID patterns (specific models first)
# Wan models (check for version patterns)
if "wan2.1" in model_id or "wan2.2" in model_id:
return "WanPipeline"
if "wan2" in model_id:
return "WanPipeline"
if "wan" in model_id:
return "WanPipeline"
# Stable Video Diffusion
if "stable-video-diffusion" in model_id or "svd" in model_id:
return "StableVideoDiffusionPipeline"
# Other video models
if "i2vgen" in model_id:
return "I2VGenXLPipeline"
if "ltx-video" in model_id or "ltxvideo" in model_id:
......@@ -273,19 +353,13 @@ def detect_pipeline_class(model_info):
return "AnimateDiffPipeline"
if "mochi" in model_id:
return "MochiPipeline"
if "flux" in model_id:
return "FluxPipeline"
if "pony" in model_id or "animagine" in model_id:
return "StableDiffusionXLPipeline"
if "sdxl" in model_id or "xl" in model_id:
return "StableDiffusionXLPipeline"
if "allegro" in model_id:
return "AllegroPipeline"
if "hunyuan" in model_id:
return "HunyuanDiTPipeline"
if "open-sora" in model_id or "opensora" in model_id:
return "OpenSoraPipeline"
if "cogvideox" in model_id:
if "cogvideox" in model_id or "cogvideo" in model_id:
return "CogVideoXPipeline"
if "hotshot" in model_id:
return "HotshotXLPipeline"
......@@ -293,24 +367,47 @@ def detect_pipeline_class(model_info):
return "TextToVideoZeroPipeline"
if "modelscope" in model_id or "text-to-video-ms" in model_id:
return "TextToVideoSDPipeline"
# Check for qwen or other diffusers models that use generic DiffusionPipeline
if "qwen" in model_id or "diffusers" in model_id:
return "DiffusionPipeline"
if "lumina" in model_id:
return "LuminaVideoPipeline"
if "stepvideo" in model_id or "step-video" in model_id:
return "StepVideoPipeline"
# Check tags
# Image models
if "flux" in model_id:
return "FluxPipeline"
if "pony" in model_id or "animagine" in model_id:
return "StableDiffusionXLPipeline"
if "sdxl" in model_id or "stable-diffusion-xl" in model_id:
return "StableDiffusionXLPipeline"
if "sd3" in model_id or "stable-diffusion-3" in model_id:
return "StableDiffusion3Pipeline"
# 4. Check tags for model type
if "video" in tags:
if "image-to-video" in tags:
if "image-to-video" in tags or "i2v" in tags:
return "StableVideoDiffusionPipeline"
if "text-to-video" in tags:
return "WanPipeline"
return "WanPipeline"
if "text-to-image" in tags:
if "flux" in model_id:
return "FluxPipeline"
return "StableDiffusionXLPipeline"
# Check library
if "image-to-image" in tags:
return "StableDiffusionXLPipeline"
# 5. Check library name
if library_name == "diffusers":
# Use generic DiffusionPipeline for diffusers models
# This allows loading any diffusers-compatible model
return "DiffusionPipeline"
# 6. Check for specific patterns that indicate generic diffusers
if "diffusers" in model_id:
return "DiffusionPipeline"
return None
......@@ -460,6 +557,26 @@ def search_hf_models(query, limit=20, hf_token=None):
nsfw_keywords = ["nsfw", "adult", "uncensored", "porn", "explicit"]
is_nsfw = any(kw in model_id.lower() for kw in nsfw_keywords)
# Build model_info dict for detect_pipeline_class
model_info = {
"id": model_id,
"tags": tags,
"pipeline_tag": m.get("pipeline_tag", ""),
"library_name": m.get("library_name", ""),
"config": m.get("config", {}),
}
# Try to fetch model_index.json for diffusers models
if m.get("library_name") == "diffusers" or "diffusers" in tags:
try:
config_url = f"https://huggingface.co/{model_id}/raw/main/model_index.json"
config_req = urllib.request.Request(config_url, headers=headers)
with urllib.request.urlopen(config_req, timeout=5) as config_response:
config_data = json.loads(config_response.read().decode())
model_info["model_index"] = config_data
except:
pass
results.append({
"id": model_id,
"downloads": m.get("downloads", 0),
......@@ -469,7 +586,9 @@ def search_hf_models(query, limit=20, hf_token=None):
"is_video": is_video,
"is_image": is_image,
"is_nsfw": is_nsfw,
"pipeline_class": detect_pipeline_class(m) or "Unknown",
"pipeline_class": detect_pipeline_class(model_info) or "Unknown",
"pipeline_tag": m.get("pipeline_tag", ""),
"library_name": m.get("library_name", ""),
})
return results
......@@ -874,7 +993,7 @@ def update_all_models(hf_token=None):
print(" (These models may require significant VRAM - 40GB to 140GB)")
print()
for model_id, pipeline_class, vram_est, description in known_large_models:
for model_id, default_pipeline_class, vram_est, description in known_large_models:
if model_id in seen_ids:
continue
seen_ids.add(model_id)
......@@ -891,19 +1010,30 @@ def update_all_models(hf_token=None):
name = f"{base_name}_{counter}"
counter += 1
# Try to validate (but don't skip if it fails)
# Try to validate and detect actual pipeline class
model_info = validate_hf_model(model_id, hf_token=hf_token)
detected_pipeline = None
if model_info:
tags = model_info.get("tags", [])
downloads = model_info.get("downloads", 0)
likes = model_info.get("likes", 0)
is_i2v = any(t in tags for t in ["image-to-video", "i2v"]) or "i2v" in model_id.lower()
# Try to detect actual pipeline class from model_index.json
detected_pipeline = detect_pipeline_class(model_info)
if detected_pipeline:
pipeline_class = detected_pipeline
print(f" 🔍 Detected pipeline: {pipeline_class} for {model_id}")
else:
pipeline_class = default_pipeline_class
else:
# Add anyway with defaults
tags = ["video", "text-to-video", "large-model"]
downloads = 0
likes = 0
is_i2v = "i2v" in model_id.lower()
pipeline_class = default_pipeline_class
print(f" ⚠️ Could not validate {model_id} - adding with defaults")
# Build entry
......@@ -921,7 +1051,7 @@ def update_all_models(hf_token=None):
}
all_models[name] = model_entry
print(f" ✅ {name}: {model_id} ({vram_est})")
print(f" ✅ {name}: {model_id} ({vram_est}) [{pipeline_class}]")
for query, limit in search_queries:
print(f"\n🔍 Searching: '{query}' (limit: {limit})")
......@@ -944,7 +1074,9 @@ def update_all_models(hf_token=None):
"TextToVideoSDPipeline", "TextToVideoZeroPipeline",
"HotshotXLPipeline", "AllegroPipeline",
"HunyuanDiTPipeline", "OpenSoraPipeline",
"LuminaVideoPipeline", "StepVideoPipeline"]
"LuminaVideoPipeline", "StepVideoPipeline",
"DiffusionPipeline", "FluxPipeline",
"StableDiffusionXLPipeline", "StableDiffusion3Pipeline"]
if not (is_video_model or is_nsfw_model or is_known_pipeline):
continue
......@@ -961,12 +1093,20 @@ def update_all_models(hf_token=None):
name = f"{base_name}_{counter}"
counter += 1
# Determine pipeline class
# Use pipeline class from search results (already detected via detect_pipeline_class)
pipeline_class = m["pipeline_class"]
if pipeline_class == "Unknown":
pipeline_class = "WanPipeline" if m["is_video"] else "StableDiffusionXLPipeline"
# Fallback based on model type
if m["is_i2v"]:
pipeline_class = "StableVideoDiffusionPipeline"
elif m["is_video"]:
pipeline_class = "WanPipeline"
elif m["is_image"]:
pipeline_class = "StableDiffusionXLPipeline"
else:
pipeline_class = "DiffusionPipeline"
# Determine VRAM estimate
# Determine VRAM estimate from pipeline class
vram_est = PIPELINE_CLASS_MAP.get(pipeline_class, {}).get("default_vram", "~10-20 GB")
# Detect if LoRA
......@@ -991,13 +1131,15 @@ def update_all_models(hf_token=None):
"likes": m.get("likes", 0),
"is_lora": is_lora,
"auto_added": True,
"pipeline_tag": m.get("pipeline_tag", ""),
"library_name": m.get("library_name", ""),
}
if base_model:
model_entry["base_model"] = base_model
all_models[name] = model_entry
print(f" ✅ {name}: {model_id}")
print(f" ✅ {name}: {model_id} [{pipeline_class}]")
# Also search for safetensors files (community models)
print(f"\n" + "-" * 60)
......@@ -1046,10 +1188,15 @@ def update_all_models(hf_token=None):
name = f"{base_name}_{counter}"
counter += 1
# Determine pipeline class
# Use pipeline class from search results
pipeline_class = m["pipeline_class"]
if pipeline_class == "Unknown":
pipeline_class = "WanPipeline" if m["is_video"] else "StableDiffusionXLPipeline"
if m["is_i2v"]:
pipeline_class = "StableVideoDiffusionPipeline"
elif m["is_video"]:
pipeline_class = "WanPipeline"
else:
pipeline_class = "StableDiffusionXLPipeline"
# Determine VRAM estimate
vram_est = PIPELINE_CLASS_MAP.get(pipeline_class, {}).get("default_vram", "~10-20 GB")
......@@ -1080,7 +1227,7 @@ def update_all_models(hf_token=None):
model_entry["file_url"] = f"https://huggingface.co/{model_id}/blob/main/{primary_file}"
all_models[name] = model_entry
print(f" ✅ [safetensors] {name}: {model_id} ({len(safetensor_files)} files)")
print(f" ✅ [safetensors] {name}: {model_id} ({len(safetensor_files)} files) [{pipeline_class}]")
print(f"\n" + "=" * 60)
print(f"📊 Found {len(all_models)} new models from HuggingFace")
......@@ -2955,12 +3102,19 @@ def main(args):
print(f" Model estimated VRAM: {parse_vram_estimate(m_info['vram']):.1f} GB")
print(f" low_cpu_mem_usage: {use_low_mem} ({reason})")
# Get HF token for authenticated model access
hf_token = os.environ.get("HF_TOKEN")
pipe_kwargs = {
"torch_dtype": torch.bfloat16 if any(x in args.model for x in ["mochi", "wan", "flux"]) else torch.float16,
"device_map": device_map,
"max_memory": max_mem,
"offload_folder": args.offload_dir,
}
# Add auth token if available (for gated/private models)
if hf_token:
pipe_kwargs["use_auth_token"] = hf_token
if use_low_mem:
pipe_kwargs["low_cpu_mem_usage"] = True
......@@ -3010,7 +3164,10 @@ def main(args):
if extra.get("use_custom_vae"):
try:
vae_model_id = model_id_to_load if is_lora else m_info["id"]
vae = AutoencoderKLWan.from_pretrained(vae_model_id, subfolder="vae", torch_dtype=pipe_kwargs["torch_dtype"])
vae_kwargs = {"subfolder": "vae", "torch_dtype": pipe_kwargs["torch_dtype"]}
if hf_token:
vae_kwargs["use_auth_token"] = hf_token
vae = AutoencoderKLWan.from_pretrained(vae_model_id, **vae_kwargs)
pipe_kwargs["vae"] = vae
except Exception as e:
print(f"Custom Wan VAE load failed: {e}")
......@@ -3053,16 +3210,31 @@ def main(args):
print()
# Check if we should retry with an alternative model (auto mode)
if getattr(args, '_auto_mode', False) and getattr(args, '_retry_count', 0) < getattr(args, '_max_retries', 3):
# This applies to ALL error types - we try alternatives before giving up
if getattr(args, '_auto_mode', False):
retry_count = getattr(args, '_retry_count', 0)
max_retries = getattr(args, '_max_retries', 3)
alternative_models = getattr(args, '_auto_alternative_models', [])
if alternative_models:
args._retry_count += 1
if retry_count < max_retries and alternative_models:
# We have alternatives available - retry with next model
args._retry_count = retry_count + 1
next_model_name, next_model_info, next_reason = alternative_models.pop(0)
args._auto_alternative_models = alternative_models # Update the list
print(f"\n⚠️ Model loading failed: {model_id_to_load}")
print(f" Error: {error_str[:100]}...")
print(f"\n🔄 Retrying with alternative model ({args._retry_count}/{args._max_retries})...")
# Print appropriate error message based on error type
if "404" in error_str or "Entry Not Found" in error_str:
print(f"❌ Model not found on HuggingFace: {model_id_to_load}")
elif "401" in error_str or "Unauthorized" in error_str:
print(f"❌ Model requires authentication: {model_id_to_load}")
elif "FrozenDict" in error_str or "scale_factor" in error_str or "has no attribute" in error_str:
print(f"❌ Pipeline compatibility error: {model_id_to_load}")
print(f" This model uses an incompatible pipeline architecture.")
else:
print(f"❌ Model loading failed: {model_id_to_load}")
print(f" Error: {error_str[:100]}...")
print(f"\n🔄 Retrying with alternative model ({args._retry_count}/{max_retries})...")
print(f" New model: {next_model_name}")
print(f" {next_reason}")
......@@ -3073,8 +3245,11 @@ def main(args):
torch.cuda.empty_cache()
# Retry main() with the new model
return main(args)
# No more alternatives or retries exhausted
print(f"\n❌ All model retries exhausted ({retry_count}/{max_retries} attempts)")
# Check for common errors and provide helpful messages
# Print detailed error message for the user
if "404" in error_str or "Entry Not Found" in error_str:
print(f"❌ Model not found on HuggingFace: {model_id_to_load}")
print(f" This model may have been removed or the ID is incorrect.")
......@@ -3127,16 +3302,7 @@ def main(args):
print(f"\n [DEBUG] Full traceback:")
traceback.print_exc()
# If we've exhausted all retries, exit with error
if getattr(args, '_auto_mode', False):
retry_count = getattr(args, '_retry_count', 0)
max_retries = getattr(args, '_max_retries', 3)
alternative_models = getattr(args, '_auto_alternative_models', [])
if retry_count >= max_retries or not alternative_models:
print(f"\n❌ All model retries exhausted ({retry_count}/{max_retries} attempts)")
print(f" Try searching for alternative models: videogen --search-models <query>")
print(f"\n 💡 Try searching for alternative models: videogen --search-models <query>")
sys.exit(1)
timing.end_step() # model_loading
......@@ -3403,6 +3569,10 @@ def main(args):
"max_memory": max_mem,
"offload_folder": args.offload_dir,
}
# Add auth token if available (for gated/private models)
if hf_token:
img_kwargs["use_auth_token"] = hf_token
if use_low_mem:
img_kwargs["low_cpu_mem_usage"] = True
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment