Township: per-profile LoRA training, Matches page, modals; fix SDXL LoRA dtype

LoRA training (codai/api/loras.py):
- Fix instant crash "mat1 and mat2 have the same dtype (Half != float)":
  components loaded in their checkpoint's native dtype (fp16 text encoders +
  fp32 UNet). Force consistent fp32 (stable LoRA reference) for both SD1.5 and
  SDXL, and free VAE/text encoders after pre-encoding to keep memory in budget.

Web UI (tools/gen_township_fighters.py):
- Per-profile LoRA training launched from the Characters/Environments cards
  (steps/rank, live server progress, status badge); the "Train LoRAs" step now
  trains regardless of the consistency setting and logs what it does. Profiles
  are uploaded to CoderAI first so server-side training has reference images.
- New Matches page: final short/long videos, single clips, and outcome outputs
  per match. Re-render clips/outputs with the video model (reuses the render
  path, honoring keyframe/LoRA settings) or reassemble finals from existing
  clips (no model).
- Replace browser alert/confirm/prompt with proper in-page modal dialogs
  (uiConfirm/uiAlert/uiPrompt) on every page.
- Apply changed connection/model settings (image/video/text model) to the live
  session on Save config / Start, so per-profile jobs use them immediately.
Co-Authored-By: 's avatarClaude Opus 4.8 <noreply@anthropic.com>
parent 278fab3e
...@@ -311,10 +311,12 @@ def _train_sd15(req, base_path, images, instance_prompt, ...@@ -311,10 +311,12 @@ def _train_sd15(req, base_path, images, instance_prompt,
name = req.name name = req.name
g = torch.Generator(device=device).manual_seed(seed) g = torch.Generator(device=device).manual_seed(seed)
# Consistent fp32 precision (see _train_sdxl) to avoid mixed-dtype crashes.
weight_dtype = torch.float32
tokenizer = CLIPTokenizer.from_pretrained(base_path, subfolder="tokenizer") tokenizer = CLIPTokenizer.from_pretrained(base_path, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(base_path, subfolder="text_encoder").to(device) text_encoder = CLIPTextModel.from_pretrained(base_path, subfolder="text_encoder").to(device, dtype=weight_dtype)
vae = AutoencoderKL.from_pretrained(base_path, subfolder="vae").to(device) vae = AutoencoderKL.from_pretrained(base_path, subfolder="vae").to(device, dtype=weight_dtype)
unet = UNet2DConditionModel.from_pretrained(base_path, subfolder="unet").to(device) unet = UNet2DConditionModel.from_pretrained(base_path, subfolder="unet").to(device, dtype=weight_dtype)
noise_scheduler = DDPMScheduler.from_pretrained(base_path, subfolder="scheduler") noise_scheduler = DDPMScheduler.from_pretrained(base_path, subfolder="scheduler")
vae.requires_grad_(False) vae.requires_grad_(False)
...@@ -338,6 +340,13 @@ def _train_sd15(req, base_path, images, instance_prompt, ...@@ -338,6 +340,13 @@ def _train_sd15(req, base_path, images, instance_prompt,
return_tensors="pt").input_ids.to(device) return_tensors="pt").input_ids.to(device)
encoder_hidden_states = text_encoder(tok)[0] encoder_hidden_states = text_encoder(tok)[0]
# VAE + text encoder are done; free them so only the UNet trains resident.
del vae, text_encoder
try:
torch.cuda.empty_cache()
except Exception:
pass
_set_progress(status="training", message="training (SD1.5)") _set_progress(status="training", message="training (SD1.5)")
unet.train() unet.train()
n = len(latents_list) n = len(latents_list)
...@@ -379,7 +388,7 @@ def _train_sd15(req, base_path, images, instance_prompt, ...@@ -379,7 +388,7 @@ def _train_sd15(req, base_path, images, instance_prompt,
_write_meta(name, req, base_path, len(images), "sd15", instance_prompt) _write_meta(name, req, base_path, len(images), "sd15", instance_prompt)
# Release training tensors. # Release training tensors.
del unet, vae, text_encoder, optimizer, latents_list del unet, optimizer, latents_list
try: try:
torch.cuda.empty_cache() torch.cuda.empty_cache()
except Exception: except Exception:
...@@ -405,12 +414,19 @@ def _train_sdxl(req, base_path, images, instance_prompt, ...@@ -405,12 +414,19 @@ def _train_sdxl(req, base_path, images, instance_prompt,
name = req.name name = req.name
g = torch.Generator(device=device).manual_seed(seed) g = torch.Generator(device=device).manual_seed(seed)
# Train in a single consistent precision. Checkpoints can store each
# component in a different native dtype (e.g. fp16 text encoders + fp32
# UNet), which otherwise crashes with "mat1 and mat2 have the same dtype"
# (Half != float) in cross-attention. fp32 is the stable reference for
# LoRA fine-tuning.
weight_dtype = torch.float32
tokenizer_1 = CLIPTokenizer.from_pretrained(base_path, subfolder="tokenizer") tokenizer_1 = CLIPTokenizer.from_pretrained(base_path, subfolder="tokenizer")
tokenizer_2 = CLIPTokenizer.from_pretrained(base_path, subfolder="tokenizer_2") tokenizer_2 = CLIPTokenizer.from_pretrained(base_path, subfolder="tokenizer_2")
text_encoder_1 = CLIPTextModel.from_pretrained(base_path, subfolder="text_encoder").to(device) text_encoder_1 = CLIPTextModel.from_pretrained(base_path, subfolder="text_encoder").to(device, dtype=weight_dtype)
text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(base_path, subfolder="text_encoder_2").to(device) text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(base_path, subfolder="text_encoder_2").to(device, dtype=weight_dtype)
vae = AutoencoderKL.from_pretrained(base_path, subfolder="vae").to(device) vae = AutoencoderKL.from_pretrained(base_path, subfolder="vae").to(device, dtype=weight_dtype)
unet = UNet2DConditionModel.from_pretrained(base_path, subfolder="unet").to(device) unet = UNet2DConditionModel.from_pretrained(base_path, subfolder="unet").to(device, dtype=weight_dtype)
noise_scheduler = DDPMScheduler.from_pretrained(base_path, subfolder="scheduler") noise_scheduler = DDPMScheduler.from_pretrained(base_path, subfolder="scheduler")
for m in (vae, text_encoder_1, text_encoder_2, unet): for m in (vae, text_encoder_1, text_encoder_2, unet):
...@@ -458,6 +474,14 @@ def _train_sdxl(req, base_path, images, instance_prompt, ...@@ -458,6 +474,14 @@ def _train_sdxl(req, base_path, images, instance_prompt,
device=device, dtype=prompt_embeds.dtype, device=device, dtype=prompt_embeds.dtype,
) )
# VAE + text encoders are no longer needed during the training loop; free
# them so only the UNet stays resident (keeps SDXL fp32 training in budget).
del vae, text_encoder_1, text_encoder_2
try:
torch.cuda.empty_cache()
except Exception:
pass
_set_progress(status="training", message="training (SDXL)") _set_progress(status="training", message="training (SDXL)")
unet.train() unet.train()
n = len(latents_list) n = len(latents_list)
...@@ -498,7 +522,7 @@ def _train_sdxl(req, base_path, images, instance_prompt, ...@@ -498,7 +522,7 @@ def _train_sdxl(req, base_path, images, instance_prompt,
safe_serialization=True) safe_serialization=True)
_write_meta(name, req, base_path, len(images), "sdxl", instance_prompt) _write_meta(name, req, base_path, len(images), "sdxl", instance_prompt)
del unet, vae, text_encoder_1, text_encoder_2, optimizer, latents_list del unet, optimizer, latents_list
try: try:
torch.cuda.empty_cache() torch.cuda.empty_cache()
except Exception: except Exception:
......
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment