Commit 52eb402a authored by Your Name's avatar Your Name

Fix --remove-model to work with HuggingFace repo IDs

- Updated remove_cached_model() to search by repo_id for HuggingFace models
- Moved cache management options (--list-cached-models, --remove-model, --remove-all-models) to run before heavy imports
- Improved cache operations to use centralized functions in codai.models.cache module
- Fixed model removal to work with full repo IDs like 'TheBloke/Llama-2-7B-GGUF'
parent e509279a
...@@ -25,7 +25,7 @@ def main(): ...@@ -25,7 +25,7 @@ def main():
args = parse_args() args = parse_args()
# Handle --list-cached-models early (before heavy imports) # Handle early exit options (before heavy imports)
if args.list_cached_models: if args.list_cached_models:
print("\n=== Listing Cached Models ===") print("\n=== Listing Cached Models ===")
...@@ -72,6 +72,34 @@ def main(): ...@@ -72,6 +72,34 @@ def main():
sys.exit(0) sys.exit(0)
# Handle --remove-all-models early
if args.remove_all_models:
print("\n=== Removing All Cached Models ===")
from codai.models.cache import remove_all_cached_models
total_removed = remove_all_cached_models()
print(f"\n=== Removed {total_removed} item(s) from all caches ===")
sys.exit(0)
# Handle --remove-model early
if args.remove_model:
print(f"\n=== Removing Cached Model Matching: {args.remove_model} ===")
from codai.models.cache import remove_cached_model
removed = remove_cached_model(args.remove_model)
if not removed:
print(f"No cached models found matching: {args.remove_model}")
print(f"\nUse --list-cached-models to see available models.")
sys.exit(0)
total_size = sum(size for _, _, size in removed)
print(f"\nRemoved {len(removed)} cached model file(s), freeing {total_size / (1024*1024):.1f} MB")
sys.exit(0)
# Import globals from codai modules (only after early exits) # Import globals from codai modules (only after early exits)
from codai.api import app from codai.api import app
from codai.api.state import ( from codai.api.state import (
...@@ -173,106 +201,6 @@ def main(): ...@@ -173,106 +201,6 @@ def main():
print(f"Error listing devices: {e}") print(f"Error listing devices: {e}")
sys.exit(0) sys.exit(0)
# Handle --remove-all-models
if args.remove_all_models:
print("\n=== Removing All Cached Models ===")
import shutil
caches = get_all_cache_dirs()
if not caches:
print("No cache directories found.")
sys.exit(0)
total_removed = 0
for cache_name, cache_dir in caches.items():
if not os.path.exists(cache_dir):
continue
files = os.listdir(cache_dir)
if not files:
continue
print(f"\nRemoving from {cache_name} cache ({cache_dir})...")
print(f" Found {len(files)} file(s). Deleting...")
# For diffusers, remove entire directory tree
if cache_name == 'diffusers':
for item in os.listdir(cache_dir):
item_path = os.path.join(cache_dir, item)
if os.path.isdir(item_path):
shutil.rmtree(item_path)
else:
os.remove(item_path)
print(f" Deleted: {item}")
total_removed += 1
else:
for f in files:
filepath = os.path.join(cache_dir, f)
os.remove(filepath)
print(f" Deleted: {f}")
total_removed += 1
print(f"\n=== Removed {total_removed} item(s) from all caches ===")
sys.exit(0)
# Handle --remove-model
if args.remove_model:
print(f"\n=== Removing Cached Model Matching: {args.remove_model} ===")
import shutil
caches = get_all_cache_dirs()
if not caches:
print("No cache directories found.")
sys.exit(0)
all_matching = []
for cache_name, cache_dir in caches.items():
if not os.path.exists(cache_dir):
continue
# For diffusers and huggingface, search recursively
if cache_name in ('diffusers', 'huggingface'):
for root, dirs, files in os.walk(cache_dir):
for f in files:
if args.remove_model.lower() in f.lower():
filepath = os.path.join(root, f)
rel_path = os.path.relpath(filepath, cache_dir)
size = os.path.getsize(filepath)
all_matching.append((cache_name, rel_path, filepath, size))
else:
files = os.listdir(cache_dir)
for f in files:
if args.remove_model.lower() in f.lower():
filepath = os.path.join(cache_dir, f)
if os.path.isfile(filepath):
size = os.path.getsize(filepath)
all_matching.append((cache_name, f, filepath, size))
if not all_matching:
print(f"No cached models found matching: {args.remove_model}")
print(f"\nUse --list-cached-models to see available models.")
sys.exit(0)
print(f"\nFound {len(all_matching)} matching file(s):")
for cache_name, filename, filepath, size in all_matching:
print(f" [{cache_name}] {filename} ({size / (1024*1024):.1f} MB)")
# Confirm before deleting
print(f"\nDeleting {len(all_matching)} file(s)...")
for cache_name, filename, filepath, size in all_matching:
try:
os.remove(filepath)
print(f" Deleted: [{cache_name}] {filename}")
except Exception as e:
print(f" Failed to delete {filename}: {e}")
print(f"\nRemoved {len(all_matching)} cached model file(s).")
sys.exit(0)
# Get model names from args - support multiple models # Get model names from args - support multiple models
model_names = args.model if args.model else [] model_names = args.model if args.model else []
......
...@@ -207,31 +207,58 @@ def list_cached_models() -> Tuple[List[Tuple[str, str, int]], int]: ...@@ -207,31 +207,58 @@ def list_cached_models() -> Tuple[List[Tuple[str, str, int]], int]:
def remove_cached_model(match_term: str) -> List[Tuple[str, str, int]]: def remove_cached_model(match_term: str) -> List[Tuple[str, str, int]]:
""" """
Remove cached models matching the given term. Remove cached models matching the given term.
Args: Args:
match_term: String to match against cached model names match_term: String to match against cached model names
Returns: Returns:
List of (cache_name, filename, size) for removed files List of (cache_name, filename, size) for removed files
""" """
import shutil import shutil
caches = get_all_cache_dirs() caches = get_all_cache_dirs()
all_matching = [] all_matching = []
for cache_name, cache_dir in caches.items(): for cache_name, cache_dir in caches.items():
if not os.path.exists(cache_dir): if not os.path.exists(cache_dir):
continue continue
# For diffusers and huggingface, search recursively # For diffusers and huggingface, search recursively
if cache_name in ('diffusers', 'huggingface'): if cache_name in ('diffusers', 'huggingface'):
for root, dirs, files in os.walk(cache_dir): # First, try to match by repo ID for HuggingFace models
for f in files: if cache_name == 'huggingface':
if match_term.lower() in f.lower(): try:
filepath = os.path.join(root, f) from huggingface_hub import scan_cache_dir
rel_path = os.path.relpath(filepath, cache_dir) cache_info = scan_cache_dir(cache_dir)
size = os.path.getsize(filepath)
all_matching.append((cache_name, rel_path, filepath, size)) # Check if match_term matches any repo_id
for repo in cache_info.repos:
if match_term.lower() in repo.repo_id.lower():
# Found matching repo, add all its files
for revision in repo.revisions:
for file_info in revision.files:
filepath = os.path.join(cache_dir, file_info.file_path)
if os.path.exists(filepath):
size = os.path.getsize(filepath)
rel_path = file_info.file_path
all_matching.append((cache_name, rel_path, filepath, size))
break # Only match one repo per search term
except ImportError:
# huggingface_hub not available, fall back to filename search
pass
except Exception:
# Error scanning HF cache, fall back to filename search
pass
# Fall back to filename search for both diffusers and huggingface
if not all_matching: # Only if we didn't find repo matches
for root, dirs, files in os.walk(cache_dir):
for f in files:
if match_term.lower() in f.lower():
filepath = os.path.join(root, f)
rel_path = os.path.relpath(filepath, cache_dir)
size = os.path.getsize(filepath)
all_matching.append((cache_name, rel_path, filepath, size))
else: else:
files = os.listdir(cache_dir) files = os.listdir(cache_dir)
for f in files: for f in files:
...@@ -240,7 +267,7 @@ def remove_cached_model(match_term: str) -> List[Tuple[str, str, int]]: ...@@ -240,7 +267,7 @@ def remove_cached_model(match_term: str) -> List[Tuple[str, str, int]]:
if os.path.isfile(filepath): if os.path.isfile(filepath):
size = os.path.getsize(filepath) size = os.path.getsize(filepath)
all_matching.append((cache_name, f, filepath, size)) all_matching.append((cache_name, f, filepath, size))
# Remove matching files # Remove matching files
removed = [] removed = []
for cache_name, filename, filepath, size in all_matching: for cache_name, filename, filepath, size in all_matching:
...@@ -250,7 +277,7 @@ def remove_cached_model(match_term: str) -> List[Tuple[str, str, int]]: ...@@ -250,7 +277,7 @@ def remove_cached_model(match_term: str) -> List[Tuple[str, str, int]]:
print(f" Deleted: [{cache_name}] {filename}") print(f" Deleted: [{cache_name}] {filename}")
except Exception as e: except Exception as e:
print(f" Error deleting [{cache_name}] {filename}: {e}") print(f" Error deleting [{cache_name}] {filename}: {e}")
return removed return removed
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment