Commit b7fbde39 authored by Your Name's avatar Your Name

Fix: Use DiffusionPipeline for custom model support (ZImagePipeline) - was...

Fix: Use DiffusionPipeline for custom model support (ZImagePipeline) - was using it in original code
parent 671388fd
...@@ -314,8 +314,8 @@ async def create_image_generation(request: ImageGenerationRequest, http_request: ...@@ -314,8 +314,8 @@ async def create_image_generation(request: ImageGenerationRequest, http_request:
# Try to load the model # Try to load the model
load_error = None load_error = None
try: try:
# Use DiffusionPipeline.from_pretrained which auto-detects the correct pipeline class # Use DiffusionPipeline which auto-detects the correct pipeline class from model_index.json
# from model_index.json (supports custom pipelines like ZImagePipeline) # This supports custom pipelines like ZImagePipeline (DiT-based) which use 'transformer' instead of 'unet'
from diffusers import DiffusionPipeline from diffusers import DiffusionPipeline
print(f"Loading diffusers model: {model_to_use}") print(f"Loading diffusers model: {model_to_use}")
...@@ -353,13 +353,15 @@ async def create_image_generation(request: ImageGenerationRequest, http_request: ...@@ -353,13 +353,15 @@ async def create_image_generation(request: ImageGenerationRequest, http_request:
torch_dtype=dtype, torch_dtype=dtype,
) )
else: else:
# Fall back to DiffusionPipeline for custom pipelines
pipeline = DiffusionPipeline.from_pretrained( pipeline = DiffusionPipeline.from_pretrained(
model_to_use, model_to_use,
torch_dtype=dtype, torch_dtype=dtype,
) )
except Exception as retry_error: except Exception as retry_error:
# If it still fails, try without safety checker # If it still fails, try DiffusionPipeline (for custom pipelines like ZImagePipeline)
print(f"Warning: Retry failed: {retry_error}, trying without safety checker...") print(f"Warning: Retry failed: {retry_error}, trying DiffusionPipeline for custom pipelines...")
from diffusers import DiffusionPipeline
if is_xl: if is_xl:
pipeline = StableDiffusionXLPipeline.from_pretrained( pipeline = StableDiffusionXLPipeline.from_pretrained(
model_to_use, model_to_use,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment