Commit 782612ea authored by Your Name's avatar Your Name

Add --image-precision option and VAE tiling support for diffusers

- Add --image-precision with choices: bf16, f32, f16, f8
- bf16 recommended for modern GPUs (RTX 30/40 series) to avoid NaN issues
- Enable VAE tiling for diffusers when --vae-tiling is specified
parent df8b4875
...@@ -3741,18 +3741,29 @@ async def create_image_generation(request: ImageGenerationRequest, http_request: ...@@ -3741,18 +3741,29 @@ async def create_image_generation(request: ImageGenerationRequest, http_request:
if pipeline is None: if pipeline is None:
print(f"Loading Stable Diffusion model: {model_to_use}") print(f"Loading Stable Diffusion model: {model_to_use}")
# Determine precision from CLI argument
precision = getattr(global_args, 'image_precision', 'f32') or 'f32'
precision_map = {
'bf16': torch.bfloat16,
'f32': torch.float32,
'f16': torch.float16,
'f8': torch.float8_e4m3fn,
}
torch_dtype = precision_map.get(precision, torch.float32)
print(f"Using precision: {precision} ({torch_dtype})")
# Try to load as Stable Diffusion XL first # Try to load as Stable Diffusion XL first
try: try:
pipeline = StableDiffusionXLPipeline.from_pretrained( pipeline = StableDiffusionXLPipeline.from_pretrained(
model_to_use, model_to_use,
torch_dtype=torch.float32, torch_dtype=torch_dtype,
use_safetensors=True, use_safetensors=True,
) )
except Exception: except Exception:
# Try generic diffusion pipeline # Try generic diffusion pipeline
pipeline = DiffusionPipeline.from_pretrained( pipeline = DiffusionPipeline.from_pretrained(
model_to_use, model_to_use,
torch_dtype=torch.float32, torch_dtype=torch_dtype,
use_safetensors=True, use_safetensors=True,
) )
...@@ -3766,6 +3777,11 @@ async def create_image_generation(request: ImageGenerationRequest, http_request: ...@@ -3766,6 +3777,11 @@ async def create_image_generation(request: ImageGenerationRequest, http_request:
if torch.cuda.is_available(): if torch.cuda.is_available():
pipeline.enable_attention_slicing() pipeline.enable_attention_slicing()
# Enable VAE tiling if requested (for lower VRAM usage)
if getattr(global_args, 'vae_tiling', False):
print("Enabling VAE tiling for lower VRAM usage...")
pipeline.enable_vae_tiling()
multi_model_manager.add_model(model_key, pipeline) multi_model_manager.add_model(model_key, pipeline)
# Get timestamp BEFORE calling diffusers (to avoid scope issues) # Get timestamp BEFORE calling diffusers (to avoid scope issues)
...@@ -5087,6 +5103,13 @@ def parse_args(): ...@@ -5087,6 +5103,13 @@ def parse_args():
default=1.0, default=1.0,
help="CFG scale for image generation (default: 1.0 for Z-Image Turbo).", help="CFG scale for image generation (default: 1.0 for Z-Image Turbo).",
) )
parser.add_argument(
"--image-precision",
type=str,
default="f32",
choices=["bf16", "f32", "f16", "f8"],
help="Model precision for image generation (default: f32). bf16 recommended for modern GPUs.",
)
parser.add_argument( parser.add_argument(
"--image-seed", "--image-seed",
type=int, type=int,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment