Commit 55a39eeb authored by Your Name's avatar Your Name

Add steps and guidance_scale to image generation request

- Add 'steps' parameter to ImageGenerationRequest (overrides quality-based default)
- Add 'guidance_scale' parameter to ImageGenerationRequest (overrides CLI --image-cfg-scale)
- Use request values in diffusers pipeline call
parent 9a749ea4
...@@ -395,6 +395,8 @@ class ImageGenerationRequest(BaseModel): ...@@ -395,6 +395,8 @@ class ImageGenerationRequest(BaseModel):
prompt: str prompt: str
n: int = 1 n: int = 1
size: Optional[str] = "1024x1024" size: Optional[str] = "1024x1024"
steps: Optional[int] = None # Number of inference steps (overrides quality-based default)
guidance_scale: Optional[float] = None # CFG scale (overrides quality-based default)
quality: Optional[str] = "standard" quality: Optional[str] = "standard"
style: Optional[str] = None style: Optional[str] = None
response_format: Optional[str] = "url" response_format: Optional[str] = "url"
...@@ -3780,6 +3782,12 @@ async def create_image_generation(request: ImageGenerationRequest, http_request: ...@@ -3780,6 +3782,12 @@ async def create_image_generation(request: ImageGenerationRequest, http_request:
# Quality: "standard" or "hd" # Quality: "standard" or "hd"
quality = request.quality or "standard" quality = request.quality or "standard"
# Use request parameters if provided, otherwise fall back to quality-based defaults
num_steps = request.steps if request.steps else (30 if quality == "standard" else 50)
cfg_scale = request.guidance_scale if request.guidance_scale else (
getattr(global_args, 'image_cfg_scale', 7.5) if quality == "standard" else 9.0
)
# Generate # Generate
result = pipeline( result = pipeline(
prompt=request.prompt, prompt=request.prompt,
...@@ -3788,8 +3796,8 @@ async def create_image_generation(request: ImageGenerationRequest, http_request: ...@@ -3788,8 +3796,8 @@ async def create_image_generation(request: ImageGenerationRequest, http_request:
height=height, height=height,
width=width, width=width,
generator=generator, generator=generator,
guidance_scale=getattr(global_args, 'image_cfg_scale', 7.5) if quality == "standard" else 9.0, guidance_scale=cfg_scale,
num_inference_steps=30 if quality == "standard" else 50, num_inference_steps=num_steps,
) )
# Extract images # Extract images
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment