Commit 405be1cb authored by Your Name's avatar Your Name

Refactor: Move all API endpoints to codai.api module and extract CLI to codai.cli

- Move parse_args to codai.cli
- Move main() to codai.main
- Simplify coderai to be a thin wrapper importing from codai package
- Create codai.api module with organized endpoints:
  - codai/api/app.py: FastAPI app, /v1/models, /v1/files, get_load_mode
  - codai/api/text.py: /v1/chat/completions, legacy /v1/completions
  - codai/api/images.py: /v1/images/generations
  - codai/api/transcriptions.py: /v1/audio/transcriptions
  - codai/api/tts.py: /v1/audio/speech
- coderai is now backward compatible entry point only
parent 1c79d9ab
# codai.api - FastAPI application module
from .app import app
__all__ = ['app']
"""
FastAPI application module for codai API.
Contains the FastAPI app initialization, lifespan, and core endpoints.
"""
from contextlib import asynccontextmanager
from typing import List
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import FileResponse, JSONResponse
# Import from codai modules
from codai.pydantic.textrequest import ModelList
from codai.models.manager import model_manager, multi_model_manager
# Global references to be set by coderai
# These will be imported/assigned after the app is created
global_debug = False
global_file_path = None
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Lifespan context manager for startup/shutdown."""
# Startup
yield
# Shutdown
multi_model_manager.cleanup()
model_manager.cleanup()
# Stop whisper-server if running
if multi_model_manager.whisper_server:
multi_model_manager.whisper_server.stop()
# Create the FastAPI app
app = FastAPI(
title="OpenAI-Compatible API",
description="OpenAI-compatible API supporting NVIDIA (CUDA) and Vulkan backends",
version="2.0.0",
lifespan=lifespan,
)
# Import routers from submodules
from codai.api.transcriptions import router as transcriptions_router
from codai.api.images import router as images_router
from codai.api.tts import router as tts_router
# Import and add middleware
from codai.api.log import log_requests
app.middleware("http")(log_requests)
# Include routers from submodules
app.include_router(transcriptions_router)
app.include_router(images_router)
app.include_router(tts_router)
@app.get("/v1/models", response_model=ModelList)
async def list_models():
"""List available models."""
models = multi_model_manager.list_models()
return ModelList(data=models)
@app.get("/v1/files/{filename}")
async def get_file(filename: str):
"""Serve uploaded/generated files."""
if global_file_path:
import os
file_path = os.path.join(global_file_path, filename)
if os.path.exists(file_path):
return FileResponse(file_path)
raise HTTPException(status_code=404, detail="File not found")
def set_global_debug(debug: bool):
"""Set the global debug flag."""
global global_debug
global_debug = debug
def set_global_file_path(path: str):
"""Set the global file path."""
global global_file_path
global_file_path = path
# Load mode - will be set by coderai
_load_mode = {"mode": "ondemand"}
def get_load_mode():
"""Get the current load mode."""
return _load_mode.get("mode", "ondemand")
def set_load_mode(mode: str):
"""Set the load mode from coderai."""
global _load_mode
_load_mode = mode
This diff is collapsed.
"""
Request logging middleware for the codai API.
"""
from fastapi import Request
async def log_requests(request: Request, call_next):
"""Log all incoming requests for debugging."""
# Import global debug flag from app
from codai.api.app import global_debug
if request.url.path in ["/v1/chat/completions", "/v1/completions"]:
body = b""
body_str = ""
try:
body = await request.body()
body_str = body.decode('utf-8')
# In debug mode, dump the full request
if global_debug:
print(f"DEBUG: Request body: {body_str[:500]}...")
except Exception as e:
print(f"Error reading request body: {e}")
# Call the next middleware/handler
response = await call_next(request)
# Log response status
if global_debug:
print(f"DEBUG: Response status: {response.status_code}")
return response
else:
# For non-chat endpoints, just pass through
response = await call_next(request)
return response
This diff is collapsed.
"""
Audio transcription endpoint for the codai API.
"""
import io
import os
import tempfile
from fastapi import APIRouter, HTTPException, UploadFile, File, Form
from typing import Optional
# Import from codai modules
from codai.models.manager import multi_model_manager
# Global reference to be set by coderai
global_args = None
def set_global_args(args):
"""Set global args from coderai."""
global global_args
global_args = args
# =============================================================================
# Router and Endpoints
# =============================================================================
router = APIRouter()
@router.post("/v1/audio/transcriptions")
async def create_transcription(
model: str = Form(...),
file: UploadFile = File(...),
language: Optional[str] = Form(None),
prompt: Optional[str] = Form(None),
response_format: Optional[str] = Form("json"),
temperature: Optional[float] = Form(0.0),
):
"""
Audio transcription endpoint (OpenAI-compatible).
"""
# Check if whisper-server is available FIRST
if multi_model_manager.whisper_server and multi_model_manager.whisper_server.is_running():
file_content = await file.read()
result = multi_model_manager.whisper_server.transcribe(
file_content,
language=language,
prompt=prompt
)
if "error" in result:
raise HTTPException(status_code=500, detail=result["error"])
return {"text": result.get("text", "")}
audio_model = multi_model_manager.audio_model
if not audio_model:
raise HTTPException(
status_code=400,
detail="Audio transcription not configured. Use --audio-model or --whisper-server."
)
# Determine model to use
model_to_use = model
if model_to_use.startswith("whisper:") or model_to_use.startswith("audio:"):
model_to_use = audio_model
# Read the uploaded file
file_content = await file.read()
# Save to temp file (needed for some backends)
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as tmp:
tmp.write(file_content)
tmp_path = tmp.name
try:
# Try faster-whisper first
try:
from faster_whisper import WhisperModel
# Determine model key
model_key = f"audio:{model_to_use}"
whisper_model = multi_model_manager.get_model(model_key)
if whisper_model is None:
print(f"Loading faster-whisper model: {model_to_use}")
# Determine compute type - always use int8 for CPU
compute_type = "int8"
# Load the model
whisper_model = WhisperModel(
model_to_use,
device="cpu", # Always use CPU - faster-whisper CUDA doesn't work with AMD
compute_type=compute_type,
)
# Cache the model
multi_model_manager.add_model(model_key, whisper_model)
print(f"Loaded faster-whisper model: {model_to_use}")
# Run transcription
segments, info = whisper_model.transcribe(
tmp_path,
language=language,
initial_prompt=prompt,
temperature=temperature,
)
# Collect all segments
text_parts = []
for segment in segments:
text_parts.append(segment.text)
full_text = "".join(text_parts)
return {
"text": full_text.strip()
}
except ImportError:
pass
# Try whispercpp as fallback
try:
import whispercpp
# Determine model key
model_key = f"audio:{model_to_use}"
whisper_model = multi_model_manager.get_model(model_key)
if whisper_model is None:
print(f"Loading whispercpp model: {model_to_use}")
# Check if it's a built-in model name
if model_to_use in ['tiny.en', 'tiny', 'base.en', 'base', 'small.en', 'small', 'medium.en', 'medium', 'large-v1', 'large']:
# It's a built-in model name
whisper_model = whispercpp.Whisper.from_pretrained(model_to_use)
else:
# It's a path to a GGUF file
whisper_model = whispercpp.Whisper.from_pretrained(model_to_use)
# Cache the model
multi_model_manager.add_model(model_key, whisper_model)
print(f"Loaded whispercpp model: {model_to_use}")
# Run transcription
result = whisper_model.transcribe(tmp_path)
# Extract text from result
text = ""
if hasattr(result, 'text'):
text = result.text
elif isinstance(result, dict):
text = result.get('text', '')
elif isinstance(result, list):
# Some versions return a list of segments
for segment in result:
if hasattr(segment, 'text'):
text += segment.text
elif isinstance(segment, dict):
text += segment.get('text', '')
return {
"text": text.strip()
}
except ImportError as e:
raise HTTPException(
status_code=501,
detail="Audio transcription not available. Install faster-whisper or whispercpp."
)
except Exception as e:
print(f"Transcription error: {e}")
import traceback
traceback.print_exc()
raise HTTPException(status_code=500, detail=f"Transcription error: {str(e)}")
finally:
# Clean up temp file
try:
os.unlink(tmp_path)
except Exception:
pass
"""
Text-to-speech endpoints for the codai API.
"""
import base64
import os
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel, ConfigDict
# Import from codai modules
from codai.models.manager import multi_model_manager
# Global reference to be set by coderai
global_args = None
def get_cached_model_path(url: str) -> str:
"""Get cached model path if available."""
from codai.models.manager import multi_model_manager
return multi_model_manager.get_cached_model_path(url)
def get_model_cache_dir() -> str:
"""Get model cache directory."""
from codai.models.manager import multi_model_manager
return multi_model_manager.get_model_cache_dir()
def set_global_args(args):
"""Set global args from coderai."""
global global_args
global_args = args
# =============================================================================
# Router and Endpoints
# =============================================================================
router = APIRouter()
class TTSRequest(BaseModel):
model: str
input: str
voice: str = "af_sarah"
response_format: str = "mp3"
speed: float = 1.0
model_config = ConfigDict(extra="allow")
class TTSResponse(BaseModel):
audio: str # base64 encoded audio
model_config = ConfigDict(extra="allow")
@router.post("/v1/audio/speech")
async def create_speech(request: TTSRequest):
"""
Text-to-speech endpoint (OpenAI-compatible).
Supports:
- Kokoro TTS models (when --tts-model is specified)
"""
tts_model = multi_model_manager.tts_model
# If no TTS model configured, return an error
if not tts_model:
raise HTTPException(
status_code=400,
detail="TTS not configured. Use --tts-model to specify a model."
)
# Determine model to use
model_to_use = request.model
if model_to_use.startswith("tts:"):
model_to_use = tts_model
# Try to use kokoro if available
try:
from kokoro import Kokoro
# Determine model key
model_key = f"tts:{model_to_use}"
kokoro_model = multi_model_manager.get_model(model_key)
if kokoro_model is None:
print(f"Loading Kokoro TTS model: {model_to_use}")
# Check if model_to_use is a URL - download it (with caching)
model_path = None
if model_to_use.startswith('http://') or model_to_use.startswith('https://'):
# Check cache first
cached_path = get_cached_model_path(model_to_use)
if cached_path:
model_path = cached_path
print(f"Using cached model: {model_path}")
else:
print(f"Downloading model from URL: {model_to_use}")
try:
import requests
import hashlib
# Get cache directory
cache_dir = get_model_cache_dir()
# Extract filename from URL
url_path = model_to_use.split('?')[0]
filename = os.path.basename(url_path)
if not filename.endswith('.pt') and not filename.endswith('.bin'):
filename = "kokoro-model.pt"
# Create safe filename in cache
url_hash = hashlib.sha256(model_to_use.encode()).hexdigest()
cached_filename = f"{url_hash}_{filename}"
model_path = os.path.join(cache_dir, cached_filename)
# Download to cache
response = requests.get(model_to_use, stream=True)
response.raise_for_status()
total_size = int(response.headers.get('content-length', 0))
downloaded = 0
with open(model_path, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192*1024):
if chunk:
f.write(chunk)
downloaded += len(chunk)
if total_size > 0:
percent = (downloaded / total_size) * 100
print(f"Downloaded: {percent:.1f}%", end='\r')
print(f"\nDownloaded and cached to: {model_path}")
except Exception as e:
print(f"Error downloading model: {e}")
raise
else:
# Use local path or model name
model_path = model_to_use
# Load the Kokoro model
kokoro_model = Kokoro(model_path if model_path else model_to_use)
multi_model_manager.add_model(model_key, kokoro_model)
# Generate speech
voice = request.voice or "af_sarah"
speed = request.speed or 1.0
audio_bytes = kokoro_model.generate(request.input, voice=voice, speed=speed)
# Convert to base64
audio_base64 = base64.b64encode(audio_bytes).decode('utf-8')
return {
"audio": audio_base64
}
except ImportError as e:
# kokoro not installed
raise HTTPException(
status_code=501,
detail=f"TTS not available. Install kokoro: pip install kokoro. Error: {str(e)}"
)
except Exception as e:
print(f"TTS error: {e}")
import traceback
traceback.print_exc()
raise HTTPException(status_code=500, detail=f"TTS error: {str(e)}")
This diff is collapsed.
This diff is collapsed.
# codai.models - Model parsing and templates
from .manager import (
ModelManager,
WhisperServerManager,
MultiModelManager,
model_manager,
multi_model_manager,
)
from .parser import (
ModelParserDispatcher,
BaseParser,
......@@ -16,12 +23,20 @@ from .parser import (
ToolCallParser,
ModelParserAdapter,
filter_repetition,
filter_malformed_content,
cleanup_control_tokens,
validate_json_complete,
format_tools_for_prompt,
)
from .templates import AgenticTemplateManager
__all__ = [
'ModelManager',
'WhisperServerManager',
'MultiModelManager',
'model_manager',
'multi_model_manager',
'ModelParserDispatcher',
'BaseParser',
'QwenParser',
......@@ -39,5 +54,8 @@ __all__ = [
'ModelParserAdapter',
'AgenticTemplateManager',
'filter_repetition',
'filter_malformed_content',
'cleanup_control_tokens',
'validate_json_complete',
'format_tools_for_prompt',
]
......@@ -660,3 +660,12 @@ class MultiModelManager:
models.append(ModelInfo(id=alias))
return models
# =============================================================================
# Singleton Instances
# =============================================================================
# Global singleton instances for convenience
model_manager = ModelManager()
multi_model_manager = MultiModelManager()
......@@ -1299,6 +1299,90 @@ def filter_malformed_content(text: str) -> str:
return filtered
# =============================================================================
# Control Token Cleanup
# =============================================================================
# Pre-compiled pattern for cleanup_control_tokens newline cleanup
_RE_TRIPLE_NEWLINE = re.compile(r'\n{3,}')
# Control tokens to strip - defined once at module level
_CONTROL_TOKENS = [
'<|im_end|>',
'<|im_start|>',
'<|endoftext|>',
'<|end_of_text|>',
'<|eot_id|>',
'<|eom_id|>',
'<|assistant|>',
'<|model|>',
'<|python|>',
'<|javascript|>',
'<|html|>',
'\n\nassistant',
'\nAssistant',
'ASSISTANT',
'Assistant',
'assistant',
]
# Build expanded token set for O(1) lookup (includes \n and space prefixed variants)
_CONTROL_TOKENS_START = set()
_CONTROL_TOKENS_END = set()
for _t in _CONTROL_TOKENS:
_CONTROL_TOKENS_START.update([_t, '\n' + _t, ' ' + _t])
_CONTROL_TOKENS_END.update([_t, '\n' + _t, ' ' + _t])
# Sort by length descending so longer tokens match first (avoids partial matches)
_CONTROL_TOKENS_START_SORTED = sorted(_CONTROL_TOKENS_START, key=len, reverse=True)
_CONTROL_TOKENS_END_SORTED = sorted(_CONTROL_TOKENS_END, key=len, reverse=True)
def cleanup_control_tokens(text: str) -> str:
"""
Clean up leading/trailing control tokens from model output.
Removes tokens like <|im_end|>, <|im_start|>, 'assistant', etc. that might
appear at the start or end of the response after reasoning extraction.
Uses module-level sorted token lists for efficient matching.
Tokens are sorted by length (longest first) to prevent partial matches
like 'assistant' matching before '<|assistant|>'.
"""
if not text:
return text
cleaned = text
# Strip from start - keep trying until no more tokens at start
# Max iterations bounded by text length / min token length
changed = True
while changed:
changed = False
for token in _CONTROL_TOKENS_START_SORTED:
if cleaned.startswith(token):
cleaned = cleaned[len(token):]
changed = True
break # Restart from longest token after each removal
# Strip from end - keep trying until no more tokens at end
changed = True
while changed:
changed = False
for token in _CONTROL_TOKENS_END_SORTED:
if cleaned.endswith(token):
cleaned = cleaned[:-len(token)]
changed = True
break # Restart from longest token after each removal
# Clean up any resulting triple+ newlines (pre-compiled pattern)
cleaned = _RE_TRIPLE_NEWLINE.sub('\n\n', cleaned)
# Strip leading/trailing whitespace
cleaned = cleaned.strip()
return cleaned
def filter_repetition(text: str, min_repeat_count: int = 3, ngram_sizes: tuple = (2, 3)) -> str:
"""
Detect and remove n-gram repetition from text.
......@@ -1478,8 +1562,18 @@ def validate_json_complete(json_str: str) -> bool:
# Tool Formatting
# =============================================================================
def format_tools_for_prompt(tools, messages):
"""Format tools into the system message or add a tool description."""
def format_tools_for_prompt(tools, messages, tools_closer_prompt: bool = False):
"""Format tools into the system message or add a tool description.
Args:
tools: List of Tool objects
messages: List of ChatMessage objects
tools_closer_prompt: If True, place tools right before the user's latest message
instead of in the system prompt (prompt distillation)
Returns:
Modified list of ChatMessage objects
"""
import json
if not tools:
......@@ -1512,19 +1606,47 @@ def format_tools_for_prompt(tools, messages):
# Add or prepend to system message
new_messages = list(messages)
system_found = False
for i, msg in enumerate(new_messages):
if msg.role == "system":
new_messages[i] = ChatMessage(
role="system",
content=f"{tools_text}\n\n{msg.content or ''}"
)
system_found = True
break
if not system_found:
new_messages.insert(0, ChatMessage(role="system", content=tools_text))
if tools_closer_prompt:
# Prompt distillation: insert tools right before the LAST user message
# Find the last user message and insert tools before it
last_user_idx = None
for i, msg in enumerate(new_messages):
if msg.role == "user":
last_user_idx = i
if last_user_idx is not None:
# Insert a tool context message before the last user message
tools_message = ChatMessage(role="system", content=f"Available tools:\n{tools_text}")
new_messages.insert(last_user_idx, tools_message)
else:
# No user message found, fall back to system message
system_found = False
for i, msg in enumerate(new_messages):
if msg.role == "system":
new_messages[i] = ChatMessage(
role="system",
content=f"{tools_text}\n\n{msg.content or ''}"
)
system_found = True
break
if not system_found:
new_messages.insert(0, ChatMessage(role="system", content=tools_text))
else:
# Traditional behavior: prepend tools to system message
system_found = False
for i, msg in enumerate(new_messages):
if msg.role == "system":
new_messages[i] = ChatMessage(
role="system",
content=f"{tools_text}\n\n{msg.content or ''}"
)
system_found = True
break
if not system_found:
new_messages.insert(0, ChatMessage(role="system", content=tools_text))
return new_messages
......
import time
import uuid
# Try to import litellm for response formatting
# Fall back to plain dicts if litellm is not available or doesn't export these
try:
from litellm import ModelResponse, ChatCompletionChunk
LITELLM_AVAILABLE = True
except ImportError:
LITELLM_AVAILABLE = False
ModelResponse = None
ChatCompletionChunk = None
class OpenAIFormatter:
"""Formatter for standardizing chat completion responses in OpenAI format.
This class provides final sanitization of responses before sending them
to clients. It processes the output of the internal parser and formats
them into proper OpenAI-compatible responses.
"""
def __init__(self, model_name: str):
self.model_name = model_name
self.id = f"chatcmpl-{uuid.uuid4()}"
def format_full(self, text: str, prompt_tokens: int, completion_tokens: int, tool_calls=None) -> dict:
"""Format a standard (non-streaming) response.
Args:
text: The generated text content
prompt_tokens: Number of tokens in the prompt
completion_tokens: Number of tokens in the completion
tool_calls: Optional list of tool calls to include
Returns:
Dictionary representation of the response
"""
message = {
"role": "assistant",
"content": text if not tool_calls else None,
}
if tool_calls:
message["tool_calls"] = tool_calls
choice = {
"index": 0,
"message": message,
"finish_reason": "tool_calls" if tool_calls else "stop",
}
return {
"id": self.id,
"object": "chat.completion",
"created": int(time.time()),
"model": self.model_name,
"choices": [choice],
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens,
},
"provider": {
"provider_name": "coderai",
"provider_id": "coderai",
},
}
def format_chunk(self, delta_text: str, is_final: bool = False, usage: dict = None) -> dict:
"""Format a streaming chunk response.
Args:
delta_text: The incremental text content for this chunk
is_final: Whether this is the final chunk
usage: Optional usage information (typically only sent on final chunk)
Returns:
Dictionary representation of the chunk
"""
delta = {
"content": delta_text,
"role": "assistant",
}
choice = {
"index": 0,
"delta": delta,
"finish_reason": "stop" if is_final else None,
}
chunk = {
"id": self.id,
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": self.model_name,
"choices": [choice],
}
if usage and is_final:
chunk["usage"] = usage
return chunk
def format_final_chunk(self, usage: dict = None) -> dict:
"""Format the final streaming chunk with usage information.
Args:
usage: Usage statistics dictionary with prompt_tokens, completion_tokens, total_tokens
Returns:
Dictionary representation of the final chunk
"""
delta = {
"content": None,
"role": "assistant",
}
choice = {
"index": 0,
"delta": delta,
"finish_reason": "stop",
}
chunk = {
"id": self.id,
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": self.model_name,
"choices": [choice],
}
if usage:
chunk["usage"] = usage
return chunk
def format_litellm_full(self, text: str, prompt_tokens: int, completion_tokens: int, tool_calls=None) -> dict:
"""Format using litellm's ModelResponse if available.
Args:
text: The generated text content
prompt_tokens: Number of tokens in the prompt
completion_tokens: Number of tokens in the completion
tool_calls: Optional list of tool calls to include
Returns:
Dictionary representation of ModelResponse
"""
if not LITELLM_AVAILABLE or ModelResponse is None:
return self.format_full(text, prompt_tokens, completion_tokens, tool_calls)
try:
from litellm import Choices, Message, Usage
return ModelResponse(
id=self.id,
model=self.model_name,
object="chat.completion",
created=int(time.time()),
choices=[Choices(
finish_reason="tool_calls" if tool_calls else "stop",
index=0,
message=Message(content=text if not tool_calls else None, role="assistant", tool_calls=tool_calls)
)],
usage=Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens
)
).model_dump()
except Exception:
# Fall back to plain dict if litellm fails
return self.format_full(text, prompt_tokens, completion_tokens, tool_calls)
def format_litellm_chunk(self, delta_text: str, is_final: bool = False, usage: dict = None) -> dict:
"""Format streaming chunk using litellm's ChatCompletionChunk if available.
Args:
delta_text: The incremental text content for this chunk
is_final: Whether this is the final chunk
usage: Optional usage information (typically only sent on final chunk)
Returns:
Dictionary representation of ChatCompletionChunk
"""
if not LITELLM_AVAILABLE or ChatCompletionChunk is None:
return self.format_chunk(delta_text, is_final, usage)
try:
from litellm import StreamingChoices, Delta, Usage
return ChatCompletionChunk(
id=self.id,
model=self.model_name,
object="chat.completion.chunk",
created=int(time.time()),
choices=[StreamingChoices(
finish_reason="stop" if is_final else None,
index=0,
delta=Delta(content=delta_text, role="assistant")
)],
usage=Usage(**usage) if usage else None
).model_dump()
except Exception:
# Fall back to plain dict if litellm fails
return self.format_chunk(delta_text, is_final, usage)
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment