Commit 7788ce85 authored by Your Name's avatar Your Name

Fix architecture: Proper separation of Model Manager and Cache responsibilities

- **Model Manager**: Central coordinator for model lifecycle, alias resolution, loading/unloading
- **Cache Module**: Handles downloading, caching, and storage of models
- **API Modules**: Request models from Model Manager (not directly from cache)

Key changes:
- Removed resolve_and_load_model() from cache - moved logic to Model Manager
- Model Manager now downloads/caches models at startup when registered
- API modules use multi_model_manager.load_model() instead of cache functions
- Proper separation: Cache=storage, Manager=lifecycle coordination, APIs=requests

This fixes the incorrect direct API-to-cache coupling and establishes proper architectural boundaries.
parent de4d544f
...@@ -13,7 +13,6 @@ from PIL import Image ...@@ -13,7 +13,6 @@ from PIL import Image
# Import from codai modules # Import from codai modules
from codai.models.manager import multi_model_manager from codai.models.manager import multi_model_manager
from codai.models.cache import resolve_and_load_model
from codai.pydantic.imagerequest import ImageGenerationRequest from codai.pydantic.imagerequest import ImageGenerationRequest
from codai.api.state import get_load_mode from codai.api.state import get_load_mode
...@@ -654,9 +653,11 @@ async def create_image_generation(request: ImageGenerationRequest, http_request: ...@@ -654,9 +653,11 @@ async def create_image_generation(request: ImageGenerationRequest, http_request:
try: try:
from stable_diffusion_cpp import StableDiffusion from stable_diffusion_cpp import StableDiffusion
# Use centralized model resolution # Use model manager to resolve and load the model
model_path = resolve_and_load_model(model_to_use, model_type='image') model_path = multi_model_manager.load_model(model_to_use)
# For diffusers models, model_path will be the identifier string
# For GGUF models, it will be the file path
if model_path is not None and not os.path.isfile(model_path): if model_path is not None and not os.path.isfile(model_path):
# This is a diffusers model identifier (not a file path) # This is a diffusers model identifier (not a file path)
# Skip sd.cpp and let diffusers handle it # Skip sd.cpp and let diffusers handle it
......
...@@ -198,6 +198,8 @@ def load_model(model_path: str, cache_dir: Optional[str] = None, file_pattern: s ...@@ -198,6 +198,8 @@ def load_model(model_path: str, cache_dir: Optional[str] = None, file_pattern: s
return download_model(model_path, cache_dir, file_pattern) return download_model(model_path, cache_dir, file_pattern)
def download_model(model_path: str, cache_dir: Optional[str] = None, file_pattern: str = '.gguf') -> Optional[str]: def download_model(model_path: str, cache_dir: Optional[str] = None, file_pattern: str = '.gguf') -> Optional[str]:
""" """
Download a model from URL or HuggingFace. Works with both URLs and model IDs. Download a model from URL or HuggingFace. Works with both URLs and model IDs.
......
...@@ -13,7 +13,7 @@ from codai.models.parser import ModelParserAdapter ...@@ -13,7 +13,7 @@ from codai.models.parser import ModelParserAdapter
from codai.backends import detect_available_backends from codai.backends import detect_available_backends
from codai.backends.cuda import NvidiaBackend from codai.backends.cuda import NvidiaBackend
from codai.backends.vulkan import VulkanBackend from codai.backends.vulkan import VulkanBackend
from codai.models.cache import load_model from codai.models.cache import get_cached_model_path, download_model, get_model_cache_dir, load_model
from codai.models.utils import FuzzyToolBreaker from codai.models.utils import FuzzyToolBreaker
from codai.pydantic.textrequest import ModelInfo from codai.pydantic.textrequest import ModelInfo
...@@ -483,10 +483,17 @@ class MultiModelManager: ...@@ -483,10 +483,17 @@ class MultiModelManager:
self.load_mode = mode self.load_mode = mode
def set_default_model(self, model_name: str, config: Dict = None, backend_type: str = "auto"): def set_default_model(self, model_name: str, config: Dict = None, backend_type: str = "auto"):
"""Set the default/main text model.""" """Set the default/main text model and download/cache it if needed."""
self.default_model = model_name self.default_model = model_name
self.config[model_name] = config or {} self.config[model_name] = config or {}
self.model_backend_types[model_name] = backend_type self.model_backend_types[model_name] = backend_type
# Download/cache the model at startup if it's a URL or HF ID
resolved_model = self.load_model(model_name)
if resolved_model != model_name:
# Model was downloaded/cached, update the stored name
self.default_model = resolved_model
print(f"Model '{model_name}' cached as: {resolved_model}")
def _load_default_model(self): def _load_default_model(self):
"""Load the default model on demand.""" """Load the default model on demand."""
...@@ -586,27 +593,58 @@ class MultiModelManager: ...@@ -586,27 +593,58 @@ class MultiModelManager:
return None return None
def set_audio_model(self, model_name: str, config: Dict = None): def set_audio_model(self, model_name: str, config: Dict = None):
"""Add an audio transcription model.""" """Add an audio transcription model and download/cache it if needed."""
if model_name not in self.audio_models: if model_name not in self.audio_models:
self.audio_models.append(model_name) self.audio_models.append(model_name)
self.config[f"audio:{model_name}"] = config or {} self.config[f"audio:{model_name}"] = config or {}
# Download/cache the model at startup if it's a URL or HF ID
resolved_model = self.load_model(model_name)
if resolved_model != model_name:
# Model was downloaded/cached, update the stored name
idx = self.audio_models.index(model_name)
self.audio_models[idx] = resolved_model
self.config[f"audio:{resolved_model}"] = self.config.pop(f"audio:{model_name}")
print(f"Audio model '{model_name}' cached as: {resolved_model}")
def set_tts_model(self, model_name: str, config: Dict = None): def set_tts_model(self, model_name: str, config: Dict = None):
"""Set the text-to-speech model.""" """Set the text-to-speech model and download/cache it if needed."""
self.tts_model = model_name self.tts_model = model_name
self.config[f"tts:{model_name}"] = config or {} self.config[f"tts:{model_name}"] = config or {}
# Download/cache the model at startup if it's a URL or HF ID
resolved_model = self.load_model(model_name)
if resolved_model != model_name:
# Model was downloaded/cached, update the stored name
self.tts_model = resolved_model
self.config[f"tts:{resolved_model}"] = self.config.pop(f"tts:{model_name}")
print(f"TTS model '{model_name}' cached as: {resolved_model}")
def set_image_model(self, model_name: str, config: Dict = None): def set_image_model(self, model_name: str, config: Dict = None):
"""Add an image generation model.""" """Add an image generation model and download/cache it if needed."""
if model_name not in self.image_models: if model_name not in self.image_models:
self.image_models.append(model_name) self.image_models.append(model_name)
self.config[f"image:{model_name}"] = config or {} self.config[f"image:{model_name}"] = config or {}
# For image models, we don't download at startup since they may be large
# and handled by different backends (diffusers vs sd.cpp)
# The download will happen when the model is first requested
print(f"Registered image model: {model_name}")
def set_vision_model(self, model_name: str, config: Dict = None): def set_vision_model(self, model_name: str, config: Dict = None):
"""Add a vision model.""" """Add a vision model and download/cache it if needed."""
if model_name not in self.vision_models: if model_name not in self.vision_models:
self.vision_models.append(model_name) self.vision_models.append(model_name)
self.config[f"vision:{model_name}"] = config or {} self.config[f"vision:{model_name}"] = config or {}
# Download/cache the model at startup if it's a URL or HF ID
resolved_model = self.load_model(model_name)
if resolved_model != model_name:
# Model was downloaded/cached, update the stored name
idx = self.vision_models.index(model_name)
self.vision_models[idx] = resolved_model
self.config[f"vision:{resolved_model}"] = self.config.pop(f"vision:{model_name}")
print(f"Vision model '{model_name}' cached as: {resolved_model}")
def set_model_alias(self, alias: str, model_name: str): def set_model_alias(self, alias: str, model_name: str):
"""Register an alias for a model.""" """Register an alias for a model."""
...@@ -792,6 +830,60 @@ class MultiModelManager: ...@@ -792,6 +830,60 @@ class MultiModelManager:
# Otherwise return the first loaded model (there should only be one in ondemand mode) # Otherwise return the first loaded model (there should only be one in ondemand mode)
return list(self.models.keys())[0] if self.models else None return list(self.models.keys())[0] if self.models else None
def get_cached_model_path(self, model_path: str) -> Optional[str]:
"""
Check if a model is already cached.
This is a proxy method to the cache module function.
Returns the cached path if the model is cached, None otherwise.
"""
return get_cached_model_path(model_path)
def get_model_cache_dir(self) -> str:
"""
Get the model cache directory.
This is a proxy method to the cache module function.
Returns the path to the model cache directory.
"""
return get_model_cache_dir()
def load_model(self, model_path: str, cache_dir: Optional[str] = None, file_pattern: str = '.gguf') -> Optional[str]:
"""
Load a model with intelligent caching and resolution.
Handles local files, URLs, and HuggingFace model IDs.
Returns the resolved model path or identifier.
"""
from codai.models.cache import is_huggingface_model_id
# 1. Check if it's a local file
if os.path.isfile(model_path):
print(f"Using local model: {model_path}")
return model_path
# 2. Check if it's a URL
if model_path.startswith('http://') or model_path.startswith('https://'):
print(f"Loading model from URL: {model_path}")
return load_model(model_path, cache_dir, file_pattern)
# 3. Check if it's a HuggingFace model ID
if is_huggingface_model_id(model_path):
# For diffusers models (most image models), return the identifier
# The actual loading will be handled by the specific backend (diffusers, sd.cpp, etc.)
print(f"Using HuggingFace model: {model_path}")
return model_path
# 4. Try as a generic model identifier with caching
print(f"Resolving model: {model_path}")
cached_path = get_cached_model_path(model_path)
if cached_path:
print(f"Using cached model: {cached_path}")
return cached_path
# 5. Try to download it
return load_model(model_path, cache_dir, file_pattern)
def unload_all_models(self): def unload_all_models(self):
""" """
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment