Commit bff24350 authored by Your Name's avatar Your Name

Implement intelligent model loading for local files, URLs, and HF IDs

- Updated load_model() to handle three input types:
  1. Local files: Use directly without caching
  2. URLs: Download to cache if not cached, then use
  3. HF model IDs: Download via HF API if not cached, then use
- Updated get_cached_model_path() to validate local files
- Enhanced module documentation to reflect new capabilities
- All model types (text, image, audio, etc.) can now use any input type
parent 3e3067a9
"""
Model Cache - Unified model caching, downloading, and management.
Model Cache - Unified model loading, caching, downloading, and management.
This module provides unified functions that work with both:
- CoderAI GGUF cache (~/.cache/coderai/models) for downloaded GGUF files
- HuggingFace cache (~/.cache/huggingface/hub) for HF models and datasets
This module provides unified functions that handle three types of model sources:
1. Local files: Used directly without caching
2. URLs: Downloaded to CoderAI cache (~/.cache/coderai/models) if not cached
3. HF model IDs: Downloaded to HF cache (~/.cache/huggingface/hub) if not cached
All functions work seamlessly across:
- CoderAI GGUF cache for downloaded files
- HuggingFace cache for HF models and datasets
- Direct local file access
Functions available:
- Intelligent model loading with automatic source detection
- Cache directory management for all cache types
- Checking for cached models (URLs and HF model IDs)
- Downloading models from URLs or HuggingFace
- Checking for cached models (URLs, HF IDs) and local file validation
- Downloading models from URLs or HuggingFace with caching
- Listing and removing cached models across all cache types
"""
......@@ -59,15 +66,19 @@ def get_all_cache_dirs() -> dict:
def get_cached_model_path(model_path: str) -> Optional[str]:
"""
Check if a model is already cached. Works with both URLs and HuggingFace model IDs.
Check if a model is already cached. Works with URLs, HF model IDs, and local files.
Args:
model_path: URL or HuggingFace model ID (e.g., 'TheBloke/Llama-2-7B-GGUF')
model_path: Local path, URL, or HuggingFace model ID
Returns:
Path to cached model if found, None otherwise
Path to cached model if found, local path if local file exists, None otherwise
"""
# Check if it's a HuggingFace model ID
# 1. Check if it's a local file - return directly if it exists
if os.path.isfile(model_path):
return model_path
# 2. Check if it's a HuggingFace model ID
if is_huggingface_model_id(model_path):
# For HF models, check the HF cache
caches = get_all_cache_dirs()
......@@ -91,7 +102,7 @@ def get_cached_model_path(model_path: str) -> Optional[str]:
except (ImportError, Exception):
pass
else:
# For URLs, check the coderai cache
# 3. For URLs, check the coderai cache
cache_dir = get_model_cache_dir()
# Create a safe filename from the URL using SHA-256 hash
url_hash = hashlib.sha256(model_path.encode()).hexdigest()
......@@ -149,24 +160,41 @@ def download_huggingface_model(model_id: str, cache_dir: Optional[str] = None, f
def load_model(model_path: str, cache_dir: Optional[str] = None, file_pattern: str = '.gguf') -> Optional[str]:
"""
Load a model - check cache first, download if needed. Works with URLs and HF model IDs.
Load a model from various sources with intelligent caching.
This is the main entry point for model loading that handles caching transparently.
Handles three types of inputs:
1. Local file path: Use directly (no caching)
2. URL: Download if not cached, then use cached version
3. HF model ID: Download via HF API if not cached, then use
Args:
model_path: URL or HuggingFace model ID (e.g., 'TheBloke/Llama-2-7B-GGUF')
model_path: Local path, URL, or HuggingFace model ID
cache_dir: Specific cache directory (auto-detected if None)
file_pattern: File pattern for HF downloads (default: '.gguf')
Returns:
Path to cached/downloaded model, or None on failure
Path to model file (local or cached), or None on failure
"""
# First check if already cached
# 1. Check if it's a local file path
if os.path.isfile(model_path):
# Local file - use directly without caching
return model_path
# 2. Check if it's a URL (starts with http/https)
if model_path.startswith('http://') or model_path.startswith('https://'):
# URL - check cache first, download if needed
cached_path = get_cached_model_path(model_path)
if cached_path:
return cached_path
# Not cached - download
return download_model(model_path, cache_dir, file_pattern)
# Not cached - need to download
# 3. Assume it's a HuggingFace model ID
# Check cache first, download if needed
cached_path = get_cached_model_path(model_path)
if cached_path:
return cached_path
# Not cached - download from HF
return download_model(model_path, cache_dir, file_pattern)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment