Add LoRA adapter detection and base model extraction from HuggingFace tags

- Detect LoRA adapters from tags (lora, LoRA) and files (*.safetensors)
- Extract base model from tags (format: base_model:org/model-name)
- Skip model_index.json fetch for LoRA-only repos
- Determine pipeline class from base model for LoRA adapters
- Improves handling of models like enhanceaiteam/Flux-Uncensored-V2
parent 33ec35a2
Pipeline #233 canceled with stages
......@@ -202,6 +202,7 @@ def validate_hf_model(model_id, hf_token=None, debug=False):
- Library name (diffusers, transformers, etc.)
- Model config (if available)
- siblings (files in the repo)
- LoRA detection and base model extraction
"""
headers = {}
if hf_token:
......@@ -232,20 +233,54 @@ def validate_hf_model(model_id, hf_token=None, debug=False):
print(f" [DEBUG] Pipeline tag: {data.get('pipeline_tag', 'N/A')}")
print(f" [DEBUG] Library: {data.get('library_name', 'N/A')}")
# Check if this is a LoRA adapter
tags = data.get("tags", [])
is_lora = "lora" in tags or "LoRA" in tags
siblings = data.get("siblings", [])
files = [s.get("rfilename", "") for s in siblings]
# Check for LoRA-specific files
has_lora_file = any(f.endswith(".safetensors") and "lora" in f.lower() for f in files)
has_model_index = any(f == "model_index.json" for f in files)
# Detect base model from tags (format: "base_model:org/model-name")
base_model_from_tags = None
for tag in tags:
if tag.startswith("base_model:"):
base_model_from_tags = tag.replace("base_model:", "")
if debug:
print(f" [DEBUG] Found base model in tags: {base_model_from_tags}")
break
# Mark as LoRA if detected from tags or files
if is_lora or has_lora_file:
data["_is_lora"] = True
if base_model_from_tags:
data["_base_model"] = base_model_from_tags
if debug:
print(f" [DEBUG] Detected LoRA adapter")
if base_model_from_tags:
print(f" [DEBUG] Base model: {base_model_from_tags}")
# Try to fetch model_index.json for diffusers models
# This contains the actual pipeline class name
if data.get("library_name") == "diffusers" or "diffusers" in data.get("tags", []):
try:
config_url = f"https://huggingface.co/{model_id}/raw/main/model_index.json"
config_req = urllib.request.Request(config_url, headers=headers)
with urllib.request.urlopen(config_req, timeout=10) as config_response:
config_data = json.loads(config_response.read().decode())
data["model_index"] = config_data
if debug:
print(f" [DEBUG] model_index.json found: {config_data.get('_class_name', 'N/A')}")
except:
# Skip for LoRA-only repos (they don't have model_index.json)
if data.get("library_name") == "diffusers" or "diffusers" in tags:
# Skip if this is a LoRA-only repo (no model_index.json)
if data.get("_is_lora") and not has_model_index:
if debug:
print(f" [DEBUG] No model_index.json found (may not be a diffusers model)")
print(f" [DEBUG] Skipping model_index.json for LoRA-only repo")
else:
try:
config_url = f"https://huggingface.co/{model_id}/raw/main/model_index.json"
config_req = urllib.request.Request(config_url, headers=headers)
with urllib.request.urlopen(config_req, timeout=10) as config_response:
config_data = json.loads(config_response.read().decode())
data["model_index"] = config_data
if debug:
print(f" [DEBUG] model_index.json found: {config_data.get('_class_name', 'N/A')}")
except:
if debug:
print(f" [DEBUG] No model_index.json found (may be a LoRA-only repo or non-diffusers model)")
return data
except urllib.error.HTTPError as e:
......@@ -461,12 +496,9 @@ def add_model_from_hf(model_id_or_url, name=None, hf_token=None, debug=False):
if not model_info:
return None
# Detect pipeline class
pipeline_class = detect_pipeline_class(model_info)
if not pipeline_class:
print(f"⚠️ Could not auto-detect pipeline class for {model_id}")
print(f" Available classes: {', '.join(PIPELINE_CLASS_MAP.keys())}")
pipeline_class = "WanPipeline" # Default fallback
# Check if this is a LoRA adapter (from validation)
is_lora = model_info.get("_is_lora", False)
base_model = model_info.get("_base_model") # Extracted from tags
# Get model name
if not name:
......@@ -476,25 +508,63 @@ def add_model_from_hf(model_id_or_url, name=None, hf_token=None, debug=False):
tags = model_info.get("tags", [])
is_i2v = any(t in tags for t in ["image-to-video", "i2v"]) or "i2v" in model_id.lower()
# For LoRA adapters, determine pipeline class from base model
if is_lora:
if base_model:
print(f" 📦 LoRA adapter detected")
print(f" Base model: {base_model}")
else:
# Try to infer base model from LoRA name
if "wan" in model_id.lower():
base_model = "Wan-AI/Wan2.1-I2V-14B-Diffusers" if is_i2v else "Wan-AI/Wan2.1-T2V-14B-Diffusers"
elif "svd" in model_id.lower() or "stable-video" in model_id.lower():
base_model = "stabilityai/stable-video-diffusion-img2vid-xt-1-1"
elif "flux" in model_id.lower():
base_model = "black-forest-labs/FLUX.1-dev"
elif "sdxl" in model_id.lower() or "xl" in model_id.lower():
base_model = "stabilityai/stable-diffusion-xl-base-1.0"
if base_model:
print(f" 📦 LoRA adapter detected (inferred base model)")
print(f" Base model: {base_model}")
else:
print(f" ⚠️ LoRA adapter detected but could not determine base model")
print(f" Please specify --base-model when using this LoRA")
# Detect pipeline class
pipeline_class = detect_pipeline_class(model_info)
# For LoRA adapters, use the base model's pipeline class
if is_lora and base_model:
# Determine pipeline class from base model
base_model_lower = base_model.lower()
if "wan" in base_model_lower:
pipeline_class = "WanPipeline"
elif "svd" in base_model_lower or "stable-video-diffusion" in base_model_lower:
pipeline_class = "StableVideoDiffusionPipeline"
elif "flux" in base_model_lower:
pipeline_class = "FluxPipeline"
elif "sdxl" in base_model_lower or "stable-diffusion-xl" in base_model_lower:
pipeline_class = "StableDiffusionXLPipeline"
elif "sd3" in base_model_lower or "stable-diffusion-3" in base_model_lower:
pipeline_class = "StableDiffusion3Pipeline"
else:
# Default to FluxPipeline for unknown image LoRAs
if "text-to-image" in tags or "image-to-image" in tags:
pipeline_class = "FluxPipeline"
if not pipeline_class:
print(f"⚠️ Could not auto-detect pipeline class for {model_id}")
print(f" Available classes: {', '.join(PIPELINE_CLASS_MAP.keys())}")
pipeline_class = "WanPipeline" # Default fallback
# Get VRAM estimate
vram_est = PIPELINE_CLASS_MAP.get(pipeline_class, {}).get("default_vram", "~10-20 GB")
# Check for NSFW indicators
nsfw_keywords = ["nsfw", "adult", "uncensored", "porn", "explicit", "nude", "erotic"]
is_nsfw = any(kw in model_id.lower() or kw in model_info.get("description", "").lower() for kw in nsfw_keywords)
# Detect if this is a LoRA
is_lora = "lora" in model_id.lower() or any(t in tags for t in ["lora", "LoRA"])
base_model = None
if is_lora:
# Try to detect base model for LoRA
if "wan" in model_id.lower():
base_model = "Wan-AI/Wan2.1-I2V-14B-Diffusers" if is_i2v else "Wan-AI/Wan2.1-T2V-14B-Diffusers"
elif "svd" in model_id.lower() or "stable-video" in model_id.lower():
base_model = "stabilityai/stable-video-diffusion-img2vid-xt-1-1"
elif "sdxl" in model_id.lower() or "xl" in model_id.lower():
base_model = "stabilityai/stable-diffusion-xl-base-1.0"
is_nsfw = is_nsfw or any(kw in str(tags).lower() for kw in nsfw_keywords)
# Build model entry
model_entry = {
......@@ -521,6 +591,8 @@ def add_model_from_hf(model_id_or_url, name=None, hf_token=None, debug=False):
print(f" VRAM: {vram_est}")
print(f" I2V: {is_i2v}")
print(f" NSFW-friendly: {is_nsfw}")
if is_lora:
print(f" LoRA: Yes (base: {base_model or 'unknown'})")
return name, model_entry
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment