Commit 82735770 authored by Your Name's avatar Your Name

Unify cache functions to work with both GGUF and HuggingFace caches

- Updated get_cached_model_path() to check both coderai and HF caches
- Updated download_model() to handle both URLs and HF model IDs automatically
- Made download_huggingface_model() consistent with unified API
- Updated module docstring to reflect unified cache functionality
- All cache functions now work seamlessly with both cache types
parent 52eb402a
""" """
Model Cache - Handles model caching, downloading, and management. Model Cache - Unified model caching, downloading, and management.
This module provides functions for: This module provides unified functions that work with both:
- Model cache directory management - CoderAI GGUF cache (~/.cache/coderai/models) for downloaded GGUF files
- Checking for cached models - HuggingFace cache (~/.cache/huggingface/hub) for HF models and datasets
- Downloading models from URLs and HuggingFace
- Listing and removing cached models Functions available:
- Cache directory management for all cache types
- Checking for cached models (URLs and HF model IDs)
- Downloading models from URLs or HuggingFace
- Listing and removing cached models across all cache types
""" """
import os import os
...@@ -53,22 +57,56 @@ def get_all_cache_dirs() -> dict: ...@@ -53,22 +57,56 @@ def get_all_cache_dirs() -> dict:
return caches return caches
def get_cached_model_path(url: str) -> Optional[str]: def get_cached_model_path(model_path: str) -> Optional[str]:
"""Check if a model URL is already cached. Returns path if cached, None otherwise.""" """
cache_dir = get_model_cache_dir() Check if a model is already cached. Works with both URLs and HuggingFace model IDs.
# Create a safe filename from the URL using SHA-256 hash
url_hash = hashlib.sha256(url.encode()).hexdigest() Args:
url_path = url.split('?')[0] # Remove query params model_path: URL or HuggingFace model ID (e.g., 'TheBloke/Llama-2-7B-GGUF')
filename = os.path.basename(url_path)
if not filename: Returns:
filename = "model.gguf" Path to cached model if found, None otherwise
# Match the format used in load_model: {hash}_{filename} """
cached_filename = f"{url_hash}_{filename}" # Check if it's a HuggingFace model ID
cached_path = os.path.join(cache_dir, cached_filename) if is_huggingface_model_id(model_path):
# For HF models, check the HF cache
if os.path.exists(cached_path): caches = get_all_cache_dirs()
print(f"Using cached model: {cached_path}") hf_dir = caches.get('huggingface')
return cached_path if hf_dir:
try:
from huggingface_hub import scan_cache_dir
cache_info = scan_cache_dir(hf_dir)
# Look for the model in the cache
for repo in cache_info.repos:
if repo.repo_id == model_path:
# Return the path to the first file (typically the model file)
if repo.revisions:
latest_rev = repo.revisions[0] # Use first revision
if latest_rev.files:
file_path = os.path.join(hf_dir, latest_rev.files[0].file_path)
if os.path.exists(file_path):
print(f"Using cached HF model: {file_path}")
return file_path
except (ImportError, Exception):
pass
else:
# For URLs, check the coderai cache
cache_dir = get_model_cache_dir()
# Create a safe filename from the URL using SHA-256 hash
url_hash = hashlib.sha256(model_path.encode()).hexdigest()
url_path = model_path.split('?')[0] # Remove query params
filename = os.path.basename(url_path)
if not filename:
filename = "model.gguf"
# Match the format used in load_model: {hash}_{filename}
cached_filename = f"{url_hash}_{filename}"
cached_path = os.path.join(cache_dir, cached_filename)
if os.path.exists(cached_path):
print(f"Using cached model: {cached_path}")
return cached_path
return None return None
...@@ -78,8 +116,14 @@ def is_huggingface_model_id(path: str) -> bool: ...@@ -78,8 +116,14 @@ def is_huggingface_model_id(path: str) -> bool:
return '/' in path and not path.startswith('http://') and not path.startswith('https://') return '/' in path and not path.startswith('http://') and not path.startswith('https://')
def download_huggingface_model(model_id: str, cache_dir: str, file_pattern: str = '.gguf') -> Optional[str]: def download_huggingface_model(model_id: str, cache_dir: Optional[str] = None, file_pattern: str = '.gguf') -> Optional[str]:
"""Download a model from Hugging Face by model ID. Returns cached path or None on failure.""" """Download a model from Hugging Face by model ID. Returns cached path or None on failure."""
if cache_dir is None:
caches = get_all_cache_dirs()
cache_dir = caches.get('huggingface')
if not cache_dir:
print("No HuggingFace cache directory found")
return None
try: try:
from huggingface_hub import hf_hub_download, list_repo_files from huggingface_hub import hf_hub_download, list_repo_files
...@@ -103,66 +147,94 @@ def download_huggingface_model(model_id: str, cache_dir: str, file_pattern: str ...@@ -103,66 +147,94 @@ def download_huggingface_model(model_id: str, cache_dir: str, file_pattern: str
return None return None
def download_model(url: str, cache_dir: str) -> str: def download_model(model_path: str, cache_dir: Optional[str] = None, file_pattern: str = '.gguf') -> Optional[str]:
"""Download a model from URL with progress reporting. Returns cached path.""" """
import requests Download a model from URL or HuggingFace. Works with both URLs and model IDs.
url_path = url.split('?')[0] Args:
filename = os.path.basename(url_path) model_path: URL or HuggingFace model ID (e.g., 'TheBloke/Llama-2-7B-GGUF')
cache_dir: Cache directory to use (auto-detected if None)
# Determine file extension file_pattern: File pattern to download for HF models (default: '.gguf')
if 'gguf' in url.lower():
ext = '.gguf' Returns:
elif 'bin' in url.lower(): Path to downloaded/cached model, or None on failure
ext = '.bin' """
elif 'ggml' in url.lower(): # Check if it's a HuggingFace model ID
ext = '.ggml' if is_huggingface_model_id(model_path):
# Use HF cache if no specific cache_dir provided
if cache_dir is None:
caches = get_all_cache_dirs()
cache_dir = caches.get('huggingface')
if not cache_dir:
print("No HuggingFace cache directory found")
return None
return download_huggingface_model(model_path, cache_dir, file_pattern)
else: else:
ext = '.bin' # It's a URL - use coderai cache
if cache_dir is None:
if not filename.endswith(ext): cache_dir = get_model_cache_dir()
filename = f"model{ext}"
# Download from URL
# Create safe filename in cache import requests
url_hash = hashlib.sha256(url.encode()).hexdigest()
cached_filename = f"{url_hash}_{filename}" url = model_path
model_path = os.path.join(cache_dir, cached_filename) url_path = url.split('?')[0]
filename = os.path.basename(url_path)
# Check if already cached
if os.path.exists(model_path): # Determine file extension
print(f"Using cached model: {model_path}") if 'gguf' in url.lower():
return model_path ext = '.gguf'
elif 'bin' in url.lower():
# Download ext = '.bin'
print(f"Downloading model: {url}") elif 'ggml' in url.lower():
response = requests.get(url, stream=True) ext = '.ggml'
response.raise_for_status() else:
ext = '.bin'
total_size = int(response.headers.get('content-length', 0))
total_mb = total_size / (1024 * 1024) if total_size > 0 else 0 if not filename.endswith(ext):
filename = f"model{ext}"
downloaded = 0
start_time = time.time() # Create safe filename in cache
url_hash = hashlib.sha256(url.encode()).hexdigest()
with open(model_path, 'wb') as f: cached_filename = f"{url_hash}_{filename}"
for chunk in response.iter_content(chunk_size=8192*1024): cached_path = os.path.join(cache_dir, cached_filename)
if chunk:
f.write(chunk) # Check if already cached
downloaded += len(chunk) if os.path.exists(cached_path):
if total_size > 0: print(f"Using cached model: {cached_path}")
percent = (downloaded / total_size) * 100 return cached_path
elapsed = time.time() - start_time
speed = downloaded / elapsed if elapsed > 0 else 0 # Download
speed_mb = speed / (1024 * 1024) print(f"Downloading model: {url}")
dl_mb = downloaded / (1024 * 1024) response = requests.get(url, stream=True)
print(f"Downloaded: {percent:.1f}% ({dl_mb:.1f}/{total_mb:.1f} MB) at {speed_mb:.1f} MB/s", end='\r') response.raise_for_status()
print() # New line after progress total_size = int(response.headers.get('content-length', 0))
print(f"Downloaded and cached to: {model_path}") total_mb = total_size / (1024 * 1024) if total_size > 0 else 0
if total_mb > 0:
print(f"File size: {total_mb:.1f} MB") downloaded = 0
start_time = time.time()
return model_path
with open(cached_path, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192*1024):
if chunk:
f.write(chunk)
downloaded += len(chunk)
if total_size > 0:
percent = (downloaded / total_size) * 100
elapsed = time.time() - start_time
speed = downloaded / elapsed if elapsed > 0 else 0
speed_mb = speed / (1024 * 1024)
dl_mb = downloaded / (1024 * 1024)
print(f"Downloaded: {percent:.1f}% ({dl_mb:.1f}/{total_mb:.1f} MB) at {speed_mb:.1f} MB/s", end='\r')
print() # New line after progress
print(f"Downloaded and cached to: {cached_path}")
if total_mb > 0:
print(f"File size: {total_mb:.1f} MB")
return cached_path
def list_cached_models() -> Tuple[List[Tuple[str, str, int]], int]: def list_cached_models() -> Tuple[List[Tuple[str, str, int]], int]:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment