#!/usr/bin/env python3
"""
OpenAI-compatible API server for HuggingFace models (NVIDIA) and GGUF models (Vulkan).
Supports CUDA (NVIDIA) and Vulkan (AMD) GPU backends, memory-aware model loading,
streaming, and tool calling.
"""

import argparse
import asyncio
import hashlib
import json
import os
import pathlib
import re
import sys
import time
import uuid
import warnings
import requests
from abc import ABC, abstractmethod
from contextlib import asynccontextmanager
from typing import AsyncGenerator, Dict, List, Optional, Union

import psutil
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import StreamingResponse, FileResponse, JSONResponse
from pydantic import BaseModel, Field, validator, field_validator, ConfigDict
from pydantic_core import PydanticCustomError
from threading import Thread

# Import codai module for enhanced tool call parsing
from codai.models import ModelParserDispatcher
from codai.models.parsers import OpenAIFormatter
# Per-model semaphores for request concurrency control
model_semaphores: dict = {}
load_mode = {"mode": "ondemand"}  # Track load mode globally
queue_flags = {"model_1": False, "image_1": False, "audio_1": False, "tts_1": False}  # Track --X-1 flags
# =============================================================================
# Model Cache Directory
# =============================================================================

# =============================================================================
# Model Capability Detection
# =============================================================================

from typing import Set, Optional
from dataclasses import dataclass, field

@dataclass
class ModelCapabilities:
    """Represents what a model can do."""
    text_generation: bool = False  # LLM/chat completion
    image_to_text: bool = False  # Image understanding (captioning, VQA)
    image_generation: bool = False  # Text-to-image (Stable Diffusion)
    speech_to_text: bool = False  # Audio transcription
    text_to_speech: bool = False  # Speech synthesis
    
    def __str__(self):
        caps = []
        if self.text_generation:
            caps.append("text")
        if self.image_to_text:
            caps.append("image-to-text")
        if self.image_generation:
            caps.append("image")
        if self.speech_to_text:
            caps.append("speech-to-text")
        if self.text_to_speech:
            caps.append("text-to-speech")
        return ", ".join(caps) if caps else "none"

def detect_model_capabilities(model_name: str) -> ModelCapabilities:
    """
    Detect model capabilities based on model name/type.
    
    This is a heuristic detection - actual capabilities may vary.
    """
    caps = ModelCapabilities()
    
    if not model_name:
        return caps
    
    name_lower = model_name.lower()
    
    # Check for image generation models (Stable Diffusion, SDXL, etc.)
    if any(x in name_lower for x in ['stable-diffusion', 'sd15', 'sdxl', 'sd-xl', 'turbo', 'playground']):
        caps.image_generation = True
        return caps  # Usually SD models are dedicated
    
    # Check for vision models (image-to-text)
    if any(x in name_lower for x in ['vision', 'vl-', '-vl', 'llava', 'qwen2-vl', 'qwen-vl', 'phi-4-mini', 'pixtral', 'clip']):
        caps.image_to_text = True
        caps.text_generation = True  # Vision models are also LLMs
        return caps
    
    # Check for TTS models
    if any(x in name_lower for x in ['kokoro', 'tts', 'speech', 'voice']):
        caps.text_to_speech = True
        return caps
    
    # Check for whisper models (speech-to-text)
    if any(x in name_lower for x in ['whisper', 'faster-whisper', 'distil-whisper']):
        caps.speech_to_text = True
        return caps
    
    # Check for GGUF models (typically text models)
    if '.gguf' in name_lower or 'gguf' in name_lower:
        caps.text_generation = True
        return caps
    
    # Default: assume text generation (most HF models are LLMs)
    caps.text_generation = True
    return caps

def get_model_cache_dir() -> str:
    """Get or create the model cache directory."""
    # Use XDG_CACHE_HOME if set, otherwise use ~/.cache/coderai
    cache_home = os.environ.get('XDG_CACHE_HOME', os.path.expanduser('~/.cache'))
    cache_dir = os.path.join(cache_home, 'coderai', 'models')
    pathlib.Path(cache_dir).mkdir(parents=True, exist_ok=True)
    return cache_dir

def get_all_cache_dirs() -> dict:
    """Get all model cache directories."""
    caches = {}
    cache_home = os.environ.get('XDG_CACHE_HOME', os.path.expanduser('~/.cache'))
    
    # Coderai GGUF cache
    coderai_cache = os.path.join(cache_home, 'coderai', 'models')
    if os.path.exists(coderai_cache):
        caches['coderai'] = coderai_cache
    
    # HuggingFace cache (for .safetensors, PyTorch models, etc.)
    hf_cache = os.path.join(cache_home, 'huggingface')
    if os.path.exists(hf_cache):
        caches['huggingface'] = hf_cache
    
    # Local diffusers cache (often stored locally by apps)
    local_diffusers = os.path.expanduser('~/.cache/diffusers')
    if os.path.exists(local_diffusers):
        caches['diffusers'] = local_diffusers
    
    return caches

def get_cached_model_path(url: str) -> Optional[str]:
    """Check if a model URL is already cached. Returns path if cached, None otherwise."""
    cache_dir = get_model_cache_dir()
    # Create a safe filename from the URL using SHA-256 hash
    url_hash = hashlib.sha256(url.encode()).hexdigest()
    url_path = url.split('?')[0]  # Remove query params
    filename = os.path.basename(url_path)
    if not filename:
        filename = "model.gguf"
    # Match the format used in load_model: {hash}_{filename}
    cached_filename = f"{url_hash}_{filename}"
    cached_path = os.path.join(cache_dir, cached_filename)
    
    if os.path.exists(cached_path):
        print(f"Using cached model: {cached_path}")
        return cached_path
    return None
def is_huggingface_model_id(path: str) -> bool:
    """Check if the path is a Hugging Face model ID (e.g., 'Qwen/Qwen3-4B-Instruct-2507-Q3_K_S')."""
    # Must contain / but not be a URL
    return '/' in path and not path.startswith('http://') and not path.startswith('https://')
def download_huggingface_model(model_id: str, cache_dir: str, file_pattern: str = '.gguf') -> Optional[str]:
    """Download a model from Hugging Face by model ID. Returns cached path or None on failure."""
    try:
        from huggingface_hub import hf_hub_download, list_repo_files
        
        print(f"Downloading from Hugging Face: {model_id}")
        files = list_repo_files(model_id)
        
        # Filter files by extension pattern
        matching_files = [f for f in files if f.endswith(file_pattern)]
        if not matching_files:
            print(f"No {file_pattern} files found in {model_id}")
            return None
        
        # Download the first matching file
        filename = matching_files[0]
        print(f"Downloading: {filename}")
        model_path = hf_hub_download(repo_id=model_id, filename=filename, cache_dir=cache_dir)
        print(f"Downloaded model to: {model_path}")
        return model_path
    except Exception as e:
        print(f"Error downloading from Hugging Face: {e}")
        return None
def download_model(url: str, cache_dir: str) -> str:
    """Download a model from URL with progress reporting. Returns cached path."""
    import requests
    import hashlib
    
    url_path = url.split('?')[0]
    filename = os.path.basename(url_path)
    
    # Determine file extension
    if 'gguf' in url.lower():
        ext = '.gguf'
    elif 'bin' in url.lower():
        ext = '.bin'
    elif 'ggml' in url.lower():
        ext = '.ggml'
    else:
        ext = '.bin'
    
    if not filename.endswith(ext):
        filename = f"model{ext}"
    
    # Create safe filename in cache
    url_hash = hashlib.sha256(url.encode()).hexdigest()
    cached_filename = f"{url_hash}_{filename}"
    model_path = os.path.join(cache_dir, cached_filename)
    
    # Check if already cached
    if os.path.exists(model_path):
        print(f"Using cached model: {model_path}")
        return model_path
    
    # Download
    print(f"Downloading model: {url}")
    response = requests.get(url, stream=True)
    response.raise_for_status()
    
    total_size = int(response.headers.get('content-length', 0))
    total_mb = total_size / (1024 * 1024) if total_size > 0 else 0
    
    downloaded = 0
    start_time = time.time()
    
    with open(model_path, 'wb') as f:
        for chunk in response.iter_content(chunk_size=8192*1024):
            if chunk:
                f.write(chunk)
                downloaded += len(chunk)
                if total_size > 0:
                    percent = (downloaded / total_size) * 100
                    elapsed = time.time() - start_time
                    speed = downloaded / elapsed if elapsed > 0 else 0
                    speed_mb = speed / (1024 * 1024)
                    dl_mb = downloaded / (1024 * 1024)
                    print(f"Downloaded: {percent:.1f}% ({dl_mb:.1f}/{total_mb:.1f} MB) at {speed_mb:.1f} MB/s", end='\r')
    
    print()  # New line after progress
    print(f"Downloaded and cached to: {model_path}")
    if total_mb > 0:
        print(f"File size: {total_mb:.1f} MB")
    
    return model_path
# =============================================================================
# Backend Detection and Imports
# =============================================================================

def detect_available_backends():
    """Detect which backends are available."""
    backends = {'cpu': True}
    
    # Check for PyTorch/CUDA
    try:
        import torch
        if torch.cuda.is_available():
            backends['nvidia'] = True
    except ImportError:
        pass
    
    # Check for llama-cpp-python (Vulkan)
    try:
        import llama_cpp
        backends['vulkan'] = True
    except ImportError:
        pass
    
    return backends
# =============================================================================
# Flash Attention Detection (for NVIDIA backend)
# =============================================================================

def check_flash_attn_availability() -> bool:
    """Check if flash-attn is installed and available."""
    try:
        import flash_attn
        return True
    except ImportError:
        return False
# =============================================================================
# Pydantic Models for API
# =============================================================================

class ToolFunction(BaseModel):
    name: str
    description: Optional[str] = None
    parameters: Optional[Dict] = None
class Tool(BaseModel):
    type: str = "function"
    function: ToolFunction
class ChatMessage(BaseModel):
    role: str
    content: Optional[Union[str, List[Dict]]] = None
    name: Optional[str] = None
    tool_calls: Optional[List[Dict]] = None
    tool_call_id: Optional[str] = None
    
    @field_validator('content', mode='before')
    @classmethod
    def convert_content_array_to_string(cls, v):
        """Convert multipart content array to string for compatibility."""
        if v is None:
            return None
        if isinstance(v, str):
            return v
        if isinstance(v, list):
            # Handle multipart content array format (e.g., from KiloCode)
            # Format: [{"type": "text", "text": "..."}, {"type": "text", "text": "..."}]
            parts = []
            for item in v:
                if isinstance(item, dict):
                    if item.get('type') == 'text' and 'text' in item:
                        parts.append(item['text'])
                    else:
                        # Handle other content types (image_url, etc.) by converting to placeholder
                        parts.append(f"[{item.get('type', 'unknown')} content]")
                else:
                    parts.append(str(item))
            return '\n'.join(parts)
        return str(v)
class ChatCompletionRequest(BaseModel):
    model: str
    messages: List[ChatMessage]
    temperature: float = 0.7
    top_p: float = 1.0
    n: int = 1
    max_tokens: Optional[int] = None
    stream: bool = False
    stop: Optional[Union[str, List[str]]] = None
    presence_penalty: float = 0.0
    frequency_penalty: float = 0.0
    repeat_penalty: float = 1.0
    tools: Optional[List[Tool]] = None
    tool_choice: Optional[Union[str, Dict]] = "auto"
    # Extra fields that clients may send but we ignore
    seed: Optional[int] = None
    logprobs: Optional[bool] = None
    top_logprobs: Optional[int] = None
    response_format: Optional[Dict] = None
    user: Optional[str] = None
    
    model_config = ConfigDict(extra="allow")  # Allow extra fields to prevent 422 errors
class CompletionRequest(BaseModel):
    model: str
    prompt: Union[str, List[str]]
    temperature: float = 0.7
    top_p: float = 1.0
    n: int = 1
    max_tokens: Optional[int] = None
    stream: bool = False
    stop: Optional[Union[str, List[str]]] = None
    presence_penalty: float = 0.0
    frequency_penalty: float = 0.0
    repeat_penalty: float = 1.0
    # Extra fields that clients may send but we ignore
    seed: Optional[int] = None
    logprobs: Optional[bool] = None
    top_logprobs: Optional[int] = None
    best_of: Optional[int] = None
    echo: Optional[bool] = None
    user: Optional[str] = None
    
    model_config = ConfigDict(extra="allow")  # Allow extra fields to prevent 422 errors
class ModelInfo(BaseModel):
    id: str
    object: str = "model"
    created: int = Field(default_factory=lambda: int(time.time()))
    owned_by: str = "huggingface"
class ModelList(BaseModel):
    object: str = "list"
    data: List[ModelInfo]
# =============================================================================
# Audio Transcription Models
# =============================================================================

class TranscriptionRequest(BaseModel):
    model: str
    file: Optional[bytes] = None
    file_path: Optional[str] = None
    language: Optional[str] = None
    prompt: Optional[str] = None
    response_format: Optional[str] = "json"
    temperature: Optional[float] = 0.0
    timestamp_granularities: Optional[List[str]] = None
    
    model_config = ConfigDict(extra="allow")
class TranscriptionResponse(BaseModel):
    text: str
    model_config = ConfigDict(extra="allow")
# =============================================================================
# Image Generation Models
# =============================================================================

class ImageGenerationRequest(BaseModel):
    model: str
    prompt: str
    n: int = 1
    size: Optional[str] = "1024x1024"
    steps: Optional[int] = None  # Number of inference steps (overrides quality-based default)
    guidance_scale: Optional[float] = None  # CFG scale (overrides quality-based default)
    quality: Optional[str] = "standard"
    style: Optional[str] = None
    response_format: Optional[str] = "url"
    seed: Optional[int] = None
    user: Optional[str] = None
    
    model_config = ConfigDict(extra="allow")
class ImageGenerationResponse(BaseModel):
    created: int
    data: List[Dict]
    model_config = ConfigDict(extra="allow")
# =============================================================================
# Content Filtering Utility
# =============================================================================

def filter_malformed_content(text: str) -> str:
    """Filter out malformed SEARCH/REPLACE blocks that the model might output as content."""
    if not text:
        return text
    
    # Remove diff-like blocks that shouldn't be in the output
    filtered = text
    
    # Remove git-style diff markers and SEARCH/REPLACE patterns
    filtered = re.sub(r'<<<<<<<\s+SEARCH.*?=======', '', filtered, flags=re.DOTALL)
    filtered = re.sub(r'=======.*?>>>>>>>\s+REPLACE', '', filtered, flags=re.DOTALL)
    filtered = re.sub(r'>>>>>>>\s+REPLACE', '', filtered)
    
    # Also remove common malformed patterns seen in outputs
    filtered = re.sub(r'<<<<<<<\s+SEARCH\s*:start_line:\d+[^<]*', '', filtered, flags=re.DOTALL)
    filtered = re.sub(r'<button>Stop Generation</button>', '', filtered)
    filtered = re.sub(r'\<\|assistant\|\>', '', filtered)
    filtered = re.sub(r'\</\|assistant\|\>', '', filtered)
    
    # Clean up excessive newlines left from removal
    filtered = re.sub(r'\n{3,}', '\n\n', filtered)
    
    # Don't strip single newlines or whitespace - they might be valid content
    return filtered
# =============================================================================
# Tool Parsing
# =============================================================================

class ToolCallParser:
    """Parse model outputs to extract tool calls."""
    
    def __init__(self, tokenizer=None, model_name: str = None):
        self.tokenizer = tokenizer
        self.model_name = model_name
    
    def set_model_name(self, model_name: str):
        """Set the model name for model-specific parsing."""
        self.model_name = model_name
    
    def _is_qwen_model(self) -> bool:
        """Check if the current model is a Qwen model."""
        if not self.model_name:
            return False
        model_lower = self.model_name.lower()
        return 'qwen' in model_lower or 'qwen2' in model_lower or 'qwen3' in model_lower
    
    def _parse_qwen_tool_calls(self, text: str) -> Optional[List[Dict]]:
        """
        Parse tool calls from Qwen model output.
        
        Supports:
        1. Instruct-style: <tool_call>{...}</tool_call> with JSON inside
        2. Coder-style: <tool_call><function=name><parameter=key>value</parameter></function></tool_call>
        
        Returns OpenAI-compatible tool_calls format.
        """
        import uuid
        import json
        import re
        
        # Clean the text first - remove thinking tags if present
        clean_text = re.sub(r'<｜.*?｜>', '', text)
        clean_text = re.sub(r'<think>.*?</think>', '', clean_text, flags=re.DOTALL)
        clean_text = re.sub(r'<think>', '', clean_text)
        clean_text = re.sub(r'</think>', '', clean_text)
        clean_text = re.sub(r'<tool_call>\s*', '<tool_call>', clean_text)
        clean_text = re.sub(r'\s*</tool_call>', '</tool_call>', clean_text)
        
        tool_calls = []
        
        # 1. Check for Instruct-style (JSON inside <tool_call> tags)
        # Format: <tool_call>{"name": "...", "arguments": {...}}</tool_call>
        instruct_matches = re.findall(r'<tool_call>\s*(\{.*?\})\s*</tool_call>', clean_text, re.DOTALL)
        for match in instruct_matches:
            try:
                data = json.loads(match.strip())
                if 'name' in data and 'arguments' in data:
                    tool_calls.append({
                        "id": f"call_{uuid.uuid4().hex[:8]}",
                        "type": "function",
                        "function": {
                            "name": data.get("name"),
                            "arguments": json.dumps(data.get("arguments", {})) if isinstance(data.get("arguments"), dict) else str(data.get("arguments", "{}"))
                        }
                    })
            except json.JSONDecodeError:
                continue
        
        # 2. Check for Coder-style (XML tags) if no Instruct calls were found or as additional
        # Format: <tool_call><function=name><parameter=key>value</parameter></function></tool_call>
        if not tool_calls:
            # Find all tool_call blocks
            coder_blocks = re.findall(r'<tool_call>\s*(.*?)\s*</tool_call>', clean_text, re.DOTALL)
            
            # Some Coder models might skip the outer <tool_call> wrapper and just use <function>
            if not coder_blocks:
                coder_blocks = re.findall(r'(<function=.*?</function>)', clean_text, re.DOTALL)
                
            for block in coder_blocks:
                # Extract function name
                func_name_match = re.search(r'<function=([^>]+)>', block)
                if func_name_match:
                    func_name = func_name_match.group(1).strip()
                    # Find all parameters within this specific function block
                    params = re.findall(r'<parameter=([^>]+)>(.*?)</parameter>', block, re.DOTALL)
                    arguments = {}
                    for k, v in params:
                        key = k.strip()
                        val = v.strip()
                        # Try to parse as JSON, otherwise use raw string
                        try:
                            arguments[key] = json.loads(val)
                        except (json.JSONDecodeError, TypeError):
                            arguments[key] = val
                    
                    tool_calls.append({
                        "id": f"call_{uuid.uuid4().hex[:8]}",
                        "type": "function",
                        "function": {
                            "name": func_name,
                            "arguments": json.dumps(arguments)
                        }
                    })
        
        return tool_calls if tool_calls else None
    
    def _parse_nested_xml_tool(self, xml_content: str) -> Optional[Dict]:
        """Parse nested XML tool format like <tool><name>...</name><arguments>...</arguments></tool>."""
        try:
            # Extract name
            name_match = re.search(r'<name>\s*(.*?)\s*</name>', xml_content, re.DOTALL)
            if not name_match:
                return None
            tool_name = name_match.group(1).strip()
            
            # Extract arguments - handle both JSON and nested XML
            args_match = re.search(r'<arguments>\s*(.*?)\s*</arguments>', xml_content, re.DOTALL)
            if not args_match:
                return None
            args_content = args_match.group(1).strip()
            
            # Try to parse as JSON first
            try:
                args_dict = json.loads(args_content)
                return {"name": tool_name, "arguments": args_dict}
            except json.JSONDecodeError:
                # Try to parse as nested XML (e.g., <files><path>...</path></files>)
                try:
                    args_dict = self._xml_to_dict(args_content)
                    return {"name": tool_name, "arguments": args_dict}
                except:
                    # Return raw string as fallback
                    return {"name": tool_name, "arguments": args_content}
        except Exception:
            return None
    
    def _xml_to_dict(self, xml_content: str) -> Dict:
        """Convert simple nested XML to dictionary."""
        result = {}
        # Find all top-level tags
        pattern = r'<(\w+)>\s*(.*?)\s*</\1>'
        matches = re.findall(pattern, xml_content, re.DOTALL)
        
        for tag, content in matches:
            # Check if content has nested tags
            if re.search(r'<\w+>', content):
                # Recursively parse nested content
                try:
                    result[tag] = self._xml_to_dict(content)
                except:
                    # If recursive parsing fails, check for array-like content
                    items = re.findall(r'<(\w+)>\s*(.*?)\s*</\1>', content, re.DOTALL)
                    if items and all(item[0] == items[0][0] for item in items):
                        # Array of items with same tag
                        result[tag] = []
                        for _, item_content in items:
                            if re.search(r'<\w+>', item_content):
                                result[tag].append(self._xml_to_dict(item_content))
                            else:
                                result[tag].append(item_content)
                    else:
                        result[tag] = content
            else:
                result[tag] = content
        
        return result if result else xml_content
    
    def _filter_malformed_content(self, text: str) -> str:
        """Filter out malformed SEARCH/REPLACE blocks - delegates to standalone function."""
        # Only filter if --reply-filters is set for text models (generic)
        if check_reply_filter('malformed', 'text'):
            return filter_malformed_content(text)
        return text
    
    def extract_tool_calls(self, text: str, available_tools: List[Tool]) -> Optional[List[Dict]]:
        """Extract tool calls from model output."""
        # First filter out malformed content
        text = self._filter_malformed_content(text)
        
        # For Qwen models, try Qwen-specific parsing first
        if self._is_qwen_model():
            qwen_tool_calls = self._parse_qwen_tool_calls(text)
            if qwen_tool_calls:
                return qwen_tool_calls
        
        tool_calls = []
        seen_signatures = set()  # Track seen tool calls to avoid duplicates
        
        # Look for function calls in various formats
        # Format 1: <tool> or <function> tags with JSON content
        tool_pattern = r'<(?:tool|function)>(.*?)</(?:tool|function)>'
        tool_matches = re.findall(tool_pattern, text, re.DOTALL)
        
        for match in tool_matches:
            # Try JSON format first
            try:
                tool_data = json.loads(match.strip())
                if 'name' in tool_data and 'arguments' in tool_data:
                    # Create a signature to detect duplicates
                    sig = (tool_data['name'], json.dumps(tool_data['arguments'], sort_keys=True))
                    if sig not in seen_signatures:
                        seen_signatures.add(sig)
                        tool_calls.append({
                            "id": f"call_{uuid.uuid4().hex[:16]}",
                            "type": "function",
                            "function": {
                                "name": tool_data["name"],
                                "arguments": json.dumps(tool_data["arguments"])
                            }
                        })
                    continue
            except json.JSONDecodeError:
                pass
            
            # Try nested XML format (Format 1b)
            try:
                tool_data = self._parse_nested_xml_tool(match)
                if tool_data and 'name' in tool_data and 'arguments' in tool_data:
                    args_str = json.dumps(tool_data["arguments"]) if isinstance(tool_data["arguments"], dict) else str(tool_data["arguments"])
                    sig = (tool_data['name'], args_str)
                    if sig not in seen_signatures:
                        seen_signatures.add(sig)
                        tool_calls.append({
                            "id": f"call_{uuid.uuid4().hex[:16]}",
                            "type": "function",
                            "function": {
                                "name": tool_data["name"],
                                "arguments": args_str
                            }
                        })
            except Exception:
                pass
        
        # Format 2: JSON with function_call key
        try:
            if "function_call" in text:
                json_match = re.search(r'\{[^}]*"function_call"[^}]*\}', text, re.DOTALL)
                if json_match:
                    data = json.loads(json_match.group())
                    if "function_call" in data:
                        fc = data["function_call"]
                        sig = (fc.get("name", ""), fc.get("arguments", "{}"))
                        if sig not in seen_signatures:
                            seen_signatures.add(sig)
                            tool_calls.append({
                                "id": f"call_{uuid.uuid4().hex[:16]}",
                                "type": "function",
                                "function": {
                                    "name": fc.get("name", ""),
                                    "arguments": fc.get("arguments", "{}")
                                }
                            })
        except (json.JSONDecodeError, AttributeError):
            pass
        
        # Format 3: Direct JSON function call
        for tool in available_tools:
            tool_name = tool.function.name
            pattern = rf'<{tool_name}>(.*?)</{tool_name}>'
            matches = re.findall(pattern, text, re.DOTALL)
            for match in matches:
                try:
                    args = json.loads(match.strip())
                    sig = (tool_name, json.dumps(args, sort_keys=True))
                    if sig not in seen_signatures:
                        seen_signatures.add(sig)
                        tool_calls.append({
                            "id": f"call_{uuid.uuid4().hex[:16]}",
                            "type": "function",
                            "function": {
                                "name": tool_name,
                                "arguments": json.dumps(args)
                            }
                        })
                except json.JSONDecodeError:
                    # Try to use the raw text as arguments
                    sig = (tool_name, match.strip())
                    if sig not in seen_signatures:
                        seen_signatures.add(sig)
                        tool_calls.append({
                            "id": f"call_{uuid.uuid4().hex[:16]}",
                            "type": "function",
                            "function": {
                                "name": tool_name,
                                "arguments": match.strip()
                            }
                        })
        
        return tool_calls if tool_calls else None

    def strip_tool_calls_from_content(self, text: str) -> str:
        """Remove tool call format from text after extracting tool calls."""
        if not text:
            return text
        
        # Use greedy matching with re.DOTALL to match across newlines and handle nested content
        # Remove <tool>...</tool> and <function>...</function> patterns
        text = re.sub(r'<tool>.*?</tool>', '', text, flags=re.DOTALL)
        text = re.sub(r'<function>.*?</function>', '', text, flags=re.DOTALL)
        
        # Also remove JSON format with greedy matching: <tool>{...}</tool>
        # Use greedy .* to match the entire JSON object including nested braces
        text = re.sub(r'<tool>\{.*?\}</tool>', '', text, flags=re.DOTALL)
        text = re.sub(r'<function>\{.*?\}</function>', '', text, flags=re.DOTALL)
        
        # More aggressive pattern: match <tool> followed by anything until </tool>
        # This handles cases where JSON might be split or malformed
        text = re.sub(r'<tool>[\s\S]*?</tool>', '', text)
        text = re.sub(r'<function>[\s\S]*?</function>', '', text)
        
        # Also remove common tool name tags like <read>...</read>, <exec>...</exec>
        for tool_name in ['read', 'write', 'exec', 'browser', 'message', 'web_search', 'web_fetch', 
                         'memory_search', 'memory_get', 'sessions_list', 'sessions_send', 'tts', 'canvas', 'nodes',
                         'read_file', 'write_file', 'exec', 'process', 'browser', 'message', 'web_search', 'web_fetch',
                         'tts', 'canvas', 'nodes', 'agents_list', 'sessions_list', 'sessions_history', 'sessions_spawn',
                         'subagents', 'session_status', 'memory_search', 'memory_get']:
            text = re.sub(rf'<{tool_name}>[\s\S]*?</{tool_name}>', '', text)
        
        # Clean up excessive newlines left from removal
        text = re.sub(r'\n{3,}', '\n\n', text)
        
        # Don't strip whitespace - spaces and newlines are valid content
        
        return text

# =============================================================================
# Model Parser Dispatcher Wrapper - integrates model_parser module
# =============================================================================

class ModelParserAdapter:
    """Adapter class that wraps ModelParserDispatcher to provide ToolCallParser interface.
    
    This allows seamless integration of the model_parser module while maintaining
    compatibility with the existing coderai codebase.
    """
    
    def __init__(self, model_name: str = None, tools_schema: Dict = None):
        self._model_name = model_name
        self._tools_schema = tools_schema or {}
        self._dispatcher = ModelParserDispatcher(model_name=model_name, tools_schema=self._tools_schema)
    
    def set_model_name(self, model_name: str) -> None:
        """Set the model name for model-specific parsing."""
        self._model_name = model_name
        # Recreate dispatcher with new model name
        self._dispatcher = ModelParserDispatcher(model_name=model_name, tools_schema=self._tools_schema)
    
    def extract_tool_calls(self, text: str, available_tools: List[Tool]) -> Optional[List[Dict]]:
        """Extract tool calls from model output using model-specific parsing."""
        if not text:
            return None
        
        # Convert available_tools to the format expected by ModelParserDispatcher
        tools_dict = {}
        for tool in available_tools:
            if hasattr(tool, 'function') and tool.function:
                func = tool.function
                tools_dict[func.name] = {
                    'description': func.description or '',
                    'parameters': func.parameters or {}
                }
        
        # Update dispatcher tools if changed
        if tools_dict != self._tools_schema:
            self._tools_schema = tools_dict
            self._dispatcher.set_tools(tools_dict)
        
        # Use the dispatcher to parse tool calls
        tool_calls = self._dispatcher.parse(text)
        
        if tool_calls:
            # Add unique IDs to each tool call
            import uuid
            for tc in tool_calls:
                if 'id' not in tc:
                    tc['id'] = f"call_{uuid.uuid4().hex[:16]}"
                if 'type' not in tc:
                    tc['type'] = 'function'
            return tool_calls
        
        return None
    
    def strip_tool_calls_from_content(self, text: str) -> str:
        """Remove tool call format from text after extracting tool calls."""
        if not text:
            return text
        
        # Remove Qwen format: <tool=func_name>...</tool_call> and <tool=func_name>...</tool>
        text = re.sub(r'<tool=[^>]+>.*?</tool_call>', '', text, flags=re.DOTALL)
        text = re.sub(r'<tool=[^>]+>.*?</tool>', '', text, flags=re.DOTALL)
        
        # Remove <tool>...</tool> and <function>...</function> patterns
        text = re.sub(r'<tool>.*?</tool>', '', text, flags=re.DOTALL)
        text = re.sub(r'<function>.*?</function>', '', text, flags=re.DOTALL)
        
        # Remove JSON format with greedy matching: <tool>{...}</tool>
        text = re.sub(r'<tool>\{.*?\}</tool>', '', text, flags=re.DOTALL)
        text = re.sub(r'<function>\{.*?\}</function>', '', text, flags=re.DOTALL)
        
        # More aggressive pattern
        text = re.sub(r'<tool>[\s\S]*?</tool>', '', text)
        text = re.sub(r'<function>[\s\S]*?</function>', '', text)
        
        # Remove common tool name tags
        for tool_name in ['read', 'write', 'exec', 'browser', 'message', 'web_search', 'web_fetch', 
                         'memory_search', 'memory_get', 'sessions_list', 'sessions_send', 'tts', 'canvas', 'nodes',
                         'read_file', 'write_file', 'exec', 'process', 'browser', 'message', 'web_search', 'web_fetch',
                         'tts', 'canvas', 'nodes', 'agents_list', 'sessions_list', 'sessions_history', 'sessions_spawn',
                         'subagents', 'session_status', 'memory_search', 'memory_get']:
            text = re.sub(rf'<{tool_name}>[\s\S]*?</{tool_name}>', '', text)
        
        # Clean up excessive newlines
        text = re.sub(r'\n{3,}', '\n\n', text)
        
        return text

def format_tools_for_prompt(tools: List[Tool], messages: List[ChatMessage]) -> List[ChatMessage]:
    """Format tools into the system message or add a tool description."""
    if not tools:
        return messages
    
    tool_descriptions = []
    for tool in tools:
        func = tool.function
        desc = f"Tool: {func.name}"
        if func.description:
            desc += f"\nDescription: {func.description}"
        if func.parameters:
            desc += f"\nParameters: {json.dumps(func.parameters, indent=2)}"
        tool_descriptions.append(desc)
    
    tools_text = "You have access to the following tools:\n\n" + "\n\n".join(tool_descriptions)
    tools_text += "\n\nIMPORTANT: When you need to use a tool, you MUST format your response EXACTLY as:\n"
    tools_text += '<tool>{"name": "tool_name", "arguments": {"param1": "value1", "param2": "value2"}}</tool>'
    tools_text += "\n\nRules:\n"
    tools_text += "1. The content inside <tool> tags must be valid JSON\n"
    tools_text += "2. Do NOT use nested XML tags like <name> or <arguments> - use JSON format only\n"
    tools_text += "3. The 'name' field must match one of the available tool names exactly\n"
    tools_text += "4. The 'arguments' field must be a JSON object with the required parameters\n"
    tools_text += "\nExample:\n"
    tools_text += 'User: Read the file example.txt\n'
    tools_text += 'Assistant: <tool>{"name": "read_file", "arguments": {"files": [{"path": "example.txt"}]}}</tool>'
    
    # Add or prepend to system message
    new_messages = list(messages)
    system_found = False
    
    for i, msg in enumerate(new_messages):
        if msg.role == "system":
            new_messages[i] = ChatMessage(
                role="system",
                content=f"{tools_text}\n\n{msg.content or ''}"
            )
            system_found = True
            break
    
    if not system_found:
        new_messages.insert(0, ChatMessage(role="system", content=tools_text))
    
    return new_messages
# =============================================================================
# Abstract Model Backend
# =============================================================================

class ModelBackend(ABC):
    """Abstract base class for model backends."""
    
    @abstractmethod
    def load_model(self, model_name: str, **kwargs) -> None:
        """Load the model."""
        pass
    
    @abstractmethod
    def generate(self, prompt: str, max_tokens: Optional[int] = None, 
                 temperature: float = 0.7, top_p: float = 1.0,
                 stop: Optional[List[str]] = None) -> str:
        """Generate text non-streaming."""
        pass
    
    @abstractmethod
    def generate_stream(self, prompt: str, max_tokens: Optional[int] = None,
                        temperature: float = 0.7, top_p: float = 1.0,
                        stop: Optional[List[str]] = None) -> AsyncGenerator[str, None]:
        """Generate text in streaming fashion."""
        pass
    
    @abstractmethod
    def format_messages(self, messages: List[ChatMessage]) -> str:
        """Format messages into a prompt string."""
        pass
    
    @abstractmethod
    def get_model_name(self) -> str:
        """Return the loaded model name."""
        pass
    
    @abstractmethod
    def cleanup(self) -> None:
        """Cleanup resources."""
        pass
# =============================================================================
# NVIDIA/HuggingFace Backend
# =============================================================================

class NvidiaBackend(ModelBackend):
    """Backend for NVIDIA GPUs using HuggingFace Transformers."""
    
    def __init__(self):
        self.model = None
        self.tokenizer = None
        self.model_name = None
        self.device = None
        self.use_flash_attn = False
        self.flash_attn_available = False
        
    def check_flash_attn_support(self) -> None:
        """Check and print Flash Attention availability status."""
        self.flash_attn_available = check_flash_attn_availability()
        if self.use_flash_attn:
            if self.flash_attn_available:
                print("Flash Attention 2: Available and enabled")
            else:
                print("Warning: Flash Attention 2 requested but not installed")
                print("Install with: pip install flash-attn --no-build-isolation")
                print("Falling back to standard attention")
                self.use_flash_attn = False
    
    def _detect_device(self) -> str:
        """Auto-detect available GPU or fall back to CPU."""
        import torch
        if torch.cuda.is_available():
            # Check for ROCm (HIP)
            if hasattr(torch.version, 'hip') and torch.version.hip is not None:
                print(f"ROCm/HIP detected: {torch.version.hip}")
                return "cuda"
            else:
                print(f"CUDA detected: {torch.version.cuda}")
                return "cuda"
        else:
            print("No GPU detected, using CPU")
            return "cpu"
    
    def _get_available_vram(self) -> int:
        """Get available VRAM in bytes. Returns 0 if no GPU available."""
        import torch
        if not torch.cuda.is_available():
            return 0
        
        try:
            total_vram = 0
            for i in range(torch.cuda.device_count()):
                props = torch.cuda.get_device_properties(i)
                total_vram += props.total_memory
            return total_vram
        except Exception as e:
            print(f"Warning: Could not detect VRAM: {e}")
            return 0
    
    def _estimate_model_size(self, model_name: str) -> Optional[int]:
        """Estimate model size in bytes from config."""
        from transformers import AutoConfig
        try:
            config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
            
            # Get model parameters from config
            if hasattr(config, 'num_parameters'):
                num_params = config.num_parameters
            elif hasattr(config, 'n_params'):
                num_params = config.n_params
            elif hasattr(config, 'num_hidden_layers') and hasattr(config, 'hidden_size'):
                layers = config.num_hidden_layers
                hidden = config.hidden_size
                vocab_size = getattr(config, 'vocab_size', 50000)
                num_params = (vocab_size * hidden_size) + (layers * 4 * hidden * hidden)
            else:
                return None
            
            # Assume float16 (2 bytes per parameter)
            return num_params * 2
        except Exception as e:
            print(f"Warning: Could not estimate model size: {e}")
            return None
    
    def _get_gpu_memory_map(self) -> Dict:
        """Get max_memory dict for Accelerate with 93% GPU limit, then CPU, then disk."""
        import torch
        max_memory = {}
        
        # GPU memory: 93% of available VRAM per GPU
        if torch.cuda.is_available():
            for i in range(torch.cuda.device_count()):
                props = torch.cuda.get_device_properties(i)
                total_vram = props.total_memory
                # Leave 7% headroom for CUDA overhead (changed from 0.1% to 7%)
                usable_vram = int(total_vram * 0.93)
                max_memory[i] = usable_vram
                print(f"  GPU {i}: {total_vram / 1e9:.1f}GB total, {usable_vram / 1e9:.1f}GB usable")
        
        # CPU memory: use manual limit or auto-detect
        manual_ram_gb = self._pending_ram_gb
        if manual_ram_gb:
            # Convert GB to bytes
            max_memory['cpu'] = int(manual_ram_gb * 1e9)
            print(f"  CPU: {manual_ram_gb}GB (user specified)")
        else:
            # Auto-detect available system RAM, leave 4GB for system
            import psutil
            available_ram = psutil.virtual_memory().available
            usable_ram = max(0, available_ram - int(4e9))  # Leave 4GB for OS
            max_memory['cpu'] = usable_ram
            print(f"  CPU: {usable_ram / 1e9:.1f}GB (auto-detected, 4GB reserved for system)")
        
        return max_memory
    
    def _try_load_model(self, model_name: str, load_kwargs: dict, device: str) -> Optional[any]:
        """Try to load model with given settings, return None on OOM or retry without quantization if unsupported."""
        import torch
        from transformers import AutoModelForCausalLM
        
        try:
            model = AutoModelForCausalLM.from_pretrained(model_name, **load_kwargs)
            if device == "cpu" and load_kwargs.get('device_map') is None:
                model = model.to(device)
            return model
        except (RuntimeError, torch.cuda.OutOfMemoryError) as e:
            error_msg = str(e).lower()
            if "out of memory" in error_msg or "cuda" in error_msg or "oom" in error_msg:
                return None
            raise
        except TypeError as e:
            error_msg = str(e).lower()
            # Check if the error is about unsupported quantization arguments
            if "load_in_4bit" in error_msg or "load_in_8bit" in error_msg or "unexpected keyword argument" in error_msg:
                # Check if we have quantization args that need to be removed
                if 'load_in_4bit' in load_kwargs or 'load_in_8bit' in load_kwargs:
                    print(f"Warning: Model does not support bitsandbytes quantization (load_in_4bit/load_in_8bit)")
                    print("Retrying without quantization...")
                    # Create a copy of load_kwargs without quantization args
                    retry_kwargs = load_kwargs.copy()
                    retry_kwargs.pop('load_in_4bit', None)
                    retry_kwargs.pop('load_in_8bit', None)
                    # Retry loading without quantization
                    try:
                        model = AutoModelForCausalLM.from_pretrained(model_name, **retry_kwargs)
                        if device == "cpu" and retry_kwargs.get('device_map') is None:
                            model = model.to(device)
                        print("Model loaded successfully without quantization")
                        return model
                    except (RuntimeError, torch.cuda.OutOfMemoryError) as e2:
                        error_msg2 = str(e2).lower()
                        if "out of memory" in error_msg2 or "cuda" in error_msg2 or "oom" in error_msg2:
                            return None
                        raise
                    except TypeError:
                        # If it still fails with TypeError, re-raise the original error
                        raise e
            # Re-raise if not a quantization-related error
            raise
    
    def _is_moe_model(self, model_name: str) -> bool:
        """Check if model is a MoE (Mixture of Experts) model which needs more VRAM headroom."""
        moe_indicators = ['moe', 'mixtral', 'qwen3_5_moe', 'qwen3.5_moe', 'expert', 'a3b']
        model_name_lower = model_name.lower()
        return any(indicator in model_name_lower for indicator in moe_indicators)
    
    def _get_vram_percentages_for_strategy(self, strategy: str, is_moe: bool, total_vram_gb: float) -> list:
        """Get VRAM percentage steps based on offload strategy."""
        if strategy == "conservative":
            print(f"  Using conservative offload strategy - minimal VRAM usage for maximum stability")
            if is_moe:
                return [0.70, 0.65, 0.60, 0.50, 0.40, 0.30, 0.20, 0.0]
            return [0.80, 0.75, 0.70, 0.65, 0.50, 0.40, 0.30, 0.20, 0.0]
        elif strategy == "balanced":
            print(f"  Using balanced offload strategy - good performance with reasonable stability")
            if is_moe:
                return [0.75, 0.70, 0.65, 0.60, 0.50, 0.40, 0.30, 0.20, 0.0]
            return [0.85, 0.80, 0.75, 0.70, 0.65, 0.50, 0.40, 0.30, 0.20, 0.0]
        elif strategy == "aggressive":
            print(f"  Using aggressive offload strategy - maximize VRAM usage for performance")
            if is_moe:
                return [0.85, 0.80, 0.75, 0.70, 0.65, 0.60, 0.50, 0.40, 0.30, 0.20, 0.0]
            return [0.95, 0.90, 0.85, 0.80, 0.75, 0.70, 0.65, 0.50, 0.40, 0.30, 0.20, 0.0]
        elif strategy == "sequential":
            print(f"  Using sequential offload strategy - fine-grained incremental VRAM reduction")
            # Fine-grained steps with 2% increments for precise memory management
            if is_moe:
                return [0.80, 0.78, 0.76, 0.74, 0.72, 0.70, 0.68, 0.66, 0.64, 0.62, 0.60, 0.55, 0.50, 0.45, 0.40, 0.35, 0.30, 0.25, 0.20, 0.0]
            return [0.93, 0.91, 0.89, 0.87, 0.85, 0.83, 0.81, 0.79, 0.77, 0.75, 0.73, 0.71, 0.69, 0.67, 0.65, 0.60, 0.55, 0.50, 0.45, 0.40, 0.35, 0.30, 0.20, 0.0]
        else:  # auto
            if total_vram_gb < 3:
                print(f"  Detected small GPU ({total_vram_gb:.1f}GB), using aggressive VRAM usage (99% start)")
                return [0.99, 0.95, 0.90, 0.85, 0.75, 0.65, 0.50, 0.35, 0.20, 0.0]
            elif total_vram_gb <= 8:
                print(f"  Detected medium GPU ({total_vram_gb:.1f}GB), using high VRAM usage (96% start)")
                return [0.96, 0.90, 0.85, 0.75, 0.65, 0.50, 0.35, 0.20, 0.0]
            else:
                if is_moe:
                    print(f"  Detected large GPU ({total_vram_gb:.1f}GB), using MoE-safe VRAM usage (80% start)")
                    return [0.80, 0.75, 0.70, 0.65, 0.60, 0.50, 0.40, 0.30, 0.20, 0.0]
                else:
                    print(f"  Detected large GPU ({total_vram_gb:.1f}GB), using conservative VRAM usage (93% start)")
                    return [0.93, 0.85, 0.75, 0.65, 0.50, 0.35, 0.20, 0.0]
    
    def _get_vram_percentages_for_gpu(self, model_name: str = "", strategy: str = "auto", max_gpu_percent: float = None) -> list:
        """Get VRAM percentage steps based on GPU memory size, model type, and offload strategy."""
        import torch
        
        if not torch.cuda.is_available():
            return [0.0]  # CPU only
        
        # If max_gpu_percent is specified, use it to create custom percentage steps
        if max_gpu_percent is not None:
            # Clamp to valid range (5-100%)
            max_pct = max(0.05, min(1.0, max_gpu_percent / 100.0))
            print(f"  Using custom max GPU percent: {max_pct*100:.0f}%")
            # Create a descending series from max_pct down to 0
            steps = []
            current = max_pct
            while current > 0.05:
                steps.append(current)
                # Reduce by 5% each step, or smaller steps near the end
                if current > 0.3:
                    current -= 0.05
                elif current > 0.15:
                    current -= 0.03
                else:
                    current -= 0.02
            steps.append(0.0)
            return steps
        
        # Get total VRAM of the first GPU
        total_vram_gb = 0
        for i in range(torch.cuda.device_count()):
            props = torch.cuda.get_device_properties(i)
            total_vram_gb += props.total_memory / 1e9
        
        # Check if MoE model (needs more headroom for generation)
        is_moe = self._is_moe_model(model_name)
        if is_moe:
            print(f"  Detected MoE model, using extra conservative VRAM limits for generation headroom")
        
        return self._get_vram_percentages_for_strategy(strategy, is_moe, total_vram_gb)
    
    def load_model(self, model_name: str, **kwargs) -> None:
        """Load the model using HuggingFace Transformers with automatic OOM handling."""
        import torch
        from transformers import AutoModelForCausalLM, AutoTokenizer
        
        offload_dir = kwargs.get('offload_dir')
        load_in_4bit = kwargs.get('load_in_4bit', False)
        load_in_8bit = kwargs.get('load_in_8bit', False)
        manual_ram_gb = kwargs.get('manual_ram_gb')
        flash_attn = kwargs.get('flash_attn', False)
        offload_strategy = kwargs.get('offload_strategy', 'auto')
        max_gpu_percent = kwargs.get('max_gpu_percent', None)
        
        # Store RAM limit for use in _get_gpu_memory_map
        self._pending_ram_gb = manual_ram_gb
        
        print(f"Loading HuggingFace model: {model_name}")
        
        self.use_flash_attn = flash_attn
        self.check_flash_attn_support()
        
        self.device = self._detect_device()
        
        # Load tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(
            model_name,
            trust_remote_code=True,
            padding_side="left"
        )
        
        # Set pad token if not present
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
        
        # Prepare model loading arguments
        load_kwargs = {'trust_remote_code': True}
        
        # Check if model supports quantization
        if load_in_4bit or load_in_8bit:
            # Qwen3.5-A3B/MoE models don't support bitsandbytes quantization
            if 'qwen3.5' in model_name.lower() and ('a3b' in model_name.lower() or 'moe' in model_name.lower()):
                print(f"Warning: {model_name} does not support bitsandbytes quantization (load_in_4bit/load_in_8bit)")
                print("Quantization disabled for this model")
            else:
                try:
                    import bitsandbytes as bnb
                    print(f"Using {4 if load_in_4bit else 8}-bit quantization")
                    load_kwargs['load_in_4bit'] = load_in_4bit
                    load_kwargs['load_in_8bit'] = load_in_8bit
                except ImportError:
                    print("Warning: bitsandbytes not installed. Quantization disabled.")
        
        # Set dtype
        if self.device == "cuda":
            load_kwargs['torch_dtype'] = torch.float16
        else:
            load_kwargs['torch_dtype'] = torch.float32
        
        # Add offload folder if specified (disk offloading is last resort)
        if offload_dir:
            os.makedirs(offload_dir, exist_ok=True)
            load_kwargs['offload_folder'] = offload_dir
        
        # Add Flash Attention 2 if enabled
        if self.use_flash_attn and self.flash_attn_available:
            load_kwargs['attn_implementation'] = "flash_attention_2"
            print("Using Flash Attention 2")
        
        # Try loading with automatic fallback on OOM
        model = None
        vram_percentages = self._get_vram_percentages_for_gpu(model_name, offload_strategy, max_gpu_percent)
        first_vram_pct = vram_percentages[0] if vram_percentages else 0.93
        
        for vram_pct in vram_percentages:
            if self.device != "cuda":
                # CPU-only mode
                load_kwargs['device_map'] = None
                print("Loading model in CPU-only mode...")
                model = self._try_load_model(model_name, load_kwargs, self.device)
                if model is not None:
                    break
            
            # GPU mode with varying VRAM limits
            if vram_pct > 0:
                max_memory = self._get_gpu_memory_map_with_limit(vram_pct)
                load_kwargs['max_memory'] = max_memory
                load_kwargs['device_map'] = 'auto'
                print(f"\nTrying with GPU limit: {vram_pct*100:.0f}% VRAM")
                if offload_dir:
                    print(f"  Disk offload directory: {offload_dir}")
                
                model = self._try_load_model(model_name, load_kwargs, self.device)
                
                if model is not None:
                    print(f"  ✓ Model loaded successfully with {vram_pct*100:.0f}% GPU VRAM limit")
                    if vram_pct < first_vram_pct:
                        print(f"  (Reduced from {first_vram_pct*100:.0f}% due to memory constraints)")
                    break
                else:
                    print(f"  ✗ Out of memory with {vram_pct*100:.0f}% GPU VRAM, trying lower limit...")
                    # Clear CUDA cache before retry
                    if torch.cuda.is_available():
                        torch.cuda.empty_cache()
            else:
                # Last resort: CPU-only mode with offloading
                print("\nFalling back to CPU-only mode (no GPU layers)...")
                load_kwargs['max_memory'] = {0: 0, 'cpu': int((manual_ram_gb or 48) * 1e9)}
                load_kwargs['device_map'] = 'auto'
                model = self._try_load_model(model_name, load_kwargs, "cpu")
                if model is not None:
                    print("  ✓ Model loaded successfully on CPU")
                    break
        
        if model is None:
            raise RuntimeError("Failed to load model: Out of memory even with minimum GPU usage")
        
        self.model = model
        self.model.eval()
        self.model_name = model_name
        
        print(f"\nModel loaded successfully")
        print(f"Model device: {next(self.model.parameters()).device}")
        
        # Show model capabilities
        caps = detect_model_capabilities(model_name)
        print(f"Model capabilities: {caps}")
    
    def _get_gpu_memory_map_with_limit(self, vram_fraction: float) -> Dict:
        """Get max_memory dict with specified VRAM fraction limit."""
        import torch
        max_memory = {}
        
        if torch.cuda.is_available():
            for i in range(torch.cuda.device_count()):
                props = torch.cuda.get_device_properties(i)
                total_vram = props.total_memory
                usable_vram = int(total_vram * vram_fraction)
                max_memory[i] = usable_vram
        
        # CPU memory
        manual_ram_gb = getattr(self, '_pending_ram_gb', None)
        if manual_ram_gb:
            max_memory['cpu'] = int(manual_ram_gb * 1e9)
        else:
            import psutil
            available_ram = psutil.virtual_memory().available
            usable_ram = max(0, available_ram - int(4e9))
            max_memory['cpu'] = usable_ram
        
        return max_memory
    
    def format_messages(self, messages: List[ChatMessage]) -> str:
        """Format messages into a prompt string."""
        formatted = []
        
        for msg in messages:
            if msg.role == "system":
                formatted.append(f"System: {msg.content}")
            elif msg.role == "user":
                formatted.append(f"User: {msg.content}")
            elif msg.role == "assistant":
                content = msg.content or ""
                if msg.tool_calls:
                    for tc in msg.tool_calls:
                        if tc.get("function"):
                            func = tc["function"]
                            content += f'\n<tool>{{"name": "{func.get("name", "")}", "arguments": {func.get("arguments", "{}")}}}</tool>'
                formatted.append(f"Assistant: {content}")
            elif msg.role == "tool":
                formatted.append(f"Tool ({msg.name}): {msg.content}")
        
        formatted.append("Assistant:")
        return "\n\n".join(formatted)
    
    def _validate_params(self, temperature: float, top_p: float) -> tuple:
        """Validate generation parameters."""
        if temperature <= 0:
            temperature = 1.0
            do_sample = False
        else:
            temperature = max(0.01, min(temperature, 2.0))
            do_sample = True
        top_p = max(0.0, min(top_p, 1.0))
        return temperature, top_p, do_sample
    
    def generate(self, prompt: str, max_tokens: Optional[int] = None,
                 temperature: float = 0.7, top_p: float = 1.0,
                 stop: Optional[List[str]] = None) -> str:
        """Generate text non-streaming."""
        import torch
        from transformers import LogitsProcessor, LogitsProcessorList
        
        class InvalidLogitsProcessor(LogitsProcessor):
            def __call__(self, input_ids, scores):
                scores = torch.where(torch.isnan(scores), torch.tensor(-1e9, dtype=scores.dtype, device=scores.device), scores)
                scores = torch.where(torch.isinf(scores), torch.tensor(1e9, dtype=scores.dtype, device=scores.device), scores)
                return scores
        
        inputs = self.tokenizer(prompt, return_tensors="pt", padding=True)
        inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
        
        if max_tokens is None:
            max_tokens = 512
        
        temperature, top_p, do_sample = self._validate_params(temperature, top_p)
        
        # Try generation with OOM handling
        try:
            with torch.no_grad():
                outputs = self.model.generate(
                    input_ids=inputs["input_ids"],
                    attention_mask=inputs["attention_mask"],
                    max_new_tokens=max_tokens,
                    temperature=temperature if do_sample else None,
                    top_p=top_p if do_sample else None,
                    do_sample=do_sample,
                    pad_token_id=self.tokenizer.pad_token_id,
                    eos_token_id=self.tokenizer.eos_token_id,
                    logits_processor=LogitsProcessorList([InvalidLogitsProcessor()]),
                )
            
            generated_tokens = outputs[0][inputs["input_ids"].shape[1]:]
            return self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
        except (RuntimeError, torch.cuda.OutOfMemoryError) as e:
            error_msg = str(e).lower()
            if "out of memory" in error_msg or "cuda" in error_msg or "oom" in error_msg:
                print(f"Warning: CUDA OOM during generation. Clearing cache and retrying with reduced tokens...")
                torch.cuda.empty_cache()
                # Retry with half the tokens
                try:
                    with torch.no_grad():
                        outputs = self.model.generate(
                            input_ids=inputs["input_ids"],
                            attention_mask=inputs["attention_mask"],
                            max_new_tokens=max(1, max_tokens // 2),
                            temperature=temperature if do_sample else None,
                            top_p=top_p if do_sample else None,
                            do_sample=do_sample,
                            pad_token_id=self.tokenizer.pad_token_id,
                            eos_token_id=self.tokenizer.eos_token_id,
                            logits_processor=LogitsProcessorList([InvalidLogitsProcessor()]),
                        )
                    
                    generated_tokens = outputs[0][inputs["input_ids"].shape[1]:]
                    return self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
                except Exception as e2:
                    print(f"Error: Generation failed even with reduced tokens: {e2}")
                    return "[Error: Out of memory during generation. Try reducing --max-gpu-percent or using a smaller model.]"
            raise
    
    async def generate_stream(self, prompt: str, max_tokens: Optional[int] = None,
                              temperature: float = 0.7, top_p: float = 1.0,
                              stop: Optional[List[str]] = None) -> AsyncGenerator[str, None]:
        """Generate text in streaming fashion."""
        import torch
        from transformers import TextIteratorStreamer, LogitsProcessor, LogitsProcessorList, StoppingCriteria, StoppingCriteriaList
        
        class InvalidLogitsProcessor(LogitsProcessor):
            def __call__(self, input_ids, scores):
                scores = torch.where(torch.isnan(scores), torch.tensor(-1e9, dtype=scores.dtype, device=scores.device), scores)
                scores = torch.where(torch.isinf(scores), torch.tensor(1e9, dtype=scores.dtype, device=scores.device), scores)
                return scores
        
        inputs = self.tokenizer(prompt, return_tensors="pt", padding=True)
        inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
        
        if max_tokens is None:
            max_tokens = 512
        
        temperature, top_p, do_sample = self._validate_params(temperature, top_p)
        
        streamer = TextIteratorStreamer(
            self.tokenizer,
            skip_prompt=True,
            skip_special_tokens=True,
        )
        
        generation_kwargs = {
            "input_ids": inputs["input_ids"],
            "attention_mask": inputs["attention_mask"],
            "max_new_tokens": max_tokens,
            "temperature": temperature,
            "top_p": top_p,
            "do_sample": do_sample,
            "streamer": streamer,
            "pad_token_id": self.tokenizer.pad_token_id,
            "eos_token_id": self.tokenizer.eos_token_id,
            "logits_processor": LogitsProcessorList([InvalidLogitsProcessor()]),
        }
        
        # Handle stop sequences
        if stop:
            class StopOnSequence(StoppingCriteria):
                def __init__(self, stop_sequences, tokenizer):
                    self.stop_sequences = stop_sequences
                    self.tokenizer = tokenizer
                
                def __call__(self, input_ids, scores, **kwargs):
                    decoded = self.tokenizer.decode(input_ids[0][-20:], skip_special_tokens=True)
                    return any(seq in decoded for seq in self.stop_sequences)
            
            generation_kwargs["stopping_criteria"] = StoppingCriteriaList([
                StopOnSequence(stop, self.tokenizer)
            ])
        
        # Run generation in a separate thread with OOM handling
        generation_error = None
        
        def generate_with_error_handling():
            nonlocal generation_error
            try:
                self.model.generate(**generation_kwargs)
            except (RuntimeError, torch.cuda.OutOfMemoryError) as e:
                error_msg = str(e).lower()
                if "out of memory" in error_msg or "cuda" in error_msg or "oom" in error_msg:
                    generation_error = "oom"
                    print(f"Warning: CUDA OOM during streaming generation. Clearing cache...")
                    torch.cuda.empty_cache()
                else:
                    generation_error = str(e)
        
        thread = Thread(target=generate_with_error_handling)
        thread.start()
        
        try:
            for text in streamer:
                yield text
        except Exception as e:
            print(f"Error during stream iteration: {e}")
        
        thread.join()
        
        # Check if there was an OOM error
        if generation_error == "oom":
            yield "\n[Warning: Generation stopped due to out-of-memory. Try reducing --max-gpu-percent.]"
        elif generation_error:
            yield f"\n[Error during generation: {generation_error}]"
    
    def get_model_name(self) -> str:
        return self.model_name or "unknown"
    
    def cleanup(self) -> None:
        import torch
        if self.model is not None:
            del self.model
            del self.tokenizer
            self.model = None
            self.tokenizer = None
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
# =============================================================================
# Vulkan Backend (llama-cpp-python)
# =============================================================================

class VulkanBackend(ModelBackend):
    """Backend for Vulkan (AMD GPUs) using llama-cpp-python with GGUF models."""
    
    def __init__(self, original_backend: str = None):
        self.model = None
        self.model_name = None
        self.n_gpu_layers = -1  # Offload all layers to GPU by default
        self.n_ctx = 2048
        self.verbose = True
        self.main_gpu = 0  # Default to first GPU
        self.chat_template = None  # Detected chat template name
        self.hf_tokenizer = None  # HuggingFace tokenizer for apply_chat_template
        self.force_cuda = original_backend in ("nvidia", "cuda")  # Force CUDA if original was nvidia
        if self.force_cuda:
            print("DEBUG: GGUF model will use CUDA backend (forced by --backend nvidia)")
        self._detect_chat_template()
    
    def _detect_chat_template(self):
        """Detect the chat template used by the model."""
        try:
            # Try to get the chat template from the model
            # llama.cpp models have a chat_template attribute
            from llama_cpp.llama_chat_format import ChatFormatterResponse
            # We'll detect it when the model is loaded
            self.chat_template = "unknown"
            print("DEBUG: Chat template detection will happen after model load")
        except Exception as e:
            print(f"DEBUG: Could not initialize chat template detection: {e}")
            self.chat_template = None
    
    def _load_huggingface_tokenizer(self, template_name: str = None):
        """Load HuggingFace tokenizer for apply_chat_template support.
        
        Args:
            template_name: Optional specific template to use (e.g., 'llama3', 'chatml'). 
                          If None, will auto-detect from tokenizer.
        """
        if self.hf_tokenizer is not None:
            return  # Already loaded
        
        model_path = getattr(self, 'model_name', None)
        if not model_path:
            print("DEBUG: No model name available for HuggingFace tokenizer")
            return
        
        # If model_path is a URL, try to get the cached local path first
        if model_path.startswith('http://') or model_path.startswith('https://'):
            cached_path = get_cached_model_path(model_path)
            if cached_path and os.path.exists(cached_path):
                model_path = cached_path
                print(f"DEBUG: Using cached model path for HF tokenizer: {model_path}")
        
        try:
            from transformers import AutoTokenizer
            
            # If a specific template is provided, we can use it directly without loading tokenizer
            if template_name:
                self.chat_template = template_name
                print(f"DEBUG: Using specified chat template: {template_name}")
                # Still need to load tokenizer to get the actual template
                # but we can use the specified template name
            
            # Try to determine the model identifier
            # If model_path is a GGUF file, try to find the corresponding HF model
            if model_path.endswith('.gguf'):
                # Try to extract model name from path
                # Common patterns: .../models/llama-3.1-8b-instruct-q4_k_m.gguf
                model_dir = os.path.dirname(model_path)
                model_file = os.path.basename(model_path)
                
                # Try to find a tokenizer config in the model directory
                tokenizer_config_path = os.path.join(model_dir, 'tokenizer_config.json')
                if os.path.exists(tokenizer_config_path):
                    # Load from local directory
                    self.hf_tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
                    print(f"DEBUG: Loaded HuggingFace tokenizer from local: {model_dir}")
                    if not template_name:
                        self.chat_template = "hf_local"
                    return
                
                # Try to infer model name from file name
                # Common patterns: llama-3.1-8b-instruct-q4_k_m.gguf -> llama-3.1-8b-instruct
                # Also handle cached files with hash prefix: hash_modelname.gguf -> modelname
                model_base = model_file.replace('.gguf', '')
                
                # Remove hash prefix (64 hex chars for SHA-256 followed by underscore)
                if len(model_base) > 64 and model_base[:64].isalnum():
                    model_base = model_base[65:]  # Skip hash + underscore
                
                # Remove common quantization suffixes (case-insensitive)
                for suffix in ['_q4_k_m', '_q4_k', '_q5_k', '_q5_k_m', '_q8_0', '_f16', '_q4_0', '_q3_k_m', '_q2_k', '_Q4_K_M', '_Q4_K', '_Q5_K', '_Q5_K_M', '_Q8_0', '_F16', '_Q4_0', '_Q3_K_M', '_Q2_K']:
                    model_base = model_base.replace(suffix, '')
                
                # Try to load from HuggingFace hub
                # First try the cleaned model_base
                model_names_to_try = [model_base]
                
                # Generate shorter versions of the model name for fallback
                # E.g., Qwen3.5-27B-Uncensored-HauhauCS-Aggressive -> try shorter variants
                parts = model_base.split('-')
                if len(parts) > 1:
                    # Try progressively shorter names by removing parts from the end
                    for i in range(len(parts) - 1, 0, -1):
                        shorter_name = '-'.join(parts[:i])
                        if shorter_name and shorter_name != model_base:
                            model_names_to_try.append(shorter_name)
                
                # Also try with just the first part (e.g., "Qwen" from "Qwen3.5-27B...")
                if len(parts) > 1:
                    model_names_to_try.append(parts[0])
                
                tokenizer_loaded = False
                last_error = None
                for model_id in model_names_to_try:
                    try:
                        self.hf_tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
                        print(f"DEBUG: Loaded HuggingFace tokenizer from hub: {model_id}")
                        if not template_name:
                            self.chat_template = "hf_hub"
                        tokenizer_loaded = True
                        break
                    except Exception as fallback_err:
                        last_error = fallback_err
                        print(f"DEBUG: Could not load tokenizer from hub ({model_id}): {fallback_err}")
                        continue
                
                if tokenizer_loaded:
                    return
                
                # If HF tokenizer loading failed, try to use known template names based on model name
                # This helps when we can't find the tokenizer but know the model family
                model_base_lower = model_base.lower()
                
                # Check if this looks like a Qwen model
                known_templates_to_try = []
                if 'qwen' in model_base_lower:
                    # Try known Qwen template names in order of specificity
                    if 'qwen3.5' in model_base_lower or 'qwen3' in model_base_lower:
                        known_templates_to_try = ['qwen3', 'qwen', None]  # None means use manual formatting
                    elif 'qwen2' in model_base_lower:
                        known_templates_to_try = ['qwen2', 'qwen', None]
                    else:
                        known_templates_to_try = ['qwen', None]
                elif 'llama' in model_base_lower:
                    known_templates_to_try = ['llama3', 'llama', None]
                elif 'phi' in model_base_lower:
                    known_templates_to_try = ['phi', None]
                elif 'mistral' in model_base_lower or 'mixtral' in model_base_lower:
                    known_templates_to_try = ['mistral', None]
                
                # Try each known template - directly use the template name without loading tokenizer
                # This is the key fix: instead of trying to load more non-existent tokenizers,
                # directly set the chat_template to the known template name
                for template_name in known_templates_to_try:
                    if template_name is None:
                        # No more templates to try, use manual formatting with generic format
                        self.chat_template = "chatml"  # Use ChatML as generic fallback
                        print(f"DEBUG: No known templates worked, using generic ChatML format")
                        break
                    # Directly use this known template name - no need to load tokenizer
                    # The manual formatting will use <|im_start|> tags which work for most models
                    self.chat_template = template_name
                    print(f"DEBUG: Using known template '{template_name}' for model family detection")
                    # Successfully set template - don't try to load tokenizer
                    break
                
                if self.chat_template:
                    return
                
                # All attempts failed - warn but continue without template
                print(f"Warning: Could not load HuggingFace tokenizer for any variant of '{model_base}'")
                print(f"Warning: Will not use apply_chat_template - model will use manual formatting")
                self.chat_template = None
            else:
                # Not a GGUF file, try to load directly
                self.hf_tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
                print(f"DEBUG: Loaded HuggingFace tokenizer from: {model_path}")
                if not template_name:
                    self.chat_template = "hf"
                return
                    
        except ImportError as e:
            print(f"DEBUG: transformers not installed, cannot use HuggingFace chat template: {e}")
            self.chat_template = None
        except Exception as e:
            print(f"DEBUG: Failed to load HuggingFace tokenizer: {e}")
            self.chat_template = None
    
    def _finalize_chat_template_detection(self):
        """Finalize chat template detection after model is loaded."""
        # Check if we should use HuggingFace tokenizer for chat template
        # Try to get model info
        model_name = getattr(self, 'model_name', None) or "unknown"
        
        # Determine model type - text models use GGUF, images would be different
        model_type = "text"
        if model_name.startswith("image:"):
            model_type = "image"
        
        should_use, template_name = check_hf_chat_template(model_type, model_name)
        if should_use:
            self._load_huggingface_tokenizer(template_name)
            return
        
        try:
            # Try to get the chat template name from the model's chat formatter
            if hasattr(self.model, 'tokenizer') and self.model.tokenizer:
                tokenizer = self.model.tokenizer
                # Check if there's a chat_template attribute
                if hasattr(tokenizer, 'chat_template'):
                    template = tokenizer.chat_template
                    if template:
                        # Detect common templates
                        template_str = str(template)
                        if 'qwen' in template_str.lower():
                            self.chat_template = "qwen"
                        elif 'phi' in template_str.lower():
                            self.chat_template = "phi"
                        elif 'llama3' in template_str.lower() or 'llama-3' in template_str.lower():
                            self.chat_template = "llama3"
                        elif 'chatml' in template_str.lower():
                            self.chat_template = "chatml"
                        else:
                            self.chat_template = "default"
                        print(f"DEBUG: Detected chat template: {self.chat_template}")
                        return
            
            # Try a test message to see what format works
            test_messages = [{"role": "user", "content": "test"}]
            try:
                self.model.create_chat_completion(messages=test_messages, max_tokens=1)
                self.chat_template = "default"
                print("DEBUG: Chat template detected via test: default")
            except Exception as e:
                error_str = str(e).lower()
                if 'jinja' in error_str:
                    # Jinja template issue - try without tools
                    self.chat_template = "jinja_fallback"
                    print("DEBUG: Chat template detected: jinja_fallback (will use manual formatting)")
                else:
                    self.chat_template = "unknown"
                    print(f"DEBUG: Chat template detection failed: {e}")
        except Exception as e:
            self.chat_template = "unknown"
            print(f"DEBUG: Final chat template detection error: {e}")
        
    def list_vulkan_devices(self):
        """List available Vulkan GPU devices."""
        try:
            # Try to get device info via vulkaninfo or similar
            import subprocess
            result = subprocess.run(['vulkaninfo', '--summary'], capture_output=True, text=True)
            if result.returncode == 0:
                print("\nAvailable Vulkan devices:")
                print(result.stdout)
        except Exception:
            pass
    
    def count_vulkan_devices(self):
        """Count the number of Vulkan GPU devices available."""
        # llama.cpp filters out some devices (like CPU llvmpipe), so we need to
        # count only the devices that llama.cpp will actually use
        try:
            import subprocess
            result = subprocess.run(['vulkaninfo', '--summary'], capture_output=True, text=True)
            if result.returncode == 0:
                # Count GPU0, GPU1, etc. entries (these are actual GPUs, not CPU)
                import re
                gpu_matches = re.findall(r'^GPU\d+:', result.stdout, re.MULTILINE)
                # Filter out CPU devices (llvmpipe)
                non_cpu_gpus = 0
                for i, match in enumerate(gpu_matches):
                    # Check if this GPU is a CPU device
                    gpu_section = result.stdout.split(match)[1].split('\nGPU')[0] if i < len(gpu_matches) - 1 else result.stdout.split(match)[1]
                    if 'llvmpipe' not in gpu_section and 'CPU' not in gpu_section:
                        non_cpu_gpus += 1
                return max(non_cpu_gpus, 1)
        except:
            pass
        return 2  # Default to 2 (common for NVIDIA + AMD setups)
    
    def load_model(self, model_name: str, **kwargs) -> None:
        """Load a GGUF model using llama-cpp-python."""
        from llama_cpp import Llama
        
        # model_name can be:
        # - Local file path to .gguf
        # - HuggingFace model ID (e.g., "microsoft/Phi-3-mini-4k-instruct-gguf")
        # - Full URL to a GGUF file
        
        n_gpu_layers = kwargs.get('n_gpu_layers', -1)
        n_ctx = kwargs.get('n_ctx', 2048)
        verbose = kwargs.get('verbose', True)
        main_gpu = kwargs.get('main_gpu', 0)
        self.main_gpu = main_gpu
        
        # Check if model_name is a URL - download it (with caching)
        model_path = None
        if model_name.startswith('http://') or model_name.startswith('https://'):
            # Check cache first
            cached_path = get_cached_model_path(model_name)
            if cached_path:
                model_path = cached_path
                print(f"Using cached model: {model_path}")
            else:
                print(f"Downloading model from URL: {model_name}")
                try:
                    import requests
                    from huggingface_hub import hf_hub_download
                    import tempfile
                    import hashlib
                    
                    # Get cache directory
                    cache_dir = get_model_cache_dir()
                    
                    # Extract filename from URL
                    url_path = model_name.split('?')[0]  # Remove query params
                    filename = os.path.basename(url_path)
                    
                    if not filename.endswith('.gguf'):
                        filename = "model.gguf"
                    
                    # Create safe filename in cache (use hash to avoid special char issues)
                    url_hash = hashlib.sha256(model_name.encode()).hexdigest()
                    cached_filename = f"{url_hash}_{filename}"
                    model_path = os.path.join(cache_dir, cached_filename)
                    
                    # Download to cache
                    response = requests.get(model_name, stream=True)
                    response.raise_for_status()
                    
                    total_size = int(response.headers.get('content-length', 0))
                    downloaded = 0
                    
                    with open(model_path, 'wb') as f:
                        for chunk in response.iter_content(chunk_size=8192*1024):  # 8MB chunks
                            if chunk:
                                f.write(chunk)
                                downloaded += len(chunk)
                                if total_size > 0:
                                    percent = (downloaded / total_size) * 100
                                    print(f"Downloaded: {percent:.1f}%", end='\r')
                    
                    print(f"\nDownloaded and cached to: {model_path}")
                    print(f"File size: {os.path.getsize(model_path) / 1e9:.2f} GB")
                    
                except Exception as e:
                    print(f"Error downloading model: {e}")
                    raise
        
        # Check if model_name is a local file
        elif os.path.isfile(model_name):
            model_path = model_name
            print(f"Loading local GGUF model: {model_path}")
        else:
            # Try to download from HuggingFace Hub
            print(f"Attempting to download GGUF model: {model_name}")
            try:
                from huggingface_hub import hf_hub_download, list_repo_files
                
                # Parse model name (format: "org/model" or "org/model/filename.gguf")
                parts = model_name.split('/')
                if len(parts) >= 2:
                    repo_id = f"{parts[0]}/{parts[1]}"
                    
                    # If specific file provided
                    if len(parts) >= 3 and parts[-1].endswith('.gguf'):
                        filename = '/'.join(parts[2:])
                    else:
                        # Find GGUF files in the repo
                        files = list_repo_files(repo_id)
                        gguf_files = [f for f in files if f.endswith('.gguf')]
                        if not gguf_files:
                            raise ValueError(f"No GGUF files found in {repo_id}")
                        # Prefer Q4_K_M quantized models for good balance
                        preferred = [f for f in gguf_files if 'Q4_K_M' in f or 'q4_k_m' in f.lower()]
                        if preferred:
                            filename = preferred[0]
                        else:
                            filename = gguf_files[0]
                        print(f"Selected GGUF file: {filename}")
                    
                    model_path = hf_hub_download(repo_id=repo_id, filename=filename)
                    print(f"Downloaded to: {model_path}")
                else:
                    raise ValueError(f"Invalid model name format: {model_name}")
            except Exception as e:
                print(f"Error downloading model: {e}")
                print("Please provide a local path to a .gguf file")
                raise
        
        print(f"Loading GGUF model with Vulkan support...")
        print(f"  Model path: {model_path}")
        print(f"  GPU layers: {n_gpu_layers} (-1 = all layers)")
        print(f"  Context size: {n_ctx}")
        print(f"  GPU device: {main_gpu}")
        
        # List available devices for user reference
        self.list_vulkan_devices()
        
        # Check if single GPU mode is requested
        single_gpu = kwargs.get('single_gpu', False)
        tensor_split = None
        
        # First, get the number of Vulkan devices from llama.cpp's perspective
        # We'll try to detect from ggml_vulkan output by checking available GPUs
        num_devices = 2  # Default
        
        # Try to parse vulkaninfo to get actual device count
        try:
            import subprocess
            result = subprocess.run(['vulkaninfo', '--summary'], capture_output=True, text=True)
            if result.returncode == 0:
                # Count actual GPU devices (exclude llvmpipe CPU)
                import re
                lines = result.stdout.split('\n')
                gpu_count = 0
                for i, line in enumerate(lines):
                    if line.strip().startswith('GPU'):
                        # Check next few lines for device type
                        section = '\n'.join(lines[i:i+10])
                        if 'llvmpipe' not in section.lower() and 'cpu' not in section.split('deviceType')[0] if 'deviceType' in result.stdout else '':
                            gpu_count += 1
                if gpu_count > 0:
                    num_devices = gpu_count
        except Exception as e:
            print(f"Warning: Could not detect Vulkan device count: {e}")
        
        print(f"DEBUG: Detected {num_devices} Vulkan GPU devices")
        
        if single_gpu:
            # Build tensor_split to force all layers onto one GPU
            # tensor_split is a list where index = GPU device, value = weight (0.0 = don't use)
            tensor_split = [0.0] * num_devices
            if main_gpu < num_devices:
                tensor_split[main_gpu] = 1.0
                print(f"  Single GPU mode: Setting tensor_split for GPU {main_gpu}: {tensor_split}")
            else:
                print(f"Warning: main_gpu={main_gpu} exceeds detected devices ({num_devices}), ignoring single_gpu")
                tensor_split = None
        
        try:
            # If force_cuda is set, configure environment for CUDA
            # Note: Vulkan environment variables should be set from launching script
            # if needed to force CUDA. Here we just ensure CUDA_VISIBLE_DEVICES is set.
            if self.force_cuda:
                print("DEBUG: Forcing CUDA backend for llama-cpp-python...")
                # Ensure CUDA is used - set environment to prefer CUDA
                if 'CUDA_VISIBLE_DEVICES' not in os.environ:
                    # Use all available CUDA devices
                    import subprocess
                    try:
                        result = subprocess.run(['nvidia-smi', '-L'], capture_output=True, text=True)
                        if result.returncode == 0:
                            gpu_count = len([l for l in result.stdout.split('\n') if 'GPU' in l])
                            os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(str(i) for i in range(gpu_count))
                            print(f"DEBUG: Set CUDA_VISIBLE_DEVICES={os.environ['CUDA_VISIBLE_DEVICES']}")
                    except Exception as e:
                        print(f"Warning: Could not detect GPU count: {e}")
                        os.environ['CUDA_VISIBLE_DEVICES'] = '0'
                # Print CUDA info
                print(f"  CUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES', 'not set')}")
            
            llama_kwargs = {
                'model_path': model_path,
                'n_gpu_layers': n_gpu_layers,
                'n_ctx': n_ctx,
                'verbose': verbose,
                'main_gpu': main_gpu,
            }
            
            if tensor_split:
                llama_kwargs['tensor_split'] = tensor_split
            
            self.model = Llama(**llama_kwargs)
            self.model_name = model_name
            backend_name = "CUDA" if self.force_cuda else "Vulkan"
            print(f"\nModel loaded successfully with {backend_name}!")
            
            # Show model capabilities
            caps = detect_model_capabilities(model_name)
            print(f"Model capabilities: {caps}")
            
            # Detect the chat template after model load
            self._finalize_chat_template_detection()
            print(f"DEBUG: Chat template: {self.chat_template}")
        except Exception as e:
            backend_name = "CUDA" if self.force_cuda else "Vulkan"
            
            # Check if this might be a corrupted cache file
            cache_dir = get_model_cache_dir()
            is_cached = model_path and model_path.startswith(cache_dir) and os.path.exists(model_path)
            
            if is_cached:
                # Try to determine if it's a corruption error
                error_str = str(e).lower()
                corruption_indicators = ['invalid', 'corrupt', 'magic', 'header', 'file', 'open', 'read']
                if any(indicator in error_str for indicator in corruption_indicators):
                    print(f"WARNING: Cached model appears corrupted: {e}")
                    print("Deleting corrupted cache and re-downloading...")
                    try:
                        os.remove(model_path)
                        print(f"Deleted: {model_path}")
                        # Re-download by setting model_path to None and re-calling the download logic
                        model_path = None
                        # Re-download from original URL
                        if model_name.startswith('http://') or model_name.startswith('https://'):
                            import requests
                            from huggingface_hub import hf_hub_download
                            import hashlib
                            
                            cache_dir = get_model_cache_dir()
                            url_path = model_name.split('?')[0]
                            filename = os.path.basename(url_path)
                            if not filename.endswith('.gguf'):
                                filename = "model.gguf"
                            url_hash = hashlib.sha256(model_name.encode()).hexdigest()
                            cached_filename = f"{url_hash}_{filename}"
                            model_path = os.path.join(cache_dir, cached_filename)
                            
                            print(f"Re-downloading model from URL: {model_name}")
                            response = requests.get(model_name, stream=True)
                            response.raise_for_status()
                            
                            total_size = int(response.headers.get('content-length', 0))
                            downloaded = 0
                            
                            with open(model_path, 'wb') as f:
                                for chunk in response.iter_content(chunk_size=8192*1024):
                                    if chunk:
                                        f.write(chunk)
                                        downloaded += len(chunk)
                                        if total_size > 0:
                                            percent = (downloaded / total_size) * 100
                                            print(f"Downloaded: {percent:.1f}%", end='\r')
                            
                            print(f"\nRe-downloaded and cached to: {model_path}")
                            print(f"File size: {os.path.getsize(model_path) / 1e9:.2f} GB")
                            
                            # Retry loading with new model_path
                            llama_kwargs['model_path'] = model_path
                            self.model = Llama(**llama_kwargs)
                            self.model_name = model_name
                            print(f"\nModel loaded successfully with {backend_name} after re-download!")
                            self._finalize_chat_template_detection()
                            print(f"DEBUG: Chat template: {self.chat_template}")
                            return
                    except Exception as redownload_error:
                        print(f"Failed to re-download model: {redownload_error}")
                        # Fall through to regular error handling
            
            print(f"Error loading model with {backend_name}: {e}")
            if self.force_cuda:
                print("Make sure CUDA is available:")
                print("  - Install llama-cpp-python with CUDA support: pip install llama-cpp-python[cuda]")
                print("  - Ensure NVIDIA drivers are installed")
                print("  - Check nvidia-smi output")
            else:
                print("Make sure Vulkan drivers are installed:")
                print("  Debian/Ubuntu: sudo apt install libvulkan-dev vulkan-tools")
                print("  Fedora: sudo dnf install vulkan-loader-devel vulkan-tools")
            raise
    
    def _format_messages_hf(self, messages: List[ChatMessage]) -> str:
        """Format messages using HuggingFace transformers apply_chat_template."""
        if self.hf_tokenizer is None:
            return self._manual_format_messages([{"role": m.role, "content": m.content or ""} for m in messages])
        
        # Convert messages to the format expected by transformers
        chat_messages = []
        for msg in messages:
            chat_msg = {"role": msg.role}
            # Ensure content is never None
            if msg.content is not None:
                chat_msg["content"] = msg.content
            else:
                chat_msg["content"] = ""
            if msg.tool_calls:
                chat_msg["tool_calls"] = msg.tool_calls
            chat_messages.append(chat_msg)
        
        try:
            # Use HuggingFace's apply_chat_template
            prompt = self.hf_tokenizer.apply_chat_template(
                chat_messages,
                tokenize=False,
                add_generation_prompt=True
            )
            return prompt
        except Exception as e:
            print(f"Warning: HF apply_chat_template failed ({e}), using manual formatting")
            return self._manual_format_messages(chat_messages)
    
    def format_messages(self, messages: List[ChatMessage]) -> str:
        """Format messages into a prompt string suitable for chat models.
        
        Uses HuggingFace transformers apply_chat_template if available and enabled,
        otherwise falls back to llama.cpp's built-in support.
        """
        # Check if we should use HuggingFace tokenizer
        if self.hf_tokenizer is not None:
            return self._format_messages_hf(messages)
        
        # Convert to format expected by llama.cpp
        chat_messages = []
        for msg in messages:
            chat_msg = {"role": msg.role}
            # CRITICAL: Ensure content is never None - Jinja templates fail on None
            # Also ensure content key exists
            if msg.content is not None:
                chat_msg["content"] = msg.content
            else:
                chat_msg["content"] = ""
            if msg.tool_calls:
                chat_msg["tool_calls"] = msg.tool_calls
            chat_messages.append(chat_msg)
        
        # Use llama.cpp's apply_chat_template if available
        try:
            prompt = self.model.apply_chat_template(chat_messages, tokenize=False, add_generation_prompt=True)
            return prompt
        except Exception as e:
            # Fallback to manual formatting if apply_chat_template fails
            print(f"Warning: apply_chat_template failed ({e}), using fallback formatting")
            formatted = []
            for msg in messages:
                if msg.role == "system":
                    formatted.append(f"<|im_start|>system\n{msg.content}<|im_end|>")
                elif msg.role == "user":
                    formatted.append(f"<|im_start|>user\n{msg.content}<|im_end|>")
                elif msg.role == "assistant":
                    content = msg.content or ""
                    formatted.append(f"<|im_start|>assistant\n{content}<|im_end|>")
            
            formatted.append("<|im_start|>assistant\n")
            return "\n".join(formatted)
    
    def generate(self, prompt: str, max_tokens: Optional[int] = None,
                 temperature: float = 0.7, top_p: float = 1.0,
                 stop: Optional[List[str]] = None) -> str:
        """Generate text non-streaming using llama-cpp."""
        if max_tokens is None:
            max_tokens = 512
        
        output = self.model(
            prompt,
            max_tokens=max_tokens,
            temperature=temperature,
            top_p=top_p,
            stop=stop or [],
        )
        
        return output["choices"][0]["text"]
    
    def generate_chat(self, messages: List[Dict], max_tokens: Optional[int] = None,
                      temperature: float = 0.7, top_p: float = 1.0,
                      stop: Optional[List[str]] = None, tools: Optional[List] = None,
                      response_format: Optional[Dict] = None) -> str:
        """Generate chat completion using llama-cpp's create_chat_completion."""
        if max_tokens is None:
            max_tokens = 512
        
        # Handle response_format - extract type if provided
        # llama.cpp supports response_format={'type': 'json_object'} for JSON mode
        response_format_param = None
        if response_format:
            if isinstance(response_format, dict):
                response_format_type = response_format.get('type', '')
                if response_format_type == 'json_object' or response_format_type == 'json':
                    response_format_param = {"type": "json_object"}
            elif isinstance(response_format, str):
                if response_format == 'json_object' or response_format == 'json':
                    response_format_param = {"type": "json_object"}
        
        # CRITICAL: Ensure NO message has None content - Jinja templates fail on None
        # This is a safety check in case messages bypass the main endpoint validation
        cleaned_messages = []
        for msg in messages:
            cleaned_msg = dict(msg)  # Make a copy to avoid modifying original
            # Ensure content key exists and is never None
            if "content" not in cleaned_msg:
                cleaned_msg["content"] = ""
            elif cleaned_msg.get("content") is None:
                cleaned_msg["content"] = ""
            # Convert non-string content to string
            elif not isinstance(cleaned_msg["content"], str):
                cleaned_msg["content"] = str(cleaned_msg["content"])
            cleaned_messages.append(cleaned_msg)
        messages = cleaned_messages
        
        # Check if we should use manual formatting based on detected template
        # NOTE: "default" means llama.cpp detected an embedded template - let llama.cpp handle it via create_chat_completion
        # For other templates (unknown, jinja_fallback, None), or when tools are present with non-default templates,
        # use manual formatting since Jinja templates often fail with tool messages
        use_manual = self.chat_template in ("unknown", "jinja_fallback", None)
        # With tools, use manual for non-default templates but let llama.cpp handle "default" (embedded GGUF template)
        use_manual = use_manual or (tools is not None and self.chat_template != "default")
        use_manual = use_manual or (tools is not None)  # Force manual when tools present to avoid Jinja errors
        # Also use manual when we have a known template but no HuggingFace tokenizer (except for embedded "default")
        use_manual = use_manual or (self.chat_template is not None and self.chat_template != "default" and self.hf_tokenizer is None)
        use_hf = self.hf_tokenizer is not None
        
        if use_hf:
            # Use HuggingFace tokenizer for chat template
            try:
                prompt = self.hf_tokenizer.apply_chat_template(
                    messages,
                    tokenize=False,
                    add_generation_prompt=True
                )
                print(f"DEBUG: Using HuggingFace chat template")
                return self.generate(prompt, max_tokens, temperature, top_p, stop)
            except Exception as e:
                print(f"Warning: HF apply_chat_template failed ({e}), falling back")
        
        if use_manual:
            print(f"DEBUG: Using manual message formatting (template: {self.chat_template}, tools: {tools is not None})")
            prompt = self._manual_format_messages(messages)
            return self.generate(prompt, max_tokens, temperature, top_p, stop)
        
        try:
            response = self.model.create_chat_completion(
                messages=messages,
                max_tokens=max_tokens,
                temperature=temperature,
                top_p=top_p,
                stop=stop or [],
                tools=tools,
                response_format=response_format_param,
            )
            content = response["choices"][0]["message"].get("content", "")
            print(f"DEBUG: generate_chat returned content length: {len(content) if content else 0}")
            if not content or not content.strip():
                print(f"DEBUG: Empty content from create_chat_completion, using fallback")
                raise Exception("Empty response from create_chat_completion")
            return content
        except Exception as e:
            print(f"Warning: create_chat_completion failed ({e}), falling back to text generation")
            # Fallback: format messages manually and use text generation
            prompt = self._manual_format_messages(messages)
            return self.generate(prompt, max_tokens, temperature, top_p, stop)
    
    async def generate_chat_stream(self, messages: List[Dict], max_tokens: Optional[int] = None,
                                    temperature: float = 0.7, top_p: float = 1.0,
                                    stop: Optional[List[str]] = None, tools: Optional[List] = None,
                                    response_format: Optional[Dict] = None) -> AsyncGenerator[str, None]:
        """Generate chat completion streaming using llama-cpp."""
        if max_tokens is None:
            max_tokens = 512
        
        # Handle response_format - extract type if provided
        response_format_param = None
        if response_format:
            if isinstance(response_format, dict):
                response_format_type = response_format.get('type', '')
                if response_format_type == 'json_object' or response_format_type == 'json':
                    response_format_param = {"type": "json_object"}
            elif isinstance(response_format, str):
                if response_format == 'json_object' or response_format == 'json':
                    response_format_param = {"type": "json_object"}
        
        total_content = ""
        chunk_count = 0
        has_tools = tools is not None  # Track if tools are available
        
        # CRITICAL: Ensure NO message has None content - Jinja templates fail on None
        # This is a safety check in case messages bypass the main endpoint validation
        cleaned_messages = []
        for msg in messages:
            cleaned_msg = dict(msg)  # Make a copy to avoid modifying original
            # Ensure content key exists and is never None
            if "content" not in cleaned_msg:
                cleaned_msg["content"] = ""
            elif cleaned_msg.get("content") is None:
                cleaned_msg["content"] = ""
            # Convert non-string content to string
            elif not isinstance(cleaned_msg["content"], str):
                cleaned_msg["content"] = str(cleaned_msg["content"])
            cleaned_messages.append(cleaned_msg)
        messages = cleaned_messages
        
        # Check if we should use manual formatting based on detected template
        # NOTE: "default" means llama.cpp detected an embedded template - let llama.cpp handle it via create_chat_completion
        # For other templates (unknown, jinja_fallback, None), or when tools are present with non-default templates,
        # use manual formatting since Jinja templates often fail with tool messages
        use_manual = self.chat_template in ("unknown", "jinja_fallback", None)
        # With tools, use manual for non-default templates but let llama.cpp handle "default" (embedded GGUF template)
        use_manual = use_manual or (tools is not None and self.chat_template != "default")
        use_manual = use_manual or (tools is not None)  # Force manual when tools present to avoid Jinja errors
        # Also use manual when we have a known template but no HuggingFace tokenizer (except for embedded "default")
        use_manual = use_manual or (self.chat_template is not None and self.chat_template != "default" and self.hf_tokenizer is None)
        use_hf = self.hf_tokenizer is not None
        
        if use_hf:
            # Use HuggingFace tokenizer for chat template
            try:
                prompt = self.hf_tokenizer.apply_chat_template(
                    messages,
                    tokenize=False,
                    add_generation_prompt=True
                )
                print(f"DEBUG: Using HuggingFace chat template for streaming")
                async for chunk in self.generate_stream(prompt, max_tokens, temperature, top_p, stop):
                    yield chunk
                return
            except Exception as e:
                print(f"Warning: HF apply_chat_template failed ({e}), falling back")
        
        if use_manual:
            print(f"DEBUG: Using manual message formatting for streaming (template: {self.chat_template}, tools: {tools is not None})")
            prompt = self._manual_format_messages(messages)
            async for chunk in self.generate_stream(prompt, max_tokens, temperature, top_p, stop):
                yield chunk
            return
        
        # Collect all chunks synchronously then yield them
        # This avoids issues with generators across thread boundaries
        def collect_chunks():
            """Collect all chunks from the stream."""
            print(f"DEBUG: generate_chat_stream: Calling create_chat_completion with tools={tools}")
            stream = self.model.create_chat_completion(
                messages=messages,
                max_tokens=max_tokens,
                temperature=temperature,
                top_p=top_p,
                stop=stop or [],
                tools=tools,
                stream=True,
                response_format=response_format_param,
            )
            print(f"DEBUG: generate_chat_stream: Got stream object: {type(stream)}")
            chunks = []
            for chunk in stream:
                chunks.append(chunk)
            print(f"DEBUG: generate_chat_stream: Collected {len(chunks)} chunks")
            return chunks
        
        try:
            # Run the collection in thread pool
            loop = asyncio.get_event_loop()
            chunks = await loop.run_in_executor(None, collect_chunks)
            
            for chunk in chunks:
                chunk_count += 1
                print(f"DEBUG: generate_chat_stream: Processing chunk {chunk_count}: {repr(chunk)}")
                delta = chunk["choices"][0].get("delta", {})
                content = delta.get("content", "")
                
                # Handle Qwen3's special thinking token - skip it and continue
                # Qwen3 uses `<think>` tags for reasoning, we should pass through the content
                if content:
                    total_content += content
                    yield content
                
                # Small yield to allow other async tasks
                await asyncio.sleep(0)
            
            print(f"DEBUG: generate_chat_stream yielded {chunk_count} chunks, total content length: {len(total_content)}")
            if chunk_count == 0 or not total_content.strip():
                print(f"DEBUG: Empty stream from create_chat_completion, using fallback")
                raise Exception("Empty stream response")
        except Exception as e:
            print(f"DEBUG: generate_chat_stream exception: {type(e).__name__}: {e}")
            import traceback
            traceback.print_exc()
            if chunk_count == 0:
                print(f"Warning: create_chat_completion stream failed ({e}), falling back to text generation")
                # Fallback: format messages manually and use text generation
                prompt = self._manual_format_messages(messages)
                async for chunk in self.generate_stream(prompt, max_tokens, temperature, top_p, stop):
                    yield chunk
            else:
                print(f"DEBUG: Stream completed with {chunk_count} chunks")
    
    def _manual_format_messages(self, messages: List[Dict]) -> str:
        """Manual fallback for formatting messages when create_chat_completion fails."""
        formatted = []
        for msg in messages:
            role = msg.get("role", "")
            # CRITICAL: Ensure content is never None - Jinja templates fail on None
            # Also ensure content key exists
            content = msg.get("content")
            if content is None:
                content = ""
            elif not isinstance(content, str):
                content = str(content)
            
            if role == "system":
                formatted.append(f"<|im_start|>system\n{content}<|im_end|>")
            elif role == "user":
                formatted.append(f"<|im_start|>user\n{content}<|im_end|>")
            elif role == "assistant":
                # Handle tool_calls if present
                tool_calls = msg.get("tool_calls", [])
                if tool_calls:
                    for tc in tool_calls:
                        if isinstance(tc, dict) and "function" in tc:
                            func = tc["function"]
                            tc_str = f'<tool>{{"name": "{func.get("name", "")}", "arguments": {func.get("arguments", "{}")}}}</tool>'
                            content = content + "\n" + tc_str if content else tc_str
                formatted.append(f"<|im_start|>assistant\n{content}<|im_end|>")
            elif role == "tool":
                # Tool result messages
                tool_call_id = msg.get("tool_call_id", "")
                name = msg.get("name", "")
                formatted.append(f"<|im_start|>tool (tool_call_id={tool_call_id}, name={name})\n{content}<|im_end|>")
        
        formatted.append("<|im_start|>assistant\n")
        return "\n".join(formatted)
    
    async def generate_stream(self, prompt: str, max_tokens: Optional[int] = None,
                              temperature: float = 0.7, top_p: float = 1.0,
                              stop: Optional[List[str]] = None) -> AsyncGenerator[str, None]:
        """Generate text in streaming fashion using llama-cpp."""
        if max_tokens is None:
            max_tokens = 512
        
        stream = self.model(
            prompt,
            max_tokens=max_tokens,
            temperature=temperature,
            top_p=top_p,
            stop=stop or [],
            stream=True,
        )
        
        for chunk in stream:
            text = chunk["choices"][0].get("text", "")
            if text:
                yield text
    
    def get_model_name(self) -> str:
        return self.model_name or "unknown"
    
    def cleanup(self) -> None:
        # Vulkan environment variables should be managed from launching script
        if self.model is not None:
            del self.model
            self.model = None
# =============================================================================
# Model Manager
# =============================================================================

class ModelManager:
    """Manages the loaded model and tokenizer."""
    
    def __init__(self):
        self.backend: Optional[ModelBackend] = None
        self.backend_type: Optional[str] = None
        self.tool_parser = ModelParserAdapter()
    
    def _aggressive_vram_cleanup(self, model_manager):
        """
        Aggressively cleanup VRAM when switching between different model types.
        This is more thorough than a simple cleanup() call.
        """
        import gc
        import time
        
        try:
            import torch
            
            # First, try to move model to CPU if it has a model attribute
            if hasattr(model_manager, 'model') and model_manager.model is not None:
                model = model_manager.model
                
                # If it's a diffusers pipeline, try to move to CPU first
                if hasattr(model, 'to'):
                    try:
                        model.to('cpu')
                    except:
                        pass
                
                # Delete the model
                del model
            
            # Also handle backend directly if it's different
            if hasattr(model_manager, 'backend') and model_manager.backend is not None:
                backend = model_manager.backend
                
                if hasattr(backend, 'model') and backend.model is not None:
                    model = backend.model
                    if hasattr(model, 'to'):
                        try:
                            model.to('cpu')
                        except:
                            pass
                    del model
                
                if hasattr(backend, 'pipeline') and backend.pipeline is not None:
                    del backend.pipeline
                
                if hasattr(backend, 'vae') and backend.vae is not None:
                    del backend.vae
                
                if hasattr(backend, 'text_encoder') and backend.text_encoder is not None:
                    del backend.text_encoder
                
                if hasattr(backend, 'tokenizer') and backend.tokenizer is not None:
                    del backend.tokenizer
            
            # Force multiple rounds of garbage collection
            for _ in range(3):
                gc.collect()
            
            # Clear PyTorch cache
            if torch.cuda.is_available():
                torch.cuda.synchronize()
                torch.cuda.empty_cache()
            
            # Add delay to allow Vulkan to release memory
            time.sleep(2)
            
        except Exception as e:
            print(f"Warning during aggressive VRAM cleanup: {e}")
        finally:
            # Try to cleanup the model manager itself
            try:
                if hasattr(model_manager, 'cleanup'):
                    model_manager.cleanup()
            except:
                pass
        
    def load_model(self, model_name: str, backend_type: str = "auto", **kwargs):
        """
        Load the model with the specified backend.
        
        Args:
            model_name: Model name or path
            backend_type: 'nvidia', 'vulkan', 'cuda', or 'auto' to detect
            **kwargs: Additional arguments for the specific backend
        """
        available = detect_available_backends()
        
        # Check if model is a GGUF file
        is_gguf = model_name.endswith('.gguf') or 'gguf' in model_name.lower()
        
        # Determine backend
        if backend_type == "auto":
            if available.get('nvidia'):
                backend_type = "nvidia"
                print("Auto-detected NVIDIA backend")
            elif available.get('vulkan'):
                backend_type = "vulkan"
                print("Auto-detected Vulkan backend")
            else:
                print("No GPU backend detected. For NVIDIA, install PyTorch with CUDA.")
                print("For Vulkan, install llama-cpp-python with Vulkan support.")
                raise RuntimeError("No suitable backend found")
        
        # If GGUF file and backend is nvidia/cuda, use llama-cpp-python with CUDA backend
        original_backend = None
        if is_gguf and backend_type in ("nvidia", "cuda"):
            original_backend = backend_type
            print(f"GGUF model detected, using llama-cpp-python (original backend: {original_backend})")
            backend_type = "vulkan"  # Use llama-cpp-python for GGUF
        
        self.backend_type = backend_type
        
        # Create appropriate backend
        if backend_type == "nvidia":
            if not available.get('nvidia'):
                raise RuntimeError("NVIDIA backend requested but PyTorch/CUDA not available")
            self.backend = NvidiaBackend()
        elif backend_type == "vulkan":
            if not available.get('vulkan'):
                raise RuntimeError("Vulkan backend requested but llama-cpp-python not available")
            self.backend = VulkanBackend(original_backend=original_backend)
        else:
            raise ValueError(f"Unknown backend: {backend_type}")
        
        # Load the model
        self.backend.load_model(model_name, **kwargs)
        self.tool_parser = ModelParserAdapter(model_name=model_name)
        
    def format_messages(self, messages: List[ChatMessage]) -> str:
        """Format messages into a prompt string."""
        if self.backend is None:
            raise RuntimeError("No model loaded")
        return self.backend.format_messages(messages)
    
    def generate(self, prompt: str, max_tokens: Optional[int] = None,
                 temperature: float = 0.7, top_p: float = 1.0,
                 stop: Optional[List[str]] = None) -> str:
        """Generate text non-streaming."""
        if self.backend is None:
            raise RuntimeError("No model loaded")
        return self.backend.generate(prompt, max_tokens, temperature, top_p, stop)
    
    def generate_chat(self, messages: List[Dict], max_tokens: Optional[int] = None,
                      temperature: float = 0.7, top_p: float = 1.0,
                      stop: Optional[List[str]] = None, tools: Optional[List] = None,
                      response_format: Optional[Dict] = None) -> str:
        """Generate chat completion non-streaming."""
        if self.backend is None:
            raise RuntimeError("No model loaded")
        # Use generate_chat if available (Vulkan backend), otherwise format and use generate
        if hasattr(self.backend, 'generate_chat'):
            return self.backend.generate_chat(messages, max_tokens, temperature, top_p, stop, tools, response_format)
        else:
            # Fallback for NVIDIA backend
            prompt = self.format_messages([ChatMessage(**m) for m in messages])
            return self.backend.generate(prompt, max_tokens, temperature, top_p, stop)
    
    async def generate_stream(self, prompt: str, max_tokens: Optional[int] = None,
                              temperature: float = 0.7, top_p: float = 1.0,
                              stop: Optional[List[str]] = None) -> AsyncGenerator[str, None]:
        """Generate text in streaming fashion."""
        if self.backend is None:
            raise RuntimeError("No model loaded")
        async for chunk in self.backend.generate_stream(prompt, max_tokens, temperature, top_p, stop):
            yield chunk
    
    async def generate_chat_stream(self, messages: List[Dict], max_tokens: Optional[int] = None,
                                    temperature: float = 0.7, top_p: float = 1.0,
                                    stop: Optional[List[str]] = None, tools: Optional[List] = None,
                                    response_format: Optional[Dict] = None) -> AsyncGenerator[str, None]:
        """Generate chat completion streaming."""
        if self.backend is None:
            raise RuntimeError("No model loaded")
        # Use generate_chat_stream if available (Vulkan backend), otherwise format and use generate_stream
        if hasattr(self.backend, 'generate_chat_stream'):
            async for chunk in self.backend.generate_chat_stream(messages, max_tokens, temperature, top_p, stop, tools, response_format):
                yield chunk
        else:
            # Fallback for NVIDIA backend
            prompt = self.format_messages([ChatMessage(**m) for m in messages])
            async for chunk in self.backend.generate_stream(prompt, max_tokens, temperature, top_p, stop):
                yield chunk
    
    @property
    def model_name(self) -> str:
        if self.backend is None:
            return "unknown"
        return self.backend.get_model_name()
    
    @property
    def model(self):
        if self.backend is None:
            return None
        return self.backend
    
    @property
    def tokenizer(self):
        # Only NVIDIA backend has a tokenizer
        if isinstance(self.backend, NvidiaBackend):
            return self.backend.tokenizer
        return None
    
    def cleanup(self):
        if self.backend is not None:
            self.backend.cleanup()
            self.backend = None
# =============================================================================
# Whisper Server Manager - manages whisper-server subprocess
# =============================================================================

import subprocess
import signal
import requests
import time
import threading
class WhisperServerManager:
    """Manages whisper-server subprocess for audio transcription with model swapping support."""
    
    def __init__(self, server_path: str = None, port: int = 8744):
        self.server_path = server_path
        self.port = port
        self.process = None
        self.current_model = None
        self.base_url = f"http://127.0.0.1:{port}"
        self.lock = threading.Lock()
        self._health_check_thread = None
        self._running = False
        
        # Check if port is available
        if not self._is_port_available(port):
            # Try to find an available port
            for new_port in range(port + 1, port + 100):
                if self._is_port_available(new_port):
                    self.port = new_port
                    self.base_url = f"http://127.0.0.1:{new_port}"
                    print(f"Port {port} in use, using port {new_port} instead")
                    break
    
    def _is_port_available(self, port: int) -> bool:
        """Check if a port is available."""
        import socket
        try:
            with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
                s.bind(('127.0.0.1', port))
                return True
        except OSError:
            return False
    
    def is_running(self) -> bool:
        """Check if whisper-server is running."""
        if self.process is None:
            return False
        return self.process.poll() is None
    
    def start(self, model_path: str = None, gpu_device: int = 0) -> str:
        """Start whisper-server with the specified model. Returns actual model path or empty string on failure."""
        with self.lock:
            # Stop existing server if running
            if self.is_running():
                self.stop()
            
            if not self.server_path:
                print("Error: whisper-server path not set")
                return ""
            
            # Handle URL models - download if needed
            actual_model_path = model_path
            if model_path and (model_path.startswith('http://') or model_path.startswith('https://')):
                # Check cache first
                cached_path = get_cached_model_path(model_path)
                if cached_path:
                    actual_model_path = cached_path
                    print(f"Using cached model: {actual_model_path}")
                else:
                    # Download the model
                    cache_dir = get_model_cache_dir()
                    print(f"Downloading model: {model_path}")
                    actual_model_path = download_model(model_path, cache_dir)
                    print(f"Downloaded model to: {actual_model_path}")
            
            # Build command
            cmd = [self.server_path]
            
            if actual_model_path:
                cmd.extend(["-m", actual_model_path])
            
            # Add GPU device
            cmd.extend(["-dev", str(gpu_device)])
            
            # Add --convert flag to convert audio to 16kHz mono on the server side
            cmd.append("--convert")
            
            # Add host and port
            cmd.extend(["--host", "127.0.0.1"])
            cmd.extend(["--port", str(self.port)])
            
            print(f"Starting whisper-server: {' '.join(cmd)}")
            print(f"DEBUG: Full whisper-server command: {' '.join(cmd)}")
            
            try:
                self.process = subprocess.Popen(
                    cmd,
                    stdout=subprocess.PIPE,
                    stderr=subprocess.PIPE,
                    preexec_fn=lambda: signal.signal(signal.SIGTERM, signal.SIG_DFL)
                )
                self.current_model = actual_model_path
                
                # Wait for server to be ready
                if self._wait_for_server(30):
                    print(f"whisper-server started on {self.base_url}")
                    self._running = True
                    return actual_model_path
                else:
                    print("Error: whisper-server failed to start")
                    self.stop()
                    return ""
            except Exception as e:
                print(f"Error starting whisper-server: {e}")
                return ""
    
    def stop(self):
        """Stop whisper-server."""
        with self.lock:
            self._running = False
            if self.process:
                try:
                    self.process.terminate()
                    try:
                        self.process.wait(timeout=5)
                    except subprocess.TimeoutExpired:
                        self.process.kill()
                        self.process.wait()
                except Exception as e:
                    print(f"Error stopping whisper-server: {e}")
                self.process = None
                self.current_model = None
    
    def restart(self, model_path: str = None, gpu_device: int = 0) -> bool:
        """Restart whisper-server with a new model."""
        print(f"Restarting whisper-server with model: {model_path}")
        return self.start(model_path, gpu_device)
    
    def transcribe(self, audio_data: bytes, language: str = None, prompt: str = None) -> dict:
        """Send transcription request to whisper-server."""
        if not self.is_running():
            return {"error": "whisper-server not running"}
        
        try:
            files = {"file": ("audio.wav", audio_data, "audio/wav")}
            data = {}
            if language:
                data["language"] = language
            if prompt:
                data["prompt"] = prompt
            
            print(f"DEBUG: Sending POST to {self.base_url}/inference with data={data}, file_size={len(audio_data)}")
            response = requests.post(
                f"{self.base_url}/inference",
                files=files,
                data=data,
                timeout=300
            )
            print(f"DEBUG: whisper-server response status={response.status_code}, body={response.text[:500] if response.text else 'empty'}")
            
            if response.status_code == 200:
                return response.json()
            else:
                return {"error": f"Server error: {response.status_code}", "detail": response.text}
        except Exception as e:
            print(f"DEBUG: whisper-server exception: {e}")
            return {"error": str(e)}
    
    def _wait_for_server(self, timeout: int = 30) -> bool:
        """Wait for whisper-server to be ready."""
        start_time = time.time()
        while time.time() - start_time < timeout:
            try:
                response = requests.get(f"{self.base_url}/health", timeout=2)
                if response.status_code == 200:
                    return True
            except:
                pass
            time.sleep(0.5)
        return False
    
    def get_status(self) -> dict:
        """Get whisper-server status."""
        return {
            "running": self.is_running(),
            "model": self.current_model,
            "url": self.base_url
        }
# =============================================================================
# Multi-Model Manager (supports audio transcription and image generation)
# =============================================================================

class MultiModelManager:
    """
    Manages multiple models: main text model, audio transcription, and image generation.
    Supports dynamic switching based on request model name.
    
    Modes:
    - default: Load models on-demand (swap models in VRAM when request changes)
    - loadall: Pre-load all models in VRAM at startup
    - loadswap: Keep all models in memory (CPU RAM), swap active model to VRAM
    """
    
    def __init__(self):
        self.models: Dict[str, ModelManager] = {}
        self.default_model: Optional[str] = None
        self.audio_models: List[str] = []  # List of audio model names
        self.tts_model: Optional[str] = None
        self.image_models: List[str] = []  # List of image model names
        self.vision_models: List[str] = []  # List of vision (image/video to text) model names
        self.tool_parser = ModelParserAdapter()
        self.current_model_key: Optional[str] = None
        # Configuration for each model type
        self.config: Dict[str, Dict] = {}
        # Load mode settings
        self.load_mode: str = "ondemand"  # "ondemand", "loadall", "loadswap"
        self.active_in_vram: Optional[str] = None  # Which model is currently in VRAM
        # Model aliases: alias -> actual model name mapping
        self.model_aliases: Dict[str, str] = {}
        # Whisper server manager
        self.whisper_server: Optional[WhisperServerManager] = None
        # Track backend type for each model (needed for on-demand loading)
        self.model_backend_types: Dict[str, str] = {}
    
    def _aggressive_vram_cleanup(self, model_manager):
        """
        Aggressively cleanup VRAM when switching between different model types.
        This is more thorough than a simple cleanup() call.
        """
        import gc
        import time
        
        try:
            import torch
            
            # First, try to move model to CPU if it has a model attribute
            if hasattr(model_manager, 'model') and model_manager.model is not None:
                model = model_manager.model
                
                # If it's a diffusers pipeline, try to move to CPU first
                if hasattr(model, 'to'):
                    try:
                        model.to('cpu')
                    except:
                        pass
                
                # Delete the model
                del model
            
            # Also handle backend directly if it's different
            if hasattr(model_manager, 'backend') and model_manager.backend is not None:
                backend = model_manager.backend
                
                if hasattr(backend, 'model') and backend.model is not None:
                    model = backend.model
                    if hasattr(model, 'to'):
                        try:
                            model.to('cpu')
                        except:
                            pass
                    del model
                
                if hasattr(backend, 'pipeline') and backend.pipeline is not None:
                    del backend.pipeline
                
                if hasattr(backend, 'vae') and backend.vae is not None:
                    del backend.vae
                
                if hasattr(backend, 'text_encoder') and backend.text_encoder is not None:
                    del backend.text_encoder
                
                if hasattr(backend, 'tokenizer') and backend.tokenizer is not None:
                    del backend.tokenizer
            
            # Force multiple rounds of garbage collection
            for _ in range(3):
                gc.collect()
            
            # Clear PyTorch cache
            if torch.cuda.is_available():
                torch.cuda.synchronize()
                torch.cuda.empty_cache()
            
            # Add delay to allow Vulkan to release memory
            time.sleep(2)
            
        except Exception as e:
            print(f"Warning during aggressive VRAM cleanup: {e}")
        finally:
            # Try to cleanup the model manager itself
            try:
                if hasattr(model_manager, 'cleanup'):
                    model_manager.cleanup()
            except:
                pass
        # Load mode settings
        self.load_mode: str = "ondemand"  # "ondemand", "loadall", "loadswap"
        self.active_in_vram: Optional[str] = None  # Which model is currently in VRAM
        # Model aliases: alias -> actual model name mapping
        self.model_aliases: Dict[str, str] = {}
        # Whisper server manager
        self.whisper_server: Optional[WhisperServerManager] = None
        # Track backend type for each model (needed for on-demand loading)
        self.model_backend_types: Dict[str, str] = {}
        
    @property
    def audio_model(self) -> Optional[str]:
        """Get the first/default audio model."""
        return self.audio_models[0] if self.audio_models else None
    
    @property
    def image_model(self) -> Optional[str]:
        """Get the first/default image model."""
        return self.image_models[0] if self.image_models else None
    
    @property
    def vision_model(self) -> Optional[str]:
        """Get the first/default vision model."""
        return self.vision_models[0] if self.vision_models else None
        
    def set_load_mode(self, mode: str):
        """Set the load mode: 'ondemand', 'loadall', or 'loadswap'."""
        self.load_mode = mode
    
    def set_default_model(self, model_name: str, config: Dict = None, backend_type: str = "auto"):
        """Set the default/main text model."""
        self.default_model = model_name
        self.config[model_name] = config or {}
        self.model_backend_types[model_name] = backend_type
    
    def set_audio_model(self, model_name: str, config: Dict = None):
        """Add an audio transcription model."""
        if model_name not in self.audio_models:
            self.audio_models.append(model_name)
        self.config[f"audio:{model_name}"] = config or {}
    
    def set_tts_model(self, model_name: str, config: Dict = None):
        """Set the text-to-speech model."""
        self.tts_model = model_name
        self.config[f"tts:{model_name}"] = config or {}
    
    def set_image_model(self, model_name: str, config: Dict = None):
        """Add an image generation model."""
        if model_name not in self.image_models:
            self.image_models.append(model_name)
        self.config[f"image:{model_name}"] = config or {}
    
    def set_vision_model(self, model_name: str, config: Dict = None):
        """Add a vision (image/video to text) model."""
        if model_name not in self.vision_models:
            self.vision_models.append(model_name)
        self.config[f"vision:{model_name}"] = config or {}
    
    def set_model_alias(self, alias: str, model_name: str):
        """Register an alias for a model."""
        self.model_aliases[alias] = model_name
    
    def get_model_for_request(self, requested_model: str) -> Optional[ModelManager]:
        """
        Get the appropriate model manager for a request based on model name.
        
        Model name conventions:
        - "default", empty, or matches default model -> use main model
        - "audio" -> use first/default audio model
        - "audio:modelname" -> use specific audio model
        - "image" -> use first/default image model  
        - "vision:modelname" or "image:modelname" -> use specific image model
        - "tts" -> use TTS model
        - "tts:modelname" -> use specific TTS model
        - Custom aliases -> resolve to actual model name
        - Otherwise match by model ID in multi_model_manager.models
        
        In ondemand mode with multiple text models:
        - If requested model is different from currently loaded model, 
          unload current and load new model on-demand (respecting --backend)
        """
        # Import global_args inside function to ensure it's available
        global global_args
        
        # Resolve custom aliases first
        if requested_model in self.model_aliases:
            requested_model = self.model_aliases[requested_model]
        
        # Handle empty or "default" model names
        if not requested_model or requested_model == "default":
            if self.default_model and self.default_model in self.models:
                self.current_model_key = self.default_model
                return self.models[self.default_model]
            # Model not loaded - check if it's in config (registered but unloaded)
            if self.default_model and self.default_model in self.config:
                # Need to reload the default model - cleanup image models first
                for key in list(self.models.keys()):
                    if key.startswith("image:"):
                        model_to_cleanup = self.models.get(key)
                        if model_to_cleanup is not None:
                            print(f"Unloading image model '{key}' from VRAM to reload text model")
                            self._aggressive_vram_cleanup(model_to_cleanup)
                        del self.models[key]
                
                # Add delay to allow VRAM to be freed
                import time
                time.sleep(2)
                
                # Now try to reload the default model
                try:
                    from llama_cpp import Llama
                    backend = self.config[self.default_model].get('backend_type', 'auto')
                    model_path = self.default_model
                    # Check if model_path is a URL and try to get cached path
                    if model_path.startswith('http://') or model_path.startswith('https://'):
                        cached_path = get_cached_model_path(model_path)
                        if cached_path:
                            model_path = cached_path
                            print(f"Using cached model path: {model_path}")
                        else:
                            print(f"Warning: Model URL not cached, cannot reload: {model_path}")
                            return None
                    load_kwargs = self.config[self.default_model].copy()
                    load_kwargs.pop('backend_type', None)
                    print(f"Reloading default model: {model_path}")
                    llm = Llama(model_path=model_path, **load_kwargs)
                    self.models[self.default_model] = ModelManager(llm, backend=backend)
                    self.current_model_key = self.default_model
                    return self.models[self.default_model]
                except Exception as e:
                    print(f"Error reloading default model: {e}")
            return None
        
        # Handle "audio" alias - use first/default audio model
        if requested_model == "audio":
            if self.audio_models:
                first_audio = self.audio_models[0]
                key = f"audio:{first_audio}"
                if key in self.models:
                    self.current_model_key = key
                    return self.models[key]
                # Try to load on demand
                return None
            return None
        
        # Handle "image" alias - use first/default image model
        if requested_model == "image":
            if self.image_models:
                first_image = self.image_models[0]
                key = f"image:{first_image}"
                if key in self.models:
                    self.current_model_key = key
                    return self.models[key]
                # Try to load on demand
                return None
            return None
        
        # Handle "tts" alias
        if requested_model == "tts":
            if self.tts_model:
                key = f"tts:{self.tts_model}"
                if key in self.models:
                    self.current_model_key = key
                    return self.models[key]
                return None
            return None
        
        # Check for specialized models with prefix
        if requested_model.startswith("audio:"):
            audio_name = requested_model[6:]  # Remove "audio:" prefix
            key = f"audio:{audio_name}"
            if key in self.models:
                self.current_model_key = key
                return self.models[key]
            elif audio_name in self.audio_models:
                # Try loading audio model on demand
                return None
            return None
        
        if requested_model.startswith("tts:"):
            tts_name = requested_model[4:]  # Remove "tts:" prefix
            key = f"tts:{tts_name}"
            if key in self.models:
                self.current_model_key = key
                return self.models[key]
            elif self.tts_model and tts_name == self.tts_model:
                return None
            return None
        
        # Handle both "vision:" and "image:" prefixes
        if requested_model.startswith("vision:") or requested_model.startswith("image:"):
            # Extract the model name (remove either prefix)
            if requested_model.startswith("vision:"):
                image_name = requested_model[7:]  # Remove "vision:" prefix
            else:
                image_name = requested_model[6:]  # Remove "image:" prefix
            key = f"image:{image_name}"
            if key in self.models:
                self.current_model_key = key
                return self.models[key]
            elif image_name in self.image_models:
                # Try loading image model on demand
                return None
            return None
        
        # Check if it's the default model
        if self.default_model and (requested_model == self.default_model or 
                                    requested_model.endswith(self.default_model.split("/")[-1])):
            self.current_model_key = self.default_model
            return self.models.get(self.default_model)
        
        # Check if any loaded model matches
        for key, model in self.models.items():
            if requested_model in key or key.endswith(requested_model.split("/")[-1]):
                self.current_model_key = key
                return model
        
        # === ON-DEMAND MODEL SWITCHING FOR TEXT MODELS ===
        # If we're in ondemand mode and the requested model is in config but not loaded,
        # we should try to load it on-demand (swap from current model)
        # Only for text models (not audio/image/tts which have their own handling)
        
        # First, cleanup any image models to free VRAM for text model
        for key in list(self.models.keys()):
            if key.startswith("image:"):
                model_to_cleanup = self.models.get(key)
                if model_to_cleanup is not None:
                    print(f"Unloading image model '{key}' from VRAM to make room for text model")
                    self._aggressive_vram_cleanup(model_to_cleanup)
                    del self.models[key]
        
        # Force garbage collection and clear GPU cache
        import gc
        gc.collect()
        try:
            import torch
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
        except:
            pass
        
        # Add a longer delay to allow VRAM to be freed (Vulkan needs more time)
        import time
        time.sleep(2)
        
        # Check if requested model is already loaded - if so, reuse it
        if requested_model in self.models:
            self.current_model_key = requested_model
            return self.models[requested_model]
        
        # Check if requested model is in our config (means it was registered but not loaded)
        if self.load_mode == "ondemand" and requested_model in self.config:
            # This is a text model that's registered but not loaded
            # We need to swap: unload current model and load this one
            
            # Always cleanup any loaded model (unless it's the same model we're about to load)
            for key in list(self.models.keys()):
                if key != requested_model:
                    model_to_cleanup = self.models.get(key)
                    if model_to_cleanup is not None:
                        print(f"Unloading '{key}' from VRAM to load '{requested_model}'")
                        try:
                            if hasattr(model_to_cleanup, 'cleanup') and callable(getattr(model_to_cleanup, 'cleanup')):
                                model_to_cleanup.cleanup()
                            elif hasattr(model_to_cleanup, 'model') and model_to_cleanup.model is not None:
                                # Handle ModelManager objects
                                if hasattr(model_to_cleanup.model, 'cleanup'):
                                    model_to_cleanup.model.cleanup()
                        except Exception as e:
                            print(f"Warning during cleanup of '{key}': {e}")
                        del self.models[key]
            
            # Load the new model on-demand
            print(f"ON-DEMAND SWAP: Loading model '{requested_model}' into VRAM")
            
            # Get the backend type for this model
            backend_type = getattr(self, 'model_backend_types', {}).get(requested_model, "auto")
            
            # Get config for this model
            model_config = self.config.get(requested_model, {})
            
            effective_backend = backend_type
            if effective_backend == "auto" and global_args:
                effective_backend = getattr(global_args, 'backend', 'auto')
            
            try:
                # Create new model manager and load the model
                new_manager = ModelManager()
                new_manager.load_model(
                    model_name=requested_model,
                    backend_type=effective_backend,
                    **model_config
                )
                self.models[requested_model] = new_manager
                self.current_model_key = requested_model
                self.active_in_vram = requested_model
                print(f"ON-DEMAND SWAP: Successfully loaded model '{requested_model}' with backend '{effective_backend}'")
                return new_manager
            except Exception as e:
                print(f"ON-DEMAND SWAP: Failed to load model '{requested_model}': {e}")
                # Try to restore the previous model if we had one
                return None
        
        # Also check if the model matches by short name (e.g., "Phi-3" matches "microsoft/Phi-3-mini-4k-instruct")
        if self.load_mode == "ondemand":
            # First, cleanup any image models to free VRAM for text model
            for key in list(self.models.keys()):
                if key.startswith("image:"):
                    model_to_cleanup = self.models.get(key)
                    if model_to_cleanup is not None:
                        print(f"Unloading image model '{key}' from VRAM to make room for text model")
                        try:
                            if hasattr(model_to_cleanup, 'cleanup') and callable(getattr(model_to_cleanup, 'cleanup')):
                                model_to_cleanup.cleanup()
                            elif hasattr(model_to_cleanup, 'model') and model_to_cleanup.model is not None:
                                if hasattr(model_to_cleanup.model, 'cleanup'):
                                    model_to_cleanup.model.cleanup()
                        except Exception as e:
                            print(f"Warning during cleanup of '{key}': {e}")
                        del self.models[key]
            
            for model_name in self.config.keys():
                # Only check text models (not audio:, image:, tts: prefixes)
                if ":" not in model_name:
                    short_name = model_name.split("/")[-1] if "/" in model_name else model_name
                    if requested_model.lower() in short_name.lower() or short_name.lower() in requested_model.lower():
                        # Found a matching model in config, try to load it
                        if model_name not in self.models:
                            # Always cleanup any loaded model (unless it's the same model we're about to load)
                            for key in list(self.models.keys()):
                                if key != model_name:
                                    model_to_cleanup = self.models.get(key)
                                    if model_to_cleanup is not None:
                                        print(f"Unloading '{key}' from VRAM to load '{model_name}'")
                                        try:
                                            if hasattr(model_to_cleanup, 'cleanup') and callable(getattr(model_to_cleanup, 'cleanup')):
                                                model_to_cleanup.cleanup()
                                            elif hasattr(model_to_cleanup, 'model') and model_to_cleanup.model is not None:
                                                if hasattr(model_to_cleanup.model, 'cleanup'):
                                                    model_to_cleanup.model.cleanup()
                                        except Exception as e:
                                            print(f"Warning during cleanup of '{key}': {e}")
                                        del self.models[key]
                            
                            # Load the new model on-demand
                            print(f"ON-DEMAND SWAP: Loading model '{model_name}' into VRAM")
                            
                            # Get the backend type for this model
                            backend_type = getattr(self, 'model_backend_types', {}).get(model_name, "auto")
                            
                            # Get config for this model
                            model_config = self.config.get(model_name, {})
                            
                            effective_backend = backend_type
                            if effective_backend == "auto" and global_args:
                                effective_backend = getattr(global_args, 'backend', 'auto')
                            
                            try:
                                new_manager = ModelManager()
                                new_manager.load_model(
                                    model_name=model_name,
                                    backend_type=effective_backend,
                                    **model_config
                                )
                                self.models[model_name] = new_manager
                                self.current_model_key = model_name
                                self.active_in_vram = model_name
                                print(f"ON-DEMAND SWAP: Successfully loaded model '{model_name}' with backend '{effective_backend}'")
                                return new_manager
                            except Exception as e:
                                print(f"ON-DEMAND SWAP: Failed to load model '{model_name}': {e}")
                                return None
        
        return None
    
    def add_model(self, key: str, manager: ModelManager):
        """Add a model manager for a specific key."""
        self.models[key] = manager
    
    def get_model(self, key: str) -> Optional[ModelManager]:
        """Get a model manager by key."""
        return self.models.get(key)
    
    def get_current_model(self) -> Optional[ModelManager]:
        """Get the currently active model."""
        if self.current_model_key:
            return self.models.get(self.current_model_key)
        if self.default_model:
            return self.models.get(self.default_model)
        return None
    
    def list_models(self) -> List[ModelInfo]:
        """List all available models."""
        models = []
        
        # Add default model(s)
        if self.default_model:
            model_id = self.default_model
            # Skip URLs - they are download sources, not model identifiers
            if not (model_id.startswith("http://") or model_id.startswith("https://")):
                # Also add short name
                short_name = self.default_model.split("/")[-1] if "/" in self.default_model else self.default_model
                if short_name != self.default_model:
                    models.append(ModelInfo(id=short_name))
                models.append(ModelInfo(id=model_id))
                models.append(ModelInfo(id="default"))
        
        # Add aliases for first/default models
        if self.audio_models:
            models.append(ModelInfo(id="audio"))  # Alias for first audio model
            # Add all audio models
            for audio_id in self.audio_models:
                models.append(ModelInfo(id=f"audio:{audio_id}"))
        
        # Add TTS models
        if self.tts_model:
            models.append(ModelInfo(id="tts"))  # Alias for TTS
            tts_id = f"tts:{self.tts_model}"
            models.append(ModelInfo(id=tts_id))
        
        # Add vision/image models
        if self.image_models:
            models.append(ModelInfo(id="image"))  # Alias for first image model
            # Add all image models - convert URLs to cached paths for display
            for image_id in self.image_models:
                # Check if image_id is a URL and try to get cached path
                display_id = image_id
                if image_id.startswith('http://') or image_id.startswith('https://'):
                    cached_path = get_cached_model_path(image_id)
                    if cached_path:
                        # Use the filename from cached path for display
                        display_id = os.path.basename(cached_path)
                models.append(ModelInfo(id=f"image:{display_id}"))
        
        # Add loaded models that aren't in the above categories
        for key in self.models:
            # Skip if already added
            if key == self.default_model or key.startswith("audio:") or key.startswith("image:") or key.startswith("tts:"):
                continue
            # Skip short names (already added)
            if self.default_model and key == self.default_model.split("/")[-1]:
                continue
            # Skip URLs - they are download sources, not model identifiers
            if key.startswith("http://") or key.startswith("https://"):
                continue
            models.append(ModelInfo(id=key))
        
        # Add custom model aliases
        for alias in self.model_aliases:
            models.append(ModelInfo(id=alias))
        
        return models if models else [ModelInfo(id="default")]
    
    def swap_model_to_vram(self, model_key: str) -> bool:
        """
        Swap a model from CPU RAM to GPU VRAM.
        Returns True if successful, False otherwise.
        Only applies in loadswap mode.
        """
        if self.load_mode != "loadswap":
            return True  # No swapping needed in other modes
        
        if self.active_in_vram == model_key:
            return True  # Already in VRAM
        
        # If another model is in VRAM, swap it to CPU first
        if self.active_in_vram and self.active_in_vram in self.models:
            self._swap_model_to_cpu(self.active_in_vram)
        
        # Swap the requested model to VRAM
        success = self._swap_model_to_vram(model_key)
        if success:
            self.active_in_vram = model_key
        return success
    
    def _swap_model_to_vram(self, model_key: str) -> bool:
        """Internal method to swap a specific model to VRAM."""
        # This would need backend-specific implementation
        # For now, we assume the model is already in memory
        print(f"SWAP: Moving model {model_key} to VRAM")
        return True
    
    def _swap_model_to_cpu(self, model_key: str) -> bool:
        """Internal method to swap a specific model to CPU RAM."""
        # This would need backend-specific implementation
        # For now, we just track that it's in CPU
        print(f"SWAP: Moving model {model_key} to CPU RAM")
        return True
    
    def get_active_model_key(self) -> Optional[str]:
        """Get the currently active model key."""
        return self.current_model_key or self.default_model
    
    def cleanup(self):
        """Cleanup all models."""
        for key, model in self.models.items():
            try:
                if hasattr(model, 'cleanup') and callable(getattr(model, 'cleanup')):
                    model.cleanup()
                elif hasattr(model, 'model') and model.model is not None:
                    # Handle ModelManager objects
                    if hasattr(model.model, 'cleanup'):
                        model.model.cleanup()
                    elif hasattr(model.model, 'model'):
                        # Nested model (e.g., StableDiffusion wrapper)
                        if model.model.model is not None:
                            del model.model.model
                # Remove from dict
                del self.models[key]
            except Exception as e:
                print(f"Warning: Failed to cleanup model '{key}': {e}")
        self.models.clear()
# Global multi-model manager
multi_model_manager = MultiModelManager()
# Global model manager (for backward compatibility)
model_manager = ModelManager()

# Global args for access in endpoints
global_args = None


def check_reply_filter(filter_type: str, model_type: str = "text", model_name: str = None) -> bool:
    """
    Check if a specific reply filter is enabled.
    
    Args:
        filter_type: The filter to check ('malformed', 'tool_calls', 'all')
        model_type: The model type ('text', 'image', etc.)
        model_name: The specific model name (optional) - e.g., 'llama-3.1', 'sd-xl'
    
    Returns:
        True if the filter should be applied, False otherwise.
    
    Syntax:
        # Apply to all models
        --reply-filters all
        
        # Apply to all text or all image models
        --reply-filters text:malformed
        --reply-filters image:tool_calls
        
        # Apply to specific model
        --reply-filters text:llama-3.1:malformed
        --reply-filters image:sd-xl:tool_calls
        
        # Comma-separated for multiple filters on same target
        --reply-filters text:malformed,tool_calls
        --reply-filters text:llama-3.1:malformed,tool_calls
    """
    reply_filters = getattr(global_args, 'reply_filters', []) or []
    
    for filter_spec in reply_filters:
        # Handle comma-separated values: "malformed,tool_calls" or "text:malformed,tool_calls"
        if ',' in filter_spec:
            # Check each comma-separated part
            for f in filter_spec.split(','):
                f = f.strip()
                if f and check_single_filter(f, filter_type, model_type, model_name):
                    return True
            continue
        
        # Check single filter spec
        if check_single_filter(filter_spec, filter_type, model_type, model_name):
            return True
    
    return False


def check_single_filter(filter_spec: str, filter_type: str, model_type: str, model_name: str = None) -> bool:
    """
    Check a single filter specification against the model.
    """
    # Handle model-specific filters: "text:malformed", "text:model_name:malformed", "image:sd-xl:tool_calls"
    if ':' in filter_spec:
        parts = filter_spec.split(':')
        spec_model_type = parts[0]
        
        # Check if this filter spec matches our model type
        if spec_model_type != model_type and spec_model_type != '*':
            return False
        
        # If there's a model name in the spec, check for exact match or wildcard
        if len(parts) > 2:
            # Format: text:model_name:filter or image:model_name:filter
            spec_model_name = parts[1]
            spec_filter = parts[2]
            
            # Check model name matches (wildcard support)
            if spec_model_name != '*' and spec_model_name != model_name:
                return False
            
            # Check filter matches
            return spec_filter == 'all' or spec_filter == filter_type
        else:
            # Format: text:malformed (no specific model, applies to all of this type)
            spec_filter = parts[1]
            return spec_filter == 'all' or spec_filter == filter_type
    else:
        # Simple filter: "malformed" or "all" - applies to all models
        return filter_spec == 'all' or filter_spec == filter_type


def check_hf_chat_template(model_type: str = "text", model_name: str = None) -> tuple:
    """
    Check if HuggingFace chat template should be used for the model.
    Returns a tuple (should_use, template_name) where template_name is the template to use or None for auto-detect.
    
    Args:
        model_type: The model type ('text', 'image', etc.)
        model_name: The specific model name (optional)
    
    Returns:
        Tuple of (should_use: bool, template_name: str or None)
        template_name is None means auto-detect from tokenizer
    
    Examples:
        # Auto-detect and apply to all models
        --hf-chat-template auto
        
        # Apply to all text models with auto-detect
        --hf-chat-template text
        
        # Apply to specific model with auto-detect
        --hf-chat-template text:llama-3.1
        
        # Apply to specific model with specific template
        --hf-chat-template "llama-3.1:llama3"
        --hf-chat-template "phi-3:chatml"
    """
    hf_chat_template = getattr(global_args, 'hf_chat_template', []) or []
    
    # If empty list, HF chat template is not enabled
    if not hf_chat_template:
        return (False, None)
    
    for spec in hf_chat_template:
        # Handle auto-detect - try to load HF tokenizer and auto-detect template
        if spec == 'auto' or spec == '':
            # Applies to all models when using 'auto'
            return (True, None)
        
        # Check if this spec has a template specified after the model name
        # Format: "model_name:template_name" or "type:model_name:template_name"
        parts = spec.split(':')
        
        if len(parts) == 1:
            # Just a type or single value
            spec_val = parts[0]
            if spec_val == model_type or spec_val == '*':
                return (True, None)
            # Check if it matches the model name directly (when model_type is part of the name)
            if model_name and (spec_val in model_name or model_name in spec_val):
                return (True, None)
        elif len(parts) == 2:
            # Format: "type:model_name" or "model_name:template"
            spec_type = parts[0]
            spec_model = parts[1]
            
            # Check if it's "text" or "image" type
            if spec_type in ('text', 'image', '*'):
                if spec_type == model_type or spec_type == '*':
                    # Check if model name matches
                    if spec_model == model_name or spec_model == '*':
                        return (True, None)
            else:
                # It's "model_name:template" format
                if model_name and (spec_model in model_name or model_name in spec_model):
                    return (True, spec_type)  # spec_type is actually the template!
        elif len(parts) == 3:
            # Format: "type:model_name:template"
            spec_type = parts[0]
            spec_model = parts[1]
            spec_template = parts[2]
            
            if spec_type == model_type or spec_type == '*':
                if spec_model == model_name or spec_model == '*':
                    return (True, spec_template)
    
    return (False, None)

# Global system prompt (set via --system-prompt flag)
# None = don't inject, True = use default, string = use custom text
global_system_prompt = None

# Global debug flag
global_debug = False
global_file_path = None

# =============================================================================
# Queue Manager for Model Loading Notifications
# =============================================================================

class QueueManager:
    """
    Manages request queue for model loading notifications.
    When clients are waiting for a model to load, sends them progress updates.
    """
    
    def __init__(self):
        self.waiting_requests: Dict[str, float] = {}  # request_id -> start_time
        self.current_request_id: Optional[str] = None
        self.model_loading: bool = False
        self.model_name: Optional[str] = None
        self.lock = asyncio.Lock()
    
    async def add_waiting(self, request_id: str) -> None:
        """Add a request to the waiting queue."""
        async with self.lock:
            self.waiting_requests[request_id] = time.time()
    
    async def remove_waiting(self, request_id: str) -> None:
        """Remove a request from the waiting queue."""
        async with self.lock:
            self.waiting_requests.pop(request_id, None)
    
    async def start_processing(self, request_id: str, model_name: str = None) -> None:
        """Mark a request as now processing (model loaded)."""
        async with self.lock:
            self.waiting_requests.pop(request_id, None)
            self.current_request_id = request_id
            self.model_name = model_name
    
    async def finish_processing(self) -> None:
        """Mark current request as finished."""
        async with self.lock:
            self.current_request_id = None
    
    async def is_waiting(self, request_id: str) -> bool:
        """Check if a request is in the waiting queue."""
        async with self.lock:
            return request_id in self.waiting_requests
    
    async def get_wait_time(self, request_id: str) -> float:
        """Get how long a request has been waiting in seconds."""
        async with self.lock:
            if request_id in self.waiting_requests:
                return time.time() - self.waiting_requests[request_id]
            return 0.0
    
    async def get_queue_position(self, request_id: str) -> int:
        """Get the position of a request in the queue (1-based)."""
        async with self.lock:
            keys = list(self.waiting_requests.keys())
            try:
                return keys.index(request_id) + 1
            except ValueError:
                return 0
# Global queue manager
queue_manager = QueueManager()
# =============================================================================
# FastAPI Application
# =============================================================================

@asynccontextmanager
async def lifespan(app: FastAPI):
    """Lifespan context manager for startup/shutdown."""
    # Startup
    yield
    # Shutdown
    multi_model_manager.cleanup()
    model_manager.cleanup()
    # Stop whisper-server if running
    if multi_model_manager.whisper_server:
        multi_model_manager.whisper_server.stop()
app = FastAPI(
    title="OpenAI-Compatible API",
    description="OpenAI-compatible API supporting NVIDIA (CUDA) and Vulkan backends",
    version="2.0.0",
    lifespan=lifespan,
)

# Add request logging middleware for debugging
@app.middleware("http")
async def log_requests(request: Request, call_next):
    """Log all incoming requests for debugging."""
    if request.url.path in ["/v1/chat/completions", "/v1/completions"]:
        body = b""
        body_str = ""
        try:
            body = await request.body()
            body_str = body.decode('utf-8')
            
            # In debug mode, dump the full request
            if global_debug:
                print(f"\n{'='*80}")
                print(f"=== FULL DEBUG REQUEST ===")
                print(f"{'='*80}")
                print(f"Path: {request.url.path}")
                print(f"Method: {request.method}")
                print(f"Headers: {dict(request.headers)}")
                print(f"\n--- FULL BODY ({len(body)} bytes) ---")
                print(body_str)
                print(f"--- END FULL BODY ---")
                print(f"{'='*80}\n")
            else:
                print(f"\n{'='*60}")
                print(f"=== INCOMING REQUEST ===")
                print(f"{'='*60}")
                print(f"Path: {request.url.path}")
                print(f"Method: {request.method}")
                print(f"Headers: {dict(request.headers)}")
                print(f"\n--- RAW BODY ({len(body)} bytes) ---")
                # Print body with truncation for very large bodies
                if len(body_str) > 2000:
                    print(f"{body_str[:1000]}...\n... [truncated {len(body_str)-2000} chars] ...\n...{body_str[-1000:]}")
                else:
                    print(body_str)
                print(f"--- END RAW BODY ---")
            
            # Try to parse as JSON to see if it's valid
            try:
                parsed = json.loads(body_str)
                print(f"\n--- PARSED JSON STRUCTURE ---")
                print(f"Keys: {list(parsed.keys())}")
                if 'messages' in parsed and isinstance(parsed['messages'], list):
                    print(f"Number of messages: {len(parsed['messages'])}")
                    for i, msg in enumerate(parsed['messages']):
                        role = msg.get('role', 'unknown')
                        content_preview = str(msg.get('content', ''))[:50].replace('\n', ' ')
                        print(f"  [{i}] {role}: {content_preview}...")
                print(f"--- END PARSED JSON ---")
            except json.JSONDecodeError as e:
                print(f"\n*** JSON Parse Error: {e} ***")
                print(f"Error at position: char {e.pos}, line {e.lineno}, column {e.colno}")
            except Exception as e:
                print(f"\n*** Error analyzing JSON: {e} ***")
        except Exception as e:
            # Handle ClientDisconnect and other exceptions gracefully
            print(f"Error logging request: {e}")
            # Continue with empty body if we couldn't read it
            body = b""
        
        # Re-create request with body for downstream handlers (only if we successfully read it)
        if body:
            async def receive():
                return {"type": "http.request", "body": body}
            request = Request(request.scope, receive, request._send)
    
    try:
        response = await call_next(request)
        if request.url.path in ["/v1/chat/completions", "/v1/completions"]:
            print(f"\n--- RESPONSE ---")
            print(f"Status Code: {response.status_code}")
            
            # For error responses, try to read and log the body, then create a new response
            if response.status_code >= 400:
                try:
                    # Read the response body
                    response_body = b""
                    async for chunk in response.body_iterator:
                        response_body += chunk
                    
                    error_body = response_body.decode('utf-8')
                    print(f"Error Response Body: {error_body}")
                    
                    # Try to parse and pretty-print error details
                    try:
                        error_json = json.loads(error_body)
                        if 'detail' in error_json:
                            print(f"\n*** VALIDATION ERROR DETAILS ***")
                            detail = error_json['detail']
                            if isinstance(detail, list):
                                for err in detail:
                                    if isinstance(err, dict):
                                        loc = err.get('loc', [])
                                        msg = err.get('msg', 'Unknown error')
                                        err_type = err.get('type', 'unknown')
                                        print(f"  - Location: {loc}")
                                        print(f"    Message: {msg}")
                                        print(f"    Type: {err_type}")
                            else:
                                print(f"  Detail: {detail}")
                            print(f"*** END ERROR DETAILS ***")
                    except Exception as parse_err:
                        print(f"Could not parse error details: {parse_err}")
                    
                    # Create new response with the same body
                    from starlette.responses import Response
                    return Response(
                        content=response_body,
                        status_code=response.status_code,
                        headers=dict(response.headers),
                        media_type=response.media_type
                    )
                except Exception as read_err:
                    print(f"Could not read error response body: {read_err}")
            
            print(f"--- END RESPONSE ---")
            print(f"{'='*60}\n")
        return response
    except Exception as e:
        print(f"\n*** EXCEPTION DURING REQUEST PROCESSING ***")
        print(f"Error type: {type(e).__name__}")
        print(f"Error message: {e}")
        import traceback
        traceback.print_exc()
        print(f"*** END EXCEPTION ***\n")
        raise
    finally:
        if request.url.path in ["/v1/chat/completions", "/v1/completions"]:
            pass  # End logging already done above for successful responses
@app.get("/v1/models", response_model=ModelList)
async def list_models():
    """List available models."""
    models = multi_model_manager.list_models()
    return ModelList(data=models)

# =============================================================================
# Static File Serving Endpoint
# =============================================================================

@app.get("/v1/files/{filename}")
async def get_file(filename: str):
    """Serve generated files (images, audio) from the file path directory."""
    import os
    if not global_file_path:
        raise HTTPException(status_code=404, detail="File path not configured")
    
    file_path = os.path.join(global_file_path, filename)
    if not os.path.exists(file_path):
        raise HTTPException(status_code=404, detail="File not found")
    
    return FileResponse(file_path)

# =============================================================================
# Audio Transcription Endpoint
# =============================================================================

from fastapi import UploadFile, File, Form

@app.post("/v1/audio/transcriptions")
async def create_transcription(
    model: str = Form(...),
    file: UploadFile = File(...),
    language: Optional[str] = Form(None),
    prompt: Optional[str] = Form(None),
    response_format: Optional[str] = Form("json"),
    temperature: Optional[float] = Form(0.0),
):
    """
    Audio transcription endpoint (OpenAI-compatible).
    
    Supports:
    - OpenAI's whisper-1 model (via OpenAI API)
    - Local faster-whisper models (when --audio-model is specified)
    - whisper.cpp server (when --whisper-server is specified)
    """
    # Check if whisper-server is available FIRST (before checking audio_model)
    print(f"DEBUG: Audio request - whisper_server available: {multi_model_manager.whisper_server is not None}, running: {multi_model_manager.whisper_server.is_running() if multi_model_manager.whisper_server else 'N/A'}")
    if multi_model_manager.whisper_server and multi_model_manager.whisper_server.is_running():
        # Use whisper-server - read file and send to server
        file_content = await file.read()
        print(f"DEBUG: whisper-server transcription request - file_size={len(file_content)}, language={language}, prompt={prompt}")
        result = multi_model_manager.whisper_server.transcribe(
            file_content,
            language=language,
            prompt=prompt
        )
        print(f"DEBUG: whisper-server transcription result: {result}")
        if "error" in result:
            raise HTTPException(status_code=500, detail=result["error"])
        # Convert whisper-server response to OpenAI format
        text = result.get("text", "")
        return {
            "text": text
        }
    
    audio_model = multi_model_manager.audio_model
    
    # DEBUG: Print audio model status
    print(f"DEBUG: audio_model check - audio_models list: {multi_model_manager.audio_models}, audio_model: {audio_model}, whisper_server: {multi_model_manager.whisper_server}")
    
    # If no audio model configured, return an error
    if not audio_model:
        raise HTTPException(
            status_code=400,
            detail="Audio transcription not configured. Use --audio-model or --whisper-server to specify a model."
        )
    
    # Determine model to use - always use the configured audio model
    # The model parameter from the request is ignored in favor of the configured transcription model
    actual_model = audio_model  # This is the configured transcription model from --audio-model
    model_to_use = actual_model
    
    print(f"DEBUG: Transcription request model: {model}, using configured model: {actual_model}")
    
    # Check if Vulkan is available for whispercpp
    whisper_vulkan_available = False
    try:
        # Check if whispercpp is installed and has Vulkan support
        import whispercpp
        # Try to detect Vulkan support by checking if we can list devices
        # whispercpp doesn't have a direct Vulkan check, but we can verify by environment
        if os.environ.get('VK_DEVICE_SELECT_DEVICE'):
            whisper_vulkan_available = True
            print(f"Whisper Vulkan: Using configured Vulkan device")
        elif os.path.exists('/dev/dri'):  # Linux DRM devices exist = AMD/Intel GPU
            whisper_vulkan_available = True
            print(f"Whisper Vulkan: Auto-detected GPU")
    except ImportError:
        pass
    
    # Read file content
    file_content = await file.read()
    
    # Write to temp file
    import tempfile
    
    with tempfile.NamedTemporaryFile(delete=False, suffix=f"_{file.filename}") as tmp:
        tmp.write(file_content)
        tmp_path = tmp.name
    
    try:
        # Check if model is a GGUF file - faster-whisper doesn't support GGUF format
        is_gguf_model = model_to_use.endswith('.gguf') or 'gguf' in model_to_use.lower()
        
        if is_gguf_model:
            # Skip faster-whisper for GGUF files - go directly to whispercpp
            print("Detected GGUF model - using whispercpp backend")
            faster_whisper_failed = True
        else:
            # Try faster-whisper first
            faster_whisper_failed = False
            try:
                from faster_whisper import WhisperModel
                
                # Determine compute type based on GPU availability
                import torch
                if torch.cuda.is_available():
                    compute_type = "float16"
                else:
                    compute_type = "int8"
                
                # Try to load the model (lazy loading)
                model_key = f"audio:{model_to_use}"
                whisper_model = multi_model_manager.get_model(model_key)
                
                if whisper_model is None:
                    print(f"Loading faster-whisper model: {model_to_use}")
                    
                    # Check if model_to_use is a URL - download it (with caching)
                    model_path = None
                    if model_to_use.startswith('http://') or model_to_use.startswith('https://'):
                        # Check cache first
                        cached_path = get_cached_model_path(model_to_use)
                        if cached_path:
                            model_to_use = cached_path
                            print(f"Using cached model: {model_to_use}")
                        else:
                            print(f"Downloading model from URL: {model_to_use}")
                            try:
                                import requests
                                import hashlib
                                
                                # Get cache directory
                                cache_dir = get_model_cache_dir()
                                
                                # Extract filename from URL
                                url_path = model_to_use.split('?')[0]
                                filename = os.path.basename(url_path)
                                
                                if not filename.endswith('.bin') and not filename.endswith('.ggml'):
                                    filename = "whisper-model.bin"
                                
                                # Create safe filename in cache
                                url_hash = hashlib.sha256(model_to_use.encode()).hexdigest()
                                cached_filename = f"{url_hash}_{filename}"
                                model_path = os.path.join(cache_dir, cached_filename)
                                
                                # Download to cache
                                response = requests.get(model_to_use, stream=True)
                                response.raise_for_status()
                                
                                total_size = int(response.headers.get('content-length', 0))
                                downloaded = 0
                                
                                with open(model_path, 'wb') as f:
                                    for chunk in response.iter_content(chunk_size=8192*1024):
                                        if chunk:
                                            f.write(chunk)
                                            downloaded += len(chunk)
                                            if total_size > 0:
                                                percent = (downloaded / total_size) * 100
                                                print(f"Downloaded: {percent:.1f}%", end='\r')
                                
                                print(f"\nDownloaded and cached to: {model_path}")
                                model_to_use = model_path
                                
                            except Exception as e:
                                print(f"Error downloading model: {e}")
                                raise
                    
                    whisper_model = WhisperModel(
                        model_to_use,
                        device="cpu",  # faster-whisper CUDA doesn't work with AMD/Vulkan
                        compute_type=compute_type
                    )
                    # Store in multi_model_manager
                    multi_model_manager.add_model(model_key, whisper_model)
                
                # Run transcription
                segments, info = whisper_model.transcribe(
                    tmp_path,
                    language=language,
                    initial_prompt=prompt,
                    temperature=temperature or 0.0,
                )
                
                # Collect all segments
                text_parts = []
                for segment in segments:
                    text_parts.append(segment.text.strip())
                
                full_text = " ".join(text_parts)
                
                return {"text": full_text}
            
            except ImportError:
                # faster-whisper not available, will try whispercpp below
                faster_whisper_failed = True
            except Exception as e:
                # faster-whisper failed for some other reason
                print(f"Warning: faster-whisper failed to load model: {e}")
                faster_whisper_failed = True
        
        # If faster-whisper failed (not installed or couldn't load), try whispercpp
        if faster_whisper_failed:
            try:
                import whispercpp
                
                # Try to load the model (lazy loading)
                model_key = f"audio:{model_to_use}"
                whisper_model = multi_model_manager.get_model(model_key)
                
                if whisper_model is None:
                    print(f"Loading whispercpp model: {model_to_use}")
                    if whisper_vulkan_available:
                        print(f"  -> Using Vulkan GPU acceleration (device {whisper_vulkan_device})")
                    
                    # Check if model_to_use is a URL - download it (with caching)
                    model_path = None
                    if model_to_use.startswith('http://') or model_to_use.startswith('https://'):
                        # Check cache first
                        cached_path = get_cached_model_path(model_to_use)
                        if cached_path:
                            model_path = cached_path
                            print(f"Using cached model: {model_path}")
                        else:
                            print(f"Downloading model from URL: {model_to_use}")
                            try:
                                import requests
                                import hashlib
                                
                                # Get cache directory
                                cache_dir = get_model_cache_dir()
                                
                                # Extract filename from URL
                                url_path = model_to_use.split('?')[0]
                                filename = os.path.basename(url_path)
                                
                                if not filename.endswith('.gguf'):
                                    filename = "whisper-model.gguf"
                                
                                # Create safe filename in cache
                                url_hash = hashlib.sha256(model_to_use.encode()).hexdigest()
                                cached_filename = f"{url_hash}_{filename}"
                                model_path = os.path.join(cache_dir, cached_filename)
                                
                                # Download to cache
                                response = requests.get(model_to_use, stream=True)
                                response.raise_for_status()
                                
                                total_size = int(response.headers.get('content-length', 0))
                                downloaded = 0
                                
                                with open(model_path, 'wb') as f:
                                    for chunk in response.iter_content(chunk_size=8192*1024):
                                        if chunk:
                                            f.write(chunk)
                                            downloaded += len(chunk)
                                            if total_size > 0:
                                                percent = (downloaded / total_size) * 100
                                                print(f"Downloaded: {percent:.1f}%", end='\r')
                                
                                print(f"\nDownloaded and cached to: {model_path}")
                                model_to_use = model_path
                                
                            except Exception as e:
                                print(f"Error downloading model: {e}")
                                raise
                    
                    # whispercpp needs a local file path
                    if not model_path:
                        model_path = model_to_use if os.path.isfile(model_to_use) else None
                    
                    if not model_path or not os.path.isfile(model_path):
                        raise HTTPException(
                            status_code=400,
                            detail="whispercpp requires a local GGUF file path. Cannot use URLs directly."
                        )
                    
                    # Load the whispercpp model
                    # Note: whispercpp uses model files directly, not paths like Llama
                    # whispercpp only supports:
                    # 1. Built-in model names (tiny, base, small, medium, large-v1, large)
                    # 2. Pre-converted GGUF files in whisper.cpp format (NOT HuggingFace GGUF)
                    try:
                        whisper_model = whispercpp.Whisper.from_pretrained(model_path)
                    except Exception as e:
                        error_msg = str(e).lower()
                        if 'not a valid preconverted model' in error_msg:
                            # This is expected for HuggingFace GGUF files
                            print(f"Warning: whispercpp does not support HuggingFace GGUF format")
                            print("whispercpp only supports its own pre-converted models or built-in names.")
                            print("For Vulkan audio transcription, please either:")
                            print("  1. Install PyTorch + faster-whisper: pip install torch faster-whisper")
                            print("  2. Use a built-in whispercpp model: --audio-model base")
                            raise HTTPException(
                                status_code=400,
                                detail="whispercpp does not support HuggingFace GGUF Whisper models. Use --audio-model with a built-in name (tiny/base/small/medium/large-v1/large) or install faster-whisper with PyTorch."
                            )
                        else:
                            raise
                    
                    # Store in multi_model_manager
                    multi_model_manager.add_model(model_key, whisper_model)
                
                # Run transcription
                # whispercpp returns text directly
                result = whisper_model.transcribe(tmp_path)
                
                # Collect all segments
                text_parts = []
                for segment in result:
                    text_parts.append(str(segment).strip())
                
                full_text = " ".join(text_parts) if text_parts else ""
                
                return {"text": full_text}
            
            except ImportError as e:
                # Check if it's a specific error about whispercpp not working
                error_msg = str(e).lower()
                if 'invalid elf' in error_msg or 'mach-o' in error_msg:
                    # whispercpp library failed to load - architecture mismatch
                    print(f"Warning: whispercpp library failed to load: {e}")
                    print("This usually means whispercpp was installed for a different OS/architecture.")
                    print("Try reinstalling: pip install whispercpp --force-reinstall")
                    print("Audio model will load on-demand when transcription is requested.")
                else:
                    # Neither faster-whisper nor whispercpp available
                    print(f"Warning: No audio transcription library available: {e}")
                    print("Options:")
                    print("  1. Install PyTorch + faster-whisper: pip install torch faster-whisper")
                    print("  2. Use a built-in whispercpp model: --audio-model base")
                    print("  3. Use --whisper-cpp to specify whisper.cpp CLI path")
                    print("Audio model will load on-demand when transcription is requested.")
                
                # Try whisper.cpp CLI as fallback if specified
                whisper_cpp_path = getattr(global_args, 'whisper_cpp', None)
                if whisper_cpp_path and os.path.isfile(whisper_cpp_path):
                    print(f"Using whisper.cpp CLI: {whisper_cpp_path}")
                    try:
                        import subprocess
                        
                        # Determine the model path - check if it's already cached or needs downloading
                        model_path = None
                        
                        # First check if it's a local file
                        if os.path.isfile(model_to_use):
                            model_path = model_to_use
                            print(f"DEBUG: Using local model file: {model_path}")
                        else:
                            # Check cache for downloaded model
                            cached = get_cached_model_path(model_to_use)
                            if cached and os.path.isfile(cached):
                                model_path = cached
                                print(f"DEBUG: Using cached model: {model_path}")
                            else:
                                # Download the model if not cached
                                print(f"DEBUG: Model not cached, downloading: {model_to_use}")
                                cache_dir = get_model_cache_dir()
                                model_path = download_model(model_to_use, cache_dir)
                                print(f"DEBUG: Downloaded model to: {model_path}")
                        
                        print(f"DEBUG: Whisper model: {model_to_use}")
                        print(f"DEBUG: Whisper model path (resolved): {model_path}")
                        
                        # Build whisper.cpp CLI command
                        # Usage: whisper-cli [options] file0 file1 ...
                        # Options:
                        #   -m, --model FNAME    model path
                        #   -f, --file FNAME    input audio file
                        #   -of, --output-file  output file (without extension)
                        #   -dev, --device N   GPU device ID
                        #   -otxt, --output-txt output as text file
                        cmd = [whisper_cpp_path]
                        if model_path:
                            cmd.extend(["-m", model_path])
                        cmd.extend(["-f", tmp_path])
                        cmd.extend(["-otxt"])  # Output as text
                        
                        # Add Vulkan device if specified
                        audio_vulkan_device = getattr(global_args, 'audio_vulkan_device', 0)
                        if audio_vulkan_device is not None:
                            cmd.extend(["-dev", str(audio_vulkan_device)])
                        
                        print(f"DEBUG: Running whisper.cpp command: {' '.join(cmd)}")
                        
                        result = subprocess.run(cmd, capture_output=True, text=True, timeout=300)
                        
                        if result.returncode == 0:
                            # Read output - whisper.cpp -otxt outputs to stdout or a file
                            # With -otxt flag, it outputs the transcribed text
                            if result.stdout:
                                full_text = result.stdout
                            else:
                                # Try to read from a file with same name as input but .txt extension
                                output_txt = tmp_path + ".txt"
                                if os.path.exists(output_txt):
                                    with open(output_txt, 'r') as f:
                                        full_text = f.read()
                                    os.unlink(output_txt)
                                else:
                                    full_text = ""
                            return {"text": full_text}
                        else:
                            print(f"whisper.cpp CLI error: {result.stderr}")
                    except Exception as subprocess_error:
                        print(f"whisper.cpp CLI subprocess error: {subprocess_error}")
                
                # Return error response
                raise HTTPException(
                    status_code=501,
                    detail="Audio transcription not available. Install faster-whisper (requires PyTorch) or use --whisper-cpp to specify whisper.cpp CLI path."
                )
        
    finally:
        # Cleanup temp file
        os.unlink(tmp_path)
# =============================================================================
# Image Generation Endpoint
# =============================================================================

# Global load_mode tracker - will be set in main()
def get_load_mode():
    return load_mode.get("mode", "ondemand")

# Helper function to get CFG scale for image generation
def get_cfg_scale():
    """Get CFG scale for image generation. Auto-detect VRAM for Vulkan."""
    cfg_scale = getattr(global_args, 'image_cfg_scale', 1.0)
    
    # If using Vulkan and CLI didn't specify cfg_scale (default 1.0), check VRAM
    if cfg_scale == 1.0:  # Only auto-detect if using default
        backend = getattr(global_args, 'backend', 'auto')
        image_backend = getattr(global_args, 'image_backend', 'auto')
        
        # Check if using Vulkan (either global or image-specific)
        use_vulkan = (backend == 'vulkan') or (image_backend == 'vulkan') or (image_backend == 'auto' and backend == 'auto')
        
        if use_vulkan:
            # Try to detect VRAM
            try:
                import subprocess
                # Try vulkaninfo first
                result = subprocess.run(['vulkaninfo', '-J'], capture_output=True, text=True, timeout=5)
                if result.returncode == 0:
                    import json
                    data = json.loads(result.stdout)
                    # Find device memory
                    for dev in data.get('devices', []):
                        mem = dev.get('deviceMemoryHeap', [{}])
                        for heap in mem:
                            if heap.get('flags', []).get('deviceLocal', False):
                                vram_mb = heap.get('size', 0) / (1024 * 1024)
                                print(f"DEBUG: Detected VRAM: {vram_mb:.0f} MB")
                                if vram_mb < 16000:  # Less than 16GB
                                    print(f"DEBUG: VRAM < 16GB, using cfg_scale=1.0 for better performance")
                                    return 1.0
                                break
            except Exception as e:
                print(f"DEBUG: Could not detect VRAM: {e}")
                # Default to 1.0 for Vulkan if detection fails
                return 1.0
    
    return cfg_scale

# Helper function to save generated images and return response dict
def save_image_response(img, request_format="base64", http_request=None):
    """
    Save image to file path if configured, return response dict.
    
    If --file-path is set and request_format is url (not base64), return only URL.
    If --file-path is set and request_format is base64, return both URL and base64.
    If --file-path is not set, return base64 as usual.
    """
    import base64
    import io
    import os
    import uuid
    from PIL import Image
    
    # Convert to PIL Image if needed
    if not isinstance(img, Image.Image):
        img = Image.fromarray(img)
    
    result = {}
    
    # Save to file path if configured
    if global_file_path:
        os.makedirs(global_file_path, exist_ok=True)
        # Generate unique filename
        filename = f"{uuid.uuid4().hex}.png"
        file_path = os.path.join(global_file_path, filename)
        img.save(file_path, format="PNG")
        # Add URL to response
        # Determine base URL based on --url argument
        url_setting = getattr(global_args, 'url', 'auto') if global_args else 'auto'
        if url_setting == 'auto':
            # Use server host from request headers (what client used to connect)
            if http_request:
                # Get the Host header - this is what the client used to reach the server
                client_host = http_request.headers.get('host', '')
                if not client_host:
                    # Fallback to client IP if no Host header
                    client_host = http_request.client.host if http_request.client else '127.0.0.1'
                # Strip port if present in Host header
                if ':' in client_host and not client_host.replace(':', '').isdigit():
                    client_host = client_host.split(':')[0]
                # Check if HTTPS is enabled
                use_https = getattr(global_args, 'https', False) or getattr(global_args, 'pubkey', None)
                protocol = "https" if use_https else "http"
                port = getattr(global_args, 'port', 8000)
                base_url = f"{protocol}://{client_host}:{port}"
            else:
                base_url = "http://127.0.0.1:8000"
        else:
            # Use explicitly provided URL (strip trailing slash if present)
            base_url = url_setting.rstrip('/')
        result["url"] = f"{base_url}/v1/files/{filename}"
        
        # If client explicitly requested base64, include it
        # Otherwise, only return URL when file-path is set
        if request_format == "base64":
            buffered = io.BytesIO()
            img.save(buffered, format="PNG")
            img_bytes = buffered.getvalue()
            img_base64 = base64.b64encode(img_bytes).decode('utf-8')
            result["b64_json"] = img_base64
    else:
        # No file-path, return base64 as usual
        buffered = io.BytesIO()
        img.save(buffered, format="PNG")
        img_bytes = buffered.getvalue()
        img_base64 = base64.b64encode(img_bytes).decode('utf-8')
        result["b64_json"] = img_base64
    
    return result

@app.post("/v1/images/generations")
async def create_image_generation(request: ImageGenerationRequest, http_request: Request = None):
    """
    Image generation endpoint (OpenAI-compatible).
    
    Supports:
    - Stable Diffusion via stable-diffusion-cpp-python (sd.cpp)
    - Stable Diffusion XL (via local inference with diffusers)
    - Other diffusers models
    """
    # Get or create semaphore for this model
    model_key = f"image:{request.model}" if request.model else "image"
    mode = get_load_mode()
    
    # Check if --image-1 is set (no queue, return 409 if busy)
    use_1_mode = queue_flags.get("image_1", False)
    
    # In loadall mode, allow 1 concurrent request per model
    # In ondemand mode, serialize all requests (use global semaphore)
    if mode == "loadall":
        if model_key not in model_semaphores:
            model_semaphores[model_key] = asyncio.Semaphore(1)
        semaphore = model_semaphores[model_key]
    else:
        # Use a global semaphore for ondemand mode
        if "global_image" not in model_semaphores:
            model_semaphores["global_image"] = asyncio.Semaphore(1)
        semaphore = model_semaphores["global_image"]
    
    # Try to acquire semaphore without blocking
    if use_1_mode:
        acquired = semaphore.locked()
        if acquired:
            raise HTTPException(
                status_code=409,
                detail="Image model is busy. Try again later."
            )
    
    async with semaphore:
        image_model = multi_model_manager.image_model
    
    # If no image model configured, try to use main --model as fallback
    if not image_model:
        # Try to get the main model from args
        main_model = getattr(global_args, 'model', None)
        if main_model and isinstance(main_model, list) and len(main_model) > 0:
            image_model = main_model[0]
        elif main_model:
            image_model = main_model
        
        # Check if main model is a GGUF file - can't use for image generation
        if image_model and ('.gguf' in image_model.lower() or 'gguf' in image_model.lower()):
            print(f"Note: Main model is a GGUF file (for text), not suitable for image generation")
            image_model = None  # Can't use GGUF for images
    
    # If still no image model configured, return an error
    if not image_model:
        raise HTTPException(
            status_code=400,
            detail="Image generation not configured. Use --image-model to specify a model."
        )
    
    # Determine model to use
    # Priority: 1) model specified in request, 2) default image model from --image-model
    model_to_use = request.model
    if not model_to_use or model_to_use == "image":
        # No model specified in request, use default
        model_to_use = image_model
    elif model_to_use.startswith("image:"):
        # Legacy format - strip prefix and use default
        model_to_use = image_model
    else:
        # Check if model_to_use is a valid model (URL, file, or known model)
        # If not, fallback to the configured image model to avoid HF resolution errors
        if image_model:
            is_url = model_to_use.startswith('http://') or model_to_use.startswith('https://')
            is_file = os.path.isfile(model_to_use) if model_to_use else False
            if not is_url and not is_file:
                # Unknown model name - use default instead of trying to resolve as HF
                print(f"Warning: Unknown model '{model_to_use}' in image generation request, using configured --image-model")
                model_to_use = image_model
    
    # Track errors for proper fallback chain
    diffusers_error = None
    sd_cpp_error = None
    
    # Parse size (e.g., "1024x1024")
    width, height = 1024, 1024
    if request.size:
        parts = request.size.split("x")
        if len(parts) == 2:
            try:
                width = int(parts[0])
                height = int(parts[1])
            except ValueError:
                pass
    
    # Try diffusers first (torch-based, best quality for NVIDIA)
    # Skip if it's a GGUF model (those need stable-diffusion-cpp)
    # First, cleanup any other models to free VRAM
    for key in list(multi_model_manager.models.keys()):
        # Skip image models
        if key.startswith("image:"):
            continue
        # Unload any other model (text, audio, etc.) to free VRAM
        model_to_cleanup = multi_model_manager.models.get(key)
        if model_to_cleanup is not None:
            print(f"Unloading '{key}' from VRAM to make room for diffusers image model")
            try:
                if hasattr(model_to_cleanup, 'cleanup') and callable(getattr(model_to_cleanup, 'cleanup')):
                    model_to_cleanup.cleanup()
                elif hasattr(model_to_cleanup, 'model') and model_to_cleanup.model is not None:
                    if hasattr(model_to_cleanup.model, 'cleanup'):
                        model_to_cleanup.model.cleanup()
            except Exception as e:
                print(f"Warning during cleanup of '{key}': {e}")
            del multi_model_manager.models[key]
    
    try:
        import torch
        from diffusers import StableDiffusionXLPipeline, DiffusionPipeline
        
        # Check if this is a GGUF model - skip diffusers for those
        is_gguf_model = (model_to_use.endswith('.gguf') or 'gguf' in model_to_use.lower() or
                        (model_to_use.startswith('http') and '.gguf' in model_to_use))
        
        if is_gguf_model:
            print(f"GGUF model detected ({model_to_use}), skipping diffusers, using stable-diffusion-cpp...")
            raise Exception("GGUF model - use stable-diffusion-cpp instead")
        
        # Determine model key
        model_key = f"image:{model_to_use}"
        pipeline = multi_model_manager.get_model(model_key)
        
        if pipeline is None:
            print(f"Loading Stable Diffusion model: {model_to_use}")
            
            # Determine precision from CLI argument
            precision = getattr(global_args, 'image_precision', 'f32') or 'f32'
            precision_map = {
                'bf16': torch.bfloat16,
                'f32': torch.float32,
                'f16': torch.float16,
                'f8': torch.float8_e4m3fn,
            }
            torch_dtype = precision_map.get(precision, torch.float32)
            print(f"Using precision: {precision} ({torch_dtype})")
            
            # Check if CPU offload is requested via CLI
            use_sequential_offload = getattr(global_args, 'image_cpu_offload', False)
            
            # Track loading attempts for OOM handling
            load_attempt = 0
            max_attempts = 3
            pipeline = None
            
            while pipeline is None and load_attempt < max_attempts:
                try:
                    load_attempt += 1
                    print(f"Loading attempt {load_attempt}/{max_attempts}...")
                    
                    # Try to load as Stable Diffusion XL first
                    try:
                        pipeline = StableDiffusionXLPipeline.from_pretrained(
                            model_to_use,
                            torch_dtype=torch_dtype,
                            use_safetensors=True,
                        )
                    except Exception:
                        # Try generic diffusion pipeline
                        pipeline = DiffusionPipeline.from_pretrained(
                            model_to_use,
                            torch_dtype=torch_dtype,
                            use_safetensors=True,
                        )
                    
                    # Apply memory optimizations based on attempt
                    if torch.cuda.is_available():
                        if load_attempt >= 2:
                            # Second attempt: enable attention slicing
                            print("Enabling attention slicing for lower VRAM usage...")
                            pipeline.enable_attention_slicing()
                        
                        if load_attempt >= 3 or use_sequential_offload:
                            # Third attempt or offload requested: enable sequential CPU offload
                            print("Enabling sequential CPU offload for lower VRAM usage...")
                            pipeline.enable_sequential_cpu_offload()
                        else:
                            # First attempt: try regular GPU
                            pipeline = pipeline.to("cuda")
                    else:
                        pipeline = pipeline.to("cpu")
                    
                except Exception as load_error:
                    error_msg = str(load_error).lower()
                    is_oom = any(x in error_msg for x in ['out of memory', 'oom', 'cuda error', 'cudamalloc'])
                    
                    if is_oom and load_attempt < max_attempts:
                        print(f"OOM during model loading: {load_error}")
                        print(f"Retrying with more aggressive memory optimization...")
                        pipeline = None  # Reset for retry
                        # Clear CUDA cache
                        if torch.cuda.is_available():
                            torch.cuda.empty_cache()
                    else:
                        raise load_error
            
            # Enable VAE tiling if requested (for lower VRAM usage)
            if getattr(global_args, 'vae_tiling', False):
                print("Enabling VAE tiling for lower VRAM usage...")
                try:
                    pipeline.enable_vae_tiling()
                except Exception as e:
                    print(f"Warning: Could not enable VAE tiling: {e}")
            
            multi_model_manager.add_model(model_key, pipeline)
        
        # Get timestamp BEFORE calling diffusers (to avoid scope issues)
        import time as time_module
        timestamp = int(time_module.time())
        
        # Generate images
        # Use request seed if provided, otherwise use CLI default seed
        seed = request.seed if request.seed is not None else getattr(global_args, 'image_seed', None)
        generator = None
        if seed is not None:
            generator = torch.Generator(device=pipeline.device).manual_seed(seed)
        
        # Quality: "standard" or "hd"
        quality = request.quality or "standard"
        
        # Use request parameters if provided, otherwise fall back to quality-based defaults
        num_steps = request.steps if request.steps else (30 if quality == "standard" else 50)
        cfg_scale = request.guidance_scale if request.guidance_scale else (
            getattr(global_args, 'image_cfg_scale', 7.5) if quality == "standard" else 9.0
        )
        
        # Generate
        result = pipeline(
            prompt=request.prompt,
            negative_prompt=None,
            num_images_per_prompt=request.n,
            height=height,
            width=width,
            generator=generator,
            guidance_scale=cfg_scale,
            num_inference_steps=num_steps,
        )
        
        # Extract images
        images = []
        try:
            result_images = result.images
        except Exception as img_err:
            print(f"Warning: Could not access result.images: {img_err}")
            # Try alternative: result might have 'image' or 'output'
            result_images = getattr(result, 'image', None) or getattr(result, 'output', None)
            if result_images is None:
                raise Exception(f"Could not extract images from diffusers result: {img_err}")
        
        for img in result_images:
            # Convert to base64
            import base64
            import io
            import numpy as np
            
            # Handle NaN/Inf values in image data - convert to valid values
            if isinstance(img, np.ndarray):
                # Replace NaN and Inf with valid values
                img = np.nan_to_num(img, nan=0.0, posinf=1.0, neginf=0.0)
                # Clip to valid range [0, 1]
                img = np.clip(img, 0.0, 1.0)
            
            # Use helper function to save and get response
            img_data = save_image_response(img, request.response_format, http_request)
            images.append(img_data)
        
        return {
            "created": timestamp,
            "data": images
        }
        
    except ImportError as e:
        # diffusers/torch not installed - record error and try sd.cpp
        diffusers_error = str(e)
        print(f"diffusers not available: {diffusers_error}, trying stable-diffusion-cpp-python...")
    except Exception as e:
        # Other error with diffusers - record and try sd.cpp
        import traceback
        diffusers_error = str(e)
        print(f"diffusers error: {diffusers_error}")
        print(f"Traceback: {traceback.format_exc()}")
        print(f"Trying stable-diffusion-cpp-python...")
    
    # Try stable-diffusion-cpp-python (sd.cpp) as fallback
    # First, check all available image models to find one loaded via sd.cpp
    # Always check for cached models - allows dynamically loaded models to be reused across requests
    sd_model = None
    for key in multi_model_manager.models:
        if key.startswith("image:"):
            potential_model = multi_model_manager.get_model(key)
            if potential_model is not None:
                # Check if it's a stable-diffusion-cpp model
                try:
                    from stable_diffusion_cpp import StableDiffusion
                    if isinstance(potential_model, StableDiffusion):
                        sd_model = potential_model
                        print(f"Found cached stable-diffusion-cpp model with key: {key}")
                        break
                except ImportError:
                    pass
    
    # If no cached image model found, need to load one - first cleanup any existing models
    if sd_model is None:
        # Check if there's a text model loaded and unload it to free VRAM
        # Cleanup ALL models except the one we're about to load
        for key in list(multi_model_manager.models.keys()):
            # Skip the image model we'll be loading (if we find it later)
            # For now, cleanup all other models
            if key.startswith("image:"):
                continue
            # Unload any other model (text, audio, etc.) to free VRAM
            model_to_cleanup = multi_model_manager.models.get(key)
            if model_to_cleanup is not None:
                print(f"Unloading '{key}' from VRAM to make room for image model")
                try:
                    if hasattr(model_to_cleanup, 'cleanup') and callable(getattr(model_to_cleanup, 'cleanup')):
                        model_to_cleanup.cleanup()
                    elif hasattr(model_to_cleanup, 'model') and model_to_cleanup.model is not None:
                        if hasattr(model_to_cleanup.model, 'cleanup'):
                            model_to_cleanup.model.cleanup()
                except Exception as e:
                    print(f"Warning during cleanup of '{key}': {e}")
                del multi_model_manager.models[key]
    
    if sd_model is not None:
        # Check if it's a stable-diffusion-cpp model (has generate method from sd.cpp)
        try:
            from stable_diffusion_cpp import StableDiffusion
            if isinstance(sd_model, StableDiffusion):
                print(f"Using stable-diffusion-cpp-python for image generation")
                # Use sd.cpp for generation
                # Parse size
                width, height = 512, 512
                if request.size:
                    parts = request.size.split("x")
                    if len(parts) == 2:
                        try:
                            width = int(parts[0])
                            height = int(parts[1])
                        except ValueError:
                            pass
                
                # Use default steps for Z-Image Turbo (very fast)
                steps = 4  # Default for fast generation
                
                # Generate images using sd.cpp (run in thread to not block event loop)
                # Use request seed if provided, otherwise use CLI default seed
                seed = request.seed if request.seed is not None else getattr(global_args, 'image_seed', None)
                
                result = await asyncio.to_thread(
                    sd_model.generate_image,
                    prompt=request.prompt,
                    negative_prompt='',
                    width=width,
                    height=height,
                    cfg_scale=get_cfg_scale(),
                    sample_steps=steps,
                    seed=seed if seed is not None else 42,
                    batch_count=request.n if request.n else 1,
                )
                
                # Small delay to let Vulkan driver settle after generation
                import time
                time.sleep(0.1)
                
                # Convert results to response format
                images = []
                import base64
                import io
                from PIL import Image
                
                for img in result:
                    # Use helper function to save and get response
                    img_data = save_image_response(img, http_request=http_request)
                    images.append(img_data)
                
                return {
                    "created": int(time.time()),
                    "data": images
                }
        except ImportError as e:
            # stable-diffusion-cpp not available
            sd_cpp_error = str(e)
            print(f"stable-diffusion-cpp-python not available: {sd_cpp_error}")
        except Exception as e:
            print(f"sd.cpp generation error: {e}")
            sd_cpp_error = str(e)
    else:
        # No sd.cpp model pre-loaded, try to load dynamically
        print("No pre-loaded sd.cpp model found, trying to load...")
        try:
            from stable_diffusion_cpp import StableDiffusion
            
            # Check if model_to_use is a URL and get cached path
            # Also handle HuggingFace model IDs that need to be resolved
            model_path = None
            if model_to_use.startswith('http://') or model_to_use.startswith('https://'):
                cached_path = get_cached_model_path(model_to_use)
                if cached_path:
                    model_path = cached_path
                    print(f"Using cached model: {model_path}")
                else:
                    # Not cached - download it
                    print(f"Downloading model: {model_to_use}")
                    cache_dir = get_model_cache_dir()
                    model_path = download_model(model_to_use, cache_dir)
                    print(f"Downloaded to: {model_path}")
            elif os.path.isfile(model_to_use):
                model_path = model_to_use
            else:
                # Try to resolve as HuggingFace model ID
                print(f"Trying to resolve as HuggingFace model ID: {model_to_use}")
                try:
                    from huggingface_hub import hf_hub_download, list_repo_files
                    
                    # Parse model name (format: "org/model" or "org/model/filename.gguf")
                    parts = model_to_use.split('/')
                    if len(parts) >= 2:
                        repo_id = f"{parts[0]}/{parts[1]}"
                        
                        # First check if there's a cached GGUF file for this model
                        # Try common GGUF file patterns
                        files = list_repo_files(repo_id)
                        gguf_files = [f for f in files if f.endswith('.gguf')]
                        
                        if gguf_files:
                            # Try to find a cached version first
                            for gguf_file in gguf_files:
                                # Construct potential URL and check cache
                                potential_url = f"https://huggingface.co/{repo_id}/resolve/main/{gguf_file}"
                                cached = get_cached_model_path(potential_url)
                                if cached:
                                    model_path = cached
                                    print(f"Using cached GGUF model: {model_path}")
                                    break
                            
                            # If not cached, download the first GGUF file
                            if not model_path:
                                print(f"Downloading GGUF model from HF: {gguf_files[0]}")
                                model_path = hf_hub_download(repo_id=repo_id, filename=gguf_files[0])
                                print(f"Downloaded to: {model_path}")
                except Exception as e:
                    print(f"Could not resolve as HuggingFace model: {e}")
            
            if model_path is None:
                print("Warning: Could not resolve sd.cpp model path")
                sd_cpp_error = "Could not resolve model path"
            else:
                # Load sd.cpp model
                # Determine backend to use based on CLI args
                backend = getattr(global_args, 'backend', 'auto')
                image_backend = getattr(global_args, 'image_backend', 'auto')
                
                # Use CUDA only if explicitly requested via --backend nvidia or --image-backend nvidia
                use_cuda = (backend == 'nvidia' or backend == 'cuda' or 
                           image_backend == 'nvidia' or image_backend == 'cuda')
                
                if use_cuda:
                    print(f"Using CUDA backend for sd.cpp image generation")
                else:
                    print(f"Using Vulkan backend for sd.cpp image generation")
                
                # Build kwargs for stable-diffusion-cpp with CLI args
                sd_kwargs = {'diffusion_model_path': model_path}
                
                # Add VAE path from CLI args if provided
                vae_path = getattr(global_args, 'vae_path', None)
                if vae_path:
                    # Check if it's a URL and download if needed
                    if vae_path.startswith('http://') or vae_path.startswith('https://'):
                        cached = get_cached_model_path(vae_path)
                        if cached:
                            sd_kwargs['vae_path'] = cached
                            print(f"Using cached VAE model: {cached}")
                        else:
                            cache_dir = get_model_cache_dir()
                            sd_kwargs['vae_path'] = download_model(vae_path, cache_dir)
                    else:
                        sd_kwargs['vae_path'] = vae_path
                
                # Add LLM/CLIP path from CLI args if provided
                llm_path = getattr(global_args, 'llm_path', None)
                if llm_path:
                    if llm_path.startswith('http://') or llm_path.startswith('https://'):
                        cached = get_cached_model_path(llm_path)
                        if cached:
                            sd_kwargs['llm_path'] = cached
                            print(f"Using cached LLM model: {cached}")
                        else:
                            cache_dir = get_model_cache_dir()
                            sd_kwargs['llm_path'] = download_model(llm_path, cache_dir)
                    else:
                        sd_kwargs['llm_path'] = llm_path
                
                # Add T5XXL path from CLI args if provided
                t5xxl_path = getattr(global_args, 't5xxl_path', None)
                if t5xxl_path:
                    if t5xxl_path.startswith('http://') or t5xxl_path.startswith('https://'):
                        cached = get_cached_model_path(t5xxl_path)
                        if cached:
                            sd_kwargs['t5xxl_path'] = cached
                            print(f"Using cached T5XXL model: {cached}")
                        else:
                            cache_dir = get_model_cache_dir()
                            sd_kwargs['t5xxl_path'] = download_model(t5xxl_path, cache_dir)
                    else:
                        sd_kwargs['t5xxl_path'] = t5xxl_path
                
                # Add clip_on_cpu if specified
                if getattr(global_args, 'clip_on_cpu', False):
                    sd_kwargs['keep_clip_on_cpu'] = True
                    print(f"DEBUG: Running CLIP on CPU to save VRAM (keep_clip_on_cpu=True)")
                
                # Use all available CPU cores
                import psutil
                sd_kwargs['n_threads'] = psutil.cpu_count()
                
                sd_model = StableDiffusion(**sd_kwargs)
                
                # Cache the model for reuse on subsequent requests
                cache_key = f"image:{model_path}"
                multi_model_manager.add_model(cache_key, sd_model)
                print(f"Using stable-diffusion-cpp-python for image generation")
                
                # Generate images
                width, height = 512, 512
                if request.size:
                    parts = request.size.split("x")
                    if len(parts) == 2:
                        try:
                            width = int(parts[0])
                            height = int(parts[1])
                        except ValueError:
                            pass
                
                steps = 4
                
                # Use request seed if provided, otherwise use CLI default seed
                seed = request.seed if request.seed is not None else getattr(global_args, 'image_seed', None)
                
                result = await asyncio.to_thread(
                    sd_model.generate_image,
                    prompt=request.prompt,
                    negative_prompt='',
                    width=width,
                    height=height,
                    cfg_scale=get_cfg_scale(),
                    sample_steps=steps,
                    seed=seed if seed is not None else 42,
                    batch_count=request.n if request.n else 1,
                )
                
                # Small delay to let Vulkan driver settle after generation
                import time
                time.sleep(0.1)
                
                # Convert results to response format
                images = []
                import base64
                import io
                from PIL import Image
                
                for img in result:
                    # Use helper function to save and get response
                    img_data = save_image_response(img, http_request=http_request)
                    images.append(img_data)
                
                return {
                    "created": int(time.time()),
                    "data": images
                }
        except ImportError as e:
            sd_cpp_error = str(e)
            print(f"stable-diffusion-cpp-python not available: {sd_cpp_error}")
        except Exception as e:
            sd_cpp_error = str(e)
            print(f"sd.cpp error: {sd_cpp_error}")
    
    # Both backends failed - return error with installation instructions
    raise HTTPException(
        status_code=400,
        detail=f"Model '{model_to_use}' does not support image generation"
    )
# =============================================================================
# Text-to-Speech Endpoint
# =============================================================================

class TTSRequest(BaseModel):
    model: str
    input: str
    voice: Optional[str] = "af_sarah"
    response_format: Optional[str] = "mp3"
    speed: Optional[float] = 1.0
    
    model_config = ConfigDict(extra="allow")
class TTSResponse(BaseModel):
    audio: str  # base64 encoded audio
    model_config = ConfigDict(extra="allow")
@app.post("/v1/audio/speech")
async def create_speech(request: TTSRequest):
    """
    Text-to-speech endpoint (OpenAI-compatible).
    
    Supports:
    - Kokoro TTS models (when --tts-model is specified)
    """
    tts_model = multi_model_manager.tts_model
    
    # If no TTS model configured, return an error
    if not tts_model:
        raise HTTPException(
            status_code=400,
            detail="TTS not configured. Use --tts-model to specify a model."
        )
    
    # Determine model to use
    model_to_use = request.model
    if model_to_use.startswith("tts:"):
        model_to_use = tts_model
    
    # Try to use kokoro if available
    try:
        from kokoro import Kokoro
        
        # Determine model key
        model_key = f"tts:{model_to_use}"
        kokoro_model = multi_model_manager.get_model(model_key)
        
        if kokoro_model is None:
            print(f"Loading Kokoro TTS model: {model_to_use}")
            
            # Check if model_to_use is a URL - download it (with caching)
            model_path = None
            if model_to_use.startswith('http://') or model_to_use.startswith('https://'):
                # Check cache first
                cached_path = get_cached_model_path(model_to_use)
                if cached_path:
                    model_path = cached_path
                    print(f"Using cached model: {model_path}")
                else:
                    print(f"Downloading model from URL: {model_to_use}")
                    try:
                        import requests
                        import hashlib
                        
                        # Get cache directory
                        cache_dir = get_model_cache_dir()
                        
                        # Extract filename from URL
                        url_path = model_to_use.split('?')[0]
                        filename = os.path.basename(url_path)
                        
                        if not filename.endswith('.pt') and not filename.endswith('.bin'):
                            filename = "kokoro-model.pt"
                        
                        # Create safe filename in cache
                        url_hash = hashlib.sha256(model_to_use.encode()).hexdigest()
                        cached_filename = f"{url_hash}_{filename}"
                        model_path = os.path.join(cache_dir, cached_filename)
                        
                        # Download to cache
                        response = requests.get(model_to_use, stream=True)
                        response.raise_for_status()
                        
                        total_size = int(response.headers.get('content-length', 0))
                        downloaded = 0
                        
                        with open(model_path, 'wb') as f:
                            for chunk in response.iter_content(chunk_size=8192*1024):
                                if chunk:
                                    f.write(chunk)
                                    downloaded += len(chunk)
                                    if total_size > 0:
                                        percent = (downloaded / total_size) * 100
                                        print(f"Downloaded: {percent:.1f}%", end='\r')
                        
                        print(f"\nDownloaded and cached to: {model_path}")
                        
                    except Exception as e:
                        print(f"Error downloading model: {e}")
                        raise
            else:
                # Use local path or model name
                model_path = model_to_use
            
            # Load the Kokoro model
            kokoro_model = Kokoro(model_path if model_path else model_to_use)
            multi_model_manager.add_model(model_key, kokoro_model)
        
        # Generate speech
        voice = request.voice or "af_sarah"
        speed = request.speed or 1.0
        
        audio_bytes = kokoro_model.generate(request.input, voice=voice, speed=speed)
        
        # Convert to base64
        import base64
        audio_base64 = base64.b64encode(audio_bytes).decode('utf-8')
        
        return {
            "audio": audio_base64
        }
        
    except ImportError as e:
        # kokoro not installed
        raise HTTPException(
            status_code=501,
            detail=f"TTS not available. Install kokoro: pip install kokoro. Error: {str(e)}"
        )
    except Exception as e:
        print(f"TTS error: {e}")
        import traceback
        traceback.print_exc()
        raise HTTPException(status_code=500, detail=f"TTS error: {str(e)}")
@app.post("/v1/chat/completions")
async def chat_completions(request: ChatCompletionRequest, http_request: Request = None):
    """Chat completions endpoint with streaming and tool support."""
    
    # Check if we should use litellm backend
    parser_type = getattr(global_args, 'parser', 'auto') if global_args else 'auto'
    
    if parser_type == 'litellm':
        # Use LiteLLM backend
        from codai.openai.litellm import get_litellm_backend, LITELLM_AVAILABLE
        
        if not LITELLM_AVAILABLE:
            raise HTTPException(
                status_code=500,
                detail="LiteLLM is not installed. Run: pip install litellm"
            )
        
        # Check for API key in request - litellm requires an API key
        # If not provided, use a fake key to allow the request to proceed
        api_key = None
        
        # Try to get API key from request body
        if hasattr(request, 'api_key') and request.api_key:
            api_key = request.api_key
        
        # If no API key in body, try to get from Authorization header
        if not api_key:
            auth_header = http_request.headers.get('Authorization', '') if http_request else ''
            if auth_header.startswith('Bearer '):
                api_key = auth_header[7:]  # Extract token after 'Bearer '
        
        # If still no API key, use a fake key to allow litellm to proceed
        # litellm will then fail with the actual provider error if needed
        if not api_key:
            api_key = "fake-key-for-local-testing"
            print("DEBUG: No API key provided, using fake key for litellm")
        
        # Determine the base URL for litellm to connect to
        # Use the server's host and port for local connections
        api_base = None
        
        # Check if model starts with 'ollama:' - use local Ollama
        if request.model and request.model.startswith('ollama:'):
            # Get the host from the request headers
            client_host = "127.0.0.1"
            if http_request:
                host_header = http_request.headers.get('host', '')
                if host_header:
                    # Strip port if present
                    if ':' in host_header:
                        client_host = host_header.split(':')[0]
                        if client_host.replace('.', '').isdigit():
                            # It's an IP, keep it
                            pass
                        else:
                            # It's a hostname, use localhost
                            client_host = "127.0.0.1"
                    else:
                        client_host = host_header
            
            # Get port from global_args or use default
            port = getattr(global_args, 'port', 11434) if global_args else 11434
            api_base = f"http://{client_host}:{port}/v1"
            print(f"DEBUG: Using api_base for Ollama: {api_base}")
        else:
            # For non-Ollama models, use the server's own URL as base
            # This allows LiteLLM to make requests to the local server
            if http_request:
                # Get the host from the request headers
                host_header = http_request.headers.get('host', '')
                if host_header:
                    # Strip port if present to reconstruct clean URL
                    if ':' in host_header:
                        client_host = host_header.split(':')[0]
                        # Keep the port from the request for consistency
                        server_port = host_header.split(':')[1] if len(host_header.split(':')) > 1 else str(getattr(global_args, 'port', 6745))
                    else:
                        client_host = host_header
                        server_port = str(getattr(global_args, 'port', 6745))
                else:
                    # Fallback to client host if no Host header
                    client_host = http_request.client.host if http_request.client else "127.0.0.1"
                    server_port = str(getattr(global_args, 'port', 6745))
            else:
                # Fallback if no http_request
                client_host = "127.0.0.1"
                server_port = str(getattr(global_args, 'port', 6745))
            
            # Determine protocol (http or https)
            use_https = getattr(global_args, 'https', False) or getattr(global_args, 'pubkey', None)
            protocol = "https" if use_https else "http"
            api_base = f"{protocol}://{client_host}:{server_port}/v1"
            print(f"DEBUG: Using api_base for local server: {api_base}")
        
        # Get or create litellm backend
        litellm_backend = get_litellm_backend(
            model=request.model,
            api_key=api_key,
            api_base=api_base,
            context_window=8192,  # Default, can be made configurable
            model_manager=multi_model_manager  # Pass for alias resolution
        )
        
        # Get the tool_parser from multi_model_manager for model-specific parsing
        tool_parser = multi_model_manager.tool_parser if hasattr(multi_model_manager, 'tool_parser') else None
        
        # Convert messages to dict format
        messages_dict = []
        for msg in request.messages:
            msg_dict = {"role": msg.role, "content": msg.content or ""}
            if hasattr(msg, 'tool_calls') and msg.tool_calls:
                msg_dict["tool_calls"] = msg.tool_calls
            if hasattr(msg, 'tool_call_id') and msg.tool_call_id:
                msg_dict["tool_call_id"] = msg.tool_call_id
            messages_dict.append(msg_dict)
        
        # Prepare tools if provided
        tools_dict = None
        if request.tools:
            tools_dict = request.tools
        
        # Generate response
        try:
            if request.stream:
                # Streaming response
                
                async def generate():
                    try:
                        async for chunk in await litellm_backend.chat_completion(
                            messages=messages_dict,
                            model=request.model,
                            temperature=request.temperature,
                            top_p=request.top_p,
                            max_tokens=request.max_tokens,
                            stop=request.stop,
                            tools=tools_dict,
                            tool_choice=request.tool_choice,
                            stream=True,
                            tool_parser=tool_parser,
                        ):
                            # Add rate limit headers
                            headers = {}
                            if 'usage' in chunk:
                                headers = litellm_backend.get_rate_limit_headers(
                                    prompt_tokens=chunk.get('usage', {}).get('prompt_tokens', 0),
                                    completion_tokens=chunk.get('usage', {}).get('completion_tokens', 0)
                                )
                            
                            # Handle Qwen tool calls if model is Qwen family
                            if 'qwen' in request.model.lower():
                                content = chunk.get('choices', [{}])[0].get('delta', {}).get('content', '')
                                tool_calls = chunk.get('choices', [{}])[0].get('delta', {}).get('tool_calls', [])
                                
                                if not tool_calls and content:
                                    # Try to parse tool calls from content
                                    tool_calls = litellm_backend.parse_qwen_tool_calls(content)
                                    if tool_calls:
                                        # Strip tool tags from content
                                        content = litellm_backend.strip_tool_tags(content)
                                        chunk['choices'][0]['delta']['content'] = content
                                        chunk['choices'][0]['delta']['tool_calls'] = tool_calls
                            
                            yield f"data: {json.dumps(chunk)}\n\n"
                        
                        yield "data: [DONE]\n\n"
                    except Exception as e:
                        yield f"data: {json.dumps({'error': {'message': str(e), 'type': 'internal_error'}})}\n\n"
                
                return StreamingResponse(generate(), media_type="text/event-stream")
            else:
                # Non-streaming response
                response = await litellm_backend.chat_completion(
                    messages=messages_dict,
                    model=request.model,
                    temperature=request.temperature,
                    top_p=request.top_p,
                    max_tokens=request.max_tokens,
                    stop=request.stop,
                    tools=tools_dict,
                    tool_choice=request.tool_choice,
                    stream=False,
                    tool_parser=tool_parser,
                )
                
                # Handle Qwen tool calls
                if 'qwen' in request.model.lower() and 'choices' in response:
                    msg = response['choices'][0].get('message', {})
                    content = msg.get('content', '')
                    tool_calls = msg.get('tool_calls', [])
                    
                    if not tool_calls and content:
                        tool_calls = litellm_backend.parse_qwen_tool_calls(content)
                        if tool_calls:
                            msg['content'] = litellm_backend.strip_tool_tags(content)
                            msg['tool_calls'] = tool_calls
                            response['choices'][0]['message'] = msg
                
                # Add rate limit headers
                headers = {}
                if 'usage' in response:
                    headers = litellm_backend.get_rate_limit_headers(
                        prompt_tokens=response.get('usage', {}).get('prompt_tokens', 0),
                        completion_tokens=response.get('usage', {}).get('completion_tokens', 0)
                    )
                
                
        except Exception as e:
            # Handle litellm errors
            error_response = {
                "error": {
                    "message": str(e),
                    "type": "internal_error",
                    "code": 500
                }
            }
            return JSONResponse(content=error_response, status_code=500)
    
    # Continue with original implementation for 'auto' parser
    # Get the model for this request
    requested_model = request.model
    
    # Try to get the appropriate model
    mm = multi_model_manager.get_model_for_request(requested_model)
    
    if mm is None:
        # Model not loaded - try to use default
        if model_manager.backend is not None:
            # Fallback to legacy model_manager
            current_manager = model_manager
        else:
            raise HTTPException(status_code=503, detail="Model not loaded")
    else:
        current_manager = mm
    
    # Inject system prompt if --system-prompt flag was provided
    messages = request.messages
    if global_system_prompt is not None:
        # Check if there's already a system message
        has_system = any(msg.role == "system" for msg in messages)
        if not has_system:
            # Use default or custom system prompt
            if global_system_prompt is True:
                # Default system prompt
                system_text = "You are a helpful assistant."
            else:
                # Custom system prompt provided as argument
                system_text = str(global_system_prompt)
            # Insert system message at the beginning
            messages = [ChatMessage(role="system", content=system_text)] + list(messages)
    
    # Format messages with tools if provided
    if request.tools:
        messages = format_tools_for_prompt(request.tools, messages)
    
    # Get the tool_parser from the current manager
    tool_parser = current_manager.tool_parser if hasattr(current_manager, 'tool_parser') else ModelParserAdapter()
    
    # Prepare stop sequences
    stop_sequences = []
    if request.stop:
        if isinstance(request.stop, str):
            stop_sequences = [request.stop]
        else:
            stop_sequences = request.stop
    
    # Convert messages to dict format for chat completion
    messages_dict = []
    for msg in messages:
        msg_dict = {"role": msg.role}
        # Always include content key - llama_cpp template expects it
        # Convert content to string if it's a list (multipart content)
        content = msg.content
        if content is None:
            content = ""
        elif isinstance(content, list):
            # Handle multipart content array format: [{"type": "text", "text": "..."}]
            parts = []
            for item in content:
                if isinstance(item, dict):
                    if item.get('type') == 'text' and 'text' in item:
                        parts.append(item['text'])
                    else:
                        parts.append(f"[{item.get('type', 'unknown')} content]")
                else:
                    parts.append(str(item))
            content = '\n'.join(parts)
        # Ensure content is never None - convert to string
        msg_dict["content"] = str(content) if content is not None else ""
        # Handle tool_calls - convert to proper format if present
        if msg.tool_calls:
            # tool_calls should be a list of dicts with 'id', 'type', 'function' keys
            msg_dict["tool_calls"] = msg.tool_calls
        if msg.name:
            msg_dict["name"] = msg.name
        if msg.tool_call_id:
            msg_dict["tool_call_id"] = msg.tool_call_id
        messages_dict.append(msg_dict)
    
    # Final safety check: ensure NO message has None content before passing to llama_cpp
    # Also ensure content key always exists (not just None check)
    for i, m in enumerate(messages_dict):
        # Handle missing content key entirely
        if "content" not in m:
            messages_dict[i]["content"] = ""
        # Handle None content
        elif m.get("content") is None:
            messages_dict[i]["content"] = ""
        # Handle content that's not a string (shouldn't happen but be safe)
        elif not isinstance(m["content"], str):
            messages_dict[i]["content"] = str(m["content"])
    
    # Debug: print first few messages to see their structure
    print(f"DEBUG: messages_dict[0] keys: {list(messages_dict[0].keys()) if messages_dict else 'empty'}")
    if len(messages_dict) > 1:
        print(f"DEBUG: messages_dict[1] keys: {list(messages_dict[1].keys()) if len(messages_dict) > 1 else 'empty'}")
    
    # Convert tools to dict format if present
    tools_dict = None
    if request.tools:
        tools_dict = []
        for tool in request.tools:
            tools_dict.append({
                "type": tool.type,
                "function": {
                    "name": tool.function.name,
                    "description": tool.function.description,
                    "parameters": tool.function.parameters
                }
            })
    
    if request.stream:
        return StreamingResponse(
            stream_chat_response(
                messages_dict,
                request.model,
                request.max_tokens,
                request.temperature,
                request.top_p,
                stop_sequences,
                tools_dict,
                current_manager,
                tool_parser,
                request.response_format,
            ),
            media_type="text/event-stream",
        )
    else:
        return await generate_chat_response(
            messages_dict,
            request.model,
            request.max_tokens,
            request.temperature,
            request.top_p,
            stop_sequences,
            tools_dict,
            current_manager,
            tool_parser,
            request.response_format,
        )

async def stream_chat_response(
    messages: List[Dict],
    model_name: str,
    max_tokens: Optional[int],
    temperature: float,
    top_p: float,
    stop: List[str],
    tools: Optional[List[Dict]],
    current_manager: ModelManager,
    tool_parser: ToolCallParser,
    response_format: Optional[Dict] = None,
) -> AsyncGenerator[str, None]:
    """Stream chat completion response with queue notifications."""
    completion_id = f"chatcmpl-{uuid.uuid4().hex}"
    created = int(time.time())
    request_id = f"req-{uuid.uuid4().hex[:8]}"
    
    generated_text = ""
    print(f"DEBUG: stream_chat_response started, stream=True, tools={tools is not None}")
    
    # Check if model is loaded - if not, notify waiting clients
    # The model manager exists but backend may not be loaded yet in on-demand mode
    model_loaded = False
    if current_manager is not None:
        if hasattr(current_manager, 'backend') and current_manager.backend is not None:
            # Check if backend has the model loaded
            if hasattr(current_manager.backend, 'model') and current_manager.backend.model is not None:
                model_loaded = True
        elif hasattr(current_manager, 'model') and current_manager.model is not None:
            # Alternative check for some model managers
            model_loaded = True
    
    # If model not loaded, add to queue and send waiting notifications
    if not model_loaded:
        await queue_manager.add_waiting(request_id)
        wait_interval = 2.0  # Send waiting update every 2 seconds
        last_wait_update = time.time()
        
        # Send initial waiting message
        data = {
            "id": completion_id,
            "object": "chat.completion.chunk",
            "created": created,
            "model": model_name,
            "choices": [{
                "index": 0,
                "delta": {"content": "Waiting for model to load..."},
                "finish_reason": None,
            }],
            "x_queue_info": {
                "status": "waiting",
                "message": "Model is loading, please wait...",
            },
        }
        yield f"data: {json.dumps(data)}\n\n"
        
        # Keep sending wait updates until model is loaded
        # In a real implementation, this would check a loading status
        # For now, we'll send a few updates then proceed
        max_wait_updates = 5
        wait_count = 0
        while wait_count < max_wait_updates:
            await asyncio.sleep(wait_interval)
            wait_time = await queue_manager.get_wait_time(request_id)
            wait_count += 1
            
            queue_pos = await queue_manager.get_queue_position(request_id)
            
            data = {
                "id": completion_id,
                "object": "chat.completion.chunk",
                "created": created,
                "model": model_name,
                "choices": [{
                    "index": 0,
                    "delta": {"content": f""},
                    "finish_reason": None,
                }],
                "x_queue_info": {
                    "status": "waiting",
                    "message": f"Waiting for model... ({int(wait_time)}s)",
                    "queue_position": queue_pos,
                    "wait_time_seconds": int(wait_time),
                },
            }
            yield f"data: {json.dumps(data)}\n\n"
    
    # Mark as starting processing
    await queue_manager.start_processing(request_id, model_name)
    
    # Send "Model starting" message
    data = {
        "id": completion_id,
        "object": "chat.completion.chunk",
        "created": created,
        "model": model_name,
        "choices": [{
            "index": 0,
            "delta": {"content": ""},
            "finish_reason": None,
        }],
        "x_queue_info": {
            "status": "starting",
            "message": "Model starting",
        },
    }
    yield f"data: {json.dumps(data)}\n\n"
    
    try:
        chunk_count = 0
        # Use generate_chat_stream for proper chat template handling
        async for chunk in current_manager.generate_chat_stream(
            messages=messages,
            max_tokens=max_tokens,
            temperature=temperature,
            top_p=top_p,
            stop=stop,
            tools=tools,
            response_format=response_format,
        ):
            chunk_count += 1
            # Filter malformed content from each chunk (only if --reply-filters is set)
            filtered_chunk = chunk
            if check_reply_filter('malformed', 'text', model_name):
                filtered_chunk = filter_malformed_content(filtered_chunk)
            
            # Filter out tool call format - only if --reply-filters is set
            if check_reply_filter('tool_calls', 'text', model_name):
                filtered_chunk = tool_parser.strip_tool_calls_from_content(filtered_chunk)
            
            # Pass through all content including whitespace - it's essential for message composition
            generated_text += filtered_chunk
            
            data = {
                "id": completion_id,
                "object": "chat.completion.chunk",
                "created": created,
                "model": model_name,
                "choices": [{
                    "index": 0,
                    "delta": {"content": filtered_chunk},
                    "finish_reason": None,
                }],
            }
            yield f"data: {json.dumps(data)}\n\n"
            # Explicitly flush to ensure data is sent immediately
            await asyncio.sleep(0)
        
        print(f"DEBUG: stream_chat_response completed, {chunk_count} chunks, generated_text length: {len(generated_text)}")
        if not generated_text.strip():
            print(f"DEBUG: Warning - no content generated!")
        
        # In debug mode, dump the full generated text
        if global_debug:
            print(f"\n{'='*80}")
            print(f"=== FULL GENERATED TEXT (DEBUG) ===")
            print(f"{'='*80}")
            # Show both raw (actual) content and escaped representation
            print(f"--- RAW CONTENT (actual newlines shown as lines) ---")
            print(generated_text)
            print(f"--- END RAW CONTENT ---")
            print(f"--- ESCAPED CONTENT (repr() - shows \\n for newlines) ---")
            print(repr(generated_text))
            print(f"--- END ESCAPED CONTENT ---")
            print(f"{'='*80}\n")
        
        # Check for tool calls in complete output (for API response format)
        if tools:
            # Convert tools back to Tool objects for parsing
            from typing import cast
            tool_objects = []
            for t in tools:
                try:
                    # Handle both dict and pydantic model formats
                    if isinstance(t, dict):
                        func_data = t.get("function", {})
                        tool_func = ToolFunction(
                            name=func_data.get("name", ""),
                            description=func_data.get("description"),
                            parameters=func_data.get("parameters")
                        )
                    else:
                        # Pydantic model
                        tool_func = ToolFunction(
                            name=t.function.name if hasattr(t.function, 'name') else str(t.function),
                            description=t.function.description if hasattr(t.function, 'description') else None,
                            parameters=t.function.parameters if hasattr(t.function, 'parameters') else None
                        )
                    tool_objects.append(Tool(type=t.get("type", "function") if isinstance(t, dict) else t.type, function=tool_func))
                except Exception as e:
                    print(f"DEBUG: Error converting tool: {e}, tool type: {type(t)}")
                    continue
            try:
                tool_calls = tool_parser.extract_tool_calls(generated_text, tool_objects)
            except Exception as e:
                print(f"DEBUG: Error extracting tool calls: {e}")
                tool_calls = None
            if tool_calls:
                # In debug mode, dump tool calls
                if global_debug:
                    print(f"\n{'='*80}")
                    print(f"=== EXTRACTED TOOL CALLS (DEBUG) ===")
                    print(f"{'='*80}")
                    print(json.dumps(tool_calls, indent=2))
                    print(f"{'='*80}\n")
                # Tool calls were extracted and stripped from content during streaming
                # Just send the tool_calls chunk
                data = {
                    "id": completion_id,
                    "object": "chat.completion.chunk",
                    "created": created,
                    "model": model_name,
                    "choices": [{
                        "index": 0,
                        "delta": {"tool_calls": tool_calls},
                        "finish_reason": "tool_calls",
                        "logprobs": None,
                        "native_finish_reason": "tool_calls",
                    }],
                }
                yield f"data: {json.dumps(data)}\n\n"
            else:
                # Calculate token counts for usage in final chunk
                prompt_text = "\n".join([m.get("content", "") for m in messages])
                prompt_tokens = len(prompt_text.split())
                completion_tokens = len(generated_text.split()) if generated_text else 0
                
                # Use OpenAIFormatter for final chunk sanitization
                formatter = OpenAIFormatter(model_name)
                usage_details = {
                    "prompt_tokens": prompt_tokens,
                    "completion_tokens": completion_tokens,
                    "total_tokens": prompt_tokens + completion_tokens,
                }
                final_chunk = formatter.format_litellm_chunk("", is_final=True, usage=usage_details)
                yield f"data: {json.dumps(final_chunk)}\n\n"
        else:
            # Calculate token counts for usage in final chunk
            prompt_text = "\n".join([m.get("content", "") for m in messages])
            prompt_tokens = len(prompt_text.split())
            completion_tokens = len(generated_text.split()) if generated_text else 0
            
            # Build complete final chunk with all OpenAI fields
            final_chunk = {
                "id": completion_id,
                "object": "chat.completion.chunk",
                "created": created,
                "model": model_name,
                "choices": [{
                    "index": 0,
                    "finish_reason": "stop",
                    "logprobs": None,
                    "native_finish_reason": "stop",
                }],
                "usage": {
                    "prompt_tokens": prompt_tokens,
                    "completion_tokens": completion_tokens,
                    "total_tokens": prompt_tokens + completion_tokens,
                    "prompt_tokens_details": {
                        "cached_tokens": 0,
                        "audio_tokens": 0,
                    },
                    "completion_tokens_details": {
                        "reasoning_tokens": 0,
                        "audio_tokens": 0,
                    },
                },
                "provider": {
                    "provider_name": "coderai",
                    "provider_id": "coderai",
                },
                "system_fingerprint": None,
            }
            yield f"data: {json.dumps(final_chunk)}\n\n"
        
        yield "data: [DONE]\n\n"
    except Exception as e:
        print(f"Error during streaming generation: {e}")
        data = {
            "id": completion_id,
            "object": "chat.completion.chunk",
            "created": created,
            "model": model_name,
            "choices": [{
                "index": 0,
                "delta": {"content": f"\n[Generation error: {str(e)}]"},
                "finish_reason": "stop",
            }],
        }
        yield f"data: {json.dumps(data)}\n\n"
        yield "data: [DONE]\n\n"
    finally:
        # Always clean up queue state
        await queue_manager.finish_processing()
async def generate_chat_response(
    messages: List[Dict],
    model_name: str,
    max_tokens: Optional[int],
    temperature: float,
    top_p: float,
    stop: List[str],
    tools: Optional[List[Dict]],
    current_manager: ModelManager,
    tool_parser: ToolCallParser,
    response_format: Optional[Dict] = None,
) -> Dict:
    """Generate non-streaming chat completion response."""
    completion_id = f"chatcmpl-{uuid.uuid4().hex}"
    created = int(time.time())
    
    try:
        # Use generate_chat for proper chat template handling
        generated_text = current_manager.generate_chat(
            messages=messages,
            max_tokens=max_tokens,
            temperature=temperature,
            top_p=top_p,
            stop=stop,
            tools=tools,
            response_format=response_format,
        )
        
        # Filter out malformed content from generated text (only if --reply-filters is set)
        if check_reply_filter('malformed', 'text', model_name):
            generated_text = filter_malformed_content(generated_text)
        
        response_message = {
            "role": "assistant",
            "content": generated_text,
        }
        
        finish_reason = "stop"
        
        # Check for tool calls
        if tools:
            # Convert tools back to Tool objects for parsing
            tool_objects = []
            for t in tools:
                try:
                    # Handle both dict and pydantic model formats
                    if isinstance(t, dict):
                        func_data = t.get("function", {})
                        tool_func = ToolFunction(
                            name=func_data.get("name", ""),
                            description=func_data.get("description"),
                            parameters=func_data.get("parameters")
                        )
                    else:
                        # Pydantic model
                        tool_func = ToolFunction(
                            name=t.function.name if hasattr(t.function, 'name') else str(t.function),
                            description=t.function.description if hasattr(t.function, 'description') else None,
                            parameters=t.function.parameters if hasattr(t.function, 'parameters') else None
                        )
                    tool_objects.append(Tool(type=t.get("type", "function") if isinstance(t, dict) else t.type, function=tool_func))
                except Exception as e:
                    print(f"DEBUG: Error converting tool: {e}, tool type: {type(t)}")
                    continue
            try:
                tool_calls = tool_parser.extract_tool_calls(generated_text, tool_objects)
            except Exception as e:
                print(f"DEBUG: Error extracting tool calls: {e}")
                tool_calls = None
            if tool_calls:
                # Strip tool call format from content so user doesn't see raw tags (only if --reply-filters is set)
                if check_reply_filter('tool_calls', 'text', model_name):
                    clean_content = tool_parser.strip_tool_calls_from_content(generated_text)
                    response_message["content"] = clean_content if clean_content.strip() else None
                else:
                    response_message["content"] = generated_text if generated_text.strip() else None
                response_message["tool_calls"] = tool_calls
                finish_reason = "tool_calls"
        
        # Calculate token counts - rough estimate since we don't have direct access to tokenizer
        prompt_text = "\n".join([m.get("content", "") for m in messages])
        prompt_tokens = len(prompt_text.split())
        completion_tokens = len(generated_text.split()) if generated_text else 0
        
        # Use OpenAIFormatter for final sanitization
        formatter = OpenAIFormatter(model_name)
        return formatter.format_litellm_full(
            text=response_message.get("content", ""),
            prompt_tokens=prompt_tokens,
            completion_tokens=completion_tokens,
            tool_calls=response_message.get("tool_calls")
        )
    except Exception as e:
        print(f"Error during generation: {e}")
        raise HTTPException(status_code=500, detail=f"Generation error: {str(e)}")
@app.post("/v1/completions")
async def completions(request: CompletionRequest):
    """Text completions endpoint."""
    # Get the model for this request
    requested_model = request.model
    
    # Try to get the appropriate model
    mm = multi_model_manager.get_model_for_request(requested_model)
    
    if mm is None:
        # Model not loaded - try to use default
        if model_manager.backend is not None:
            # Fallback to legacy model_manager
            current_manager = model_manager
        else:
            raise HTTPException(status_code=503, detail="Model not loaded")
    else:
        current_manager = mm
    
    prompts = request.prompt if isinstance(request.prompt, list) else [request.prompt]
    stop_sequences = []
    if request.stop:
        stop_sequences = [request.stop] if isinstance(request.stop, str) else request.stop
    
    if request.stream:
        return StreamingResponse(
            stream_completion_response(
                prompts[0],
                request.model,
                request.max_tokens,
                request.temperature,
                request.top_p,
                stop_sequences,
                current_manager,
            ),
            media_type="text/event-stream",
        )
    else:
        return await generate_completion_response(
            prompts[0],
            request.model,
            request.max_tokens,
            request.temperature,
            request.top_p,
            stop_sequences,
            current_manager,
        )
async def stream_completion_response(
    prompt: str,
    model_name: str,
    max_tokens: Optional[int],
    temperature: float,
    top_p: float,
    stop: List[str],
    current_manager: ModelManager,
) -> AsyncGenerator[str, None]:
    """Stream completion response."""
    completion_id = f"cmpl-{uuid.uuid4().hex}"
    created = int(time.time())
    
    try:
        async for chunk in current_manager.generate_stream(
            prompt=prompt,
            max_tokens=max_tokens,
            temperature=temperature,
            top_p=top_p,
            stop=stop,
        ):
            data = {
                "id": completion_id,
                "object": "text_completion",
                "created": created,
                "model": model_name,
                "choices": [{
                    "text": chunk,
                    "index": 0,
                    "logprobs": None,
                    "finish_reason": None,
                }],
            }
            yield f"data: {json.dumps(data)}\n\n"
        
        yield f"data: {json.dumps({'choices': [{'finish_reason': 'stop'}]})}\n\n"
        yield "data: [DONE]\n\n"
    except Exception as e:
        print(f"Error during streaming completion: {e}")
        yield f"data: {json.dumps({'choices': [{'finish_reason': 'stop'}]})}\n\n"
        yield "data: [DONE]\n\n"
async def generate_completion_response(
    prompt: str,
    model_name: str,
    max_tokens: Optional[int],
    temperature: float,
    top_p: float,
    stop: List[str],
    current_manager: ModelManager,
) -> Dict:
    """Generate non-streaming completion response."""
    completion_id = f"cmpl-{uuid.uuid4().hex}"
    created = int(time.time())
    
    try:
        generated_text = current_manager.generate(
            prompt=prompt,
            max_tokens=max_tokens,
            temperature=temperature,
            top_p=top_p,
            stop=stop,
        )
        
        # Calculate token counts if tokenizer available
        if current_manager.tokenizer:
            prompt_tokens = len(current_manager.tokenizer.encode(prompt))
            completion_tokens = len(current_manager.tokenizer.encode(generated_text))
        else:
            prompt_tokens = len(prompt.split())
            completion_tokens = len(generated_text.split())
        
        return {
            "id": completion_id,
            "object": "text_completion",
            "created": created,
            "model": model_name,
            "choices": [{
                "text": generated_text,
                "index": 0,
                "logprobs": None,
                "finish_reason": "stop",
            }],
            "usage": {
                "prompt_tokens": prompt_tokens,
                "completion_tokens": completion_tokens,
                "total_tokens": prompt_tokens + completion_tokens,
            },
        }
    except Exception as e:
        print(f"Error during completion: {e}")
        raise HTTPException(status_code=500, detail=f"Generation error: {str(e)}")
# =============================================================================
# Main Entry Point
# =============================================================================

def parse_args():
    """Parse command line arguments."""
    parser = argparse.ArgumentParser(
        description="OpenAI-compatible API server supporting NVIDIA (CUDA) and Vulkan backends"
    )
    parser.add_argument(
        "--model",
        type=str,
        action="append",
        default=None,
        help="Model name, path, or URL for text-to-text LLM. Can be specified multiple times for multiple models.",
    )
    parser.add_argument(
        "--model-alias",
        type=str,
        action="append",
        default=None,
        dest="model_aliases",
        nargs=2,
        metavar=("ALIAS", "MODEL"),
        help="Register an alias for a model. Format: --model-alias <alias_name> <actual_model>",
    )
    parser.add_argument(
        "--backend",
        type=str,
        choices=["auto", "nvidia", "vulkan", "opencl"],
        default="auto",
        help="Backend to use: auto (detect), nvidia (CUDA), vulkan (AMD), or opencl",
    )
    parser.add_argument(
        "--image-backend",
        type=str,
        choices=["auto", "nvidia", "vulkan", "opencl"],
        default="auto",
        help="Image generation backend: auto, nvidia (CUDA), vulkan (AMD), or opencl",
    )
    parser.add_argument(
        "--audio-backend",
        type=str,
        choices=["auto", "nvidia", "vulkan", "opencl"],
        default="auto",
        help="Audio transcription backend: auto, nvidia (CUDA), vulkan (AMD), or opencl",
    )
    parser.add_argument(
        "--tts-backend",
        type=str,
        choices=["auto", "nvidia", "vulkan", "opencl"],
        default="auto",
        help="TTS backend: auto, nvidia (CUDA), vulkan (AMD), or opencl",
    )
    parser.add_argument(
        "--host",
        type=str,
        default="0.0.0.0",
        help="Host to bind to (default: 0.0.0.0)",
    )
    parser.add_argument(
        "--port",
        type=int,
        default=8000,
        help="Port to bind to (default: 8000)",
    )
    parser.add_argument(
        "--url",
        type=str,
        default="auto",
        help="Base URL for media downloads: 'auto' (use request IP) or explicit URL (e.g., http://myserver:8000)",
    )
    parser.add_argument(
        "--https",
        action="store_true",
        help="Enable HTTPS with auto-generated certificate",
    )
    parser.add_argument(
        "--privkey",
        type=str,
        default=None,
        help="Path to HTTPS private key file",
    )
    parser.add_argument(
        "--pubkey",
        type=str,
        default=None,
        help="Path to HTTPS certificate file",
    )
    parser.add_argument(
        "--offload-dir",
        type=str,
        default="./offload",
        help="Directory for disk offload (NVIDIA backend only, default: ./offload)",
    )
    parser.add_argument(
        "--load-in-4bit",
        action="store_true",
        help="Load model in 4-bit precision (NVIDIA backend only, requires bitsandbytes)",
    )
    parser.add_argument(
        "--load-in-8bit",
        action="store_true",
        help="Load model in 8-bit precision (NVIDIA backend only, requires bitsandbytes)",
    )
    parser.add_argument(
        "--ram",
        type=float,
        default=None,
        help="Maximum CPU RAM to use for model offloading in GB (NVIDIA backend only). Auto-detected if not specified. Disk offloading only occurs after this limit is exceeded.",
    )
    parser.add_argument(
        "--flash-attn",
        action="store_true",
        help="Use Flash Attention 2 (NVIDIA backend only, requires flash-attn package)",
    )
    parser.add_argument(
        "--offload-strategy",
        type=str,
        choices=["auto", "conservative", "balanced", "aggressive", "sequential"],
        default="auto",
        help="Offload strategy for NVIDIA backend (default: auto)",
    )
    parser.add_argument(
        "--max-gpu-percent",
        type=float,
        default=None,
        help="Maximum GPU VRAM to use as percentage (0-100). Overrides offload-strategy. Lower values offload more to CPU/RAM (default: None = use offload-strategy)",
    )
    parser.add_argument(
        "--n-gpu-layers",
        type=int,
        default=-1,
        help="Number of layers to offload to GPU (Vulkan backend only, default: -1 = all layers)",
    )
    parser.add_argument(
        "--n-ctx",
        type=int,
        action="append",
        default=None,
        help="Context window size (Vulkan backend). Can be specified multiple times, one per --model.",
    )
    parser.add_argument(
        "--vulkan-device",
        type=int,
        default=0,
        help="Vulkan GPU device ID to use (Vulkan backend only, default: 0). Use --vulkan-list-devices to see available devices",
    )
    parser.add_argument(
        "--vulkan-single-gpu",
        action="store_true",
        help="Force Vulkan to use only the specified GPU device (prevents layer distribution across multiple GPUs)",
    )
    parser.add_argument(
        "--vulkan-list-devices",
        action="store_true",
        help="List available Vulkan GPU devices and exit",
    )
    parser.add_argument(
        "--hf-chat-template",
        action="append",
        default=[],
        help="Use HuggingFace apply_chat_template. Examples: --hf-chat-template auto (all models), --hf-chat-template text (all text), --hf-chat-template mymodel:llama3 (specific model with template). Can be repeated.",
    )
    parser.add_argument(
        "--system-prompt",
        nargs="?",
        const=True,
        default=None,
        help="Inject a system prompt at the beginning of conversations. Use without a value for a default prompt, or provide custom text.",
    )
    # Multi-model arguments
    parser.add_argument(
        "--tts-model",
        type=str,
        default=None,
        help="Model for text-to-speech (e.g., kokoro, or path/URL to Kokoro model). Can be specified multiple times.",
    )
    parser.add_argument(
        "--audio-model",
        type=str,
        action="append",
        default=None,
        help="Model for audio transcription (e.g., whisper-1, base, or path to faster-whisper model). Can be specified multiple times for multiple models.",
    )
    parser.add_argument(
        "--audio-1",
        action="store_true",
        help="Disable request queue for audio models - return 409 if model is busy",
    )
    parser.add_argument(
        "--image-model",
        type=str,
        action="append",
        default=None,
        help="Model for image generation (e.g., stable-diffusion-xl-base-1.0). Can be specified multiple times for multiple models.",
    )
    parser.add_argument(
        "--vision-model",
        type=str,
        action="append",
        default=None,
        help="Model for image/video-to-text (e.g., llava-1.5, LLaVA). Supports vulkan and cuda backends.",
    )
    parser.add_argument(
        "--image-1",
        action="store_true",
        help="Disable request queue for image models - return 409 if model is busy",
    )
    parser.add_argument(
        "--llm-path",
        type=str,
        default=None,
        help="Path to CLIP LLM model for image generation (stable-diffusion-cpp-python).",
    )
    parser.add_argument(
        "--vae-path",
        type=str,
        default=None,
        help="Path to VAE model for image generation (stable-diffusion-cpp-python).",
    )
    parser.add_argument(
        "--image-sample-method",
        type=str,
        default="res_multistep",
        help="Sample method for image generation (default: res_multistep for Z-Image Turbo).",
    )
    parser.add_argument(
        "--image-steps",
        type=int,
        default=4,
        help="Number of inference steps for image generation (default: 4 for Z-Image Turbo).",
    )
    parser.add_argument(
        "--image-width",
        type=int,
        default=512,
        help="Image width for generation (default: 512).",
    )
    parser.add_argument(
        "--image-height",
        type=int,
        default=512,
        help="Image height for generation (default: 512).",
    )
    parser.add_argument(
        "--image-cfg-scale",
        type=float,
        default=1.0,
        help="CFG scale for image generation (default: 1.0 for Z-Image Turbo).",
    )
    parser.add_argument(
        "--image-precision",
        type=str,
        default="f32",
        choices=["bf16", "f32", "f16", "f8"],
        help="Model precision for image generation (default: f32). bf16 recommended for modern GPUs.",
    )
    parser.add_argument(
        "--image-cpu-offload",
        action="store_true",
        help="Enable sequential CPU offload for image models (lower VRAM usage).",
    )
    parser.add_argument(
        "--image-seed",
        type=int,
        default=None,
        help="Default seed for image generation (default: random).",
    )
    parser.add_argument(
        "--vae-tiling",
        action="store_true",
        help="Enable VAE tiling for lower VRAM usage (sd.cpp only).",
    )
    parser.add_argument(
        "--clip-on-cpu",
        action="store_true",
        help="Run CLIP on CPU to save VRAM (sd.cpp only).",
    )
    parser.add_argument(
        "--loadall",
        action="store_true",
        help="Pre-load all models (main, audio, image) at startup instead of on-demand",
    )
    parser.add_argument(
        "--loadswap",
        action="store_true",
        help="Keep all models loaded, swapping active model between VRAM and RAM (only active model in VRAM)",
    )
    parser.add_argument(
        "--nopreload",
        action="store_true",
        help="Disable model preloading. Models will load on first request instead of at startup",
    )
    parser.add_argument(
        "--audio-ctx",
        type=int,
        action="append",
        default=None,
        help="Audio model context size in milliseconds. Can be specified multiple times, one per --audio-model.",
    )
    parser.add_argument(
        "--audio-offload",
        type=float,
        default=None,
        help="Audio model GPU offload percentage (0-100). If not set, uses CPU",
    )
    parser.add_argument(
        "--audio-vulkan-device",
        type=int,
        default=0,
        help="Vulkan GPU device ID to use for Whisper audio transcription (default: 0). Only used when using Vulkan backend.",
    )
    parser.add_argument(
        "--image-vulkan-device",
        type=int,
        default=None,
        help="Vulkan GPU device ID to use for image generation models (default: same as --vulkan-device). Use --vulkan-list-devices to see available devices",
    )

    parser.add_argument(
        "--whisper-cpp",
        type=str,
        default=None,
        help="Path to whisper.cpp CLI executable (e.g., ~/whisper.cpp/build/bin/whisper-cli). Uses Vulkan if available.",
    )
    parser.add_argument(
        "--whisper-server",
        type=str,
        default=None,
        help="Path to whisper.cpp server executable (e.g., ~/whisper.cpp/build/bin/whisper-server). Keeps model loaded in VRAM.",
    )
    parser.add_argument(
        "--whisper-server-port",
        type=int,
        default=8744,
        help="Port for whisper-server (default: 8744).",
    )
    parser.add_argument(
        "--image-ctx",
        type=int,
        action="append",
        default=None,
        help="Image model context size. Can be specified multiple times, one per --image-model.",
    )
    parser.add_argument(
        "--image-offload",
        type=float,
        default=None,
        help="Vision model GPU offload percentage (0-100). If not set, loads fully on GPU",
    )
    parser.add_argument(
        "--list-cached-models",
        action="store_true",
        help="List all cached models in the model cache directory",
    )
    parser.add_argument(
        "--remove-all-models",
        action="store_true",
        help="Remove all cached models from the model cache directory",
    )
    parser.add_argument(
        "--remove-model",
        type=str,
        default=None,
        help="Remove a specific cached model by name or hash (partial match)",
    )
    parser.add_argument(
        "--reply-filters",
        action="append",
        default=[],
        help="Enable filtering of model replies. Use --reply-filters malformed,tool_calls or --reply-filters text:malformed --reply-filters image:tool_calls for model-specific filters.",
    )
    parser.add_argument(
        "--debug",
        action="store_true",
        help="Enable debug mode - dumps full request/response to stdout for troubleshooting",
    )
    parser.add_argument(
        "--file-path",
        type=str,
        default=None,
        help="Path to store generated files (images, audio). If specified, files will be saved here and served over web.",
    )
    parser.add_argument(
        "--parser",
        type=str,
        default="auto",
        choices=["auto", "litellm"],
        help="Tool call parser to use: 'auto' for internal parser, 'litellm' for LiteLLM's parser. Default: auto",
    )
    return parser.parse_args()
def main():
    """Main entry point."""
    global global_system_prompt, model_manager, multi_model_manager, global_debug, global_args, global_file_path
    
    # Suppress unraisable exceptions from LlamaModel.__del__
    import sys
    original_unraisablehook = sys.unraisablehook
    def suppress_llama_del_errors(unraisable):
        if isinstance(unraisable.exc_value, AttributeError) and 'LlamaModel' in repr(unraisable.object) and 'sampler' in str(unraisable.exc_value):
            return  # Ignore this specific error
        original_unraisablehook(unraisable)
    sys.unraisablehook = suppress_llama_del_errors
    
    # Optional: set process name if procname is available
    try:
        import procname
        procname.setprocname("coderai")
    except ImportError:
        pass
    args = parse_args()
    
    # Store args globally for access in endpoints
    global_args = args
    
    # Set global system prompt from --system-prompt flag
    global_system_prompt = args.system_prompt
    
    # Set global debug flag
    global_debug = args.debug
    # Set global file path for storing generated files
    global_file_path = args.file_path
    if global_debug:
        # Print the full command line that was used to invoke coderai
        import shlex
        cmd_line = ' '.join(shlex.quote(arg) for arg in sys.argv)
        print(f"\n{'='*80}")
        print(f"=== COMMAND LINE: {cmd_line}")
        print(f"{'='*80}\n")
        print("DEBUG MODE ENABLED - Full requests and replies will be dumped to stdout")
    
    # Handle --vulkan-list-devices
    if args.vulkan_list_devices:
        print("\nListing Vulkan devices...")
        try:
            import subprocess
            result = subprocess.run(['vulkaninfo', '--summary'], capture_output=True, text=True)
            if result.returncode == 0:
                print(result.stdout)
            else:
                print("Could not run vulkaninfo. Make sure vulkan-tools is installed.")
        except Exception as e:
            print(f"Error listing devices: {e}")
        sys.exit(0)
    
    # Handle --list-cached-models
    if args.list_cached_models:
        print("\n=== Listing Cached Models ===")
        
        caches = get_all_cache_dirs()
        if not caches:
            print("No model cache directories found.")
            sys.exit(0)
        
        all_files = []
        for cache_name, cache_dir in caches.items():
            print(f"\n--- {cache_name.upper()} Cache ({cache_dir}) ---")
            if not os.path.exists(cache_dir):
                print(f"  (directory does not exist)")
                continue
                
            files = os.listdir(cache_dir)
            if not files:
                print(f"  No cached files.")
                continue
            
            # For diffusers and huggingface, show directory structure
            if cache_name in ('diffusers', 'huggingface'):
                for root, dirs, files in os.walk(cache_dir):
                    for f in files:
                        filepath = os.path.join(root, f)
                        rel_path = os.path.relpath(filepath, cache_dir)
                        size = os.path.getsize(filepath)
                        all_files.append((cache_name, rel_path, size))
            else:
                for f in files:
                    filepath = os.path.join(cache_dir, f)
                    if os.path.isfile(filepath):
                        size = os.path.getsize(filepath)
                        all_files.append((cache_name, f, size))
        
        if not all_files:
            print("\nNo cached models found.")
            sys.exit(0)
        
        # Calculate totals
        total_size = sum(size for _, _, size in all_files)
        
        print(f"\n=== Summary ===")
        print(f"Total: {len(all_files)} files, {total_size / (1024*1024*1024):.2f} GB")
        print("\nCache locations:")
        for cache_name, cache_dir in caches.items():
            print(f"  {cache_name}: {cache_dir}")
        
        sys.exit(0)
    
    # Handle --remove-all-models
    if args.remove_all_models:
        print("\n=== Removing All Cached Models ===")
        
        import shutil
        caches = get_all_cache_dirs()
        
        if not caches:
            print("No cache directories found.")
            sys.exit(0)
        
        total_removed = 0
        for cache_name, cache_dir in caches.items():
            if not os.path.exists(cache_dir):
                continue
                
            files = os.listdir(cache_dir)
            if not files:
                continue
            
            print(f"\nRemoving from {cache_name} cache ({cache_dir})...")
            print(f"  Found {len(files)} file(s). Deleting...")
            
            # For diffusers, remove entire directory tree
            if cache_name == 'diffusers':
                for item in os.listdir(cache_dir):
                    item_path = os.path.join(cache_dir, item)
                    if os.path.isdir(item_path):
                        shutil.rmtree(item_path)
                    else:
                        os.remove(item_path)
                    print(f"  Deleted: {item}")
                    total_removed += 1
            else:
                for f in files:
                    filepath = os.path.join(cache_dir, f)
                    os.remove(filepath)
                    print(f"  Deleted: {f}")
                    total_removed += 1
        
        print(f"\n=== Removed {total_removed} item(s) from all caches ===")
        sys.exit(0)
    
    # Handle --remove-model
    if args.remove_model:
        print(f"\n=== Removing Cached Model Matching: {args.remove_model} ===")
        
        import shutil
        caches = get_all_cache_dirs()
        
        if not caches:
            print("No cache directories found.")
            sys.exit(0)
        
        all_matching = []
        for cache_name, cache_dir in caches.items():
            if not os.path.exists(cache_dir):
                continue
            
            # For diffusers and huggingface, search recursively
            if cache_name in ('diffusers', 'huggingface'):
                for root, dirs, files in os.walk(cache_dir):
                    for f in files:
                        if args.remove_model.lower() in f.lower():
                            filepath = os.path.join(root, f)
                            rel_path = os.path.relpath(filepath, cache_dir)
                            size = os.path.getsize(filepath)
                            all_matching.append((cache_name, rel_path, filepath, size))
            else:
                files = os.listdir(cache_dir)
                for f in files:
                    if args.remove_model.lower() in f.lower():
                        filepath = os.path.join(cache_dir, f)
                        if os.path.isfile(filepath):
                            size = os.path.getsize(filepath)
                            all_matching.append((cache_name, f, filepath, size))
        
        if not all_matching:
            print(f"No cached models found matching: {args.remove_model}")
            print(f"\nUse --list-cached-models to see available models.")
            sys.exit(0)
        
        print(f"\nFound {len(all_matching)} matching file(s):")
        for cache_name, filename, filepath, size in all_matching:
            print(f"  [{cache_name}] {filename} ({size / (1024*1024):.1f} MB)")
        
        # Confirm before deleting
        print(f"\nDeleting {len(all_matching)} file(s)...")
        for cache_name, filename, filepath, size in all_matching:
            try:
                os.remove(filepath)
                print(f"  Deleted: [{cache_name}] {filename}")
            except Exception as e:
                print(f"  Failed to delete {filename}: {e}")
        
        print(f"\nRemoved {len(all_matching)} cached model file(s).")
        sys.exit(0)
    
    # Get model names from args - support multiple models
    model_names = args.model if args.model else []
    
    # Helper function to get config value by index with fallback
    def get_ctx_by_index(ctx_list, index, default):
        """Get context value by model index, with fallback to default."""
        if ctx_list and index < len(ctx_list):
            return ctx_list[index]
        return default
    
    # Validate: must have at least one model specified
    audio_models = args.audio_model if args.audio_model else []
    image_models = args.image_model if args.image_model else []
    vision_models = args.vision_model if args.vision_model else []
    
    if not model_names and not audio_models and not image_models and not vision_models and args.tts_model is None:
        print("Error: At least one of --model, --audio-model, --image-model, --vision-model, or --tts-model must be specified.")
        print("")
        print("For NVIDIA backend (HuggingFace models):")
        print("  - microsoft/DialoGPT-medium")
        print("  - meta-llama/Llama-2-7b-chat-hf (requires auth)")
        print("  - TinyLlama/TinyLlama-1.1B-Chat-v1.0")
        print("  - Use multiple --model flags for multiple models")
        print("")
        print("For Vulkan backend (GGUF models):")
        print("  - Local path: ./phi-3-mini-4k-instruct-q4_k_m.gguf")
        print("  - HuggingFace: microsoft/Phi-3-mini-4k-instruct-gguf")
        print("  - URL: https://huggingface.co/.../model.gguf")
        print("")
        print("For audio transcription:")
        print("  - --audio-model base")
        print("")
        print("For text-to-speech:")
        print("  - --tts-model kokoro")
        print("")
        print("For image generation:")
        print("  - --image-model stabilityai/stable-diffusion-xl-base-1.0")
        sys.exit(1)
    
    # Print loaded models info
    if model_names:
        print(f"\nText model(s): {model_names}")
        if len(model_names) > 1:
            # Load mode will be determined below
            print(f"Multiple models configured - load mode will be set based on --loadall/--loadswap flags")
    
    # Detect available backends
    available = detect_available_backends()
    
    # If user explicitly requests nvidia/cuda backend with a GGUF model, 
    # remove vulkan from available since we'll use CUDA instead
    if model_names:
        first_model = model_names[0]
        is_gguf_model = first_model.endswith('.gguf') or 'gguf' in first_model.lower()
        if is_gguf_model and args.backend in ('nvidia', 'cuda'):
            # When using nvidia/cuda backend with GGUF, vulkan uses CUDA, so remove it
            if 'vulkan' in available:
                del available['vulkan']
    
    print("\nAvailable backends:")
    for name, available_flag in available.items():
        status = "✓" if available_flag else "✗"
        print(f"  [{status}] {name}")
    print("")
    
    # Load the main model (only if specified)
    if model_names:
        # Enable verbose mode when debug is set (for better troubleshooting output from llama-cpp)
        verbose = args.debug if hasattr(args, 'debug') else False
        
        load_kwargs = {
            'offload_dir': args.offload_dir,
            'load_in_4bit': args.load_in_4bit,
            'load_in_8bit': args.load_in_8bit,
            'manual_ram_gb': args.ram,
            'flash_attn': args.flash_attn,
            'offload_strategy': args.offload_strategy,
            'max_gpu_percent': args.max_gpu_percent,
            'n_gpu_layers': args.n_gpu_layers,
            'n_ctx': get_ctx_by_index(args.n_ctx, 0, 2048),
            'main_gpu': args.vulkan_device,
            'single_gpu': args.vulkan_single_gpu,
            'verbose': verbose,
        }
        
        # Load the first model
        first_model_name = model_names[0]
        try:
            model_manager.load_model(
                model_name=first_model_name,
                backend_type=args.backend,
                **load_kwargs
            )
            # Register with multi_model_manager
            multi_model_manager.set_default_model(first_model_name, load_kwargs, args.backend)
            multi_model_manager.add_model(first_model_name, model_manager)
            print(f"\nMain text model loaded: {first_model_name}")
        except Exception as e:
            print(f"\nError loading model: {e}")
            error_str = str(e).lower()
            print("\nTroubleshooting:")
            if args.backend == "vulkan":
                print("  - For Vulkan, ensure you have Vulkan drivers installed")
                print("  - Make sure you're using a GGUF format model")
                print("  - Run build.sh with 'vulkan' argument first")
            else:
                print("  - For NVIDIA, ensure PyTorch with CUDA is installed")
                print("  - Run build.sh with 'nvidia' argument first")
                if "tokenizer" in error_str or "sentencepiece" in error_str or "tiktoken" in error_str:
                    print("  - Tokenizer error: ensure sentencepiece and tiktoken are installed")
                    print("    pip install sentencepiece tiktoken tokenizers")
                # Check if trying to load GGUF model with NVIDIA backend
                if "gguf" in first_model_name.lower():
                    print(f"\n  *** IMPORTANT: '{first_model_name}' appears to be a GGUF model ***")
                    print("  GGUF models are NOT compatible with the NVIDIA backend.")
                    print("  Use --backend vulkan instead, or choose a HuggingFace Transformers model.")
                    print("\n  Example Vulkan command:")
                    print(f"    coderai --backend vulkan --model {first_model_name}")
            sys.exit(1)
    else:
        print("\nNo main text model specified (--model). Running with audio/image/TTS models only.")
    
    # Determine load mode BEFORE setting up other models
    load_mode = "ondemand"
    if args.loadall:
        load_mode = "loadall"
    elif args.loadswap:
        load_mode = "loadswap"
    
    # Set load mode in multi_model_manager
    multi_model_manager.set_load_mode(load_mode)
    
    # Load models based on mode and count
    if len(model_names) > 1:
        # Multiple models - handle based on load mode
        print(f"\n=== Multiple Models Mode: {load_mode} ===")
        
        if load_mode == "loadall":
            # Load all models into VRAM
            # Skip first model if it's already loaded (at lines 4274-4281)
            start_index = 1 if model_names[0] in multi_model_manager.models else 0
            for i in range(start_index, len(model_names)):
                model_name = model_names[i]
                print(f"\nLoading model {i+1}/{len(model_names)}: {model_name}")
                try:
                    manager = ModelManager()
                    manager.load_model(
                        model_name=model_name,
                        backend_type=args.backend,
                        **load_kwargs
                    )
                    multi_model_manager.add_model(model_name, manager)
                    print(f"Loaded: {model_name}")
                except Exception as e:
                    print(f"Error loading {model_name}: {e}")
        
        elif load_mode == "loadswap":
            # First model in VRAM, others in RAM
            # Skip first model if it's already loaded (at lines 4274-4281)
            start_index = 1 if model_names[0] in multi_model_manager.models else 0
            for i in range(start_index, len(model_names)):
                model_name = model_names[i]
                # In loadswap, all additional models go to RAM (CPU-only)
                print(f"\nLoading model {i+1}/{len(model_names)}: {model_name} (RAM)")
                try:
                    manager = ModelManager()
                    # Modify kwargs for CPU-only loading
                    swap_kwargs = load_kwargs.copy()
                    swap_kwargs['n_gpu_layers'] = 0  # Force CPU only for swap mode
                    manager.load_model(
                        model_name=model_name,
                        backend_type=args.backend,
                        **swap_kwargs
                    )
                    multi_model_manager.add_model(model_name, manager)
                    print(f"Loaded: {model_name} (RAM)")
                except Exception as e:
                    print(f"Error loading {model_name}: {e}")
        
        else:  # ondemand
            # First model already loaded at lines 4274-4281
            # Just register other models but don't load them
            print(f"\nFirst model already loaded: {model_names[0]}")
            
            # Register other models but don't load them
            for model_name in model_names[1:]:
                multi_model_manager.set_default_model(model_name, load_kwargs, args.backend)
            
            print(f"Other models will load on-demand: {model_names[1:]}")
    # Model is already loaded at lines 4274-4281
    
    # Determine load mode BEFORE setting up other models
    load_mode = "ondemand"
    if args.loadall:
        load_mode = "loadall"
    elif args.loadswap:
        load_mode = "loadswap"
    
    # Set load mode in multi_model_manager
    multi_model_manager.set_load_mode(load_mode)
    
    # Pre-load models based on mode
    print(f"DEBUG: load_mode at line 4710 = '{load_mode}', backend = {args.backend}")
    if load_mode in ("loadall", "loadswap"):
        # Load all models into VRAM (or RAM for CUDA loadswap)
        mode_name = "Load All" if load_mode == "loadall" else "Load Swap"
        print(f"\n=== {mode_name} Mode ===")
        
        # Load main text model first
        if model_names:
            print(f"Pre-loading main text model: {model_names[0]}")
        
        # Load image model (first one only in loadall mode currently)
        print(f"DEBUG: image_models check at line 4718: {image_models}, backend = {args.backend}")
        # Only preload image model if loadall or loadswap mode is set
        if image_models and not getattr(args, 'nopreload', False) and load_mode in ("loadall", "loadswap"):
            print(f"Pre-loading image model: {image_models[0]}")
            
            # Get the original model name
            original_model_name = image_models[0]
            
            # Check if it's a URL first (before any processing)
            is_url = original_model_name.startswith('http://') or original_model_name.startswith('https://')
            
            # Strip query parameters from URL if present
            model_name = original_model_name
            if '?' in model_name:
                model_name = model_name.split('?')[0]
            
            # Check if the image model is a GGUF model
            is_gguf = model_name.endswith('.gguf') or 'gguf' in model_name.lower()
            
            if is_gguf:
                # GGUF for image - use stable-diffusion-cpp-python
                print(f"Detected GGUF image model, loading with llama.cpp...")
                try:
                    from llama_cpp import Llama
                    
                    # Download GGUF model if needed (similar to VulkanBackend)
                    model_path = None
                    
                    # Check if it's a URL - download directly
                    if is_url:
                        print(f"Image model is a URL: {original_model_name}")
                        cached_path = get_cached_model_path(original_model_name)
                        if cached_path:
                            model_path = cached_path
                            print(f"Using cached GGUF model: {model_path}")
                        else:
                            print(f"Downloading GGUF model: {original_model_name}")
                            cache_dir = get_model_cache_dir()
                            model_path = download_model(original_model_name, cache_dir)
                    elif os.path.isfile(model_name):
                        # Local file
                        model_path = model_name
                        print(f"Loading local GGUF model: {model_path}")
                    else:
                        # Try to download from HuggingFace Hub
                        print(f"Trying to resolve as HuggingFace model: {model_name}")
                        try:
                            from huggingface_hub import hf_hub_download, list_repo_files
                            parts = model_name.split('/')
                            if len(parts) >= 2:
                                repo_id = f"{parts[0]}/{parts[1]}"
                                print(f"Looking for GGUF files in repo: {repo_id}")
                                files = list_repo_files(repo_id)
                                gguf_files = [f for f in files if f.endswith('.gguf')]
                                if not gguf_files:
                                    raise ValueError(f"No GGUF files found in {repo_id}")
                                filename = gguf_files[0]
                                model_path = hf_hub_download(repo_id=repo_id, filename=filename)
                                print(f"Downloaded GGUF model to: {model_path}")
                        except Exception as e:
                            print(f"Could not resolve GGUF model path: {e}")
                            print(f"Image model will load on first request")
                            model_path = None
                    
                    if model_path and os.path.isfile(model_path):
                        # Use the cached path for the model key
                        model_key = f"image:{model_path}"
                        
                        # Load with llama.cpp
                        n_gpu_layers = -1  # Load all layers to GPU
                        n_ctx = 2048
                        
                        print(f"Loading GGUF model from: {model_path}")
                        file_size = os.path.getsize(model_path)
                        print(f"GGUF model file size: {file_size / (1024*1024):.1f} MB")
                        
                        # Verify it's a valid GGUF file (check magic bytes)
                        with open(model_path, 'rb') as f:
                            magic = f.read(8)
                            print(f"File magic bytes: {magic}")
                            if not magic.startswith(b'GGUF'):
                                print(f"ERROR: File is NOT a valid GGUF! Expected 'GGUF', got: {magic}")
                                print(f"This means the download returned an HTML error page instead of the model.")
                                print(f"The URL must be a DIRECT download link (ends with .gguf, not a model page)")
                                print(f"Example: https://huggingface.co/owner/repo/resolve/main/model.gguf")
                                print(f"Image model will load on first request")
                            else:
                                # Valid GGUF, try to load
                                try:
                                    llama_model = Llama(
                                        model_path=model_path,
                                        n_gpu_layers=n_gpu_layers,
                                        n_ctx=n_ctx,
                                        verbose=True,
                                    )
                                    multi_model_manager.add_model(model_key, llama_model)
                                    print(f"GGUF image model loaded successfully: {original_model_name}")
                                except Exception as llama_error:
                                    print(f"llama.cpp load error: {llama_error}")
                                    print(f"Trying stable-diffusion-cpp-python fallback...")
                                    # Try stable-diffusion-cpp-python as fallback
                                    try:
                                        from stable_diffusion_cpp import StableDiffusion
                                        
                                        # Initialize model_key to None first so Python knows it exists
                                        model_key = None
                                        
                                        # Define model_key for this scope
                                        model_key = f"image:{model_path}"
                                        print(f"Loading with sd.cpp: {model_path}")
                                        # For models like Z-Image-Turbo/Flux, use diffusion_model_path
                                        # Look for additional model files in same directory
                                        model_dir = os.path.dirname(model_path)
                                        model_name = os.path.basename(model_path)
                                        
                                        # Try to find additional model files
                                        clip_l_path = None
                                        t5xxl_path = None
                                        vae_path = None
                                        
                                        # Use CLI arguments if provided, download and cache if URL
                                        if args.llm_path:
                                            # Check if it's a Hugging Face model ID, URL, or local path
                                            if is_huggingface_model_id(args.llm_path):
                                                # Download from Hugging Face
                                                print(f"Attempting to download LLM model from Hugging Face: {args.llm_path}")
                                                cache_dir = get_model_cache_dir()
                                                clip_l_path = download_huggingface_model(args.llm_path, cache_dir, '.gguf')
                                                if clip_l_path:
                                                    print(f"Downloaded LLM model to: {clip_l_path}")
                                                else:
                                                    print(f"Warning: Failed to download LLM model from Hugging Face, will try as local path")
                                            elif args.llm_path.startswith('http://') or args.llm_path.startswith('https://'):
                                                cached = get_cached_model_path(args.llm_path)
                                                if cached:
                                                    clip_l_path = cached
                                                    print(f"Using cached LLM model: {clip_l_path}")
                                                else:
                                                    cache_dir = get_model_cache_dir()
                                                    clip_l_path = download_model(args.llm_path, cache_dir)
                                                    print(f"Downloaded LLM model to: {clip_l_path}")
                                            else:
                                                clip_l_path = args.llm_path
                                        if args.vae_path:
                                            # Check if it's a URL and download if needed
                                            if args.vae_path.startswith('http://') or args.vae_path.startswith('https://'):
                                                cached = get_cached_model_path(args.vae_path)
                                                if cached:
                                                    vae_path = cached
                                                    print(f"Using cached VAE model: {vae_path}")
                                                else:
                                                    cache_dir = get_model_cache_dir()
                                                    vae_path = download_model(args.vae_path, cache_dir)
                                                    print(f"Downloaded VAE model to: {vae_path}")
                                            else:
                                                vae_path = args.vae_path
                                        
                                        # Look for common file patterns only if CLI args not provided
                                        if not args.llm_path or not args.vae_path:
                                            for f in os.listdir(model_dir) if os.path.exists(model_dir) else []:
                                                if not args.llm_path and 'clip_l' in f.lower() and f.endswith(('.safetensors', '.bin')):
                                                    clip_l_path = os.path.join(model_dir, f)
                                                elif 't5xxl' in f.lower() and f.endswith(('.safetensors', '.bin')):
                                                    t5xxl_path = os.path.join(model_dir, f)
                                                elif not args.vae_path and f.endswith('.safetensors') and 'ae' in f.lower():
                                                    vae_path = os.path.join(model_dir, f)
                                        
                                        # Build kwargs based on available files
                                        sd_kwargs = {'diffusion_model_path': model_path}
                                        
                                        if clip_l_path:
                                            sd_kwargs['llm_path'] = clip_l_path
                                            print(f"DEBUG: Adding llm_path to sd_kwargs: {clip_l_path}")
                                        else:
                                            print(f"DEBUG: clip_l_path is None or empty, not adding to sd_kwargs")
                                            print(f"DEBUG: args.llm_path = {args.llm_path}")
                                        if args.vae_path:
                                            sd_kwargs['vae_path'] = vae_path
                                        elif vae_path:
                                            sd_kwargs['vae_path'] = vae_path
                                        if t5xxl_path:
                                            sd_kwargs['t5xxl_path'] = t5xxl_path
                                        
                                        # Add sd.cpp-specific options from CLI args
                                        if getattr(global_args, 'vae_tiling', False):
                                            # VAE tiling is handled internally in newer builds
                                            print(f"DEBUG: VAE tiling is handled internally in stable-diffusion-cpp-python")
                                        if getattr(global_args, 'clip_on_cpu', False):
                                            sd_kwargs['keep_clip_on_cpu'] = True
                                            print(f"DEBUG: Running CLIP on CPU to save VRAM (keep_clip_on_cpu=True)")
                                        
                                        # Use all available CPU cores for processing
                                        import psutil
                                        sd_kwargs['n_threads'] = psutil.cpu_count()
                                        print(f"DEBUG: Using {psutil.cpu_count()} CPU cores for sd.cpp")

                                        # Add generation parameters from CLI args
                                        # sd_kwargs['sample_method'] = args.image_sample_method  # Not valid for __init__
                                        # sd_kwargs['steps'] = args.image_steps  # Not valid for __init__
                                        
                                        sd_model = StableDiffusion(**sd_kwargs)
                                        multi_model_manager.add_model(model_key, sd_model)
                                        # Add alias for "image" 
                                        multi_model_manager.add_model("image", sd_model)
                                        
                                        print(f"Image model loaded successfully via sd.cpp: {original_model_name}")
                                    except ImportError as sd_error:
                                        print(f"stable-diffusion-cpp-python not installed: {sd_error}")
                                        print(f"Image model will load on first request")
                                    except Exception as sd_error:
                                        print(f"sd.cpp load error: {sd_error}")
                                        print(f"Image model will load on first request")
                    else:
                        print(f"Could not load GGUF image model: no valid model path")
                        
                except ImportError as e:
                    print(f"Warning: llama_cpp not installed: {e}")
                    print(f"Image model will load on first request")
                except Exception as e:
                    print(f"Warning: Failed to pre-load GGUF image model: {e}")
                    print(f"Image model will load on first request")
            else:
                # Load diffusers image model (Stable Diffusion)
                try:
                    import torch
                    from diffusers import StableDiffusionXLPipeline, DiffusionPipeline
                    
                    # Use model name directly for diffusers (model_path is only set in GGUF branch)
                    model_key = f"image:{model_name}"
                    print(f"Loading diffusers pipeline: {model_name}")
                    
                    # Try to load as Stable Diffusion XL first
                    try:
                        pipeline = StableDiffusionXLPipeline.from_pretrained(
                            model_name,
                            torch_dtype=torch.float32,
                            use_safetensors=True,
                        )
                    except Exception as e:
                        print(f"SDXL failed, trying generic pipeline: {e}")
                        # Try generic diffusion pipeline
                        pipeline = DiffusionPipeline.from_pretrained(
                            model_name,
                            torch_dtype=torch.float32,
                            use_safetensors=True,
                        )
                    
                    # Move to GPU if available
                    if torch.cuda.is_available():
                        pipeline = pipeline.to("cuda")
                        pipeline.enable_attention_slicing()
                    else:
                        pipeline = pipeline.to("cpu")
                    
                    multi_model_manager.add_model(model_key, pipeline)
                    # Add alias for "image"
                    multi_model_manager.add_model("image", pipeline)
                    
                    print(f"Image model loaded successfully: {model_name}")
                    
                except ImportError as e:
                    print(f"Warning: diffusers not installed, image model will load on first request: {e}")
                except Exception as e:
                    print(f"Warning: Failed to pre-load image model: {e}")
                    print(f"  Image model will load on first request")
        
        # Load audio model
        print(f"DEBUG: audio_models check at line 4970: {audio_models}")
        if audio_models:
            print(f"Pre-loading audio model: {audio_models[0]}")
        
        # Load TTS model
        if args.tts_model:
            print(f"Pre-loading TTS model: {args.tts_model}")
            
    elif load_mode == "loadswap":
        # Load models in order: model > image > audio > TTS, keep active in VRAM
        # For Vulkan backend, load all models to VRAM like loadall (VRAM is not limited like CUDA)
        print("\n=== Load Swap Mode ===")
        
        # For Vulkan, use same preloading as loadall
        if args.backend == "vulkan":
            # Vulkan: Load all models to GPU like loadall
            if model_names:
                print(f"Pre-loading main text model: {model_names[0]}")
            # Only preload image model if loadall or loadswap mode is set
            if image_models and not getattr(args, 'nopreload', False) and load_mode in ("loadall", "loadswap"):
                print(f"Pre-loading image model: {image_models[0]}")
            if audio_models:
                print(f"Pre-loading audio model: {audio_models[0]}")
            if args.tts_model:
                print(f"Pre-loading TTS model: {args.tts_model}")
        else:
            # NVIDIA/CUDA: First model in VRAM, others in RAM
            if model_names:
                print(f"Main text model will be in VRAM: {model_names[0]}")
            # Only preload image model if loadall or loadswap mode is set
            if image_models and not getattr(args, 'nopreload', False) and load_mode in ("loadall", "loadswap"):
                print(f"Image model in RAM: {image_models[0]}")
            if audio_models:
                print(f"Audio model in RAM: {audio_models[0]}")
            if args.tts_model:
                print(f"TTS model in RAM: {args.tts_model}")
        
    else:
        # No flags: only one model gets loaded (the main text model if specified)
        print("\n=== On-Demand Mode ===")
        print("Models will load on first request")
    
    # Set up audio model if specified (with pre-loading if in loadall/loadswap mode)
    print(f"DEBUG: models in manager before audio setup: {list(multi_model_manager.models.keys())}")
    if audio_models:
        print(f"\nAudio transcription model(s): {audio_models}")
        
        # Set up Vulkan device for Whisper if using Vulkan backend
        if hasattr(args, 'audio_vulkan_device') and args.audio_vulkan_device is not None:
            print(f"  Using Vulkan device: {args.audio_vulkan_device}")
        
        # Register all audio models
        print(f"DEBUG: Registering audio models: {audio_models}")
        for idx, audio_m in enumerate(audio_models):
            multi_model_manager.set_audio_model(audio_m, {
                'ctx': get_ctx_by_index(args.audio_ctx, idx, 0),
                'offload': args.audio_offload,
            })
        print(f"DEBUG: After registration, audio_models in manager: {multi_model_manager.audio_models}")
        
        # Pre-load first audio model at startup if:
        # - Using loadall or loadswap mode, OR
        # - No main model is specified (only audio model configured)
        print(f"DEBUG: load_mode at line 5015 = '{load_mode}', model_names = {model_names}, audio_models = {audio_models}")
        should_preload = load_mode in ("loadall", "loadswap") or (not model_names and audio_models)
        print(f"DEBUG: should_preload = {should_preload}")
        
        # Initialize whisper-server if specified
        if args.whisper_server:
            print(f"\nWhisper server: {args.whisper_server}")
            print(f"  Port: {args.whisper_server_port}")
            # Check if whisper-server is already running
            if multi_model_manager.whisper_server is None:
                whisper_server_mgr = WhisperServerManager(
                    server_path=args.whisper_server,
                    port=args.whisper_server_port
                )
                multi_model_manager.whisper_server = whisper_server_mgr
            else:
                whisper_server_mgr = multi_model_manager.whisper_server
                print("Whisper server already running, using existing instance")
            
            # Start whisper-server if we should preload or if it's the only audio option
            print(f"DEBUG: whisper-server start check - audio_models={audio_models}, should_preload={should_preload}, whisper_cpp={args.whisper_cpp}")
            if audio_models and (should_preload or not args.whisper_cpp):
                model_to_use = audio_models[0] if audio_models else None
                gpu_device = getattr(args, 'audio_vulkan_device', 0) or 0
                print(f"DEBUG: Starting whisper-server with gpu_device={gpu_device}")
                actual_model_path = whisper_server_mgr.start(model_path=model_to_use, gpu_device=gpu_device)
                if actual_model_path:
                    # Update audio_models in multi_model_manager to store the actual path (not the URL)
                    if model_to_use != actual_model_path:
                        # Update the manager's audio_models list
                        if multi_model_manager.audio_models and multi_model_manager.audio_models[0] == model_to_use:
                            multi_model_manager.audio_models[0] = actual_model_path
                    print(f"Whisper server started with model: {actual_model_path}")
                else:
                    print("Warning: Failed to start whisper-server, falling back to other backends")
        elif should_preload:
            print(f"Pre-loading audio model... {audio_models[0]}")
            
            # Use first audio model for pre-loading
            model_to_use = audio_models[0]
            is_gguf_model = model_to_use.endswith('.gguf') or 'gguf' in model_to_use.lower()
            
            if is_gguf_model:
                # Skip faster-whisper for GGUF files - it doesn't support them
                # Go directly to whispercpp
                print("Detected GGUF model - using whispercpp backend")
                faster_whisper_failed = True
            else:
                # Try faster-whisper first
                faster_whisper_failed = False
            try:
                # Try faster-whisper first (requires torch)
                from faster_whisper import WhisperModel
                import torch
                
                model_to_use = audio_models[0]
                model_path = None
                
                # Check if model is a URL - handle caching
                if model_to_use.startswith('http://') or model_to_use.startswith('https://'):
                    cached_path = get_cached_model_path(model_to_use)
                    if cached_path:
                        model_path = cached_path
                        print(f"Using cached model: {model_path}")
                    else:
                        # Download with progress
                        cache_dir = get_model_cache_dir()
                        model_path = download_model(model_to_use, cache_dir)
                        model_to_use = model_path
                
                # Determine compute type - always use CPU on Vulkan backend
                # faster-whisper CUDA doesn't work with AMD/Vulkan GPUs
                compute_type = "int8"
                
                # Load the model - always use CPU (faster-whisper CUDA doesn't work with AMD/Vulkan)
                whisper_model = WhisperModel(
                    model_to_use,
                    device="cpu",
                    compute_type=compute_type
                )
                
                # Store in multi_model_manager
                model_key = f"audio:{audio_models[0]}"
                multi_model_manager.add_model(model_key, whisper_model)
                print(f"Audio model loaded successfully (faster-whisper)")
                
                # Warn if using CPU (no CUDA available)
                import torch
                if not torch.cuda.is_available():
                    print("Note: faster-whisper is running on CPU (no CUDA GPU detected)")
                    print("      For GPU acceleration, use NVIDIA GPU with CUDA or wait for Vulkan support.")
                
            except ImportError:
                # faster-whisper not available, will try whispercpp below
                faster_whisper_failed = True
            except Exception as e:
                # faster-whisper failed for some other reason (e.g., GGUF file not supported)
                print(f"Warning: faster-whisper failed to load model: {e}")
                faster_whisper_failed = True
            
            # If faster-whisper failed (not installed or couldn't load), try whispercpp
            if faster_whisper_failed:
                # Initialize model_path
                model_path = None
                
                # Check if model is a GGUF file - whispercpp can handle those
                model_is_gguf = model_to_use.endswith('.gguf') or (model_path and model_path.endswith('.gguf'))
                
                # Check if Vulkan is available for whispercpp
                whisper_vulkan_available = False
                whisper_vulkan_device = os.environ.get('VK_DEVICE_SELECT_DEVICE', '0')
                try:
                    import whispercpp
                    if os.environ.get('VK_DEVICE_SELECT_DEVICE'):
                        whisper_vulkan_available = True
                        print(f"Whisper Vulkan: Will use GPU device {whisper_vulkan_device}")
                    elif os.path.exists('/dev/dri'):
                        whisper_vulkan_available = True
                        print(f"Whisper Vulkan: Auto-detected GPU, will use device {whisper_vulkan_device}")
                except ImportError as e:
                    print(f"Debug: whispercpp import failed: {e}")
                
                try:
                    import whispercpp
                    
                    model_to_use = audio_models[0]
                    model_path = None
                    
                    # Check if model is a URL - handle caching
                    if model_to_use.startswith('http://') or model_to_use.startswith('https://'):
                        cached_path = get_cached_model_path(model_to_use)
                        if cached_path:
                            model_path = cached_path
                            print(f"Using cached model: {model_path}")
                        else:
                            # Download with progress
                            cache_dir = get_model_cache_dir()
                            model_path = download_model(model_to_use, cache_dir)
                            model_to_use = model_path
                    
                    # whispercpp needs a local file or a built-in model name
                    # whispercpp supports: tiny, base, small, medium, large-v1, large (built-in)
                    # or pre-converted GGUF files (NOT HuggingFace GGUF format)
                    if not model_path:
                        # Check if it's a local file
                        if os.path.isfile(model_to_use):
                            model_path = model_to_use
                        elif model_to_use in ['tiny.en', 'tiny', 'base.en', 'base', 'small.en', 'small', 'medium.en', 'medium', 'large-v1', 'large']:
                            # It's a built-in model name - whispercpp will download automatically
                            model_path = model_to_use
                        else:
                            # Could be a model name without .gguf extension - try it
                            model_path = model_to_use
                    
                    if not model_path or (model_path != model_to_use and not os.path.isfile(model_path)):
                        # If model_path is not a valid built-in name, check if file exists
                        if model_path and model_path not in ['tiny.en', 'tiny', 'base.en', 'base', 'small.en', 'small', 'medium.en', 'medium', 'large-v1', 'large']:
                            if not os.path.isfile(model_path):
                                print(f"Warning: whispercpp requires a local GGUF file or built-in model name, not: {model_to_use}")
                                print("For Vulkan audio transcription, use a built-in model name (tiny/base/small/medium/large-v1/large)")
                                print("or install faster-whisper with PyTorch for HuggingFace GGUF support.")
                                print("Audio model will load on-demand when transcription is requested.")
                    else:
                        # Load the whispercpp model
                        try:
                            whisper_model = whispercpp.Whisper.from_pretrained(model_path)
                            
                            # Store in multi_model_manager
                            model_key = f"audio:{audio_models[0]}"
                            multi_model_manager.add_model(model_key, whisper_model)
                            print(f"Audio model loaded successfully (whispercpp)")
                            if whisper_vulkan_available:
                                print(f"  -> Using Vulkan GPU acceleration (device {whisper_vulkan_device})")
                        except Exception as e:
                            error_msg = str(e).lower()
                            if 'not a valid preconverted model' in error_msg:
                                print(f"Warning: whispercpp does not support this model format")
                                print("whispercpp only supports built-in model names or pre-converted GGUF files.")
                                print("For Vulkan audio transcription, please either:")
                                print("  1. Install PyTorch + faster-whisper: pip install torch faster-whisper")
                                print("  2. Use a built-in whispercpp model: --audio-model base")
                                print("Audio model will load on-demand when transcription is requested.")
                            else:
                                print(f"Warning: Could not pre-load audio model with whispercpp: {e}")
                                print("Audio model will load on-demand when transcription is requested.")
                except ImportError as e:
                    # Neither faster-whisper nor whispercpp available
                    print(f"Warning: No audio transcription library available: {e}")
                    print("Options:")
                    print("  1. Install PyTorch + faster-whisper: pip install torch faster-whisper")
                    print("  2. Use a built-in whispercpp model: --audio-model base")
                    print("Audio model will load on-demand when transcription is requested.")
                except Exception as e:
                    print(f"Warning: Could not pre-load audio model with whispercpp: {e}")
                    print("Audio model will load on-demand when transcription is requested.")
    
    # Set up TTS model if specified
    if args.tts_model:
        print(f"\nText-to-speech model: {args.tts_model}")
        multi_model_manager.set_tts_model(args.tts_model, {})
        
        # Pre-load TTS model if it's the only model configured
        if not model_names and not audio_models and not image_models:
            print(f"Pre-loading TTS model...")
            # TTS models load on-demand, but we can pre-download if needed
    
    # Set up image model if specified
    if image_models:
        print(f"\nImage generation model(s): {image_models}")
        multi_model_manager.set_image_model(image_models[0], {
            'ctx': get_ctx_by_index(args.image_ctx, 0, 0),
            'offload': args.image_offload,
            'llm_path': args.llm_path,
            'vae_path': args.vae_path,
            'sample_method': args.image_sample_method,
            'steps': args.image_steps,
            'width': args.image_width,
            'height': args.image_height,
            'cfg_scale': args.image_cfg_scale,
        })
        # Register all image models
        for idx, img_m in enumerate(image_models[1:], start=1):
            multi_model_manager.set_image_model(img_m, {
                'ctx': get_ctx_by_index(args.image_ctx, idx, 0),
                'offload': args.image_offload,
            })
        
        # Pre-load image model if configured and in loadall/loadswap mode
        if image_models and not getattr(args, 'nopreload', False) and load_mode in ("loadall", "loadswap"):
            print(f"Pre-loading image model...")
            
            # Get the original model name
            original_model_name = image_models[0]
            
            # Check if it's a URL first (before any processing)
            is_url = original_model_name.startswith('http://') or original_model_name.startswith('https://')
            
            # Strip query parameters from URL if present
            model_name = original_model_name
            if '?' in model_name:
                model_name = model_name.split('?')[0]
            
            # Check if the image model is a GGUF model
            is_gguf = model_name.endswith('.gguf') or 'gguf' in model_name.lower()
            
            if is_gguf:
                # GGUF for image - use stable-diffusion-cpp-python
                print(f"Detected GGUF image model, loading with llama.cpp...")
                try:
                    from llama_cpp import Llama
                    from llama_cpp import Llama
                    
                    # Download GGUF model if needed
                    model_path = None
                    
                    # Check if it's a URL - download directly
                    if is_url:
                        print(f"Image model is a URL: {original_model_name}")
                        cached_path = get_cached_model_path(original_model_name)
                        if cached_path:
                            model_path = cached_path
                            print(f"Using cached GGUF model: {model_path}")
                        else:
                            print(f"Downloading GGUF model: {original_model_name}")
                            cache_dir = get_model_cache_dir()
                            model_path = download_model(original_model_name, cache_dir)
                    elif os.path.isfile(model_name):
                        model_path = model_name
                        print(f"Loading local GGUF model: {model_name}")
                    else:
                        # Try to download from HuggingFace Hub
                        print(f"Trying to resolve as HuggingFace model: {model_name}")
                        try:
                            from huggingface_hub import hf_hub_download, list_repo_files
                            parts = model_name.split('/')
                            if len(parts) >= 2:
                                repo_id = f"{parts[0]}/{parts[1]}"
                                print(f"Looking for GGUF files in repo: {repo_id}")
                                files = list_repo_files(repo_id)
                                gguf_files = [f for f in files if f.endswith('.gguf')]
                                if not gguf_files:
                                    raise ValueError(f"No GGUF files found in {repo_id}")
                                filename = gguf_files[0]
                                model_path = hf_hub_download(repo_id=repo_id, filename=filename)
                                print(f"Downloaded GGUF model to: {model_path}")
                        except Exception as e:
                            print(f"Could not resolve GGUF model path: {e}")
                            model_path = None
                    
                    if model_path and os.path.isfile(model_path):
                        n_gpu_layers = -1
                        n_ctx = 2048
                        
                        print(f"Loading GGUF model from: {model_path}")
                        file_size = os.path.getsize(model_path)
                        print(f"GGUF model file size: {file_size / (1024*1024):.1f} MB")
                        
                        # Verify it's a valid GGUF file (check magic bytes)
                        with open(model_path, 'rb') as f:
                            magic = f.read(8)
                            print(f"File magic bytes: {magic}")
                            if not magic.startswith(b'GGUF'):
                                print(f"ERROR: File is NOT a valid GGUF! Expected 'GGUF', got: {magic}")
                                print(f"The URL must be a DIRECT download link (ends with .gguf)")
                                print(f"Image model will load on first request")
                            else:
                                try:
                                    llama_model = Llama(
                                        model_path=model_path,
                                        n_gpu_layers=n_gpu_layers,
                                        n_ctx=n_ctx,
                                        verbose=True,
                                    )
                                    multi_model_manager.add_model(model_key, llama_model)
                                    print(f"GGUF image model loaded successfully: {original_model_name}")
                                except Exception as llama_error:
                                    print(f"llama.cpp load error: {llama_error}")
                                    print(f"Trying stable-diffusion-cpp-python fallback...")
                                    # Try stable-diffusion-cpp-python as fallback
                                    try:
                                        from stable_diffusion_cpp import StableDiffusion
                                        
                                        # Initialize model_key to avoid unbound variable error
                                        model_key = None
                                        
                                        print(f"Loading with sd.cpp: {model_path}")
                                        # For models like Z-Image-Turbo/Flux, use diffusion_model_path
                                        # Look for additional model files in same directory
                                        model_dir = os.path.dirname(model_path)
                                        model_name = os.path.basename(model_path)
                                        
                                        # Try to find additional model files
                                        clip_l_path = None
                                        t5xxl_path = None
                                        vae_path = None
                                        
                                        # Use CLI arguments if provided, download and cache if URL
                                        if args.llm_path:
                                            # Check if it's a Hugging Face model ID, URL, or local path
                                            if is_huggingface_model_id(args.llm_path):
                                                # Download from Hugging Face
                                                print(f"Attempting to download LLM model from Hugging Face: {args.llm_path}")
                                                cache_dir = get_model_cache_dir()
                                                clip_l_path = download_huggingface_model(args.llm_path, cache_dir, '.gguf')
                                                if clip_l_path:
                                                    print(f"Downloaded LLM model to: {clip_l_path}")
                                                else:
                                                    print(f"Warning: Failed to download LLM model from Hugging Face, will try as local path")
                                            elif args.llm_path.startswith('http://') or args.llm_path.startswith('https://'):
                                                cached = get_cached_model_path(args.llm_path)
                                                if cached:
                                                    clip_l_path = cached
                                                    print(f"Using cached LLM model: {clip_l_path}")
                                                else:
                                                    cache_dir = get_model_cache_dir()
                                                    clip_l_path = download_model(args.llm_path, cache_dir)
                                                    print(f"Downloaded LLM model to: {clip_l_path}")
                                            else:
                                                clip_l_path = args.llm_path
                                        if args.vae_path:
                                            # Check if it's a URL and download if needed
                                            if args.vae_path.startswith('http://') or args.vae_path.startswith('https://'):
                                                cached = get_cached_model_path(args.vae_path)
                                                if cached:
                                                    vae_path = cached
                                                    print(f"Using cached VAE model: {vae_path}")
                                                else:
                                                    cache_dir = get_model_cache_dir()
                                                    vae_path = download_model(args.vae_path, cache_dir)
                                                    print(f"Downloaded VAE model to: {vae_path}")
                                            else:
                                                vae_path = args.vae_path
                                        
                                        # Look for common file patterns only if CLI args not provided
                                        if not args.llm_path or not args.vae_path:
                                            for f in os.listdir(model_dir) if os.path.exists(model_dir) else []:
                                                if not args.llm_path and 'clip_l' in f.lower() and f.endswith(('.safetensors', '.bin')):
                                                    clip_l_path = os.path.join(model_dir, f)
                                                elif 't5xxl' in f.lower() and f.endswith(('.safetensors', '.bin')):
                                                    t5xxl_path = os.path.join(model_dir, f)
                                                elif not args.vae_path and f.endswith('.safetensors') and 'ae' in f.lower():
                                                    vae_path = os.path.join(model_dir, f)
                                        
                                        # Build kwargs based on available files
                                        sd_kwargs = {'diffusion_model_path': model_path}
                                        
                                        if clip_l_path:
                                            sd_kwargs['llm_path'] = clip_l_path
                                            print(f"DEBUG: Adding llm_path to sd_kwargs: {clip_l_path}")
                                        else:
                                            print(f"DEBUG: clip_l_path is None or empty, not adding to sd_kwargs")
                                            print(f"DEBUG: args.llm_path = {args.llm_path}")
                                        if args.vae_path:
                                            sd_kwargs['vae_path'] = vae_path
                                        elif vae_path:
                                            sd_kwargs['vae_path'] = vae_path
                                        if t5xxl_path:
                                            sd_kwargs['t5xxl_path'] = t5xxl_path
                                        
                                        # Add sd.cpp-specific options from CLI args
                                        if getattr(global_args, 'vae_tiling', False):
                                            # VAE tiling is handled internally in newer builds
                                            print(f"DEBUG: VAE tiling is handled internally in stable-diffusion-cpp-python")
                                        if getattr(global_args, 'clip_on_cpu', False):
                                            sd_kwargs['keep_clip_on_cpu'] = True
                                            print(f"DEBUG: Running CLIP on CPU to save VRAM (keep_clip_on_cpu=True)")
                                        
                                        # Use all available CPU cores for processing
                                        import psutil
                                        sd_kwargs['n_threads'] = psutil.cpu_count()
                                        print(f"DEBUG: Using {psutil.cpu_count()} CPU cores for sd.cpp")

                                        # Add generation parameters from CLI args
                                        # sd_kwargs['sample_method'] = args.image_sample_method  # Not valid for __init__
                                        # sd_kwargs['steps'] = args.image_steps  # Not valid for __init__
                                        
                                        # Define model_key for adding to manager
                                        model_key = f"image:{model_path}"
                                        sd_model = StableDiffusion(**sd_kwargs)
                                        multi_model_manager.add_model(model_key, sd_model)
                                        # Add alias for "image" 
                                        multi_model_manager.add_model("image", sd_model)
                                        
                                        print(f"Image model loaded successfully via sd.cpp: {original_model_name}")
                                    except ImportError as sd_error:
                                        print(f"stable-diffusion-cpp-python not installed: {sd_error}")
                                        print(f"Image model will load on first request")
                                    except Exception as sd_error:
                                        print(f"sd.cpp load error: {sd_error}")
                                        print(f"Image model will load on first request")
                    else:
                        print(f"Could not load GGUF image model: no valid model path")
                        
                except ImportError as e:
                    print(f"Warning: llama_cpp not installed: {e}")
                except Exception as e:
                    print(f"Warning: Failed to pre-load GGUF image model: {e}")
                try:
                    from llama_cpp import Llama
                    from llama_cpp import Llama
                    
                    # Download GGUF model if needed (similar to VulkanBackend)
                    model_path = None
                    if model_name.startswith('http://') or model_name.startswith('https://'):
                        cached_path = get_cached_model_path(model_name)
                        if cached_path:
                            model_path = cached_path
                            print(f"Using cached GGUF model: {model_path}")
                        else:
                            print(f"Downloading GGUF model: {model_name}")
                            cache_dir = get_model_cache_dir()
                            model_path = download_model(model_name, cache_dir)
                    elif os.path.isfile(model_name):
                        model_path = model_name
                        print(f"Loading local GGUF model: {model_path}")
                    else:
                        # Try to download from HuggingFace Hub
                        try:
                            from huggingface_hub import hf_hub_download, list_repo_files
                            parts = model_name.split('/')
                            if len(parts) >= 2:
                                repo_id = f"{parts[0]}/{parts[1]}"
                                files = list_repo_files(repo_id)
                                gguf_files = [f for f in files if f.endswith('.gguf')]
                                if not gguf_files:
                                    raise ValueError(f"No GGUF files found in {repo_id}")
                                filename = gguf_files[0]
                                model_path = hf_hub_download(repo_id=repo_id, filename=filename)
                                print(f"Downloaded GGUF model to: {model_path}")
                        except Exception as e:
                            print(f"Could not resolve GGUF model path: {e}")
                            model_path = None
                    
                    if model_path and os.path.isfile(model_path):
                        # Use the cached path for the model key
                        model_key = f"image:{model_path}"
                        
                        # Load with llama.cpp
                        n_gpu_layers = -1  # Load all layers to GPU
                        n_ctx = 2048
                        
                        llama_model = Llama(
                            model_path=model_path,
                            n_gpu_layers=n_gpu_layers,
                            n_ctx=n_ctx,
                            verbose=False,
                        )
                        multi_model_manager.add_model(model_key, llama_model)
                        print(f"GGUF image model loaded successfully: {model_name}")
                    else:
                        print(f"Could not load GGUF image model: no valid model path")
                        
                except ImportError as e:
                    print(f"Warning: llama_cpp not installed: {e}")
                except Exception as e:
                    print(f"Warning: Failed to pre-load GGUF image model: {e}")
            else:
                # Load diffusers image model (Stable Diffusion)
                try:
                    import torch
                    from diffusers import StableDiffusionXLPipeline, DiffusionPipeline
                    
                    # Use model name directly for diffusers (model_path is only set in GGUF branch)
                    model_key = f"image:{model_name}"
                    print(f"Loading diffusers pipeline: {model_name}")
                    
                    # Try to load as Stable Diffusion XL first
                    try:
                        pipeline = StableDiffusionXLPipeline.from_pretrained(
                            model_name,
                            torch_dtype=torch.float32,
                            use_safetensors=True,
                        )
                    except Exception as e:
                        print(f"SDXL failed, trying generic pipeline: {e}")
                        pipeline = DiffusionPipeline.from_pretrained(
                            model_name,
                            torch_dtype=torch.float32,
                            use_safetensors=True,
                        )
                    
                    if torch.cuda.is_available():
                        pipeline = pipeline.to("cuda")
                        pipeline.enable_attention_slicing()
                    else:
                        pipeline = pipeline.to("cpu")
                    
                    multi_model_manager.add_model(model_key, pipeline)
                    # Add alias for "image"
                    multi_model_manager.add_model("image", pipeline)
                    
                    print(f"Image model loaded successfully: {model_name}")
                    
                except ImportError as e:
                    print(f"Warning: diffusers not installed: {e}")
                except Exception as e:
                    print(f"Warning: Failed to pre-load image model: {e}")
    
    # Register model aliases if specified
    if args.model_aliases:
        print(f"\nRegistering model aliases:")
        for alias, model in args.model_aliases:
            multi_model_manager.set_model_alias(alias, model)
            print(f"  {alias} -> {model}")
    
    # Start the server
    import uvicorn
    print(f"\nStarting server on http://{args.host}:{args.port}")
    print(f"API documentation available at http://{args.host}:{args.port}/docs")
    if model_manager.backend is not None:
        # Show actual backend being used
        actual_backend = model_manager.backend_type
        if hasattr(model_manager.backend, 'force_cuda') and model_manager.backend.force_cuda:
            actual_backend = "cuda (via llama-cpp-python)"
        print(f"Using backend: {actual_backend}")
    
    # Print available models
    models = multi_model_manager.list_models()
    print(f"Available models: {[m.id for m in models]}")
    
    # Run server with or without HTTPS
    if args.https:
        import ssl
        
        # Determine SSL context
        ssl_keyfile = None
        ssl_certfile = None
        
        if args.privkey and args.pubkey:
            # Use provided certificates
            ssl_keyfile = args.privkey
            ssl_certfile = args.pubkey
            print(f"Using HTTPS with custom certificates: {args.pubkey}")
        else:
            # Auto-generate self-signed certificate
            print("Generating self-signed HTTPS certificate...")
            import subprocess
            try:
                # Generate self-signed cert
                cert_path = "./cert.pem"
                key_path = "./key.pem"
                subprocess.run([
                    "openssl", "req", "-x509", "-newkey", "rsa:4096",
                    "-keyout", key_path, "-out", cert_path,
                    "-days", "365", "-nodes",
                    "-subj", "/CN=localhost"
                ], check=True, capture_output=True)
                ssl_keyfile = key_path
                ssl_certfile = cert_path
                print(f"Generated self-signed certificate: {cert_path}")
            except Exception as e:
                print(f"Warning: Could not generate certificate: {e}")
                print("Falling back to HTTP...")
                uvicorn.run(app, host=args.host, port=args.port)
                return
        
        # Run with HTTPS
        ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
        ssl_context.load_cert_chain(ssl_certfile, ssl_keyfile)
        uvicorn.run(app, host=args.host, port=args.port, ssl=ssl_context)
    else:
        uvicorn.run(app, host=args.host, port=args.port)
if __name__ == "__main__":
    main()
