Commit 82735770 authored by Your Name's avatar Your Name

Unify cache functions to work with both GGUF and HuggingFace caches

- Updated get_cached_model_path() to check both coderai and HF caches
- Updated download_model() to handle both URLs and HF model IDs automatically
- Made download_huggingface_model() consistent with unified API
- Updated module docstring to reflect unified cache functionality
- All cache functions now work seamlessly with both cache types
parent 52eb402a
"""
Model Cache - Handles model caching, downloading, and management.
Model Cache - Unified model caching, downloading, and management.
This module provides functions for:
- Model cache directory management
- Checking for cached models
- Downloading models from URLs and HuggingFace
- Listing and removing cached models
This module provides unified functions that work with both:
- CoderAI GGUF cache (~/.cache/coderai/models) for downloaded GGUF files
- HuggingFace cache (~/.cache/huggingface/hub) for HF models and datasets
Functions available:
- Cache directory management for all cache types
- Checking for cached models (URLs and HF model IDs)
- Downloading models from URLs or HuggingFace
- Listing and removing cached models across all cache types
"""
import os
......@@ -53,12 +57,45 @@ def get_all_cache_dirs() -> dict:
return caches
def get_cached_model_path(url: str) -> Optional[str]:
"""Check if a model URL is already cached. Returns path if cached, None otherwise."""
def get_cached_model_path(model_path: str) -> Optional[str]:
"""
Check if a model is already cached. Works with both URLs and HuggingFace model IDs.
Args:
model_path: URL or HuggingFace model ID (e.g., 'TheBloke/Llama-2-7B-GGUF')
Returns:
Path to cached model if found, None otherwise
"""
# Check if it's a HuggingFace model ID
if is_huggingface_model_id(model_path):
# For HF models, check the HF cache
caches = get_all_cache_dirs()
hf_dir = caches.get('huggingface')
if hf_dir:
try:
from huggingface_hub import scan_cache_dir
cache_info = scan_cache_dir(hf_dir)
# Look for the model in the cache
for repo in cache_info.repos:
if repo.repo_id == model_path:
# Return the path to the first file (typically the model file)
if repo.revisions:
latest_rev = repo.revisions[0] # Use first revision
if latest_rev.files:
file_path = os.path.join(hf_dir, latest_rev.files[0].file_path)
if os.path.exists(file_path):
print(f"Using cached HF model: {file_path}")
return file_path
except (ImportError, Exception):
pass
else:
# For URLs, check the coderai cache
cache_dir = get_model_cache_dir()
# Create a safe filename from the URL using SHA-256 hash
url_hash = hashlib.sha256(url.encode()).hexdigest()
url_path = url.split('?')[0] # Remove query params
url_hash = hashlib.sha256(model_path.encode()).hexdigest()
url_path = model_path.split('?')[0] # Remove query params
filename = os.path.basename(url_path)
if not filename:
filename = "model.gguf"
......@@ -69,6 +106,7 @@ def get_cached_model_path(url: str) -> Optional[str]:
if os.path.exists(cached_path):
print(f"Using cached model: {cached_path}")
return cached_path
return None
......@@ -78,8 +116,14 @@ def is_huggingface_model_id(path: str) -> bool:
return '/' in path and not path.startswith('http://') and not path.startswith('https://')
def download_huggingface_model(model_id: str, cache_dir: str, file_pattern: str = '.gguf') -> Optional[str]:
def download_huggingface_model(model_id: str, cache_dir: Optional[str] = None, file_pattern: str = '.gguf') -> Optional[str]:
"""Download a model from Hugging Face by model ID. Returns cached path or None on failure."""
if cache_dir is None:
caches = get_all_cache_dirs()
cache_dir = caches.get('huggingface')
if not cache_dir:
print("No HuggingFace cache directory found")
return None
try:
from huggingface_hub import hf_hub_download, list_repo_files
......@@ -103,10 +147,38 @@ def download_huggingface_model(model_id: str, cache_dir: str, file_pattern: str
return None
def download_model(url: str, cache_dir: str) -> str:
"""Download a model from URL with progress reporting. Returns cached path."""
def download_model(model_path: str, cache_dir: Optional[str] = None, file_pattern: str = '.gguf') -> Optional[str]:
"""
Download a model from URL or HuggingFace. Works with both URLs and model IDs.
Args:
model_path: URL or HuggingFace model ID (e.g., 'TheBloke/Llama-2-7B-GGUF')
cache_dir: Cache directory to use (auto-detected if None)
file_pattern: File pattern to download for HF models (default: '.gguf')
Returns:
Path to downloaded/cached model, or None on failure
"""
# Check if it's a HuggingFace model ID
if is_huggingface_model_id(model_path):
# Use HF cache if no specific cache_dir provided
if cache_dir is None:
caches = get_all_cache_dirs()
cache_dir = caches.get('huggingface')
if not cache_dir:
print("No HuggingFace cache directory found")
return None
return download_huggingface_model(model_path, cache_dir, file_pattern)
else:
# It's a URL - use coderai cache
if cache_dir is None:
cache_dir = get_model_cache_dir()
# Download from URL
import requests
url = model_path
url_path = url.split('?')[0]
filename = os.path.basename(url_path)
......@@ -126,12 +198,12 @@ def download_model(url: str, cache_dir: str) -> str:
# Create safe filename in cache
url_hash = hashlib.sha256(url.encode()).hexdigest()
cached_filename = f"{url_hash}_{filename}"
model_path = os.path.join(cache_dir, cached_filename)
cached_path = os.path.join(cache_dir, cached_filename)
# Check if already cached
if os.path.exists(model_path):
print(f"Using cached model: {model_path}")
return model_path
if os.path.exists(cached_path):
print(f"Using cached model: {cached_path}")
return cached_path
# Download
print(f"Downloading model: {url}")
......@@ -144,7 +216,7 @@ def download_model(url: str, cache_dir: str) -> str:
downloaded = 0
start_time = time.time()
with open(model_path, 'wb') as f:
with open(cached_path, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192*1024):
if chunk:
f.write(chunk)
......@@ -158,11 +230,11 @@ def download_model(url: str, cache_dir: str) -> str:
print(f"Downloaded: {percent:.1f}% ({dl_mb:.1f}/{total_mb:.1f} MB) at {speed_mb:.1f} MB/s", end='\r')
print() # New line after progress
print(f"Downloaded and cached to: {model_path}")
print(f"Downloaded and cached to: {cached_path}")
if total_mb > 0:
print(f"File size: {total_mb:.1f} MB")
return model_path
return cached_path
def list_cached_models() -> Tuple[List[Tuple[str, str, int]], int]:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment