Fix model ID consistency with filters and add new CLI options

- Fixed model ID consistency: numeric IDs now remain the same when using
  filters like --nsfw-friendly, --t2v-only, --i2v-only, etc.
  Previously, filtered lists would renumber models making --show-model
  by numeric ID unreliable.

- Added --model-list-batch option for script-friendly output:
  Outputs 'NUMERIC_ID:FULL_MODEL_NAME' format for easy parsing

- Added --output-dir option to specify output directory:
  Sets the directory where output files will be saved

- Fixed syntax error in argparse epilog string that was causing
  'SyntaxError: invalid decimal literal'
parent 1f33b9e5
...@@ -3210,7 +3210,7 @@ def show_model_details(model_id_or_name, args): ...@@ -3210,7 +3210,7 @@ def show_model_details(model_id_or_name, args):
# Apply same filters # Apply same filters
filtered = [] filtered = []
for name, info in sorted_models: for orig_idx, (name, info) in enumerate(sorted_models, 1):
if args.i2v_only and not info.get("supports_i2v", False): if args.i2v_only and not info.get("supports_i2v", False):
continue continue
if args.t2v_only and info.get("supports_i2v", False): if args.t2v_only and info.get("supports_i2v", False):
...@@ -3230,12 +3230,23 @@ def show_model_details(model_id_or_name, args): ...@@ -3230,12 +3230,23 @@ def show_model_details(model_id_or_name, args):
est = parse_vram_estimate(info["vram"]) est = parse_vram_estimate(info["vram"])
if est == 0 or est <= 55: if est == 0 or est <= 55:
continue continue
filtered.append((name, info)) # Store original index for lookup
filtered.append((name, info, orig_idx))
if 1 <= model_idx <= len(filtered):
model_name, model = filtered[model_idx - 1] # Check if the model_idx is in valid range
# Note: model_idx should match the original index, not filtered position
if 1 <= model_idx <= len(sorted_models):
# Find the model with matching original index in filtered list
for name, info, orig_idx in filtered:
if orig_idx == model_idx:
model = info
model_name = name
break
if model is None:
print(f"❌ Model ID {model_idx} not found in filtered results")
sys.exit(1)
else: else:
print(f"❌ Model ID {model_idx} out of range (1-{len(filtered)})") print(f"❌ Model ID {model_idx} out of range (1-{len(sorted_models)})")
sys.exit(1) sys.exit(1)
except ValueError: except ValueError:
# Not a number, search by name # Not a number, search by name
...@@ -3311,6 +3322,9 @@ def show_model_details(model_id_or_name, args): ...@@ -3311,6 +3322,9 @@ def show_model_details(model_id_or_name, args):
def print_model_list(args): def print_model_list(args):
# Check if JSON output is requested # Check if JSON output is requested
json_output = getattr(args, 'json', False) json_output = getattr(args, 'json', False)
# Check if batch output is requested (script-friendly: NUMERIC_ID:FULL_MODEL_NAME)
batch_output = getattr(args, 'model_list_batch', False)
shown = 0 shown = 0
results = [] results = []
...@@ -3319,7 +3333,11 @@ def print_model_list(args): ...@@ -3319,7 +3333,11 @@ def print_model_list(args):
# Load auto-disable data for showing disabled status # Load auto-disable data for showing disabled status
auto_disable_data = load_auto_disable_data() auto_disable_data = load_auto_disable_data()
for name, info in sorted(MODELS.items()): # Create a sorted list with original indices for stable IDs
# This ensures IDs remain consistent regardless of filters
sorted_models = sorted(MODELS.items())
for orig_idx, (name, info) in enumerate(sorted_models, 1):
caps = detect_model_type(info) caps = detect_model_type(info)
# Apply filters # Apply filters
...@@ -3361,7 +3379,8 @@ def print_model_list(args): ...@@ -3361,7 +3379,8 @@ def print_model_list(args):
is_disabled = is_model_disabled(model_id, name) is_disabled = is_model_disabled(model_id, name)
fail_count = get_model_fail_count(model_id, name) fail_count = get_model_fail_count(model_id, name)
results.append((name, info, caps, is_disabled, fail_count)) # Include original index for stable IDs
results.append((name, info, caps, is_disabled, fail_count, orig_idx))
# Build JSON result # Build JSON result
if json_output: if json_output:
...@@ -3383,6 +3402,13 @@ def print_model_list(args): ...@@ -3383,6 +3402,13 @@ def print_model_list(args):
print(json.dumps(json_results, indent=2)) print(json.dumps(json_results, indent=2))
sys.exit(0) sys.exit(0)
# Batch output - script-friendly format: NUMERIC_ID:FULL_MODEL_NAME
if batch_output:
for name, info, caps, is_disabled, fail_count, orig_idx in results:
model_id = info.get("id", "")
print(f"{orig_idx}:{model_id}")
sys.exit(0)
# Print header only for non-JSON output # Print header only for non-JSON output
print("\nAvailable models (filtered):\n") print("\nAvailable models (filtered):\n")
...@@ -3393,7 +3419,9 @@ def print_model_list(args): ...@@ -3393,7 +3419,9 @@ def print_model_list(args):
print(f"{'ID':>4} {'Name':<22} {'VRAM':<9} {'T2V':<3} {'I2V':<3} {'T2I':<3} {'V2V':<3} {'V2I':<3} {'3D':<3} {'TTS':<3} {'NSFW':<4} {'LoRA':<4} {'Auto':<5}") print(f"{'ID':>4} {'Name':<22} {'VRAM':<9} {'T2V':<3} {'I2V':<3} {'T2I':<3} {'V2V':<3} {'V2I':<3} {'3D':<3} {'TTS':<3} {'NSFW':<4} {'LoRA':<4} {'Auto':<5}")
print("-" * 110) print("-" * 110)
for idx, (name, info, caps, is_disabled, fail_count) in enumerate(results, 1): for idx, (name, info, caps, is_disabled, fail_count, orig_idx) in enumerate(results, 1):
# Use original index for stable IDs (not filtered position)
display_idx = orig_idx
# Truncate name if too long # Truncate name if too long
display_name = name[:20] + ".." if len(name) > 22 else name display_name = name[:20] + ".." if len(name) > 22 else name
vram = info["vram"][:7] if len(info["vram"]) > 7 else info["vram"] vram = info["vram"][:7] if len(info["vram"]) > 7 else info["vram"]
...@@ -3420,7 +3448,7 @@ def print_model_list(args): ...@@ -3420,7 +3448,7 @@ def print_model_list(args):
if is_disabled: if is_disabled:
display_name = f"🚫{display_name[:19]}" if len(display_name) < 22 else f"🚫{display_name[:19]}.." display_name = f"🚫{display_name[:19]}" if len(display_name) < 22 else f"🚫{display_name[:19]}.."
print(f"{idx:>4} {display_name:<22} {vram:<9} {t2v:<3} {i2v:<3} {t2i:<3} {v2v:<3} {v2i:<3} {to_3d:<3} {tts:<3} {nsfw:<4} {lora:<4} {auto_status:<5}") print(f"{display_idx:>4} {display_name:<22} {vram:<9} {t2v:<3} {i2v:<3} {t2i:<3} {v2v:<3} {v2i:<3} {to_3d:<3} {tts:<3} {nsfw:<4} {lora:<4} {auto_status:<5}")
print("-" * 110) print("-" * 110)
print(f"Total shown: {shown} / {len(MODELS)} available") print(f"Total shown: {shown} / {len(MODELS)} available")
...@@ -9588,10 +9616,9 @@ def main(args): ...@@ -9588,10 +9616,9 @@ def main(args):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Universal Video Generation Toolkit with Audio Support", description="VideoGen - Universal Video Generation Toolkit",
formatter_class=argparse.RawDescriptionHelpFormatter, formatter_class=argparse.RawDescriptionHelpFormatter,
epilog=""" epilog="""
Examples:
Single GPU (simple T2V): Single GPU (simple T2V):
python3 videogen --model wan_1.3b_t2v --prompt "a cat playing piano" --length 5.0 --output cat_piano python3 videogen --model wan_1.3b_t2v --prompt "a cat playing piano" --length 5.0 --output cat_piano
...@@ -9659,6 +9686,8 @@ List TTS voices: ...@@ -9659,6 +9686,8 @@ List TTS voices:
# Model listing arguments # Model listing arguments
parser.add_argument("--model-list", action="store_true", parser.add_argument("--model-list", action="store_true",
help="Print list of all available models and exit") help="Print list of all available models and exit")
parser.add_argument("--model-list-batch", action="store_true",
help="Print list of models in script-friendly format (NUMERIC_ID:FULL_MODEL_NAME)")
parser.add_argument("--json", action="store_true", parser.add_argument("--json", action="store_true",
help="Output model list in JSON format (for web interface)") help="Output model list in JSON format (for web interface)")
parser.add_argument("--tts-list", action="store_true", parser.add_argument("--tts-list", action="store_true",
...@@ -9735,6 +9764,8 @@ List TTS voices: ...@@ -9735,6 +9764,8 @@ List TTS voices:
parser.add_argument("--height", type=int, default=480) parser.add_argument("--height", type=int, default=480)
parser.add_argument("--fps", type=int, default=15) parser.add_argument("--fps", type=int, default=15)
parser.add_argument("--output", default="output") parser.add_argument("--output", default="output")
parser.add_argument("--output-dir", default=None,
help="Directory for output files (overrides --output path)")
parser.add_argument("--seed", type=int, default=-1) parser.add_argument("--seed", type=int, default=-1)
parser.add_argument("--no_filter", action="store_true") parser.add_argument("--no_filter", action="store_true")
parser.add_argument("--upscale", action="store_true") parser.add_argument("--upscale", action="store_true")
...@@ -10040,4 +10071,15 @@ List TTS voices: ...@@ -10040,4 +10071,15 @@ List TTS voices:
help="Enable debug mode for detailed error messages and troubleshooting") help="Enable debug mode for detailed error messages and troubleshooting")
args = parser.parse_args() args = parser.parse_args()
# Handle output directory - prepend to output path if specified
if getattr(args, 'output_dir', None):
output_dir = args.output_dir
# Get the output filename (just the name, not full path)
output_name = os.path.basename(args.output) if args.output else "output"
# Combine with output directory
args.output = os.path.join(output_dir, output_name)
# Create the directory if it doesn't exist
os.makedirs(output_dir, exist_ok=True)
main(args) main(args)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment