Commit 671388fd authored by Your Name's avatar Your Name

Fix: Use DiffusionPipeline for custom model support (ZImagePipeline)

parent afb2eead
......@@ -314,18 +314,23 @@ async def create_image_generation(request: ImageGenerationRequest, http_request:
# Try to load the model
load_error = None
try:
if is_xl:
# Load SDXL
pipeline = StableDiffusionXLPipeline.from_pretrained(
model_to_use,
torch_dtype=dtype,
)
# Use DiffusionPipeline.from_pretrained which auto-detects the correct pipeline class
# from model_index.json (supports custom pipelines like ZImagePipeline)
from diffusers import DiffusionPipeline
print(f"Loading diffusers model: {model_to_use}")
# Determine compute type
if torch.cuda.is_available():
dtype = torch.float16
else:
# Load SD 1.5
pipeline = StableDiffusionPipeline.from_pretrained(
model_to_use,
torch_dtype=dtype,
)
dtype = torch.float32
# Use DiffusionPipeline for auto-detection of pipeline class
pipeline = DiffusionPipeline.from_pretrained(
model_to_use,
torch_dtype=dtype,
)
except Exception as load_error:
# Try with revised model resolution for custom models
print(f"Warning: First model load attempt failed: {load_error}")
......@@ -341,13 +346,14 @@ async def create_image_generation(request: ImageGenerationRequest, http_request:
# Try with default resolution
try:
from diffusers import DiffusionPipeline
if is_xl:
pipeline = StableDiffusionXLPipeline.from_pretrained(
model_to_use,
torch_dtype=dtype,
)
else:
pipeline = StableDiffusionPipeline.from_pretrained(
pipeline = DiffusionPipeline.from_pretrained(
model_to_use,
torch_dtype=dtype,
)
......@@ -361,7 +367,7 @@ async def create_image_generation(request: ImageGenerationRequest, http_request:
safety_checker=None,
)
else:
pipeline = StableDiffusionPipeline.from_pretrained(
pipeline = DiffusionPipeline.from_pretrained(
model_to_use,
torch_dtype=dtype,
safety_checker=None,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment