Commit c93d4a6b authored by Your Name's avatar Your Name

Centralize all model loading/downloading logic in codai.models.cache

- Added unified load_model() function as main entry point for model loading
- Updated WhisperServerManager to use centralized load_model() instead of inline logic
- Removed proxy methods from MultiModelManager - use cache module directly
- All cache functions now work seamlessly with both GGUF and HF model caches
- Improved separation of concerns: cache module handles all caching/downloading
parent 82735770
...@@ -147,6 +147,29 @@ def download_huggingface_model(model_id: str, cache_dir: Optional[str] = None, f ...@@ -147,6 +147,29 @@ def download_huggingface_model(model_id: str, cache_dir: Optional[str] = None, f
return None return None
def load_model(model_path: str, cache_dir: Optional[str] = None, file_pattern: str = '.gguf') -> Optional[str]:
"""
Load a model - check cache first, download if needed. Works with URLs and HF model IDs.
This is the main entry point for model loading that handles caching transparently.
Args:
model_path: URL or HuggingFace model ID (e.g., 'TheBloke/Llama-2-7B-GGUF')
cache_dir: Specific cache directory (auto-detected if None)
file_pattern: File pattern for HF downloads (default: '.gguf')
Returns:
Path to cached/downloaded model, or None on failure
"""
# First check if already cached
cached_path = get_cached_model_path(model_path)
if cached_path:
return cached_path
# Not cached - need to download
return download_model(model_path, cache_dir, file_pattern)
def download_model(model_path: str, cache_dir: Optional[str] = None, file_pattern: str = '.gguf') -> Optional[str]: def download_model(model_path: str, cache_dir: Optional[str] = None, file_pattern: str = '.gguf') -> Optional[str]:
""" """
Download a model from URL or HuggingFace. Works with both URLs and model IDs. Download a model from URL or HuggingFace. Works with both URLs and model IDs.
...@@ -157,7 +180,7 @@ def download_model(model_path: str, cache_dir: Optional[str] = None, file_patter ...@@ -157,7 +180,7 @@ def download_model(model_path: str, cache_dir: Optional[str] = None, file_patter
file_pattern: File pattern to download for HF models (default: '.gguf') file_pattern: File pattern to download for HF models (default: '.gguf')
Returns: Returns:
Path to downloaded/cached model, or None on failure Path to downloaded model, or None on failure
""" """
# Check if it's a HuggingFace model ID # Check if it's a HuggingFace model ID
if is_huggingface_model_id(model_path): if is_huggingface_model_id(model_path):
...@@ -200,7 +223,7 @@ def download_model(model_path: str, cache_dir: Optional[str] = None, file_patter ...@@ -200,7 +223,7 @@ def download_model(model_path: str, cache_dir: Optional[str] = None, file_patter
cached_filename = f"{url_hash}_{filename}" cached_filename = f"{url_hash}_{filename}"
cached_path = os.path.join(cache_dir, cached_filename) cached_path = os.path.join(cache_dir, cached_filename)
# Check if already cached # Check if already cached (double-check)
if os.path.exists(cached_path): if os.path.exists(cached_path):
print(f"Using cached model: {cached_path}") print(f"Using cached model: {cached_path}")
return cached_path return cached_path
...@@ -446,3 +469,19 @@ def remove_all_cached_models() -> int: ...@@ -446,3 +469,19 @@ def remove_all_cached_models() -> int:
total_removed += 1 total_removed += 1
return total_removed return total_removed
# Export all public functions
__all__ = [
'get_model_cache_dir',
'get_all_cache_dirs',
'get_cached_model_path',
'is_huggingface_model_id',
'download_huggingface_model',
'load_model',
'download_model',
'list_cached_models',
'remove_cached_model',
'list_cached_models_info',
'remove_all_cached_models',
]
...@@ -13,7 +13,7 @@ from codai.models.parser import ModelParserAdapter ...@@ -13,7 +13,7 @@ from codai.models.parser import ModelParserAdapter
from codai.backends import detect_available_backends from codai.backends import detect_available_backends
from codai.backends.cuda import NvidiaBackend from codai.backends.cuda import NvidiaBackend
from codai.backends.vulkan import VulkanBackend from codai.backends.vulkan import VulkanBackend
from codai.models.cache import get_cached_model_path, download_model, get_model_cache_dir from codai.models.cache import load_model
from codai.models.utils import FuzzyToolBreaker from codai.models.utils import FuzzyToolBreaker
from codai.pydantic.textrequest import ModelInfo from codai.pydantic.textrequest import ModelInfo
...@@ -258,24 +258,20 @@ class WhisperServerManager: ...@@ -258,24 +258,20 @@ class WhisperServerManager:
with self.lock: with self.lock:
if self.is_running(): if self.is_running():
self.stop() self.stop()
if not self.server_path: if not self.server_path:
print("Error: whisper-server path not set") print("Error: whisper-server path not set")
return "" return ""
# Handle URL models # Handle URL models - use centralized cache loading
actual_model_path = model_path actual_model_path = model_path
if model_path and (model_path.startswith('http://') or model_path.startswith('https://')): if model_path and (model_path.startswith('http://') or model_path.startswith('https://')):
cached_path = get_cached_model_path(model_path) print(f"Loading model: {model_path}")
if cached_path: actual_model_path = load_model(model_path)
actual_model_path = cached_path if not actual_model_path:
print(f"Using cached model: {actual_model_path}") print(f"Failed to load model: {model_path}")
else: return ""
cache_dir = get_model_cache_dir()
print(f"Downloading model: {model_path}")
actual_model_path = download_model(model_path, cache_dir)
print(f"Downloaded model to: {actual_model_path}")
cmd = [self.server_path] cmd = [self.server_path]
if actual_model_path: if actual_model_path:
cmd.extend(["-m", actual_model_path]) cmd.extend(["-m", actual_model_path])
...@@ -283,9 +279,9 @@ class WhisperServerManager: ...@@ -283,9 +279,9 @@ class WhisperServerManager:
cmd.append("--convert") cmd.append("--convert")
cmd.extend(["--host", "127.0.0.1"]) cmd.extend(["--host", "127.0.0.1"])
cmd.extend(["--port", str(self.port)]) cmd.extend(["--port", str(self.port)])
print(f"Starting whisper-server: {' '.join(cmd)}") print(f"Starting whisper-server: {' '.join(cmd)}")
try: try:
self.process = subprocess.Popen( self.process = subprocess.Popen(
cmd, cmd,
...@@ -294,7 +290,7 @@ class WhisperServerManager: ...@@ -294,7 +290,7 @@ class WhisperServerManager:
preexec_fn=lambda: signal.signal(signal.SIGTERM, signal.SIG_DFL) preexec_fn=lambda: signal.signal(signal.SIGTERM, signal.SIG_DFL)
) )
self.current_model = actual_model_path self.current_model = actual_model_path
if self._wait_for_server(30): if self._wait_for_server(30):
print(f"whisper-server started on {self.base_url}") print(f"whisper-server started on {self.base_url}")
return actual_model_path return actual_model_path
...@@ -796,23 +792,6 @@ class MultiModelManager: ...@@ -796,23 +792,6 @@ class MultiModelManager:
# Otherwise return the first loaded model (there should only be one in ondemand mode) # Otherwise return the first loaded model (there should only be one in ondemand mode)
return list(self.models.keys())[0] if self.models else None return list(self.models.keys())[0] if self.models else None
def get_cached_model_path(self, url: str) -> Optional[str]:
"""
Check if a model URL is already cached.
This is a proxy method to the cache module function.
Returns the cached path if the model is cached, None otherwise.
"""
return get_cached_model_path(url)
def get_model_cache_dir(self) -> str:
"""
Get the model cache directory.
This is a proxy method to the cache module function.
Returns the path to the model cache directory.
"""
return get_model_cache_dir()
def unload_all_models(self): def unload_all_models(self):
""" """
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment