LoRA transport: upload / id / url / inline file (no shared filesystem)

Previously a per-request LoRA could only be a local path or HF id, which
assumed the client shared the server's filesystem. Add a content-addressed
store so remote clients can supply LoRAs by value or handle.

Request `loras` spec now accepts (resolved server-side, in priority):
  id "name:<registered>"  -> a LoRA trained on this server (path-independent)
  id "sha256:<hex>"       -> a previously uploaded blob
  file/data (base64)      -> inline weights, cached in the blob store
  url                     -> server downloads (cached by content hash)
  model/path              -> legacy local path / HF id (unchanged)

- loras.py: blob store (save_lora_blob / lora_blob_exists / _lora_blob_path),
  resolve_lora_ref(), resolve_request_loras() (in-place -> clean 400 on a
  missing blob / unknown name). New POST /v1/loras/upload (multipart / JSON
  base64 / raw, dedup) and GET /v1/loras/blob/{hash} existence check.
- LoraConfig / VideoLoraConfig: model now optional; add id/url/file/data/path.
- image + video handlers resolve_request_loras() before model work, so
  signature dedup / VRAM reserve / load_lora_weights read lora.model as before.
- gen_township_fighters.py: reference trained LoRAs by id "name:<registered>"
  (derived from the server path) with the raw path kept as a co-located
  fallback, so the script works client/server-split.

Also harden video load: float(cfg.get('balanced_gpu_percent', 80)) crashed on
an explicit null (admin UI writes null for blank fields); use `or 80`.
Co-Authored-By: 's avatarClaude Opus 4.8 <noreply@anthropic.com>
parent 2b28cae2
...@@ -1250,6 +1250,11 @@ async def create_image_generation(request: ImageGenerationRequest, http_request: ...@@ -1250,6 +1250,11 @@ async def create_image_generation(request: ImageGenerationRequest, http_request:
# headroom for base weights + adapters before the pipeline loads. # headroom for base weights + adapters before the pipeline loads.
_lora_extra_gb = 0.0 _lora_extra_gb = 0.0
if getattr(request, 'loras', None): if getattr(request, 'loras', None):
# Resolve id/url/inline-file LoRA refs to local paths now (clean 400 on
# a missing blob / unknown name) before any model work; _apply_loras
# then reads lora.model as usual.
from codai.api.loras import resolve_request_loras
resolve_request_loras(request.loras)
try: try:
_lora_extra_gb = multi_model_manager._lora_vram_gb(request.loras) _lora_extra_gb = multi_model_manager._lora_vram_gb(request.loras)
except Exception: except Exception:
......
...@@ -190,6 +190,149 @@ def _lora_weight_file(name: str) -> Optional[str]: ...@@ -190,6 +190,149 @@ def _lora_weight_file(name: str) -> Optional[str]:
return None return None
# ── Content-addressed LoRA blob store ─────────────────────────────────────────
# Lets clients on a *different* machine supply a LoRA without sharing a
# filesystem: upload the file once (POST /v1/loras/upload), then reference it by
# its sha256 in generation requests. Identical files dedupe automatically, so a
# match's env/fighter LoRAs are transmitted at most once.
def _lora_blob_dir() -> str:
d = os.path.join(_loras_dir(), "_blobs")
os.makedirs(d, exist_ok=True)
return d
def _lora_blob_path(h: str) -> str:
"""Local path for a blob given its hex sha256 (accepts an optional
'sha256:' prefix). Returns the path whether or not the file exists yet."""
h = (h or "").strip()
if h.lower().startswith("sha256:"):
h = h[7:]
h = h.lower()
return os.path.join(_lora_blob_dir(), f"{h}.safetensors")
def lora_blob_exists(h: str) -> bool:
return os.path.isfile(_lora_blob_path(h))
def save_lora_blob(data: bytes) -> tuple:
"""Write `data` to the content-addressed store. Returns (hex_sha256, existed)."""
import hashlib
h = hashlib.sha256(data).hexdigest()
p = _lora_blob_path(h)
if os.path.isfile(p):
return h, True
tmp = p + ".tmp"
with open(tmp, "wb") as f:
f.write(data)
os.replace(tmp, p)
return h, False
def _spec_get(spec, key):
"""Read a field from a LoRA spec that may be a dict OR a pydantic object."""
if isinstance(spec, dict):
return spec.get(key)
return getattr(spec, key, None)
def _strip_data_uri(s: str) -> str:
"""Drop a leading 'data:...;base64,' prefix if present."""
if isinstance(s, str) and s.startswith("data:"):
comma = s.find(",")
if comma != -1:
return s[comma + 1:]
return s
def resolve_lora_ref(spec) -> Optional[str]:
"""Resolve a LoRA request spec to a concrete local path (or HF id) usable by
diffusers' load_lora_weights().
Accepts, in priority order, on a dict or LoraConfig/VideoLoraConfig object:
* id — "name:<registered>" (a trained LoRA on this server),
"sha256:<hex>" or a bare 64-char hex (an uploaded blob)
* file / data — base64 (optionally a data: URI) of the .safetensors bytes;
stored in the blob cache and used from there
* url — http(s) URL the server downloads (and caches by content hash)
* model / path — a local path or HuggingFace repo id (legacy behaviour)
Raises HTTPException(400) when a ref is given but cannot be resolved.
"""
if spec is None:
return None
# 1. id reference
ref_id = _spec_get(spec, "id")
if ref_id:
ref_id = str(ref_id).strip()
if ref_id.startswith("name:"):
name = ref_id[5:].strip()
wf = _lora_weight_file(name) or (
_lora_dir(name) if os.path.isdir(_lora_dir(name)) else None)
if not wf:
raise HTTPException(status_code=400,
detail=f"LoRA '{name}' not found on server")
return wf
# sha256:<hex> or bare hex
h = ref_id[7:] if ref_id.lower().startswith("sha256:") else ref_id
if lora_blob_exists(h):
return _lora_blob_path(h)
raise HTTPException(
status_code=400,
detail=f"LoRA blob '{ref_id}' not on server — upload it via "
f"POST /v1/loras/upload first")
# 2. inline base64 file
blob = _spec_get(spec, "file") or _spec_get(spec, "data")
if blob:
import base64
try:
raw = base64.b64decode(_strip_data_uri(str(blob)))
except Exception as e:
raise HTTPException(status_code=400, detail=f"invalid base64 LoRA file: {e}")
h, _ = save_lora_blob(raw)
return _lora_blob_path(h)
# 3. url
url = _spec_get(spec, "url")
if url:
import urllib.request
try:
with urllib.request.urlopen(str(url)) as resp:
raw = resp.read()
except Exception as e:
raise HTTPException(status_code=400, detail=f"could not download LoRA from {url}: {e}")
h, _ = save_lora_blob(raw)
return _lora_blob_path(h)
# 4. legacy path / HF id
return _spec_get(spec, "model") or _spec_get(spec, "path")
def resolve_request_loras(loras) -> None:
"""Resolve every LoRA spec in a request's `loras` list to a concrete local
path IN PLACE, writing it back onto each spec's `model` field. Call this in
the async request handler (before the heavy generation work) so a missing
blob / unknown name surfaces as a clean HTTP 400. After this, downstream code
that reads `lora.model` (signature dedup, VRAM estimate, load_lora_weights)
works unchanged regardless of how the client supplied the weights."""
if not loras:
return
for lora in loras:
path = resolve_lora_ref(lora)
if not path:
continue
if isinstance(lora, dict):
lora["model"] = path
else:
try:
lora.model = path
except Exception:
pass
def _require_api_auth(request: Request) -> None: def _require_api_auth(request: Request) -> None:
"""Raise 401 if auth is enabled and the request carries no valid credential.""" """Raise 401 if auth is enabled and the request carries no valid credential."""
try: try:
...@@ -1535,6 +1678,53 @@ async def lora_progress(job: Optional[str] = None, session: Optional[str] = None ...@@ -1535,6 +1678,53 @@ async def lora_progress(job: Optional[str] = None, session: Optional[str] = None
return dict(_progress) return dict(_progress)
@router.post("/v1/loras/upload")
async def upload_lora(request: Request, _auth=Depends(_require_api_auth)):
"""Upload a LoRA file into the content-addressed store so a remote client can
use it without sharing a filesystem.
Accepts the file as multipart/form-data (field `file`), as JSON
{"file": "<base64>"} (optionally a data: URI), or as a raw request body.
Returns {id: "sha256:<hex>", bytes, existed}. Reference the returned id as
{"id": "<id>", "weight": ...} in image/video generation requests."""
ctype = request.headers.get("content-type", "")
data = b""
if "multipart/form-data" in ctype:
form = await request.form()
up = form.get("file")
if up is None:
raise HTTPException(status_code=400, detail="multipart upload missing 'file' field")
data = await up.read() if hasattr(up, "read") else bytes(up)
elif "application/json" in ctype:
body = await request.json()
blob = body.get("file") or body.get("data")
if not blob:
raise HTTPException(status_code=400, detail="JSON upload missing 'file'/'data' (base64)")
import base64
try:
data = base64.b64decode(_strip_data_uri(str(blob)))
except Exception as e:
raise HTTPException(status_code=400, detail=f"invalid base64: {e}")
else:
data = await request.body()
if not data:
raise HTTPException(status_code=400, detail="empty upload")
h, existed = save_lora_blob(data)
return {"id": f"sha256:{h}", "bytes": len(data), "existed": existed}
@router.get("/v1/loras/blob/{hash}")
async def lora_blob_info(hash: str, _auth=Depends(_require_api_auth)):
"""Existence check for an uploaded LoRA blob. 200 with metadata when present,
404 when absent — lets a client skip re-uploading a file the server already
has. `hash` may be a hex sha256 or 'sha256:<hex>'."""
if not lora_blob_exists(hash):
raise HTTPException(status_code=404, detail="blob not found")
p = _lora_blob_path(hash)
return {"id": f"sha256:{os.path.basename(p)[:-len('.safetensors')]}",
"bytes": os.path.getsize(p), "exists": True}
@router.get("/v1/loras") @router.get("/v1/loras")
async def list_loras(_auth=Depends(_require_api_auth)): async def list_loras(_auth=Depends(_require_api_auth)):
out = [] out = []
......
...@@ -1922,6 +1922,11 @@ async def video_generations(request: VideoGenerationRequest, ...@@ -1922,6 +1922,11 @@ async def video_generations(request: VideoGenerationRequest,
# headroom for base weights + adapters before the pipeline loads. # headroom for base weights + adapters before the pipeline loads.
_lora_extra_gb = 0.0 _lora_extra_gb = 0.0
if getattr(request, 'loras', None): if getattr(request, 'loras', None):
# Resolve any id/url/inline-file LoRA refs to concrete local paths now, in
# the async handler, so a missing blob / unknown name returns a clean 400
# before we touch the model. Downstream code then reads lora.model as usual.
from codai.api.loras import resolve_request_loras
resolve_request_loras(request.loras)
try: try:
_lora_extra_gb = multi_model_manager._lora_vram_gb(request.loras) _lora_extra_gb = multi_model_manager._lora_vram_gb(request.loras)
except Exception: except Exception:
......
...@@ -22,7 +22,16 @@ from pydantic import BaseModel, ConfigDict ...@@ -22,7 +22,16 @@ from pydantic import BaseModel, ConfigDict
class LoraConfig(BaseModel): class LoraConfig(BaseModel):
model: str """A LoRA adapter for one image request. Weights may be supplied (resolved
server-side, in priority) via `id` ("name:<registered>" or "sha256:<hex>"),
inline `file`/`data` base64, a `url`, or the legacy `model`/`path` (local path
/ HF id) — so a remote client needn't share the server's filesystem."""
model: Optional[str] = None
path: Optional[str] = None
id: Optional[str] = None
url: Optional[str] = None
file: Optional[str] = None
data: Optional[str] = None
weight: float = 1.0 weight: float = 1.0
name: Optional[str] = None name: Optional[str] = None
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow")
......
...@@ -21,10 +21,21 @@ from pydantic import BaseModel, ConfigDict ...@@ -21,10 +21,21 @@ from pydantic import BaseModel, ConfigDict
class VideoLoraConfig(BaseModel): class VideoLoraConfig(BaseModel):
"""A LoRA adapter to apply to the video diffusion pipeline for one request.""" """A LoRA adapter to apply to the video diffusion pipeline for one request.
model: str # path or HF id of the LoRA weights
The weights may be supplied in several ways (resolved server-side, in this
priority): `id` ("name:<registered>" trained LoRA, or "sha256:<hex>" uploaded
blob), inline `file`/`data` base64, a `url` to download, or the legacy
`model`/`path` local path / HF id. This lets a client on a different machine
use a LoRA without sharing a filesystem with the server."""
model: Optional[str] = None # legacy: local path or HF id of the weights
path: Optional[str] = None # alias of model
id: Optional[str] = None # "name:<registered>" or "sha256:<hex>"
url: Optional[str] = None # http(s) URL the server downloads
file: Optional[str] = None # base64 of the .safetensors (or data: URI)
data: Optional[str] = None # alias of file
weight: float = 1.0 weight: float = 1.0
name: Optional[str] = None name: Optional[str] = None # adapter name
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow")
......
...@@ -1440,13 +1440,46 @@ def _set_lora_job_id(out_dir: Path, lora_name: str, job_id): ...@@ -1440,13 +1440,46 @@ def _set_lora_job_id(out_dir: Path, lora_name: str, job_id):
pass pass
def _lora_name_ref(path: str):
"""Derive the server-side registered LoRA name from a trained-LoRA path.
LoRAs are trained on the coderai server, which saves them under
<loras_dir>/<registered_name>/pytorch_lora_weights.safetensors and returns
that path. The *registered name* (the directory holding the weights) is a
filesystem-path-independent handle the server can resolve no matter where it
runs — so we reference LoRAs by `id: "name:<registered>"` instead of a raw
path that's only meaningful if client and server share a disk. Returns None
when the path doesn't look like a trained-LoRA layout."""
if not path:
return None
p = str(path).rstrip("/\\")
base = os.path.basename(p)
if base.lower().endswith((".safetensors", ".bin", ".pt", ".ckpt")):
reg = os.path.basename(os.path.dirname(p))
else:
reg = base # path is the LoRA directory itself
return f"name:{reg}" if reg else None
def _lora_spec(path: str, weight: float, name: str) -> dict:
"""One `loras` request entry. References the server-registered LoRA by name
(works even when client and server are on different machines) and keeps the
raw path as a legacy fallback for co-located setups."""
spec = {"weight": float(weight), "name": name}
ref = _lora_name_ref(path)
if ref:
spec["id"] = ref
spec["model"] = path # fallback: used only if the id can't be resolved
return spec
def _lora_specs_for(fighters: list, lora_map: dict, weight: float) -> list: def _lora_specs_for(fighters: list, lora_map: dict, weight: float) -> list:
"""Build the `loras` request list for the fighters appearing in a clip.""" """Build the `loras` request list for the fighters appearing in a clip."""
specs = [] specs = []
for f in fighters: for f in fighters:
path = (lora_map or {}).get(f) path = (lora_map or {}).get(f)
if path: if path:
specs.append({"model": path, "weight": float(weight), "name": f}) specs.append(_lora_spec(path, weight, f))
return specs return specs
...@@ -1456,7 +1489,7 @@ def _env_lora_specs_for(env: str, env_lora_map: dict, weight: float) -> list: ...@@ -1456,7 +1489,7 @@ def _env_lora_specs_for(env: str, env_lora_map: dict, weight: float) -> list:
return [] return []
path = (env_lora_map or {}).get(env) path = (env_lora_map or {}).get(env)
if path: if path:
return [{"model": path, "weight": float(weight), "name": f"env_{env}"}] return [_lora_spec(path, weight, f"env_{env}")]
return [] return []
...@@ -1478,7 +1511,7 @@ def _video_lora_specs_for(fighters: list, vmap: dict, slug: str, weight: float) ...@@ -1478,7 +1511,7 @@ def _video_lora_specs_for(fighters: list, vmap: dict, slug: str, weight: float)
for f in fighters: for f in fighters:
path = _video_lora_path((vmap or {}).get(f), slug) path = _video_lora_path((vmap or {}).get(f), slug)
if path: if path:
specs.append({"model": path, "weight": float(weight), "name": f}) specs.append(_lora_spec(path, weight, f))
return specs return specs
...@@ -1487,7 +1520,7 @@ def _env_video_lora_specs_for(env: str, env_vmap: dict, slug: str, weight: float ...@@ -1487,7 +1520,7 @@ def _env_video_lora_specs_for(env: str, env_vmap: dict, slug: str, weight: float
return [] return []
path = _video_lora_path((env_vmap or {}).get(env), slug) path = _video_lora_path((env_vmap or {}).get(env), slug)
if path: if path:
return [{"model": path, "weight": float(weight), "name": f"env_{env}"}] return [_lora_spec(path, weight, f"env_{env}")]
return [] return []
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment