Fix image generation to handle LoRA adapters - load base model first, then apply LoRA weights

parent ad95e206
Pipeline #234 canceled with stages
...@@ -4030,6 +4030,34 @@ def main(args): ...@@ -4030,6 +4030,34 @@ def main(args):
timing.begin_step("image_generation") timing.begin_step("image_generation")
img_info = MODELS[args.image_model] img_info = MODELS[args.image_model]
# Check if image model is a LoRA adapter
img_is_lora = img_info.get("is_lora", False)
img_lora_id = None
img_base_model_id = None
img_model_id_to_load = img_info["id"]
if img_is_lora:
img_lora_id = img_info["id"]
img_base_model_id = img_info.get("base_model")
# Try to infer base model from LoRA name if not specified
if not img_base_model_id:
if "flux" in img_lora_id.lower():
img_base_model_id = "black-forest-labs/FLUX.1-dev"
elif "sdxl" in img_lora_id.lower():
img_base_model_id = "stabilityai/stable-diffusion-xl-base-1.0"
elif "sd15" in img_lora_id.lower() or "sd1.5" in img_lora_id.lower():
img_base_model_id = "runwayml/stable-diffusion-v1-5"
else:
# Default to Flux for unknown image LoRAs
img_base_model_id = "black-forest-labs/FLUX.1-dev"
print(f" 📦 Image model is a LoRA adapter")
print(f" LoRA: {img_lora_id}")
print(f" Base model: {img_base_model_id}")
img_model_id_to_load = img_base_model_id
ImgCls = get_pipeline_class(img_info["class"]) ImgCls = get_pipeline_class(img_info["class"])
if not ImgCls: if not ImgCls:
print(f"❌ Pipeline class '{img_info['class']}' not found for image model.") print(f"❌ Pipeline class '{img_info['class']}' not found for image model.")
...@@ -4051,7 +4079,18 @@ def main(args): ...@@ -4051,7 +4079,18 @@ def main(args):
img_kwargs["low_cpu_mem_usage"] = True img_kwargs["low_cpu_mem_usage"] = True
try: try:
img_pipe = ImgCls.from_pretrained(img_info["id"], **img_kwargs) img_pipe = ImgCls.from_pretrained(img_model_id_to_load, **img_kwargs)
# Apply LoRA if image model is a LoRA adapter
if img_is_lora and img_lora_id:
print(f" Loading image LoRA adapter: {img_lora_id}")
try:
img_pipe.load_lora_weights(img_lora_id)
print(f" ✅ Image LoRA applied successfully")
except Exception as lora_e:
print(f" ⚠️ Image LoRA loading failed: {lora_e}")
print(f" Continuing with base image model...")
img_pipe.enable_model_cpu_offload() img_pipe.enable_model_cpu_offload()
img_prompt = ", ".join(args.prompt_image) if args.prompt_image else main_prompt img_prompt = ", ".join(args.prompt_image) if args.prompt_image else main_prompt
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment