Better progression stats during generation from web

parent cb542996
This diff is collapsed.
...@@ -37,6 +37,32 @@ router = APIRouter() ...@@ -37,6 +37,32 @@ router = APIRouter()
global_args = None global_args = None
global_file_path = None global_file_path = None
# =============================================================================
# Audio generation progress tracking
# =============================================================================
_aud_progress: dict = {
"current": 0, "total": 0, "active": False,
"started_at": 0.0, "it_per_s": 0.0, "unit": "it",
}
def _aud_progress_reset(total: int, unit: str = "it"):
_aud_progress["current"] = 0
_aud_progress["total"] = total
_aud_progress["active"] = True
_aud_progress["started_at"] = time.monotonic()
_aud_progress["it_per_s"] = 0.0
_aud_progress["unit"] = unit
def _aud_progress_done():
_aud_progress["current"] = max(_aud_progress["current"], _aud_progress["total"])
_aud_progress["active"] = False
def _aud_progress_step(step: int):
_aud_progress["current"] = step
elapsed = time.monotonic() - _aud_progress["started_at"]
if elapsed > 0 and step > 0:
_aud_progress["it_per_s"] = round(step / elapsed, 2)
def set_global_args(args): def set_global_args(args):
global global_args global global_args
...@@ -124,6 +150,8 @@ def _generate_audio(pipe, model_name: str, request: AudioGenerationRequest): ...@@ -124,6 +150,8 @@ def _generate_audio(pipe, model_name: str, request: AudioGenerationRequest):
temperature=request.temperature, temperature=request.temperature,
cfg_coef=request.cfg_coef, cfg_coef=request.cfg_coef,
) )
# MusicGen/AudioGen generate in one shot — track elapsed only
_aud_progress_reset(0, unit="s")
if request.melody and model_type == 'musicgen': if request.melody and model_type == 'musicgen':
import torchaudio, torch import torchaudio, torch
raw = _decode_b64_or_url(request.melody) raw = _decode_b64_or_url(request.melody)
...@@ -135,10 +163,18 @@ def _generate_audio(pipe, model_name: str, request: AudioGenerationRequest): ...@@ -135,10 +163,18 @@ def _generate_audio(pipe, model_name: str, request: AudioGenerationRequest):
sr = pipe.sample_rate sr = pipe.sample_rate
elif model_type == 'audioldm': elif model_type == 'audioldm':
num_steps = 50
_aud_progress_reset(num_steps, unit="it")
def _aud_step_cb(pipe, step_index, timestep, callback_kwargs):
_aud_progress_step(step_index + 1)
return callback_kwargs
result = pipe( result = pipe(
request.prompt, request.prompt,
num_inference_steps=50, num_inference_steps=num_steps,
audio_length_in_s=request.duration, audio_length_in_s=request.duration,
callback_on_step_end=_aud_step_cb,
) )
audio_np = result.audios[0] audio_np = result.audios[0]
sr = 16000 sr = 16000
...@@ -162,6 +198,23 @@ def _decode_b64_or_url(data: str) -> bytes: ...@@ -162,6 +198,23 @@ def _decode_b64_or_url(data: str) -> bytes:
return base64.b64decode(data) return base64.b64decode(data)
@router.get("/v1/audio/progress")
async def get_audio_progress():
"""Return current audio generation progress including speed."""
elapsed = time.monotonic() - _aud_progress["started_at"] if _aud_progress["active"] else 0.0
total = _aud_progress["total"]
current = _aud_progress["current"]
return {
"current": current,
"total": total,
"active": _aud_progress["active"],
"pct": int(current / total * 100) if total > 0 else 0,
"it_per_s": _aud_progress["it_per_s"],
"elapsed": round(elapsed, 1),
"unit": _aud_progress["unit"],
}
@router.post("/v1/audio/generate", response_model=AudioGenerationResponse) @router.post("/v1/audio/generate", response_model=AudioGenerationResponse)
async def audio_generate(request: AudioGenerationRequest, http_request: Request = None): async def audio_generate(request: AudioGenerationRequest, http_request: Request = None):
""" """
...@@ -196,7 +249,10 @@ async def audio_generate(request: AudioGenerationRequest, http_request: Request ...@@ -196,7 +249,10 @@ async def audio_generate(request: AudioGenerationRequest, http_request: Request
audio_bytes, ext = await asyncio.get_event_loop().run_in_executor( audio_bytes, ext = await asyncio.get_event_loop().run_in_executor(
None, _generate_audio, pipe, model_name, request) None, _generate_audio, pipe, model_name, request)
except Exception as e: except Exception as e:
_aud_progress_done()
raise HTTPException(status_code=500, detail=f"Audio generation failed: {e}") raise HTTPException(status_code=500, detail=f"Audio generation failed: {e}")
finally:
_aud_progress_done()
result = _save_audio_response(audio_bytes, ext, http_request) result = _save_audio_response(audio_bytes, ext, http_request)
......
...@@ -118,12 +118,19 @@ queue_flags = {} ...@@ -118,12 +118,19 @@ queue_flags = {}
# ============================================================================= # =============================================================================
# Generation progress tracking # Generation progress tracking
# ============================================================================= # =============================================================================
_gen_progress: dict = {"current": 0, "total": 0, "active": False} import time as _time
_gen_progress: dict = {
"current": 0, "total": 0, "active": False,
"started_at": 0.0, "it_per_s": 0.0,
}
def _progress_reset(total: int): def _progress_reset(total: int):
_gen_progress["current"] = 0 _gen_progress["current"] = 0
_gen_progress["total"] = total _gen_progress["total"] = total
_gen_progress["active"] = True _gen_progress["active"] = True
_gen_progress["started_at"] = _time.monotonic()
_gen_progress["it_per_s"] = 0.0
def _progress_done(): def _progress_done():
_gen_progress["current"] = _gen_progress["total"] _gen_progress["current"] = _gen_progress["total"]
...@@ -131,6 +138,9 @@ def _progress_done(): ...@@ -131,6 +138,9 @@ def _progress_done():
def _progress_step(step: int): def _progress_step(step: int):
_gen_progress["current"] = step _gen_progress["current"] = step
elapsed = _time.monotonic() - _gen_progress["started_at"]
if elapsed > 0 and step > 0:
_gen_progress["it_per_s"] = round(step / elapsed, 2)
# ============================================================================= # =============================================================================
...@@ -884,13 +894,16 @@ router = APIRouter() ...@@ -884,13 +894,16 @@ router = APIRouter()
@router.get("/v1/images/progress") @router.get("/v1/images/progress")
async def get_image_progress(): async def get_image_progress():
"""Return current image generation step progress.""" """Return current image generation step progress including speed."""
elapsed = _time.monotonic() - _gen_progress["started_at"] if _gen_progress["active"] else 0.0
return { return {
"current": _gen_progress["current"], "current": _gen_progress["current"],
"total": _gen_progress["total"], "total": _gen_progress["total"],
"active": _gen_progress["active"], "active": _gen_progress["active"],
"pct": int(_gen_progress["current"] / _gen_progress["total"] * 100) "pct": int(_gen_progress["current"] / _gen_progress["total"] * 100)
if _gen_progress["total"] > 0 else 0, if _gen_progress["total"] > 0 else 0,
"it_per_s": _gen_progress["it_per_s"],
"elapsed": round(elapsed, 1),
} }
......
...@@ -51,6 +51,31 @@ router = APIRouter() ...@@ -51,6 +51,31 @@ router = APIRouter()
global_args = None global_args = None
global_file_path = None global_file_path = None
# =============================================================================
# Video generation progress tracking
# =============================================================================
_vid_progress: dict = {
"current": 0, "total": 0, "active": False,
"started_at": 0.0, "it_per_s": 0.0,
}
def _vid_progress_reset(total: int):
_vid_progress["current"] = 0
_vid_progress["total"] = total
_vid_progress["active"] = True
_vid_progress["started_at"] = time.monotonic()
_vid_progress["it_per_s"] = 0.0
def _vid_progress_done():
_vid_progress["current"] = _vid_progress["total"]
_vid_progress["active"] = False
def _vid_progress_step(step: int):
_vid_progress["current"] = step
elapsed = time.monotonic() - _vid_progress["started_at"]
if elapsed > 0 and step > 0:
_vid_progress["it_per_s"] = round(step / elapsed, 2)
def set_global_args(args): def set_global_args(args):
global global_args global global_args
...@@ -322,6 +347,17 @@ def _generate_video(pipe, request: VideoGenerationRequest): ...@@ -322,6 +347,17 @@ def _generate_video(pipe, request: VideoGenerationRequest):
kw.setdefault('guidance_scale', 7.5) kw.setdefault('guidance_scale', 7.5)
kw.setdefault('num_frames', 16) kw.setdefault('num_frames', 16)
_vid_progress_reset(kw['num_inference_steps'])
def _vid_step_cb(pipe, step_index, timestep, callback_kwargs):
_vid_progress_step(step_index + 1)
return callback_kwargs
try:
kw['callback_on_step_end'] = _vid_step_cb
except Exception:
pass
_apply_camera_motion(kw, request.camera_motion) _apply_camera_motion(kw, request.camera_motion)
char_images, char_names = _resolve_character_inputs(request) char_images, char_names = _resolve_character_inputs(request)
...@@ -355,6 +391,7 @@ def _generate_video(pipe, request: VideoGenerationRequest): ...@@ -355,6 +391,7 @@ def _generate_video(pipe, request: VideoGenerationRequest):
kw['strength'] = request.strength kw['strength'] = request.strength
frames = _run_pipeline(pipe, kw) frames = _run_pipeline(pipe, kw)
_vid_progress_done()
return frames, fps return frames, fps
...@@ -781,6 +818,25 @@ def _translate_srt(srt_path: str, target_lang: str, temps: list) -> str: ...@@ -781,6 +818,25 @@ def _translate_srt(srt_path: str, target_lang: str, temps: list) -> str:
return srt_path return srt_path
# =============================================================================
# Progress endpoint
# =============================================================================
@router.get("/v1/video/progress")
async def get_video_progress():
"""Return current video generation step progress including speed."""
elapsed = time.monotonic() - _vid_progress["started_at"] if _vid_progress["active"] else 0.0
return {
"current": _vid_progress["current"],
"total": _vid_progress["total"],
"active": _vid_progress["active"],
"pct": int(_vid_progress["current"] / _vid_progress["total"] * 100)
if _vid_progress["total"] > 0 else 0,
"it_per_s": _vid_progress["it_per_s"],
"elapsed": round(elapsed, 1),
}
# ============================================================================= # =============================================================================
# Main generation endpoint # Main generation endpoint
# ============================================================================= # =============================================================================
...@@ -836,6 +892,7 @@ async def video_generations(request: VideoGenerationRequest, ...@@ -836,6 +892,7 @@ async def video_generations(request: VideoGenerationRequest,
frames, fps = await asyncio.get_event_loop().run_in_executor( frames, fps = await asyncio.get_event_loop().run_in_executor(
None, _generate_video, pipe, request) None, _generate_video, pipe, request)
except Exception as e: except Exception as e:
_vid_progress_done()
raise HTTPException(status_code=500, detail=f"Video generation failed: {e}") raise HTTPException(status_code=500, detail=f"Video generation failed: {e}")
# Encode raw frames to MP4 # Encode raw frames to MP4
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment