#!/usr/bin/env python3
"""
OpenAI-compatible API server for HuggingFace models (NVIDIA) and GGUF models (Vulkan).
Supports CUDA (NVIDIA) and Vulkan (AMD) GPU backends, memory-aware model loading,
streaming, and tool calling.
"""

import argparse
import asyncio
import hashlib
import json
import os
import pathlib
import re
import sys
import time
import uuid
import warnings
import requests
from abc import ABC, abstractmethod
from contextlib import asynccontextmanager
from typing import AsyncGenerator, Dict, List, Optional, Union

import psutil
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field, validator, field_validator, ConfigDict
from pydantic_core import PydanticCustomError
from threading import Thread
# Per-model semaphores for request concurrency control
model_semaphores: dict = {}
load_mode = {"mode": "ondemand"}  # Track load mode globally
queue_flags = {"model_1": False, "image_1": False, "audio_1": False, "tts_1": False}  # Track --X-1 flags
# =============================================================================
# Model Cache Directory
# =============================================================================

def get_model_cache_dir() -> str:
    """Get or create the model cache directory."""
    # Use XDG_CACHE_HOME if set, otherwise use ~/.cache/coderai
    cache_home = os.environ.get('XDG_CACHE_HOME', os.path.expanduser('~/.cache'))
    cache_dir = os.path.join(cache_home, 'coderai', 'models')
    pathlib.Path(cache_dir).mkdir(parents=True, exist_ok=True)
    return cache_dir

def get_cached_model_path(url: str) -> Optional[str]:
    """Check if a model URL is already cached. Returns path if cached, None otherwise."""
    cache_dir = get_model_cache_dir()
    # Create a safe filename from the URL using SHA-256 hash
    url_hash = hashlib.sha256(url.encode()).hexdigest()
    url_path = url.split('?')[0]  # Remove query params
    ext = os.path.splitext(url_path)[1] or '.gguf'
    cached_filename = f"{url_hash}{ext}"
    cached_path = os.path.join(cache_dir, cached_filename)
    
    if os.path.exists(cached_path):
        print(f"Using cached model: {cached_path}")
        return cached_path
    return None
def is_huggingface_model_id(path: str) -> bool:
    """Check if the path is a Hugging Face model ID (e.g., 'Qwen/Qwen3-4B-Instruct-2507-Q3_K_S')."""
    # Must contain / but not be a URL
    return '/' in path and not path.startswith('http://') and not path.startswith('https://')
def download_huggingface_model(model_id: str, cache_dir: str, file_pattern: str = '.gguf') -> Optional[str]:
    """Download a model from Hugging Face by model ID. Returns cached path or None on failure."""
    try:
        from huggingface_hub import hf_hub_download, list_repo_files
        
        print(f"Downloading from Hugging Face: {model_id}")
        files = list_repo_files(model_id)
        
        # Filter files by extension pattern
        matching_files = [f for f in files if f.endswith(file_pattern)]
        if not matching_files:
            print(f"No {file_pattern} files found in {model_id}")
            return None
        
        # Download the first matching file
        filename = matching_files[0]
        print(f"Downloading: {filename}")
        model_path = hf_hub_download(repo_id=model_id, filename=filename, cache_dir=cache_dir)
        print(f"Downloaded model to: {model_path}")
        return model_path
    except Exception as e:
        print(f"Error downloading from Hugging Face: {e}")
        return None
def download_model(url: str, cache_dir: str) -> str:
    """Download a model from URL with progress reporting. Returns cached path."""
    import requests
    import hashlib
    
    url_path = url.split('?')[0]
    filename = os.path.basename(url_path)
    
    # Determine file extension
    if 'gguf' in url.lower():
        ext = '.gguf'
    elif 'bin' in url.lower():
        ext = '.bin'
    elif 'ggml' in url.lower():
        ext = '.ggml'
    else:
        ext = '.bin'
    
    if not filename.endswith(ext):
        filename = f"model{ext}"
    
    # Create safe filename in cache
    url_hash = hashlib.sha256(url.encode()).hexdigest()
    cached_filename = f"{url_hash}_{filename}"
    model_path = os.path.join(cache_dir, cached_filename)
    
    # Check if already cached
    if os.path.exists(model_path):
        print(f"Using cached model: {model_path}")
        return model_path
    
    # Download
    print(f"Downloading model: {url}")
    response = requests.get(url, stream=True)
    response.raise_for_status()
    
    total_size = int(response.headers.get('content-length', 0))
    total_mb = total_size / (1024 * 1024) if total_size > 0 else 0
    
    downloaded = 0
    start_time = time.time()
    
    with open(model_path, 'wb') as f:
        for chunk in response.iter_content(chunk_size=8192*1024):
            if chunk:
                f.write(chunk)
                downloaded += len(chunk)
                if total_size > 0:
                    percent = (downloaded / total_size) * 100
                    elapsed = time.time() - start_time
                    speed = downloaded / elapsed if elapsed > 0 else 0
                    speed_mb = speed / (1024 * 1024)
                    dl_mb = downloaded / (1024 * 1024)
                    print(f"Downloaded: {percent:.1f}% ({dl_mb:.1f}/{total_mb:.1f} MB) at {speed_mb:.1f} MB/s", end='\r')
    
    print()  # New line after progress
    print(f"Downloaded and cached to: {model_path}")
    if total_mb > 0:
        print(f"File size: {total_mb:.1f} MB")
    
    return model_path
# =============================================================================
# Backend Detection and Imports
# =============================================================================

def detect_available_backends():
    """Detect which backends are available."""
    backends = {'cpu': True}
    
    # Check for PyTorch/CUDA
    try:
        import torch
        if torch.cuda.is_available():
            backends['nvidia'] = True
    except ImportError:
        pass
    
    # Check for llama-cpp-python (Vulkan)
    try:
        import llama_cpp
        backends['vulkan'] = True
    except ImportError:
        pass
    
    return backends
# =============================================================================
# Flash Attention Detection (for NVIDIA backend)
# =============================================================================

def check_flash_attn_availability() -> bool:
    """Check if flash-attn is installed and available."""
    try:
        import flash_attn
        return True
    except ImportError:
        return False
# =============================================================================
# Pydantic Models for API
# =============================================================================

class ToolFunction(BaseModel):
    name: str
    description: Optional[str] = None
    parameters: Optional[Dict] = None
class Tool(BaseModel):
    type: str = "function"
    function: ToolFunction
class ChatMessage(BaseModel):
    role: str
    content: Optional[Union[str, List[Dict]]] = None
    name: Optional[str] = None
    tool_calls: Optional[List[Dict]] = None
    tool_call_id: Optional[str] = None
    
    @field_validator('content', mode='before')
    @classmethod
    def convert_content_array_to_string(cls, v):
        """Convert multipart content array to string for compatibility."""
        if v is None:
            return None
        if isinstance(v, str):
            return v
        if isinstance(v, list):
            # Handle multipart content array format (e.g., from KiloCode)
            # Format: [{"type": "text", "text": "..."}, {"type": "text", "text": "..."}]
            parts = []
            for item in v:
                if isinstance(item, dict):
                    if item.get('type') == 'text' and 'text' in item:
                        parts.append(item['text'])
                    else:
                        # Handle other content types (image_url, etc.) by converting to placeholder
                        parts.append(f"[{item.get('type', 'unknown')} content]")
                else:
                    parts.append(str(item))
            return '\n'.join(parts)
        return str(v)
class ChatCompletionRequest(BaseModel):
    model: str
    messages: List[ChatMessage]
    temperature: float = 0.7
    top_p: float = 1.0
    n: int = 1
    max_tokens: Optional[int] = None
    stream: bool = False
    stop: Optional[Union[str, List[str]]] = None
    presence_penalty: float = 0.0
    frequency_penalty: float = 0.0
    tools: Optional[List[Tool]] = None
    tool_choice: Optional[Union[str, Dict]] = "auto"
    # Extra fields that clients may send but we ignore
    seed: Optional[int] = None
    logprobs: Optional[bool] = None
    top_logprobs: Optional[int] = None
    response_format: Optional[Dict] = None
    user: Optional[str] = None
    
    model_config = ConfigDict(extra="allow")  # Allow extra fields to prevent 422 errors
class CompletionRequest(BaseModel):
    model: str
    prompt: Union[str, List[str]]
    temperature: float = 0.7
    top_p: float = 1.0
    n: int = 1
    max_tokens: Optional[int] = None
    stream: bool = False
    stop: Optional[Union[str, List[str]]] = None
    presence_penalty: float = 0.0
    frequency_penalty: float = 0.0
    # Extra fields that clients may send but we ignore
    seed: Optional[int] = None
    logprobs: Optional[bool] = None
    top_logprobs: Optional[int] = None
    best_of: Optional[int] = None
    echo: Optional[bool] = None
    user: Optional[str] = None
    
    model_config = ConfigDict(extra="allow")  # Allow extra fields to prevent 422 errors
class ModelInfo(BaseModel):
    id: str
    object: str = "model"
    created: int = Field(default_factory=lambda: int(time.time()))
    owned_by: str = "huggingface"
class ModelList(BaseModel):
    object: str = "list"
    data: List[ModelInfo]
# =============================================================================
# Audio Transcription Models
# =============================================================================

class TranscriptionRequest(BaseModel):
    model: str
    file: Optional[bytes] = None
    file_path: Optional[str] = None
    language: Optional[str] = None
    prompt: Optional[str] = None
    response_format: Optional[str] = "json"
    temperature: Optional[float] = 0.0
    timestamp_granularities: Optional[List[str]] = None
    
    model_config = ConfigDict(extra="allow")
class TranscriptionResponse(BaseModel):
    text: str
    model_config = ConfigDict(extra="allow")
# =============================================================================
# Image Generation Models
# =============================================================================

class ImageGenerationRequest(BaseModel):
    model: str
    prompt: str
    n: int = 1
    size: Optional[str] = "1024x1024"
    quality: Optional[str] = "standard"
    style: Optional[str] = None
    response_format: Optional[str] = "url"
    seed: Optional[int] = None
    user: Optional[str] = None
    
    model_config = ConfigDict(extra="allow")
class ImageGenerationResponse(BaseModel):
    created: int
    data: List[Dict]
    model_config = ConfigDict(extra="allow")
# =============================================================================
# Content Filtering Utility
# =============================================================================

def filter_malformed_content(text: str) -> str:
    """Filter out malformed SEARCH/REPLACE blocks that the model might output as content."""
    if not text:
        return text
    
    # Remove diff-like blocks that shouldn't be in the output
    filtered = text
    
    # Remove git-style diff markers and SEARCH/REPLACE patterns
    filtered = re.sub(r'<<<<<<<\s+SEARCH.*?=======', '', filtered, flags=re.DOTALL)
    filtered = re.sub(r'=======.*?>>>>>>>\s+REPLACE', '', filtered, flags=re.DOTALL)
    filtered = re.sub(r'>>>>>>>\s+REPLACE', '', filtered)
    
    # Also remove common malformed patterns seen in outputs
    filtered = re.sub(r'<<<<<<<\s+SEARCH\s*:start_line:\d+[^<]*', '', filtered, flags=re.DOTALL)
    filtered = re.sub(r'<button>Stop Generation</button>', '', filtered)
    filtered = re.sub(r'\<\|assistant\|\>', '', filtered)
    filtered = re.sub(r'\</\|assistant\|\>', '', filtered)
    
    # Clean up excessive newlines left from removal
    filtered = re.sub(r'\n{3,}', '\n\n', filtered)
    
    # Don't strip single newlines or whitespace - they might be valid content
    return filtered
# =============================================================================
# Tool Parsing
# =============================================================================

class ToolCallParser:
    """Parse model outputs to extract tool calls."""
    
    def __init__(self, tokenizer=None):
        self.tokenizer = tokenizer
    
    def _parse_nested_xml_tool(self, xml_content: str) -> Optional[Dict]:
        """Parse nested XML tool format like <tool><name>...</name><arguments>...</arguments></tool>."""
        try:
            # Extract name
            name_match = re.search(r'<name>\s*(.*?)\s*</name>', xml_content, re.DOTALL)
            if not name_match:
                return None
            tool_name = name_match.group(1).strip()
            
            # Extract arguments - handle both JSON and nested XML
            args_match = re.search(r'<arguments>\s*(.*?)\s*</arguments>', xml_content, re.DOTALL)
            if not args_match:
                return None
            args_content = args_match.group(1).strip()
            
            # Try to parse as JSON first
            try:
                args_dict = json.loads(args_content)
                return {"name": tool_name, "arguments": args_dict}
            except json.JSONDecodeError:
                # Try to parse as nested XML (e.g., <files><path>...</path></files>)
                try:
                    args_dict = self._xml_to_dict(args_content)
                    return {"name": tool_name, "arguments": args_dict}
                except:
                    # Return raw string as fallback
                    return {"name": tool_name, "arguments": args_content}
        except Exception:
            return None
    
    def _xml_to_dict(self, xml_content: str) -> Dict:
        """Convert simple nested XML to dictionary."""
        result = {}
        # Find all top-level tags
        pattern = r'<(\w+)>\s*(.*?)\s*</\1>'
        matches = re.findall(pattern, xml_content, re.DOTALL)
        
        for tag, content in matches:
            # Check if content has nested tags
            if re.search(r'<\w+>', content):
                # Recursively parse nested content
                try:
                    result[tag] = self._xml_to_dict(content)
                except:
                    # If recursive parsing fails, check for array-like content
                    items = re.findall(r'<(\w+)>\s*(.*?)\s*</\1>', content, re.DOTALL)
                    if items and all(item[0] == items[0][0] for item in items):
                        # Array of items with same tag
                        result[tag] = []
                        for _, item_content in items:
                            if re.search(r'<\w+>', item_content):
                                result[tag].append(self._xml_to_dict(item_content))
                            else:
                                result[tag].append(item_content)
                    else:
                        result[tag] = content
            else:
                result[tag] = content
        
        return result if result else xml_content
    
    def _filter_malformed_content(self, text: str) -> str:
        """Filter out malformed SEARCH/REPLACE blocks - delegates to standalone function."""
        return filter_malformed_content(text)
    
    def extract_tool_calls(self, text: str, available_tools: List[Tool]) -> Optional[List[Dict]]:
        """Extract tool calls from model output."""
        # First filter out malformed content
        text = self._filter_malformed_content(text)
        
        tool_calls = []
        seen_signatures = set()  # Track seen tool calls to avoid duplicates
        
        # Look for function calls in various formats
        # Format 1: <tool> or <function> tags with JSON content
        tool_pattern = r'<(?:tool|function)>(.*?)</(?:tool|function)>'
        tool_matches = re.findall(tool_pattern, text, re.DOTALL)
        
        for match in tool_matches:
            # Try JSON format first
            try:
                tool_data = json.loads(match.strip())
                if 'name' in tool_data and 'arguments' in tool_data:
                    # Create a signature to detect duplicates
                    sig = (tool_data['name'], json.dumps(tool_data['arguments'], sort_keys=True))
                    if sig not in seen_signatures:
                        seen_signatures.add(sig)
                        tool_calls.append({
                            "id": f"call_{uuid.uuid4().hex[:16]}",
                            "type": "function",
                            "function": {
                                "name": tool_data["name"],
                                "arguments": json.dumps(tool_data["arguments"])
                            }
                        })
                    continue
            except json.JSONDecodeError:
                pass
            
            # Try nested XML format (Format 1b)
            try:
                tool_data = self._parse_nested_xml_tool(match)
                if tool_data and 'name' in tool_data and 'arguments' in tool_data:
                    args_str = json.dumps(tool_data["arguments"]) if isinstance(tool_data["arguments"], dict) else str(tool_data["arguments"])
                    sig = (tool_data['name'], args_str)
                    if sig not in seen_signatures:
                        seen_signatures.add(sig)
                        tool_calls.append({
                            "id": f"call_{uuid.uuid4().hex[:16]}",
                            "type": "function",
                            "function": {
                                "name": tool_data["name"],
                                "arguments": args_str
                            }
                        })
            except Exception:
                pass
        
        # Format 2: JSON with function_call key
        try:
            if "function_call" in text:
                json_match = re.search(r'\{[^}]*"function_call"[^}]*\}', text, re.DOTALL)
                if json_match:
                    data = json.loads(json_match.group())
                    if "function_call" in data:
                        fc = data["function_call"]
                        sig = (fc.get("name", ""), fc.get("arguments", "{}"))
                        if sig not in seen_signatures:
                            seen_signatures.add(sig)
                            tool_calls.append({
                                "id": f"call_{uuid.uuid4().hex[:16]}",
                                "type": "function",
                                "function": {
                                    "name": fc.get("name", ""),
                                    "arguments": fc.get("arguments", "{}")
                                }
                            })
        except (json.JSONDecodeError, AttributeError):
            pass
        
        # Format 3: Direct JSON function call
        for tool in available_tools:
            tool_name = tool.function.name
            pattern = rf'<{tool_name}>(.*?)</{tool_name}>'
            matches = re.findall(pattern, text, re.DOTALL)
            for match in matches:
                try:
                    args = json.loads(match.strip())
                    sig = (tool_name, json.dumps(args, sort_keys=True))
                    if sig not in seen_signatures:
                        seen_signatures.add(sig)
                        tool_calls.append({
                            "id": f"call_{uuid.uuid4().hex[:16]}",
                            "type": "function",
                            "function": {
                                "name": tool_name,
                                "arguments": json.dumps(args)
                            }
                        })
                except json.JSONDecodeError:
                    # Try to use the raw text as arguments
                    sig = (tool_name, match.strip())
                    if sig not in seen_signatures:
                        seen_signatures.add(sig)
                        tool_calls.append({
                            "id": f"call_{uuid.uuid4().hex[:16]}",
                            "type": "function",
                            "function": {
                                "name": tool_name,
                                "arguments": match.strip()
                            }
                        })
        
        return tool_calls if tool_calls else None

    def strip_tool_calls_from_content(self, text: str) -> str:
        """Remove tool call format from text after extracting tool calls."""
        if not text:
            return text
        
        # Use greedy matching with re.DOTALL to match across newlines and handle nested content
        # Remove <tool>...</tool> and <function>...</function> patterns
        text = re.sub(r'<tool>.*?</tool>', '', text, flags=re.DOTALL)
        text = re.sub(r'<function>.*?</function>', '', text, flags=re.DOTALL)
        
        # Also remove JSON format with greedy matching: <tool>{...}</tool>
        # Use greedy .* to match the entire JSON object including nested braces
        text = re.sub(r'<tool>\{.*?\}</tool>', '', text, flags=re.DOTALL)
        text = re.sub(r'<function>\{.*?\}</function>', '', text, flags=re.DOTALL)
        
        # More aggressive pattern: match <tool> followed by anything until </tool>
        # This handles cases where JSON might be split or malformed
        text = re.sub(r'<tool>[\s\S]*?</tool>', '', text)
        text = re.sub(r'<function>[\s\S]*?</function>', '', text)
        
        # Also remove common tool name tags like <read>...</read>, <exec>...</exec>
        for tool_name in ['read', 'write', 'exec', 'browser', 'message', 'web_search', 'web_fetch', 
                         'memory_search', 'memory_get', 'sessions_list', 'sessions_send', 'tts', 'canvas', 'nodes',
                         'read_file', 'write_file', 'exec', 'process', 'browser', 'message', 'web_search', 'web_fetch',
                         'tts', 'canvas', 'nodes', 'agents_list', 'sessions_list', 'sessions_history', 'sessions_spawn',
                         'subagents', 'session_status', 'memory_search', 'memory_get']:
            text = re.sub(rf'<{tool_name}>[\s\S]*?</{tool_name}>', '', text)
        
        # Clean up excessive newlines left from removal
        text = re.sub(r'\n{3,}', '\n\n', text)
        
        return text.strip()
def format_tools_for_prompt(tools: List[Tool], messages: List[ChatMessage]) -> List[ChatMessage]:
    """Format tools into the system message or add a tool description."""
    if not tools:
        return messages
    
    tool_descriptions = []
    for tool in tools:
        func = tool.function
        desc = f"Tool: {func.name}"
        if func.description:
            desc += f"\nDescription: {func.description}"
        if func.parameters:
            desc += f"\nParameters: {json.dumps(func.parameters, indent=2)}"
        tool_descriptions.append(desc)
    
    tools_text = "You have access to the following tools:\n\n" + "\n\n".join(tool_descriptions)
    tools_text += "\n\nIMPORTANT: When you need to use a tool, you MUST format your response EXACTLY as:\n"
    tools_text += '<tool>{"name": "tool_name", "arguments": {"param1": "value1", "param2": "value2"}}</tool>'
    tools_text += "\n\nRules:\n"
    tools_text += "1. The content inside <tool> tags must be valid JSON\n"
    tools_text += "2. Do NOT use nested XML tags like <name> or <arguments> - use JSON format only\n"
    tools_text += "3. The 'name' field must match one of the available tool names exactly\n"
    tools_text += "4. The 'arguments' field must be a JSON object with the required parameters\n"
    tools_text += "\nExample:\n"
    tools_text += 'User: Read the file example.txt\n'
    tools_text += 'Assistant: <tool>{"name": "read_file", "arguments": {"files": [{"path": "example.txt"}]}}</tool>'
    
    # Add or prepend to system message
    new_messages = list(messages)
    system_found = False
    
    for i, msg in enumerate(new_messages):
        if msg.role == "system":
            new_messages[i] = ChatMessage(
                role="system",
                content=f"{tools_text}\n\n{msg.content or ''}"
            )
            system_found = True
            break
    
    if not system_found:
        new_messages.insert(0, ChatMessage(role="system", content=tools_text))
    
    return new_messages
# =============================================================================
# Abstract Model Backend
# =============================================================================

class ModelBackend(ABC):
    """Abstract base class for model backends."""
    
    @abstractmethod
    def load_model(self, model_name: str, **kwargs) -> None:
        """Load the model."""
        pass
    
    @abstractmethod
    def generate(self, prompt: str, max_tokens: Optional[int] = None, 
                 temperature: float = 0.7, top_p: float = 1.0,
                 stop: Optional[List[str]] = None) -> str:
        """Generate text non-streaming."""
        pass
    
    @abstractmethod
    def generate_stream(self, prompt: str, max_tokens: Optional[int] = None,
                        temperature: float = 0.7, top_p: float = 1.0,
                        stop: Optional[List[str]] = None) -> AsyncGenerator[str, None]:
        """Generate text in streaming fashion."""
        pass
    
    @abstractmethod
    def format_messages(self, messages: List[ChatMessage]) -> str:
        """Format messages into a prompt string."""
        pass
    
    @abstractmethod
    def get_model_name(self) -> str:
        """Return the loaded model name."""
        pass
    
    @abstractmethod
    def cleanup(self) -> None:
        """Cleanup resources."""
        pass
# =============================================================================
# NVIDIA/HuggingFace Backend
# =============================================================================

class NvidiaBackend(ModelBackend):
    """Backend for NVIDIA GPUs using HuggingFace Transformers."""
    
    def __init__(self):
        self.model = None
        self.tokenizer = None
        self.model_name = None
        self.device = None
        self.use_flash_attn = False
        self.flash_attn_available = False
        
    def check_flash_attn_support(self) -> None:
        """Check and print Flash Attention availability status."""
        self.flash_attn_available = check_flash_attn_availability()
        if self.use_flash_attn:
            if self.flash_attn_available:
                print("Flash Attention 2: Available and enabled")
            else:
                print("Warning: Flash Attention 2 requested but not installed")
                print("Install with: pip install flash-attn --no-build-isolation")
                print("Falling back to standard attention")
                self.use_flash_attn = False
    
    def _detect_device(self) -> str:
        """Auto-detect available GPU or fall back to CPU."""
        import torch
        if torch.cuda.is_available():
            # Check for ROCm (HIP)
            if hasattr(torch.version, 'hip') and torch.version.hip is not None:
                print(f"ROCm/HIP detected: {torch.version.hip}")
                return "cuda"
            else:
                print(f"CUDA detected: {torch.version.cuda}")
                return "cuda"
        else:
            print("No GPU detected, using CPU")
            return "cpu"
    
    def _get_available_vram(self) -> int:
        """Get available VRAM in bytes. Returns 0 if no GPU available."""
        import torch
        if not torch.cuda.is_available():
            return 0
        
        try:
            total_vram = 0
            for i in range(torch.cuda.device_count()):
                props = torch.cuda.get_device_properties(i)
                total_vram += props.total_memory
            return total_vram
        except Exception as e:
            print(f"Warning: Could not detect VRAM: {e}")
            return 0
    
    def _estimate_model_size(self, model_name: str) -> Optional[int]:
        """Estimate model size in bytes from config."""
        from transformers import AutoConfig
        try:
            config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
            
            # Get model parameters from config
            if hasattr(config, 'num_parameters'):
                num_params = config.num_parameters
            elif hasattr(config, 'n_params'):
                num_params = config.n_params
            elif hasattr(config, 'num_hidden_layers') and hasattr(config, 'hidden_size'):
                layers = config.num_hidden_layers
                hidden = config.hidden_size
                vocab_size = getattr(config, 'vocab_size', 50000)
                num_params = (vocab_size * hidden_size) + (layers * 4 * hidden * hidden)
            else:
                return None
            
            # Assume float16 (2 bytes per parameter)
            return num_params * 2
        except Exception as e:
            print(f"Warning: Could not estimate model size: {e}")
            return None
    
    def _get_gpu_memory_map(self) -> Dict:
        """Get max_memory dict for Accelerate with 93% GPU limit, then CPU, then disk."""
        import torch
        max_memory = {}
        
        # GPU memory: 93% of available VRAM per GPU
        if torch.cuda.is_available():
            for i in range(torch.cuda.device_count()):
                props = torch.cuda.get_device_properties(i)
                total_vram = props.total_memory
                # Leave 7% headroom for CUDA overhead (changed from 0.1% to 7%)
                usable_vram = int(total_vram * 0.93)
                max_memory[i] = usable_vram
                print(f"  GPU {i}: {total_vram / 1e9:.1f}GB total, {usable_vram / 1e9:.1f}GB usable")
        
        # CPU memory: use manual limit or auto-detect
        manual_ram_gb = self._pending_ram_gb
        if manual_ram_gb:
            # Convert GB to bytes
            max_memory['cpu'] = int(manual_ram_gb * 1e9)
            print(f"  CPU: {manual_ram_gb}GB (user specified)")
        else:
            # Auto-detect available system RAM, leave 4GB for system
            import psutil
            available_ram = psutil.virtual_memory().available
            usable_ram = max(0, available_ram - int(4e9))  # Leave 4GB for OS
            max_memory['cpu'] = usable_ram
            print(f"  CPU: {usable_ram / 1e9:.1f}GB (auto-detected, 4GB reserved for system)")
        
        return max_memory
    
    def _try_load_model(self, model_name: str, load_kwargs: dict, device: str) -> Optional[any]:
        """Try to load model with given settings, return None on OOM or retry without quantization if unsupported."""
        import torch
        from transformers import AutoModelForCausalLM
        
        try:
            model = AutoModelForCausalLM.from_pretrained(model_name, **load_kwargs)
            if device == "cpu" and load_kwargs.get('device_map') is None:
                model = model.to(device)
            return model
        except (RuntimeError, torch.cuda.OutOfMemoryError) as e:
            error_msg = str(e).lower()
            if "out of memory" in error_msg or "cuda" in error_msg or "oom" in error_msg:
                return None
            raise
        except TypeError as e:
            error_msg = str(e).lower()
            # Check if the error is about unsupported quantization arguments
            if "load_in_4bit" in error_msg or "load_in_8bit" in error_msg or "unexpected keyword argument" in error_msg:
                # Check if we have quantization args that need to be removed
                if 'load_in_4bit' in load_kwargs or 'load_in_8bit' in load_kwargs:
                    print(f"Warning: Model does not support bitsandbytes quantization (load_in_4bit/load_in_8bit)")
                    print("Retrying without quantization...")
                    # Create a copy of load_kwargs without quantization args
                    retry_kwargs = load_kwargs.copy()
                    retry_kwargs.pop('load_in_4bit', None)
                    retry_kwargs.pop('load_in_8bit', None)
                    # Retry loading without quantization
                    try:
                        model = AutoModelForCausalLM.from_pretrained(model_name, **retry_kwargs)
                        if device == "cpu" and retry_kwargs.get('device_map') is None:
                            model = model.to(device)
                        print("Model loaded successfully without quantization")
                        return model
                    except (RuntimeError, torch.cuda.OutOfMemoryError) as e2:
                        error_msg2 = str(e2).lower()
                        if "out of memory" in error_msg2 or "cuda" in error_msg2 or "oom" in error_msg2:
                            return None
                        raise
                    except TypeError:
                        # If it still fails with TypeError, re-raise the original error
                        raise e
            # Re-raise if not a quantization-related error
            raise
    
    def _is_moe_model(self, model_name: str) -> bool:
        """Check if model is a MoE (Mixture of Experts) model which needs more VRAM headroom."""
        moe_indicators = ['moe', 'mixtral', 'qwen3_5_moe', 'qwen3.5_moe', 'expert', 'a3b']
        model_name_lower = model_name.lower()
        return any(indicator in model_name_lower for indicator in moe_indicators)
    
    def _get_vram_percentages_for_strategy(self, strategy: str, is_moe: bool, total_vram_gb: float) -> list:
        """Get VRAM percentage steps based on offload strategy."""
        if strategy == "conservative":
            print(f"  Using conservative offload strategy - minimal VRAM usage for maximum stability")
            if is_moe:
                return [0.70, 0.65, 0.60, 0.50, 0.40, 0.30, 0.20, 0.0]
            return [0.80, 0.75, 0.70, 0.65, 0.50, 0.40, 0.30, 0.20, 0.0]
        elif strategy == "balanced":
            print(f"  Using balanced offload strategy - good performance with reasonable stability")
            if is_moe:
                return [0.75, 0.70, 0.65, 0.60, 0.50, 0.40, 0.30, 0.20, 0.0]
            return [0.85, 0.80, 0.75, 0.70, 0.65, 0.50, 0.40, 0.30, 0.20, 0.0]
        elif strategy == "aggressive":
            print(f"  Using aggressive offload strategy - maximize VRAM usage for performance")
            if is_moe:
                return [0.85, 0.80, 0.75, 0.70, 0.65, 0.60, 0.50, 0.40, 0.30, 0.20, 0.0]
            return [0.95, 0.90, 0.85, 0.80, 0.75, 0.70, 0.65, 0.50, 0.40, 0.30, 0.20, 0.0]
        elif strategy == "sequential":
            print(f"  Using sequential offload strategy - fine-grained incremental VRAM reduction")
            # Fine-grained steps with 2% increments for precise memory management
            if is_moe:
                return [0.80, 0.78, 0.76, 0.74, 0.72, 0.70, 0.68, 0.66, 0.64, 0.62, 0.60, 0.55, 0.50, 0.45, 0.40, 0.35, 0.30, 0.25, 0.20, 0.0]
            return [0.93, 0.91, 0.89, 0.87, 0.85, 0.83, 0.81, 0.79, 0.77, 0.75, 0.73, 0.71, 0.69, 0.67, 0.65, 0.60, 0.55, 0.50, 0.45, 0.40, 0.35, 0.30, 0.20, 0.0]
        else:  # auto
            if total_vram_gb < 3:
                print(f"  Detected small GPU ({total_vram_gb:.1f}GB), using aggressive VRAM usage (99% start)")
                return [0.99, 0.95, 0.90, 0.85, 0.75, 0.65, 0.50, 0.35, 0.20, 0.0]
            elif total_vram_gb <= 8:
                print(f"  Detected medium GPU ({total_vram_gb:.1f}GB), using high VRAM usage (96% start)")
                return [0.96, 0.90, 0.85, 0.75, 0.65, 0.50, 0.35, 0.20, 0.0]
            else:
                if is_moe:
                    print(f"  Detected large GPU ({total_vram_gb:.1f}GB), using MoE-safe VRAM usage (80% start)")
                    return [0.80, 0.75, 0.70, 0.65, 0.60, 0.50, 0.40, 0.30, 0.20, 0.0]
                else:
                    print(f"  Detected large GPU ({total_vram_gb:.1f}GB), using conservative VRAM usage (93% start)")
                    return [0.93, 0.85, 0.75, 0.65, 0.50, 0.35, 0.20, 0.0]
    
    def _get_vram_percentages_for_gpu(self, model_name: str = "", strategy: str = "auto", max_gpu_percent: float = None) -> list:
        """Get VRAM percentage steps based on GPU memory size, model type, and offload strategy."""
        import torch
        
        if not torch.cuda.is_available():
            return [0.0]  # CPU only
        
        # If max_gpu_percent is specified, use it to create custom percentage steps
        if max_gpu_percent is not None:
            # Clamp to valid range (5-100%)
            max_pct = max(0.05, min(1.0, max_gpu_percent / 100.0))
            print(f"  Using custom max GPU percent: {max_pct*100:.0f}%")
            # Create a descending series from max_pct down to 0
            steps = []
            current = max_pct
            while current > 0.05:
                steps.append(current)
                # Reduce by 5% each step, or smaller steps near the end
                if current > 0.3:
                    current -= 0.05
                elif current > 0.15:
                    current -= 0.03
                else:
                    current -= 0.02
            steps.append(0.0)
            return steps
        
        # Get total VRAM of the first GPU
        total_vram_gb = 0
        for i in range(torch.cuda.device_count()):
            props = torch.cuda.get_device_properties(i)
            total_vram_gb += props.total_memory / 1e9
        
        # Check if MoE model (needs more headroom for generation)
        is_moe = self._is_moe_model(model_name)
        if is_moe:
            print(f"  Detected MoE model, using extra conservative VRAM limits for generation headroom")
        
        return self._get_vram_percentages_for_strategy(strategy, is_moe, total_vram_gb)
    
    def load_model(self, model_name: str, **kwargs) -> None:
        """Load the model using HuggingFace Transformers with automatic OOM handling."""
        import torch
        from transformers import AutoModelForCausalLM, AutoTokenizer
        
        offload_dir = kwargs.get('offload_dir')
        load_in_4bit = kwargs.get('load_in_4bit', False)
        load_in_8bit = kwargs.get('load_in_8bit', False)
        manual_ram_gb = kwargs.get('manual_ram_gb')
        flash_attn = kwargs.get('flash_attn', False)
        offload_strategy = kwargs.get('offload_strategy', 'auto')
        max_gpu_percent = kwargs.get('max_gpu_percent', None)
        
        # Store RAM limit for use in _get_gpu_memory_map
        self._pending_ram_gb = manual_ram_gb
        
        print(f"Loading HuggingFace model: {model_name}")
        
        self.use_flash_attn = flash_attn
        self.check_flash_attn_support()
        
        self.device = self._detect_device()
        
        # Load tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(
            model_name,
            trust_remote_code=True,
            padding_side="left"
        )
        
        # Set pad token if not present
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
        
        # Prepare model loading arguments
        load_kwargs = {'trust_remote_code': True}
        
        # Check if model supports quantization
        if load_in_4bit or load_in_8bit:
            # Qwen3.5-A3B/MoE models don't support bitsandbytes quantization
            if 'qwen3.5' in model_name.lower() and ('a3b' in model_name.lower() or 'moe' in model_name.lower()):
                print(f"Warning: {model_name} does not support bitsandbytes quantization (load_in_4bit/load_in_8bit)")
                print("Quantization disabled for this model")
            else:
                try:
                    import bitsandbytes as bnb
                    print(f"Using {4 if load_in_4bit else 8}-bit quantization")
                    load_kwargs['load_in_4bit'] = load_in_4bit
                    load_kwargs['load_in_8bit'] = load_in_8bit
                except ImportError:
                    print("Warning: bitsandbytes not installed. Quantization disabled.")
        
        # Set dtype
        if self.device == "cuda":
            load_kwargs['torch_dtype'] = torch.float16
        else:
            load_kwargs['torch_dtype'] = torch.float32
        
        # Add offload folder if specified (disk offloading is last resort)
        if offload_dir:
            os.makedirs(offload_dir, exist_ok=True)
            load_kwargs['offload_folder'] = offload_dir
        
        # Add Flash Attention 2 if enabled
        if self.use_flash_attn and self.flash_attn_available:
            load_kwargs['attn_implementation'] = "flash_attention_2"
            print("Using Flash Attention 2")
        
        # Try loading with automatic fallback on OOM
        model = None
        vram_percentages = self._get_vram_percentages_for_gpu(model_name, offload_strategy, max_gpu_percent)
        first_vram_pct = vram_percentages[0] if vram_percentages else 0.93
        
        for vram_pct in vram_percentages:
            if self.device != "cuda":
                # CPU-only mode
                load_kwargs['device_map'] = None
                print("Loading model in CPU-only mode...")
                model = self._try_load_model(model_name, load_kwargs, self.device)
                if model is not None:
                    break
            
            # GPU mode with varying VRAM limits
            if vram_pct > 0:
                max_memory = self._get_gpu_memory_map_with_limit(vram_pct)
                load_kwargs['max_memory'] = max_memory
                load_kwargs['device_map'] = 'auto'
                print(f"\nTrying with GPU limit: {vram_pct*100:.0f}% VRAM")
                if offload_dir:
                    print(f"  Disk offload directory: {offload_dir}")
                
                model = self._try_load_model(model_name, load_kwargs, self.device)
                
                if model is not None:
                    print(f"  ✓ Model loaded successfully with {vram_pct*100:.0f}% GPU VRAM limit")
                    if vram_pct < first_vram_pct:
                        print(f"  (Reduced from {first_vram_pct*100:.0f}% due to memory constraints)")
                    break
                else:
                    print(f"  ✗ Out of memory with {vram_pct*100:.0f}% GPU VRAM, trying lower limit...")
                    # Clear CUDA cache before retry
                    if torch.cuda.is_available():
                        torch.cuda.empty_cache()
            else:
                # Last resort: CPU-only mode with offloading
                print("\nFalling back to CPU-only mode (no GPU layers)...")
                load_kwargs['max_memory'] = {0: 0, 'cpu': int((manual_ram_gb or 48) * 1e9)}
                load_kwargs['device_map'] = 'auto'
                model = self._try_load_model(model_name, load_kwargs, "cpu")
                if model is not None:
                    print("  ✓ Model loaded successfully on CPU")
                    break
        
        if model is None:
            raise RuntimeError("Failed to load model: Out of memory even with minimum GPU usage")
        
        self.model = model
        self.model.eval()
        self.model_name = model_name
        
        print(f"\nModel loaded successfully")
        print(f"Model device: {next(self.model.parameters()).device}")
    
    def _get_gpu_memory_map_with_limit(self, vram_fraction: float) -> Dict:
        """Get max_memory dict with specified VRAM fraction limit."""
        import torch
        max_memory = {}
        
        if torch.cuda.is_available():
            for i in range(torch.cuda.device_count()):
                props = torch.cuda.get_device_properties(i)
                total_vram = props.total_memory
                usable_vram = int(total_vram * vram_fraction)
                max_memory[i] = usable_vram
        
        # CPU memory
        manual_ram_gb = getattr(self, '_pending_ram_gb', None)
        if manual_ram_gb:
            max_memory['cpu'] = int(manual_ram_gb * 1e9)
        else:
            import psutil
            available_ram = psutil.virtual_memory().available
            usable_ram = max(0, available_ram - int(4e9))
            max_memory['cpu'] = usable_ram
        
        return max_memory
    
    def format_messages(self, messages: List[ChatMessage]) -> str:
        """Format messages into a prompt string."""
        formatted = []
        
        for msg in messages:
            if msg.role == "system":
                formatted.append(f"System: {msg.content}")
            elif msg.role == "user":
                formatted.append(f"User: {msg.content}")
            elif msg.role == "assistant":
                content = msg.content or ""
                if msg.tool_calls:
                    for tc in msg.tool_calls:
                        if tc.get("function"):
                            func = tc["function"]
                            content += f'\n<tool>{{"name": "{func.get("name", "")}", "arguments": {func.get("arguments", "{}")}}}</tool>'
                formatted.append(f"Assistant: {content}")
            elif msg.role == "tool":
                formatted.append(f"Tool ({msg.name}): {msg.content}")
        
        formatted.append("Assistant:")
        return "\n\n".join(formatted)
    
    def _validate_params(self, temperature: float, top_p: float) -> tuple:
        """Validate generation parameters."""
        if temperature <= 0:
            temperature = 1.0
            do_sample = False
        else:
            temperature = max(0.01, min(temperature, 2.0))
            do_sample = True
        top_p = max(0.0, min(top_p, 1.0))
        return temperature, top_p, do_sample
    
    def generate(self, prompt: str, max_tokens: Optional[int] = None,
                 temperature: float = 0.7, top_p: float = 1.0,
                 stop: Optional[List[str]] = None) -> str:
        """Generate text non-streaming."""
        import torch
        from transformers import LogitsProcessor, LogitsProcessorList
        
        class InvalidLogitsProcessor(LogitsProcessor):
            def __call__(self, input_ids, scores):
                scores = torch.where(torch.isnan(scores), torch.tensor(-1e9, dtype=scores.dtype, device=scores.device), scores)
                scores = torch.where(torch.isinf(scores), torch.tensor(1e9, dtype=scores.dtype, device=scores.device), scores)
                return scores
        
        inputs = self.tokenizer(prompt, return_tensors="pt", padding=True)
        inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
        
        if max_tokens is None:
            max_tokens = 512
        
        temperature, top_p, do_sample = self._validate_params(temperature, top_p)
        
        # Try generation with OOM handling
        try:
            with torch.no_grad():
                outputs = self.model.generate(
                    input_ids=inputs["input_ids"],
                    attention_mask=inputs["attention_mask"],
                    max_new_tokens=max_tokens,
                    temperature=temperature if do_sample else None,
                    top_p=top_p if do_sample else None,
                    do_sample=do_sample,
                    pad_token_id=self.tokenizer.pad_token_id,
                    eos_token_id=self.tokenizer.eos_token_id,
                    logits_processor=LogitsProcessorList([InvalidLogitsProcessor()]),
                )
            
            generated_tokens = outputs[0][inputs["input_ids"].shape[1]:]
            return self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
        except (RuntimeError, torch.cuda.OutOfMemoryError) as e:
            error_msg = str(e).lower()
            if "out of memory" in error_msg or "cuda" in error_msg or "oom" in error_msg:
                print(f"Warning: CUDA OOM during generation. Clearing cache and retrying with reduced tokens...")
                torch.cuda.empty_cache()
                # Retry with half the tokens
                try:
                    with torch.no_grad():
                        outputs = self.model.generate(
                            input_ids=inputs["input_ids"],
                            attention_mask=inputs["attention_mask"],
                            max_new_tokens=max(1, max_tokens // 2),
                            temperature=temperature if do_sample else None,
                            top_p=top_p if do_sample else None,
                            do_sample=do_sample,
                            pad_token_id=self.tokenizer.pad_token_id,
                            eos_token_id=self.tokenizer.eos_token_id,
                            logits_processor=LogitsProcessorList([InvalidLogitsProcessor()]),
                        )
                    
                    generated_tokens = outputs[0][inputs["input_ids"].shape[1]:]
                    return self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
                except Exception as e2:
                    print(f"Error: Generation failed even with reduced tokens: {e2}")
                    return "[Error: Out of memory during generation. Try reducing --max-gpu-percent or using a smaller model.]"
            raise
    
    async def generate_stream(self, prompt: str, max_tokens: Optional[int] = None,
                              temperature: float = 0.7, top_p: float = 1.0,
                              stop: Optional[List[str]] = None) -> AsyncGenerator[str, None]:
        """Generate text in streaming fashion."""
        import torch
        from transformers import TextIteratorStreamer, LogitsProcessor, LogitsProcessorList, StoppingCriteria, StoppingCriteriaList
        
        class InvalidLogitsProcessor(LogitsProcessor):
            def __call__(self, input_ids, scores):
                scores = torch.where(torch.isnan(scores), torch.tensor(-1e9, dtype=scores.dtype, device=scores.device), scores)
                scores = torch.where(torch.isinf(scores), torch.tensor(1e9, dtype=scores.dtype, device=scores.device), scores)
                return scores
        
        inputs = self.tokenizer(prompt, return_tensors="pt", padding=True)
        inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
        
        if max_tokens is None:
            max_tokens = 512
        
        temperature, top_p, do_sample = self._validate_params(temperature, top_p)
        
        streamer = TextIteratorStreamer(
            self.tokenizer,
            skip_prompt=True,
            skip_special_tokens=True,
        )
        
        generation_kwargs = {
            "input_ids": inputs["input_ids"],
            "attention_mask": inputs["attention_mask"],
            "max_new_tokens": max_tokens,
            "temperature": temperature,
            "top_p": top_p,
            "do_sample": do_sample,
            "streamer": streamer,
            "pad_token_id": self.tokenizer.pad_token_id,
            "eos_token_id": self.tokenizer.eos_token_id,
            "logits_processor": LogitsProcessorList([InvalidLogitsProcessor()]),
        }
        
        # Handle stop sequences
        if stop:
            class StopOnSequence(StoppingCriteria):
                def __init__(self, stop_sequences, tokenizer):
                    self.stop_sequences = stop_sequences
                    self.tokenizer = tokenizer
                
                def __call__(self, input_ids, scores, **kwargs):
                    decoded = self.tokenizer.decode(input_ids[0][-20:], skip_special_tokens=True)
                    return any(seq in decoded for seq in self.stop_sequences)
            
            generation_kwargs["stopping_criteria"] = StoppingCriteriaList([
                StopOnSequence(stop, self.tokenizer)
            ])
        
        # Run generation in a separate thread with OOM handling
        generation_error = None
        
        def generate_with_error_handling():
            nonlocal generation_error
            try:
                self.model.generate(**generation_kwargs)
            except (RuntimeError, torch.cuda.OutOfMemoryError) as e:
                error_msg = str(e).lower()
                if "out of memory" in error_msg or "cuda" in error_msg or "oom" in error_msg:
                    generation_error = "oom"
                    print(f"Warning: CUDA OOM during streaming generation. Clearing cache...")
                    torch.cuda.empty_cache()
                else:
                    generation_error = str(e)
        
        thread = Thread(target=generate_with_error_handling)
        thread.start()
        
        try:
            for text in streamer:
                yield text
        except Exception as e:
            print(f"Error during stream iteration: {e}")
        
        thread.join()
        
        # Check if there was an OOM error
        if generation_error == "oom":
            yield "\n[Warning: Generation stopped due to out-of-memory. Try reducing --max-gpu-percent.]"
        elif generation_error:
            yield f"\n[Error during generation: {generation_error}]"
    
    def get_model_name(self) -> str:
        return self.model_name or "unknown"
    
    def cleanup(self) -> None:
        import torch
        if self.model is not None:
            del self.model
            del self.tokenizer
            self.model = None
            self.tokenizer = None
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
# =============================================================================
# Vulkan Backend (llama-cpp-python)
# =============================================================================

class VulkanBackend(ModelBackend):
    """Backend for Vulkan (AMD GPUs) using llama-cpp-python with GGUF models."""
    
    def __init__(self):
        self.model = None
        self.model_name = None
        self.n_gpu_layers = -1  # Offload all layers to GPU by default
        self.n_ctx = 2048
        self.verbose = True
        self.main_gpu = 0  # Default to first GPU
        self.chat_template = None  # Detected chat template name
        self._detect_chat_template()
    
    def _detect_chat_template(self):
        """Detect the chat template used by the model."""
        try:
            # Try to get the chat template from the model
            # llama.cpp models have a chat_template attribute
            from llama_cpp.llama_chat_format import ChatFormatterResponse
            # We'll detect it when the model is loaded
            self.chat_template = "unknown"
            print("DEBUG: Chat template detection will happen after model load")
        except Exception as e:
            print(f"DEBUG: Could not initialize chat template detection: {e}")
            self.chat_template = None
    
    def _finalize_chat_template_detection(self):
        """Finalize chat template detection after model is loaded."""
        try:
            # Try to get the chat template name from the model's chat formatter
            if hasattr(self.model, 'tokenizer') and self.model.tokenizer:
                tokenizer = self.model.tokenizer
                # Check if there's a chat_template attribute
                if hasattr(tokenizer, 'chat_template'):
                    template = tokenizer.chat_template
                    if template:
                        # Detect common templates
                        template_str = str(template)
                        if 'qwen' in template_str.lower():
                            self.chat_template = "qwen"
                        elif 'phi' in template_str.lower():
                            self.chat_template = "phi"
                        elif 'llama3' in template_str.lower() or 'llama-3' in template_str.lower():
                            self.chat_template = "llama3"
                        elif 'chatml' in template_str.lower():
                            self.chat_template = "chatml"
                        else:
                            self.chat_template = "default"
                        print(f"DEBUG: Detected chat template: {self.chat_template}")
                        return
            
            # Try a test message to see what format works
            test_messages = [{"role": "user", "content": "test"}]
            try:
                self.model.create_chat_completion(messages=test_messages, max_tokens=1)
                self.chat_template = "default"
                print("DEBUG: Chat template detected via test: default")
            except Exception as e:
                error_str = str(e).lower()
                if 'jinja' in error_str:
                    # Jinja template issue - try without tools
                    self.chat_template = "jinja_fallback"
                    print("DEBUG: Chat template detected: jinja_fallback (will use manual formatting)")
                else:
                    self.chat_template = "unknown"
                    print(f"DEBUG: Chat template detection failed: {e}")
        except Exception as e:
            self.chat_template = "unknown"
            print(f"DEBUG: Final chat template detection error: {e}")
        
    def list_vulkan_devices(self):
        """List available Vulkan GPU devices."""
        try:
            # Try to get device info via vulkaninfo or similar
            import subprocess
            result = subprocess.run(['vulkaninfo', '--summary'], capture_output=True, text=True)
            if result.returncode == 0:
                print("\nAvailable Vulkan devices:")
                print(result.stdout)
        except Exception:
            pass
    
    def count_vulkan_devices(self):
        """Count the number of Vulkan GPU devices available."""
        # llama.cpp filters out some devices (like CPU llvmpipe), so we need to
        # count only the devices that llama.cpp will actually use
        try:
            import subprocess
            result = subprocess.run(['vulkaninfo', '--summary'], capture_output=True, text=True)
            if result.returncode == 0:
                # Count GPU0, GPU1, etc. entries (these are actual GPUs, not CPU)
                import re
                gpu_matches = re.findall(r'^GPU\d+:', result.stdout, re.MULTILINE)
                # Filter out CPU devices (llvmpipe)
                non_cpu_gpus = 0
                for i, match in enumerate(gpu_matches):
                    # Check if this GPU is a CPU device
                    gpu_section = result.stdout.split(match)[1].split('\nGPU')[0] if i < len(gpu_matches) - 1 else result.stdout.split(match)[1]
                    if 'llvmpipe' not in gpu_section and 'CPU' not in gpu_section:
                        non_cpu_gpus += 1
                return max(non_cpu_gpus, 1)
        except:
            pass
        return 2  # Default to 2 (common for NVIDIA + AMD setups)
    
    def load_model(self, model_name: str, **kwargs) -> None:
        """Load a GGUF model using llama-cpp-python."""
        from llama_cpp import Llama
        
        # model_name can be:
        # - Local file path to .gguf
        # - HuggingFace model ID (e.g., "microsoft/Phi-3-mini-4k-instruct-gguf")
        # - Full URL to a GGUF file
        
        n_gpu_layers = kwargs.get('n_gpu_layers', -1)
        n_ctx = kwargs.get('n_ctx', 2048)
        verbose = kwargs.get('verbose', True)
        main_gpu = kwargs.get('main_gpu', 0)
        self.main_gpu = main_gpu
        
        # Check if model_name is a URL - download it (with caching)
        model_path = None
        if model_name.startswith('http://') or model_name.startswith('https://'):
            # Check cache first
            cached_path = get_cached_model_path(model_name)
            if cached_path:
                model_path = cached_path
                print(f"Using cached model: {model_path}")
            else:
                print(f"Downloading model from URL: {model_name}")
                try:
                    import requests
                    from huggingface_hub import hf_hub_download
                    import tempfile
                    import hashlib
                    
                    # Get cache directory
                    cache_dir = get_model_cache_dir()
                    
                    # Extract filename from URL
                    url_path = model_name.split('?')[0]  # Remove query params
                    filename = os.path.basename(url_path)
                    
                    if not filename.endswith('.gguf'):
                        filename = "model.gguf"
                    
                    # Create safe filename in cache (use hash to avoid special char issues)
                    url_hash = hashlib.sha256(model_name.encode()).hexdigest()
                    cached_filename = f"{url_hash}_{filename}"
                    model_path = os.path.join(cache_dir, cached_filename)
                    
                    # Download to cache
                    response = requests.get(model_name, stream=True)
                    response.raise_for_status()
                    
                    total_size = int(response.headers.get('content-length', 0))
                    downloaded = 0
                    
                    with open(model_path, 'wb') as f:
                        for chunk in response.iter_content(chunk_size=8192*1024):  # 8MB chunks
                            if chunk:
                                f.write(chunk)
                                downloaded += len(chunk)
                                if total_size > 0:
                                    percent = (downloaded / total_size) * 100
                                    print(f"Downloaded: {percent:.1f}%", end='\r')
                    
                    print(f"\nDownloaded and cached to: {model_path}")
                    print(f"File size: {os.path.getsize(model_path) / 1e9:.2f} GB")
                    
                except Exception as e:
                    print(f"Error downloading model: {e}")
                    raise
        
        # Check if model_name is a local file
        elif os.path.isfile(model_name):
            model_path = model_name
            print(f"Loading local GGUF model: {model_path}")
        else:
            # Try to download from HuggingFace Hub
            print(f"Attempting to download GGUF model: {model_name}")
            try:
                from huggingface_hub import hf_hub_download, list_repo_files
                
                # Parse model name (format: "org/model" or "org/model/filename.gguf")
                parts = model_name.split('/')
                if len(parts) >= 2:
                    repo_id = f"{parts[0]}/{parts[1]}"
                    
                    # If specific file provided
                    if len(parts) >= 3 and parts[-1].endswith('.gguf'):
                        filename = '/'.join(parts[2:])
                    else:
                        # Find GGUF files in the repo
                        files = list_repo_files(repo_id)
                        gguf_files = [f for f in files if f.endswith('.gguf')]
                        if not gguf_files:
                            raise ValueError(f"No GGUF files found in {repo_id}")
                        # Prefer Q4_K_M quantized models for good balance
                        preferred = [f for f in gguf_files if 'Q4_K_M' in f or 'q4_k_m' in f.lower()]
                        if preferred:
                            filename = preferred[0]
                        else:
                            filename = gguf_files[0]
                        print(f"Selected GGUF file: {filename}")
                    
                    model_path = hf_hub_download(repo_id=repo_id, filename=filename)
                    print(f"Downloaded to: {model_path}")
                else:
                    raise ValueError(f"Invalid model name format: {model_name}")
            except Exception as e:
                print(f"Error downloading model: {e}")
                print("Please provide a local path to a .gguf file")
                raise
        
        print(f"Loading GGUF model with Vulkan support...")
        print(f"  Model path: {model_path}")
        print(f"  GPU layers: {n_gpu_layers} (-1 = all layers)")
        print(f"  Context size: {n_ctx}")
        print(f"  GPU device: {main_gpu}")
        
        # List available devices for user reference
        self.list_vulkan_devices()
        
        # Check if single GPU mode is requested
        single_gpu = kwargs.get('single_gpu', False)
        tensor_split = None
        
        # First, get the number of Vulkan devices from llama.cpp's perspective
        # We'll try to detect from ggml_vulkan output by checking available GPUs
        num_devices = 2  # Default
        
        # Try to parse vulkaninfo to get actual device count
        try:
            import subprocess
            result = subprocess.run(['vulkaninfo', '--summary'], capture_output=True, text=True)
            if result.returncode == 0:
                # Count actual GPU devices (exclude llvmpipe CPU)
                import re
                lines = result.stdout.split('\n')
                gpu_count = 0
                for i, line in enumerate(lines):
                    if line.strip().startswith('GPU'):
                        # Check next few lines for device type
                        section = '\n'.join(lines[i:i+10])
                        if 'llvmpipe' not in section.lower() and 'cpu' not in section.split('deviceType')[0] if 'deviceType' in result.stdout else '':
                            gpu_count += 1
                if gpu_count > 0:
                    num_devices = gpu_count
        except Exception as e:
            print(f"Warning: Could not detect Vulkan device count: {e}")
        
        print(f"DEBUG: Detected {num_devices} Vulkan GPU devices")
        
        if single_gpu:
            # Build tensor_split to force all layers onto one GPU
            # tensor_split is a list where index = GPU device, value = weight (0.0 = don't use)
            tensor_split = [0.0] * num_devices
            if main_gpu < num_devices:
                tensor_split[main_gpu] = 1.0
                print(f"  Single GPU mode: Setting tensor_split for GPU {main_gpu}: {tensor_split}")
            else:
                print(f"Warning: main_gpu={main_gpu} exceeds detected devices ({num_devices}), ignoring single_gpu")
                tensor_split = None
        
        try:
            llama_kwargs = {
                'model_path': model_path,
                'n_gpu_layers': n_gpu_layers,
                'n_ctx': n_ctx,
                'verbose': verbose,
                'main_gpu': main_gpu,
            }
            
            if tensor_split:
                llama_kwargs['tensor_split'] = tensor_split
            
            self.model = Llama(**llama_kwargs)
            self.model_name = model_name
            print("\nModel loaded successfully with Vulkan!")
            
            # Detect the chat template after model load
            self._finalize_chat_template_detection()
            print(f"DEBUG: Chat template: {self.chat_template}")
        except Exception as e:
            print(f"Error loading model with Vulkan: {e}")
            print("Make sure Vulkan drivers are installed:")
            print("  Debian/Ubuntu: sudo apt install libvulkan-dev vulkan-tools")
            print("  Fedora: sudo dnf install vulkan-loader-devel vulkan-tools")
            raise
    
    def format_messages(self, messages: List[ChatMessage]) -> str:
        """Format messages into a prompt string suitable for chat models.
        
        Uses llama.cpp's built-in chat template support for proper formatting.
        """
        # Convert to format expected by llama.cpp
        chat_messages = []
        for msg in messages:
            chat_msg = {"role": msg.role}
            # CRITICAL: Ensure content is never None - Jinja templates fail on None
            # Also ensure content key exists
            if msg.content is not None:
                chat_msg["content"] = msg.content
            else:
                chat_msg["content"] = ""
            if msg.tool_calls:
                chat_msg["tool_calls"] = msg.tool_calls
            chat_messages.append(chat_msg)
        
        # Use llama.cpp's apply_chat_template if available
        try:
            prompt = self.model.apply_chat_template(chat_messages, tokenize=False, add_generation_prompt=True)
            return prompt
        except Exception as e:
            # Fallback to manual formatting if apply_chat_template fails
            print(f"Warning: apply_chat_template failed ({e}), using fallback formatting")
            formatted = []
            for msg in messages:
                if msg.role == "system":
                    formatted.append(f"<|im_start|>system\n{msg.content}<|im_end|>")
                elif msg.role == "user":
                    formatted.append(f"<|im_start|>user\n{msg.content}<|im_end|>")
                elif msg.role == "assistant":
                    content = msg.content or ""
                    formatted.append(f"<|im_start|>assistant\n{content}<|im_end|>")
            
            formatted.append("<|im_start|>assistant\n")
            return "\n".join(formatted)
    
    def generate(self, prompt: str, max_tokens: Optional[int] = None,
                 temperature: float = 0.7, top_p: float = 1.0,
                 stop: Optional[List[str]] = None) -> str:
        """Generate text non-streaming using llama-cpp."""
        if max_tokens is None:
            max_tokens = 512
        
        output = self.model(
            prompt,
            max_tokens=max_tokens,
            temperature=temperature,
            top_p=top_p,
            stop=stop or [],
        )
        
        return output["choices"][0]["text"]
    
    def generate_chat(self, messages: List[Dict], max_tokens: Optional[int] = None,
                      temperature: float = 0.7, top_p: float = 1.0,
                      stop: Optional[List[str]] = None, tools: Optional[List] = None) -> str:
        """Generate chat completion using llama-cpp's create_chat_completion."""
        if max_tokens is None:
            max_tokens = 512
        
        # CRITICAL: Ensure NO message has None content - Jinja templates fail on None
        # This is a safety check in case messages bypass the main endpoint validation
        cleaned_messages = []
        for msg in messages:
            cleaned_msg = dict(msg)  # Make a copy to avoid modifying original
            # Ensure content key exists and is never None
            if "content" not in cleaned_msg:
                cleaned_msg["content"] = ""
            elif cleaned_msg.get("content") is None:
                cleaned_msg["content"] = ""
            # Convert non-string content to string
            elif not isinstance(cleaned_msg["content"], str):
                cleaned_msg["content"] = str(cleaned_msg["content"])
            cleaned_messages.append(cleaned_msg)
        messages = cleaned_messages
        
        # Check if we should use manual formatting based on detected template
        # Always use manual formatting when tools are present, since Jinja templates often fail with tool messages
        use_manual = self.chat_template in ("unknown", "jinja_fallback", None) or tools is not None
        
        if use_manual:
            print(f"DEBUG: Using manual message formatting (template: {self.chat_template}, tools: {tools is not None})")
            prompt = self._manual_format_messages(messages)
            return self.generate(prompt, max_tokens, temperature, top_p, stop)
        
        try:
            response = self.model.create_chat_completion(
                messages=messages,
                max_tokens=max_tokens,
                temperature=temperature,
                top_p=top_p,
                stop=stop or [],
                tools=tools,
            )
            content = response["choices"][0]["message"].get("content", "")
            print(f"DEBUG: generate_chat returned content length: {len(content) if content else 0}")
            if not content or not content.strip():
                print(f"DEBUG: Empty content from create_chat_completion, using fallback")
                raise Exception("Empty response from create_chat_completion")
            return content
        except Exception as e:
            print(f"Warning: create_chat_completion failed ({e}), falling back to text generation")
            # Fallback: format messages manually and use text generation
            prompt = self._manual_format_messages(messages)
            return self.generate(prompt, max_tokens, temperature, top_p, stop)
    
    async def generate_chat_stream(self, messages: List[Dict], max_tokens: Optional[int] = None,
                                    temperature: float = 0.7, top_p: float = 1.0,
                                    stop: Optional[List[str]] = None, tools: Optional[List] = None) -> AsyncGenerator[str, None]:
        """Generate chat completion streaming using llama-cpp."""
        if max_tokens is None:
            max_tokens = 512
        
        total_content = ""
        chunk_count = 0
        has_tools = tools is not None  # Track if tools are available
        
        # CRITICAL: Ensure NO message has None content - Jinja templates fail on None
        # This is a safety check in case messages bypass the main endpoint validation
        cleaned_messages = []
        for msg in messages:
            cleaned_msg = dict(msg)  # Make a copy to avoid modifying original
            # Ensure content key exists and is never None
            if "content" not in cleaned_msg:
                cleaned_msg["content"] = ""
            elif cleaned_msg.get("content") is None:
                cleaned_msg["content"] = ""
            # Convert non-string content to string
            elif not isinstance(cleaned_msg["content"], str):
                cleaned_msg["content"] = str(cleaned_msg["content"])
            cleaned_messages.append(cleaned_msg)
        messages = cleaned_messages
        
        # Check if we should use manual formatting based on detected template
        # Always use manual formatting when tools are present, since Jinja templates often fail with tool messages
        use_manual = self.chat_template in ("unknown", "jinja_fallback", None) or tools is not None
        
        if use_manual:
            print(f"DEBUG: Using manual message formatting for streaming (template: {self.chat_template}, tools: {tools is not None})")
            prompt = self._manual_format_messages(messages)
            async for chunk in self.generate_stream(prompt, max_tokens, temperature, top_p, stop):
                yield chunk
            return
        
        # Collect all chunks synchronously then yield them
        # This avoids issues with generators across thread boundaries
        def collect_chunks():
            """Collect all chunks from the stream."""
            print(f"DEBUG: generate_chat_stream: Calling create_chat_completion with tools={tools}")
            stream = self.model.create_chat_completion(
                messages=messages,
                max_tokens=max_tokens,
                temperature=temperature,
                top_p=top_p,
                stop=stop or [],
                tools=tools,
                stream=True,
            )
            print(f"DEBUG: generate_chat_stream: Got stream object: {type(stream)}")
            chunks = []
            for chunk in stream:
                chunks.append(chunk)
            print(f"DEBUG: generate_chat_stream: Collected {len(chunks)} chunks")
            return chunks
        
        try:
            # Run the collection in thread pool
            loop = asyncio.get_event_loop()
            chunks = await loop.run_in_executor(None, collect_chunks)
            
            for chunk in chunks:
                chunk_count += 1
                print(f"DEBUG: generate_chat_stream: Processing chunk {chunk_count}: {repr(chunk)}")
                delta = chunk["choices"][0].get("delta", {})
                content = delta.get("content", "")
                
                # Handle Qwen3's special thinking token - skip it and continue
                # Qwen3 uses `<think>` tags for reasoning, we should pass through the content
                if content:
                    total_content += content
                    yield content
                
                # Small yield to allow other async tasks
                await asyncio.sleep(0)
            
            print(f"DEBUG: generate_chat_stream yielded {chunk_count} chunks, total content length: {len(total_content)}")
            if chunk_count == 0 or not total_content.strip():
                print(f"DEBUG: Empty stream from create_chat_completion, using fallback")
                raise Exception("Empty stream response")
        except Exception as e:
            print(f"DEBUG: generate_chat_stream exception: {type(e).__name__}: {e}")
            import traceback
            traceback.print_exc()
            if chunk_count == 0:
                print(f"Warning: create_chat_completion stream failed ({e}), falling back to text generation")
                # Fallback: format messages manually and use text generation
                prompt = self._manual_format_messages(messages)
                async for chunk in self.generate_stream(prompt, max_tokens, temperature, top_p, stop):
                    yield chunk
            else:
                print(f"DEBUG: Stream completed with {chunk_count} chunks")
    
    def _manual_format_messages(self, messages: List[Dict]) -> str:
        """Manual fallback for formatting messages when create_chat_completion fails."""
        formatted = []
        for msg in messages:
            role = msg.get("role", "")
            # CRITICAL: Ensure content is never None - Jinja templates fail on None
            # Also ensure content key exists
            content = msg.get("content")
            if content is None:
                content = ""
            elif not isinstance(content, str):
                content = str(content)
            
            if role == "system":
                formatted.append(f"<|im_start|>system\n{content}<|im_end|>")
            elif role == "user":
                formatted.append(f"<|im_start|>user\n{content}<|im_end|>")
            elif role == "assistant":
                # Handle tool_calls if present
                tool_calls = msg.get("tool_calls", [])
                if tool_calls:
                    for tc in tool_calls:
                        if isinstance(tc, dict) and "function" in tc:
                            func = tc["function"]
                            tc_str = f'<tool>{{"name": "{func.get("name", "")}", "arguments": {func.get("arguments", "{}")}}}</tool>'
                            content = content + "\n" + tc_str if content else tc_str
                formatted.append(f"<|im_start|>assistant\n{content}<|im_end|>")
            elif role == "tool":
                # Tool result messages
                tool_call_id = msg.get("tool_call_id", "")
                name = msg.get("name", "")
                formatted.append(f"<|im_start|>tool (tool_call_id={tool_call_id}, name={name})\n{content}<|im_end|>")
        
        formatted.append("<|im_start|>assistant\n")
        return "\n".join(formatted)
    
    async def generate_stream(self, prompt: str, max_tokens: Optional[int] = None,
                              temperature: float = 0.7, top_p: float = 1.0,
                              stop: Optional[List[str]] = None) -> AsyncGenerator[str, None]:
        """Generate text in streaming fashion using llama-cpp."""
        if max_tokens is None:
            max_tokens = 512
        
        stream = self.model(
            prompt,
            max_tokens=max_tokens,
            temperature=temperature,
            top_p=top_p,
            stop=stop or [],
            stream=True,
        )
        
        for chunk in stream:
            text = chunk["choices"][0].get("text", "")
            if text:
                yield text
    
    def get_model_name(self) -> str:
        return self.model_name or "unknown"
    
    def cleanup(self) -> None:
        if self.model is not None:
            del self.model
            self.model = None
# =============================================================================
# Model Manager
# =============================================================================

class ModelManager:
    """Manages the loaded model and tokenizer."""
    
    def __init__(self):
        self.backend: Optional[ModelBackend] = None
        self.backend_type: Optional[str] = None
        self.tool_parser = ToolCallParser()
        
    def load_model(self, model_name: str, backend_type: str = "auto", **kwargs):
        """
        Load the model with the specified backend.
        
        Args:
            model_name: Model name or path
            backend_type: 'nvidia', 'vulkan', or 'auto' to detect
            **kwargs: Additional arguments for the specific backend
        """
        available = detect_available_backends()
        
        # Determine backend
        if backend_type == "auto":
            if available.get('nvidia'):
                backend_type = "nvidia"
                print("Auto-detected NVIDIA backend")
            elif available.get('vulkan'):
                backend_type = "vulkan"
                print("Auto-detected Vulkan backend")
            else:
                print("No GPU backend detected. For NVIDIA, install PyTorch with CUDA.")
                print("For Vulkan, install llama-cpp-python with Vulkan support.")
                raise RuntimeError("No suitable backend found")
        
        self.backend_type = backend_type
        
        # Create appropriate backend
        if backend_type == "nvidia":
            if not available.get('nvidia'):
                raise RuntimeError("NVIDIA backend requested but PyTorch/CUDA not available")
            self.backend = NvidiaBackend()
        elif backend_type == "vulkan":
            if not available.get('vulkan'):
                raise RuntimeError("Vulkan backend requested but llama-cpp-python not available")
            self.backend = VulkanBackend()
        else:
            raise ValueError(f"Unknown backend: {backend_type}")
        
        # Load the model
        self.backend.load_model(model_name, **kwargs)
        self.tool_parser = ToolCallParser()
        
    def format_messages(self, messages: List[ChatMessage]) -> str:
        """Format messages into a prompt string."""
        if self.backend is None:
            raise RuntimeError("No model loaded")
        return self.backend.format_messages(messages)
    
    def generate(self, prompt: str, max_tokens: Optional[int] = None,
                 temperature: float = 0.7, top_p: float = 1.0,
                 stop: Optional[List[str]] = None) -> str:
        """Generate text non-streaming."""
        if self.backend is None:
            raise RuntimeError("No model loaded")
        return self.backend.generate(prompt, max_tokens, temperature, top_p, stop)
    
    def generate_chat(self, messages: List[Dict], max_tokens: Optional[int] = None,
                      temperature: float = 0.7, top_p: float = 1.0,
                      stop: Optional[List[str]] = None, tools: Optional[List] = None) -> str:
        """Generate chat completion non-streaming."""
        if self.backend is None:
            raise RuntimeError("No model loaded")
        # Use generate_chat if available (Vulkan backend), otherwise format and use generate
        if hasattr(self.backend, 'generate_chat'):
            return self.backend.generate_chat(messages, max_tokens, temperature, top_p, stop, tools)
        else:
            # Fallback for NVIDIA backend
            prompt = self.format_messages([ChatMessage(**m) for m in messages])
            return self.backend.generate(prompt, max_tokens, temperature, top_p, stop)
    
    async def generate_stream(self, prompt: str, max_tokens: Optional[int] = None,
                              temperature: float = 0.7, top_p: float = 1.0,
                              stop: Optional[List[str]] = None) -> AsyncGenerator[str, None]:
        """Generate text in streaming fashion."""
        if self.backend is None:
            raise RuntimeError("No model loaded")
        async for chunk in self.backend.generate_stream(prompt, max_tokens, temperature, top_p, stop):
            yield chunk
    
    async def generate_chat_stream(self, messages: List[Dict], max_tokens: Optional[int] = None,
                                    temperature: float = 0.7, top_p: float = 1.0,
                                    stop: Optional[List[str]] = None, tools: Optional[List] = None) -> AsyncGenerator[str, None]:
        """Generate chat completion streaming."""
        if self.backend is None:
            raise RuntimeError("No model loaded")
        # Use generate_chat_stream if available (Vulkan backend), otherwise format and use generate_stream
        if hasattr(self.backend, 'generate_chat_stream'):
            async for chunk in self.backend.generate_chat_stream(messages, max_tokens, temperature, top_p, stop, tools):
                yield chunk
        else:
            # Fallback for NVIDIA backend
            prompt = self.format_messages([ChatMessage(**m) for m in messages])
            async for chunk in self.backend.generate_stream(prompt, max_tokens, temperature, top_p, stop):
                yield chunk
    
    @property
    def model_name(self) -> str:
        if self.backend is None:
            return "unknown"
        return self.backend.get_model_name()
    
    @property
    def model(self):
        if self.backend is None:
            return None
        return self.backend
    
    @property
    def tokenizer(self):
        # Only NVIDIA backend has a tokenizer
        if isinstance(self.backend, NvidiaBackend):
            return self.backend.tokenizer
        return None
    
    def cleanup(self):
        if self.backend is not None:
            self.backend.cleanup()
            self.backend = None
# =============================================================================
# Whisper Server Manager - manages whisper-server subprocess
# =============================================================================

import subprocess
import signal
import requests
import time
import threading
class WhisperServerManager:
    """Manages whisper-server subprocess for audio transcription with model swapping support."""
    
    def __init__(self, server_path: str = None, port: int = 8744):
        self.server_path = server_path
        self.port = port
        self.process = None
        self.current_model = None
        self.base_url = f"http://127.0.0.1:{port}"
        self.lock = threading.Lock()
        self._health_check_thread = None
        self._running = False
        
        # Check if port is available
        if not self._is_port_available(port):
            # Try to find an available port
            for new_port in range(port + 1, port + 100):
                if self._is_port_available(new_port):
                    self.port = new_port
                    self.base_url = f"http://127.0.0.1:{new_port}"
                    print(f"Port {port} in use, using port {new_port} instead")
                    break
    
    def _is_port_available(self, port: int) -> bool:
        """Check if a port is available."""
        import socket
        try:
            with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
                s.bind(('127.0.0.1', port))
                return True
        except OSError:
            return False
    
    def is_running(self) -> bool:
        """Check if whisper-server is running."""
        if self.process is None:
            return False
        return self.process.poll() is None
    
    def start(self, model_path: str = None, gpu_device: int = 0) -> str:
        """Start whisper-server with the specified model. Returns actual model path or empty string on failure."""
        with self.lock:
            # Stop existing server if running
            if self.is_running():
                self.stop()
            
            if not self.server_path:
                print("Error: whisper-server path not set")
                return ""
            
            # Handle URL models - download if needed
            actual_model_path = model_path
            if model_path and (model_path.startswith('http://') or model_path.startswith('https://')):
                # Check cache first
                cached_path = get_cached_model_path(model_path)
                if cached_path:
                    actual_model_path = cached_path
                    print(f"Using cached model: {actual_model_path}")
                else:
                    # Download the model
                    cache_dir = get_model_cache_dir()
                    print(f"Downloading model: {model_path}")
                    actual_model_path = download_model(model_path, cache_dir)
                    print(f"Downloaded model to: {actual_model_path}")
            
            # Build command
            cmd = [self.server_path]
            
            if actual_model_path:
                cmd.extend(["-m", actual_model_path])
            
            # Add GPU device
            cmd.extend(["-dev", str(gpu_device)])
            
            # Add --convert flag to convert audio to 16kHz mono on the server side
            cmd.append("--convert")
            
            # Add host and port
            cmd.extend(["--host", "127.0.0.1"])
            cmd.extend(["--port", str(self.port)])
            
            print(f"Starting whisper-server: {' '.join(cmd)}")
            print(f"DEBUG: Full whisper-server command: {' '.join(cmd)}")
            
            try:
                self.process = subprocess.Popen(
                    cmd,
                    stdout=subprocess.PIPE,
                    stderr=subprocess.PIPE,
                    preexec_fn=lambda: signal.signal(signal.SIGTERM, signal.SIG_DFL)
                )
                self.current_model = actual_model_path
                
                # Wait for server to be ready
                if self._wait_for_server(30):
                    print(f"whisper-server started on {self.base_url}")
                    self._running = True
                    return actual_model_path
                else:
                    print("Error: whisper-server failed to start")
                    self.stop()
                    return ""
            except Exception as e:
                print(f"Error starting whisper-server: {e}")
                return ""
    
    def stop(self):
        """Stop whisper-server."""
        with self.lock:
            self._running = False
            if self.process:
                try:
                    self.process.terminate()
                    try:
                        self.process.wait(timeout=5)
                    except subprocess.TimeoutExpired:
                        self.process.kill()
                        self.process.wait()
                except Exception as e:
                    print(f"Error stopping whisper-server: {e}")
                self.process = None
                self.current_model = None
    
    def restart(self, model_path: str = None, gpu_device: int = 0) -> bool:
        """Restart whisper-server with a new model."""
        print(f"Restarting whisper-server with model: {model_path}")
        return self.start(model_path, gpu_device)
    
    def transcribe(self, audio_data: bytes, language: str = None, prompt: str = None) -> dict:
        """Send transcription request to whisper-server."""
        if not self.is_running():
            return {"error": "whisper-server not running"}
        
        try:
            files = {"file": ("audio.wav", audio_data, "audio/wav")}
            data = {}
            if language:
                data["language"] = language
            if prompt:
                data["prompt"] = prompt
            
            print(f"DEBUG: Sending POST to {self.base_url}/inference with data={data}, file_size={len(audio_data)}")
            response = requests.post(
                f"{self.base_url}/inference",
                files=files,
                data=data,
                timeout=300
            )
            print(f"DEBUG: whisper-server response status={response.status_code}, body={response.text[:500] if response.text else 'empty'}")
            
            if response.status_code == 200:
                return response.json()
            else:
                return {"error": f"Server error: {response.status_code}", "detail": response.text}
        except Exception as e:
            print(f"DEBUG: whisper-server exception: {e}")
            return {"error": str(e)}
    
    def _wait_for_server(self, timeout: int = 30) -> bool:
        """Wait for whisper-server to be ready."""
        start_time = time.time()
        while time.time() - start_time < timeout:
            try:
                response = requests.get(f"{self.base_url}/health", timeout=2)
                if response.status_code == 200:
                    return True
            except:
                pass
            time.sleep(0.5)
        return False
    
    def get_status(self) -> dict:
        """Get whisper-server status."""
        return {
            "running": self.is_running(),
            "model": self.current_model,
            "url": self.base_url
        }
# =============================================================================
# Multi-Model Manager (supports audio transcription and image generation)
# =============================================================================

class MultiModelManager:
    """
    Manages multiple models: main text model, audio transcription, and image generation.
    Supports dynamic switching based on request model name.
    
    Modes:
    - default: Load models on-demand
    - loadall: Pre-load all models in VRAM at startup
    - loadswap: Keep all models in memory (CPU RAM), swap active model to VRAM
    """
    
    def __init__(self):
        self.models: Dict[str, ModelManager] = {}
        self.default_model: Optional[str] = None
        self.audio_models: List[str] = []  # List of audio model names
        self.tts_model: Optional[str] = None
        self.image_models: List[str] = []  # List of image model names
        self.tool_parser = ToolCallParser()
        self.current_model_key: Optional[str] = None
        # Configuration for each model type
        self.config: Dict[str, Dict] = {}
        # Load mode settings
        self.load_mode: str = "ondemand"  # "ondemand", "loadall", "loadswap"
        self.active_in_vram: Optional[str] = None  # Which model is currently in VRAM
        # Model aliases: alias -> actual model name mapping
        self.model_aliases: Dict[str, str] = {}
        # Whisper server manager
        self.whisper_server: Optional[WhisperServerManager] = None
        
    @property
    def audio_model(self) -> Optional[str]:
        """Get the first/default audio model."""
        return self.audio_models[0] if self.audio_models else None
    
    @property
    def image_model(self) -> Optional[str]:
        """Get the first/default image model."""
        return self.image_models[0] if self.image_models else None
        
    def set_load_mode(self, mode: str):
        """Set the load mode: 'ondemand', 'loadall', or 'loadswap'."""
        self.load_mode = mode
    
    def set_default_model(self, model_name: str, config: Dict = None):
        """Set the default/main text model."""
        self.default_model = model_name
        self.config[model_name] = config or {}
    
    def set_audio_model(self, model_name: str, config: Dict = None):
        """Add an audio transcription model."""
        if model_name not in self.audio_models:
            self.audio_models.append(model_name)
        self.config[f"audio:{model_name}"] = config or {}
    
    def set_tts_model(self, model_name: str, config: Dict = None):
        """Set the text-to-speech model."""
        self.tts_model = model_name
        self.config[f"tts:{model_name}"] = config or {}
    
    def set_image_model(self, model_name: str, config: Dict = None):
        """Add an image generation model."""
        if model_name not in self.image_models:
            self.image_models.append(model_name)
        self.config[f"image:{model_name}"] = config or {}
    
    def set_model_alias(self, alias: str, model_name: str):
        """Register an alias for a model."""
        self.model_aliases[alias] = model_name
    
    def get_model_for_request(self, requested_model: str) -> Optional[ModelManager]:
        """
        Get the appropriate model manager for a request based on model name.
        
        Model name conventions:
        - "default", empty, or matches default model -> use main model
        - "audio" -> use first/default audio model
        - "audio:modelname" -> use specific audio model
        - "vision" or "image" -> use first/default image model  
        - "vision:modelname" or "image:modelname" -> use specific image model
        - "tts" -> use TTS model
        - "tts:modelname" -> use specific TTS model
        - Custom aliases -> resolve to actual model name
        - Otherwise match by model ID in multi_model_manager.models
        """
        # Resolve custom aliases first
        if requested_model in self.model_aliases:
            requested_model = self.model_aliases[requested_model]
        
        # Handle empty or "default" model names
        if not requested_model or requested_model == "default":
            if self.default_model and self.default_model in self.models:
                self.current_model_key = self.default_model
                return self.models[self.default_model]
            return None
        
        # Handle "audio" alias - use first/default audio model
        if requested_model == "audio":
            if self.audio_models:
                first_audio = self.audio_models[0]
                key = f"audio:{first_audio}"
                if key in self.models:
                    self.current_model_key = key
                    return self.models[key]
                # Try to load on demand
                return None
            return None
        
        # Handle "vision" or "image" alias - use first/default image model
        if requested_model in ("vision", "image"):
            if self.image_models:
                first_image = self.image_models[0]
                key = f"image:{first_image}"
                if key in self.models:
                    self.current_model_key = key
                    return self.models[key]
                # Try to load on demand
                return None
            return None
        
        # Handle "tts" alias
        if requested_model == "tts":
            if self.tts_model:
                key = f"tts:{self.tts_model}"
                if key in self.models:
                    self.current_model_key = key
                    return self.models[key]
                return None
            return None
        
        # Check for specialized models with prefix
        if requested_model.startswith("audio:"):
            audio_name = requested_model[6:]  # Remove "audio:" prefix
            key = f"audio:{audio_name}"
            if key in self.models:
                self.current_model_key = key
                return self.models[key]
            elif audio_name in self.audio_models:
                # Try loading audio model on demand
                return None
            return None
        
        if requested_model.startswith("tts:"):
            tts_name = requested_model[4:]  # Remove "tts:" prefix
            key = f"tts:{tts_name}"
            if key in self.models:
                self.current_model_key = key
                return self.models[key]
            elif self.tts_model and tts_name == self.tts_model:
                return None
            return None
        
        # Handle both "vision:" and "image:" prefixes
        if requested_model.startswith("vision:") or requested_model.startswith("image:"):
            # Extract the model name (remove either prefix)
            if requested_model.startswith("vision:"):
                image_name = requested_model[7:]  # Remove "vision:" prefix
            else:
                image_name = requested_model[6:]  # Remove "image:" prefix
            key = f"image:{image_name}"
            if key in self.models:
                self.current_model_key = key
                return self.models[key]
            elif image_name in self.image_models:
                # Try loading image model on demand
                return None
            return None
        
        # Check if it's the default model
        if self.default_model and (requested_model == self.default_model or 
                                    requested_model.endswith(self.default_model.split("/")[-1])):
            self.current_model_key = self.default_model
            return self.models.get(self.default_model)
        
        # Check if any loaded model matches
        for key, model in self.models.items():
            if requested_model in key or key.endswith(requested_model.split("/")[-1]):
                self.current_model_key = key
                return model
        
        return None
    
    def add_model(self, key: str, manager: ModelManager):
        """Add a model manager for a specific key."""
        self.models[key] = manager
    
    def get_model(self, key: str) -> Optional[ModelManager]:
        """Get a model manager by key."""
        return self.models.get(key)
    
    def get_current_model(self) -> Optional[ModelManager]:
        """Get the currently active model."""
        if self.current_model_key:
            return self.models.get(self.current_model_key)
        if self.default_model:
            return self.models.get(self.default_model)
        return None
    
    def list_models(self) -> List[ModelInfo]:
        """List all available models."""
        models = []
        
        # Add default model(s)
        if self.default_model:
            model_id = self.default_model
            # Also add short name
            short_name = self.default_model.split("/")[-1] if "/" in self.default_model else self.default_model
            if short_name != self.default_model:
                models.append(ModelInfo(id=short_name))
            models.append(ModelInfo(id=model_id))
            models.append(ModelInfo(id="default"))
        
        # Add aliases for first/default models
        if self.audio_models:
            models.append(ModelInfo(id="audio"))  # Alias for first audio model
            # Add all audio models
            for audio_id in self.audio_models:
                models.append(ModelInfo(id=f"audio:{audio_id}"))
        
        # Add TTS models
        if self.tts_model:
            models.append(ModelInfo(id="tts"))  # Alias for TTS
            tts_id = f"tts:{self.tts_model}"
            models.append(ModelInfo(id=tts_id))
        
        # Add vision/image models
        if self.image_models:
            models.append(ModelInfo(id="vision"))  # Alias for first image model
            models.append(ModelInfo(id="image"))    # Alias for first image model
            # Add all image models - convert URLs to cached paths for display
            for image_id in self.image_models:
                # Check if image_id is a URL and try to get cached path
                display_id = image_id
                if image_id.startswith('http://') or image_id.startswith('https://'):
                    cached_path = get_cached_model_path(image_id)
                    if cached_path:
                        # Use the filename from cached path for display
                        display_id = os.path.basename(cached_path)
                models.append(ModelInfo(id=f"image:{display_id}"))
                models.append(ModelInfo(id=f"vision:{display_id}"))
        
        # Add loaded models that aren't in the above categories
        for key in self.models:
            # Skip if already added
            if key == self.default_model or key.startswith("audio:") or key.startswith("image:") or key.startswith("tts:"):
                continue
            # Skip short names (already added)
            if self.default_model and key == self.default_model.split("/")[-1]:
                continue
            models.append(ModelInfo(id=key))
        
        # Add custom model aliases
        for alias in self.model_aliases:
            models.append(ModelInfo(id=alias))
        
        return models if models else [ModelInfo(id="default")]
    
    def swap_model_to_vram(self, model_key: str) -> bool:
        """
        Swap a model from CPU RAM to GPU VRAM.
        Returns True if successful, False otherwise.
        Only applies in loadswap mode.
        """
        if self.load_mode != "loadswap":
            return True  # No swapping needed in other modes
        
        if self.active_in_vram == model_key:
            return True  # Already in VRAM
        
        # If another model is in VRAM, swap it to CPU first
        if self.active_in_vram and self.active_in_vram in self.models:
            self._swap_model_to_cpu(self.active_in_vram)
        
        # Swap the requested model to VRAM
        success = self._swap_model_to_vram(model_key)
        if success:
            self.active_in_vram = model_key
        return success
    
    def _swap_model_to_vram(self, model_key: str) -> bool:
        """Internal method to swap a specific model to VRAM."""
        # This would need backend-specific implementation
        # For now, we assume the model is already in memory
        print(f"SWAP: Moving model {model_key} to VRAM")
        return True
    
    def _swap_model_to_cpu(self, model_key: str) -> bool:
        """Internal method to swap a specific model to CPU RAM."""
        # This would need backend-specific implementation
        # For now, we just track that it's in CPU
        print(f"SWAP: Moving model {model_key} to CPU RAM")
        return True
    
    def get_active_model_key(self) -> Optional[str]:
        """Get the currently active model key."""
        return self.current_model_key or self.default_model
    
    def cleanup(self):
        """Cleanup all models."""
        for model in self.models.values():
            model.cleanup()
        self.models.clear()
# Global multi-model manager
multi_model_manager = MultiModelManager()
# Global model manager (for backward compatibility)
model_manager = ModelManager()

# Global args for access in endpoints
global_args = None

# Global system prompt (set via --system-prompt flag)
# None = don't inject, True = use default, string = use custom text
global_system_prompt = None

# Global debug flag
global_debug = False

# =============================================================================
# Queue Manager for Model Loading Notifications
# =============================================================================

class QueueManager:
    """
    Manages request queue for model loading notifications.
    When clients are waiting for a model to load, sends them progress updates.
    """
    
    def __init__(self):
        self.waiting_requests: Dict[str, float] = {}  # request_id -> start_time
        self.current_request_id: Optional[str] = None
        self.model_loading: bool = False
        self.model_name: Optional[str] = None
        self.lock = asyncio.Lock()
    
    async def add_waiting(self, request_id: str) -> None:
        """Add a request to the waiting queue."""
        async with self.lock:
            self.waiting_requests[request_id] = time.time()
    
    async def remove_waiting(self, request_id: str) -> None:
        """Remove a request from the waiting queue."""
        async with self.lock:
            self.waiting_requests.pop(request_id, None)
    
    async def start_processing(self, request_id: str, model_name: str = None) -> None:
        """Mark a request as now processing (model loaded)."""
        async with self.lock:
            self.waiting_requests.pop(request_id, None)
            self.current_request_id = request_id
            self.model_name = model_name
    
    async def finish_processing(self) -> None:
        """Mark current request as finished."""
        async with self.lock:
            self.current_request_id = None
    
    async def is_waiting(self, request_id: str) -> bool:
        """Check if a request is in the waiting queue."""
        async with self.lock:
            return request_id in self.waiting_requests
    
    async def get_wait_time(self, request_id: str) -> float:
        """Get how long a request has been waiting in seconds."""
        async with self.lock:
            if request_id in self.waiting_requests:
                return time.time() - self.waiting_requests[request_id]
            return 0.0
    
    async def get_queue_position(self, request_id: str) -> int:
        """Get the position of a request in the queue (1-based)."""
        async with self.lock:
            keys = list(self.waiting_requests.keys())
            try:
                return keys.index(request_id) + 1
            except ValueError:
                return 0
# Global queue manager
queue_manager = QueueManager()
# =============================================================================
# FastAPI Application
# =============================================================================

@asynccontextmanager
async def lifespan(app: FastAPI):
    """Lifespan context manager for startup/shutdown."""
    # Startup
    yield
    # Shutdown
    multi_model_manager.cleanup()
    model_manager.cleanup()
    # Stop whisper-server if running
    if multi_model_manager.whisper_server:
        multi_model_manager.whisper_server.stop()
app = FastAPI(
    title="OpenAI-Compatible API",
    description="OpenAI-compatible API supporting NVIDIA (CUDA) and Vulkan backends",
    version="2.0.0",
    lifespan=lifespan,
)

# Add request logging middleware for debugging
@app.middleware("http")
async def log_requests(request: Request, call_next):
    """Log all incoming requests for debugging."""
    if request.url.path in ["/v1/chat/completions", "/v1/completions"]:
        body = b""
        body_str = ""
        try:
            body = await request.body()
            body_str = body.decode('utf-8')
            
            # In debug mode, dump the full request
            if global_debug:
                print(f"\n{'='*80}")
                print(f"=== FULL DEBUG REQUEST ===")
                print(f"{'='*80}")
                print(f"Path: {request.url.path}")
                print(f"Method: {request.method}")
                print(f"Headers: {dict(request.headers)}")
                print(f"\n--- FULL BODY ({len(body)} bytes) ---")
                print(body_str)
                print(f"--- END FULL BODY ---")
                print(f"{'='*80}\n")
            else:
                print(f"\n{'='*60}")
                print(f"=== INCOMING REQUEST ===")
                print(f"{'='*60}")
                print(f"Path: {request.url.path}")
                print(f"Method: {request.method}")
                print(f"Headers: {dict(request.headers)}")
                print(f"\n--- RAW BODY ({len(body)} bytes) ---")
                # Print body with truncation for very large bodies
                if len(body_str) > 2000:
                    print(f"{body_str[:1000]}...\n... [truncated {len(body_str)-2000} chars] ...\n...{body_str[-1000:]}")
                else:
                    print(body_str)
                print(f"--- END RAW BODY ---")
            
            # Try to parse as JSON to see if it's valid
            try:
                parsed = json.loads(body_str)
                print(f"\n--- PARSED JSON STRUCTURE ---")
                print(f"Keys: {list(parsed.keys())}")
                if 'messages' in parsed and isinstance(parsed['messages'], list):
                    print(f"Number of messages: {len(parsed['messages'])}")
                    for i, msg in enumerate(parsed['messages']):
                        role = msg.get('role', 'unknown')
                        content_preview = str(msg.get('content', ''))[:50].replace('\n', ' ')
                        print(f"  [{i}] {role}: {content_preview}...")
                print(f"--- END PARSED JSON ---")
            except json.JSONDecodeError as e:
                print(f"\n*** JSON Parse Error: {e} ***")
                print(f"Error at position: char {e.pos}, line {e.lineno}, column {e.colno}")
            except Exception as e:
                print(f"\n*** Error analyzing JSON: {e} ***")
        except Exception as e:
            # Handle ClientDisconnect and other exceptions gracefully
            print(f"Error logging request: {e}")
            # Continue with empty body if we couldn't read it
            body = b""
        
        # Re-create request with body for downstream handlers (only if we successfully read it)
        if body:
            async def receive():
                return {"type": "http.request", "body": body}
            request = Request(request.scope, receive, request._send)
    
    try:
        response = await call_next(request)
        if request.url.path in ["/v1/chat/completions", "/v1/completions"]:
            print(f"\n--- RESPONSE ---")
            print(f"Status Code: {response.status_code}")
            
            # For error responses, try to read and log the body, then create a new response
            if response.status_code >= 400:
                try:
                    # Read the response body
                    response_body = b""
                    async for chunk in response.body_iterator:
                        response_body += chunk
                    
                    error_body = response_body.decode('utf-8')
                    print(f"Error Response Body: {error_body}")
                    
                    # Try to parse and pretty-print error details
                    try:
                        error_json = json.loads(error_body)
                        if 'detail' in error_json:
                            print(f"\n*** VALIDATION ERROR DETAILS ***")
                            detail = error_json['detail']
                            if isinstance(detail, list):
                                for err in detail:
                                    if isinstance(err, dict):
                                        loc = err.get('loc', [])
                                        msg = err.get('msg', 'Unknown error')
                                        err_type = err.get('type', 'unknown')
                                        print(f"  - Location: {loc}")
                                        print(f"    Message: {msg}")
                                        print(f"    Type: {err_type}")
                            else:
                                print(f"  Detail: {detail}")
                            print(f"*** END ERROR DETAILS ***")
                    except Exception as parse_err:
                        print(f"Could not parse error details: {parse_err}")
                    
                    # Create new response with the same body
                    from starlette.responses import Response
                    return Response(
                        content=response_body,
                        status_code=response.status_code,
                        headers=dict(response.headers),
                        media_type=response.media_type
                    )
                except Exception as read_err:
                    print(f"Could not read error response body: {read_err}")
            
            print(f"--- END RESPONSE ---")
            print(f"{'='*60}\n")
        return response
    except Exception as e:
        print(f"\n*** EXCEPTION DURING REQUEST PROCESSING ***")
        print(f"Error type: {type(e).__name__}")
        print(f"Error message: {e}")
        import traceback
        traceback.print_exc()
        print(f"*** END EXCEPTION ***\n")
        raise
    finally:
        if request.url.path in ["/v1/chat/completions", "/v1/completions"]:
            pass  # End logging already done above for successful responses
@app.get("/v1/models", response_model=ModelList)
async def list_models():
    """List available models."""
    models = multi_model_manager.list_models()
    return ModelList(data=models)
# =============================================================================
# Audio Transcription Endpoint
# =============================================================================

from fastapi import UploadFile, File, Form

@app.post("/v1/audio/transcriptions")
async def create_transcription(
    model: str = Form(...),
    file: UploadFile = File(...),
    language: Optional[str] = Form(None),
    prompt: Optional[str] = Form(None),
    response_format: Optional[str] = Form("json"),
    temperature: Optional[float] = Form(0.0),
):
    """
    Audio transcription endpoint (OpenAI-compatible).
    
    Supports:
    - OpenAI's whisper-1 model (via OpenAI API)
    - Local faster-whisper models (when --audio-model is specified)
    - whisper.cpp server (when --whisper-server is specified)
    """
    # Check if whisper-server is available FIRST (before checking audio_model)
    print(f"DEBUG: Audio request - whisper_server available: {multi_model_manager.whisper_server is not None}, running: {multi_model_manager.whisper_server.is_running() if multi_model_manager.whisper_server else 'N/A'}")
    if multi_model_manager.whisper_server and multi_model_manager.whisper_server.is_running():
        # Use whisper-server - read file and send to server
        file_content = await file.read()
        print(f"DEBUG: whisper-server transcription request - file_size={len(file_content)}, language={language}, prompt={prompt}")
        result = multi_model_manager.whisper_server.transcribe(
            file_content,
            language=language,
            prompt=prompt
        )
        print(f"DEBUG: whisper-server transcription result: {result}")
        if "error" in result:
            raise HTTPException(status_code=500, detail=result["error"])
        # Convert whisper-server response to OpenAI format
        text = result.get("text", "")
        return {
            "text": text
        }
    
    audio_model = multi_model_manager.audio_model
    
    # DEBUG: Print audio model status
    print(f"DEBUG: audio_model check - audio_models list: {multi_model_manager.audio_models}, audio_model: {audio_model}, whisper_server: {multi_model_manager.whisper_server}")
    
    # If no audio model configured, return an error
    if not audio_model:
        raise HTTPException(
            status_code=400,
            detail="Audio transcription not configured. Use --audio-model or --whisper-server to specify a model."
        )
    
    # Determine model to use - always use the configured audio model
    # The model parameter from the request is ignored in favor of the configured transcription model
    actual_model = audio_model  # This is the configured transcription model from --audio-model
    model_to_use = actual_model
    
    print(f"DEBUG: Transcription request model: {model}, using configured model: {actual_model}")
    
    # Check if Vulkan is available for whispercpp
    whisper_vulkan_available = False
    try:
        # Check if whispercpp is installed and has Vulkan support
        import whispercpp
        # Try to detect Vulkan support by checking if we can list devices
        # whispercpp doesn't have a direct Vulkan check, but we can verify by environment
        if os.environ.get('VK_DEVICE_SELECT_DEVICE'):
            whisper_vulkan_available = True
            print(f"Whisper Vulkan: Using configured Vulkan device")
        elif os.path.exists('/dev/dri'):  # Linux DRM devices exist = AMD/Intel GPU
            whisper_vulkan_available = True
            print(f"Whisper Vulkan: Auto-detected GPU")
    except ImportError:
        pass
    
    # Read file content
    file_content = await file.read()
    
    # Write to temp file
    import tempfile
    
    with tempfile.NamedTemporaryFile(delete=False, suffix=f"_{file.filename}") as tmp:
        tmp.write(file_content)
        tmp_path = tmp.name
    
    try:
        # Check if model is a GGUF file - faster-whisper doesn't support GGUF format
        is_gguf_model = model_to_use.endswith('.gguf') or 'gguf' in model_to_use.lower()
        
        if is_gguf_model:
            # Skip faster-whisper for GGUF files - go directly to whispercpp
            print("Detected GGUF model - using whispercpp backend")
            faster_whisper_failed = True
        else:
            # Try faster-whisper first
            faster_whisper_failed = False
            try:
                from faster_whisper import WhisperModel
                
                # Determine compute type based on GPU availability
                import torch
                if torch.cuda.is_available():
                    compute_type = "float16"
                else:
                    compute_type = "int8"
                
                # Try to load the model (lazy loading)
                model_key = f"audio:{model_to_use}"
                whisper_model = multi_model_manager.get_model(model_key)
                
                if whisper_model is None:
                    print(f"Loading faster-whisper model: {model_to_use}")
                    
                    # Check if model_to_use is a URL - download it (with caching)
                    model_path = None
                    if model_to_use.startswith('http://') or model_to_use.startswith('https://'):
                        # Check cache first
                        cached_path = get_cached_model_path(model_to_use)
                        if cached_path:
                            model_to_use = cached_path
                            print(f"Using cached model: {model_to_use}")
                        else:
                            print(f"Downloading model from URL: {model_to_use}")
                            try:
                                import requests
                                import hashlib
                                
                                # Get cache directory
                                cache_dir = get_model_cache_dir()
                                
                                # Extract filename from URL
                                url_path = model_to_use.split('?')[0]
                                filename = os.path.basename(url_path)
                                
                                if not filename.endswith('.bin') and not filename.endswith('.ggml'):
                                    filename = "whisper-model.bin"
                                
                                # Create safe filename in cache
                                url_hash = hashlib.sha256(model_to_use.encode()).hexdigest()
                                cached_filename = f"{url_hash}_{filename}"
                                model_path = os.path.join(cache_dir, cached_filename)
                                
                                # Download to cache
                                response = requests.get(model_to_use, stream=True)
                                response.raise_for_status()
                                
                                total_size = int(response.headers.get('content-length', 0))
                                downloaded = 0
                                
                                with open(model_path, 'wb') as f:
                                    for chunk in response.iter_content(chunk_size=8192*1024):
                                        if chunk:
                                            f.write(chunk)
                                            downloaded += len(chunk)
                                            if total_size > 0:
                                                percent = (downloaded / total_size) * 100
                                                print(f"Downloaded: {percent:.1f}%", end='\r')
                                
                                print(f"\nDownloaded and cached to: {model_path}")
                                model_to_use = model_path
                                
                            except Exception as e:
                                print(f"Error downloading model: {e}")
                                raise
                    
                    whisper_model = WhisperModel(
                        model_to_use,
                        device="cpu",  # faster-whisper CUDA doesn't work with AMD/Vulkan
                        compute_type=compute_type
                    )
                    # Store in multi_model_manager
                    multi_model_manager.add_model(model_key, whisper_model)
                
                # Run transcription
                segments, info = whisper_model.transcribe(
                    tmp_path,
                    language=language,
                    initial_prompt=prompt,
                    temperature=temperature or 0.0,
                )
                
                # Collect all segments
                text_parts = []
                for segment in segments:
                    text_parts.append(segment.text.strip())
                
                full_text = " ".join(text_parts)
                
                return {"text": full_text}
            
            except ImportError:
                # faster-whisper not available, will try whispercpp below
                faster_whisper_failed = True
            except Exception as e:
                # faster-whisper failed for some other reason
                print(f"Warning: faster-whisper failed to load model: {e}")
                faster_whisper_failed = True
        
        # If faster-whisper failed (not installed or couldn't load), try whispercpp
        if faster_whisper_failed:
            try:
                import whispercpp
                
                # Try to load the model (lazy loading)
                model_key = f"audio:{model_to_use}"
                whisper_model = multi_model_manager.get_model(model_key)
                
                if whisper_model is None:
                    print(f"Loading whispercpp model: {model_to_use}")
                    if whisper_vulkan_available:
                        print(f"  -> Using Vulkan GPU acceleration (device {whisper_vulkan_device})")
                    
                    # Check if model_to_use is a URL - download it (with caching)
                    model_path = None
                    if model_to_use.startswith('http://') or model_to_use.startswith('https://'):
                        # Check cache first
                        cached_path = get_cached_model_path(model_to_use)
                        if cached_path:
                            model_path = cached_path
                            print(f"Using cached model: {model_path}")
                        else:
                            print(f"Downloading model from URL: {model_to_use}")
                            try:
                                import requests
                                import hashlib
                                
                                # Get cache directory
                                cache_dir = get_model_cache_dir()
                                
                                # Extract filename from URL
                                url_path = model_to_use.split('?')[0]
                                filename = os.path.basename(url_path)
                                
                                if not filename.endswith('.gguf'):
                                    filename = "whisper-model.gguf"
                                
                                # Create safe filename in cache
                                url_hash = hashlib.sha256(model_to_use.encode()).hexdigest()
                                cached_filename = f"{url_hash}_{filename}"
                                model_path = os.path.join(cache_dir, cached_filename)
                                
                                # Download to cache
                                response = requests.get(model_to_use, stream=True)
                                response.raise_for_status()
                                
                                total_size = int(response.headers.get('content-length', 0))
                                downloaded = 0
                                
                                with open(model_path, 'wb') as f:
                                    for chunk in response.iter_content(chunk_size=8192*1024):
                                        if chunk:
                                            f.write(chunk)
                                            downloaded += len(chunk)
                                            if total_size > 0:
                                                percent = (downloaded / total_size) * 100
                                                print(f"Downloaded: {percent:.1f}%", end='\r')
                                
                                print(f"\nDownloaded and cached to: {model_path}")
                                model_to_use = model_path
                                
                            except Exception as e:
                                print(f"Error downloading model: {e}")
                                raise
                    
                    # whispercpp needs a local file path
                    if not model_path:
                        model_path = model_to_use if os.path.isfile(model_to_use) else None
                    
                    if not model_path or not os.path.isfile(model_path):
                        raise HTTPException(
                            status_code=400,
                            detail="whispercpp requires a local GGUF file path. Cannot use URLs directly."
                        )
                    
                    # Load the whispercpp model
                    # Note: whispercpp uses model files directly, not paths like Llama
                    # whispercpp only supports:
                    # 1. Built-in model names (tiny, base, small, medium, large-v1, large)
                    # 2. Pre-converted GGUF files in whisper.cpp format (NOT HuggingFace GGUF)
                    try:
                        whisper_model = whispercpp.Whisper.from_pretrained(model_path)
                    except Exception as e:
                        error_msg = str(e).lower()
                        if 'not a valid preconverted model' in error_msg:
                            # This is expected for HuggingFace GGUF files
                            print(f"Warning: whispercpp does not support HuggingFace GGUF format")
                            print("whispercpp only supports its own pre-converted models or built-in names.")
                            print("For Vulkan audio transcription, please either:")
                            print("  1. Install PyTorch + faster-whisper: pip install torch faster-whisper")
                            print("  2. Use a built-in whispercpp model: --audio-model base")
                            raise HTTPException(
                                status_code=400,
                                detail="whispercpp does not support HuggingFace GGUF Whisper models. Use --audio-model with a built-in name (tiny/base/small/medium/large-v1/large) or install faster-whisper with PyTorch."
                            )
                        else:
                            raise
                    
                    # Store in multi_model_manager
                    multi_model_manager.add_model(model_key, whisper_model)
                
                # Run transcription
                # whispercpp returns text directly
                result = whisper_model.transcribe(tmp_path)
                
                # Collect all segments
                text_parts = []
                for segment in result:
                    text_parts.append(str(segment).strip())
                
                full_text = " ".join(text_parts) if text_parts else ""
                
                return {"text": full_text}
            
            except ImportError as e:
                # Check if it's a specific error about whispercpp not working
                error_msg = str(e).lower()
                if 'invalid elf' in error_msg or 'mach-o' in error_msg:
                    # whispercpp library failed to load - architecture mismatch
                    print(f"Warning: whispercpp library failed to load: {e}")
                    print("This usually means whispercpp was installed for a different OS/architecture.")
                    print("Try reinstalling: pip install whispercpp --force-reinstall")
                    print("Audio model will load on-demand when transcription is requested.")
                else:
                    # Neither faster-whisper nor whispercpp available
                    print(f"Warning: No audio transcription library available: {e}")
                    print("Options:")
                    print("  1. Install PyTorch + faster-whisper: pip install torch faster-whisper")
                    print("  2. Use a built-in whispercpp model: --audio-model base")
                    print("  3. Use --whisper-cpp to specify whisper.cpp CLI path")
                    print("Audio model will load on-demand when transcription is requested.")
                
                # Try whisper.cpp CLI as fallback if specified
                whisper_cpp_path = getattr(global_args, 'whisper_cpp', None)
                if whisper_cpp_path and os.path.isfile(whisper_cpp_path):
                    print(f"Using whisper.cpp CLI: {whisper_cpp_path}")
                    try:
                        import subprocess
                        
                        # Determine the model path - check if it's already cached or needs downloading
                        model_path = None
                        
                        # First check if it's a local file
                        if os.path.isfile(model_to_use):
                            model_path = model_to_use
                            print(f"DEBUG: Using local model file: {model_path}")
                        else:
                            # Check cache for downloaded model
                            cached = get_cached_model_path(model_to_use)
                            if cached and os.path.isfile(cached):
                                model_path = cached
                                print(f"DEBUG: Using cached model: {model_path}")
                            else:
                                # Download the model if not cached
                                print(f"DEBUG: Model not cached, downloading: {model_to_use}")
                                cache_dir = get_model_cache_dir()
                                model_path = download_model(model_to_use, cache_dir)
                                print(f"DEBUG: Downloaded model to: {model_path}")
                        
                        print(f"DEBUG: Whisper model: {model_to_use}")
                        print(f"DEBUG: Whisper model path (resolved): {model_path}")
                        
                        # Build whisper.cpp CLI command
                        # Usage: whisper-cli [options] file0 file1 ...
                        # Options:
                        #   -m, --model FNAME    model path
                        #   -f, --file FNAME    input audio file
                        #   -of, --output-file  output file (without extension)
                        #   -dev, --device N   GPU device ID
                        #   -otxt, --output-txt output as text file
                        cmd = [whisper_cpp_path]
                        if model_path:
                            cmd.extend(["-m", model_path])
                        cmd.extend(["-f", tmp_path])
                        cmd.extend(["-otxt"])  # Output as text
                        
                        # Add Vulkan device if specified
                        audio_vulkan_device = getattr(global_args, 'audio_vulkan_device', 0)
                        if audio_vulkan_device is not None:
                            cmd.extend(["-dev", str(audio_vulkan_device)])
                        
                        print(f"DEBUG: Running whisper.cpp command: {' '.join(cmd)}")
                        
                        result = subprocess.run(cmd, capture_output=True, text=True, timeout=300)
                        
                        if result.returncode == 0:
                            # Read output - whisper.cpp -otxt outputs to stdout or a file
                            # With -otxt flag, it outputs the transcribed text
                            if result.stdout:
                                full_text = result.stdout
                            else:
                                # Try to read from a file with same name as input but .txt extension
                                output_txt = tmp_path + ".txt"
                                if os.path.exists(output_txt):
                                    with open(output_txt, 'r') as f:
                                        full_text = f.read()
                                    os.unlink(output_txt)
                                else:
                                    full_text = ""
                            return {"text": full_text}
                        else:
                            print(f"whisper.cpp CLI error: {result.stderr}")
                    except Exception as subprocess_error:
                        print(f"whisper.cpp CLI subprocess error: {subprocess_error}")
                
                # Return error response
                raise HTTPException(
                    status_code=501,
                    detail="Audio transcription not available. Install faster-whisper (requires PyTorch) or use --whisper-cpp to specify whisper.cpp CLI path."
                )
        
    finally:
        # Cleanup temp file
        os.unlink(tmp_path)
# =============================================================================
# Image Generation Endpoint
# =============================================================================

# Global load_mode tracker - will be set in main()
def get_load_mode():
    return load_mode.get("mode", "ondemand")

@app.post("/v1/images/generations")
async def create_image_generation(request: ImageGenerationRequest):
    """
    Image generation endpoint (OpenAI-compatible).
    
    Supports:
    - Stable Diffusion via stable-diffusion-cpp-python (sd.cpp)
    - Stable Diffusion XL (via local inference with diffusers)
    - Other diffusers models
    """
    # Get or create semaphore for this model
    model_key = f"image:{request.model}" if request.model else "image"
    mode = get_load_mode()
    
    # Check if --image-1 is set (no queue, return 409 if busy)
    use_1_mode = queue_flags.get("image_1", False)
    
    # In loadall mode, allow 1 concurrent request per model
    # In ondemand mode, serialize all requests (use global semaphore)
    if mode == "loadall":
        if model_key not in model_semaphores:
            model_semaphores[model_key] = asyncio.Semaphore(1)
        semaphore = model_semaphores[model_key]
    else:
        # Use a global semaphore for ondemand mode
        if "global_image" not in model_semaphores:
            model_semaphores["global_image"] = asyncio.Semaphore(1)
        semaphore = model_semaphores["global_image"]
    
    # Try to acquire semaphore without blocking
    if use_1_mode:
        acquired = semaphore.locked()
        if acquired:
            raise HTTPException(
                status_code=409,
                detail="Image model is busy. Try again later."
            )
    
    async with semaphore:
        image_model = multi_model_manager.image_model
    
    # If no image model configured, return an error
    if not image_model:
        raise HTTPException(
            status_code=400,
            detail="Image generation not configured. Use --image-model to specify a model."
        )
    
    # Determine model to use
    model_to_use = request.model
    if model_to_use.startswith("image:"):
        model_to_use = image_model
    
    # Track errors for proper fallback chain
    diffusers_error = None
    sd_cpp_error = None
    
    # Parse size (e.g., "1024x1024")
    width, height = 1024, 1024
    if request.size:
        parts = request.size.split("x")
        if len(parts) == 2:
            try:
                width = int(parts[0])
                height = int(parts[1])
            except ValueError:
                pass
    
    # Try diffusers first (torch-based, best quality for NVIDIA)
    try:
        import torch
        from diffusers import StableDiffusionXLPipeline, DiffusionPipeline
        
        # Determine model key
        model_key = f"image:{model_to_use}"
        pipeline = multi_model_manager.get_model(model_key)
        
        if pipeline is None:
            print(f"Loading Stable Diffusion model: {model_to_use}")
            
            # Try to load as Stable Diffusion XL first
            try:
                pipeline = StableDiffusionXLPipeline.from_pretrained(
                    model_to_use,
                    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
                    use_safetensors=True,
                )
            except Exception:
                # Try generic diffusion pipeline
                pipeline = DiffusionPipeline.from_pretrained(
                    model_to_use,
                    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
                    use_safetensors=True,
                )
            
            # Move to GPU if available
            if torch.cuda.is_available():
                pipeline = pipeline.to("cuda")
            else:
                pipeline = pipeline.to("cpu")
            
            # Enable attention slicing for lower memory usage
            if torch.cuda.is_available():
                pipeline.enable_attention_slicing()
            
            multi_model_manager.add_model(model_key, pipeline)
        
        # Generate images
        generator = None
        if request.seed is not None:
            generator = torch.Generator(device=pipeline.device).manual_seed(request.seed)
        
        # Quality: "standard" or "hd"
        quality = request.quality or "standard"
        
        # Generate
        result = pipeline(
            prompt=request.prompt,
            negative_prompt=None,
            num_images_per_prompt=request.n,
            height=height,
            width=width,
            generator=generator,
            guidance_scale=7.5 if quality == "standard" else 9.0,
            num_inference_steps=30 if quality == "standard" else 50,
        )
        
        # Extract images
        images = []
        for img in result.images:
            # Convert to base64
            import base64
            import io
            
            buffered = io.BytesIO()
            img.save(buffered, format="PNG")
            img_bytes = buffered.getvalue()
            img_base64 = base64.b64encode(img_bytes).decode('utf-8')
            
            if request.response_format == "base64":
                images.append({"b64_json": img_base64})
            else:
                # For URL format, we'd need to save somewhere
                # For now, return base64
                images.append({"b64_json": img_base64})
        
        return {
            "created": int(time.time()),
            "data": images
        }
        
    except ImportError as e:
        # diffusers/torch not installed - record error and try sd.cpp
        diffusers_error = str(e)
        print(f"diffusers not available: {diffusers_error}, trying stable-diffusion-cpp-python...")
    except Exception as e:
        # Other error with diffusers - record and try sd.cpp
        diffusers_error = str(e)
        print(f"diffusers error: {diffusers_error}, trying stable-diffusion-cpp-python...")
    
    # Try stable-diffusion-cpp-python (sd.cpp) as fallback
    # First, check all available image models to find one loaded via sd.cpp
    sd_model = None
    for key in multi_model_manager.models:
        if key.startswith("image:"):
            potential_model = multi_model_manager.get_model(key)
            if potential_model is not None:
                # Check if it's a stable-diffusion-cpp model
                try:
                    from stable_diffusion_cpp import StableDiffusion
                    if isinstance(potential_model, StableDiffusion):
                        sd_model = potential_model
                        print(f"Found stable-diffusion-cpp model with key: {key}")
                        break
                except ImportError:
                    pass
    
    if sd_model is not None:
        # Check if it's a stable-diffusion-cpp model (has generate method from sd.cpp)
        try:
            from stable_diffusion_cpp import StableDiffusion
            if isinstance(sd_model, StableDiffusion):
                print(f"Using stable-diffusion-cpp-python for image generation")
                # Use sd.cpp for generation
                # Parse size
                width, height = 512, 512
                if request.size:
                    parts = request.size.split("x")
                    if len(parts) == 2:
                        try:
                            width = int(parts[0])
                            height = int(parts[1])
                        except ValueError:
                            pass
                
                # Use default steps for Z-Image Turbo (very fast)
                steps = 4  # Default for fast generation
                
                # Generate images using sd.cpp (run in thread to not block event loop)
                result = await asyncio.to_thread(
                    sd_model.generate_image,
                    prompt=request.prompt,
                    negative_prompt='',
                    width=width,
                    height=height,
                    cfg_scale=7.0,
                    sample_steps=steps,
                    seed=42,
                    batch_count=request.n if request.n else 1,
                )
                
                # Convert results to response format
                images = []
                import base64
                import io
                from PIL import Image
                
                for img in result:
                    # Convert to base64
                    buffered = io.BytesIO()
                    if isinstance(img, Image.Image):
                        img.save(buffered, format="PNG")
                    else:
                        # Might be numpy array
                        Image.fromarray(img).save(buffered, format="PNG")
                    img_bytes = buffered.getvalue()
                    img_base64 = base64.b64encode(img_bytes).decode('utf-8')
                    images.append({"b64_json": img_base64})
                
                return {
                    "created": int(time.time()),
                    "data": images
                }
        except ImportError as e:
            # stable-diffusion-cpp not available
            sd_cpp_error = str(e)
            print(f"stable-diffusion-cpp-python not available: {sd_cpp_error}")
        except Exception as e:
            print(f"sd.cpp generation error: {e}")
            sd_cpp_error = str(e)
    else:
        # No sd.cpp model pre-loaded, try to load dynamically
        print("No pre-loaded sd.cpp model found, trying to load...")
        try:
            from stable_diffusion_cpp import StableDiffusion
            
            # Check if model_to_use is a URL and get cached path
            model_path = None
            if model_to_use.startswith('http://') or model_to_use.startswith('https://'):
                cached_path = get_cached_model_path(model_to_use)
                if cached_path:
                    model_path = cached_path
                    print(f"Using cached model: {model_path}")
            
            if model_path is None and os.path.isfile(model_to_use):
                model_path = model_to_use
            
            if model_path is None:
                print("Warning: Could not resolve sd.cpp model path")
                sd_cpp_error = "Could not resolve model path"
            else:
                # Load sd.cpp model
                sd_model = StableDiffusion(
                    model_path=model_path,
                    vae_path=None,
                    n_threads=4,
                    n_gpu_layers=-1,  # All layers to GPU
                )
                
                print(f"Using stable-diffusion-cpp-python for image generation")
                
                # Generate images
                width, height = 512, 512
                if request.size:
                    parts = request.size.split("x")
                    if len(parts) == 2:
                        try:
                            width = int(parts[0])
                            height = int(parts[1])
                        except ValueError:
                            pass
                
                steps = 4
                
                result = await asyncio.to_thread(
                    sd_model.generate_image,
                    prompt=request.prompt,
                    negative_prompt='',
                    width=width,
                    height=height,
                    cfg_scale=7.0,
                    sample_steps=steps,
                    seed=42,
                    batch_count=request.n if request.n else 1,
                )
                
                # Convert results to response format
                images = []
                import base64
                import io
                from PIL import Image
                
                for img in result:
                    buffered = io.BytesIO()
                    if isinstance(img, Image.Image):
                        img.save(buffered, format="PNG")
                    else:
                        Image.fromarray(img).save(buffered, format="PNG")
                    img_bytes = buffered.getvalue()
                    img_base64 = base64.b64encode(img_bytes).decode('utf-8')
                    images.append({"b64_json": img_base64})
                
                return {
                    "created": int(time.time()),
                    "data": images
                }
        except ImportError as e:
            sd_cpp_error = str(e)
            print(f"stable-diffusion-cpp-python not available: {sd_cpp_error}")
        except Exception as e:
            sd_cpp_error = str(e)
            print(f"sd.cpp error: {sd_cpp_error}")
    
    # Both backends failed - return error with installation instructions
    error_details = []
    if diffusers_error:
        error_details.append(f"diffusers: {diffusers_error}")
    if sd_cpp_error:
        error_details.append(f"sd.cpp: {sd_cpp_error}")
    
    raise HTTPException(
        status_code=501,
        detail=f"Image generation not available. Tried: {', '.join(error_details)}. "
               f"Install either: pip install diffusers torch accelerate safetensors (for NVIDIA) "
               f"or: pip install stable-diffusion-cpp-python (for Vulkan/AMD)"
    )
# =============================================================================
# Text-to-Speech Endpoint
# =============================================================================

class TTSRequest(BaseModel):
    model: str
    input: str
    voice: Optional[str] = "af_sarah"
    response_format: Optional[str] = "mp3"
    speed: Optional[float] = 1.0
    
    model_config = ConfigDict(extra="allow")
class TTSResponse(BaseModel):
    audio: str  # base64 encoded audio
    model_config = ConfigDict(extra="allow")
@app.post("/v1/audio/speech")
async def create_speech(request: TTSRequest):
    """
    Text-to-speech endpoint (OpenAI-compatible).
    
    Supports:
    - Kokoro TTS models (when --tts-model is specified)
    """
    tts_model = multi_model_manager.tts_model
    
    # If no TTS model configured, return an error
    if not tts_model:
        raise HTTPException(
            status_code=400,
            detail="TTS not configured. Use --tts-model to specify a model."
        )
    
    # Determine model to use
    model_to_use = request.model
    if model_to_use.startswith("tts:"):
        model_to_use = tts_model
    
    # Try to use kokoro if available
    try:
        from kokoro import Kokoro
        
        # Determine model key
        model_key = f"tts:{model_to_use}"
        kokoro_model = multi_model_manager.get_model(model_key)
        
        if kokoro_model is None:
            print(f"Loading Kokoro TTS model: {model_to_use}")
            
            # Check if model_to_use is a URL - download it (with caching)
            model_path = None
            if model_to_use.startswith('http://') or model_to_use.startswith('https://'):
                # Check cache first
                cached_path = get_cached_model_path(model_to_use)
                if cached_path:
                    model_path = cached_path
                    print(f"Using cached model: {model_path}")
                else:
                    print(f"Downloading model from URL: {model_to_use}")
                    try:
                        import requests
                        import hashlib
                        
                        # Get cache directory
                        cache_dir = get_model_cache_dir()
                        
                        # Extract filename from URL
                        url_path = model_to_use.split('?')[0]
                        filename = os.path.basename(url_path)
                        
                        if not filename.endswith('.pt') and not filename.endswith('.bin'):
                            filename = "kokoro-model.pt"
                        
                        # Create safe filename in cache
                        url_hash = hashlib.sha256(model_to_use.encode()).hexdigest()
                        cached_filename = f"{url_hash}_{filename}"
                        model_path = os.path.join(cache_dir, cached_filename)
                        
                        # Download to cache
                        response = requests.get(model_to_use, stream=True)
                        response.raise_for_status()
                        
                        total_size = int(response.headers.get('content-length', 0))
                        downloaded = 0
                        
                        with open(model_path, 'wb') as f:
                            for chunk in response.iter_content(chunk_size=8192*1024):
                                if chunk:
                                    f.write(chunk)
                                    downloaded += len(chunk)
                                    if total_size > 0:
                                        percent = (downloaded / total_size) * 100
                                        print(f"Downloaded: {percent:.1f}%", end='\r')
                        
                        print(f"\nDownloaded and cached to: {model_path}")
                        
                    except Exception as e:
                        print(f"Error downloading model: {e}")
                        raise
            else:
                # Use local path or model name
                model_path = model_to_use
            
            # Load the Kokoro model
            kokoro_model = Kokoro(model_path if model_path else model_to_use)
            multi_model_manager.add_model(model_key, kokoro_model)
        
        # Generate speech
        voice = request.voice or "af_sarah"
        speed = request.speed or 1.0
        
        audio_bytes = kokoro_model.generate(request.input, voice=voice, speed=speed)
        
        # Convert to base64
        import base64
        audio_base64 = base64.b64encode(audio_bytes).decode('utf-8')
        
        return {
            "audio": audio_base64
        }
        
    except ImportError as e:
        # kokoro not installed
        raise HTTPException(
            status_code=501,
            detail=f"TTS not available. Install kokoro: pip install kokoro. Error: {str(e)}"
        )
    except Exception as e:
        print(f"TTS error: {e}")
        import traceback
        traceback.print_exc()
        raise HTTPException(status_code=500, detail=f"TTS error: {str(e)}")
@app.post("/v1/chat/completions")
async def chat_completions(request: ChatCompletionRequest):
    """Chat completions endpoint with streaming and tool support."""
    # Get the model for this request
    requested_model = request.model
    
    # Try to get the appropriate model
    mm = multi_model_manager.get_model_for_request(requested_model)
    
    if mm is None:
        # Model not loaded - try to use default
        if model_manager.backend is not None:
            # Fallback to legacy model_manager
            current_manager = model_manager
        else:
            raise HTTPException(status_code=503, detail="Model not loaded")
    else:
        current_manager = mm
    
    # Inject system prompt if --system-prompt flag was provided
    messages = request.messages
    if global_system_prompt is not None:
        # Check if there's already a system message
        has_system = any(msg.role == "system" for msg in messages)
        if not has_system:
            # Use default or custom system prompt
            if global_system_prompt is True:
                # Default system prompt
                system_text = "You are a helpful assistant."
            else:
                # Custom system prompt provided as argument
                system_text = str(global_system_prompt)
            # Insert system message at the beginning
            messages = [ChatMessage(role="system", content=system_text)] + list(messages)
    
    # Format messages with tools if provided
    if request.tools:
        messages = format_tools_for_prompt(request.tools, messages)
    
    # Get the tool_parser from the current manager
    tool_parser = current_manager.tool_parser if hasattr(current_manager, 'tool_parser') else ToolCallParser()
    
    # Prepare stop sequences
    stop_sequences = []
    if request.stop:
        if isinstance(request.stop, str):
            stop_sequences = [request.stop]
        else:
            stop_sequences = request.stop
    
    # Convert messages to dict format for chat completion
    messages_dict = []
    for msg in messages:
        msg_dict = {"role": msg.role}
        # Always include content key - llama_cpp template expects it
        # Convert content to string if it's a list (multipart content)
        content = msg.content
        if content is None:
            content = ""
        elif isinstance(content, list):
            # Handle multipart content array format: [{"type": "text", "text": "..."}]
            parts = []
            for item in content:
                if isinstance(item, dict):
                    if item.get('type') == 'text' and 'text' in item:
                        parts.append(item['text'])
                    else:
                        parts.append(f"[{item.get('type', 'unknown')} content]")
                else:
                    parts.append(str(item))
            content = '\n'.join(parts)
        # Ensure content is never None - convert to string
        msg_dict["content"] = str(content) if content is not None else ""
        # Handle tool_calls - convert to proper format if present
        if msg.tool_calls:
            # tool_calls should be a list of dicts with 'id', 'type', 'function' keys
            msg_dict["tool_calls"] = msg.tool_calls
        if msg.name:
            msg_dict["name"] = msg.name
        if msg.tool_call_id:
            msg_dict["tool_call_id"] = msg.tool_call_id
        messages_dict.append(msg_dict)
    
    # Final safety check: ensure NO message has None content before passing to llama_cpp
    # Also ensure content key always exists (not just None check)
    for i, m in enumerate(messages_dict):
        # Handle missing content key entirely
        if "content" not in m:
            messages_dict[i]["content"] = ""
        # Handle None content
        elif m.get("content") is None:
            messages_dict[i]["content"] = ""
        # Handle content that's not a string (shouldn't happen but be safe)
        elif not isinstance(m["content"], str):
            messages_dict[i]["content"] = str(m["content"])
    
    # Debug: print first few messages to see their structure
    print(f"DEBUG: messages_dict[0] keys: {list(messages_dict[0].keys()) if messages_dict else 'empty'}")
    if len(messages_dict) > 1:
        print(f"DEBUG: messages_dict[1] keys: {list(messages_dict[1].keys()) if len(messages_dict) > 1 else 'empty'}")
    
    # Convert tools to dict format if present
    tools_dict = None
    if request.tools:
        tools_dict = []
        for tool in request.tools:
            tools_dict.append({
                "type": tool.type,
                "function": {
                    "name": tool.function.name,
                    "description": tool.function.description,
                    "parameters": tool.function.parameters
                }
            })
    
    if request.stream:
        return StreamingResponse(
            stream_chat_response(
                messages_dict,
                request.model,
                request.max_tokens,
                request.temperature,
                request.top_p,
                stop_sequences,
                tools_dict,
                current_manager,
                tool_parser,
            ),
            media_type="text/event-stream",
        )
    else:
        return await generate_chat_response(
            messages_dict,
            request.model,
            request.max_tokens,
            request.temperature,
            request.top_p,
            stop_sequences,
            tools_dict,
            current_manager,
            tool_parser,
        )

async def stream_chat_response(
    messages: List[Dict],
    model_name: str,
    max_tokens: Optional[int],
    temperature: float,
    top_p: float,
    stop: List[str],
    tools: Optional[List[Dict]],
    current_manager: ModelManager,
    tool_parser: ToolCallParser,
) -> AsyncGenerator[str, None]:
    """Stream chat completion response with queue notifications."""
    completion_id = f"chatcmpl-{uuid.uuid4().hex}"
    created = int(time.time())
    request_id = f"req-{uuid.uuid4().hex[:8]}"
    
    generated_text = ""
    print(f"DEBUG: stream_chat_response started, stream=True, tools={tools is not None}")
    
    # Check if model is loaded - if not, notify waiting clients
    # The model manager exists but backend may not be loaded yet in on-demand mode
    model_loaded = False
    if current_manager is not None:
        if hasattr(current_manager, 'backend') and current_manager.backend is not None:
            # Check if backend has the model loaded
            if hasattr(current_manager.backend, 'model') and current_manager.backend.model is not None:
                model_loaded = True
        elif hasattr(current_manager, 'model') and current_manager.model is not None:
            # Alternative check for some model managers
            model_loaded = True
    
    # If model not loaded, add to queue and send waiting notifications
    if not model_loaded:
        await queue_manager.add_waiting(request_id)
        wait_interval = 2.0  # Send waiting update every 2 seconds
        last_wait_update = time.time()
        
        # Send initial waiting message
        data = {
            "id": completion_id,
            "object": "chat.completion.chunk",
            "created": created,
            "model": model_name,
            "choices": [{
                "index": 0,
                "delta": {"content": "Waiting for model to load..."},
                "finish_reason": None,
            }],
            "x_queue_info": {
                "status": "waiting",
                "message": "Model is loading, please wait...",
            },
        }
        yield f"data: {json.dumps(data)}\n\n"
        
        # Keep sending wait updates until model is loaded
        # In a real implementation, this would check a loading status
        # For now, we'll send a few updates then proceed
        max_wait_updates = 5
        wait_count = 0
        while wait_count < max_wait_updates:
            await asyncio.sleep(wait_interval)
            wait_time = await queue_manager.get_wait_time(request_id)
            wait_count += 1
            
            queue_pos = await queue_manager.get_queue_position(request_id)
            
            data = {
                "id": completion_id,
                "object": "chat.completion.chunk",
                "created": created,
                "model": model_name,
                "choices": [{
                    "index": 0,
                    "delta": {"content": f""},
                    "finish_reason": None,
                }],
                "x_queue_info": {
                    "status": "waiting",
                    "message": f"Waiting for model... ({int(wait_time)}s)",
                    "queue_position": queue_pos,
                    "wait_time_seconds": int(wait_time),
                },
            }
            yield f"data: {json.dumps(data)}\n\n"
    
    # Mark as starting processing
    await queue_manager.start_processing(request_id, model_name)
    
    # Send "Model starting" message
    data = {
        "id": completion_id,
        "object": "chat.completion.chunk",
        "created": created,
        "model": model_name,
        "choices": [{
            "index": 0,
            "delta": {"content": ""},
            "finish_reason": None,
        }],
        "x_queue_info": {
            "status": "starting",
            "message": "Model starting",
        },
    }
    yield f"data: {json.dumps(data)}\n\n"
    
    try:
        chunk_count = 0
        # Use generate_chat_stream for proper chat template handling
        async for chunk in current_manager.generate_chat_stream(
            messages=messages,
            max_tokens=max_tokens,
            temperature=temperature,
            top_p=top_p,
            stop=stop,
            tools=tools,
        ):
            chunk_count += 1
            # Filter malformed content from each chunk
            filtered_chunk = filter_malformed_content(chunk)
            
            # Always filter out tool call format - some may slip through even without tools
            filtered_chunk = tool_parser.strip_tool_calls_from_content(filtered_chunk)
            
            if not filtered_chunk:
                print(f"DEBUG: filtered_chunk was empty (original chunk: {repr(chunk[:50])})")
                continue
                
            generated_text += filtered_chunk
            
            data = {
                "id": completion_id,
                "object": "chat.completion.chunk",
                "created": created,
                "model": model_name,
                "choices": [{
                    "index": 0,
                    "delta": {"content": filtered_chunk},
                    "finish_reason": None,
                }],
            }
            yield f"data: {json.dumps(data)}\n\n"
            # Explicitly flush to ensure data is sent immediately
            await asyncio.sleep(0)
        
        print(f"DEBUG: stream_chat_response completed, {chunk_count} chunks, generated_text length: {len(generated_text)}")
        if not generated_text.strip():
            print(f"DEBUG: Warning - no content generated!")
        
        # In debug mode, dump the full generated text
        if global_debug:
            print(f"\n{'='*80}")
            print(f"=== FULL GENERATED TEXT (DEBUG) ===")
            print(f"{'='*80}")
            print(repr(generated_text))
            print(f"{'='*80}\n")
        
        # Check for tool calls in complete output (for API response format)
        if tools:
            # Convert tools back to Tool objects for parsing
            from typing import cast
            tool_objects = []
            for t in tools:
                tool_func = ToolFunction(
                    name=t["function"]["name"],
                    description=t["function"].get("description"),
                    parameters=t["function"].get("parameters")
                )
                tool_objects.append(Tool(type=t.get("type", "function"), function=tool_func))
            tool_calls = tool_parser.extract_tool_calls(generated_text, tool_objects)
            if tool_calls:
                # In debug mode, dump tool calls
                if global_debug:
                    print(f"\n{'='*80}")
                    print(f"=== EXTRACTED TOOL CALLS (DEBUG) ===")
                    print(f"{'='*80}")
                    print(json.dumps(tool_calls, indent=2))
                    print(f"{'='*80}\n")
                # Tool calls were extracted and stripped from content during streaming
                # Just send the tool_calls chunk
                data = {
                    "id": completion_id,
                    "object": "chat.completion.chunk",
                    "created": created,
                    "model": model_name,
                    "choices": [{
                        "index": 0,
                        "delta": {"tool_calls": tool_calls},
                        "finish_reason": "tool_calls",
                    }],
                }
                yield f"data: {json.dumps(data)}\n\n"
            else:
                yield f"data: {json.dumps({'choices': [{'finish_reason': 'stop'}]})}\n\n"
        else:
            yield f"data: {json.dumps({'choices': [{'finish_reason': 'stop'}]})}\n\n"
        
        yield "data: [DONE]\n\n"
    except Exception as e:
        print(f"Error during streaming generation: {e}")
        data = {
            "id": completion_id,
            "object": "chat.completion.chunk",
            "created": created,
            "model": model_name,
            "choices": [{
                "index": 0,
                "delta": {"content": f"\n[Generation error: {str(e)}]"},
                "finish_reason": "stop",
            }],
        }
        yield f"data: {json.dumps(data)}\n\n"
        yield "data: [DONE]\n\n"
    finally:
        # Always clean up queue state
        await queue_manager.finish_processing()
async def generate_chat_response(
    messages: List[Dict],
    model_name: str,
    max_tokens: Optional[int],
    temperature: float,
    top_p: float,
    stop: List[str],
    tools: Optional[List[Dict]],
    current_manager: ModelManager,
    tool_parser: ToolCallParser,
) -> Dict:
    """Generate non-streaming chat completion response."""
    completion_id = f"chatcmpl-{uuid.uuid4().hex}"
    created = int(time.time())
    
    try:
        # Use generate_chat for proper chat template handling
        generated_text = current_manager.generate_chat(
            messages=messages,
            max_tokens=max_tokens,
            temperature=temperature,
            top_p=top_p,
            stop=stop,
            tools=tools,
        )
        
        # Filter out malformed content from generated text
        generated_text = filter_malformed_content(generated_text)
        
        response_message = {
            "role": "assistant",
            "content": generated_text,
        }
        
        finish_reason = "stop"
        
        # Check for tool calls
        if tools:
            # Convert tools back to Tool objects for parsing
            tool_objects = []
            for t in tools:
                tool_func = ToolFunction(
                    name=t["function"]["name"],
                    description=t["function"].get("description"),
                    parameters=t["function"].get("parameters")
                )
                tool_objects.append(Tool(type=t.get("type", "function"), function=tool_func))
            tool_calls = tool_parser.extract_tool_calls(generated_text, tool_objects)
            if tool_calls:
                # Strip tool call format from content so user doesn't see raw tags
                clean_content = tool_parser.strip_tool_calls_from_content(generated_text)
                response_message["content"] = clean_content if clean_content.strip() else None
                response_message["tool_calls"] = tool_calls
                finish_reason = "tool_calls"
        
        # Calculate token counts - rough estimate since we don't have direct access to tokenizer
        prompt_text = "\n".join([m.get("content", "") for m in messages])
        prompt_tokens = len(prompt_text.split())
        completion_tokens = len(generated_text.split()) if generated_text else 0
        
        return {
            "id": completion_id,
            "object": "chat.completion",
            "created": created,
            "model": model_name,
            "choices": [{
                "index": 0,
                "message": response_message,
                "finish_reason": finish_reason,
            }],
            "usage": {
                "prompt_tokens": prompt_tokens,
                "completion_tokens": completion_tokens,
                "total_tokens": prompt_tokens + completion_tokens,
            },
        }
    except Exception as e:
        print(f"Error during generation: {e}")
        raise HTTPException(status_code=500, detail=f"Generation error: {str(e)}")
@app.post("/v1/completions")
async def completions(request: CompletionRequest):
    """Text completions endpoint."""
    # Get the model for this request
    requested_model = request.model
    
    # Try to get the appropriate model
    mm = multi_model_manager.get_model_for_request(requested_model)
    
    if mm is None:
        # Model not loaded - try to use default
        if model_manager.backend is not None:
            # Fallback to legacy model_manager
            current_manager = model_manager
        else:
            raise HTTPException(status_code=503, detail="Model not loaded")
    else:
        current_manager = mm
    
    prompts = request.prompt if isinstance(request.prompt, list) else [request.prompt]
    stop_sequences = []
    if request.stop:
        stop_sequences = [request.stop] if isinstance(request.stop, str) else request.stop
    
    if request.stream:
        return StreamingResponse(
            stream_completion_response(
                prompts[0],
                request.model,
                request.max_tokens,
                request.temperature,
                request.top_p,
                stop_sequences,
                current_manager,
            ),
            media_type="text/event-stream",
        )
    else:
        return await generate_completion_response(
            prompts[0],
            request.model,
            request.max_tokens,
            request.temperature,
            request.top_p,
            stop_sequences,
            current_manager,
        )
async def stream_completion_response(
    prompt: str,
    model_name: str,
    max_tokens: Optional[int],
    temperature: float,
    top_p: float,
    stop: List[str],
    current_manager: ModelManager,
) -> AsyncGenerator[str, None]:
    """Stream completion response."""
    completion_id = f"cmpl-{uuid.uuid4().hex}"
    created = int(time.time())
    
    try:
        async for chunk in current_manager.generate_stream(
            prompt=prompt,
            max_tokens=max_tokens,
            temperature=temperature,
            top_p=top_p,
            stop=stop,
        ):
            data = {
                "id": completion_id,
                "object": "text_completion",
                "created": created,
                "model": model_name,
                "choices": [{
                    "text": chunk,
                    "index": 0,
                    "logprobs": None,
                    "finish_reason": None,
                }],
            }
            yield f"data: {json.dumps(data)}\n\n"
        
        yield f"data: {json.dumps({'choices': [{'finish_reason': 'stop'}]})}\n\n"
        yield "data: [DONE]\n\n"
    except Exception as e:
        print(f"Error during streaming completion: {e}")
        yield f"data: {json.dumps({'choices': [{'finish_reason': 'stop'}]})}\n\n"
        yield "data: [DONE]\n\n"
async def generate_completion_response(
    prompt: str,
    model_name: str,
    max_tokens: Optional[int],
    temperature: float,
    top_p: float,
    stop: List[str],
    current_manager: ModelManager,
) -> Dict:
    """Generate non-streaming completion response."""
    completion_id = f"cmpl-{uuid.uuid4().hex}"
    created = int(time.time())
    
    try:
        generated_text = current_manager.generate(
            prompt=prompt,
            max_tokens=max_tokens,
            temperature=temperature,
            top_p=top_p,
            stop=stop,
        )
        
        # Calculate token counts if tokenizer available
        if current_manager.tokenizer:
            prompt_tokens = len(current_manager.tokenizer.encode(prompt))
            completion_tokens = len(current_manager.tokenizer.encode(generated_text))
        else:
            prompt_tokens = len(prompt.split())
            completion_tokens = len(generated_text.split())
        
        return {
            "id": completion_id,
            "object": "text_completion",
            "created": created,
            "model": model_name,
            "choices": [{
                "text": generated_text,
                "index": 0,
                "logprobs": None,
                "finish_reason": "stop",
            }],
            "usage": {
                "prompt_tokens": prompt_tokens,
                "completion_tokens": completion_tokens,
                "total_tokens": prompt_tokens + completion_tokens,
            },
        }
    except Exception as e:
        print(f"Error during completion: {e}")
        raise HTTPException(status_code=500, detail=f"Generation error: {str(e)}")
# =============================================================================
# Main Entry Point
# =============================================================================

def parse_args():
    """Parse command line arguments."""
    parser = argparse.ArgumentParser(
        description="OpenAI-compatible API server supporting NVIDIA (CUDA) and Vulkan backends"
    )
    parser.add_argument(
        "--model",
        type=str,
        action="append",
        default=None,
        help="Model name, path, or URL for text-to-text LLM. Can be specified multiple times for multiple models.",
    )
    parser.add_argument(
        "--model-alias",
        type=str,
        action="append",
        default=None,
        dest="model_aliases",
        nargs=2,
        metavar=("ALIAS", "MODEL"),
        help="Register an alias for a model. Format: --model-alias <alias_name> <actual_model>",
    )
    parser.add_argument(
        "--backend",
        type=str,
        choices=["auto", "nvidia", "vulkan"],
        default="auto",
        help="Backend to use: auto (detect), nvidia (CUDA), or vulkan (AMD GPUs)",
    )
    parser.add_argument(
        "--host",
        type=str,
        default="0.0.0.0",
        help="Host to bind to (default: 0.0.0.0)",
    )
    parser.add_argument(
        "--port",
        type=int,
        default=8000,
        help="Port to bind to (default: 8000)",
    )
    parser.add_argument(
        "--offload-dir",
        type=str,
        default="./offload",
        help="Directory for disk offload (NVIDIA backend only, default: ./offload)",
    )
    parser.add_argument(
        "--load-in-4bit",
        action="store_true",
        help="Load model in 4-bit precision (NVIDIA backend only, requires bitsandbytes)",
    )
    parser.add_argument(
        "--load-in-8bit",
        action="store_true",
        help="Load model in 8-bit precision (NVIDIA backend only, requires bitsandbytes)",
    )
    parser.add_argument(
        "--ram",
        type=float,
        default=None,
        help="Maximum CPU RAM to use for model offloading in GB (NVIDIA backend only). Auto-detected if not specified. Disk offloading only occurs after this limit is exceeded.",
    )
    parser.add_argument(
        "--flash-attn",
        action="store_true",
        help="Use Flash Attention 2 (NVIDIA backend only, requires flash-attn package)",
    )
    parser.add_argument(
        "--offload-strategy",
        type=str,
        choices=["auto", "conservative", "balanced", "aggressive", "sequential"],
        default="auto",
        help="Offload strategy for NVIDIA backend (default: auto)",
    )
    parser.add_argument(
        "--max-gpu-percent",
        type=float,
        default=None,
        help="Maximum GPU VRAM to use as percentage (0-100). Overrides offload-strategy. Lower values offload more to CPU/RAM (default: None = use offload-strategy)",
    )
    parser.add_argument(
        "--n-gpu-layers",
        type=int,
        default=-1,
        help="Number of layers to offload to GPU (Vulkan backend only, default: -1 = all layers)",
    )
    parser.add_argument(
        "--n-ctx",
        type=int,
        default=2048,
        help="Context window size (Vulkan backend only, default: 2048)",
    )
    parser.add_argument(
        "--vulkan-device",
        type=int,
        default=0,
        help="Vulkan GPU device ID to use (Vulkan backend only, default: 0). Use --vulkan-list-devices to see available devices",
    )
    parser.add_argument(
        "--vulkan-single-gpu",
        action="store_true",
        help="Force Vulkan to use only the specified GPU device (prevents layer distribution across multiple GPUs)",
    )
    parser.add_argument(
        "--vulkan-list-devices",
        action="store_true",
        help="List available Vulkan GPU devices and exit",
    )
    parser.add_argument(
        "--system-prompt",
        nargs="?",
        const=True,
        default=None,
        help="Inject a system prompt at the beginning of conversations. Use without a value for a default prompt, or provide custom text.",
    )
    # Multi-model arguments
    parser.add_argument(
        "--tts-model",
        type=str,
        default=None,
        help="Model for text-to-speech (e.g., kokoro, or path/URL to Kokoro model). Can be specified multiple times.",
    )
    parser.add_argument(
        "--audio-model",
        type=str,
        action="append",
        default=None,
        help="Model for audio transcription (e.g., whisper-1, base, or path to faster-whisper model). Can be specified multiple times for multiple models.",
    )
    parser.add_argument(
        "--audio-1",
        action="store_true",
        help="Disable request queue for audio models - return 409 if model is busy",
    )
    parser.add_argument(
        "--image-model",
        type=str,
        action="append",
        default=None,
        help="Model for image generation (e.g., stable-diffusion-xl-base-1.0). Can be specified multiple times for multiple models.",
    )
    parser.add_argument(
        "--image-1",
        action="store_true",
        help="Disable request queue for image models - return 409 if model is busy",
    )
    parser.add_argument(
        "--llm-path",
        type=str,
        default=None,
        help="Path to CLIP LLM model for image generation (stable-diffusion-cpp-python).",
    )
    parser.add_argument(
        "--vae-path",
        type=str,
        default=None,
        help="Path to VAE model for image generation (stable-diffusion-cpp-python).",
    )
    parser.add_argument(
        "--image-sample-method",
        type=str,
        default="res_multistep",
        help="Sample method for image generation (default: res_multistep for Z-Image Turbo).",
    )
    parser.add_argument(
        "--image-steps",
        type=int,
        default=4,
        help="Number of inference steps for image generation (default: 4 for Z-Image Turbo).",
    )
    parser.add_argument(
        "--image-width",
        type=int,
        default=512,
        help="Image width for generation (default: 512).",
    )
    parser.add_argument(
        "--image-height",
        type=int,
        default=512,
        help="Image height for generation (default: 512).",
    )
    parser.add_argument(
        "--image-cfg-scale",
        type=float,
        default=1.0,
        help="CFG scale for image generation (default: 1.0 for Z-Image Turbo).",
    )
    parser.add_argument(
        "--loadall",
        action="store_true",
        help="Pre-load all models (main, audio, image) at startup instead of on-demand",
    )
    parser.add_argument(
        "--loadswap",
        action="store_true",
        help="Keep all models loaded, swapping active model between VRAM and RAM (only active model in VRAM)",
    )
    parser.add_argument(
        "--audio-ctx",
        type=int,
        default=480000,
        help="Audio model context size in milliseconds (default: 480000 = 30 seconds for Whisper)",
    )
    parser.add_argument(
        "--audio-offload",
        type=float,
        default=None,
        help="Audio model GPU offload percentage (0-100). If not set, uses CPU",
    )
    parser.add_argument(
        "--audio-vulkan-device",
        type=int,
        default=0,
        help="Vulkan GPU device ID to use for Whisper audio transcription (default: 0). Only used when using Vulkan backend.",
    )
    parser.add_argument(
        "--image-vulkan-device",
        type=int,
        default=None,
        help="Vulkan GPU device ID to use for image generation models (default: same as --vulkan-device). Use --vulkan-list-devices to see available devices",
    )

    parser.add_argument(
        "--whisper-cpp",
        type=str,
        default=None,
        help="Path to whisper.cpp CLI executable (e.g., ~/whisper.cpp/build/bin/whisper-cli). Uses Vulkan if available.",
    )
    parser.add_argument(
        "--whisper-server",
        type=str,
        default=None,
        help="Path to whisper.cpp server executable (e.g., ~/whisper.cpp/build/bin/whisper-server). Keeps model loaded in VRAM.",
    )
    parser.add_argument(
        "--whisper-server-port",
        type=int,
        default=8744,
        help="Port for whisper-server (default: 8744).",
    )
    parser.add_argument(
        "--vision-ctx",
        type=int,
        default=2048,
        help="Vision model context size (default: 2048)",
    )
    parser.add_argument(
        "--vision-offload",
        type=float,
        default=None,
        help="Vision model GPU offload percentage (0-100). If not set, loads fully on GPU",
    )
    parser.add_argument(
        "--list-cached-models",
        action="store_true",
        help="List all cached models in the model cache directory",
    )
    parser.add_argument(
        "--remove-all-models",
        action="store_true",
        help="Remove all cached models from the model cache directory",
    )
    parser.add_argument(
        "--remove-model",
        type=str,
        default=None,
        help="Remove a specific cached model by name or hash (partial match)",
    )
    parser.add_argument(
        "--debug",
        action="store_true",
        help="Enable debug mode - dumps full request/response to stdout for troubleshooting",
    )
    return parser.parse_args()
def main():
    """Main entry point."""
    global global_system_prompt, model_manager, multi_model_manager, global_debug, global_args
    
    # Suppress unraisable exceptions from LlamaModel.__del__
    import sys
    original_unraisablehook = sys.unraisablehook
    def suppress_llama_del_errors(unraisable):
        if isinstance(unraisable.exc_value, AttributeError) and 'LlamaModel' in repr(unraisable.object) and 'sampler' in str(unraisable.exc_value):
            return  # Ignore this specific error
        original_unraisablehook(unraisable)
    sys.unraisablehook = suppress_llama_del_errors
    
    # Optional: set process name if procname is available
    try:
        import procname
        procname.setprocname("coderai")
    except ImportError:
        pass
    args = parse_args()
    
    # Store args globally for access in endpoints
    global_args = args
    
    # Set global system prompt from --system-prompt flag
    global_system_prompt = args.system_prompt
    
    # Set global debug flag
    global_debug = args.debug
    if global_debug:
        # Print the full command line that was used to invoke coderai
        import shlex
        cmd_line = ' '.join(shlex.quote(arg) for arg in sys.argv)
        print(f"\n{'='*80}")
        print(f"=== COMMAND LINE: {cmd_line}")
        print(f"{'='*80}\n")
        print("DEBUG MODE ENABLED - Full requests and replies will be dumped to stdout")
    
    # Handle --vulkan-list-devices
    if args.vulkan_list_devices:
        print("\nListing Vulkan devices...")
        try:
            import subprocess
            result = subprocess.run(['vulkaninfo', '--summary'], capture_output=True, text=True)
            if result.returncode == 0:
                print(result.stdout)
            else:
                print("Could not run vulkaninfo. Make sure vulkan-tools is installed.")
        except Exception as e:
            print(f"Error listing devices: {e}")
        sys.exit(0)
    
    # Handle --list-cached-models
    if args.list_cached_models:
        print("\nListing cached models...")
        cache_dir = get_model_cache_dir()
        print(f"Cache directory: {cache_dir}")
        print("")
        
        if not os.path.exists(cache_dir):
            print("Cache directory does not exist.")
            sys.exit(0)
        
        files = os.listdir(cache_dir)
        if not files:
            print("No cached models found.")
            sys.exit(0)
        
        print(f"Found {len(files)} cached files:")
        total_size = 0
        for f in sorted(files):
            filepath = os.path.join(cache_dir, f)
            size = os.path.getsize(filepath)
            total_size += size
            size_mb = size / (1024 * 1024)
            print(f"  {f} ({size_mb:.1f} MB)")
        
        print(f"\nTotal: {len(files)} files, {total_size / (1024*1024*1024):.2f} GB")
        sys.exit(0)
    
    # Handle --remove-all-models
    if args.remove_all_models:
        print("\nRemoving all cached models...")
        cache_dir = get_model_cache_dir()
        
        if not os.path.exists(cache_dir):
            print("Cache directory does not exist.")
            sys.exit(0)
        
        files = os.listdir(cache_dir)
        if not files:
            print("No cached models to remove.")
            sys.exit(0)
        
        print(f"Found {len(files)} cached files. Deleting...")
        import shutil
        for f in files:
            filepath = os.path.join(cache_dir, f)
            os.remove(filepath)
            print(f"  Deleted: {f}")
        
        print(f"\nAll {len(files)} cached models removed.")
        sys.exit(0)
    
    # Handle --remove-model
    if args.remove_model:
        print(f"\nRemoving cached model matching: {args.remove_model}")
        cache_dir = get_model_cache_dir()
        
        if not os.path.exists(cache_dir):
            print("Cache directory does not exist.")
            sys.exit(0)
        
        files = os.listdir(cache_dir)
        # Find files that contain the search term
        matching = [f for f in files if args.remove_model.lower() in f.lower()]
        
        if not matching:
            print(f"No cached models found matching: {args.remove_model}")
            print(f"\nUse --list-cached-models to see available models.")
            sys.exit(0)
        
        print(f"Found {len(matching)} matching file(s):")
        for f in matching:
            filepath = os.path.join(cache_dir, f)
            size = os.path.getsize(filepath)
            print(f"  {f} ({size / (1024*1024):.1f} MB)")
        
        # Confirm before deleting
        print(f"\nDeleting {len(matching)} file(s)...")
        for f in matching:
            filepath = os.path.join(cache_dir, f)
            os.remove(filepath)
            print(f"  Deleted: {f}")
        
        print(f"\nRemoved {len(matching)} cached model(s).")
        sys.exit(0)
    
    # Get model names from args - support multiple models
    model_names = args.model if args.model else []
    
    # Validate: must have at least one model specified
    audio_models = args.audio_model if args.audio_model else []
    image_models = args.image_model if args.image_model else []
    
    if not model_names and not audio_models and not image_models and args.tts_model is None:
        print("Error: At least one of --model, --audio-model, --image-model, or --tts-model must be specified.")
        print("")
        print("For NVIDIA backend (HuggingFace models):")
        print("  - microsoft/DialoGPT-medium")
        print("  - meta-llama/Llama-2-7b-chat-hf (requires auth)")
        print("  - TinyLlama/TinyLlama-1.1B-Chat-v1.0")
        print("  - Use multiple --model flags for multiple models")
        print("")
        print("For Vulkan backend (GGUF models):")
        print("  - Local path: ./phi-3-mini-4k-instruct-q4_k_m.gguf")
        print("  - HuggingFace: microsoft/Phi-3-mini-4k-instruct-gguf")
        print("  - URL: https://huggingface.co/.../model.gguf")
        print("")
        print("For audio transcription:")
        print("  - --audio-model base")
        print("")
        print("For text-to-speech:")
        print("  - --tts-model kokoro")
        print("")
        print("For image generation:")
        print("  - --image-model stabilityai/stable-diffusion-xl-base-1.0")
        sys.exit(1)
    
    # Print loaded models info
    if model_names:
        print(f"\nText model(s): {model_names}")
        if len(model_names) > 1:
            # Load mode will be determined below
            print(f"Multiple models configured - load mode will be set based on --loadall/--loadswap flags")
    
    # Detect available backends
    available = detect_available_backends()
    print("\nAvailable backends:")
    for name, available_flag in available.items():
        status = "✓" if available_flag else "✗"
        print(f"  [{status}] {name}")
    print("")
    
    # Load the main model (only if specified)
    if model_names:
        load_kwargs = {
            'offload_dir': args.offload_dir,
            'load_in_4bit': args.load_in_4bit,
            'load_in_8bit': args.load_in_8bit,
            'manual_ram_gb': args.ram,
            'flash_attn': args.flash_attn,
            'offload_strategy': args.offload_strategy,
            'max_gpu_percent': args.max_gpu_percent,
            'n_gpu_layers': args.n_gpu_layers,
            'n_ctx': args.n_ctx,
            'main_gpu': args.vulkan_device,
            'single_gpu': args.vulkan_single_gpu,
        }
        
        # Load the first model
        first_model_name = model_names[0]
        try:
            model_manager.load_model(
                model_name=first_model_name,
                backend_type=args.backend,
                **load_kwargs
            )
            # Register with multi_model_manager
            multi_model_manager.set_default_model(first_model_name, load_kwargs)
            multi_model_manager.add_model(first_model_name, model_manager)
            print(f"\nMain text model loaded: {first_model_name}")
        except Exception as e:
            print(f"\nError loading model: {e}")
            error_str = str(e).lower()
            print("\nTroubleshooting:")
            if args.backend == "vulkan":
                print("  - For Vulkan, ensure you have Vulkan drivers installed")
                print("  - Make sure you're using a GGUF format model")
                print("  - Run build.sh with 'vulkan' argument first")
            else:
                print("  - For NVIDIA, ensure PyTorch with CUDA is installed")
                print("  - Run build.sh with 'nvidia' argument first")
                if "tokenizer" in error_str or "sentencepiece" in error_str or "tiktoken" in error_str:
                    print("  - Tokenizer error: ensure sentencepiece and tiktoken are installed")
                    print("    pip install sentencepiece tiktoken tokenizers")
                # Check if trying to load GGUF model with NVIDIA backend
                if "gguf" in first_model_name.lower():
                    print(f"\n  *** IMPORTANT: '{first_model_name}' appears to be a GGUF model ***")
                    print("  GGUF models are NOT compatible with the NVIDIA backend.")
                    print("  Use --backend vulkan instead, or choose a HuggingFace Transformers model.")
                    print("\n  Example Vulkan command:")
                    print(f"    coderai --backend vulkan --model {first_model_name}")
            sys.exit(1)
    else:
        print("\nNo main text model specified (--model). Running with audio/image/TTS models only.")
    
    # Determine load mode BEFORE setting up other models
    load_mode = "ondemand"
    if args.loadall:
        load_mode = "loadall"
    elif args.loadswap:
        load_mode = "loadswap"
    
    # Set load mode in multi_model_manager
    multi_model_manager.set_load_mode(load_mode)
    
    # Load models based on mode and count
    if len(model_names) > 1:
        # Multiple models - handle based on load mode
        print(f"\n=== Multiple Models Mode: {load_mode} ===")
        
        if load_mode == "loadall":
            # Load all models into VRAM
            # Skip first model if it's already loaded (at lines 4274-4281)
            start_index = 1 if model_names[0] in multi_model_manager.models else 0
            for i in range(start_index, len(model_names)):
                model_name = model_names[i]
                print(f"\nLoading model {i+1}/{len(model_names)}: {model_name}")
                try:
                    manager = ModelManager()
                    manager.load_model(
                        model_name=model_name,
                        backend_type=args.backend,
                        **load_kwargs
                    )
                    multi_model_manager.add_model(model_name, manager)
                    print(f"Loaded: {model_name}")
                except Exception as e:
                    print(f"Error loading {model_name}: {e}")
        
        elif load_mode == "loadswap":
            # First model in VRAM, others in RAM
            # Skip first model if it's already loaded (at lines 4274-4281)
            start_index = 1 if model_names[0] in multi_model_manager.models else 0
            for i in range(start_index, len(model_names)):
                model_name = model_names[i]
                # In loadswap, all additional models go to RAM (CPU-only)
                print(f"\nLoading model {i+1}/{len(model_names)}: {model_name} (RAM)")
                try:
                    manager = ModelManager()
                    # Modify kwargs for CPU-only loading
                    swap_kwargs = load_kwargs.copy()
                    swap_kwargs['n_gpu_layers'] = 0  # Force CPU only for swap mode
                    manager.load_model(
                        model_name=model_name,
                        backend_type=args.backend,
                        **swap_kwargs
                    )
                    multi_model_manager.add_model(model_name, manager)
                    print(f"Loaded: {model_name} (RAM)")
                except Exception as e:
                    print(f"Error loading {model_name}: {e}")
        
        else:  # ondemand
            # First model already loaded at lines 4274-4281
            # Just register other models but don't load them
            print(f"\nFirst model already loaded: {model_names[0]}")
            
            # Register other models but don't load them
            for model_name in model_names[1:]:
                multi_model_manager.set_default_model(model_name, load_kwargs)
            
            print(f"Other models will load on-demand: {model_names[1:]}")
    # Model is already loaded at lines 4274-4281
    
    # Determine load mode BEFORE setting up other models
    load_mode = "ondemand"
    if args.loadall:
        load_mode = "loadall"
    elif args.loadswap:
        load_mode = "loadswap"
    
    # Set load mode in multi_model_manager
    multi_model_manager.set_load_mode(load_mode)
    
    # Pre-load models based on mode
    print(f"DEBUG: load_mode at line 4710 = '{load_mode}', backend = {args.backend}")
    if load_mode in ("loadall", "loadswap"):
        # Load all models into VRAM (or RAM for CUDA loadswap)
        mode_name = "Load All" if load_mode == "loadall" else "Load Swap"
        print(f"\n=== {mode_name} Mode ===")
        
        # Load main text model first
        if model_names:
            print(f"Pre-loading main text model: {model_names[0]}")
        
        # Load image model (first one only in loadall mode currently)
        print(f"DEBUG: image_models check at line 4718: {image_models}, backend = {args.backend}")
        if image_models:
            print(f"Pre-loading image model: {image_models[0]}")
            
            # Get the original model name
            original_model_name = image_models[0]
            
            # Check if it's a URL first (before any processing)
            is_url = original_model_name.startswith('http://') or original_model_name.startswith('https://')
            
            # Strip query parameters from URL if present
            model_name = original_model_name
            if '?' in model_name:
                model_name = model_name.split('?')[0]
            
            # Check if the image model is a GGUF model
            is_gguf = model_name.endswith('.gguf') or 'gguf' in model_name.lower()
            
            if is_gguf:
                # Load GGUF image model using llama.cpp
                print(f"Detected GGUF image model, loading with llama.cpp...")
                try:
                    from llama_cpp import Llama
                    
                    # Download GGUF model if needed (similar to VulkanBackend)
                    model_path = None
                    
                    # Check if it's a URL - download directly
                    if is_url:
                        print(f"Image model is a URL: {original_model_name}")
                        cached_path = get_cached_model_path(original_model_name)
                        if cached_path:
                            model_path = cached_path
                            print(f"Using cached GGUF model: {model_path}")
                        else:
                            print(f"Downloading GGUF model: {original_model_name}")
                            cache_dir = get_model_cache_dir()
                            model_path = download_model(original_model_name, cache_dir)
                    elif os.path.isfile(model_name):
                        # Local file
                        model_path = model_name
                        print(f"Loading local GGUF model: {model_path}")
                    else:
                        # Try to download from HuggingFace Hub
                        print(f"Trying to resolve as HuggingFace model: {model_name}")
                        try:
                            from huggingface_hub import hf_hub_download, list_repo_files
                            parts = model_name.split('/')
                            if len(parts) >= 2:
                                repo_id = f"{parts[0]}/{parts[1]}"
                                print(f"Looking for GGUF files in repo: {repo_id}")
                                files = list_repo_files(repo_id)
                                gguf_files = [f for f in files if f.endswith('.gguf')]
                                if not gguf_files:
                                    raise ValueError(f"No GGUF files found in {repo_id}")
                                filename = gguf_files[0]
                                model_path = hf_hub_download(repo_id=repo_id, filename=filename)
                                print(f"Downloaded GGUF model to: {model_path}")
                        except Exception as e:
                            print(f"Could not resolve GGUF model path: {e}")
                            print(f"Image model will load on first request")
                            model_path = None
                    
                    if model_path and os.path.isfile(model_path):
                        # Use the cached path for the model key
                        model_key = f"image:{model_path}"
                        
                        # Load with llama.cpp
                        n_gpu_layers = -1  # Load all layers to GPU
                        n_ctx = 2048
                        
                        print(f"Loading GGUF model from: {model_path}")
                        file_size = os.path.getsize(model_path)
                        print(f"GGUF model file size: {file_size / (1024*1024):.1f} MB")
                        
                        # Verify it's a valid GGUF file (check magic bytes)
                        with open(model_path, 'rb') as f:
                            magic = f.read(8)
                            print(f"File magic bytes: {magic}")
                            if not magic.startswith(b'GGUF'):
                                print(f"ERROR: File is NOT a valid GGUF! Expected 'GGUF', got: {magic}")
                                print(f"This means the download returned an HTML error page instead of the model.")
                                print(f"The URL must be a DIRECT download link (ends with .gguf, not a model page)")
                                print(f"Example: https://huggingface.co/owner/repo/resolve/main/model.gguf")
                                print(f"Image model will load on first request")
                            else:
                                # Valid GGUF, try to load
                                try:
                                    llama_model = Llama(
                                        model_path=model_path,
                                        n_gpu_layers=n_gpu_layers,
                                        n_ctx=n_ctx,
                                        verbose=True,
                                    )
                                    multi_model_manager.add_model(model_key, llama_model)
                                    print(f"GGUF image model loaded successfully: {original_model_name}")
                                except Exception as llama_error:
                                    print(f"llama.cpp load error: {llama_error}")
                                    print(f"Trying stable-diffusion-cpp-python fallback...")
                                    # Try stable-diffusion-cpp-python as fallback
                                    try:
                                        from stable_diffusion_cpp import StableDiffusion
                                        
                                        # Initialize model_key to None first so Python knows it exists
                                        model_key = None
                                        
                                        # Define model_key for this scope
                                        model_key = f"image:{model_path}"
                                        print(f"Loading with sd.cpp: {model_path}")
                                        # For models like Z-Image-Turbo/Flux, use diffusion_model_path
                                        # Look for additional model files in same directory
                                        model_dir = os.path.dirname(model_path)
                                        model_name = os.path.basename(model_path)
                                        
                                        # Try to find additional model files
                                        clip_l_path = None
                                        t5xxl_path = None
                                        vae_path = None
                                        
                                        # Use CLI arguments if provided, download and cache if URL
                                        if args.llm_path:
                                            # Check if it's a Hugging Face model ID, URL, or local path
                                            if is_huggingface_model_id(args.llm_path):
                                                # Download from Hugging Face
                                                print(f"Attempting to download LLM model from Hugging Face: {args.llm_path}")
                                                cache_dir = get_model_cache_dir()
                                                clip_l_path = download_huggingface_model(args.llm_path, cache_dir, '.gguf')
                                                if clip_l_path:
                                                    print(f"Downloaded LLM model to: {clip_l_path}")
                                                else:
                                                    print(f"Warning: Failed to download LLM model from Hugging Face, will try as local path")
                                            elif args.llm_path.startswith('http://') or args.llm_path.startswith('https://'):
                                                cached = get_cached_model_path(args.llm_path)
                                                if cached:
                                                    clip_l_path = cached
                                                    print(f"Using cached LLM model: {clip_l_path}")
                                                else:
                                                    cache_dir = get_model_cache_dir()
                                                    clip_l_path = download_model(args.llm_path, cache_dir)
                                                    print(f"Downloaded LLM model to: {clip_l_path}")
                                            else:
                                                clip_l_path = args.llm_path
                                        if args.vae_path:
                                            # Check if it's a URL and download if needed
                                            if args.vae_path.startswith('http://') or args.vae_path.startswith('https://'):
                                                cached = get_cached_model_path(args.vae_path)
                                                if cached:
                                                    vae_path = cached
                                                    print(f"Using cached VAE model: {vae_path}")
                                                else:
                                                    cache_dir = get_model_cache_dir()
                                                    vae_path = download_model(args.vae_path, cache_dir)
                                                    print(f"Downloaded VAE model to: {vae_path}")
                                            else:
                                                vae_path = args.vae_path
                                        
                                        # Look for common file patterns only if CLI args not provided
                                        if not args.llm_path or not args.vae_path:
                                            for f in os.listdir(model_dir) if os.path.exists(model_dir) else []:
                                                if not args.llm_path and 'clip_l' in f.lower() and f.endswith(('.safetensors', '.bin')):
                                                    clip_l_path = os.path.join(model_dir, f)
                                                elif 't5xxl' in f.lower() and f.endswith(('.safetensors', '.bin')):
                                                    t5xxl_path = os.path.join(model_dir, f)
                                                elif not args.vae_path and f.endswith('.safetensors') and 'ae' in f.lower():
                                                    vae_path = os.path.join(model_dir, f)
                                        
                                        # Build kwargs based on available files
                                        sd_kwargs = {'diffusion_model_path': model_path}
                                        
                                        if clip_l_path:
                                            sd_kwargs['llm_path'] = clip_l_path
                                            print(f"DEBUG: Adding llm_path to sd_kwargs: {clip_l_path}")
                                        else:
                                            print(f"DEBUG: clip_l_path is None or empty, not adding to sd_kwargs")
                                            print(f"DEBUG: args.llm_path = {args.llm_path}")
                                        if args.vae_path:
                                            sd_kwargs['vae_path'] = vae_path
                                        elif vae_path:
                                            sd_kwargs['vae_path'] = vae_path
                                        if t5xxl_path:
                                            sd_kwargs['t5xxl_path'] = t5xxl_path
                                        

                                        # Add generation parameters from CLI args
                                        # sd_kwargs['sample_method'] = args.image_sample_method  # Not valid for __init__
                                        # sd_kwargs['steps'] = args.image_steps  # Not valid for __init__
                                        
                                        sd_model = StableDiffusion(**sd_kwargs)
                                        multi_model_manager.add_model(model_key, sd_model)
                                        # Add aliases for "image" and "vision" 
                                        multi_model_manager.add_model("image", sd_model)
                                        multi_model_manager.add_model("vision", sd_model)
                                        print(f"Image model loaded successfully via sd.cpp: {original_model_name}")
                                    except ImportError as sd_error:
                                        print(f"stable-diffusion-cpp-python not installed: {sd_error}")
                                        print(f"Image model will load on first request")
                                    except Exception as sd_error:
                                        print(f"sd.cpp load error: {sd_error}")
                                        print(f"Image model will load on first request")
                    else:
                        print(f"Could not load GGUF image model: no valid model path")
                        
                except ImportError as e:
                    print(f"Warning: llama_cpp not installed: {e}")
                    print(f"Image model will load on first request")
                except Exception as e:
                    print(f"Warning: Failed to pre-load GGUF image model: {e}")
                    print(f"Image model will load on first request")
            else:
                # Load diffusers image model (Stable Diffusion)
                try:
                    import torch
                    from diffusers import StableDiffusionXLPipeline, DiffusionPipeline
                    
                    # Use model name directly for diffusers (model_path is only set in GGUF branch)
                    model_key = f"image:{model_name}"
                    print(f"Loading diffusers pipeline: {model_name}")
                    
                    # Try to load as Stable Diffusion XL first
                    try:
                        pipeline = StableDiffusionXLPipeline.from_pretrained(
                            model_name,
                            torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
                            use_safetensors=True,
                        )
                    except Exception as e:
                        print(f"SDXL failed, trying generic pipeline: {e}")
                        # Try generic diffusion pipeline
                        pipeline = DiffusionPipeline.from_pretrained(
                            model_name,
                            torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
                            use_safetensors=True,
                        )
                    
                    # Move to GPU if available
                    if torch.cuda.is_available():
                        pipeline = pipeline.to("cuda")
                        pipeline.enable_attention_slicing()
                    else:
                        pipeline = pipeline.to("cpu")
                    
                    multi_model_manager.add_model(model_key, pipeline)
                    # Add aliases for "image" and "vision"
                    multi_model_manager.add_model("image", pipeline)
                    multi_model_manager.add_model("vision", pipeline)
                    print(f"Image model loaded successfully: {model_name}")
                    
                except ImportError as e:
                    print(f"Warning: diffusers not installed, image model will load on first request: {e}")
                except Exception as e:
                    print(f"Warning: Failed to pre-load image model: {e}")
                    print(f"  Image model will load on first request")
        
        # Load audio model
        print(f"DEBUG: audio_models check at line 4970: {audio_models}")
        if audio_models:
            print(f"Pre-loading audio model: {audio_models[0]}")
        
        # Load TTS model
        if args.tts_model:
            print(f"Pre-loading TTS model: {args.tts_model}")
            
    elif load_mode == "loadswap":
        # Load models in order: model > image > audio > TTS, keep active in VRAM
        # For Vulkan backend, load all models to VRAM like loadall (VRAM is not limited like CUDA)
        print("\n=== Load Swap Mode ===")
        
        # For Vulkan, use same preloading as loadall
        if args.backend == "vulkan":
            # Vulkan: Load all models to GPU like loadall
            if model_names:
                print(f"Pre-loading main text model: {model_names[0]}")
            if image_models:
                print(f"Pre-loading image model: {image_models[0]}")
            if audio_models:
                print(f"Pre-loading audio model: {audio_models[0]}")
            if args.tts_model:
                print(f"Pre-loading TTS model: {args.tts_model}")
        else:
            # NVIDIA/CUDA: First model in VRAM, others in RAM
            if model_names:
                print(f"Main text model will be in VRAM: {model_names[0]}")
            if image_models:
                print(f"Image model in RAM: {image_models[0]}")
            if audio_models:
                print(f"Audio model in RAM: {audio_models[0]}")
            if args.tts_model:
                print(f"TTS model in RAM: {args.tts_model}")
        
    else:
        # No flags: only one model gets loaded (the main text model if specified)
        print("\n=== On-Demand Mode ===")
        print("Models will load on first request")
    
    # Set up audio model if specified (with pre-loading if in loadall/loadswap mode)
    print(f"DEBUG: models in manager before audio setup: {list(multi_model_manager.models.keys())}")
    if audio_models:
        print(f"\nAudio transcription model(s): {audio_models}")
        
        # Set up Vulkan device for Whisper if using Vulkan backend
        if hasattr(args, 'audio_vulkan_device') and args.audio_vulkan_device is not None:
            print(f"  Using Vulkan device: {args.audio_vulkan_device}")
        
        # Register all audio models
        print(f"DEBUG: Registering audio models: {audio_models}")
        for audio_m in audio_models:
            multi_model_manager.set_audio_model(audio_m, {
                'ctx': args.audio_ctx,
                'offload': args.audio_offload,
            })
        print(f"DEBUG: After registration, audio_models in manager: {multi_model_manager.audio_models}")
        
        # Pre-load first audio model at startup if:
        # - Using loadall or loadswap mode, OR
        # - No main model is specified (only audio model configured)
        print(f"DEBUG: load_mode at line 5015 = '{load_mode}', model_names = {model_names}, audio_models = {audio_models}")
        should_preload = load_mode in ("loadall", "loadswap") or (not model_names and audio_models)
        print(f"DEBUG: should_preload = {should_preload}")
        
        # Initialize whisper-server if specified
        if args.whisper_server:
            print(f"\nWhisper server: {args.whisper_server}")
            print(f"  Port: {args.whisper_server_port}")
            # Check if whisper-server is already running
            if multi_model_manager.whisper_server is None:
                whisper_server_mgr = WhisperServerManager(
                    server_path=args.whisper_server,
                    port=args.whisper_server_port
                )
                multi_model_manager.whisper_server = whisper_server_mgr
            else:
                whisper_server_mgr = multi_model_manager.whisper_server
                print("Whisper server already running, using existing instance")
            
            # Start whisper-server if we should preload or if it's the only audio option
            print(f"DEBUG: whisper-server start check - audio_models={audio_models}, should_preload={should_preload}, whisper_cpp={args.whisper_cpp}")
            if audio_models and (should_preload or not args.whisper_cpp):
                model_to_use = audio_models[0] if audio_models else None
                gpu_device = getattr(args, 'audio_vulkan_device', 0) or 0
                print(f"DEBUG: Starting whisper-server with gpu_device={gpu_device}")
                actual_model_path = whisper_server_mgr.start(model_path=model_to_use, gpu_device=gpu_device)
                if actual_model_path:
                    # Update audio_models in multi_model_manager to store the actual path (not the URL)
                    if model_to_use != actual_model_path:
                        # Update the manager's audio_models list
                        if multi_model_manager.audio_models and multi_model_manager.audio_models[0] == model_to_use:
                            multi_model_manager.audio_models[0] = actual_model_path
                    print(f"Whisper server started with model: {actual_model_path}")
                else:
                    print("Warning: Failed to start whisper-server, falling back to other backends")
        elif should_preload:
            print(f"Pre-loading audio model... {audio_models[0]}")
            
            # Use first audio model for pre-loading
            model_to_use = audio_models[0]
            is_gguf_model = model_to_use.endswith('.gguf') or 'gguf' in model_to_use.lower()
            
            if is_gguf_model:
                # Skip faster-whisper for GGUF files - it doesn't support them
                # Go directly to whispercpp
                print("Detected GGUF model - using whispercpp backend")
                faster_whisper_failed = True
            else:
                # Try faster-whisper first
                faster_whisper_failed = False
            try:
                # Try faster-whisper first (requires torch)
                from faster_whisper import WhisperModel
                import torch
                
                model_to_use = audio_models[0]
                model_path = None
                
                # Check if model is a URL - handle caching
                if model_to_use.startswith('http://') or model_to_use.startswith('https://'):
                    cached_path = get_cached_model_path(model_to_use)
                    if cached_path:
                        model_path = cached_path
                        print(f"Using cached model: {model_path}")
                    else:
                        # Download with progress
                        cache_dir = get_model_cache_dir()
                        model_path = download_model(model_to_use, cache_dir)
                        model_to_use = model_path
                
                # Determine compute type - always use CPU on Vulkan backend
                # faster-whisper CUDA doesn't work with AMD/Vulkan GPUs
                compute_type = "int8"
                
                # Load the model - always use CPU (faster-whisper CUDA doesn't work with AMD/Vulkan)
                whisper_model = WhisperModel(
                    model_to_use,
                    device="cpu",
                    compute_type=compute_type
                )
                
                # Store in multi_model_manager
                model_key = f"audio:{audio_models[0]}"
                multi_model_manager.add_model(model_key, whisper_model)
                print(f"Audio model loaded successfully (faster-whisper)")
                
                # Warn if using CPU (no CUDA available)
                import torch
                if not torch.cuda.is_available():
                    print("Note: faster-whisper is running on CPU (no CUDA GPU detected)")
                    print("      For GPU acceleration, use NVIDIA GPU with CUDA or wait for Vulkan support.")
                
            except ImportError:
                # faster-whisper not available, will try whispercpp below
                faster_whisper_failed = True
            except Exception as e:
                # faster-whisper failed for some other reason (e.g., GGUF file not supported)
                print(f"Warning: faster-whisper failed to load model: {e}")
                faster_whisper_failed = True
            
            # If faster-whisper failed (not installed or couldn't load), try whispercpp
            if faster_whisper_failed:
                # Initialize model_path
                model_path = None
                
                # Check if model is a GGUF file - whispercpp can handle those
                model_is_gguf = model_to_use.endswith('.gguf') or (model_path and model_path.endswith('.gguf'))
                
                # Check if Vulkan is available for whispercpp
                whisper_vulkan_available = False
                whisper_vulkan_device = os.environ.get('VK_DEVICE_SELECT_DEVICE', '0')
                try:
                    import whispercpp
                    if os.environ.get('VK_DEVICE_SELECT_DEVICE'):
                        whisper_vulkan_available = True
                        print(f"Whisper Vulkan: Will use GPU device {whisper_vulkan_device}")
                    elif os.path.exists('/dev/dri'):
                        whisper_vulkan_available = True
                        print(f"Whisper Vulkan: Auto-detected GPU, will use device {whisper_vulkan_device}")
                except ImportError as e:
                    print(f"Debug: whispercpp import failed: {e}")
                
                try:
                    import whispercpp
                    
                    model_to_use = audio_models[0]
                    model_path = None
                    
                    # Check if model is a URL - handle caching
                    if model_to_use.startswith('http://') or model_to_use.startswith('https://'):
                        cached_path = get_cached_model_path(model_to_use)
                        if cached_path:
                            model_path = cached_path
                            print(f"Using cached model: {model_path}")
                        else:
                            # Download with progress
                            cache_dir = get_model_cache_dir()
                            model_path = download_model(model_to_use, cache_dir)
                            model_to_use = model_path
                    
                    # whispercpp needs a local file or a built-in model name
                    # whispercpp supports: tiny, base, small, medium, large-v1, large (built-in)
                    # or pre-converted GGUF files (NOT HuggingFace GGUF format)
                    if not model_path:
                        # Check if it's a local file
                        if os.path.isfile(model_to_use):
                            model_path = model_to_use
                        elif model_to_use in ['tiny.en', 'tiny', 'base.en', 'base', 'small.en', 'small', 'medium.en', 'medium', 'large-v1', 'large']:
                            # It's a built-in model name - whispercpp will download automatically
                            model_path = model_to_use
                        else:
                            # Could be a model name without .gguf extension - try it
                            model_path = model_to_use
                    
                    if not model_path or (model_path != model_to_use and not os.path.isfile(model_path)):
                        # If model_path is not a valid built-in name, check if file exists
                        if model_path and model_path not in ['tiny.en', 'tiny', 'base.en', 'base', 'small.en', 'small', 'medium.en', 'medium', 'large-v1', 'large']:
                            if not os.path.isfile(model_path):
                                print(f"Warning: whispercpp requires a local GGUF file or built-in model name, not: {model_to_use}")
                                print("For Vulkan audio transcription, use a built-in model name (tiny/base/small/medium/large-v1/large)")
                                print("or install faster-whisper with PyTorch for HuggingFace GGUF support.")
                                print("Audio model will load on-demand when transcription is requested.")
                    else:
                        # Load the whispercpp model
                        try:
                            whisper_model = whispercpp.Whisper.from_pretrained(model_path)
                            
                            # Store in multi_model_manager
                            model_key = f"audio:{audio_models[0]}"
                            multi_model_manager.add_model(model_key, whisper_model)
                            print(f"Audio model loaded successfully (whispercpp)")
                            if whisper_vulkan_available:
                                print(f"  -> Using Vulkan GPU acceleration (device {whisper_vulkan_device})")
                        except Exception as e:
                            error_msg = str(e).lower()
                            if 'not a valid preconverted model' in error_msg:
                                print(f"Warning: whispercpp does not support this model format")
                                print("whispercpp only supports built-in model names or pre-converted GGUF files.")
                                print("For Vulkan audio transcription, please either:")
                                print("  1. Install PyTorch + faster-whisper: pip install torch faster-whisper")
                                print("  2. Use a built-in whispercpp model: --audio-model base")
                                print("Audio model will load on-demand when transcription is requested.")
                            else:
                                print(f"Warning: Could not pre-load audio model with whispercpp: {e}")
                                print("Audio model will load on-demand when transcription is requested.")
                except ImportError as e:
                    # Neither faster-whisper nor whispercpp available
                    print(f"Warning: No audio transcription library available: {e}")
                    print("Options:")
                    print("  1. Install PyTorch + faster-whisper: pip install torch faster-whisper")
                    print("  2. Use a built-in whispercpp model: --audio-model base")
                    print("Audio model will load on-demand when transcription is requested.")
                except Exception as e:
                    print(f"Warning: Could not pre-load audio model with whispercpp: {e}")
                    print("Audio model will load on-demand when transcription is requested.")
    
    # Set up TTS model if specified
    if args.tts_model:
        print(f"\nText-to-speech model: {args.tts_model}")
        multi_model_manager.set_tts_model(args.tts_model, {})
        
        # Pre-load TTS model if it's the only model configured
        if not model_names and not audio_models and not image_models:
            print(f"Pre-loading TTS model...")
            # TTS models load on-demand, but we can pre-download if needed
    
    # Set up image model if specified
    if image_models:
        print(f"\nImage generation model(s): {image_models}")
        multi_model_manager.set_image_model(image_models[0], {
            'ctx': args.vision_ctx,
            'offload': args.vision_offload,
            'llm_path': args.llm_path,
            'vae_path': args.vae_path,
            'sample_method': args.image_sample_method,
            'steps': args.image_steps,
            'width': args.image_width,
            'height': args.image_height,
            'cfg_scale': args.image_cfg_scale,
        })
        # Register all image models
        for img_m in image_models[1:]:
            multi_model_manager.set_image_model(img_m, {
                'ctx': args.vision_ctx,
                'offload': args.vision_offload,
            })
        
        # Pre-load image model if it's configured (even with audio models)
        if image_models:
            print(f"Pre-loading image model...")
            
            # Get the original model name
            original_model_name = image_models[0]
            
            # Check if it's a URL first (before any processing)
            is_url = original_model_name.startswith('http://') or original_model_name.startswith('https://')
            
            # Strip query parameters from URL if present
            model_name = original_model_name
            if '?' in model_name:
                model_name = model_name.split('?')[0]
            
            # Check if the image model is a GGUF model
            is_gguf = model_name.endswith('.gguf') or 'gguf' in model_name.lower()
            
            if is_gguf:
                # Load GGUF image model using llama.cpp
                print(f"Detected GGUF image model, loading with llama.cpp...")
                try:
                    from llama_cpp import Llama
                    from llama_cpp import Llama
                    
                    # Download GGUF model if needed
                    model_path = None
                    
                    # Check if it's a URL - download directly
                    if is_url:
                        print(f"Image model is a URL: {original_model_name}")
                        cached_path = get_cached_model_path(original_model_name)
                        if cached_path:
                            model_path = cached_path
                            print(f"Using cached GGUF model: {model_path}")
                        else:
                            print(f"Downloading GGUF model: {original_model_name}")
                            cache_dir = get_model_cache_dir()
                            model_path = download_model(original_model_name, cache_dir)
                    elif os.path.isfile(model_name):
                        model_path = model_name
                        print(f"Loading local GGUF model: {model_name}")
                    else:
                        # Try to download from HuggingFace Hub
                        print(f"Trying to resolve as HuggingFace model: {model_name}")
                        try:
                            from huggingface_hub import hf_hub_download, list_repo_files
                            parts = model_name.split('/')
                            if len(parts) >= 2:
                                repo_id = f"{parts[0]}/{parts[1]}"
                                print(f"Looking for GGUF files in repo: {repo_id}")
                                files = list_repo_files(repo_id)
                                gguf_files = [f for f in files if f.endswith('.gguf')]
                                if not gguf_files:
                                    raise ValueError(f"No GGUF files found in {repo_id}")
                                filename = gguf_files[0]
                                model_path = hf_hub_download(repo_id=repo_id, filename=filename)
                                print(f"Downloaded GGUF model to: {model_path}")
                        except Exception as e:
                            print(f"Could not resolve GGUF model path: {e}")
                            model_path = None
                    
                    if model_path and os.path.isfile(model_path):
                        n_gpu_layers = -1
                        n_ctx = 2048
                        
                        print(f"Loading GGUF model from: {model_path}")
                        file_size = os.path.getsize(model_path)
                        print(f"GGUF model file size: {file_size / (1024*1024):.1f} MB")
                        
                        # Verify it's a valid GGUF file (check magic bytes)
                        with open(model_path, 'rb') as f:
                            magic = f.read(8)
                            print(f"File magic bytes: {magic}")
                            if not magic.startswith(b'GGUF'):
                                print(f"ERROR: File is NOT a valid GGUF! Expected 'GGUF', got: {magic}")
                                print(f"The URL must be a DIRECT download link (ends with .gguf)")
                                print(f"Image model will load on first request")
                            else:
                                try:
                                    llama_model = Llama(
                                        model_path=model_path,
                                        n_gpu_layers=n_gpu_layers,
                                        n_ctx=n_ctx,
                                        verbose=True,
                                    )
                                    multi_model_manager.add_model(model_key, llama_model)
                                    print(f"GGUF image model loaded successfully: {original_model_name}")
                                except Exception as llama_error:
                                    print(f"llama.cpp load error: {llama_error}")
                                    print(f"Trying stable-diffusion-cpp-python fallback...")
                                    # Try stable-diffusion-cpp-python as fallback
                                    try:
                                        from stable_diffusion_cpp import StableDiffusion
                                        
                                        # Initialize model_key to avoid unbound variable error
                                        model_key = None
                                        
                                        print(f"Loading with sd.cpp: {model_path}")
                                        # For models like Z-Image-Turbo/Flux, use diffusion_model_path
                                        # Look for additional model files in same directory
                                        model_dir = os.path.dirname(model_path)
                                        model_name = os.path.basename(model_path)
                                        
                                        # Try to find additional model files
                                        clip_l_path = None
                                        t5xxl_path = None
                                        vae_path = None
                                        
                                        # Use CLI arguments if provided, download and cache if URL
                                        if args.llm_path:
                                            # Check if it's a Hugging Face model ID, URL, or local path
                                            if is_huggingface_model_id(args.llm_path):
                                                # Download from Hugging Face
                                                print(f"Attempting to download LLM model from Hugging Face: {args.llm_path}")
                                                cache_dir = get_model_cache_dir()
                                                clip_l_path = download_huggingface_model(args.llm_path, cache_dir, '.gguf')
                                                if clip_l_path:
                                                    print(f"Downloaded LLM model to: {clip_l_path}")
                                                else:
                                                    print(f"Warning: Failed to download LLM model from Hugging Face, will try as local path")
                                            elif args.llm_path.startswith('http://') or args.llm_path.startswith('https://'):
                                                cached = get_cached_model_path(args.llm_path)
                                                if cached:
                                                    clip_l_path = cached
                                                    print(f"Using cached LLM model: {clip_l_path}")
                                                else:
                                                    cache_dir = get_model_cache_dir()
                                                    clip_l_path = download_model(args.llm_path, cache_dir)
                                                    print(f"Downloaded LLM model to: {clip_l_path}")
                                            else:
                                                clip_l_path = args.llm_path
                                        if args.vae_path:
                                            # Check if it's a URL and download if needed
                                            if args.vae_path.startswith('http://') or args.vae_path.startswith('https://'):
                                                cached = get_cached_model_path(args.vae_path)
                                                if cached:
                                                    vae_path = cached
                                                    print(f"Using cached VAE model: {vae_path}")
                                                else:
                                                    cache_dir = get_model_cache_dir()
                                                    vae_path = download_model(args.vae_path, cache_dir)
                                                    print(f"Downloaded VAE model to: {vae_path}")
                                            else:
                                                vae_path = args.vae_path
                                        
                                        # Look for common file patterns only if CLI args not provided
                                        if not args.llm_path or not args.vae_path:
                                            for f in os.listdir(model_dir) if os.path.exists(model_dir) else []:
                                                if not args.llm_path and 'clip_l' in f.lower() and f.endswith(('.safetensors', '.bin')):
                                                    clip_l_path = os.path.join(model_dir, f)
                                                elif 't5xxl' in f.lower() and f.endswith(('.safetensors', '.bin')):
                                                    t5xxl_path = os.path.join(model_dir, f)
                                                elif not args.vae_path and f.endswith('.safetensors') and 'ae' in f.lower():
                                                    vae_path = os.path.join(model_dir, f)
                                        
                                        # Build kwargs based on available files
                                        sd_kwargs = {'diffusion_model_path': model_path}
                                        
                                        if clip_l_path:
                                            sd_kwargs['llm_path'] = clip_l_path
                                            print(f"DEBUG: Adding llm_path to sd_kwargs: {clip_l_path}")
                                        else:
                                            print(f"DEBUG: clip_l_path is None or empty, not adding to sd_kwargs")
                                            print(f"DEBUG: args.llm_path = {args.llm_path}")
                                        if args.vae_path:
                                            sd_kwargs['vae_path'] = vae_path
                                        elif vae_path:
                                            sd_kwargs['vae_path'] = vae_path
                                        if t5xxl_path:
                                            sd_kwargs['t5xxl_path'] = t5xxl_path
                                        

                                        # Add generation parameters from CLI args
                                        # sd_kwargs['sample_method'] = args.image_sample_method  # Not valid for __init__
                                        # sd_kwargs['steps'] = args.image_steps  # Not valid for __init__
                                        
                                        # Define model_key for adding to manager
                                        model_key = f"image:{model_path}"
                                        sd_model = StableDiffusion(**sd_kwargs)
                                        multi_model_manager.add_model(model_key, sd_model)
                                        # Add aliases for "image" and "vision" 
                                        multi_model_manager.add_model("image", sd_model)
                                        multi_model_manager.add_model("vision", sd_model)
                                        print(f"Image model loaded successfully via sd.cpp: {original_model_name}")
                                    except ImportError as sd_error:
                                        print(f"stable-diffusion-cpp-python not installed: {sd_error}")
                                        print(f"Image model will load on first request")
                                    except Exception as sd_error:
                                        print(f"sd.cpp load error: {sd_error}")
                                        print(f"Image model will load on first request")
                    else:
                        print(f"Could not load GGUF image model: no valid model path")
                        
                except ImportError as e:
                    print(f"Warning: llama_cpp not installed: {e}")
                except Exception as e:
                    print(f"Warning: Failed to pre-load GGUF image model: {e}")
                try:
                    from llama_cpp import Llama
                    from llama_cpp import Llama
                    
                    # Download GGUF model if needed (similar to VulkanBackend)
                    model_path = None
                    if model_name.startswith('http://') or model_name.startswith('https://'):
                        cached_path = get_cached_model_path(model_name)
                        if cached_path:
                            model_path = cached_path
                            print(f"Using cached GGUF model: {model_path}")
                        else:
                            print(f"Downloading GGUF model: {model_name}")
                            cache_dir = get_model_cache_dir()
                            model_path = download_model(model_name, cache_dir)
                    elif os.path.isfile(model_name):
                        model_path = model_name
                        print(f"Loading local GGUF model: {model_path}")
                    else:
                        # Try to download from HuggingFace Hub
                        try:
                            from huggingface_hub import hf_hub_download, list_repo_files
                            parts = model_name.split('/')
                            if len(parts) >= 2:
                                repo_id = f"{parts[0]}/{parts[1]}"
                                files = list_repo_files(repo_id)
                                gguf_files = [f for f in files if f.endswith('.gguf')]
                                if not gguf_files:
                                    raise ValueError(f"No GGUF files found in {repo_id}")
                                filename = gguf_files[0]
                                model_path = hf_hub_download(repo_id=repo_id, filename=filename)
                                print(f"Downloaded GGUF model to: {model_path}")
                        except Exception as e:
                            print(f"Could not resolve GGUF model path: {e}")
                            model_path = None
                    
                    if model_path and os.path.isfile(model_path):
                        # Use the cached path for the model key
                        model_key = f"image:{model_path}"
                        
                        # Load with llama.cpp
                        n_gpu_layers = -1  # Load all layers to GPU
                        n_ctx = 2048
                        
                        llama_model = Llama(
                            model_path=model_path,
                            n_gpu_layers=n_gpu_layers,
                            n_ctx=n_ctx,
                            verbose=False,
                        )
                        multi_model_manager.add_model(model_key, llama_model)
                        print(f"GGUF image model loaded successfully: {model_name}")
                    else:
                        print(f"Could not load GGUF image model: no valid model path")
                        
                except ImportError as e:
                    print(f"Warning: llama_cpp not installed: {e}")
                except Exception as e:
                    print(f"Warning: Failed to pre-load GGUF image model: {e}")
            else:
                # Load diffusers image model (Stable Diffusion)
                try:
                    import torch
                    from diffusers import StableDiffusionXLPipeline, DiffusionPipeline
                    
                    # Use model name directly for diffusers (model_path is only set in GGUF branch)
                    model_key = f"image:{model_name}"
                    print(f"Loading diffusers pipeline: {model_name}")
                    
                    # Try to load as Stable Diffusion XL first
                    try:
                        pipeline = StableDiffusionXLPipeline.from_pretrained(
                            model_name,
                            torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
                            use_safetensors=True,
                        )
                    except Exception as e:
                        print(f"SDXL failed, trying generic pipeline: {e}")
                        pipeline = DiffusionPipeline.from_pretrained(
                            model_name,
                            torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
                            use_safetensors=True,
                        )
                    
                    if torch.cuda.is_available():
                        pipeline = pipeline.to("cuda")
                        pipeline.enable_attention_slicing()
                    else:
                        pipeline = pipeline.to("cpu")
                    
                    multi_model_manager.add_model(model_key, pipeline)
                    # Add aliases for "image" and "vision"
                    multi_model_manager.add_model("image", pipeline)
                    multi_model_manager.add_model("vision", pipeline)
                    print(f"Image model loaded successfully: {model_name}")
                    
                except ImportError as e:
                    print(f"Warning: diffusers not installed: {e}")
                except Exception as e:
                    print(f"Warning: Failed to pre-load image model: {e}")
    
    # Register model aliases if specified
    if args.model_aliases:
        print(f"\nRegistering model aliases:")
        for alias, model in args.model_aliases:
            multi_model_manager.set_model_alias(alias, model)
            print(f"  {alias} -> {model}")
    
    # Start the server
    import uvicorn
    print(f"\nStarting server on http://{args.host}:{args.port}")
    print(f"API documentation available at http://{args.host}:{args.port}/docs")
    if model_manager.backend is not None:
        print(f"Using backend: {model_manager.backend_type}")
    
    # Print available models
    models = multi_model_manager.list_models()
    print(f"Available models: {[m.id for m in models]}")
    
    uvicorn.run(app, host=args.host, port=args.port)
if __name__ == "__main__":
    main()
