Commit 671388fd authored by Your Name's avatar Your Name

Fix: Use DiffusionPipeline for custom model support (ZImagePipeline)

parent afb2eead
...@@ -314,15 +314,20 @@ async def create_image_generation(request: ImageGenerationRequest, http_request: ...@@ -314,15 +314,20 @@ async def create_image_generation(request: ImageGenerationRequest, http_request:
# Try to load the model # Try to load the model
load_error = None load_error = None
try: try:
if is_xl: # Use DiffusionPipeline.from_pretrained which auto-detects the correct pipeline class
# Load SDXL # from model_index.json (supports custom pipelines like ZImagePipeline)
pipeline = StableDiffusionXLPipeline.from_pretrained( from diffusers import DiffusionPipeline
model_to_use,
torch_dtype=dtype, print(f"Loading diffusers model: {model_to_use}")
)
# Determine compute type
if torch.cuda.is_available():
dtype = torch.float16
else: else:
# Load SD 1.5 dtype = torch.float32
pipeline = StableDiffusionPipeline.from_pretrained(
# Use DiffusionPipeline for auto-detection of pipeline class
pipeline = DiffusionPipeline.from_pretrained(
model_to_use, model_to_use,
torch_dtype=dtype, torch_dtype=dtype,
) )
...@@ -341,13 +346,14 @@ async def create_image_generation(request: ImageGenerationRequest, http_request: ...@@ -341,13 +346,14 @@ async def create_image_generation(request: ImageGenerationRequest, http_request:
# Try with default resolution # Try with default resolution
try: try:
from diffusers import DiffusionPipeline
if is_xl: if is_xl:
pipeline = StableDiffusionXLPipeline.from_pretrained( pipeline = StableDiffusionXLPipeline.from_pretrained(
model_to_use, model_to_use,
torch_dtype=dtype, torch_dtype=dtype,
) )
else: else:
pipeline = StableDiffusionPipeline.from_pretrained( pipeline = DiffusionPipeline.from_pretrained(
model_to_use, model_to_use,
torch_dtype=dtype, torch_dtype=dtype,
) )
...@@ -361,7 +367,7 @@ async def create_image_generation(request: ImageGenerationRequest, http_request: ...@@ -361,7 +367,7 @@ async def create_image_generation(request: ImageGenerationRequest, http_request:
safety_checker=None, safety_checker=None,
) )
else: else:
pipeline = StableDiffusionPipeline.from_pretrained( pipeline = DiffusionPipeline.from_pretrained(
model_to_use, model_to_use,
torch_dtype=dtype, torch_dtype=dtype,
safety_checker=None, safety_checker=None,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment