Commit 5e641ba2 authored by Your Name's avatar Your Name

Fix API modules to use centralized cache functions

- Updated codai/api/images.py to use cache module functions directly
- Updated codai/api/tts.py to use centralized load_model() function
- Removed proxy method calls that were causing AttributeError
- All model loading/downloading now goes through codai.models.cache
parent bff24350
...@@ -13,6 +13,7 @@ from PIL import Image ...@@ -13,6 +13,7 @@ from PIL import Image
# Import from codai modules # Import from codai modules
from codai.models.manager import multi_model_manager from codai.models.manager import multi_model_manager
from codai.models.cache import get_cached_model_path
from codai.pydantic.imagerequest import ImageGenerationRequest from codai.pydantic.imagerequest import ImageGenerationRequest
from codai.api.state import get_load_mode from codai.api.state import get_load_mode
...@@ -657,15 +658,15 @@ async def create_image_generation(request: ImageGenerationRequest, http_request: ...@@ -657,15 +658,15 @@ async def create_image_generation(request: ImageGenerationRequest, http_request:
# Also handle HuggingFace model IDs that need to be resolved # Also handle HuggingFace model IDs that need to be resolved
model_path = None model_path = None
if model_to_use.startswith('http://') or model_to_use.startswith('https://'): if model_to_use.startswith('http://') or model_to_use.startswith('https://'):
cached_path = multi_model_manager.get_cached_model_path(model_to_use) cached_path = get_cached_model_path(model_to_use)
if cached_path: if cached_path:
model_path = cached_path model_path = cached_path
print(f"Using cached model: {model_path}") print(f"Using cached model: {model_path}")
else: else:
# Not cached - download it # Not cached - download it
print(f"Downloading model: {model_to_use}") print(f"Downloading model: {model_to_use}")
cache_dir = multi_model_manager.get_model_cache_dir() from codai.models.cache import load_model
model_path = multi_model_manager.download_model(model_to_use, cache_dir) model_path = load_model(model_to_use)
print(f"Downloaded to: {model_path}") print(f"Downloaded to: {model_path}")
elif os.path.isfile(model_to_use): elif os.path.isfile(model_to_use):
model_path = model_to_use model_path = model_to_use
...@@ -687,7 +688,7 @@ async def create_image_generation(request: ImageGenerationRequest, http_request: ...@@ -687,7 +688,7 @@ async def create_image_generation(request: ImageGenerationRequest, http_request:
for file in files: for file in files:
# Construct potential URL and check cache # Construct potential URL and check cache
potential_url = f"https://huggingface.co/{repo_id}/resolve/main/{file}" potential_url = f"https://huggingface.co/{repo_id}/resolve/main/{file}"
cached = multi_model_manager.get_cached_model_path(potential_url) cached = get_cached_model_path(potential_url)
if cached: if cached:
model_path = cached model_path = cached
print(f"Using cached model from HF repo: {model_path}") print(f"Using cached model from HF repo: {model_path}")
...@@ -706,7 +707,7 @@ async def create_image_generation(request: ImageGenerationRequest, http_request: ...@@ -706,7 +707,7 @@ async def create_image_generation(request: ImageGenerationRequest, http_request:
for gguf_file in gguf_files: for gguf_file in gguf_files:
# Construct potential URL and check cache # Construct potential URL and check cache
potential_url = f"https://huggingface.co/{repo_id}/resolve/main/{gguf_file}" potential_url = f"https://huggingface.co/{repo_id}/resolve/main/{gguf_file}"
cached = multi_model_manager.get_cached_model_path(potential_url) cached = get_cached_model_path(potential_url)
if cached: if cached:
model_path = cached model_path = cached
print(f"Using cached GGUF model: {model_path}") print(f"Using cached GGUF model: {model_path}")
...@@ -735,7 +736,7 @@ async def create_image_generation(request: ImageGenerationRequest, http_request: ...@@ -735,7 +736,7 @@ async def create_image_generation(request: ImageGenerationRequest, http_request:
print(f"Using local file: {model_path}") print(f"Using local file: {model_path}")
else: else:
# Not a local file, check if it might be a cached model under a different name # Not a local file, check if it might be a cached model under a different name
cached_path = multi_model_manager.get_cached_model_path(model_to_use) cached_path = get_cached_model_path(model_to_use)
if cached_path: if cached_path:
model_path = cached_path model_path = cached_path
print(f"Using cached model: {model_path}") print(f"Using cached model: {model_path}")
...@@ -743,8 +744,8 @@ async def create_image_generation(request: ImageGenerationRequest, http_request: ...@@ -743,8 +744,8 @@ async def create_image_generation(request: ImageGenerationRequest, http_request:
# Last resort: try to download it as if it were a URL # Last resort: try to download it as if it were a URL
print(f"Attempting to download '{model_to_use}' as model URL") print(f"Attempting to download '{model_to_use}' as model URL")
try: try:
cache_dir = multi_model_manager.get_model_cache_dir() from codai.models.cache import load_model
model_path = multi_model_manager.download_model(model_to_use, cache_dir) model_path = load_model(model_to_use)
print(f"Downloaded to: {model_path}") print(f"Downloaded to: {model_path}")
except Exception as download_error: except Exception as download_error:
print(f"Download failed: {download_error}") print(f"Download failed: {download_error}")
......
...@@ -18,14 +18,14 @@ global_args = None ...@@ -18,14 +18,14 @@ global_args = None
def get_cached_model_path(url: str) -> str: def get_cached_model_path(url: str) -> str:
"""Get cached model path if available.""" """Get cached model path if available."""
from codai.models.manager import multi_model_manager from codai.models.cache import get_cached_model_path as cache_get_cached_model_path
return multi_model_manager.get_cached_model_path(url) return cache_get_cached_model_path(url)
def get_model_cache_dir() -> str: def get_model_cache_dir() -> str:
"""Get model cache directory.""" """Get model cache directory."""
from codai.models.manager import multi_model_manager from codai.models.cache import get_model_cache_dir
return multi_model_manager.get_model_cache_dir() return get_model_cache_dir()
def set_global_args(args): def set_global_args(args):
...@@ -127,53 +127,11 @@ async def create_speech(request: TTSRequest): ...@@ -127,53 +127,11 @@ async def create_speech(request: TTSRequest):
# Check if model_to_use is a URL - download it (with caching) # Check if model_to_use is a URL - download it (with caching)
model_path = None model_path = None
if model_to_use.startswith('http://') or model_to_use.startswith('https://'): if model_to_use.startswith('http://') or model_to_use.startswith('https://'):
# Check cache first print(f"Loading model from URL: {model_to_use}")
cached_path = get_cached_model_path(model_to_use) from codai.models.cache import load_model
if cached_path: model_path = load_model(model_to_use)
model_path = cached_path if not model_path:
print(f"Using cached model: {model_path}") raise Exception(f"Failed to load model from {model_to_use}")
else:
print(f"Downloading model from URL: {model_to_use}")
try:
import requests
import hashlib
# Get cache directory
cache_dir = get_model_cache_dir()
# Extract filename from URL
url_path = model_to_use.split('?')[0]
filename = os.path.basename(url_path)
if not filename.endswith('.pt') and not filename.endswith('.bin'):
filename = "kokoro-model.pt"
# Create safe filename in cache
url_hash = hashlib.sha256(model_to_use.encode()).hexdigest()
cached_filename = f"{url_hash}_{filename}"
model_path = os.path.join(cache_dir, cached_filename)
# Download to cache
response = requests.get(model_to_use, stream=True)
response.raise_for_status()
total_size = int(response.headers.get('content-length', 0))
downloaded = 0
with open(model_path, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192*1024):
if chunk:
f.write(chunk)
downloaded += len(chunk)
if total_size > 0:
percent = (downloaded / total_size) * 100
print(f"Downloaded: {percent:.1f}%", end='\r')
print(f"\nDownloaded and cached to: {model_path}")
except Exception as e:
print(f"Error downloading model: {e}")
raise
else: else:
# Use local path or model name # Use local path or model name
model_path = model_to_use model_path = model_to_use
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment