feat: persist and load whisper-server audio models

parent 8fd1c5c2
...@@ -20,3 +20,5 @@ debug.log ...@@ -20,3 +20,5 @@ debug.log
# Test files # Test files
test_*.py test_*.py
!tests/
!tests/test_whisper_server_local_models.py
...@@ -1123,6 +1123,17 @@ async def api_model_load(request: Request, username: str = Depends(require_admin ...@@ -1123,6 +1123,17 @@ async def api_model_load(request: Request, username: str = Depends(require_admin
raise RuntimeError("Model failed to load") raise RuntimeError("Model failed to load")
multi_model_manager.models[result["model_key"] or path] = mm multi_model_manager.models[result["model_key"] or path] = mm
multi_model_manager.active_in_vram = result["model_key"] or path multi_model_manager.active_in_vram = result["model_key"] or path
elif model_type == "audio":
wsm = multi_model_manager.whisper_servers.get(path)
if wsm is not None:
started = wsm.start(getattr(wsm, "_model_path", None), gpu_device=getattr(wsm, "_gpu_device", 0))
if not wsm.is_running():
raise RuntimeError("whisper-server failed to start")
model_key = f"audio:{path}"
multi_model_manager.models[model_key] = wsm
multi_model_manager.active_in_vram = model_key
multi_model_manager.models_in_vram.add(model_key)
return {"success": True, "already_loaded": False, "started_model": started}
elif model_type == "image": elif model_type == "image":
from codai.api.images import _load_diffusers_pipeline, _is_gguf_model, _load_sdcpp_model from codai.api.images import _load_diffusers_pipeline, _is_gguf_model, _load_sdcpp_model
from codai.api.state import get_global_args from codai.api.state import get_global_args
...@@ -1194,6 +1205,38 @@ async def api_model_configure(request: Request, username: str = Depends(require_ ...@@ -1194,6 +1205,38 @@ async def api_model_configure(request: Request, username: str = Depends(require_
if config_manager is None: if config_manager is None:
raise HTTPException(status_code=503, detail="Config manager not initialized") raise HTTPException(status_code=503, detail="Config manager not initialized")
data = await request.json() data = await request.json()
if data.get("backend") == "whisper-server":
model_id = (data.get("model_id") or "").strip()
if not model_id:
raise HTTPException(status_code=400, detail="model_id is required")
server_path = (data.get("server_path") or "").strip()
if not server_path:
raise HTTPException(status_code=400, detail="server_path is required")
port = int(data.get("port", 8744))
if port < 1 or port > 65535:
raise HTTPException(status_code=400, detail="port must be between 1 and 65535")
gpu_device = int(data.get("gpu_device", 0))
if gpu_device < 0:
raise HTTPException(status_code=400, detail="gpu_device must be >= 0")
for existing in config_manager.models_data.get("audio_models", []):
if isinstance(existing, dict) and existing.get("id") == model_id:
raise HTTPException(status_code=409, detail=f"whisper-server model '{model_id}' already exists")
entry = {
"id": model_id,
"backend": "whisper-server",
"server_path": server_path,
"model_path": (data.get("model_path") or "").strip() or None,
"port": port,
"gpu_device": gpu_device,
"load_mode": data.get("load_mode", "on-request"),
"model_type": "audio_models",
"model_types": ["audio_models"],
}
if data.get("used_vram_gb") is not None:
entry["used_vram_gb"] = data["used_vram_gb"]
config_manager.models_data.setdefault("audio_models", []).append(entry)
config_manager.save_models()
return {"success": True}
path = data.get("path") or data.get("model_id", "") path = data.get("path") or data.get("model_id", "")
model_type = data.get("model_type", "text_models") model_type = data.get("model_type", "text_models")
# Treat legacy gguf_models as text_models (GGUF is a format, not a type) # Treat legacy gguf_models as text_models (GGUF is a format, not a type)
...@@ -1633,4 +1676,4 @@ async def api_hf_model_info(model_id: str, username: str = Depends(require_admin ...@@ -1633,4 +1676,4 @@ async def api_hf_model_info(model_id: str, username: str = Depends(require_admin
"params_label": params_label, "params_label": params_label,
"gguf_files": gguf_files, "gguf_files": gguf_files,
"file_count": len(all_files), "file_count": len(all_files),
} }
\ No newline at end of file
...@@ -58,10 +58,23 @@ async def create_transcription( ...@@ -58,10 +58,23 @@ async def create_transcription(
""" """
Audio transcription endpoint (OpenAI-compatible). Audio transcription endpoint (OpenAI-compatible).
""" """
# Check if whisper-server is available FIRST # Check if the requested model maps to a configured whisper-server instance first
if multi_model_manager.whisper_server and multi_model_manager.whisper_server.is_running(): whisper_server = multi_model_manager.whisper_servers.get(model)
if whisper_server is not None:
file_content = await file.read() file_content = await file.read()
result = multi_model_manager.whisper_server.transcribe( if not whisper_server.is_running():
whisper_server.start(
getattr(whisper_server, "_model_path", None),
gpu_device=getattr(whisper_server, "_gpu_device", 0),
)
if whisper_server.is_running():
ws_key = f"audio:{model}"
multi_model_manager.models[ws_key] = whisper_server
multi_model_manager.active_in_vram = ws_key
multi_model_manager.models_in_vram.add(ws_key)
if not whisper_server.is_running():
raise HTTPException(status_code=500, detail="whisper-server failed to start")
result = whisper_server.transcribe(
file_content, file_content,
language=language, language=language,
prompt=prompt prompt=prompt
...@@ -200,4 +213,4 @@ async def create_transcription( ...@@ -200,4 +213,4 @@ async def create_transcription(
try: try:
os.unlink(tmp_path) os.unlink(tmp_path)
except Exception: except Exception:
pass pass
\ No newline at end of file
...@@ -368,7 +368,26 @@ def main(): ...@@ -368,7 +368,26 @@ def main():
audio_models = models_config.get("audio_models", []) audio_models = models_config.get("audio_models", [])
for m in audio_models: for m in audio_models:
mid = _model_id(m) mid = _model_id(m)
if mid: if not mid:
continue
if isinstance(m, dict) and m.get("backend") == "whisper-server":
cfg = _model_cfg(m, "audio")
cfg.update({
"backend": "whisper-server",
"server_path": m.get("server_path", ""),
"model_path": m.get("model_path") or None,
"port": int(m.get("port", 8744)),
"gpu_device": int(m.get("gpu_device", 0)),
})
multi_model_manager.register_whisper_server(
model_id=mid,
server_path=m.get("server_path", ""),
model_path=m.get("model_path") or None,
port=int(m.get("port", 8744)),
gpu_device=int(m.get("gpu_device", 0)),
config=cfg,
)
else:
multi_model_manager.set_audio_model(mid, config=_model_cfg(m, "audio")) multi_model_manager.set_audio_model(mid, config=_model_cfg(m, "audio"))
# Image models # Image models
...@@ -630,4 +649,4 @@ def main(): ...@@ -630,4 +649,4 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
main() main()
\ No newline at end of file
...@@ -413,6 +413,7 @@ class MultiModelManager: ...@@ -413,6 +413,7 @@ class MultiModelManager:
self.models_in_vram: set = set() # all models currently in VRAM self.models_in_vram: set = set() # all models currently in VRAM
self.model_aliases: Dict[str, str] = {} self.model_aliases: Dict[str, str] = {}
self.whisper_server: Optional[WhisperServerManager] = None self.whisper_server: Optional[WhisperServerManager] = None
self.whisper_servers: Dict[str, WhisperServerManager] = {}
self.model_backend_types: Dict[str, str] = {} self.model_backend_types: Dict[str, str] = {}
self.tool_breaker = FuzzyToolBreaker(threshold=3) # Circuit breaker for repetitive tool calls self.tool_breaker = FuzzyToolBreaker(threshold=3) # Circuit breaker for repetitive tool calls
...@@ -438,6 +439,12 @@ class MultiModelManager: ...@@ -438,6 +439,12 @@ class MultiModelManager:
self.whisper_server.stop() self.whisper_server.stop()
except Exception as e: except Exception as e:
print(f"Warning: Error cleaning up whisper server: {e}") print(f"Warning: Error cleaning up whisper server: {e}")
for manager in self.whisper_servers.values():
try:
manager.stop()
except Exception as e:
print(f"Warning: Error cleaning up whisper-server instance: {e}")
self.whisper_servers.clear()
# Clear all model lists # Clear all model lists
self.default_model = None self.default_model = None
...@@ -646,6 +653,10 @@ class MultiModelManager: ...@@ -646,6 +653,10 @@ class MultiModelManager:
self.audio_models.append(model_name) self.audio_models.append(model_name)
self.config[f"audio:{model_name}"] = config or {} self.config[f"audio:{model_name}"] = config or {}
if isinstance(config, dict) and config.get("backend") == "whisper-server":
print(f"Registered whisper-server audio model: {model_name}")
return
# Download/cache the model at startup if it's a URL or HF ID # Download/cache the model at startup if it's a URL or HF ID
resolved_model = self.load_model(model_name) resolved_model = self.load_model(model_name)
if resolved_model != model_name: if resolved_model != model_name:
...@@ -654,6 +665,21 @@ class MultiModelManager: ...@@ -654,6 +665,21 @@ class MultiModelManager:
self.audio_models[idx] = resolved_model self.audio_models[idx] = resolved_model
self.config[f"audio:{resolved_model}"] = self.config.pop(f"audio:{model_name}") self.config[f"audio:{resolved_model}"] = self.config.pop(f"audio:{model_name}")
print(f"Audio model '{model_name}' cached as: {resolved_model}") print(f"Audio model '{model_name}' cached as: {resolved_model}")
def register_whisper_server(self, model_id: str, server_path: str, model_path: str = None,
port: int = 8744, gpu_device: int = 0, config: Dict = None):
"""Register a whisper-server instance as an audio model."""
wsm = WhisperServerManager(server_path=server_path, port=port)
wsm._model_path = model_path
wsm._gpu_device = gpu_device
self.whisper_servers[model_id] = wsm
if model_id not in self.audio_models:
self.audio_models.append(model_id)
cfg = config or {}
cfg.setdefault("load_mode", "on-request")
self.config[f"audio:{model_id}"] = cfg
print(f"Registered whisper-server audio model: {model_id} (server: {server_path})")
return wsm
def set_tts_model(self, model_name: str, config: Dict = None): def set_tts_model(self, model_name: str, config: Dict = None):
"""Set the text-to-speech model and download/cache it if needed.""" """Set the text-to-speech model and download/cache it if needed."""
...@@ -1816,4 +1842,4 @@ class MultiModelManager: ...@@ -1816,4 +1842,4 @@ class MultiModelManager:
# Global singleton instances for convenience # Global singleton instances for convenience
model_manager = ModelManager() model_manager = ModelManager()
multi_model_manager = MultiModelManager() multi_model_manager = MultiModelManager()
\ No newline at end of file
from types import SimpleNamespace
from fastapi.testclient import TestClient
def test_model_configure_persists_whisper_server_audio_model(monkeypatch, tmp_path):
from codai.admin import routes
from codai.api.app import app
from codai.config import (
BackendConfig,
Config,
ConfigManager,
ImageConfig,
ModelsConfig,
OffloadConfig,
ServerConfig,
VulkanConfig,
WhisperConfig,
)
cfg = ConfigManager(str(tmp_path))
cfg.models_data = {
"text_models": [],
"image_models": [],
"audio_models": [],
"vision_models": [],
"tts_models": [],
"gguf_models": [],
"video_models": [],
"audio_gen_models": [],
"embedding_models": [],
"aliases": {},
}
cfg.config = Config(
version="1.0",
server=ServerConfig(),
backend=BackendConfig(),
models=ModelsConfig(),
offload=OffloadConfig(),
vulkan=VulkanConfig(),
image=ImageConfig(),
whisper=WhisperConfig(),
)
monkeypatch.setattr(routes, "config_manager", cfg, raising=False)
app.dependency_overrides[routes.require_admin] = lambda: "admin"
client = TestClient(app)
response = client.post(
"/admin/api/model-configure",
json={
"model_id": "whisper-vulkan-base",
"model_type": "audio_models",
"backend": "whisper-server",
"server_path": "/usr/local/bin/whisper-server",
"model_path": "/models/ggml-base.bin",
"port": 8744,
"gpu_device": 0,
"load_mode": "on-request",
"used_vram_gb": 1.8,
},
)
assert response.status_code == 200
assert cfg.models_data["audio_models"] == [
{
"id": "whisper-vulkan-base",
"backend": "whisper-server",
"server_path": "/usr/local/bin/whisper-server",
"model_path": "/models/ggml-base.bin",
"port": 8744,
"gpu_device": 0,
"load_mode": "on-request",
"used_vram_gb": 1.8,
"model_type": "audio_models",
"model_types": ["audio_models"],
}
]
app.dependency_overrides.clear()
def test_model_configure_rejects_duplicate_whisper_server_model_id(monkeypatch, tmp_path):
from codai.admin import routes
from codai.api.app import app
from codai.config import (
BackendConfig,
Config,
ConfigManager,
ImageConfig,
ModelsConfig,
OffloadConfig,
ServerConfig,
VulkanConfig,
WhisperConfig,
)
cfg = ConfigManager(str(tmp_path))
cfg.models_data = {
"text_models": [],
"image_models": [],
"audio_models": [
{
"id": "whisper-vulkan-base",
"backend": "whisper-server",
"server_path": "/usr/local/bin/whisper-server",
"model_path": "/models/ggml-base.bin",
"port": 8744,
"gpu_device": 0,
"load_mode": "on-request",
}
],
"vision_models": [],
"tts_models": [],
"gguf_models": [],
"video_models": [],
"audio_gen_models": [],
"embedding_models": [],
"aliases": {},
}
cfg.config = Config(
version="1.0",
server=ServerConfig(),
backend=BackendConfig(),
models=ModelsConfig(),
offload=OffloadConfig(),
vulkan=VulkanConfig(),
image=ImageConfig(),
whisper=WhisperConfig(),
)
monkeypatch.setattr(routes, "config_manager", cfg, raising=False)
app.dependency_overrides[routes.require_admin] = lambda: "admin"
client = TestClient(app)
response = client.post(
"/admin/api/model-configure",
json={
"model_id": "whisper-vulkan-base",
"model_type": "audio_models",
"backend": "whisper-server",
"server_path": "/usr/local/bin/whisper-server",
"model_path": "/models/ggml-small.bin",
"port": 8745,
"gpu_device": 1,
"load_mode": "load",
},
)
assert response.status_code in {400, 409}
assert "duplicate" in response.text.lower() or "already" in response.text.lower()
app.dependency_overrides.clear()
def test_model_load_and_unload_manage_whisper_server_runtime(monkeypatch):
from codai.admin import routes
from codai.api.app import app
from codai.models.manager import multi_model_manager
runtime = SimpleNamespace(
started=[],
stopped=False,
is_running=lambda: True,
start=lambda model_path=None, gpu_device=0: runtime.started.append((model_path, gpu_device)) or model_path,
cleanup=lambda: setattr(runtime, "stopped", True),
_model_path="/models/ggml-base.bin",
_gpu_device=0,
)
monkeypatch.setattr(
routes,
"config_manager",
SimpleNamespace(
models_data={
"audio_models": [
{
"id": "whisper-vulkan-base",
"backend": "whisper-server",
"server_path": "/usr/local/bin/whisper-server",
"model_path": "/models/ggml-base.bin",
"port": 8744,
"gpu_device": 0,
"load_mode": "on-request",
}
]
}
),
raising=False,
)
monkeypatch.setitem(multi_model_manager.whisper_servers, "whisper-vulkan-base", runtime)
multi_model_manager.models.clear()
app.dependency_overrides[routes.require_admin] = lambda: "admin"
client = TestClient(app)
load_response = client.post("/admin/api/model-load", json={"path": "whisper-vulkan-base"})
assert load_response.status_code == 200
assert runtime.started == [("/models/ggml-base.bin", 0)]
assert "audio:whisper-vulkan-base" in multi_model_manager.models
unload_response = client.post("/admin/api/model-unload", json={"path": "whisper-vulkan-base"})
assert unload_response.status_code == 200
assert runtime.stopped is True
assert "audio:whisper-vulkan-base" not in multi_model_manager.models
app.dependency_overrides.clear()
multi_model_manager.models.clear()
multi_model_manager.whisper_servers.clear()
def test_transcription_requires_configured_whisper_server_model_id():
import asyncio
import pytest
from fastapi import HTTPException
from codai.api import transcriptions
from codai.models.manager import multi_model_manager
multi_model_manager.whisper_servers.clear()
multi_model_manager.models.clear()
multi_model_manager.audio_models[:] = []
class DummyUpload:
filename = "sample.wav"
async def read(self):
return b"audio"
async def run_call():
return await transcriptions.create_transcription(
model="whisper-server",
file=DummyUpload(),
language=None,
prompt=None,
response_format="json",
temperature=0.0,
)
with pytest.raises(HTTPException) as exc:
asyncio.run(run_call())
assert exc.value.status_code in {400, 404}
assert "not configured" in str(exc.value.detail).lower() or "not available" in str(exc.value.detail).lower()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment