Commit 405be1cb authored by Your Name's avatar Your Name

Refactor: Move all API endpoints to codai.api module and extract CLI to codai.cli

- Move parse_args to codai.cli
- Move main() to codai.main
- Simplify coderai to be a thin wrapper importing from codai package
- Create codai.api module with organized endpoints:
  - codai/api/app.py: FastAPI app, /v1/models, /v1/files, get_load_mode
  - codai/api/text.py: /v1/chat/completions, legacy /v1/completions
  - codai/api/images.py: /v1/images/generations
  - codai/api/transcriptions.py: /v1/audio/transcriptions
  - codai/api/tts.py: /v1/audio/speech
- coderai is now backward compatible entry point only
parent 1c79d9ab
# codai.api - FastAPI application module
from .app import app
__all__ = ['app']
"""
FastAPI application module for codai API.
Contains the FastAPI app initialization, lifespan, and core endpoints.
"""
from contextlib import asynccontextmanager
from typing import List
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import FileResponse, JSONResponse
# Import from codai modules
from codai.pydantic.textrequest import ModelList
from codai.models.manager import model_manager, multi_model_manager
# Global references to be set by coderai
# These will be imported/assigned after the app is created
global_debug = False
global_file_path = None
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Lifespan context manager for startup/shutdown."""
# Startup
yield
# Shutdown
multi_model_manager.cleanup()
model_manager.cleanup()
# Stop whisper-server if running
if multi_model_manager.whisper_server:
multi_model_manager.whisper_server.stop()
# Create the FastAPI app
app = FastAPI(
title="OpenAI-Compatible API",
description="OpenAI-compatible API supporting NVIDIA (CUDA) and Vulkan backends",
version="2.0.0",
lifespan=lifespan,
)
# Import routers from submodules
from codai.api.transcriptions import router as transcriptions_router
from codai.api.images import router as images_router
from codai.api.tts import router as tts_router
# Import and add middleware
from codai.api.log import log_requests
app.middleware("http")(log_requests)
# Include routers from submodules
app.include_router(transcriptions_router)
app.include_router(images_router)
app.include_router(tts_router)
@app.get("/v1/models", response_model=ModelList)
async def list_models():
"""List available models."""
models = multi_model_manager.list_models()
return ModelList(data=models)
@app.get("/v1/files/{filename}")
async def get_file(filename: str):
"""Serve uploaded/generated files."""
if global_file_path:
import os
file_path = os.path.join(global_file_path, filename)
if os.path.exists(file_path):
return FileResponse(file_path)
raise HTTPException(status_code=404, detail="File not found")
def set_global_debug(debug: bool):
"""Set the global debug flag."""
global global_debug
global_debug = debug
def set_global_file_path(path: str):
"""Set the global file path."""
global global_file_path
global_file_path = path
# Load mode - will be set by coderai
_load_mode = {"mode": "ondemand"}
def get_load_mode():
"""Get the current load mode."""
return _load_mode.get("mode", "ondemand")
def set_load_mode(mode: str):
"""Set the load mode from coderai."""
global _load_mode
_load_mode = mode
"""
Image generation endpoints for the codai API.
"""
import asyncio
import base64
import io
import os
import uuid
from fastapi import APIRouter, HTTPException, Request
from PIL import Image
# Import from codai modules
from codai.models.manager import multi_model_manager
from codai.pydantic.imagerequest import ImageGenerationRequest
from codai.api.app import get_load_mode
# Global reference to be set by coderai
global_args = None
global_file_path = None
# Model semaphores for concurrency control (provided by coderai)
model_semaphores = {}
queue_flags = {}
# =============================================================================
# Helper Functions
# =============================================================================
def get_cfg_scale():
"""Get CFG scale for image generation. Auto-detect VRAM for Vulkan."""
global global_args
cfg_scale = getattr(global_args, 'image_cfg_scale', 1.0)
# If using Vulkan and CLI didn't specify cfg_scale (default 1.0), check VRAM
if cfg_scale == 1.0: # Only auto-detect if using default
backend = getattr(global_args, 'backend', 'auto')
image_backend = getattr(global_args, 'image_backend', 'auto')
# Check if using Vulkan (either global or image-specific)
use_vulkan = (backend == 'vulkan') or (image_backend == 'vulkan') or (image_backend == 'auto' and backend == 'auto')
if use_vulkan:
# Try to detect VRAM
try:
import subprocess
# Try vulkaninfo first
result = subprocess.run(['vulkaninfo', '-J'], capture_output=True, text=True, timeout=5)
if result.returncode == 0:
import json
data = json.loads(result.stdout)
# Find device memory
for dev in data.get('devices', []):
mem = dev.get('deviceMemoryHeap', [{}])
for heap in mem:
if heap.get('flags', []).get('deviceLocal', False):
vram_mb = heap.get('size', 0) / (1024 * 1024)
print(f"DEBUG: Detected VRAM: {vram_mb:.0f} MB")
if vram_mb < 16000: # Less than 16GB
print(f"DEBUG: VRAM < 16GB, using cfg_scale=1.0 for better performance")
return 1.0
break
except Exception as e:
print(f"DEBUG: Could not detect VRAM: {e}")
# Default to 1.0 for Vulkan if detection fails
return 1.0
return cfg_scale
def save_image_response(img, request_format="base64", http_request=None):
"""
Save image to file path if configured, return response dict.
If --file-path is set and request_format is url (not base64), return only URL.
If --file-path is set and request_format is base64, return both URL and base64.
If --file-path is not set, return base64 as usual.
"""
global global_file_path, global_args
# Convert to PIL Image if needed
if not isinstance(img, Image.Image):
img = Image.fromarray(img)
result = {}
# Save to file path if configured
if global_file_path:
os.makedirs(global_file_path, exist_ok=True)
# Generate unique filename
filename = f"{uuid.uuid4().hex}.png"
file_path = os.path.join(global_file_path, filename)
img.save(file_path, format="PNG")
# Add URL to response
# Determine base URL based on --url argument
url_setting = getattr(global_args, 'url', 'auto') if global_args else 'auto'
if url_setting == 'auto':
# Use server host from request headers (what client used to connect)
if http_request:
# Get the Host header - this is what the client used to reach the server
client_host = http_request.headers.get('host', '')
if not client_host:
# Fallback to client IP if no Host header
client_host = http_request.client.host if http_request.client else '127.0.0.1'
# Strip port if present in Host header
if ':' in client_host and not client_host.replace(':', '').isdigit():
client_host = client_host.split(':')[0]
# Check if HTTPS is enabled
use_https = getattr(global_args, 'https', False) or getattr(global_args, 'pubkey', None)
protocol = "https" if use_https else "http"
port = getattr(global_args, 'port', 8000)
base_url = f"{protocol}://{client_host}:{port}"
else:
base_url = "http://127.0.0.1:8000"
else:
# Use explicitly provided URL (strip trailing slash if present)
base_url = url_setting.rstrip('/')
result["url"] = f"{base_url}/v1/files/{filename}"
# If client explicitly requested base64, include it
# Otherwise, only return URL when file-path is set
if request_format == "base64":
buffered = io.BytesIO()
img.save(buffered, format="PNG")
img_bytes = buffered.getvalue()
img_base64 = base64.b64encode(img_bytes).decode('utf-8')
result["b64_json"] = img_base64
else:
# No file-path, return base64 as usual
buffered = io.BytesIO()
img.save(buffered, format="PNG")
img_bytes = buffered.getvalue()
img_base64 = base64.b64encode(img_bytes).decode('utf-8')
result["b64_json"] = img_base64
return result
def set_global_args(args):
"""Set global args from coderai."""
global global_args
global_args = args
def set_global_file_path(path):
"""Set global file path from coderai."""
global global_file_path
global_file_path = path
def set_model_semaphores(semaphores):
"""Set model semaphores from coderai for concurrency control."""
global model_semaphores
model_semaphores = semaphores
def set_queue_flags(flags):
"""Set queue flags from coderai."""
global queue_flags
queue_flags = flags
# =============================================================================
# Router and Endpoints
# =============================================================================
router = APIRouter()
@router.post("/v1/images/generations")
async def create_image_generation(request: ImageGenerationRequest, http_request: Request = None):
"""
Image generation endpoint (OpenAI-compatible).
Supports:
- Stable Diffusion via stable-diffusion-cpp-python (sd.cpp)
- Stable Diffusion XL (via local inference with diffusers)
"""
global global_args, global_file_path, model_semaphores, queue_flags
# Get or create semaphore for this model
model_key = f"image:{request.model}" if request.model else "image"
mode = get_load_mode()
# Check if --image-1 is set (no queue, return 409 if busy)
use_1_mode = queue_flags.get("image_1", False)
# In loadall mode, allow 1 concurrent request per model
# In ondemand mode, serialize all requests (use global semaphore)
if mode == "loadall":
if model_key not in model_semaphores:
model_semaphores[model_key] = asyncio.Semaphore(1)
semaphore = model_semaphores[model_key]
else:
# Use a global semaphore for ondemand mode
if "global_image" not in model_semaphores:
model_semaphores["global_image"] = asyncio.Semaphore(1)
semaphore = model_semaphores["global_image"]
# Try to acquire semaphore without blocking
if use_1_mode:
acquired = semaphore.locked()
if acquired:
raise HTTPException(
status_code=409,
detail="Image model is busy. Try again later."
)
async with semaphore:
image_model = multi_model_manager.image_model
# If no image model configured, try to use main --model as fallback
if not image_model:
# Try to get the main model from args
main_model = getattr(global_args, 'model', None)
if main_model and isinstance(main_model, list) and len(main_model) > 0:
image_model = main_model[0]
elif main_model:
image_model = main_model
# Check if main model is a GGUF file - can't use for image generation
if image_model and ('.gguf' in image_model.lower() or 'gguf' in image_model.lower()):
print(f"Note: Main model is a GGUF file (for text), not suitable for image generation")
image_model = None # Can't use GGUF for images
# If still no image model configured, return an error
if not image_model:
raise HTTPException(
status_code=400,
detail="Image generation not configured. Use --image-model to specify a model."
)
# Determine model to use
# Priority: 1) model specified in request, 2) default image model from --image-model
model_to_use = request.model
if not model_to_use or model_to_use == "image":
# No model specified in request, use default
model_to_use = image_model
elif model_to_use.startswith("image:"):
# Legacy format - strip prefix and use default
model_to_use = image_model
else:
# Check if model_to_use is a valid model (URL, file, or known model)
# If not, fallback to the configured image model to avoid HF resolution errors
if image_model:
is_url = model_to_use.startswith('http://') or model_to_use.startswith('https://')
is_file = os.path.isfile(model_to_use) if model_to_use else False
if not is_url and not is_file:
# Unknown model name - use default instead of trying to resolve as HF
print(f"Warning: Unknown model '{model_to_use}' in image generation request, using configured --image-model")
model_to_use = image_model
# Check if model is loaded
model_key = f"image:{model_to_use}"
pipeline = multi_model_manager.get_model(model_key)
# Try to load if not cached
if pipeline is None:
# Try diffusers first
try:
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline
import torch
# Check if model is XL
is_xl = "xl" in model_to_use.lower() or "sdxl" in model_to_use.lower()
print(f"Loading diffusers model: {model_to_use}")
# Determine compute type
if torch.cuda.is_available():
dtype = torch.float16
else:
dtype = torch.float32
# Try to load the model
load_error = None
try:
if is_xl:
# Load SDXL
pipeline = StableDiffusionXLPipeline.from_pretrained(
model_to_use,
torch_dtype=dtype,
)
else:
# Load SD 1.5
pipeline = StableDiffusionPipeline.from_pretrained(
model_to_use,
torch_dtype=dtype,
)
except Exception as load_error:
# Try with revised model resolution for custom models
print(f"Warning: First model load attempt failed: {load_error}")
print("Trying alternative loading method...")
# Try with default resolution
try:
if is_xl:
pipeline = StableDiffusionXLPipeline.from_pretrained(
model_to_use,
torch_dtype=dtype,
)
else:
pipeline = StableDiffusionPipeline.from_pretrained(
model_to_use,
torch_dtype=dtype,
)
except Exception as retry_error:
# If it still fails, try without safety checker
print(f"Warning: Retry failed: {retry_error}, trying without safety checker...")
if is_xl:
pipeline = StableDiffusionXLPipeline.from_pretrained(
model_to_use,
torch_dtype=dtype,
safety_checker=None,
)
else:
pipeline = StableDiffusionPipeline.from_pretrained(
model_to_use,
torch_dtype=dtype,
safety_checker=None,
)
# Determine device
backend = getattr(global_args, 'backend', 'auto')
image_backend = getattr(global_args, 'image_backend', 'auto')
use_vulkan = (backend == 'vulkan') or (image_backend == 'vulkan') or (image_backend == 'auto' and backend == 'auto')
if use_vulkan and not torch.cuda.is_available():
# Vulkan/CPU mode
try:
pipeline.to("cpu")
# Enable CPU offload if available
if hasattr(pipeline, 'enable_attention_slicing'):
pipeline.enable_attention_slicing()
except Exception as e:
print(f"Warning: Could not move to CPU: {e}")
elif torch.cuda.is_available():
# CUDA mode
try:
pipeline.to("cuda")
except Exception as e:
print(f"Warning: Could not move to CUDA: {e}")
# Cache the model
multi_model_manager.add_model(model_key, pipeline)
print(f"Loaded diffusers model: {model_to_use}")
except ImportError as e:
# diffusers not installed
diffusers_error = str(e)
print(f"diffusers not available: {diffusers_error}")
except Exception as e:
import traceback
diffusers_error = str(e)
print(f"diffusers error: {diffusers_error}")
print(f"Traceback: {traceback.format_exc()}")
# Try diffusers if available
if pipeline is not None:
try:
# Determine size
width, height = 512, 512
if request.size:
parts = request.size.split("x")
if len(parts) == 2:
try:
width = int(parts[0])
height = int(parts[1])
except ValueError:
pass
# Check for nan/inf in dimensions
if width != width or width == float('inf'): # NaN or inf check
width = 512
if height != height or height == float('inf'): # NaN or inf check
height = 512
# Import torch for generation
import torch
# Ensure model is on correct device
backend = getattr(global_args, 'backend', 'auto')
image_backend = getattr(global_args, 'image_backend', 'auto')
use_vulkan = (backend == 'vulkan') or (image_backend == 'vulkan') or (image_backend == 'auto' and backend == 'auto')
if use_vulkan and not torch.cuda.is_available():
# CPU mode - try to reduce memory usage
try:
if hasattr(pipeline, 'enable_attention_slicing'):
pipeline.enable_attention_slicing(slice_size="auto")
if hasattr(pipeline, 'enable_vae_slicing'):
pipeline.enable_vae_slicing()
except Exception as e:
print(f"Warning: Could not enable memory optimizations: {e}")
elif torch.cuda.is_available():
# Try to enable memory optimizations for CUDA
try:
if hasattr(pipeline, 'enable_attention_slicing'):
pipeline.enable_attention_slicing(slice_size="auto")
if hasattr(pipeline, 'enable_vae_slicing'):
pipeline.enable_vae_slicing()
except Exception as e:
print(f"Warning: Could not enable CUDA memory optimizations: {e}")
# Get timestamp BEFORE calling diffusers (to avoid scope issues)
import time as time_module
timestamp = int(time_module.time())
# Generate images
# Use request seed if provided, otherwise use CLI default seed
seed = request.seed if request.seed is not None else getattr(global_args, 'image_seed', None)
generator = None
if seed is not None:
generator = torch.Generator(device=pipeline.device).manual_seed(seed)
# Quality: "standard" or "hd"
quality = request.quality or "standard"
# Use request parameters if provided, otherwise fall back to quality-based defaults
num_steps = request.steps if request.steps else (30 if quality == "standard" else 50)
cfg_scale = request.guidance_scale if request.guidance_scale else (
getattr(global_args, 'image_cfg_scale', 7.5) if quality == "standard" else 9.0
)
# Generate
result = pipeline(
prompt=request.prompt,
negative_prompt=None,
num_images_per_prompt=request.n,
height=height,
width=width,
generator=generator,
guidance_scale=cfg_scale,
num_inference_steps=num_steps,
)
# Extract images
images = []
try:
result_images = result.images
except Exception as img_err:
print(f"Warning: Could not access result.images: {img_err}")
# Try alternative: result might have 'image' or 'output'
result_images = getattr(result, 'image', None) or getattr(result, 'output', None)
if result_images is None:
raise Exception(f"Could not extract images from diffusers result: {img_err}")
for img in result_images:
# Convert to base64
import numpy as np
# Handle NaN/Inf values in image data - convert to valid values
if isinstance(img, np.ndarray):
# Replace NaN and Inf with valid values
img = np.nan_to_num(img, nan=0.0, posinf=1.0, neginf=0.0)
# Clip to valid range [0, 1]
img = np.clip(img, 0.0, 1.0)
# Use helper function to save and get response
img_data = save_image_response(img, request.response_format, http_request)
images.append(img_data)
return {
"created": timestamp,
"data": images
}
except ImportError as e:
# diffusers/torch not installed - record error and try sd.cpp
diffusers_error = str(e)
print(f"diffusers not available: {diffusers_error}, trying stable-diffusion-cpp-python...")
except Exception as e:
# Other error with diffusers - record and try sd.cpp
import traceback
diffusers_error = str(e)
print(f"diffusers error: {diffusers_error}")
print(f"Traceback: {traceback.format_exc()}")
print(f"Trying stable-diffusion-cpp-python...")
# Try stable-diffusion-cpp-python (sd.cpp) as fallback
# First, check all available image models to find one loaded via sd.cpp
# Always check for cached models - allows dynamically loaded models to be reused across requests
sd_model = None
for key in multi_model_manager.models:
if key.startswith("image:"):
potential_model = multi_model_manager.get_model(key)
if potential_model is not None:
# Check if it's a stable-diffusion-cpp model
try:
from stable_diffusion_cpp import StableDiffusion
if isinstance(potential_model, StableDiffusion):
sd_model = potential_model
print(f"Found cached stable-diffusion-cpp model with key: {key}")
break
except ImportError:
pass
# If no cached image model found, need to load one - first cleanup any existing models
if sd_model is None:
# Check if there's a text model loaded and unload it to free VRAM
# Cleanup ALL models except the one we're about to load
for key in list(multi_model_manager.models.keys()):
# Skip the image model we'll be loading (if we find it later)
# For now, cleanup all other models
if key.startswith("image:"):
continue
# Unload any other model (text, audio, etc.) to free VRAM
model_to_cleanup = multi_model_manager.models.get(key)
if model_to_cleanup is not None:
print(f"Unloading '{key}' from VRAM to make room for image model")
try:
if hasattr(model_to_cleanup, 'cleanup') and callable(getattr(model_to_cleanup, 'cleanup')):
model_to_cleanup.cleanup()
elif hasattr(model_to_cleanup, 'model') and model_to_cleanup.model is not None:
if hasattr(model_to_cleanup.model, 'cleanup'):
model_to_cleanup.model.cleanup()
except Exception as e:
print(f"Warning during cleanup of '{key}': {e}")
del multi_model_manager.models[key]
if sd_model is not None:
# Check if it's a stable-diffusion-cpp model (has generate method from sd.cpp)
try:
from stable_diffusion_cpp import StableDiffusion
if isinstance(sd_model, StableDiffusion):
print(f"Using stable-diffusion-cpp-python for image generation")
# Use sd.cpp for generation
# Parse size
width, height = 512, 512
if request.size:
parts = request.size.split("x")
if len(parts) == 2:
try:
width = int(parts[0])
height = int(parts[1])
except ValueError:
pass
# Use default steps for Z-Image Turbo (very fast)
steps = 4 # Default for fast generation
# Generate images using sd.cpp (run in thread to not block event loop)
# Use request seed if provided, otherwise use CLI default seed
seed = request.seed if request.seed is not None else getattr(global_args, 'image_seed', None)
result = await asyncio.to_thread(
sd_model.generate_image,
prompt=request.prompt,
negative_prompt='',
width=width,
height=height,
cfg_scale=get_cfg_scale(),
sample_steps=steps,
seed=seed if seed is not None else 42,
batch_count=request.n if request.n else 1,
)
# Small delay to let Vulkan driver settle after generation
import time
time.sleep(0.1)
# Convert results to response format
images = []
for img in result:
# Use helper function to save and get response
img_data = save_image_response(img, http_request=http_request)
images.append(img_data)
return {
"created": int(time.time()),
"data": images
}
except ImportError as e:
# stable-diffusion-cpp not available
sd_cpp_error = str(e)
print(f"stable-diffusion-cpp-python not available: {sd_cpp_error}")
except Exception as e:
print(f"sd.cpp generation error: {e}")
sd_cpp_error = str(e)
else:
# No sd.cpp model pre-loaded, try to load dynamically
print("No pre-loaded sd.cpp model found, trying to load...")
try:
from stable_diffusion_cpp import StableDiffusion
# Check if model_to_use is a URL and get cached path
# Also handle HuggingFace model IDs that need to be resolved
model_path = None
if model_to_use.startswith('http://') or model_to_use.startswith('https://'):
cached_path = multi_model_manager.get_cached_model_path(model_to_use)
if cached_path:
model_path = cached_path
print(f"Using cached model: {model_path}")
else:
# Not cached - download it
print(f"Downloading model: {model_to_use}")
cache_dir = multi_model_manager.get_model_cache_dir()
model_path = multi_model_manager.download_model(model_to_use, cache_dir)
print(f"Downloaded to: {model_path}")
elif os.path.isfile(model_to_use):
model_path = model_to_use
else:
# Try to resolve as HuggingFace model ID
print(f"Trying to resolve as HuggingFace model ID: {model_to_use}")
try:
from huggingface_hub import hf_hub_download, list_repo_files
# Parse model name (format: "org/model" or "org/model/filename.gguf")
parts = model_to_use.split('/')
if len(parts) >= 2:
repo_id = f"{parts[0]}/{parts[1]}"
# First check if there's a cached GGUF file for this model
# Try common GGUF file patterns
files = list_repo_files(repo_id)
gguf_files = [f for f in files if f.endswith('.gguf')]
if gguf_files:
# Try to find a cached version first
for gguf_file in gguf_files:
# Construct potential URL and check cache
potential_url = f"https://huggingface.co/{repo_id}/resolve/main/{gguf_file}"
cached = multi_model_manager.get_cached_model_path(potential_url)
if cached:
model_path = cached
print(f"Using cached GGUF model: {model_path}")
break
# If not cached, download the first GGUF file
if not model_path:
print(f"Downloading GGUF model from HF: {gguf_files[0]}")
model_path = hf_hub_download(repo_id=repo_id, filename=gguf_files[0])
print(f"Downloaded to: {model_path}")
except Exception as e:
print(f"Could not resolve as HuggingFace model: {e}")
if model_path is None:
print("Warning: Could not resolve sd.cpp model path")
sd_cpp_error = "Could not resolve model path"
else:
# Load sd.cpp model
# Determine backend to use based on CLI args
backend = getattr(global_args, 'backend', 'auto')
image_backend = getattr(global_args, 'image_backend', 'auto')
# Use CUDA only if explicitly requested via --backend nvidia or --image-backend nvidia
use_cuda = (backend == 'nvidia' or backend == 'cuda' or
image_backend == 'nvidia' or image_backend == 'cuda')
if use_cuda:
print(f"Using CUDA backend for sd.cpp image generation")
else:
print(f"Using Vulkan backend for sd.cpp image generation")
# Build kwargs for stable-diffusion-cpp with CLI args
sd_kwargs = {'diffusion_model_path': model_path}
# Add VAE path from CLI args if provided
vae_path = getattr(global_args, 'vae_path', None)
if vae_path:
# Check if it's a URL and download if needed
if vae_path.startswith('http://') or vae_path.startswith('https://'):
cached = multi_model_manager.get_cached_model_path(vae_path)
if cached:
sd_kwargs['vae_path'] = cached
print(f"Using cached VAE model: {cached}")
else:
cache_dir = multi_model_manager.get_model_cache_dir()
sd_kwargs['vae_path'] = multi_model_manager.download_model(vae_path, cache_dir)
else:
sd_kwargs['vae_path'] = vae_path
# Add LLM/CLIP path from CLI args if provided
llm_path = getattr(global_args, 'llm_path', None)
if llm_path:
if llm_path.startswith('http://') or llm_path.startswith('https://'):
cached = multi_model_manager.get_cached_model_path(llm_path)
if cached:
sd_kwargs['llm_path'] = cached
print(f"Using cached LLM model: {cached}")
else:
cache_dir = multi_model_manager.get_model_cache_dir()
sd_kwargs['llm_path'] = multi_model_manager.download_model(llm_path, cache_dir)
else:
sd_kwargs['llm_path'] = llm_path
# Add T5XXL path from CLI args if provided
t5xxl_path = getattr(global_args, 't5xxl_path', None)
if t5xxl_path:
if t5xxl_path.startswith('http://') or t5xxl_path.startswith('https://'):
cached = multi_model_manager.get_cached_model_path(t5xxl_path)
if cached:
sd_kwargs['t5xxl_path'] = cached
print(f"Using cached T5XXL model: {cached}")
else:
cache_dir = multi_model_manager.get_model_cache_dir()
sd_kwargs['t5xxl_path'] = multi_model_manager.download_model(t5xxl_path, cache_dir)
else:
sd_kwargs['t5xxl_path'] = t5xxl_path
# Add clip_on_cpu if specified
if getattr(global_args, 'clip_on_cpu', False):
sd_kwargs['keep_clip_on_cpu'] = True
print(f"DEBUG: Running CLIP on CPU to save VRAM (keep_clip_on_cpu=True)")
# Use all available CPU cores
import psutil
sd_kwargs['n_threads'] = psutil.cpu_count()
sd_model = StableDiffusion(**sd_kwargs)
# Cache the model for reuse on subsequent requests
cache_key = f"image:{model_path}"
multi_model_manager.add_model(cache_key, sd_model)
print(f"Using stable-diffusion-cpp-python for image generation")
# Generate images
width, height = 512, 512
if request.size:
parts = request.size.split("x")
if len(parts) == 2:
try:
width = int(parts[0])
height = int(parts[1])
except ValueError:
pass
steps = 4
# Use request seed if provided, otherwise use CLI default seed
seed = request.seed if request.seed is not None else getattr(global_args, 'image_seed', None)
result = await asyncio.to_thread(
sd_model.generate_image,
prompt=request.prompt,
negative_prompt='',
width=width,
height=height,
cfg_scale=get_cfg_scale(),
sample_steps=steps,
seed=seed if seed is not None else 42,
batch_count=request.n if request.n else 1,
)
# Small delay to let Vulkan driver settle after generation
import time
time.sleep(0.1)
# Convert results to response format
images = []
for img in result:
# Use helper function to save and get response
img_data = save_image_response(img, http_request=http_request)
images.append(img_data)
return {
"created": int(time.time()),
"data": images
}
except ImportError as e:
sd_cpp_error = str(e)
print(f"stable-diffusion-cpp-python not available: {sd_cpp_error}")
except Exception as e:
sd_cpp_error = str(e)
print(f"sd.cpp error: {sd_cpp_error}")
# Both backends failed - return error with installation instructions
raise HTTPException(
status_code=400,
detail=f"Model '{model_to_use}' does not support image generation"
)
"""
Request logging middleware for the codai API.
"""
from fastapi import Request
async def log_requests(request: Request, call_next):
"""Log all incoming requests for debugging."""
# Import global debug flag from app
from codai.api.app import global_debug
if request.url.path in ["/v1/chat/completions", "/v1/completions"]:
body = b""
body_str = ""
try:
body = await request.body()
body_str = body.decode('utf-8')
# In debug mode, dump the full request
if global_debug:
print(f"DEBUG: Request body: {body_str[:500]}...")
except Exception as e:
print(f"Error reading request body: {e}")
# Call the next middleware/handler
response = await call_next(request)
# Log response status
if global_debug:
print(f"DEBUG: Response status: {response.status_code}")
return response
else:
# For non-chat endpoints, just pass through
response = await call_next(request)
return response
"""
Text generation endpoints for the codai API.
"""
import asyncio
import json
import time
import uuid
from typing import AsyncGenerator, Dict, List, Optional
from fastapi import APIRouter, HTTPException, Request
# Import from codai modules
from codai.models.manager import ModelManager, WhisperServerManager, MultiModelManager, model_manager, multi_model_manager
from codai.queue.manager import QueueManager, queue_manager
from codai.pydantic.textrequest import ChatCompletionRequest, ToolFunction, Tool
from codai.models.parser import filter_malformed_content, filter_repetition, OpenAIFormatter, ModelParserAdapter
# Global reference to be set by coderai
global_args = None
global_debug = False
global_system_prompt = None
global_tools_closer_prompt = False
grammar_guided_gen = False
# =============================================================================
# Helper Functions
# =============================================================================
def set_global_args(args):
"""Set global args from coderai."""
global global_args
global_args = args
def set_global_debug(debug: bool):
"""Set the global debug flag."""
global global_debug
global_debug = debug
def set_global_system_prompt(prompt):
"""Set the global system prompt."""
global global_system_prompt
global_system_prompt = prompt
def set_global_tools_closer_prompt(tools_closer: bool):
"""Set the global tools-closer-prompt flag."""
global global_tools_closer_prompt
global_tools_closer_prompt = tools_closer
def set_grammar_guided_gen(enabled: bool):
"""Set the grammar-guided generation flag."""
global grammar_guided_gen
grammar_guided_gen = enabled
# =============================================================================
# Router and Endpoints
# =============================================================================
router = APIRouter()
@router.post("/v1/chat/completions")
async def chat_completions(request: ChatCompletionRequest, http_request: Request = None):
"""Chat completions endpoint with streaming and tool support."""
# Check if we should use litellm backend
parser_type = getattr(global_args, 'parser', 'auto') if global_args else 'auto'
if parser_type == 'litellm':
# Use LiteLLM backend
from codai.openai.litellm import get_litellm_backend, LITELLM_AVAILABLE
if not LITELLM_AVAILABLE:
raise HTTPException(
status_code=500,
detail="LiteLLM is not installed. Run: pip install litellm"
)
# Check for API key in request - litellm requires an API key
# If not provided, use a fake key to allow the request to proceed
api_key = None
# Try to get API key from request body
if hasattr(request, 'api_key') and request.api_key:
api_key = request.api_key
# If no API key in body, try to get from Authorization header
if not api_key:
auth_header = http_request.headers.get('Authorization', '') if http_request else ''
if auth_header.startswith('Bearer '):
api_key = auth_header[7:] # Extract token after 'Bearer '
# If still no API key, use a fake key to allow litellm to proceed
# litellm will then fail with the actual provider error if needed
if not api_key:
api_key = "fake-key-for-local-testing"
print("DEBUG: No API key provided, using fake key for litellm")
# Determine the base URL for litellm to connect to
# Use the server's host and port for local connections
api_base = None
# Check if model starts with 'ollama:' - use local Ollama
if request.model and request.model.startswith('ollama:'):
# Get the host from the request headers
client_host = "127.0.0.1"
if http_request:
host_header = http_request.headers.get('host', '')
if host_header:
# Strip port if present
if ':' in host_header:
client_host = host_header.split(':')[0]
if client_host.replace('.', '').isdigit():
# It's an IP, keep it
pass
else:
# It's a hostname, use localhost
client_host = "127.0.0.1"
else:
client_host = host_header
# Get port from global_args or use default
port = getattr(global_args, 'port', 11434) if global_args else 11434
api_base = f"http://{client_host}:{port}/v1"
print(f"DEBUG: Using api_base for Ollama: {api_base}")
else:
# For non-Ollama models, use the server's own URL as base
# This allows LiteLLM to make requests to the local server
if http_request:
# Get the host from the request headers
host_header = http_request.headers.get('host', '')
if host_header:
# Strip port if present to reconstruct clean URL
if ':' in host_header:
client_host = host_header.split(':')[0]
# Keep the port from the request for consistency
server_port = host_header.split(':')[1] if len(host_header.split(':')) > 1 else str(getattr(global_args, 'port', 6745))
else:
client_host = host_header
server_port = str(getattr(global_args, 'port', 6745))
else:
# Fallback to client host if no Host header
client_host = http_request.client.host if http_request.client else "127.0.0.1"
server_port = str(getattr(global_args, 'port', 6745))
else:
# Fallback if no http_request
client_host = "127.0.0.1"
server_port = str(getattr(global_args, 'port', 6745))
# Determine protocol (http or https)
use_https = getattr(global_args, 'https', False) or getattr(global_args, 'pubkey', None)
protocol = "https" if use_https else "http"
api_base = f"{protocol}://{client_host}:{server_port}/v1"
print(f"DEBUG: Using api_base for local server: {api_base}")
# Get or create litellm backend
litellm_backend = get_litellm_backend(
model=request.model,
api_key=api_key,
api_base=api_base,
context_window=8192, # Default, can be made configurable
model_manager=multi_model_manager # Pass for alias resolution
)
# Get the tool_parser from multi_model_manager for model-specific parsing
tool_parser = multi_model_manager.tool_parser if hasattr(multi_model_manager, 'tool_parser') else None
# Convert messages to dict format
messages_dict = []
for msg in request.messages:
msg_dict = {"role": msg.role, "content": msg.content or ""}
if hasattr(msg, 'tool_calls') and msg.tool_calls:
msg_dict["tool_calls"] = msg.tool_calls
if hasattr(msg, 'tool_call_id') and msg.tool_call_id:
msg_dict["tool_call_id"] = msg.tool_call_id
messages_dict.append(msg_dict)
# Prepare tools if provided
tools_dict = None
if request.tools:
tools_dict = request.tools
# Generate response
try:
if request.stream:
# Streaming response
async def generate():
try:
async for chunk in await litellm_backend.chat_completion(
messages=messages_dict,
model=request.model,
temperature=request.temperature,
top_p=request.top_p,
max_tokens=request.max_tokens,
stop=request.stop,
tools=tools_dict,
tool_choice=request.tool_choice,
stream=True,
tool_parser=tool_parser,
):
# Add rate limit headers
headers = {}
if 'usage' in chunk:
headers = litellm_backend.get_rate_limit_headers(
prompt_tokens=chunk.get('usage', {}).get('prompt_tokens', 0),
completion_tokens=chunk.get('usage', {}).get('completion_tokens', 0)
)
# Handle Qwen tool calls if model is Qwen family
if 'qwen' in request.model.lower():
content = chunk.get('choices', [{}])[0].get('delta', {}).get('content', '')
tool_calls = chunk.get('choices', [{}])[0].get('delta', {}).get('tool_calls', [])
if not tool_calls and content:
# Try to parse tool calls from content
tool_calls = litellm_backend.parse_qwen_tool_calls(content)
if tool_calls:
# Strip tool tags from content
content = litellm_backend.strip_tool_tags(content)
chunk['choices'][0]['delta']['content'] = content
chunk['choices'][0]['delta']['tool_calls'] = tool_calls
yield f"data: {json.dumps(chunk)}\n\n"
yield "data: [DONE]\n\n"
except Exception as e:
yield f"data: {json.dumps({'error': {'message': str(e), 'type': 'internal_error'}})}\n\n"
from fastapi.responses import StreamingResponse
return StreamingResponse(generate(), media_type="text/event-stream")
else:
# Non-streaming response
response = await litellm_backend.chat_completion(
messages=messages_dict,
model=request.model,
temperature=request.temperature,
top_p=request.top_p,
max_tokens=request.max_tokens,
stop=request.stop,
tools=tools_dict,
tool_choice=request.tool_choice,
stream=False,
tool_parser=tool_parser,
)
# Handle Qwen tool calls
if 'qwen' in request.model.lower() and 'choices' in response:
msg = response['choices'][0].get('message', {})
content = msg.get('content', '')
tool_calls = msg.get('tool_calls', [])
if not tool_calls and content:
tool_calls = litellm_backend.parse_qwen_tool_calls(content)
if tool_calls:
msg['content'] = litellm_backend.strip_tool_tags(content)
msg['tool_calls'] = tool_calls
response['choices'][0]['message'] = msg
# Add rate limit headers
headers = {}
if 'usage' in response:
headers = litellm_backend.get_rate_limit_headers(
prompt_tokens=response.get('usage', {}).get('prompt_tokens', 0),
completion_tokens=response.get('usage', {}).get('completion_tokens', 0)
)
from fastapi.responses import JSONResponse
return JSONResponse(content=response, headers=headers)
except Exception as e:
# Handle litellm errors
error_response = {
"error": {
"message": str(e),
"type": "internal_error",
"code": 500
}
}
from fastapi.responses import JSONResponse
return JSONResponse(content=error_response, status_code=500)
# Continue with original implementation for 'auto' parser
# Get the model for this request
requested_model = request.model
# Try to get the appropriate model
mm = multi_model_manager.get_model_for_request(requested_model)
if mm is None:
# Model not loaded - try to use default
if model_manager.backend is not None:
# Fallback to legacy model_manager
current_manager = model_manager
else:
raise HTTPException(status_code=503, detail="Model not loaded")
else:
current_manager = mm
# Inject system prompt if --system-prompt flag was provided
messages = request.messages
if global_system_prompt is not None:
# Get the custom system prompt text
if global_system_prompt is True:
# Default system prompt
system_addon = "You are a helpful assistant."
else:
# Custom system prompt provided as argument
system_addon = str(global_system_prompt)
# Check if there's already a system message
system_found = False
for i, msg in enumerate(messages):
if msg.role == "system":
# Chain the custom system prompt at the START of existing system message
from codai.pydantic.textrequest import ChatMessage
messages[i] = ChatMessage(role="system", content=system_addon + "\n\n" + msg.content)
system_found = True
break
if not system_found:
# No existing system message, use the custom one
from codai.pydantic.textrequest import ChatMessage
messages = [ChatMessage(role="system", content=system_addon)] + list(messages)
# Enable thinking/reasoning mode if requested via API parameter OR CLI flag
force_reasoning_args = getattr(global_args, 'force_reasoning', None) if global_args else None
enable_thinking_api = getattr(request, 'enable_thinking', False)
# Parse force_reasoning: can be list (from CLI) or string (legacy)
if isinstance(force_reasoning_args, str):
# Legacy: convert string to list
if force_reasoning_args == "both":
force_reasoning_args = ["inject", "stop"]
elif force_reasoning_args == "stop":
force_reasoning_args = ["stop"]
elif force_reasoning_args == "inject":
force_reasoning_args = ["inject"]
elif force_reasoning_args == "all":
# 'all' enables all reasoning methods
force_reasoning_args = ["chat", "inject", "prompt", "mock", "raw", "twopass"]
else:
force_reasoning_args = []
elif not force_reasoning_args:
force_reasoning_args = []
# Combine CLI args with API param
# 'chat' from CLI enables API reasoning param
reasoning_enabled = enable_thinking_api or (len(force_reasoning_args) > 0)
# DEBUG: Print force_reasoning status when debug mode is enabled
if global_debug:
print(f"\n{'='*60}")
print(f"=== REASONING MODE DEBUG ===")
print(f"{'='*60}")
print(f"force_reasoning CLI args: {force_reasoning_args}")
print(f"enable_thinking API param: {enable_thinking_api}")
# Debug stop sequences if available
if 'raw_stop_sequences' in locals():
print(f"stop argument for chat call: {raw_stop_sequences}")
# Get model family for reasoning tokens
from codai.models.utils import get_model_family, get_reasoning_stop_tokens, get_resolved_model_name
model_family = get_model_family(request.model)
# Check if model is qwen3 and force_reasoning is enabled
is_qwen3 = 'qwen3' in model_family.lower() if model_family else False
use_qwen3_penalties = is_qwen3 and force_reasoning_args
# System prompt addon for qwen3 with force_reasoning
qwen3_system_addon = ""
if use_qwen3_penalties:
qwen3_system_addon = "\n\nCRITICAL: Do not repeat tool calls. If a tool fails with an [ERROR], do not retry the exact same parameters. Propose a different approach or ask for clarification."
if global_debug:
print(f"QWEEN3: Adding penalties and system addon for qwen3 with force_reasoning")
# Handle 'chat' - enable thinking API parameter
# Note: This only works with compatible APIs (OpenAI-like)
# We'll set it on the request if supported
if "chat" in force_reasoning_args or enable_thinking_api:
if hasattr(request, 'thinking'):
request.thinking = {"type": "enabled"}
if global_debug:
print(f"CHAT: Reasoning API param enabled")
# Handle 'inject' - system prompt injection
# Skip for 'raw' mode since it handles everything separately
if "raw" not in force_reasoning_args and "inject" in force_reasoning_args:
from codai.models.templates import AgenticTemplateManager
template_manager = AgenticTemplateManager(request.model)
# Use reasoning tag (]]) when prompt is also selected for consistency
use_reasoning_tag = "prompt" in force_reasoning_args
# Get the current system prompt if exists
system_content = None
for msg in messages:
if msg.role == "system":
system_content = msg.content
break
if system_content:
# Inject agentic instructions
system_content = template_manager.get_agent_system_prompt(system_content, use_reasoning_tag=use_reasoning_tag)
else:
system_content = template_manager.get_agent_system_prompt("You are a helpful assistant.", use_reasoning_tag=use_reasoning_tag)
# Update or add system message
from codai.pydantic.textrequest import ChatMessage
system_found = False
for i, msg in enumerate(messages):
if msg.role == "system":
messages[i] = ChatMessage(role="system", content=system_content)
system_found = True
break
if not system_found:
messages = [ChatMessage(role="system", content=system_content)] + list(messages)
if global_debug:
print(f"INJECT: System prompt injected with agentic instructions")
print(f"\n--- INJECTED SYSTEM PROMPT ---")
print(system_content)
print(f"--- END SYSTEM PROMPT ---")
# Handle 'prompt' - prompt seeding (ends with thought tag)
# Note: 'prompt' and 'raw' are mutually exclusive - raw bypasses this
if "prompt" in force_reasoning_args and "raw" not in force_reasoning_args:
from codai.models.templates import AgenticTemplateManager
template_manager = AgenticTemplateManager(request.model)
# Convert messages to the format expected by force_reasoning_prompt
user_message = ""
system_prompt = "You are a helpful assistant."
# Extract system and user messages
for msg in messages:
if msg.role == "system":
system_prompt = msg.content
elif msg.role == "user":
user_message = msg.content
# Add qwen3 system addon if applicable
if qwen3_system_addon:
system_prompt = system_prompt + qwen3_system_addon
# Get the seeded prompt (ends with thought tag)
seeded_prompt = template_manager.force_reasoning_prompt(system_prompt, user_message)
# Replace messages with the seeded prompt (as a single user message for raw completion)
from codai.pydantic.textrequest import ChatMessage
messages = [ChatMessage(role="user", content=seeded_prompt)]
if global_debug:
print(f"PROMPT: Prompt seeding applied (ends with thought tag)")
print(f"\n--- SEEDED PROMPT (last 80 chars) ---")
print(f"...{seeded_prompt[-80:]}")
print(f"--- END SEEDED PROMPT ---")
# Handle 'raw' - use template_manager.format_for_raw_completion for raw completion
# This bypasses the chat API and uses the model's native template with reasoning seed
# The template_manager.format_for_raw_completion will be called in the block below
# Prepare stop sequences
stop_sequences = []
if request.stop:
if isinstance(request.stop, str):
stop_sequences = [request.stop]
else:
stop_sequences = list(request.stop)
# Handle 'stop' - add reasoning stop tokens (also done for 'inject' and 'prompt')
# Skip for 'raw' mode since it handles stop tokens separately
if "raw" not in force_reasoning_args and ("stop" in force_reasoning_args or "inject" in force_reasoning_args or "prompt" in force_reasoning_args):
_, _, additional_stops = get_reasoning_stop_tokens(model_family)
for stop_token in additional_stops:
if stop_token not in stop_sequences:
stop_sequences.append(stop_token)
# When using prompt seeding, also add ]]> to force stopping after reasoning
if "prompt" in force_reasoning_args:
# Add common reasoning end tags based on model family
if "</think>" not in stop_sequences:
stop_sequences.append("</think>\n")
if global_debug:
print(f"STOP: Added reasoning stop tokens: {additional_stops}")
# Format messages with tools if provided - BUT SKIP for raw mode
# (raw mode handles tools separately via format_for_raw_completion)
# Get tools_closer_prompt from global args
tools_closer = getattr(global_args, 'tools_closer_prompt', False) if global_args else False
if request.tools and "raw" not in force_reasoning_args:
messages = format_tools_for_prompt(request.tools, messages, tools_closer_prompt=tools_closer)
# Get the tool_parser from the current manager
tool_parser = current_manager.tool_parser if hasattr(current_manager, 'tool_parser') else ModelParserAdapter()
# Convert messages to dict format for chat completion
messages_dict = []
for msg in messages:
msg_dict = {"role": msg.role}
# Always include content key - llama_cpp template expects it
# Convert content to string if it's a list (multipart content)
content = msg.content
if content is None:
content = ""
elif isinstance(content, list):
# Handle multipart content array format: [{"type": "text", "text": "..."}]
parts = []
for item in content:
if isinstance(item, dict):
if item.get('type') == 'text' and 'text' in item:
parts.append(item['text'])
else:
parts.append(f"[{item.get('type', 'unknown')} content]")
else:
parts.append(str(item))
content = '\n'.join(parts)
# Ensure content is never None - convert to string
msg_dict["content"] = str(content) if content is not None else ""
# Handle tool_calls - convert to proper format if present
if msg.tool_calls:
# tool_calls should be a list of dicts with 'id', 'type', 'function' keys
msg_dict["tool_calls"] = msg.tool_calls
if msg.name:
msg_dict["name"] = msg.name
if msg.tool_call_id:
msg_dict["tool_call_id"] = msg.tool_call_id
messages_dict.append(msg_dict)
# Final safety check: ensure NO message has None content before passing to llama_cpp
# Also ensure content key always exists (not just None check)
for i, m in enumerate(messages_dict):
# Handle missing content key entirely
if "content" not in m:
messages_dict[i]["content"] = ""
# Handle None content
elif m.get("content") is None:
messages_dict[i]["content"] = ""
# Handle content that's not a string (shouldn't happen but be safe)
elif not isinstance(m["content"], str):
messages_dict[i]["content"] = str(m["content"])
# Debug: print first few messages to see their structure
print(f"DEBUG: messages_dict[0] keys: {list(messages_dict[0].keys()) if messages_dict else 'empty'}")
if len(messages_dict) > 1:
print(f"DEBUG: messages_dict[1] keys: {list(messages_dict[1].keys()) if len(messages_dict) > 1 else 'empty'}")
# Convert tools to dict format if present
tools_dict = None
if request.tools:
tools_dict = []
for tool in request.tools:
tools_dict.append({
"type": tool.type,
"function": {
"name": tool.function.name,
"description": tool.function.description,
"parameters": tool.function.parameters
}
})
# Handle raw mode - use generate() instead of generate_chat() for raw prompt completion
# Note: These may have been set earlier in the prompt handling section
# Initialize only if not already set
if 'use_raw_mode' not in locals():
use_raw_mode = False
if 'raw_prompt_for_generation' not in locals():
raw_prompt_for_generation = None
if 'raw_stop_sequences' not in locals():
raw_stop_sequences = None
# Check if we need to set up raw mode (if not already done in prompt handling)
if "raw" in force_reasoning_args and not use_raw_mode:
# Create template_manager if not already created
if 'template_manager' not in locals():
from codai.models.templates import AgenticTemplateManager
template_manager = AgenticTemplateManager(request.model)
# Use template_manager.format_for_raw_completion which handles everything
if hasattr(template_manager, 'format_for_raw_completion'):
# Extract system and user messages
system_prompt = "You are a helpful assistant."
user_message = ""
for msg in messages:
if msg.role == "system":
system_prompt = msg.content
elif msg.role == "user":
user_message = msg.content
raw_prompt_for_generation, raw_stop_sequences = template_manager.format_for_raw_completion(
system_prompt=system_prompt,
user_message=user_message,
inject_system=True,
force_reasoning=True,
tools=request.tools, # Pass tools for family-specific formatting
tools_closer_prompt=tools_closer # Pass tools-closer-prompt flag
)
use_raw_mode = True
if global_debug:
print(f"RAW: Using template_manager.format_for_raw_completion")
print(f"RAW: Prompt ends with: ...{raw_prompt_for_generation[-80:]}")
else:
if global_debug:
print(f"RAW: template_manager.format_for_raw_completion not available")
# Get resolved model name for response (with coderai/ prefix and proper formatting)
# Use multi_model_manager to get the actual loaded models, not the individual model manager
response_model_name = get_resolved_model_name(requested_model, multi_model_manager)
print(f"DEBUG: Requested model: {requested_model}, Resolved model for response: {response_model_name}")
# Handle raw mode - two pass: first capture reasoning, then get final answer
if use_raw_mode and raw_prompt_for_generation:
if global_debug:
print(f"RAW: Starting two-pass generation")
print(f"RAW: First pass prompt: ...{raw_prompt_for_generation[-100:]}")
# Build extra params for qwen3
extra_params = {}
if use_qwen3_penalties:
extra_params = {
'repeat_penalty': 1.15,
'presence_penalty': 1.5,
'frequency_penalty': 0.5,
}
if request.stream:
# For streaming, we need to handle it differently
# First pass: generate until reasoning close tag (stream it)
async def raw_stream_generate():
import json # Local import for nested function
thought_tag, close_tag, _ = get_reasoning_stop_tokens(model_family)
reasoning_text = ""
if global_debug:
print(f"DEBUG: raw_stream_generate started, stream=True")
# Use the backend's async generate if available
if hasattr(current_manager.backend, 'generate_stream'):
async for chunk in current_manager.backend.generate_stream(
prompt=raw_prompt_for_generation,
max_tokens=request.max_tokens or 2048,
temperature=request.temperature,
top_p=request.top_p,
stop=raw_stop_sequences,
**extra_params,
):
reasoning_text += chunk
# Debug: log first pass chunks
if global_debug:
print(f"DEBUG FIRST PASS: chunk length={len(chunk)}, total reasoning so far={len(reasoning_text)}")
yield f"data: {json.dumps({'choices': [{'delta': {'content': chunk}, 'finish_reason': None}]})}\n\n"
# Check if we hit the close tag
if close_tag and close_tag in reasoning_text:
if global_debug:
print(f"DEBUG: Close tag detected in first pass, reasoning length={len(reasoning_text)}")
break
else:
# Fallback: non-streaming
if global_debug:
print(f"DEBUG: Using non-streaming fallback for first pass")
first_pass_result = current_manager.generate(
prompt=raw_prompt_for_generation,
max_tokens=request.max_tokens or 2048,
temperature=request.temperature,
top_p=request.top_p,
stop=raw_stop_sequences,
**extra_params,
)
yield f"data: {json.dumps({'choices': [{'delta': {'content': first_pass_result}, 'finish_reason': None}]})}\n\n"
# After reasoning, yield the close tag and continue with final answer
if close_tag:
yield f"data: {json.dumps({'choices': [{'delta': {'content': close_tag}, 'finish_reason': None}]})}\n\n"
# Second pass: get the rest
full_prompt = raw_prompt_for_generation + reasoning_text + (close_tag or "")
if global_debug:
print(f"DEBUG: raw_stream_generate second pass, full_prompt length: {len(full_prompt)}")
second_pass_result = current_manager.generate(
prompt=full_prompt,
max_tokens=request.max_tokens or 2048,
temperature=request.temperature,
top_p=request.top_p,
stop=stop_sequences,
**extra_params,
)
# FIX: Apply repetition filtering to both reasoning and final text
reasoning_text = filter_repetition(reasoning_text)
second_pass_result = filter_repetition(second_pass_result)
# FIX: If reasoning contains tool call tags, split at the first tool tag
# The tool call part should NOT be in reasoning - it should be left for tool extraction
tool_tag_patterns = ["<tool_call>", "<tool>", "<|tool_call|", "<function="]
earliest_tool_idx = len(reasoning_text)
earliest_tool_tag = None
for tag in tool_tag_patterns:
idx = reasoning_text.find(tag)
if idx != -1 and idx < earliest_tool_idx:
earliest_tool_idx = idx
earliest_tool_tag = tag
if earliest_tool_tag:
# Split: everything before the tool tag is reasoning, everything from the tag onwards goes to second_pass_result
tool_part = reasoning_text[earliest_tool_idx:]
reasoning_text = reasoning_text[:earliest_tool_idx].strip()
# Prepend the tool part to second_pass_result so it can be extracted as a tool call
second_pass_result = tool_part + second_pass_result
if global_debug:
print(f"DEBUG: Moved tool call from reasoning to second_pass_result: {tool_part[:100]}...")
# In debug mode, dump the full generated text (second pass result)
if global_debug:
print(f"\n{'='*80}")
print(f"=== RAW STREAM: FULL GENERATED TEXT (DEBUG) ===")
print(f"{'='*80}")
print(f"--- SECOND PASS RESULT ---")
print(second_pass_result)
print(f"--- END SECOND PASS RESULT ---")
print(f"{'='*80}\n")
# Also dump the reasoning text from first pass
print(f"\n{'='*80}")
print(f"=== RAW STREAM: REASONING TEXT (DEBUG) ===")
print(f"{'='*80}")
print(reasoning_text)
print(f"{'='*80}\n")
# Try to extract tool calls from the second pass result ONLY
# FIX: Do NOT fall back to reasoning text - tool calls should only come from final response
extracted_tool_calls = None
text_for_tool_extraction = second_pass_result
# CRITICAL: Only extract from second pass, never from reasoning
# Reasoning may contain partial/incomplete tool calls that confuse the parser
if global_debug:
print(f"DEBUG: Tool extraction - using second_pass_result only")
print(f"DEBUG: Second pass result length: {len(second_pass_result) if second_pass_result else 0}")
print(f"DEBUG: Reasoning text length: {len(reasoning_text) if reasoning_text else 0}")
if request.tools and text_for_tool_extraction:
# Convert tools for ModelParserAdapter
from codai.pydantic.textrequest import Tool, ToolFunction
from codai.models.parser import ModelParserAdapter
tools_list = []
for t in request.tools:
try:
if isinstance(t, dict):
func_data = t.get("function", {})
tool_func = ToolFunction(
name=func_data.get("name", ""),
description=func_data.get("description"),
parameters=func_data.get("parameters")
)
else:
tool_func = ToolFunction(
name=t.function.name if hasattr(t.function, 'name') else str(t.function),
description=t.function.description if hasattr(t.function, 'description') else None,
parameters=t.function.parameters if hasattr(t.function, 'parameters') else None
)
tools_list.append(Tool(type=t.get("type", "function") if isinstance(t, dict) else t.type, function=tool_func))
except Exception as e:
print(f"DEBUG: Error converting tool in raw stream: {e}")
continue
if tools_list:
adapter = ModelParserAdapter(model_name=response_model_name)
extracted_tool_calls = adapter.extract_tool_calls(text_for_tool_extraction, tools_list)
# FIX: Validate extracted tool calls have valid JSON
if extracted_tool_calls:
from codai.models.parser import validate_json_complete
validated_calls = []
for tc in extracted_tool_calls:
args = tc.get('function', {}).get('arguments', '{}')
if isinstance(args, str) and validate_json_complete(args):
validated_calls.append(tc)
elif isinstance(args, dict):
# Dict is already valid
validated_calls.append(tc)
if len(validated_calls) != len(extracted_tool_calls):
if global_debug:
print(f"DEBUG: Filtered out {len(extracted_tool_calls) - len(validated_calls)} invalid tool calls")
extracted_tool_calls = validated_calls if validated_calls else None
if global_debug and extracted_tool_calls:
print(f"\n{'='*80}")
print(f"=== RAW STREAM: EXTRACTED TOOL CALLS (DEBUG) ===")
print(f"{'='*80}")
print(json.dumps(extracted_tool_calls, indent=2))
print(f"{'='*80}\n")
elif global_debug:
print(f"DEBUG: No tool calls found in raw stream")
if extracted_tool_calls:
# Yield tool calls instead of content
yield f"data: {json.dumps({'choices': [{'delta': {'tool_calls': extracted_tool_calls}, 'finish_reason': 'tool_calls'}]})}\n\n"
else:
# No tool calls, yield the content as usual
yield f"data: {json.dumps({'choices': [{'delta': {'content': second_pass_result}, 'finish_reason': 'stop'}]})}\n\n"
yield "data: [DONE]\n\n"
from fastapi.responses import StreamingResponse
return StreamingResponse(raw_stream_generate(), media_type="text/event-stream")
# Non-streaming path (already implemented above)
# First pass: generate until reasoning close tag
first_pass_result = current_manager.generate(
prompt=raw_prompt_for_generation,
max_tokens=request.max_tokens or 2048,
temperature=request.temperature,
top_p=request.top_p,
stop=raw_stop_sequences,
**extra_params,
)
if global_debug:
print(f"RAW: First pass result: ...{first_pass_result[-200:]}")
# Dump first pass result if --dump is enabled
global_dump = getattr(global_args, 'dump', False) if global_args else False
if global_dump:
print(f"\n{'='*80}")
print(f"=== RAW MODE: FIRST PASS RESULT (DUMP) ===")
print(f"{'='*80}")
print(first_pass_result)
print(f"{'='*80}\n")
# Extract reasoning (everything up to the close tag)
thought_tag, close_tag, _ = get_reasoning_stop_tokens(model_family)
reasoning_text = ""
final_text = first_pass_result
# Define tool tags that indicate end of reasoning
tool_tags = ["<tool_call>", "<tool>", "<|tool_call|>", "<|tool|>", "<function="]
if close_tag and close_tag in first_pass_result:
# Split at close tag
parts = first_pass_result.split(close_tag, 1)
reasoning_text = parts[0]
final_text = parts[1] if len(parts) > 1 else ""
else:
# Try to find tool tags as fallback stop markers
earliest_tool_idx = len(first_pass_result)
earliest_tool_tag = None
for tag in tool_tags:
idx = first_pass_result.find(tag)
if idx != -1 and idx < earliest_tool_idx:
earliest_tool_idx = idx
earliest_tool_tag = tag
if earliest_tool_tag:
# Split at tool tag
if global_debug:
print(f"RAW: No close tag found, using tool tag '{earliest_tool_tag}' as fallback")
parts = first_pass_result.split(earliest_tool_tag, 1)
reasoning_text = parts[0]
final_text = earliest_tool_tag + (parts[1] if len(parts) > 1 else "")
if global_debug:
print(f"RAW: Extracted reasoning: {reasoning_text[:100]}...")
print(f"RAW: Final text before cleanup: {final_text[:100]}...")
# Dump extraction details if --dump is enabled
if global_dump:
print(f"\n{'='*80}")
print(f"=== RAW MODE: EXTRACTION (DUMP) ===")
print(f"{'='*80}")
print(f"Close tag used: {close_tag}")
print(f"\n--- REASONING TEXT ---")
print(reasoning_text)
print(f"\n--- FINAL TEXT (before cleanup) ---")
print(final_text)
print(f"{'='*80}\n")
# Clean up control tokens from final text
final_text = cleanup_control_tokens(final_text)
# FIX: Apply repetition filtering to reasoning and final text
reasoning_text = filter_repetition(reasoning_text)
final_text = filter_repetition(final_text)
# FIX: If reasoning contains tool call tags, split at the first tool tag
# The tool call part should NOT be in reasoning - it should be left for tool extraction in final_text
tool_tag_patterns = ["<tool_call>", "<tool>", "<|tool_call|>", "<function="]
earliest_tool_idx = len(reasoning_text)
earliest_tool_tag = None
for tag in tool_tag_patterns:
idx = reasoning_text.find(tag)
if idx != -1 and idx < earliest_tool_idx:
earliest_tool_idx = idx
earliest_tool_tag = tag
if earliest_tool_tag:
# Split: everything before the tool tag is reasoning, everything from the tag onwards goes to final_text
tool_part = reasoning_text[earliest_tool_idx:]
reasoning_text = reasoning_text[:earliest_tool_idx].strip()
# Prepend the tool part to final_text so it can be extracted as a tool call
final_text = tool_part + final_text
if global_debug:
print(f"RAW: Moved tool call from reasoning to final_text: {tool_part[:100]}...")
if global_debug:
print(f"RAW: Final text after cleanup: {final_text[:100]}...")
# If we have reasoning, continue with second pass to get more complete answer
# Build the full prompt with reasoning included
full_prompt = raw_prompt_for_generation + reasoning_text + (close_tag or "")
# Second pass: generate the rest (or just use what we have)
# For now, just return what we have + optionally continue
if final_text.strip():
# We have a complete answer after reasoning
generated_text = reasoning_text + (close_tag or "") + final_text
else:
# Need second pass to get answer
second_pass_result = current_manager.generate(
prompt=full_prompt,
max_tokens=request.max_tokens or 2048,
temperature=request.temperature,
top_p=request.top_p,
stop=stop_sequences,
**extra_params,
)
# Clean up the second pass result
second_pass_result = cleanup_control_tokens(second_pass_result)
generated_text = reasoning_text + (close_tag or "") + second_pass_result
# Additional cleanup of the full generated text
generated_text = cleanup_control_tokens(generated_text)
if global_debug:
print(f"RAW: Generated text after cleanup: {generated_text[:100]}...")
# Pass through the formatter/parser (same as regular mode)
# Pipeline: Model output -> Extract reasoning (if raw mode) -> ModelParserAdapter (extract tools) -> OpenAIFormatter (final format)
from codai.models.parser import OpenAIFormatter, ModelParserAdapter
# Convert request tools for ModelParserAdapter
tools_list = None
if request.tools:
from codai.pydantic.textrequest import Tool, ToolFunction
tools_list = []
for t in request.tools:
try:
# Handle both dict and pydantic model formats
if isinstance(t, dict):
func_data = t.get("function", {})
tool_func = ToolFunction(
name=func_data.get("name", ""),
description=func_data.get("description"),
parameters=func_data.get("parameters")
)
else:
# Pydantic model
tool_func = ToolFunction(
name=t.function.name if hasattr(t.function, 'name') else str(t.function),
description=t.function.description if hasattr(t.function, 'description') else None,
parameters=t.function.parameters if hasattr(t.function, 'parameters') else None
)
tools_list.append(Tool(type=t.get("type", "function") if isinstance(t, dict) else t.type, function=tool_func))
except Exception as e:
print(f"DEBUG: Error converting tool in raw mode: {e}, tool type: {type(t)}")
continue
# Step 1: Use ModelParserAdapter to extract tool calls from final_text (NOT generated_text which includes reasoning)
# This fixes Bug 2 and Bug 3: reasoning was appearing in both content AND reasoning fields
# because the parser was receiving the full generated_text including reasoning
extracted_tool_calls = None
clean_text = final_text # Use final_text (after reasoning) instead of generated_text (which includes reasoning)
if tools_list:
adapter = ModelParserAdapter(model_name=response_model_name)
# Extract tool calls from final_text only (after reasoning is done)
extracted_tool_calls = adapter.extract_tool_calls(final_text, tools_list)
# FIX: Validate extracted tool calls have valid JSON
if extracted_tool_calls:
from codai.models.parser import validate_json_complete
validated_calls = []
for tc in extracted_tool_calls:
args = tc.get('function', {}).get('arguments', '{}')
if isinstance(args, str) and validate_json_complete(args):
validated_calls.append(tc)
elif isinstance(args, dict):
# Dict is already valid
validated_calls.append(tc)
if len(validated_calls) != len(extracted_tool_calls):
print(f"DEBUG: Filtered out {len(extracted_tool_calls) - len(validated_calls)} invalid tool calls in non-streaming")
extracted_tool_calls = validated_calls if validated_calls else None
if extracted_tool_calls:
# Strip tool calls from the text
clean_text = adapter.strip_tool_calls_from_content(final_text)
if global_debug:
print(f"RAW: Extracted {len(extracted_tool_calls)} tool calls from final_text (after reasoning)")
# Estimate token counts
prompt_tokens = len(raw_prompt_for_generation.split())
completion_tokens = len(clean_text.split()) if clean_text else 0
# Step 2: Use OpenAIFormatter for final formatting
formatter = OpenAIFormatter(response_model_name)
try:
formatted_response = formatter.format_full(
text=clean_text,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
tool_calls=extracted_tool_calls
)
except Exception as e:
print(f"RAW: ERROR in formatter.format_full: {e}")
formatted_response = None
if global_debug:
if formatted_response and isinstance(formatted_response, dict):
try:
choices = formatted_response.get('choices', [])
if choices and len(choices) > 0:
message = choices[0].get('message', {}) if isinstance(choices[0], dict) else {}
content = message.get('content', '') if isinstance(message, dict) else ''
print(f"RAW: Passed through formatter, got: {str(content)[:100]}...")
else:
print(f"RAW: WARNING - formatter returned empty choices!")
except Exception as e:
print(f"RAW: ERROR accessing formatter response: {e}")
else:
print(f"RAW: WARNING - formatter returned None or invalid response!")
# Add mock reasoning stats if 'mock' is in force_reasoning_args
# But only if we DON'T already have real reasoning from extraction
has_real_reasoning = reasoning_text and len(reasoning_text.strip()) > 10
if force_reasoning_args and "mock" in force_reasoning_args and formatted_response and not has_real_reasoning:
# Add fake reasoning tokens to trigger VSCode plugin stats
mock_reasoning_tokens = 50
# Update usage
if "usage" in formatted_response:
formatted_response["usage"]["completion_tokens"] += mock_reasoning_tokens
formatted_response["usage"]["total_tokens"] += mock_reasoning_tokens
formatted_response["usage"]["completion_tokens_details"] = {
"reasoning_tokens": mock_reasoning_tokens
}
# Add reasoning to message if not present
if "choices" in formatted_response and formatted_response["choices"]:
choice = formatted_response["choices"][0]
if "message" in choice and "reasoning" not in choice["message"]:
choice["message"]["reasoning"] = "Processing task in optimized mode..."
elif has_real_reasoning and formatted_response:
# We have real reasoning from extraction - add it to the message
if "choices" in formatted_response and formatted_response["choices"]:
choice = formatted_response["choices"][0]
if "message" in choice:
choice["message"]["reasoning"] = reasoning_text.strip()
# Also update usage with actual reasoning tokens
if "usage" in formatted_response:
reasoning_tokens = len(reasoning_text.strip().split())
formatted_response["usage"]["completion_tokens_details"] = {
"reasoning_tokens": reasoning_tokens
}
# Dump parsed output if enabled
if global_dump:
import json
print(f"\n{'='*80}")
print(f"=== RAW MODE PARSED OUTPUT (DUMP) ===")
print(f"{'='*80}")
print(json.dumps(formatted_response, indent=2))
print(f"{'='*80}\n")
# Add rate limit headers
headers = {}
if formatted_response and 'usage' in formatted_response:
headers = current_manager.backend.get_rate_limit_headers(
prompt_tokens=formatted_response.get('usage', {}).get('prompt_tokens', 0),
completion_tokens=formatted_response.get('usage', {}).get('completion_tokens', 0)
) if hasattr(current_manager.backend, 'get_rate_limit_headers') else {}
# Ensure we have a valid response to return
if not formatted_response:
# Create a minimal fallback response
formatted_response = {
"id": f"chatcmpl-{uuid.uuid4().hex}",
"object": "chat.completion",
"created": int(time.time()),
"model": response_model_name,
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": clean_text or ""
},
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens
}
}
from fastapi.responses import JSONResponse
return JSONResponse(content=formatted_response, headers=headers)
if request.stream:
from fastapi.responses import StreamingResponse
return StreamingResponse(
stream_chat_response(
messages_dict,
response_model_name,
request.max_tokens,
request.temperature,
request.top_p,
stop_sequences,
tools_dict,
current_manager,
tool_parser,
request.response_format,
),
media_type="text/event-stream",
)
else:
return await generate_chat_response(
messages_dict,
response_model_name,
request.max_tokens,
request.temperature,
request.top_p,
stop_sequences,
tools_dict,
current_manager,
tool_parser,
request.response_format,
force_reasoning_args,
)
async def stream_chat_response(
messages: List[Dict],
model_name: str,
max_tokens: Optional[int],
temperature: float,
top_p: float,
stop: List[str],
tools: Optional[List[Dict]],
current_manager: ModelManager,
tool_parser: ToolCallParser,
response_format: Optional[Dict] = None,
) -> AsyncGenerator[str, None]:
"""Stream chat completion response with queue notifications."""
completion_id = f"chatcmpl-{uuid.uuid4().hex}"
created = int(time.time())
request_id = f"req-{uuid.uuid4().hex[:8]}"
generated_text = ""
print(f"DEBUG: stream_chat_response started, stream=True, tools={tools is not None}")
# Check if model is loaded - if not, notify waiting clients
# The model manager exists but backend may not be loaded yet in on-demand mode
model_loaded = False
if current_manager is not None:
if hasattr(current_manager, 'backend') and current_manager.backend is not None:
# Check if backend has the model loaded
if hasattr(current_manager.backend, 'model') and current_manager.backend.model is not None:
model_loaded = True
elif hasattr(current_manager, 'model') and current_manager.model is not None:
# Alternative check for some model managers
model_loaded = True
# If model not loaded, add to queue and send waiting notifications
if not model_loaded:
await queue_manager.add_waiting(request_id)
wait_interval = 2.0 # Send waiting update every 2 seconds
last_wait_update = time.time()
# Send initial waiting message
data = {
"id": completion_id,
"object": "chat.completion.chunk",
"created": created,
"model": model_name,
"choices": [{
"index": 0,
"delta": {"content": "Waiting for model to load..."},
"finish_reason": None,
}],
"x_queue_info": {
"status": "waiting",
"message": "Model is loading, please wait...",
},
}
yield f"data: {json.dumps(data)}\n\n"
# Keep sending wait updates until model is loaded
# In a real implementation, this would check a loading status
# For now, we'll send a few updates then proceed
max_wait_updates = 5
wait_count = 0
while wait_count < max_wait_updates:
await asyncio.sleep(wait_interval)
wait_time = await queue_manager.get_wait_time(request_id)
wait_count += 1
queue_pos = await queue_manager.get_queue_position(request_id)
data = {
"id": completion_id,
"object": "chat.completion.chunk",
"created": created,
"model": model_name,
"choices": [{
"index": 0,
"delta": {"content": f""},
"finish_reason": None,
}],
"x_queue_info": {
"status": "waiting",
"message": f"Waiting for model... ({int(wait_time)}s)",
"queue_position": queue_pos,
"wait_time_seconds": int(wait_time),
},
}
yield f"data: {json.dumps(data)}\n\n"
# Mark as starting processing
await queue_manager.start_processing(request_id, model_name)
# Send "Model starting" message
data = {
"id": completion_id,
"object": "chat.completion.chunk",
"created": created,
"model": model_name,
"choices": [{
"index": 0,
"delta": {"content": ""},
"finish_reason": None,
}],
"x_queue_info": {
"status": "starting",
"message": "Model starting",
},
}
yield f"data: {json.dumps(data)}\n\n"
try:
chunk_count = 0
# Debug: Print what is being passed to the model
if global_debug:
print(f"\n{'='*80}")
print(f"=== MODEL INPUT (DEBUG) ===")
print(f"{'='*80}")
print(f"Model: {model_name}")
print(f"Max tokens: {max_tokens}")
print(f"Temperature: {temperature}")
print(f"Top P: {top_p}")
print(f"Stop sequences: {stop}")
print(f"Tools: {tools is not None}")
print(f"Response format: {response_format}")
print(f"\n--- Messages ---")
for i, msg in enumerate(messages):
role = msg.get('role', 'unknown')
content = msg.get('content', '')
if content and len(content) > 500:
content = content[:500] + "... [truncated]"
print(f"[{i}] {role}: {repr(content)}")
print(f"{'='*80}\n")
# Use generate_chat_stream for proper chat template handling
async for chunk in current_manager.generate_chat_stream(
messages=messages,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
stop=stop,
tools=tools,
response_format=response_format,
):
chunk_count += 1
# Always filter malformed content (regex-based, works per-chunk)
filtered_chunk = filter_malformed_content(chunk)
# NOTE: filter_repetition() and strip_tool_calls_from_content() are NOT applied
# per-chunk because they need the full accumulated text to work correctly:
# - filter_repetition() needs enough context (6+ words) to detect n-gram repetitions
# - strip_tool_calls_from_content() needs complete XML tags that span multiple chunks
# Both are applied to the complete generated_text after streaming completes.
# Pass through all content including whitespace - it's essential for message composition
generated_text += filtered_chunk
data = {
"id": completion_id,
"object": "chat.completion.chunk",
"created": created,
"model": model_name,
"choices": [{
"index": 0,
"delta": {"content": filtered_chunk},
"finish_reason": None,
}],
}
yield f"data: {json.dumps(data)}\n\n"
# Explicitly flush to ensure data is sent immediately
await asyncio.sleep(0)
print(f"DEBUG: stream_chat_response completed, {chunk_count} chunks, generated_text length: {len(generated_text)}")
if not generated_text.strip():
print(f"DEBUG: Warning - no content generated!")
# In debug mode, dump the full generated text
if global_debug:
print(f"\n{'='*80}")
print(f"=== FULL GENERATED TEXT (DEBUG) ===")
print(f"{'='*80}")
# Show both raw (actual) content and escaped representation
print(f"--- RAW CONTENT (actual newlines shown as lines) ---")
print(generated_text)
print(f"--- END RAW CONTENT ---")
print(f"--- ESCAPED CONTENT (repr() - shows \\n for newlines) ---")
print(repr(generated_text))
print(f"--- END ESCAPED CONTENT ---")
print(f"{'='*80}\n")
# Check for tool calls in complete output (for API response format)
if tools:
# Convert tools back to Tool objects for parsing
from typing import cast
tool_objects = []
for t in tools:
try:
# Handle both dict and pydantic model formats
if isinstance(t, dict):
func_data = t.get("function", {})
tool_func = ToolFunction(
name=func_data.get("name", ""),
description=func_data.get("description"),
parameters=func_data.get("parameters")
)
else:
# Pydantic model
tool_func = ToolFunction(
name=t.function.name if hasattr(t.function, 'name') else str(t.function),
description=t.function.description if hasattr(t.function, 'description') else None,
parameters=t.function.parameters if hasattr(t.function, 'parameters') else None
)
tool_objects.append(Tool(type=t.get("type", "function") if isinstance(t, dict) else t.type, function=tool_func))
except Exception as e:
print(f"DEBUG: Error converting tool: {e}, tool type: {type(t)}")
continue
try:
tool_calls = tool_parser.extract_tool_calls(generated_text, tool_objects)
# FIX: Validate extracted tool calls have valid JSON (stream_chat_response)
if tool_calls:
from codai.models.parser import validate_json_complete
validated_calls = []
for tc in tool_calls:
args = tc.get('function', {}).get('arguments', '{}')
if isinstance(args, str) and validate_json_complete(args):
validated_calls.append(tc)
elif isinstance(args, dict):
validated_calls.append(tc)
if len(validated_calls) != len(tool_calls):
print(f"DEBUG: Filtered out {len(tool_calls) - len(validated_calls)} invalid tool calls in stream_chat_response")
tool_calls = validated_calls if validated_calls else None
except Exception as e:
print(f"DEBUG: Error extracting tool calls: {e}")
tool_calls = None
if tool_calls:
# In debug mode, dump tool calls
if global_debug:
print(f"\n{'='*80}")
print(f"=== EXTRACTED TOOL CALLS (DEBUG) ===")
print(f"{'='*80}")
print(json.dumps(tool_calls, indent=2))
print(f"{'='*80}\n")
# Tool calls were extracted and stripped from content during streaming
# Just send the tool_calls chunk
data = {
"id": completion_id,
"object": "chat.completion.chunk",
"created": created,
"model": model_name,
"choices": [{
"index": 0,
"delta": {"tool_calls": tool_calls},
"finish_reason": "tool_calls",
"logprobs": None,
"native_finish_reason": "tool_calls",
}],
}
yield f"data: {json.dumps(data)}\n\n"
else:
# Calculate token counts for usage in final chunk
prompt_text = "\n".join([m.get("content", "") for m in messages])
prompt_tokens = len(prompt_text.split())
completion_tokens = len(generated_text.split()) if generated_text else 0
# Use OpenAIFormatter for final chunk sanitization
formatter = OpenAIFormatter(model_name)
usage_details = {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens,
}
final_chunk = formatter.format_litellm_chunk("", is_final=True, usage=usage_details)
yield f"data: {json.dumps(final_chunk)}\n\n"
else:
# Calculate token counts for usage in final chunk
prompt_text = "\n".join([m.get("content", "") for m in messages])
prompt_tokens = len(prompt_text.split())
completion_tokens = len(generated_text.split()) if generated_text else 0
# Build complete final chunk with all OpenAI fields
final_chunk = {
"id": completion_id,
"object": "chat.completion.chunk",
"created": created,
"model": model_name,
"choices": [{
"index": 0,
"finish_reason": "stop",
"logprobs": None,
"native_finish_reason": "stop",
}],
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens,
"prompt_tokens_details": {
"cached_tokens": 0,
"audio_tokens": 0,
},
"completion_tokens_details": {
"reasoning_tokens": 0,
"audio_tokens": 0,
},
},
"provider": {
"provider_name": "coderai",
"provider_id": "coderai",
},
"system_fingerprint": None,
}
yield f"data: {json.dumps(final_chunk)}\n\n"
yield "data: [DONE]\n\n"
except Exception as e:
print(f"Error during streaming generation: {e}")
data = {
"id": completion_id,
"object": "chat.completion.chunk",
"created": created,
"model": model_name,
"choices": [{
"index": 0,
"delta": {"content": f"\n[Generation error: {str(e)}]"},
"finish_reason": "stop",
}],
}
yield f"data: {json.dumps(data)}\n\n"
yield "data: [DONE]\n\n"
finally:
# Always clean up queue state
await queue_manager.finish_processing()
async def generate_chat_response(
messages: List[Dict],
model_name: str,
max_tokens: Optional[int],
temperature: float,
top_p: float,
stop: List[str],
tools: Optional[List[Dict]],
current_manager: ModelManager,
tool_parser: ToolCallParser,
response_format: Optional[Dict] = None,
force_reasoning_args: Optional[List[str]] = None,
) -> Dict:
"""Generate non-streaming chat completion response."""
completion_id = f"chatcmpl-{uuid.uuid4().hex}"
created = int(time.time())
# Debug: Print what is being passed to the model
if global_debug:
print(f"\n{'='*80}")
print(f"=== MODEL INPUT (DEBUG) ===")
print(f"{'='*80}")
print(f"Model: {model_name}")
print(f"Max tokens: {max_tokens}")
print(f"Temperature: {temperature}")
print(f"Top P: {top_p}")
print(f"Stop sequences: {stop}")
print(f"Tools: {tools is not None}")
print(f"Response format: {response_format}")
print(f"\n--- Messages ---")
for i, msg in enumerate(messages):
role = msg.get('role', 'unknown')
content = msg.get('content', '')
if content and len(content) > 500:
content = content[:500] + "... [truncated]"
print(f"[{i}] {role}: {repr(content)}")
print(f"{'='*80}\n")
try:
# Use generate_chat for proper chat template handling
generated_text = current_manager.generate_chat(
messages=messages,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
stop=stop,
tools=tools,
response_format=response_format,
)
# Always filter out malformed content
generated_text = filter_malformed_content(generated_text)
# Apply repetition filtering to prevent infinite loops
generated_text = filter_repetition(generated_text)
# Dump raw output if enabled
global_dump = getattr(global_args, 'dump', False) if global_args else False
if global_dump:
print(f"\n{'='*80}")
print(f"=== RAW MODEL OUTPUT (DUMP) ===")
print(f"{'='*80}")
print(generated_text)
print(f"{'='*80}\n")
response_message = {
"role": "assistant",
"content": generated_text,
}
finish_reason = "stop"
# Check for tool calls
if tools:
# Convert tools back to Tool objects for parsing
tool_objects = []
for t in tools:
try:
# Handle both dict and pydantic model formats
if isinstance(t, dict):
func_data = t.get("function", {})
tool_func = ToolFunction(
name=func_data.get("name", ""),
description=func_data.get("description"),
parameters=func_data.get("parameters")
)
else:
# Pydantic model
tool_func = ToolFunction(
name=t.function.name if hasattr(t.function, 'name') else str(t.function),
description=t.function.description if hasattr(t.function, 'description') else None,
parameters=t.function.parameters if hasattr(t.function, 'parameters') else None
)
tool_objects.append(Tool(type=t.get("type", "function") if isinstance(t, dict) else t.type, function=tool_func))
except Exception as e:
print(f"DEBUG: Error converting tool: {e}, tool type: {type(t)}")
continue
try:
tool_calls = tool_parser.extract_tool_calls(generated_text, tool_objects)
# FIX: Validate extracted tool calls have valid JSON (generate_chat_response)
if tool_calls:
from codai.models.parser import validate_json_complete
validated_calls = []
for tc in tool_calls:
args = tc.get('function', {}).get('arguments', '{}')
if isinstance(args, str) and validate_json_complete(args):
validated_calls.append(tc)
elif isinstance(args, dict):
validated_calls.append(tc)
if len(validated_calls) != len(tool_calls):
print(f"DEBUG: Filtered out {len(tool_calls) - len(validated_calls)} invalid tool calls in generate_chat_response")
tool_calls = validated_calls if validated_calls else None
except Exception as e:
print(f"DEBUG: Error extracting tool calls: {e}")
tool_calls = None
if tool_calls:
# Always strip tool call format from content
clean_content = tool_parser.strip_tool_calls_from_content(generated_text)
response_message["content"] = clean_content if clean_content.strip() else None
response_message["tool_calls"] = tool_calls
finish_reason = "tool_calls"
# Calculate token counts - rough estimate since we don't have direct access to tokenizer
prompt_text = "\n".join([m.get("content", "") for m in messages])
prompt_tokens = len(prompt_text.split())
completion_tokens = len(generated_text.split()) if generated_text else 0
# Use OpenAIFormatter for final sanitization
formatter = OpenAIFormatter(model_name)
formatted_response = formatter.format_litellm_full(
text=response_message.get("content", ""),
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
tool_calls=response_message.get("tool_calls")
)
# Add mock reasoning stats if 'mock' is in force_reasoning_args
# But only if we don't already have real reasoning in the response
# Check if reasoning already exists in the message
existing_reasoning = None
if "choices" in formatted_response and formatted_response["choices"]:
choice = formatted_response["choices"][0]
if "message" in choice:
existing_reasoning = choice["message"].get("reasoning")
if force_reasoning_args and "mock" in force_reasoning_args and formatted_response and not existing_reasoning:
# Add fake reasoning tokens to trigger VSCode plugin stats
mock_reasoning_tokens = 50
# Update usage
if "usage" in formatted_response:
formatted_response["usage"]["completion_tokens"] += mock_reasoning_tokens
formatted_response["usage"]["total_tokens"] += mock_reasoning_tokens
formatted_response["usage"]["completion_tokens_details"] = {
"reasoning_tokens": mock_reasoning_tokens
}
# Add reasoning to message if not present
if "choices" in formatted_response and formatted_response["choices"]:
choice = formatted_response["choices"][0]
if "message" in choice and "reasoning" not in choice["message"]:
choice["message"]["reasoning"] = "Processing task in optimized mode..."
# Dump parsed output if enabled
if global_dump:
import json
print(f"\n{'='*80}")
print(f"=== PARSED OUTPUT (DUMP) ===")
print(f"{'='*80}")
print(json.dumps(formatted_response, indent=2))
print(f"{'='*80}\n")
return formatted_response
except Exception as e:
print(f"Error during generation: {e}")
raise HTTPException(status_code=500, detail=f"Generation error: {str(e)}")
# =============================================================================
# Legacy Text Completions Endpoint (/v1/completions)
# =============================================================================
# NOTE: This is a legacy endpoint for backward compatibility.
# It uses raw text completion (no chat template) instead of the modern
# /v1/chat/completions API. Consider using /v1/chat/completions instead.
# =============================================================================
from codai.pydantic.textrequest import CompletionRequest
@router.post("/v1/completions")
async def completions(request: CompletionRequest):
"""Legacy text completions endpoint (for backward compatibility)."""
# Get the model for this request
requested_model = request.model
# Try to get the appropriate model
mm = multi_model_manager.get_model_for_request(requested_model)
if mm is None:
# Model not loaded - try to use default
if model_manager.backend is not None:
# Fallback to legacy model_manager
current_manager = model_manager
else:
raise HTTPException(status_code=503, detail="Model not loaded")
else:
current_manager = mm
prompts = request.prompt if isinstance(request.prompt, list) else [request.prompt]
stop_sequences = []
if request.stop:
stop_sequences = [request.stop] if isinstance(request.stop, str) else request.stop
if request.stream:
from fastapi.responses import StreamingResponse
return StreamingResponse(
stream_completion_response(
prompts[0],
request.model,
request.max_tokens,
request.temperature,
request.top_p,
stop_sequences,
current_manager,
),
media_type="text/event-stream",
)
else:
return await generate_completion_response(
prompts[0],
request.model,
request.max_tokens,
request.temperature,
request.top_p,
stop_sequences,
current_manager,
)
async def stream_completion_response(
prompt: str,
model_name: str,
max_tokens: Optional[int],
temperature: float,
top_p: float,
stop: List[str],
current_manager: ModelManager,
) -> AsyncGenerator[str, None]:
"""Stream legacy completion response."""
completion_id = f"cmpl-{uuid.uuid4().hex}"
created = int(time.time())
try:
async for chunk in current_manager.generate_stream(
prompt=prompt,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
stop=stop,
):
data = {
"id": completion_id,
"object": "text_completion",
"created": created,
"model": model_name,
"choices": [{
"text": chunk,
"index": 0,
"logprobs": None,
"finish_reason": None,
}],
}
yield f"data: {json.dumps(data)}\n\n"
yield f"data: {json.dumps({'choices': [{'finish_reason': 'stop'}]})}\n\n"
yield "data: [DONE]\n\n"
except Exception as e:
print(f"Error during streaming completion: {e}")
yield f"data: {json.dumps({'choices': [{'finish_reason': 'stop'}]})}\n\n"
yield "data: [DONE]\n\n"
async def generate_completion_response(
prompt: str,
model_name: str,
max_tokens: Optional[int],
temperature: float,
top_p: float,
stop: List[str],
current_manager: ModelManager,
) -> Dict:
"""Generate non-streaming legacy completion response."""
completion_id = f"cmpl-{uuid.uuid4().hex}"
created = int(time.time())
try:
generated_text = current_manager.generate(
prompt=prompt,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
stop=stop,
)
# Calculate token counts if tokenizer available
if current_manager.tokenizer:
prompt_tokens = len(current_manager.tokenizer.encode(prompt))
completion_tokens = len(current_manager.tokenizer.encode(generated_text))
else:
prompt_tokens = len(prompt.split())
completion_tokens = len(generated_text.split())
return {
"id": completion_id,
"object": "text_completion",
"created": created,
"model": model_name,
"choices": [{
"text": generated_text,
"index": 0,
"logprobs": None,
"finish_reason": "stop",
}],
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens,
},
}
except Exception as e:
print(f"Error during completion: {e}")
raise HTTPException(status_code=500, detail=f"Generation error: {str(e)}")
"""
Audio transcription endpoint for the codai API.
"""
import io
import os
import tempfile
from fastapi import APIRouter, HTTPException, UploadFile, File, Form
from typing import Optional
# Import from codai modules
from codai.models.manager import multi_model_manager
# Global reference to be set by coderai
global_args = None
def set_global_args(args):
"""Set global args from coderai."""
global global_args
global_args = args
# =============================================================================
# Router and Endpoints
# =============================================================================
router = APIRouter()
@router.post("/v1/audio/transcriptions")
async def create_transcription(
model: str = Form(...),
file: UploadFile = File(...),
language: Optional[str] = Form(None),
prompt: Optional[str] = Form(None),
response_format: Optional[str] = Form("json"),
temperature: Optional[float] = Form(0.0),
):
"""
Audio transcription endpoint (OpenAI-compatible).
"""
# Check if whisper-server is available FIRST
if multi_model_manager.whisper_server and multi_model_manager.whisper_server.is_running():
file_content = await file.read()
result = multi_model_manager.whisper_server.transcribe(
file_content,
language=language,
prompt=prompt
)
if "error" in result:
raise HTTPException(status_code=500, detail=result["error"])
return {"text": result.get("text", "")}
audio_model = multi_model_manager.audio_model
if not audio_model:
raise HTTPException(
status_code=400,
detail="Audio transcription not configured. Use --audio-model or --whisper-server."
)
# Determine model to use
model_to_use = model
if model_to_use.startswith("whisper:") or model_to_use.startswith("audio:"):
model_to_use = audio_model
# Read the uploaded file
file_content = await file.read()
# Save to temp file (needed for some backends)
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as tmp:
tmp.write(file_content)
tmp_path = tmp.name
try:
# Try faster-whisper first
try:
from faster_whisper import WhisperModel
# Determine model key
model_key = f"audio:{model_to_use}"
whisper_model = multi_model_manager.get_model(model_key)
if whisper_model is None:
print(f"Loading faster-whisper model: {model_to_use}")
# Determine compute type - always use int8 for CPU
compute_type = "int8"
# Load the model
whisper_model = WhisperModel(
model_to_use,
device="cpu", # Always use CPU - faster-whisper CUDA doesn't work with AMD
compute_type=compute_type,
)
# Cache the model
multi_model_manager.add_model(model_key, whisper_model)
print(f"Loaded faster-whisper model: {model_to_use}")
# Run transcription
segments, info = whisper_model.transcribe(
tmp_path,
language=language,
initial_prompt=prompt,
temperature=temperature,
)
# Collect all segments
text_parts = []
for segment in segments:
text_parts.append(segment.text)
full_text = "".join(text_parts)
return {
"text": full_text.strip()
}
except ImportError:
pass
# Try whispercpp as fallback
try:
import whispercpp
# Determine model key
model_key = f"audio:{model_to_use}"
whisper_model = multi_model_manager.get_model(model_key)
if whisper_model is None:
print(f"Loading whispercpp model: {model_to_use}")
# Check if it's a built-in model name
if model_to_use in ['tiny.en', 'tiny', 'base.en', 'base', 'small.en', 'small', 'medium.en', 'medium', 'large-v1', 'large']:
# It's a built-in model name
whisper_model = whispercpp.Whisper.from_pretrained(model_to_use)
else:
# It's a path to a GGUF file
whisper_model = whispercpp.Whisper.from_pretrained(model_to_use)
# Cache the model
multi_model_manager.add_model(model_key, whisper_model)
print(f"Loaded whispercpp model: {model_to_use}")
# Run transcription
result = whisper_model.transcribe(tmp_path)
# Extract text from result
text = ""
if hasattr(result, 'text'):
text = result.text
elif isinstance(result, dict):
text = result.get('text', '')
elif isinstance(result, list):
# Some versions return a list of segments
for segment in result:
if hasattr(segment, 'text'):
text += segment.text
elif isinstance(segment, dict):
text += segment.get('text', '')
return {
"text": text.strip()
}
except ImportError as e:
raise HTTPException(
status_code=501,
detail="Audio transcription not available. Install faster-whisper or whispercpp."
)
except Exception as e:
print(f"Transcription error: {e}")
import traceback
traceback.print_exc()
raise HTTPException(status_code=500, detail=f"Transcription error: {str(e)}")
finally:
# Clean up temp file
try:
os.unlink(tmp_path)
except Exception:
pass
"""
Text-to-speech endpoints for the codai API.
"""
import base64
import os
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel, ConfigDict
# Import from codai modules
from codai.models.manager import multi_model_manager
# Global reference to be set by coderai
global_args = None
def get_cached_model_path(url: str) -> str:
"""Get cached model path if available."""
from codai.models.manager import multi_model_manager
return multi_model_manager.get_cached_model_path(url)
def get_model_cache_dir() -> str:
"""Get model cache directory."""
from codai.models.manager import multi_model_manager
return multi_model_manager.get_model_cache_dir()
def set_global_args(args):
"""Set global args from coderai."""
global global_args
global_args = args
# =============================================================================
# Router and Endpoints
# =============================================================================
router = APIRouter()
class TTSRequest(BaseModel):
model: str
input: str
voice: str = "af_sarah"
response_format: str = "mp3"
speed: float = 1.0
model_config = ConfigDict(extra="allow")
class TTSResponse(BaseModel):
audio: str # base64 encoded audio
model_config = ConfigDict(extra="allow")
@router.post("/v1/audio/speech")
async def create_speech(request: TTSRequest):
"""
Text-to-speech endpoint (OpenAI-compatible).
Supports:
- Kokoro TTS models (when --tts-model is specified)
"""
tts_model = multi_model_manager.tts_model
# If no TTS model configured, return an error
if not tts_model:
raise HTTPException(
status_code=400,
detail="TTS not configured. Use --tts-model to specify a model."
)
# Determine model to use
model_to_use = request.model
if model_to_use.startswith("tts:"):
model_to_use = tts_model
# Try to use kokoro if available
try:
from kokoro import Kokoro
# Determine model key
model_key = f"tts:{model_to_use}"
kokoro_model = multi_model_manager.get_model(model_key)
if kokoro_model is None:
print(f"Loading Kokoro TTS model: {model_to_use}")
# Check if model_to_use is a URL - download it (with caching)
model_path = None
if model_to_use.startswith('http://') or model_to_use.startswith('https://'):
# Check cache first
cached_path = get_cached_model_path(model_to_use)
if cached_path:
model_path = cached_path
print(f"Using cached model: {model_path}")
else:
print(f"Downloading model from URL: {model_to_use}")
try:
import requests
import hashlib
# Get cache directory
cache_dir = get_model_cache_dir()
# Extract filename from URL
url_path = model_to_use.split('?')[0]
filename = os.path.basename(url_path)
if not filename.endswith('.pt') and not filename.endswith('.bin'):
filename = "kokoro-model.pt"
# Create safe filename in cache
url_hash = hashlib.sha256(model_to_use.encode()).hexdigest()
cached_filename = f"{url_hash}_{filename}"
model_path = os.path.join(cache_dir, cached_filename)
# Download to cache
response = requests.get(model_to_use, stream=True)
response.raise_for_status()
total_size = int(response.headers.get('content-length', 0))
downloaded = 0
with open(model_path, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192*1024):
if chunk:
f.write(chunk)
downloaded += len(chunk)
if total_size > 0:
percent = (downloaded / total_size) * 100
print(f"Downloaded: {percent:.1f}%", end='\r')
print(f"\nDownloaded and cached to: {model_path}")
except Exception as e:
print(f"Error downloading model: {e}")
raise
else:
# Use local path or model name
model_path = model_to_use
# Load the Kokoro model
kokoro_model = Kokoro(model_path if model_path else model_to_use)
multi_model_manager.add_model(model_key, kokoro_model)
# Generate speech
voice = request.voice or "af_sarah"
speed = request.speed or 1.0
audio_bytes = kokoro_model.generate(request.input, voice=voice, speed=speed)
# Convert to base64
audio_base64 = base64.b64encode(audio_bytes).decode('utf-8')
return {
"audio": audio_base64
}
except ImportError as e:
# kokoro not installed
raise HTTPException(
status_code=501,
detail=f"TTS not available. Install kokoro: pip install kokoro. Error: {str(e)}"
)
except Exception as e:
print(f"TTS error: {e}")
import traceback
traceback.print_exc()
raise HTTPException(status_code=500, detail=f"TTS error: {str(e)}")
"""Command-line argument parsing for codai server."""
import argparse
def parse_args():
"""Parse command line arguments."""
parser = argparse.ArgumentParser(
description="OpenAI-compatible API server supporting NVIDIA (CUDA) and Vulkan backends"
)
parser.add_argument(
"--model",
type=str,
action="append",
default=None,
help="Model name, path, or URL for text-to-text LLM. Can be specified multiple times for multiple models.",
)
parser.add_argument(
"--model-alias",
type=str,
action="append",
default=None,
dest="model_aliases",
nargs=2,
metavar=("ALIAS", "MODEL"),
help="Register an alias for a model. Format: --model-alias <alias_name> <actual_model>",
)
parser.add_argument(
"--backend",
type=str,
choices=["auto", "nvidia", "vulkan", "opencl"],
default="auto",
help="Backend to use: auto (detect), nvidia (CUDA), vulkan (AMD), or opencl",
)
parser.add_argument(
"--image-backend",
type=str,
choices=["auto", "nvidia", "vulkan", "opencl"],
default="auto",
help="Image generation backend: auto, nvidia (CUDA), vulkan (AMD), or opencl",
)
parser.add_argument(
"--audio-backend",
type=str,
choices=["auto", "nvidia", "vulkan", "opencl"],
default="auto",
help="Audio transcription backend: auto, nvidia (CUDA), vulkan (AMD), or opencl",
)
parser.add_argument(
"--tts-backend",
type=str,
choices=["auto", "nvidia", "vulkan", "opencl"],
default="auto",
help="TTS backend: auto, nvidia (CUDA), vulkan (AMD), or opencl",
)
parser.add_argument(
"--host",
type=str,
default="0.0.0.0",
help="Host to bind to (default: 0.0.0.0)",
)
parser.add_argument(
"--port",
type=int,
default=8000,
help="Port to bind to (default: 8000)",
)
parser.add_argument(
"--url",
type=str,
default="auto",
help="Base URL for media downloads: 'auto' (use request IP) or explicit URL (e.g., http://myserver:8000)",
)
parser.add_argument(
"--https",
action="store_true",
help="Enable HTTPS with auto-generated certificate",
)
parser.add_argument(
"--privkey",
type=str,
default=None,
help="Path to HTTPS private key file",
)
parser.add_argument(
"--pubkey",
type=str,
default=None,
help="Path to HTTPS certificate file",
)
parser.add_argument(
"--offload-dir",
type=str,
default="./offload",
help="Directory for disk offload (NVIDIA backend only, default: ./offload)",
)
parser.add_argument(
"--load-in-4bit",
action="store_true",
help="Load model in 4-bit precision (NVIDIA backend only, requires bitsandbytes)",
)
parser.add_argument(
"--load-in-8bit",
action="store_true",
help="Load model in 8-bit precision (NVIDIA backend only, requires bitsandbytes)",
)
parser.add_argument(
"--ram",
type=float,
default=None,
help="Maximum CPU RAM to use for model offloading in GB (NVIDIA backend only). Auto-detected if not specified. Disk offloading only occurs after this limit is exceeded.",
)
parser.add_argument(
"--flash-attn",
action="store_true",
help="Use Flash Attention 2 (NVIDIA backend only, requires flash-attn package)",
)
parser.add_argument(
"--offload-strategy",
type=str,
choices=["auto", "conservative", "balanced", "aggressive", "sequential"],
default="auto",
help="Offload strategy for NVIDIA backend (default: auto)",
)
parser.add_argument(
"--max-gpu-percent",
type=float,
default=None,
help="Maximum GPU VRAM to use as percentage (0-100). Overrides offload-strategy. Lower values offload more to CPU/RAM (default: None = use offload-strategy)",
)
parser.add_argument(
"--n-gpu-layers",
type=int,
default=-1,
help="Number of layers to offload to GPU (Vulkan backend only, default: -1 = all layers)",
)
parser.add_argument(
"--n-ctx",
type=int,
action="append",
default=None,
help="Context window size (Vulkan backend). Can be specified multiple times, one per --model.",
)
parser.add_argument(
"--vulkan-device",
type=int,
default=0,
help="Vulkan GPU device ID to use (Vulkan backend only, default: 0). Use --vulkan-list-devices to see available devices",
)
parser.add_argument(
"--vulkan-single-gpu",
action="store_true",
help="Force Vulkan to use only the specified GPU device (prevents layer distribution across multiple GPUs)",
)
parser.add_argument(
"--vulkan-list-devices",
action="store_true",
help="List available Vulkan GPU devices and exit",
)
parser.add_argument(
"--hf-chat-template",
action="append",
default=[],
help="Use HuggingFace apply_chat_template. Examples: --hf-chat-template auto (all models), --hf-chat-template text (all text), --hf-chat-template mymodel:llama3 (specific model with template). Can be repeated.",
)
parser.add_argument(
"--system-prompt",
nargs="?",
const=True,
default=None,
help="Inject a system prompt at the beginning of conversations. Use without a value for a default prompt, or provide custom text.",
)
# Multi-model arguments
parser.add_argument(
"--tts-model",
type=str,
default=None,
help="Model for text-to-speech (e.g., kokoro, or path/URL to Kokoro model). Can be specified multiple times.",
)
parser.add_argument(
"--audio-model",
type=str,
action="append",
default=None,
help="Model for audio transcription (e.g., whisper-1, base, or path to faster-whisper model). Can be specified multiple times for multiple models.",
)
parser.add_argument(
"--audio-1",
action="store_true",
help="Disable request queue for audio models - return 409 if model is busy",
)
parser.add_argument(
"--image-model",
type=str,
action="append",
default=None,
help="Model for image generation (e.g., stable-diffusion-xl-base-1.0). Can be specified multiple times for multiple models.",
)
parser.add_argument(
"--vision-model",
type=str,
action="append",
default=None,
help="Model for image/video-to-text (e.g., llava-1.5, LLaVA). Supports vulkan and cuda backends.",
)
parser.add_argument(
"--image-1",
action="store_true",
help="Disable request queue for image models - return 409 if model is busy",
)
parser.add_argument(
"--llm-path",
type=str,
default=None,
help="Path to CLIP LLM model for image generation (stable-diffusion-cpp-python).",
)
parser.add_argument(
"--vae-path",
type=str,
default=None,
help="Path to VAE model for image generation (stable-diffusion-cpp-python).",
)
parser.add_argument(
"--image-sample-method",
type=str,
default="res_multistep",
help="Sample method for image generation (default: res_multistep for Z-Image Turbo).",
)
parser.add_argument(
"--image-steps",
type=int,
default=4,
help="Number of inference steps for image generation (default: 4 for Z-Image Turbo).",
)
parser.add_argument(
"--image-width",
type=int,
default=512,
help="Image width for generation (default: 512).",
)
parser.add_argument(
"--image-height",
type=int,
default=512,
help="Image height for generation (default: 512).",
)
parser.add_argument(
"--image-cfg-scale",
type=float,
default=1.0,
help="CFG scale for image generation (default: 1.0 for Z-Image Turbo).",
)
parser.add_argument(
"--image-precision",
type=str,
default="f32",
choices=["bf16", "f32", "f16", "f8"],
help="Model precision for image generation (default: f32). bf16 recommended for modern GPUs.",
)
parser.add_argument(
"--image-cpu-offload",
action="store_true",
help="Enable sequential CPU offload for image models (lower VRAM usage).",
)
parser.add_argument(
"--image-seed",
type=int,
default=None,
help="Default seed for image generation (default: random).",
)
parser.add_argument(
"--vae-tiling",
action="store_true",
help="Enable VAE tiling for lower VRAM usage (sd.cpp only).",
)
parser.add_argument(
"--clip-on-cpu",
action="store_true",
help="Run CLIP on CPU to save VRAM (sd.cpp only).",
)
parser.add_argument(
"--loadall",
action="store_true",
help="Pre-load all models (main, audio, image) at startup instead of on-demand",
)
parser.add_argument(
"--loadswap",
action="store_true",
help="Keep all models loaded, swapping active model between VRAM and RAM (only active model in VRAM)",
)
parser.add_argument(
"--nopreload",
action="store_true",
help="Disable model preloading. Models will load on first request instead of at startup",
)
parser.add_argument(
"--audio-ctx",
type=int,
action="append",
default=None,
help="Audio model context size in milliseconds. Can be specified multiple times, one per --audio-model.",
)
parser.add_argument(
"--audio-offload",
type=float,
default=None,
help="Audio model GPU offload percentage (0-100). If not set, uses CPU",
)
parser.add_argument(
"--audio-vulkan-device",
type=int,
default=0,
help="Vulkan GPU device ID to use for Whisper audio transcription (default: 0). Only used when using Vulkan backend.",
)
parser.add_argument(
"--image-vulkan-device",
type=int,
default=None,
help="Vulkan GPU device ID to use for image generation models (default: same as --vulkan-device). Use --vulkan-list-devices to see available devices",
)
parser.add_argument(
"--whisper-cpp",
type=str,
default=None,
help="Path to whisper.cpp CLI executable (e.g., ~/whisper.cpp/build/bin/whisper-cli). Uses Vulkan if available.",
)
parser.add_argument(
"--whisper-server",
type=str,
default=None,
help="Path to whisper.cpp server executable (e.g., ~/whisper.cpp/build/bin/whisper-server). Keeps model loaded in VRAM.",
)
parser.add_argument(
"--whisper-server-port",
type=int,
default=8744,
help="Port for whisper-server (default: 8744).",
)
parser.add_argument(
"--image-ctx",
type=int,
action="append",
default=None,
help="Image model context size. Can be specified multiple times, one per --image-model.",
)
parser.add_argument(
"--image-offload",
type=float,
default=None,
help="Vision model GPU offload percentage (0-100). If not set, loads fully on GPU",
)
parser.add_argument(
"--list-cached-models",
action="store_true",
help="List all cached models in the model cache directory",
)
parser.add_argument(
"--remove-all-models",
action="store_true",
help="Remove all cached models from the model cache directory",
)
parser.add_argument(
"--remove-model",
type=str,
default=None,
help="Remove a specific cached model by name or hash (partial match)",
)
parser.add_argument(
"--debug",
action="store_true",
help="Enable debug mode - dumps full request/response to stdout for troubleshooting",
)
parser.add_argument(
"--dump",
action="store_true",
help="Dump model output: raw output, parsed output, and litellm debug info",
)
parser.add_argument(
"--file-path",
type=str,
default=None,
help="Path to store generated files (images, audio). If specified, files will be saved here and served over web.",
)
parser.add_argument(
"--parser",
type=str,
default="auto",
choices=["auto", "litellm"],
help="Tool call parser to use: 'auto' for internal parser, 'litellm' for LiteLLM's parser. Default: auto",
)
# Custom type for comma-separated reasoning options
def reasoning_choices(value):
if not value:
return []
options = [v.strip().lower() for v in value.split(',')]
valid = {'chat', 'stop', 'inject', 'prompt', 'all', 'twopass', 'mock', 'raw'}
invalid = [o for o in options if o not in valid]
if invalid:
raise argparse.ArgumentTypeError(f"Invalid choices: {invalid}. Valid options: {valid}")
# Expand 'all' to all options
if 'all' in options:
options = ['chat', 'inject', 'prompt', 'mock', 'raw', 'twopass']
return options
parser.add_argument(
"--force-reasoning",
type=reasoning_choices,
default=None,
help="Force reasoning. Options: 'chat' (API), 'stop' (tokens), 'inject' (sys prompt), 'prompt' (seeding), 'twopass' (2 calls), 'mock' (fake stats), 'raw' (raw completion), 'all' (all options). Combine: --force-reasoning chat,inject.",
)
parser.add_argument(
"--grammar-guided-gen",
"--ggg",
action="store_true",
default=False,
help="Enable grammar-guided generation to reduce model hallucinations when using tools. Uses GBNF grammar for Vulkan backend and outlines for CUDA backend.",
)
parser.add_argument(
"--tools-closer-prompt",
action="store_true",
default=False,
help="Enable prompt distillation: place tool definitions right before the user's latest request instead of in the system prompt. This can improve tool call accuracy.",
)
return parser.parse_args()
"""Main entry point for codai server."""
import sys
import os
# Import configuration from codai modules
from codai.cli import parse_args
def main():
"""Main entry point for the codai server."""
# Suppress unraisable exceptions from LlamaModel.__del__
original_unraisablehook = sys.unraisablehook
def suppress_llama_del_errors(unraisable):
if isinstance(unraisable.exc_value, AttributeError) and 'LlamaModel' in repr(unraisable.object) and 'sampler' in str(unraisable.exc_value):
return # Ignore this specific error
original_unraisablehook(unraisable)
sys.unraisablehook = suppress_llama_del_errors
# Optional: set process name if procname is available
try:
import procname
procname.setprocname("codai")
except ImportError:
pass
args = parse_args()
# Import globals from codai modules
from codai.api import app
from codai.api.app import set_global_args
from codai.models.manager import ModelManager, MultiModelManager
from codai.backends import detect_available_backends
from codai.models.cache import (
get_all_cache_dirs,
get_cached_model_path,
get_model_cache_dir,
download_model,
)
# Store args globally for access in endpoints
set_global_args(args)
# Import global setters from text module
from codai.api.text import (
set_global_debug,
set_global_system_prompt,
set_global_tools_closer_prompt,
)
from codai.api.app import set_load_mode
# Set global variables
global global_system_prompt, global_tools_closer_prompt, global_debug, global_dump, global_file_path, grammar_guided_gen
# Set global grammar-guided-gen flag
from codai.models.grammar import set_grammar_guided_gen
grammar_guided_gen = args.grammar_guided_gen
if grammar_guided_gen:
print("Grammar-guided generation enabled (--grammar-guided-gen)")
# Set global system prompt from --system-prompt flag
global_system_prompt = args.system_prompt
set_global_system_prompt(global_system_prompt)
# Set global tools-closer-prompt flag
global_tools_closer_prompt = args.tools_closer_prompt
set_global_tools_closer_prompt(global_tools_closer_prompt)
if global_tools_closer_prompt:
print("Tools closer prompt enabled (--tools-closer-prompt)")
# Set global debug flag
global_debug = args.debug
set_global_debug(global_debug)
# Set global dump flag (enables debug as well for litellm output)
global_dump = args.dump
if global_dump:
global_debug = True
set_global_debug(True)
# Set global file path for storing generated files
global_file_path = args.file_path
from codai.api.app import set_global_file_path
set_global_file_path(global_file_path)
if global_debug:
# Print the full command line that was used to invoke codai
import shlex
cmd_line = ' '.join(shlex.quote(arg) for arg in sys.argv)
print(f"\n{'='*80}")
print(f"=== COMMAND LINE: {cmd_line}")
print(f"{'='*80}\n")
print("DEBUG MODE ENABLED - Full requests and replies will be dumped to stdout")
# Handle --vulkan-list-devices
if args.vulkan_list_devices:
print("\nListing Vulkan devices...")
try:
import subprocess
result = subprocess.run(['vulkaninfo', '--summary'], capture_output=True, text=True)
if result.returncode == 0:
print(result.stdout)
else:
print("Could not run vulkaninfo. Make sure vulkan-tools is installed.")
except Exception as e:
print(f"Error listing devices: {e}")
sys.exit(0)
# Handle --list-cached-models
if args.list_cached_models:
print("\n=== Listing Cached Models ===")
caches = get_all_cache_dirs()
if not caches:
print("No model cache directories found.")
sys.exit(0)
all_files = []
for cache_name, cache_dir in caches.items():
print(f"\n--- {cache_name.upper()} Cache ({cache_dir}) ---")
if not os.path.exists(cache_dir):
print(f" (directory does not exist)")
continue
files = os.listdir(cache_dir)
if not files:
print(f" No cached files.")
continue
# For diffusers and huggingface, show directory structure
if cache_name in ('diffusers', 'huggingface'):
for root, dirs, files in os.walk(cache_dir):
for f in files:
filepath = os.path.join(root, f)
rel_path = os.path.relpath(filepath, cache_dir)
size = os.path.getsize(filepath)
all_files.append((cache_name, rel_path, size))
else:
for f in files:
filepath = os.path.join(cache_dir, f)
if os.path.isfile(filepath):
size = os.path.getsize(filepath)
all_files.append((cache_name, f, size))
if not all_files:
print("\nNo cached models found.")
sys.exit(0)
# Calculate totals
total_size = sum(size for _, _, size in all_files)
print(f"\n=== Summary ===")
print(f"Total: {len(all_files)} files, {total_size / (1024*1024*1024):.2f} GB")
print("\nCache locations:")
for cache_name, cache_dir in caches.items():
print(f" {cache_name}: {cache_dir}")
sys.exit(0)
# Handle --remove-all-models
if args.remove_all_models:
print("\n=== Removing All Cached Models ===")
import shutil
caches = get_all_cache_dirs()
if not caches:
print("No cache directories found.")
sys.exit(0)
total_removed = 0
for cache_name, cache_dir in caches.items():
if not os.path.exists(cache_dir):
continue
files = os.listdir(cache_dir)
if not files:
continue
print(f"\nRemoving from {cache_name} cache ({cache_dir})...")
print(f" Found {len(files)} file(s). Deleting...")
# For diffusers, remove entire directory tree
if cache_name == 'diffusers':
for item in os.listdir(cache_dir):
item_path = os.path.join(cache_dir, item)
if os.path.isdir(item_path):
shutil.rmtree(item_path)
else:
os.remove(item_path)
print(f" Deleted: {item}")
total_removed += 1
else:
for f in files:
filepath = os.path.join(cache_dir, f)
os.remove(filepath)
print(f" Deleted: {f}")
total_removed += 1
print(f"\n=== Removed {total_removed} item(s) from all caches ===")
sys.exit(0)
# Handle --remove-model
if args.remove_model:
print(f"\n=== Removing Cached Model Matching: {args.remove_model} ===")
import shutil
caches = get_all_cache_dirs()
if not caches:
print("No cache directories found.")
sys.exit(0)
all_matching = []
for cache_name, cache_dir in caches.items():
if not os.path.exists(cache_dir):
continue
# For diffusers and huggingface, search recursively
if cache_name in ('diffusers', 'huggingface'):
for root, dirs, files in os.walk(cache_dir):
for f in files:
if args.remove_model.lower() in f.lower():
filepath = os.path.join(root, f)
rel_path = os.path.relpath(filepath, cache_dir)
size = os.path.getsize(filepath)
all_matching.append((cache_name, rel_path, filepath, size))
else:
files = os.listdir(cache_dir)
for f in files:
if args.remove_model.lower() in f.lower():
filepath = os.path.join(cache_dir, f)
if os.path.isfile(filepath):
size = os.path.getsize(filepath)
all_matching.append((cache_name, f, filepath, size))
if not all_matching:
print(f"No cached models found matching: {args.remove_model}")
print(f"\nUse --list-cached-models to see available models.")
sys.exit(0)
print(f"\nFound {len(all_matching)} matching file(s):")
for cache_name, filename, filepath, size in all_matching:
print(f" [{cache_name}] {filename} ({size / (1024*1024):.1f} MB)")
# Confirm before deleting
print(f"\nDeleting {len(all_matching)} file(s)...")
for cache_name, filename, filepath, size in all_matching:
try:
os.remove(filepath)
print(f" Deleted: [{cache_name}] {filename}")
except Exception as e:
print(f" Failed to delete {filename}: {e}")
print(f"\nRemoved {len(all_matching)} cached model file(s).")
sys.exit(0)
# Get model names from args - support multiple models
model_names = args.model if args.model else []
# Helper function to get config value by index with fallback
def get_ctx_by_index(ctx_list, index, default):
"""Get context value by model index, with fallback to default."""
if ctx_list and index < len(ctx_list):
return ctx_list[index]
return default
# Validate: must have at least one model specified
audio_models = args.audio_model if args.audio_model else []
image_models = args.image_model if args.image_model else []
vision_models = args.vision_model if args.vision_model else []
if not model_names and not audio_models and not image_models and not vision_models and args.tts_model is None:
print("Error: At least one of --model, --audio-model, --image-model, --vision-model, or --tts-model must be specified.")
print("")
print("For NVIDIA backend (HuggingFace models):")
print(" - microsoft/DialoGPT-medium")
print(" - meta-llama/Llama-2-7b-chat-hf (requires auth)")
print(" - TinyLlama/TinyLlama-1.1B-Chat-v1.0")
print(" - Use multiple --model flags for multiple models")
print("")
print("For Vulkan backend (GGUF models):")
print(" - Local path: ./phi-3-mini-4k-instruct-q4_k_m.gguf")
print(" - Or a HuggingFace model ID: TheBloke/Mistral-7B-Instruct-v0.2-GGUF")
print(" - Use multiple --model flags for multiple models")
print("")
sys.exit(1)
# Determine load mode
load_mode = None
if args.loadall:
load_mode = "loadall"
elif args.loadswap:
load_mode = "loadswap"
elif args.nopreload:
load_mode = "nopreload"
if load_mode:
set_load_mode(load_mode)
# Initialize model manager
print("\n=== Initializing Model Manager ===")
# Detect available backends
available_backends = detect_available_backends()
print(f"Available backends: {available_backends}")
# Determine which backend to use
backend = args.backend
if backend == "auto":
if "nvidia" in available_backends:
backend = "nvidia"
elif "vulkan" in available_backends:
backend = "vulkan"
elif "opencl" in available_backends:
backend = "opencl"
else:
print("Error: No supported backend detected (NVIDIA CUDA, AMD Vulkan, or OpenCL)")
sys.exit(1)
print(f"Using backend: {backend}")
# Initialize model manager based on backend
model_manager = ModelManager(backend=backend)
# Initialize multi-model manager
multi_model_manager = MultiModelManager(model_manager)
# Store references globally for API endpoints
from codai.api import app as fastapi_app
fastapi_app.state.model_manager = model_manager
fastapi_app.state.multi_model_manager = multi_model_manager
# Load main text model(s)
if model_names:
print(f"\nLoading main text model(s): {model_names}")
# Register models with multi_model_manager
for idx, model_name in enumerate(model_names):
multi_model_manager.set_model(model_name, {
'ctx': get_ctx_by_index(args.n_ctx, idx, 0),
})
# Load first model
try:
mm = multi_model_manager.get_model_for_request(model_names[0])
if mm is not None:
print(f"Model loaded successfully: {model_names[0]}")
else:
print(f"Warning: Model {model_names[0]} not loaded (will load on first request)")
except Exception as e:
print(f"Warning: Failed to load model: {e}")
print(f"Model will load on first request")
# Set up audio model if specified
if audio_models:
print(f"\nAudio transcription model(s): {audio_models}")
for idx, audio_m in enumerate(audio_models):
multi_model_manager.set_audio_model(audio_m, {
'ctx': get_ctx_by_index(args.audio_ctx, idx, 0),
'offload': args.audio_offload,
})
# Set up image model if specified
if image_models:
print(f"\nImage generation model(s): {image_models}")
for idx, img_m in enumerate(image_models):
multi_model_manager.set_image_model(img_m, {
'ctx': get_ctx_by_index(args.image_ctx, idx, 0),
'offload': args.image_offload,
'llm_path': args.llm_path,
'vae_path': args.vae_path,
'sample_method': args.image_sample_method,
'steps': args.image_steps,
'width': args.image_width,
'height': args.image_height,
'cfg_scale': args.image_cfg_scale,
})
# Set up vision model if specified
if vision_models:
print(f"\nVision model(s): {vision_models}")
for idx, vision_m in enumerate(vision_models):
multi_model_manager.set_vision_model(vision_m, {
'ctx': get_ctx_by_index(args.n_ctx, idx, 0),
'offload': args.image_offload,
})
# Set up TTS model if specified
if args.tts_model:
print(f"\nText-to-speech model: {args.tts_model}")
multi_model_manager.set_tts_model(args.tts_model, {})
# Register model aliases if specified
if args.model_aliases:
print(f"\nRegistering model aliases:")
for alias, model in args.model_aliases:
multi_model_manager.set_model_alias(alias, model)
print(f" {alias} -> {model}")
# Start the server
import uvicorn
print(f"\nStarting server on http://{args.host}:{args.port}")
print(f"API documentation available at http://{args.host}:{args.port}/docs")
if model_manager.backend is not None:
actual_backend = model_manager.backend_type
if hasattr(model_manager.backend, 'force_cuda') and model_manager.backend.force_cuda:
actual_backend = "cuda (via llama-cpp-python)"
print(f"Using backend: {actual_backend}")
# Print available models
models = multi_model_manager.list_models()
print(f"Available models: {[m.id for m in models]}")
# Run server with or without HTTPS
if args.https:
import ssl
ssl_keyfile = None
ssl_certfile = None
if args.privkey and args.pubkey:
ssl_keyfile = args.privkey
ssl_certfile = args.pubkey
print(f"Using HTTPS with custom certificates: {args.pubkey}")
else:
print("Generating self-signed HTTPS certificate...")
import subprocess
try:
cert_path = "./cert.pem"
key_path = "./key.pem"
subprocess.run([
"openssl", "req", "-x509", "-newkey", "rsa:4096",
"-keyout", key_path, "-out", cert_path,
"-days", "365", "-nodes",
"-subj", "/CN=localhost"
], check=True, capture_output=True)
ssl_keyfile = key_path
ssl_certfile = cert_path
print(f"Generated self-signed certificate: {cert_path}")
except Exception as e:
print(f"Warning: Could not generate certificate: {e}")
print("Falling back to HTTP...")
uvicorn.run(app, host=args.host, port=args.port)
return
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
ssl_context.load_cert_chain(ssl_certfile, ssl_keyfile)
uvicorn.run(app, host=args.host, port=args.port, ssl=ssl_context)
else:
uvicorn.run(app, host=args.host, port=args.port)
if __name__ == "__main__":
main()
# codai.models - Model parsing and templates
from .manager import (
ModelManager,
WhisperServerManager,
MultiModelManager,
model_manager,
multi_model_manager,
)
from .parser import (
ModelParserDispatcher,
BaseParser,
......@@ -16,12 +23,20 @@ from .parser import (
ToolCallParser,
ModelParserAdapter,
filter_repetition,
filter_malformed_content,
cleanup_control_tokens,
validate_json_complete,
format_tools_for_prompt,
)
from .templates import AgenticTemplateManager
__all__ = [
'ModelManager',
'WhisperServerManager',
'MultiModelManager',
'model_manager',
'multi_model_manager',
'ModelParserDispatcher',
'BaseParser',
'QwenParser',
......@@ -39,5 +54,8 @@ __all__ = [
'ModelParserAdapter',
'AgenticTemplateManager',
'filter_repetition',
'filter_malformed_content',
'cleanup_control_tokens',
'validate_json_complete',
'format_tools_for_prompt',
]
......@@ -660,3 +660,12 @@ class MultiModelManager:
models.append(ModelInfo(id=alias))
return models
# =============================================================================
# Singleton Instances
# =============================================================================
# Global singleton instances for convenience
model_manager = ModelManager()
multi_model_manager = MultiModelManager()
......@@ -1299,6 +1299,90 @@ def filter_malformed_content(text: str) -> str:
return filtered
# =============================================================================
# Control Token Cleanup
# =============================================================================
# Pre-compiled pattern for cleanup_control_tokens newline cleanup
_RE_TRIPLE_NEWLINE = re.compile(r'\n{3,}')
# Control tokens to strip - defined once at module level
_CONTROL_TOKENS = [
'<|im_end|>',
'<|im_start|>',
'<|endoftext|>',
'<|end_of_text|>',
'<|eot_id|>',
'<|eom_id|>',
'<|assistant|>',
'<|model|>',
'<|python|>',
'<|javascript|>',
'<|html|>',
'\n\nassistant',
'\nAssistant',
'ASSISTANT',
'Assistant',
'assistant',
]
# Build expanded token set for O(1) lookup (includes \n and space prefixed variants)
_CONTROL_TOKENS_START = set()
_CONTROL_TOKENS_END = set()
for _t in _CONTROL_TOKENS:
_CONTROL_TOKENS_START.update([_t, '\n' + _t, ' ' + _t])
_CONTROL_TOKENS_END.update([_t, '\n' + _t, ' ' + _t])
# Sort by length descending so longer tokens match first (avoids partial matches)
_CONTROL_TOKENS_START_SORTED = sorted(_CONTROL_TOKENS_START, key=len, reverse=True)
_CONTROL_TOKENS_END_SORTED = sorted(_CONTROL_TOKENS_END, key=len, reverse=True)
def cleanup_control_tokens(text: str) -> str:
"""
Clean up leading/trailing control tokens from model output.
Removes tokens like <|im_end|>, <|im_start|>, 'assistant', etc. that might
appear at the start or end of the response after reasoning extraction.
Uses module-level sorted token lists for efficient matching.
Tokens are sorted by length (longest first) to prevent partial matches
like 'assistant' matching before '<|assistant|>'.
"""
if not text:
return text
cleaned = text
# Strip from start - keep trying until no more tokens at start
# Max iterations bounded by text length / min token length
changed = True
while changed:
changed = False
for token in _CONTROL_TOKENS_START_SORTED:
if cleaned.startswith(token):
cleaned = cleaned[len(token):]
changed = True
break # Restart from longest token after each removal
# Strip from end - keep trying until no more tokens at end
changed = True
while changed:
changed = False
for token in _CONTROL_TOKENS_END_SORTED:
if cleaned.endswith(token):
cleaned = cleaned[:-len(token)]
changed = True
break # Restart from longest token after each removal
# Clean up any resulting triple+ newlines (pre-compiled pattern)
cleaned = _RE_TRIPLE_NEWLINE.sub('\n\n', cleaned)
# Strip leading/trailing whitespace
cleaned = cleaned.strip()
return cleaned
def filter_repetition(text: str, min_repeat_count: int = 3, ngram_sizes: tuple = (2, 3)) -> str:
"""
Detect and remove n-gram repetition from text.
......@@ -1478,8 +1562,18 @@ def validate_json_complete(json_str: str) -> bool:
# Tool Formatting
# =============================================================================
def format_tools_for_prompt(tools, messages):
"""Format tools into the system message or add a tool description."""
def format_tools_for_prompt(tools, messages, tools_closer_prompt: bool = False):
"""Format tools into the system message or add a tool description.
Args:
tools: List of Tool objects
messages: List of ChatMessage objects
tools_closer_prompt: If True, place tools right before the user's latest message
instead of in the system prompt (prompt distillation)
Returns:
Modified list of ChatMessage objects
"""
import json
if not tools:
......@@ -1512,19 +1606,47 @@ def format_tools_for_prompt(tools, messages):
# Add or prepend to system message
new_messages = list(messages)
system_found = False
for i, msg in enumerate(new_messages):
if msg.role == "system":
new_messages[i] = ChatMessage(
role="system",
content=f"{tools_text}\n\n{msg.content or ''}"
)
system_found = True
break
if not system_found:
new_messages.insert(0, ChatMessage(role="system", content=tools_text))
if tools_closer_prompt:
# Prompt distillation: insert tools right before the LAST user message
# Find the last user message and insert tools before it
last_user_idx = None
for i, msg in enumerate(new_messages):
if msg.role == "user":
last_user_idx = i
if last_user_idx is not None:
# Insert a tool context message before the last user message
tools_message = ChatMessage(role="system", content=f"Available tools:\n{tools_text}")
new_messages.insert(last_user_idx, tools_message)
else:
# No user message found, fall back to system message
system_found = False
for i, msg in enumerate(new_messages):
if msg.role == "system":
new_messages[i] = ChatMessage(
role="system",
content=f"{tools_text}\n\n{msg.content or ''}"
)
system_found = True
break
if not system_found:
new_messages.insert(0, ChatMessage(role="system", content=tools_text))
else:
# Traditional behavior: prepend tools to system message
system_found = False
for i, msg in enumerate(new_messages):
if msg.role == "system":
new_messages[i] = ChatMessage(
role="system",
content=f"{tools_text}\n\n{msg.content or ''}"
)
system_found = True
break
if not system_found:
new_messages.insert(0, ChatMessage(role="system", content=tools_text))
return new_messages
......
import time
import uuid
# Try to import litellm for response formatting
# Fall back to plain dicts if litellm is not available or doesn't export these
try:
from litellm import ModelResponse, ChatCompletionChunk
LITELLM_AVAILABLE = True
except ImportError:
LITELLM_AVAILABLE = False
ModelResponse = None
ChatCompletionChunk = None
class OpenAIFormatter:
"""Formatter for standardizing chat completion responses in OpenAI format.
This class provides final sanitization of responses before sending them
to clients. It processes the output of the internal parser and formats
them into proper OpenAI-compatible responses.
"""
def __init__(self, model_name: str):
self.model_name = model_name
self.id = f"chatcmpl-{uuid.uuid4()}"
def format_full(self, text: str, prompt_tokens: int, completion_tokens: int, tool_calls=None) -> dict:
"""Format a standard (non-streaming) response.
Args:
text: The generated text content
prompt_tokens: Number of tokens in the prompt
completion_tokens: Number of tokens in the completion
tool_calls: Optional list of tool calls to include
Returns:
Dictionary representation of the response
"""
message = {
"role": "assistant",
"content": text if not tool_calls else None,
}
if tool_calls:
message["tool_calls"] = tool_calls
choice = {
"index": 0,
"message": message,
"finish_reason": "tool_calls" if tool_calls else "stop",
}
return {
"id": self.id,
"object": "chat.completion",
"created": int(time.time()),
"model": self.model_name,
"choices": [choice],
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens,
},
"provider": {
"provider_name": "coderai",
"provider_id": "coderai",
},
}
def format_chunk(self, delta_text: str, is_final: bool = False, usage: dict = None) -> dict:
"""Format a streaming chunk response.
Args:
delta_text: The incremental text content for this chunk
is_final: Whether this is the final chunk
usage: Optional usage information (typically only sent on final chunk)
Returns:
Dictionary representation of the chunk
"""
delta = {
"content": delta_text,
"role": "assistant",
}
choice = {
"index": 0,
"delta": delta,
"finish_reason": "stop" if is_final else None,
}
chunk = {
"id": self.id,
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": self.model_name,
"choices": [choice],
}
if usage and is_final:
chunk["usage"] = usage
return chunk
def format_final_chunk(self, usage: dict = None) -> dict:
"""Format the final streaming chunk with usage information.
Args:
usage: Usage statistics dictionary with prompt_tokens, completion_tokens, total_tokens
Returns:
Dictionary representation of the final chunk
"""
delta = {
"content": None,
"role": "assistant",
}
choice = {
"index": 0,
"delta": delta,
"finish_reason": "stop",
}
chunk = {
"id": self.id,
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": self.model_name,
"choices": [choice],
}
if usage:
chunk["usage"] = usage
return chunk
def format_litellm_full(self, text: str, prompt_tokens: int, completion_tokens: int, tool_calls=None) -> dict:
"""Format using litellm's ModelResponse if available.
Args:
text: The generated text content
prompt_tokens: Number of tokens in the prompt
completion_tokens: Number of tokens in the completion
tool_calls: Optional list of tool calls to include
Returns:
Dictionary representation of ModelResponse
"""
if not LITELLM_AVAILABLE or ModelResponse is None:
return self.format_full(text, prompt_tokens, completion_tokens, tool_calls)
try:
from litellm import Choices, Message, Usage
return ModelResponse(
id=self.id,
model=self.model_name,
object="chat.completion",
created=int(time.time()),
choices=[Choices(
finish_reason="tool_calls" if tool_calls else "stop",
index=0,
message=Message(content=text if not tool_calls else None, role="assistant", tool_calls=tool_calls)
)],
usage=Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens
)
).model_dump()
except Exception:
# Fall back to plain dict if litellm fails
return self.format_full(text, prompt_tokens, completion_tokens, tool_calls)
def format_litellm_chunk(self, delta_text: str, is_final: bool = False, usage: dict = None) -> dict:
"""Format streaming chunk using litellm's ChatCompletionChunk if available.
Args:
delta_text: The incremental text content for this chunk
is_final: Whether this is the final chunk
usage: Optional usage information (typically only sent on final chunk)
Returns:
Dictionary representation of ChatCompletionChunk
"""
if not LITELLM_AVAILABLE or ChatCompletionChunk is None:
return self.format_chunk(delta_text, is_final, usage)
try:
from litellm import StreamingChoices, Delta, Usage
return ChatCompletionChunk(
id=self.id,
model=self.model_name,
object="chat.completion.chunk",
created=int(time.time()),
choices=[StreamingChoices(
finish_reason="stop" if is_final else None,
index=0,
delta=Delta(content=delta_text, role="assistant")
)],
usage=Usage(**usage) if usage else None
).model_dump()
except Exception:
# Fall back to plain dict if litellm fails
return self.format_chunk(delta_text, is_final, usage)
This source diff could not be displayed because it is too large. You can view the blob instead.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment