Better progression stats during generation from web

parent cb542996
This diff is collapsed.
......@@ -37,6 +37,32 @@ router = APIRouter()
global_args = None
global_file_path = None
# =============================================================================
# Audio generation progress tracking
# =============================================================================
_aud_progress: dict = {
"current": 0, "total": 0, "active": False,
"started_at": 0.0, "it_per_s": 0.0, "unit": "it",
}
def _aud_progress_reset(total: int, unit: str = "it"):
_aud_progress["current"] = 0
_aud_progress["total"] = total
_aud_progress["active"] = True
_aud_progress["started_at"] = time.monotonic()
_aud_progress["it_per_s"] = 0.0
_aud_progress["unit"] = unit
def _aud_progress_done():
_aud_progress["current"] = max(_aud_progress["current"], _aud_progress["total"])
_aud_progress["active"] = False
def _aud_progress_step(step: int):
_aud_progress["current"] = step
elapsed = time.monotonic() - _aud_progress["started_at"]
if elapsed > 0 and step > 0:
_aud_progress["it_per_s"] = round(step / elapsed, 2)
def set_global_args(args):
global global_args
......@@ -124,6 +150,8 @@ def _generate_audio(pipe, model_name: str, request: AudioGenerationRequest):
temperature=request.temperature,
cfg_coef=request.cfg_coef,
)
# MusicGen/AudioGen generate in one shot — track elapsed only
_aud_progress_reset(0, unit="s")
if request.melody and model_type == 'musicgen':
import torchaudio, torch
raw = _decode_b64_or_url(request.melody)
......@@ -135,10 +163,18 @@ def _generate_audio(pipe, model_name: str, request: AudioGenerationRequest):
sr = pipe.sample_rate
elif model_type == 'audioldm':
num_steps = 50
_aud_progress_reset(num_steps, unit="it")
def _aud_step_cb(pipe, step_index, timestep, callback_kwargs):
_aud_progress_step(step_index + 1)
return callback_kwargs
result = pipe(
request.prompt,
num_inference_steps=50,
num_inference_steps=num_steps,
audio_length_in_s=request.duration,
callback_on_step_end=_aud_step_cb,
)
audio_np = result.audios[0]
sr = 16000
......@@ -162,6 +198,23 @@ def _decode_b64_or_url(data: str) -> bytes:
return base64.b64decode(data)
@router.get("/v1/audio/progress")
async def get_audio_progress():
"""Return current audio generation progress including speed."""
elapsed = time.monotonic() - _aud_progress["started_at"] if _aud_progress["active"] else 0.0
total = _aud_progress["total"]
current = _aud_progress["current"]
return {
"current": current,
"total": total,
"active": _aud_progress["active"],
"pct": int(current / total * 100) if total > 0 else 0,
"it_per_s": _aud_progress["it_per_s"],
"elapsed": round(elapsed, 1),
"unit": _aud_progress["unit"],
}
@router.post("/v1/audio/generate", response_model=AudioGenerationResponse)
async def audio_generate(request: AudioGenerationRequest, http_request: Request = None):
"""
......@@ -196,7 +249,10 @@ async def audio_generate(request: AudioGenerationRequest, http_request: Request
audio_bytes, ext = await asyncio.get_event_loop().run_in_executor(
None, _generate_audio, pipe, model_name, request)
except Exception as e:
_aud_progress_done()
raise HTTPException(status_code=500, detail=f"Audio generation failed: {e}")
finally:
_aud_progress_done()
result = _save_audio_response(audio_bytes, ext, http_request)
......
......@@ -118,12 +118,19 @@ queue_flags = {}
# =============================================================================
# Generation progress tracking
# =============================================================================
_gen_progress: dict = {"current": 0, "total": 0, "active": False}
import time as _time
_gen_progress: dict = {
"current": 0, "total": 0, "active": False,
"started_at": 0.0, "it_per_s": 0.0,
}
def _progress_reset(total: int):
_gen_progress["current"] = 0
_gen_progress["total"] = total
_gen_progress["active"] = True
_gen_progress["started_at"] = _time.monotonic()
_gen_progress["it_per_s"] = 0.0
def _progress_done():
_gen_progress["current"] = _gen_progress["total"]
......@@ -131,6 +138,9 @@ def _progress_done():
def _progress_step(step: int):
_gen_progress["current"] = step
elapsed = _time.monotonic() - _gen_progress["started_at"]
if elapsed > 0 and step > 0:
_gen_progress["it_per_s"] = round(step / elapsed, 2)
# =============================================================================
......@@ -884,13 +894,16 @@ router = APIRouter()
@router.get("/v1/images/progress")
async def get_image_progress():
"""Return current image generation step progress."""
"""Return current image generation step progress including speed."""
elapsed = _time.monotonic() - _gen_progress["started_at"] if _gen_progress["active"] else 0.0
return {
"current": _gen_progress["current"],
"total": _gen_progress["total"],
"active": _gen_progress["active"],
"pct": int(_gen_progress["current"] / _gen_progress["total"] * 100)
if _gen_progress["total"] > 0 else 0,
"current": _gen_progress["current"],
"total": _gen_progress["total"],
"active": _gen_progress["active"],
"pct": int(_gen_progress["current"] / _gen_progress["total"] * 100)
if _gen_progress["total"] > 0 else 0,
"it_per_s": _gen_progress["it_per_s"],
"elapsed": round(elapsed, 1),
}
......
......@@ -51,6 +51,31 @@ router = APIRouter()
global_args = None
global_file_path = None
# =============================================================================
# Video generation progress tracking
# =============================================================================
_vid_progress: dict = {
"current": 0, "total": 0, "active": False,
"started_at": 0.0, "it_per_s": 0.0,
}
def _vid_progress_reset(total: int):
_vid_progress["current"] = 0
_vid_progress["total"] = total
_vid_progress["active"] = True
_vid_progress["started_at"] = time.monotonic()
_vid_progress["it_per_s"] = 0.0
def _vid_progress_done():
_vid_progress["current"] = _vid_progress["total"]
_vid_progress["active"] = False
def _vid_progress_step(step: int):
_vid_progress["current"] = step
elapsed = time.monotonic() - _vid_progress["started_at"]
if elapsed > 0 and step > 0:
_vid_progress["it_per_s"] = round(step / elapsed, 2)
def set_global_args(args):
global global_args
......@@ -322,6 +347,17 @@ def _generate_video(pipe, request: VideoGenerationRequest):
kw.setdefault('guidance_scale', 7.5)
kw.setdefault('num_frames', 16)
_vid_progress_reset(kw['num_inference_steps'])
def _vid_step_cb(pipe, step_index, timestep, callback_kwargs):
_vid_progress_step(step_index + 1)
return callback_kwargs
try:
kw['callback_on_step_end'] = _vid_step_cb
except Exception:
pass
_apply_camera_motion(kw, request.camera_motion)
char_images, char_names = _resolve_character_inputs(request)
......@@ -355,6 +391,7 @@ def _generate_video(pipe, request: VideoGenerationRequest):
kw['strength'] = request.strength
frames = _run_pipeline(pipe, kw)
_vid_progress_done()
return frames, fps
......@@ -781,6 +818,25 @@ def _translate_srt(srt_path: str, target_lang: str, temps: list) -> str:
return srt_path
# =============================================================================
# Progress endpoint
# =============================================================================
@router.get("/v1/video/progress")
async def get_video_progress():
"""Return current video generation step progress including speed."""
elapsed = time.monotonic() - _vid_progress["started_at"] if _vid_progress["active"] else 0.0
return {
"current": _vid_progress["current"],
"total": _vid_progress["total"],
"active": _vid_progress["active"],
"pct": int(_vid_progress["current"] / _vid_progress["total"] * 100)
if _vid_progress["total"] > 0 else 0,
"it_per_s": _vid_progress["it_per_s"],
"elapsed": round(elapsed, 1),
}
# =============================================================================
# Main generation endpoint
# =============================================================================
......@@ -836,6 +892,7 @@ async def video_generations(request: VideoGenerationRequest,
frames, fps = await asyncio.get_event_loop().run_in_executor(
None, _generate_video, pipe, request)
except Exception as e:
_vid_progress_done()
raise HTTPException(status_code=500, detail=f"Video generation failed: {e}")
# Encode raw frames to MP4
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment