Commit 671388fd authored by Your Name's avatar Your Name

Fix: Use DiffusionPipeline for custom model support (ZImagePipeline)

parent afb2eead
......@@ -314,15 +314,20 @@ async def create_image_generation(request: ImageGenerationRequest, http_request:
# Try to load the model
load_error = None
try:
if is_xl:
# Load SDXL
pipeline = StableDiffusionXLPipeline.from_pretrained(
model_to_use,
torch_dtype=dtype,
)
# Use DiffusionPipeline.from_pretrained which auto-detects the correct pipeline class
# from model_index.json (supports custom pipelines like ZImagePipeline)
from diffusers import DiffusionPipeline
print(f"Loading diffusers model: {model_to_use}")
# Determine compute type
if torch.cuda.is_available():
dtype = torch.float16
else:
# Load SD 1.5
pipeline = StableDiffusionPipeline.from_pretrained(
dtype = torch.float32
# Use DiffusionPipeline for auto-detection of pipeline class
pipeline = DiffusionPipeline.from_pretrained(
model_to_use,
torch_dtype=dtype,
)
......@@ -341,13 +346,14 @@ async def create_image_generation(request: ImageGenerationRequest, http_request:
# Try with default resolution
try:
from diffusers import DiffusionPipeline
if is_xl:
pipeline = StableDiffusionXLPipeline.from_pretrained(
model_to_use,
torch_dtype=dtype,
)
else:
pipeline = StableDiffusionPipeline.from_pretrained(
pipeline = DiffusionPipeline.from_pretrained(
model_to_use,
torch_dtype=dtype,
)
......@@ -361,7 +367,7 @@ async def create_image_generation(request: ImageGenerationRequest, http_request:
safety_checker=None,
)
else:
pipeline = StableDiffusionPipeline.from_pretrained(
pipeline = DiffusionPipeline.from_pretrained(
model_to_use,
torch_dtype=dtype,
safety_checker=None,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment