Add auto-disable feature for models that fail 3 times in --auto mode

- Add auto_disable.json to track failure counts and disabled status
- Models that fail 3 times in auto mode are automatically disabled
- Disabled models are skipped during auto model selection
- Manual selection of a disabled model re-enables it for auto mode
- Model list now shows 'Auto' column with status (Yes, OFF, or X/3)
- Disabled models shown with 🚫 indicator in model list
- New functions: load_auto_disable_data(), save_auto_disable_data(),
  record_model_failure(), is_model_disabled(), re_enable_model(),
  get_model_fail_count()
parent b1e602e5
Pipeline #235 canceled with stages
......@@ -134,6 +134,7 @@ except ImportError:
CONFIG_DIR = Path.home() / ".config" / "videogen"
MODELS_CONFIG_FILE = CONFIG_DIR / "models.json"
CACHE_FILE = CONFIG_DIR / "hf_cache.json"
AUTO_DISABLE_FILE = CONFIG_DIR / "auto_disable.json"
# Pipeline class to model type mapping
PIPELINE_CLASS_MAP = {
......@@ -166,6 +167,96 @@ def ensure_config_dir():
CONFIG_DIR.mkdir(parents=True, exist_ok=True)
def load_auto_disable_data():
"""Load auto-disable tracking data (failure counts and disabled status)"""
ensure_config_dir()
if AUTO_DISABLE_FILE.exists():
try:
with open(AUTO_DISABLE_FILE, 'r') as f:
return json.load(f)
except Exception as e:
print(f"⚠️ Could not load auto-disable data: {e}")
return {}
def save_auto_disable_data(data):
"""Save auto-disable tracking data"""
ensure_config_dir()
try:
with open(AUTO_DISABLE_FILE, 'w') as f:
json.dump(data, f, indent=2)
except Exception as e:
print(f"⚠️ Could not save auto-disable data: {e}")
def record_model_failure(model_name, model_id):
"""Record a model failure in auto mode. Returns True if model should be disabled."""
data = load_auto_disable_data()
# Use model_id as key for consistency
key = model_id or model_name
if key not in data:
data[key] = {
"fail_count": 0,
"disabled": False,
"model_name": model_name,
"last_failure": None
}
data[key]["fail_count"] += 1
data[key]["last_failure"] = str(datetime.now())
# Disable after 3 failures
if data[key]["fail_count"] >= 3:
data[key]["disabled"] = True
print(f" 🚫 Model {model_name} has failed {data[key]['fail_count']} times - AUTO-DISABLED for --auto mode")
save_auto_disable_data(data)
return data[key]["disabled"]
def is_model_disabled(model_id, model_name=None):
"""Check if a model is disabled for auto mode"""
data = load_auto_disable_data()
key = model_id or model_name
if key in data:
return data[key].get("disabled", False)
return False
def re_enable_model(model_id, model_name=None):
"""Re-enable a model that was disabled (called when manually selected and successful)"""
data = load_auto_disable_data()
key = model_id or model_name
if key in data and data[key].get("disabled", False):
data[key]["disabled"] = False
data[key]["fail_count"] = 0
data[key]["re_enabled"] = str(datetime.now())
save_auto_disable_data(data)
print(f" ✅ Model {model_name or model_id} re-enabled for --auto mode (manual selection successful)")
return True
return False
def get_model_fail_count(model_id, model_name=None):
"""Get the failure count for a model"""
data = load_auto_disable_data()
key = model_id or model_name
if key in data:
return data[key].get("fail_count", 0)
return 0
def load_models_config():
"""Load models from external config file"""
ensure_config_dir()
......@@ -1959,6 +2050,9 @@ def print_model_list(args):
shown = 0
results = []
# Load auto-disable data for showing disabled status
auto_disable_data = load_auto_disable_data()
for name, info in sorted(MODELS.items()):
if args.i2v_only and not info.get("supports_i2v", False):
continue
......@@ -1982,18 +2076,24 @@ def print_model_list(args):
shown += 1
caps = detect_model_type(info)
results.append((name, info, caps))
# Check if model is disabled for auto mode
model_id = info.get("id", "")
is_disabled = is_model_disabled(model_id, name)
fail_count = get_model_fail_count(model_id, name)
results.append((name, info, caps, is_disabled, fail_count))
if shown == 0:
print("No models match the selected filters.")
else:
# Print table header
print(f"{'ID':>4} {'Name':<28} {'VRAM':<11} {'I2V':<4} {'T2V':<4} {'T2I':<4} {'I2I':<4} {'NSFW':<5} {'LoRA':<5}")
print("-" * 95)
# Print table header with new Auto column
print(f"{'ID':>4} {'Name':<26} {'VRAM':<11} {'I2V':<4} {'T2V':<4} {'T2I':<4} {'I2I':<4} {'NSFW':<5} {'LoRA':<5} {'Auto':<6}")
print("-" * 100)
for idx, (name, info, caps) in enumerate(results, 1):
for idx, (name, info, caps, is_disabled, fail_count) in enumerate(results, 1):
# Truncate name if too long
display_name = name[:26] + ".." if len(name) > 28 else name
display_name = name[:24] + ".." if len(name) > 26 else name
vram = info["vram"][:9] if len(info["vram"]) > 9 else info["vram"]
i2v = "Yes" if caps["i2v"] else "-"
......@@ -2003,11 +2103,30 @@ def print_model_list(args):
nsfw = "Yes" if caps["nsfw"] else "-"
lora = "Yes" if caps["lora"] else "-"
print(f"{idx:>4} {display_name:<28} {vram:<11} {i2v:<4} {t2v:<4} {t2i:<4} {i2i:<4} {nsfw:<5} {lora:<5}")
# Show auto status
if is_disabled:
auto_status = "OFF"
elif fail_count > 0:
auto_status = f"{fail_count}/3"
else:
auto_status = "Yes"
# Add indicator for disabled models
if is_disabled:
display_name = f"🚫{display_name[:23]}" if len(display_name) < 26 else f"🚫{display_name[:23]}.."
print("-" * 95)
print(f"{idx:>4} {display_name:<26} {vram:<11} {i2v:<4} {t2v:<4} {t2i:<4} {i2i:<4} {nsfw:<5} {lora:<5} {auto_status:<6}")
print("-" * 100)
print(f"Total shown: {shown} / {len(MODELS)} available")
# Show legend for auto column
disabled_count = sum(1 for _, _, _, is_disabled, _ in results if is_disabled)
if disabled_count > 0:
print(f"\n 🚫 = Auto-disabled (failed 3 times in --auto mode)")
print(f" {disabled_count} model(s) disabled for --auto mode")
print(f" Use --model <name> manually to re-enable a disabled model")
print("\nUse --model <name> to select a model.")
print("Use --show-model <ID|name> to see full model details.")
sys.exit(0)
......@@ -2288,12 +2407,23 @@ def select_best_model(gen_type, models, vram_gb=24, prefer_quality=True, return_
LoRA adapters are now considered alongside base models. When a LoRA is selected,
the returned info includes 'is_lora': True and 'base_model' for the main pipeline
to load the base model first, then apply the LoRA adapter.
Auto-Disable Support:
Models that have been disabled due to repeated failures in auto mode are skipped.
"""
candidates = []
is_nsfw = gen_type.get("is_nsfw", False)
gen_type_str = gen_type.get("type", "t2v")
# Load auto-disable data
auto_disable_data = load_auto_disable_data()
for name, info in models.items():
# Skip models that are disabled for auto mode
model_id = info.get("id", "")
if is_model_disabled(model_id, name):
continue # Skip disabled models
is_lora = info.get("is_lora", False)
base_model_id = info.get("base_model")
......@@ -3659,6 +3789,9 @@ def main(args):
print(f" The model was explicitly provided with --model, not retrying with alternatives.")
print(f" Please verify the model exists or try a different model.")
else:
# Record the failure for auto-disable tracking
record_model_failure(args.model, model_id_to_load)
# If this was a LoRA with a base model, track the failed base model
if is_lora and base_model_id:
failed_base_models.add(base_model_id)
......@@ -3666,13 +3799,20 @@ def main(args):
print(f" ⚠️ Base model failed: {base_model_id}")
print(f" Will skip other LoRAs depending on this base model")
# Find next valid alternative (skip LoRAs with failed base models)
# Find next valid alternative (skip LoRAs with failed base models AND disabled models)
next_model = None
skipped_loras = []
skipped_disabled = []
while alternative_models:
candidate_name, candidate_info, candidate_reason = alternative_models.pop(0)
# Check if this model is disabled for auto mode
candidate_id = candidate_info.get("id", "")
if is_model_disabled(candidate_id, candidate_name):
skipped_disabled.append((candidate_name, candidate_id))
continue # Skip disabled model
# Check if this is a LoRA with a failed base model
if candidate_info.get("is_lora", False):
candidate_base = candidate_info.get("base_model") or candidate_info.get("_inferred_base_model")
......@@ -3690,6 +3830,9 @@ def main(args):
if skipped_loras:
print(f" ⏭️ Skipped {len(skipped_loras)} LoRA(s) with failed base models")
if skipped_disabled:
print(f" ⏭️ Skipped {len(skipped_disabled)} auto-disabled model(s)")
if retry_count < max_retries and next_model:
# We have a valid alternative - retry with it
args._retry_count = retry_count + 1
......@@ -4246,6 +4389,11 @@ def main(args):
timing.print_summary()
print(f"✨ Done! Seed: {seed}")
# Re-enable model if it was disabled and this was a manual selection that succeeded
if not getattr(args, '_auto_mode', False) and getattr(args, '_user_specified_model', False):
model_id = m_info.get("id", "")
re_enable_model(model_id, args.model)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment