Add model management and cache control features

parent 48215e18
...@@ -233,6 +233,11 @@ Add to Claude Desktop config (`~/Library/Application Support/Claude/claude_deskt ...@@ -233,6 +233,11 @@ Add to Claude Desktop config (`~/Library/Application Support/Claude/claude_deskt
| `videogen_update_models` | Update model database | | `videogen_update_models` | Update model database |
| `videogen_search_models` | Search HuggingFace | | `videogen_search_models` | Search HuggingFace |
| `videogen_add_model` | Add model to database | | `videogen_add_model` | Add model to database |
| `videogen_disable_model` | Disable model from auto selection |
| `videogen_enable_model` | Enable model for auto selection |
| `videogen_list_cached_models` | List cached models with sizes |
| `videogen_remove_cached_model` | Remove specific cached model |
| `videogen_clear_cache` | Clear entire local cache |
| `videogen_list_tts_voices` | List TTS voices | | `videogen_list_tts_voices` | List TTS voices |
### Skill Documentation ### Skill Documentation
...@@ -312,6 +317,25 @@ python3 videogen --add-model stabilityai/stable-video-diffusion-img2vid-xt-1.1 - ...@@ -312,6 +317,25 @@ python3 videogen --add-model stabilityai/stable-video-diffusion-img2vid-xt-1.1 -
# Show model details # Show model details
python3 videogen --show-model 1 python3 videogen --show-model 1
# Disable a model from auto selection
python3 videogen --disable-model <ID_or_name>
# Enable a model for auto selection
python3 videogen --enable-model <ID_or_name>
```
## Cache Management
```bash
# List cached models with sizes and last accessed times
python3 videogen --list-cached-models
# Remove specific cached model
python3 videogen --remove-cached-model <model_id>
# Clear entire cache
python3 videogen --clear-cache
``` ```
--- ---
......
...@@ -74,10 +74,15 @@ python3 videogen --show-model <model_id_or_name> ...@@ -74,10 +74,15 @@ python3 videogen --show-model <model_id_or_name>
| V2I | `--video input.mp4 --extract-keyframes` | Extract frames from video | | V2I | `--video input.mp4 --extract-keyframes` | Extract frames from video |
| 3D | `--video input.mp4 --convert-3d-sbs` | Convert 2D to 3D | | 3D | `--video input.mp4 --convert-3d-sbs` | Convert 2D to 3D |
### Model Filters ### Model Management
#### List and Manage Models
```bash ```bash
# List models by type # List all available models (with auto mode status)
python3 videogen --model-list
# Filter models
python3 videogen --model-list --t2v-only # Text-to-Video models python3 videogen --model-list --t2v-only # Text-to-Video models
python3 videogen --model-list --i2v-only # Image-to-Video models python3 videogen --model-list --i2v-only # Image-to-Video models
python3 videogen --model-list --t2i-only # Text-to-Image models python3 videogen --model-list --t2i-only # Text-to-Image models
...@@ -86,14 +91,29 @@ python3 videogen --model-list --v2i-only # Video-to-Image models ...@@ -86,14 +91,29 @@ python3 videogen --model-list --v2i-only # Video-to-Image models
python3 videogen --model-list --3d-only # 2D-to-3D models python3 videogen --model-list --3d-only # 2D-to-3D models
python3 videogen --model-list --tts-only # TTS models python3 videogen --model-list --tts-only # TTS models
python3 videogen --model-list --audio-only # Audio models python3 videogen --model-list --audio-only # Audio models
# List by VRAM requirement
python3 videogen --model-list --low-vram # ≤16GB VRAM python3 videogen --model-list --low-vram # ≤16GB VRAM
python3 videogen --model-list --high-vram # >30GB VRAM python3 videogen --model-list --high-vram # >30GB VRAM
python3 videogen --model-list --huge-vram # >55GB VRAM python3 videogen --model-list --huge-vram # >55GB VRAM
# List NSFW-friendly models
python3 videogen --model-list --nsfw-friendly python3 videogen --model-list --nsfw-friendly
# Disable a model from auto selection
python3 videogen --disable-model <ID_or_name>
# Enable a model for auto selection
python3 videogen --enable-model <ID_or_name>
```
#### Cache Management
```bash
# List cached models
python3 videogen --list-cached-models
# Remove specific cached model
python3 videogen --remove-cached-model <model_id>
# Clear entire cache
python3 videogen --clear-cache
``` ```
### Auto Mode ### Auto Mode
......
...@@ -576,7 +576,7 @@ MODELS_CONFIG_FILE = CONFIG_DIR / "models.json" ...@@ -576,7 +576,7 @@ MODELS_CONFIG_FILE = CONFIG_DIR / "models.json"
CACHE_FILE = CONFIG_DIR / "hf_cache.json" CACHE_FILE = CONFIG_DIR / "hf_cache.json"
AUTO_DISABLE_FILE = CONFIG_DIR / "auto_disable.json" AUTO_DISABLE_FILE = CONFIG_DIR / "auto_disable.json"
# Pipeline class to model type mapping # Pipeline class to model type mapping
PIPELINE_CLASS_MAP = { PIPELINE_CLASS_MAP = {
"StableVideoDiffusionPipeline": {"type": "i2v", "default_vram": "~14-18 GB"}, "StableVideoDiffusionPipeline": {"type": "i2v", "default_vram": "~14-18 GB"},
"WanPipeline": {"type": "video", "default_vram": "~10-24 GB"}, "WanPipeline": {"type": "video", "default_vram": "~10-24 GB"},
...@@ -591,9 +591,12 @@ PIPELINE_CLASS_MAP = { ...@@ -591,9 +591,12 @@ PIPELINE_CLASS_MAP = {
"TextToVideoSDPipeline": {"type": "t2v", "default_vram": "~7-9 GB"}, "TextToVideoSDPipeline": {"type": "t2v", "default_vram": "~7-9 GB"},
"TextToVideoZeroPipeline": {"type": "t2v", "default_vram": "~6-8 GB"}, "TextToVideoZeroPipeline": {"type": "t2v", "default_vram": "~6-8 GB"},
"MochiPipeline": {"type": "t2v", "default_vram": "~18-22 GB"}, "MochiPipeline": {"type": "t2v", "default_vram": "~18-22 GB"},
"StableDiffusionXLPipeline": {"type": "image", "default_vram": "~10-16 GB"}, "StableDiffusionXLPipeline": {"type": "t2i", "default_vram": "~10-16 GB"},
"StableDiffusion3Pipeline": {"type": "image", "default_vram": "~15-20 GB"}, "StableDiffusionXLImg2ImgPipeline": {"type": "i2i", "default_vram": "~10-16 GB"},
"FluxPipeline": {"type": "image", "default_vram": "~20-25 GB"}, "StableDiffusion3Pipeline": {"type": "t2i", "default_vram": "~15-20 GB"},
"StableDiffusion3Img2ImgPipeline": {"type": "i2i", "default_vram": "~15-20 GB"},
"FluxPipeline": {"type": "t2i", "default_vram": "~20-25 GB"},
"FluxImg2ImgPipeline": {"type": "i2i", "default_vram": "~20-25 GB"},
"AllegroPipeline": {"type": "t2v", "default_vram": "~35-45 GB"}, "AllegroPipeline": {"type": "t2v", "default_vram": "~35-45 GB"},
"HunyuanDiTPipeline": {"type": "t2v", "default_vram": "~40-55 GB"}, "HunyuanDiTPipeline": {"type": "t2v", "default_vram": "~40-55 GB"},
"OpenSoraPipeline": {"type": "video", "default_vram": "~45-65 GB"}, "OpenSoraPipeline": {"type": "video", "default_vram": "~45-65 GB"},
...@@ -674,8 +677,31 @@ def is_model_disabled(model_id, model_name=None): ...@@ -674,8 +677,31 @@ def is_model_disabled(model_id, model_name=None):
return False return False
def re_enable_model(model_id, model_name=None): def disable_model(model_id, model_name=None):
"""Re-enable a model that was disabled (called when manually selected and successful)""" """Disable a model for auto-selection"""
data = load_auto_disable_data()
key = model_id or model_name
if key not in data:
data[key] = {
"fail_count": 0,
"disabled": True,
"model_name": model_name,
"last_failure": None,
"disabled_by_user": True
}
else:
data[key]["disabled"] = True
data[key]["disabled_by_user"] = True
data[key]["fail_count"] = 0 # Reset fail count when manually disabled
save_auto_disable_data(data)
print(f" ✅ Model {model_name or model_id} disabled for --auto mode")
return True
def enable_model(model_id, model_name=None):
"""Enable a model for auto-selection"""
data = load_auto_disable_data() data = load_auto_disable_data()
key = model_id or model_name key = model_id or model_name
...@@ -683,12 +709,154 @@ def re_enable_model(model_id, model_name=None): ...@@ -683,12 +709,154 @@ def re_enable_model(model_id, model_name=None):
data[key]["disabled"] = False data[key]["disabled"] = False
data[key]["fail_count"] = 0 data[key]["fail_count"] = 0
data[key]["re_enabled"] = str(datetime.now()) data[key]["re_enabled"] = str(datetime.now())
data[key]["disabled_by_user"] = False
save_auto_disable_data(data) save_auto_disable_data(data)
print(f" ✅ Model {model_name or model_id} re-enabled for --auto mode (manual selection successful)") print(f" ✅ Model {model_name or model_id} enabled for --auto mode")
return True
return False
def list_cached_models():
"""List locally cached HuggingFace models with their sizes"""
try:
from huggingface_hub import scan_cache_dir
cache_info = scan_cache_dir()
print("\n📦 Locally cached HuggingFace models:")
print("=" * 100)
print(f"{'Model ID':<50} {'Size':<12} {'Last Accessed':<20} {'Last Modified':<20}")
print("-" * 100)
for repo in cache_info.repos:
repo_id = repo.repo_id
size = f"{repo.size_on_disk / (1024 ** 3):.2f} GB"
last_accessed = "Never"
last_modified = "Unknown"
# Get last accessed and modified from revisions
for rev in repo.revisions:
if rev.snapshot_path.exists():
stat = rev.snapshot_path.stat()
# Use last modified time from snapshot
last_modified = datetime.fromtimestamp(stat.st_mtime).strftime("%Y-%m-%d %H:%M:%S")
if stat.st_atime > 0:
last_accessed = datetime.fromtimestamp(stat.st_atime).strftime("%Y-%m-%d %H:%M:%S")
print(f"{repo_id:<50} {size:<12} {last_accessed:<20} {last_modified:<20}")
print("=" * 100)
print(f"Total: {len(cache_info.repos)} model(s) taking {cache_info.size_on_disk / (1024 ** 3):.2f} GB")
return True
except ImportError:
print("❌ huggingface-hub not installed. Install with: pip install huggingface-hub")
return False
except Exception as e:
print(f"❌ Error scanning cache: {e}")
return False
def remove_cached_model(model_id):
"""Remove a specific model from the local HuggingFace cache"""
try:
from huggingface_hub import scan_cache_dir, HfApi
import shutil
# Normalize model ID
model_id = model_id.strip().lower()
# Scan cache to find matching repos
cache_info = scan_cache_dir()
matching_repos = []
for repo in cache_info.repos:
if model_id in repo.repo_id.lower():
matching_repos.append(repo)
if not matching_repos:
print(f"❌ No cached model found matching: {model_id}")
print(" Use --list-cached-models to see available models")
return False
print(f"🔍 Found {len(matching_repos)} matching model(s) in cache:")
for repo in matching_repos:
print(f" - {repo.repo_id} ({repo.size_on_disk / (1024 ** 3):.2f} GB)")
# Confirm deletion
confirm = input("\n⚠️ Are you sure you want to delete these models? (y/N): ").strip().lower()
if confirm != 'y' and confirm != 'yes':
print("✅ Aborted - models not deleted")
return False
# Delete matching repos
deleted_count = 0
for repo in matching_repos:
try:
shutil.rmtree(repo.repo_path)
print(f"✅ Deleted: {repo.repo_id}")
deleted_count += 1
except Exception as e:
print(f"❌ Failed to delete {repo.repo_id}: {e}")
print(f"\n✅ Deleted {deleted_count} model(s) from cache")
return True
except ImportError:
print("❌ huggingface-hub not installed. Install with: pip install huggingface-hub")
return False
except Exception as e:
print(f"❌ Error removing cached model: {e}")
return False
def clear_cache():
"""Clear the entire local HuggingFace cache"""
try:
from huggingface_hub import scan_cache_dir
import shutil
cache_info = scan_cache_dir()
if not cache_info.repos:
print("✅ Cache is already empty")
return True return True
total_size = cache_info.size_on_disk / (1024 ** 3)
print(f"⚠️ Cache contains {len(cache_info.repos)} model(s) taking {total_size:.2f} GB")
# Confirm deletion
confirm = input("Are you sure you want to CLEAR THE ENTIRE CACHE? (y/N): ").strip().lower()
if confirm != 'y' and confirm != 'yes':
print("✅ Aborted - cache not cleared")
return False return False
# Get cache directory path
cache_dir = cache_info.repos.pop().repo_path.parent
# Delete all cache contents
for item in cache_dir.iterdir():
if item.is_dir() and item.name.startswith("models--"):
try:
shutil.rmtree(item)
print(f"✅ Deleted: {item.name}")
except Exception as e:
print(f"❌ Failed to delete {item.name}: {e}")
print("✅ Cache cleared successfully")
return True
except ImportError:
print("❌ huggingface-hub not installed. Install with: pip install huggingface-hub")
return False
except Exception as e:
print(f"❌ Error clearing cache: {e}")
return False
def re_enable_model(model_id, model_name=None):
"""Re-enable a model that was disabled (called when manually selected and successful)"""
return enable_model(model_id, model_name)
def get_model_fail_count(model_id, model_name=None): def get_model_fail_count(model_id, model_name=None):
"""Get the failure count for a model""" """Get the failure count for a model"""
...@@ -924,9 +1092,20 @@ def detect_pipeline_class(model_info): ...@@ -924,9 +1092,20 @@ def detect_pipeline_class(model_info):
# Check for specific image models # Check for specific image models
if "flux" in model_id: if "flux" in model_id:
return "FluxPipeline" return "FluxPipeline"
elif "sd3" in model_id or "stable-diffusion-3" in model_id:
return "StableDiffusion3Pipeline"
elif "sdxl" in model_id or "stable-diffusion-xl" in model_id:
return "StableDiffusionXLPipeline" return "StableDiffusionXLPipeline"
elif pipeline_tag_lower == "image-to-image":
return "StableDiffusionXLPipeline" return "StableDiffusionXLPipeline"
elif pipeline_tag_lower == "image-to-image":
# Check for specific image models
if "flux" in model_id:
return "FluxImg2ImgPipeline"
elif "sd3" in model_id or "stable-diffusion-3" in model_id:
return "StableDiffusion3Img2ImgPipeline"
elif "sdxl" in model_id or "stable-diffusion-xl" in model_id:
return "StableDiffusionXLImg2ImgPipeline"
return "StableDiffusionXLImg2ImgPipeline"
# 3. Check model ID patterns (specific models first) # 3. Check model ID patterns (specific models first)
# Wan models (check for version patterns) # Wan models (check for version patterns)
...@@ -996,10 +1175,20 @@ def detect_pipeline_class(model_info): ...@@ -996,10 +1175,20 @@ def detect_pipeline_class(model_info):
if "text-to-image" in tags: if "text-to-image" in tags:
if "flux" in model_id: if "flux" in model_id:
return "FluxPipeline" return "FluxPipeline"
elif "sd3" in model_id or "stable-diffusion-3" in model_id:
return "StableDiffusion3Pipeline"
elif "sdxl" in model_id or "stable-diffusion-xl" in model_id:
return "StableDiffusionXLPipeline"
return "StableDiffusionXLPipeline" return "StableDiffusionXLPipeline"
if "image-to-image" in tags: if "image-to-image" in tags:
return "StableDiffusionXLPipeline" if "flux" in model_id:
return "FluxImg2ImgPipeline"
elif "sd3" in model_id or "stable-diffusion-3" in model_id:
return "StableDiffusion3Img2ImgPipeline"
elif "sdxl" in model_id or "stable-diffusion-xl" in model_id:
return "StableDiffusionXLImg2ImgPipeline"
return "StableDiffusionXLImg2ImgPipeline"
# 5. Check library name # 5. Check library name
if library_name == "diffusers": if library_name == "diffusers":
...@@ -1284,8 +1473,19 @@ def search_hf_safetensors(query, limit=20, hf_token=None): ...@@ -1284,8 +1473,19 @@ def search_hf_safetensors(query, limit=20, hf_token=None):
elif "mochi" in model_name_lower: elif "mochi" in model_name_lower:
pipeline_class = "MochiPipeline" pipeline_class = "MochiPipeline"
elif "flux" in model_name_lower: elif "flux" in model_name_lower:
if "img2img" in model_name_lower or "image-to-image" in model_name_lower:
pipeline_class = "FluxImg2ImgPipeline"
else:
pipeline_class = "FluxPipeline" pipeline_class = "FluxPipeline"
elif "sdxl" in model_name_lower: elif "sd3" in model_name_lower or "stable-diffusion-3" in model_name_lower:
if "img2img" in model_name_lower or "image-to-image" in model_name_lower:
pipeline_class = "StableDiffusion3Img2ImgPipeline"
else:
pipeline_class = "StableDiffusion3Pipeline"
elif "sdxl" in model_name_lower or "stable-diffusion-xl" in model_name_lower:
if "img2img" in model_name_lower or "image-to-image" in model_name_lower:
pipeline_class = "StableDiffusionXLImg2ImgPipeline"
else:
pipeline_class = "StableDiffusionXLPipeline" pipeline_class = "StableDiffusionXLPipeline"
results.append({ results.append({
...@@ -1316,9 +1516,37 @@ def update_all_models(hf_token=None): ...@@ -1316,9 +1516,37 @@ def update_all_models(hf_token=None):
print("🔄 Updating model database from HuggingFace...") print("🔄 Updating model database from HuggingFace...")
print("=" * 60) print("=" * 60)
# Load existing models to preserve them # Load existing models
existing_models = load_models_config() or {} existing_models = load_models_config() or {}
print(f"📁 Preserving {len(existing_models)} existing local models") print(f"📁 Found {len(existing_models)} existing models")
# Validate existing models - check if they still exist on HuggingFace
valid_existing_models = {}
removed_count = 0
print("\n🔍 Validating existing models...")
for name, model in existing_models.items():
model_id = model.get("id")
# Skip validation for local models (not from HuggingFace)
if not model_id or "/" not in model_id:
valid_existing_models[name] = model
continue
print(f" Checking: {model_id}")
# Validate model exists on HuggingFace
model_info = validate_hf_model(model_id, hf_token=hf_token)
if model_info:
valid_existing_models[name] = model
else:
print(f" ❌ Model {model_id} not found - removing from config")
removed_count += 1
print(f"\n✅ Validated {len(valid_existing_models)} existing models")
if removed_count > 0:
print(f"❌ Removed {removed_count} models that no longer exist")
# Search queries for different model types # Search queries for different model types
search_queries = [ search_queries = [
...@@ -2081,8 +2309,8 @@ def update_all_models(hf_token=None): ...@@ -2081,8 +2309,8 @@ def update_all_models(hf_token=None):
print(f"\n" + "=" * 60) print(f"\n" + "=" * 60)
print(f"📊 Found {len(all_models)} new models from HuggingFace") print(f"📊 Found {len(all_models)} new models from HuggingFace")
# Merge with existing models (existing take precedence to preserve local configs) # Merge with valid existing models (existing take precedence to preserve local configs)
final_models = existing_models.copy() final_models = valid_existing_models.copy()
new_count = 0 new_count = 0
for name, entry in all_models.items(): for name, entry in all_models.items():
if name not in final_models: if name not in final_models:
...@@ -2093,8 +2321,10 @@ def update_all_models(hf_token=None): ...@@ -2093,8 +2321,10 @@ def update_all_models(hf_token=None):
save_models_config(final_models) save_models_config(final_models)
print(f"✅ Model database updated!") print(f"✅ Model database updated!")
print(f" Preserved: {len(existing_models)} existing models") print(f" Preserved: {len(valid_existing_models)} existing models")
print(f" Added: {new_count} new models") print(f" Added: {new_count} new models")
if removed_count > 0:
print(f" Removed: {removed_count} models that no longer exist")
print(f" Total models: {len(final_models)}") print(f" Total models: {len(final_models)}")
print(f" Config saved to: {MODELS_CONFIG_FILE}") print(f" Config saved to: {MODELS_CONFIG_FILE}")
...@@ -2194,6 +2424,9 @@ def get_pipeline_class(class_name): ...@@ -2194,6 +2424,9 @@ def get_pipeline_class(class_name):
"StableVideoDiffusionPipeline": ["StableVideoDiffusionImg2VidPipeline"], "StableVideoDiffusionPipeline": ["StableVideoDiffusionImg2VidPipeline"],
"CogVideoXPipeline": ["CogVideoXImageToVideoPipeline", "CogVideoXVideoToVideoPipeline"], "CogVideoXPipeline": ["CogVideoXImageToVideoPipeline", "CogVideoXVideoToVideoPipeline"],
"MochiPipeline": ["Mochi1Pipeline", "MochiVideoPipeline"], "MochiPipeline": ["Mochi1Pipeline", "MochiVideoPipeline"],
"FluxImg2ImgPipeline": ["FluxImageToImagePipeline"],
"StableDiffusion3Img2ImgPipeline": ["StableDiffusion3ImageToImagePipeline"],
"StableDiffusionXLImg2ImgPipeline": ["StableDiffusionXLImageToImagePipeline"],
"DiffusionPipeline": [], # No alternatives needed - it's the generic class "DiffusionPipeline": [], # No alternatives needed - it's the generic class
} }
...@@ -3028,9 +3261,28 @@ def print_model_list(args): ...@@ -3028,9 +3261,28 @@ def print_model_list(args):
# Show legend for auto column # Show legend for auto column
disabled_count = sum(1 for _, _, _, is_disabled, _ in results if is_disabled) disabled_count = sum(1 for _, _, _, is_disabled, _ in results if is_disabled)
if disabled_count > 0: if disabled_count > 0:
print(f"\n 🚫 = Auto-disabled (failed 3 times in --auto mode)") # Count user-disabled vs auto-disabled
print(f" {disabled_count} model(s) disabled for --auto mode") user_disabled_count = 0
print(f" Use --model <name> manually to re-enable a disabled model") auto_disabled_count = 0
auto_disable_data = load_auto_disable_data()
for _, info, _, is_disabled, fail_count in results:
if is_disabled:
model_id = info.get("id", "")
key = model_id or _
if key in auto_disable_data and auto_disable_data[key].get("disabled_by_user", False):
user_disabled_count += 1
else:
auto_disabled_count += 1
print(f"\n 🚫 = Disabled (either by user or auto-disabled)")
if user_disabled_count > 0:
print(f" 👤 {user_disabled_count} model(s) disabled by user")
if auto_disabled_count > 0:
print(f" 🤖 {auto_disabled_count} model(s) auto-disabled (failed 3 times)")
print(f" Use --enable-model <ID|name> to re-enable a disabled model")
print(f" Use --disable-model <ID|name> to disable a model")
print("\nFilters: --t2v-only, --i2v-only, --t2i-only, --v2v-only, --v2i-only, --3d-only, --tts-only, --audio-only") print("\nFilters: --t2v-only, --i2v-only, --t2i-only, --v2v-only, --v2i-only, --3d-only, --tts-only, --audio-only")
print(" --nsfw-friendly, --low-vram, --high-vram, --huge-vram") print(" --nsfw-friendly, --low-vram, --high-vram, --huge-vram")
...@@ -7266,6 +7518,78 @@ def main(args): ...@@ -7266,6 +7518,78 @@ def main(args):
print(f" Detected pipeline: {pipeline or 'Unknown'}") print(f" Detected pipeline: {pipeline or 'Unknown'}")
sys.exit(0) sys.exit(0)
# Handle cached models management
if args.list_cached_models:
list_cached_models()
sys.exit(0)
if args.remove_cached_model:
success = remove_cached_model(args.remove_cached_model)
sys.exit(0 if success else 1)
if args.clear_cache:
success = clear_cache()
sys.exit(0 if success else 1)
# Handle model disable/enable
if args.disable_model:
model_id_or_name = args.disable_model
if model_id_or_name.isdigit():
# Numeric ID
idx = int(model_id_or_name)
if idx > 0 and idx <= len(MODELS):
name = list(sorted(MODELS.keys()))[idx - 1]
info = MODELS[name]
model_id = info.get("id", "")
disable_model(model_id, name)
else:
print(f"❌ Model ID {idx} not found")
elif model_id_or_name in MODELS:
# Model name
info = MODELS[model_id_or_name]
model_id = info.get("id", "")
disable_model(model_id, model_id_or_name)
else:
# Check if it's a model ID
found = False
for name, info in MODELS.items():
if info.get("id", "") == model_id_or_name:
disable_model(model_id_or_name, name)
found = True
break
if not found:
print(f"❌ Model '{model_id_or_name}' not found")
sys.exit(0)
if args.enable_model:
model_id_or_name = args.enable_model
if model_id_or_name.isdigit():
# Numeric ID
idx = int(model_id_or_name)
if idx > 0 and idx <= len(MODELS):
name = list(sorted(MODELS.keys()))[idx - 1]
info = MODELS[name]
model_id = info.get("id", "")
enable_model(model_id, name)
else:
print(f"❌ Model ID {idx} not found")
elif model_id_or_name in MODELS:
# Model name
info = MODELS[model_id_or_name]
model_id = info.get("id", "")
enable_model(model_id, model_id_or_name)
else:
# Check if it's a model ID
found = False
for name, info in MODELS.items():
if info.get("id", "") == model_id_or_name:
enable_model(model_id_or_name, name)
found = True
break
if not found:
print(f"❌ Model '{model_id_or_name}' not found")
sys.exit(0)
# Handle model list # Handle model list
if args.model_list: if args.model_list:
print_model_list(args) print_model_list(args)
...@@ -9227,6 +9551,19 @@ List TTS voices: ...@@ -9227,6 +9551,19 @@ List TTS voices:
parser.add_argument("--remove-model", type=str, default=None, parser.add_argument("--remove-model", type=str, default=None,
metavar="ID_OR_NAME", metavar="ID_OR_NAME",
help="Remove a model from the local database by numeric ID (from --model-list) or name") help="Remove a model from the local database by numeric ID (from --model-list) or name")
parser.add_argument("--disable-model", type=str, default=None,
metavar="ID_OR_NAME",
help="Disable a model from auto-selection by numeric ID (from --model-list) or name")
parser.add_argument("--enable-model", type=str, default=None,
metavar="ID_OR_NAME",
help="Enable a model for auto-selection by numeric ID (from --model-list) or name")
parser.add_argument("--list-cached-models", action="store_true",
help="List locally cached HuggingFace models with their sizes")
parser.add_argument("--remove-cached-model", type=str, default=None,
metavar="MODEL_ID",
help="Remove a specific model from the local HuggingFace cache (e.g., stabilityai/stable-video-diffusion-img2vid-xt-1-1)")
parser.add_argument("--clear-cache", action="store_true",
help="Clear the entire local HuggingFace cache")
parser.add_argument("--update-models", action="store_true", parser.add_argument("--update-models", action="store_true",
help="Search HuggingFace and update model database with I2V, T2V, and NSFW models") help="Search HuggingFace and update model database with I2V, T2V, and NSFW models")
......
...@@ -653,6 +653,69 @@ async def list_tools() -> list: ...@@ -653,6 +653,69 @@ async def list_tools() -> list:
} }
), ),
Tool(
name="videogen_disable_model",
description="Disable a model from auto-selection.",
inputSchema={
"type": "object",
"properties": {
"model": {
"type": "string",
"description": "Model ID (number), name, or HuggingFace ID to disable"
}
},
"required": ["model"]
}
),
Tool(
name="videogen_enable_model",
description="Enable a model for auto-selection.",
inputSchema={
"type": "object",
"properties": {
"model": {
"type": "string",
"description": "Model ID (number), name, or HuggingFace ID to enable"
}
},
"required": ["model"]
}
),
Tool(
name="videogen_list_cached_models",
description="List locally cached HuggingFace models with their sizes.",
inputSchema={
"type": "object",
"properties": {}
}
),
Tool(
name="videogen_remove_cached_model",
description="Remove a specific model from the local HuggingFace cache.",
inputSchema={
"type": "object",
"properties": {
"model_id": {
"type": "string",
"description": "HuggingFace model ID to remove from cache (e.g., stabilityai/stable-video-diffusion-img2vid-xt-1.1)"
}
},
"required": ["model_id"]
}
),
Tool(
name="videogen_clear_cache",
description="Clear the entire local HuggingFace cache.",
inputSchema={
"type": "object",
"properties": {}
}
),
Tool( Tool(
name="videogen_list_tts_voices", name="videogen_list_tts_voices",
description="List all available TTS voices for audio generation.", description="List all available TTS voices for audio generation.",
...@@ -1118,6 +1181,31 @@ async def call_tool(name: str, arguments: dict) -> list: ...@@ -1118,6 +1181,31 @@ async def call_tool(name: str, arguments: dict) -> list:
output, code = run_videogen_command(args) output, code = run_videogen_command(args)
return [TextContent(type="text", text=output)] return [TextContent(type="text", text=output)]
elif name == "videogen_disable_model":
args = ["--disable-model", arguments["model"]]
output, code = run_videogen_command(args)
return [TextContent(type="text", text=output)]
elif name == "videogen_enable_model":
args = ["--enable-model", arguments["model"]]
output, code = run_videogen_command(args)
return [TextContent(type="text", text=output)]
elif name == "videogen_list_cached_models":
args = ["--list-cached-models"]
output, code = run_videogen_command(args)
return [TextContent(type="text", text=output)]
elif name == "videogen_remove_cached_model":
args = ["--remove-cached-model", arguments["model_id"]]
output, code = run_videogen_command(args)
return [TextContent(type="text", text=output)]
elif name == "videogen_clear_cache":
args = ["--clear-cache"]
output, code = run_videogen_command(args)
return [TextContent(type="text", text=output)]
elif name == "videogen_list_tts_voices": elif name == "videogen_list_tts_voices":
args = ["--tts-list"] args = ["--tts-list"]
output, code = run_videogen_command(args) output, code = run_videogen_command(args)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment