downloads: dedup re-downloads + kill orphaned workers; single-line load progress

Re-downloading a model that was already in progress spawned a second
download_worker. Both contend for huggingface_hub's per-blob file lock —
the first downloads, the second blocks on the lock and reports 0% forever
("Downloading full repository…"). Two causes, both fixed:

- Same-process re-download click: api_download_model now dedups via
  _active_download_session(model_id, file_pattern) and attaches the client
  to the live session instead of spawning a rival worker.
- Restart case: workers were plain Popen children with no parent-death
  signal, so a server/engine restart orphaned them (still holding the lock)
  while the new instance lost its in-memory dedup state. Workers now spawn
  with PR_SET_PDEATHSIG=SIGKILL so they die with the server; the re-download
  then resumes cleanly from the .incomplete blob.

Also render engine "Loading weights" tqdm progress as a single updating
line on a TTY (in-place \r) and throttle to whole-percent changes when
piped, instead of one line per update.
Co-Authored-By: 's avatarClaude Opus 4.8 <noreply@anthropic.com>
parent 3020f828
......@@ -49,15 +49,44 @@ _download_cancelled: set = set() # session_ids the user has requested to cancel
_download_procs: dict = {} # session_id → multiprocessing.Process running the download
def _worker_preexec():
"""Child preexec: die with the parent (PR_SET_PDEATHSIG=SIGKILL).
Download workers are spawned as plain subprocesses; without this they survive a
server/engine restart as orphans, keep holding huggingface_hub's per-blob file
lock, and make the next re-download deadlock at 0%. Tying their lifetime to the
parent means a restart cleans them up, and the re-download resumes from the
``.incomplete`` blob cleanly."""
try:
import ctypes
ctypes.CDLL("libc.so.6").prctl(1, 9, 0, 0, 0) # PR_SET_PDEATHSIG, SIGKILL
except Exception:
pass
def get_active_download_model_ids() -> set:
"""Return the set of model IDs whose download is currently in progress."""
return {
s["model_id"]
for s in _download_status.values()
if s.get("status") == "downloading"
if s.get("status") in ("starting", "downloading")
}
def _active_download_session(model_id: str, file_pattern: str):
"""Return the session_id of a live download for this exact (model_id, pattern),
or None. Used to dedup re-download clicks: a second worker for a model already
downloading would only block on huggingface_hub's per-blob file lock and sit at
0% forever, so we attach the new client to the running download instead."""
for sid, s in _download_status.items():
if (s.get("model_id") == model_id
and (s.get("file_pattern") or "") == (file_pattern or "")
and s.get("status") in ("starting", "downloading")
and sid in _download_sessions):
return sid
return None
def _url(request: Request, path: str) -> str:
"""Return a proxy-aware absolute path (root_path prefix + path)."""
from codai.api.urlutils import get_public_prefix
......@@ -797,7 +826,8 @@ def _run_download_thread(session_id: str, model_id: str, file_pattern: str, pq):
import time
import os
status = {"session_id": session_id, "model_id": model_id, "status": "starting",
status = {"session_id": session_id, "model_id": model_id, "file_pattern": file_pattern,
"status": "starting",
"percent": 0, "filename": "", "rate": 0, "eta": None}
_download_status[session_id] = status
......@@ -856,6 +886,7 @@ def _run_download_thread(session_id: str, model_id: str, file_pattern: str, pq):
proc = _sp.Popen(
[_sys.executable, "-m", "codai.admin.download_worker", model_id, file_pattern or ""],
stdout=_sp.PIPE, stderr=_sp.STDOUT, text=True, bufsize=1, env=env, cwd=_repo_root,
preexec_fn=_worker_preexec if os.name == "posix" else None,
)
_download_procs[session_id] = proc
terminal = None
......@@ -940,6 +971,15 @@ async def api_download_model(
if not model_id:
raise HTTPException(status_code=400, detail="Model ID required")
# Dedup: if this exact model is already downloading (e.g. the previous attempt
# survived a page reload, or the user clicked "download" again), attach to the
# live session instead of spawning a second worker. A duplicate worker would
# only deadlock on huggingface_hub's per-blob file lock and show 0% forever
# while the first worker quietly finishes.
existing = _active_download_session(model_id, file_pattern)
if existing:
return {"session_id": existing, "attached": True}
session_id = str(_uuid.uuid4())
pq = _q.Queue()
_download_sessions[session_id] = pq
......
......@@ -133,6 +133,12 @@ class EngineSupervisor:
self._poll_thread = None
self._logs = {} # engine_id -> deque tail
self._restart_lock = threading.RLock()
# Serialise terminal writes across engine pump threads + track whether the
# last thing we printed was an in-place tqdm progress line (so the next
# normal line finalises it with a newline).
self._log_lock = threading.Lock()
self._log_progress_tag = None # tag currently owning the \r line, or None
self._log_last_pct = {} # tag -> last printed % (non-TTY throttle)
def _assign_models(self, engines) -> None:
"""Give each engine the set of models it owns (via CODERAI_ENGINE_MODELS), so
......@@ -330,8 +336,41 @@ class EngineSupervisor:
if not line:
continue
tail.append(line)
print(f"[{tag}] {line}", flush=True)
self._note_load_progress(engine, line)
self._emit_log(tag, line)
def _emit_log(self, tag, line):
"""Print an engine log line, rendering tqdm progress bars as a single
in-place updating line (carriage return) on a TTY instead of one line per
update. On a non-TTY (redirected to a file) we keep newlines but throttle
to whole-percent changes so the log isn't flooded."""
pm = self._PROGRESS_RE.match(line)
is_progress = bool(pm)
with self._log_lock:
tty = False
try:
tty = sys.stdout.isatty()
except Exception:
pass
if is_progress and tty:
# Overwrite the current line; pad to clear any longer previous one.
print(f"\r[{tag}] {line}\033[K", end="", flush=True)
self._log_progress_tag = tag
return
if is_progress:
# Non-TTY: only emit when the integer percent advanced.
try:
pct = int(round(int(pm.group(2)) / max(1, int(pm.group(3))) * 100))
except Exception:
pct = -1
if pct == self._log_last_pct.get(tag):
return
self._log_last_pct[tag] = pct
# A normal line: finalise any in-place progress line first.
if self._log_progress_tag is not None:
print(flush=True)
self._log_progress_tag = None
print(f"[{tag}] {line}", flush=True)
def _note_load_progress(self, engine, line):
"""Track model-load progress from the engine's log stream so the front can
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment