#!/usr/bin/env python3
"""
OpenAI-compatible API server for HuggingFace models (NVIDIA) and GGUF models (Vulkan).
Supports CUDA (NVIDIA) and Vulkan (AMD) GPU backends, memory-aware model loading,
streaming, and tool calling.
"""

import argparse
import asyncio
import json
import os
import re
import sys
import time
import uuid
import warnings
from abc import ABC, abstractmethod
from contextlib import asynccontextmanager
from typing import AsyncGenerator, Dict, List, Optional, Union

import psutil
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field, validator, field_validator
from pydantic_core import PydanticCustomError
from threading import Thread


# =============================================================================
# Backend Detection and Imports
# =============================================================================

def detect_available_backends():
    """Detect which backends are available."""
    backends = {'cpu': True}
    
    # Check for PyTorch/CUDA
    try:
        import torch
        if torch.cuda.is_available():
            backends['nvidia'] = True
    except ImportError:
        pass
    
    # Check for llama-cpp-python (Vulkan)
    try:
        import llama_cpp
        backends['vulkan'] = True
    except ImportError:
        pass
    
    return backends


# =============================================================================
# Flash Attention Detection (for NVIDIA backend)
# =============================================================================

def check_flash_attn_availability() -> bool:
    """Check if flash-attn is installed and available."""
    try:
        import flash_attn
        return True
    except ImportError:
        return False


# =============================================================================
# Pydantic Models for API
# =============================================================================

class ToolFunction(BaseModel):
    name: str
    description: Optional[str] = None
    parameters: Optional[Dict] = None


class Tool(BaseModel):
    type: str = "function"
    function: ToolFunction


class ChatMessage(BaseModel):
    role: str
    content: Optional[Union[str, List[Dict]]] = None
    name: Optional[str] = None
    tool_calls: Optional[List[Dict]] = None
    tool_call_id: Optional[str] = None
    
    @field_validator('content', mode='before')
    @classmethod
    def convert_content_array_to_string(cls, v):
        """Convert multipart content array to string for compatibility."""
        if v is None:
            return None
        if isinstance(v, str):
            return v
        if isinstance(v, list):
            # Handle multipart content array format (e.g., from KiloCode)
            # Format: [{"type": "text", "text": "..."}, {"type": "text", "text": "..."}]
            parts = []
            for item in v:
                if isinstance(item, dict):
                    if item.get('type') == 'text' and 'text' in item:
                        parts.append(item['text'])
                    else:
                        # Handle other content types (image_url, etc.) by converting to placeholder
                        parts.append(f"[{item.get('type', 'unknown')} content]")
                else:
                    parts.append(str(item))
            return '\n'.join(parts)
        return str(v)


class ChatCompletionRequest(BaseModel):
    model: str
    messages: List[ChatMessage]
    temperature: float = 0.7
    top_p: float = 1.0
    n: int = 1
    max_tokens: Optional[int] = None
    stream: bool = False
    stop: Optional[Union[str, List[str]]] = None
    presence_penalty: float = 0.0
    frequency_penalty: float = 0.0
    tools: Optional[List[Tool]] = None
    tool_choice: Optional[Union[str, Dict]] = "auto"
    # Extra fields that clients may send but we ignore
    seed: Optional[int] = None
    logprobs: Optional[bool] = None
    top_logprobs: Optional[int] = None
    response_format: Optional[Dict] = None
    user: Optional[str] = None
    
    class Config:
        extra = "allow"  # Allow extra fields to prevent 422 errors


class CompletionRequest(BaseModel):
    model: str
    prompt: Union[str, List[str]]
    temperature: float = 0.7
    top_p: float = 1.0
    n: int = 1
    max_tokens: Optional[int] = None
    stream: bool = False
    stop: Optional[Union[str, List[str]]] = None
    presence_penalty: float = 0.0
    frequency_penalty: float = 0.0
    # Extra fields that clients may send but we ignore
    seed: Optional[int] = None
    logprobs: Optional[bool] = None
    top_logprobs: Optional[int] = None
    best_of: Optional[int] = None
    echo: Optional[bool] = None
    user: Optional[str] = None
    
    class Config:
        extra = "allow"  # Allow extra fields to prevent 422 errors


class ModelInfo(BaseModel):
    id: str
    object: str = "model"
    created: int = Field(default_factory=lambda: int(time.time()))
    owned_by: str = "huggingface"


class ModelList(BaseModel):
    object: str = "list"
    data: List[ModelInfo]


# =============================================================================
# Content Filtering Utility
# =============================================================================

def filter_malformed_content(text: str) -> str:
    """Filter out malformed SEARCH/REPLACE blocks that the model might output as content."""
    if not text:
        return text
    
    # Remove diff-like blocks that shouldn't be in the output
    filtered = text
    
    # Remove git-style diff markers and SEARCH/REPLACE patterns
    filtered = re.sub(r'<<<<<<<\s+SEARCH.*?=======', '', filtered, flags=re.DOTALL)
    filtered = re.sub(r'=======.*?>>>>>>>\s+REPLACE', '', filtered, flags=re.DOTALL)
    filtered = re.sub(r'>>>>>>>\s+REPLACE', '', filtered)
    
    # Also remove common malformed patterns seen in outputs
    filtered = re.sub(r'<<<<<<<\s+SEARCH\s*:start_line:\d+[^<]*', '', filtered, flags=re.DOTALL)
    filtered = re.sub(r'<button>Stop Generation</button>', '', filtered)
    filtered = re.sub(r'\<\|assistant\|\>', '', filtered)
    filtered = re.sub(r'\</\|assistant\|\>', '', filtered)
    
    # Clean up excessive newlines left from removal
    filtered = re.sub(r'\n{3,}', '\n\n', filtered)
    
    # Don't strip single newlines or whitespace - they might be valid content
    return filtered


# =============================================================================
# Tool Parsing
# =============================================================================

class ToolCallParser:
    """Parse model outputs to extract tool calls."""
    
    def __init__(self, tokenizer=None):
        self.tokenizer = tokenizer
    
    def _parse_nested_xml_tool(self, xml_content: str) -> Optional[Dict]:
        """Parse nested XML tool format like <tool><name>...</name><arguments>...</arguments></tool>."""
        try:
            # Extract name
            name_match = re.search(r'<name>\s*(.*?)\s*</name>', xml_content, re.DOTALL)
            if not name_match:
                return None
            tool_name = name_match.group(1).strip()
            
            # Extract arguments - handle both JSON and nested XML
            args_match = re.search(r'<arguments>\s*(.*?)\s*</arguments>', xml_content, re.DOTALL)
            if not args_match:
                return None
            args_content = args_match.group(1).strip()
            
            # Try to parse as JSON first
            try:
                args_dict = json.loads(args_content)
                return {"name": tool_name, "arguments": args_dict}
            except json.JSONDecodeError:
                # Try to parse as nested XML (e.g., <files><path>...</path></files>)
                try:
                    args_dict = self._xml_to_dict(args_content)
                    return {"name": tool_name, "arguments": args_dict}
                except:
                    # Return raw string as fallback
                    return {"name": tool_name, "arguments": args_content}
        except Exception:
            return None
    
    def _xml_to_dict(self, xml_content: str) -> Dict:
        """Convert simple nested XML to dictionary."""
        result = {}
        # Find all top-level tags
        pattern = r'<(\w+)>\s*(.*?)\s*</\1>'
        matches = re.findall(pattern, xml_content, re.DOTALL)
        
        for tag, content in matches:
            # Check if content has nested tags
            if re.search(r'<\w+>', content):
                # Recursively parse nested content
                try:
                    result[tag] = self._xml_to_dict(content)
                except:
                    # If recursive parsing fails, check for array-like content
                    items = re.findall(r'<(\w+)>\s*(.*?)\s*</\1>', content, re.DOTALL)
                    if items and all(item[0] == items[0][0] for item in items):
                        # Array of items with same tag
                        result[tag] = []
                        for _, item_content in items:
                            if re.search(r'<\w+>', item_content):
                                result[tag].append(self._xml_to_dict(item_content))
                            else:
                                result[tag].append(item_content)
                    else:
                        result[tag] = content
            else:
                result[tag] = content
        
        return result if result else xml_content
    
    def _filter_malformed_content(self, text: str) -> str:
        """Filter out malformed SEARCH/REPLACE blocks - delegates to standalone function."""
        return filter_malformed_content(text)
    
    def extract_tool_calls(self, text: str, available_tools: List[Tool]) -> Optional[List[Dict]]:
        """Extract tool calls from model output."""
        # First filter out malformed content
        text = self._filter_malformed_content(text)
        
        tool_calls = []
        
        # Look for function calls in various formats
        # Format 1: <tool> or <function> tags with JSON content
        tool_pattern = r'<(?:tool|function)>(.*?)</(?:tool|function)>'
        tool_matches = re.findall(tool_pattern, text, re.DOTALL)
        
        for match in tool_matches:
            # Try JSON format first
            try:
                tool_data = json.loads(match.strip())
                if 'name' in tool_data and 'arguments' in tool_data:
                    tool_calls.append({
                        "id": f"call_{uuid.uuid4().hex[:16]}",
                        "type": "function",
                        "function": {
                            "name": tool_data["name"],
                            "arguments": json.dumps(tool_data["arguments"])
                        }
                    })
                    continue
            except json.JSONDecodeError:
                pass
            
            # Try nested XML format (Format 1b)
            try:
                tool_data = self._parse_nested_xml_tool(match)
                if tool_data and 'name' in tool_data and 'arguments' in tool_data:
                    tool_calls.append({
                        "id": f"call_{uuid.uuid4().hex[:16]}",
                        "type": "function",
                        "function": {
                            "name": tool_data["name"],
                            "arguments": json.dumps(tool_data["arguments"]) if isinstance(tool_data["arguments"], dict) else str(tool_data["arguments"])
                        }
                    })
            except Exception:
                pass
        
        # Format 2: JSON with function_call key
        try:
            if "function_call" in text:
                json_match = re.search(r'\{[^}]*"function_call"[^}]*\}', text, re.DOTALL)
                if json_match:
                    data = json.loads(json_match.group())
                    if "function_call" in data:
                        fc = data["function_call"]
                        tool_calls.append({
                            "id": f"call_{uuid.uuid4().hex[:16]}",
                            "type": "function",
                            "function": {
                                "name": fc.get("name", ""),
                                "arguments": fc.get("arguments", "{}")
                            }
                        })
        except (json.JSONDecodeError, AttributeError):
            pass
        
        # Format 3: Direct JSON function call
        for tool in available_tools:
            tool_name = tool.function.name
            pattern = rf'<{tool_name}>(.*?)</{tool_name}>'
            matches = re.findall(pattern, text, re.DOTALL)
            for match in matches:
                try:
                    args = json.loads(match.strip())
                    tool_calls.append({
                        "id": f"call_{uuid.uuid4().hex[:16]}",
                        "type": "function",
                        "function": {
                            "name": tool_name,
                            "arguments": json.dumps(args)
                        }
                    })
                except json.JSONDecodeError:
                    # Try to use the raw text as arguments
                    tool_calls.append({
                        "id": f"call_{uuid.uuid4().hex[:16]}",
                        "type": "function",
                        "function": {
                            "name": tool_name,
                            "arguments": match.strip()
                        }
                    })
        
        return tool_calls if tool_calls else None


def format_tools_for_prompt(tools: List[Tool], messages: List[ChatMessage]) -> List[ChatMessage]:
    """Format tools into the system message or add a tool description."""
    if not tools:
        return messages
    
    tool_descriptions = []
    for tool in tools:
        func = tool.function
        desc = f"Tool: {func.name}"
        if func.description:
            desc += f"\nDescription: {func.description}"
        if func.parameters:
            desc += f"\nParameters: {json.dumps(func.parameters, indent=2)}"
        tool_descriptions.append(desc)
    
    tools_text = "You have access to the following tools:\n\n" + "\n\n".join(tool_descriptions)
    tools_text += "\n\nIMPORTANT: When you need to use a tool, you MUST format your response EXACTLY as:\n"
    tools_text += '<tool>{"name": "tool_name", "arguments": {"param1": "value1", "param2": "value2"}}</tool>'
    tools_text += "\n\nRules:\n"
    tools_text += "1. The content inside <tool> tags must be valid JSON\n"
    tools_text += "2. Do NOT use nested XML tags like <name> or <arguments> - use JSON format only\n"
    tools_text += "3. The 'name' field must match one of the available tool names exactly\n"
    tools_text += "4. The 'arguments' field must be a JSON object with the required parameters\n"
    tools_text += "\nExample:\n"
    tools_text += 'User: Read the file example.txt\n'
    tools_text += 'Assistant: <tool>{"name": "read_file", "arguments": {"files": [{"path": "example.txt"}]}}</tool>'
    
    # Add or prepend to system message
    new_messages = list(messages)
    system_found = False
    
    for i, msg in enumerate(new_messages):
        if msg.role == "system":
            new_messages[i] = ChatMessage(
                role="system",
                content=f"{tools_text}\n\n{msg.content or ''}"
            )
            system_found = True
            break
    
    if not system_found:
        new_messages.insert(0, ChatMessage(role="system", content=tools_text))
    
    return new_messages


# =============================================================================
# Abstract Model Backend
# =============================================================================

class ModelBackend(ABC):
    """Abstract base class for model backends."""
    
    @abstractmethod
    def load_model(self, model_name: str, **kwargs) -> None:
        """Load the model."""
        pass
    
    @abstractmethod
    def generate(self, prompt: str, max_tokens: Optional[int] = None, 
                 temperature: float = 0.7, top_p: float = 1.0,
                 stop: Optional[List[str]] = None) -> str:
        """Generate text non-streaming."""
        pass
    
    @abstractmethod
    def generate_stream(self, prompt: str, max_tokens: Optional[int] = None,
                        temperature: float = 0.7, top_p: float = 1.0,
                        stop: Optional[List[str]] = None) -> AsyncGenerator[str, None]:
        """Generate text in streaming fashion."""
        pass
    
    @abstractmethod
    def format_messages(self, messages: List[ChatMessage]) -> str:
        """Format messages into a prompt string."""
        pass
    
    @abstractmethod
    def get_model_name(self) -> str:
        """Return the loaded model name."""
        pass
    
    @abstractmethod
    def cleanup(self) -> None:
        """Cleanup resources."""
        pass


# =============================================================================
# NVIDIA/HuggingFace Backend
# =============================================================================

class NvidiaBackend(ModelBackend):
    """Backend for NVIDIA GPUs using HuggingFace Transformers."""
    
    def __init__(self):
        self.model = None
        self.tokenizer = None
        self.model_name = None
        self.device = None
        self.use_flash_attn = False
        self.flash_attn_available = False
        
    def check_flash_attn_support(self) -> None:
        """Check and print Flash Attention availability status."""
        self.flash_attn_available = check_flash_attn_availability()
        if self.use_flash_attn:
            if self.flash_attn_available:
                print("Flash Attention 2: Available and enabled")
            else:
                print("Warning: Flash Attention 2 requested but not installed")
                print("Install with: pip install flash-attn --no-build-isolation")
                print("Falling back to standard attention")
                self.use_flash_attn = False
    
    def _detect_device(self) -> str:
        """Auto-detect available GPU or fall back to CPU."""
        import torch
        if torch.cuda.is_available():
            # Check for ROCm (HIP)
            if hasattr(torch.version, 'hip') and torch.version.hip is not None:
                print(f"ROCm/HIP detected: {torch.version.hip}")
                return "cuda"
            else:
                print(f"CUDA detected: {torch.version.cuda}")
                return "cuda"
        else:
            print("No GPU detected, using CPU")
            return "cpu"
    
    def _get_available_vram(self) -> int:
        """Get available VRAM in bytes. Returns 0 if no GPU available."""
        import torch
        if not torch.cuda.is_available():
            return 0
        
        try:
            total_vram = 0
            for i in range(torch.cuda.device_count()):
                props = torch.cuda.get_device_properties(i)
                total_vram += props.total_memory
            return total_vram
        except Exception as e:
            print(f"Warning: Could not detect VRAM: {e}")
            return 0
    
    def _estimate_model_size(self, model_name: str) -> Optional[int]:
        """Estimate model size in bytes from config."""
        from transformers import AutoConfig
        try:
            config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
            
            # Get model parameters from config
            if hasattr(config, 'num_parameters'):
                num_params = config.num_parameters
            elif hasattr(config, 'n_params'):
                num_params = config.n_params
            elif hasattr(config, 'num_hidden_layers') and hasattr(config, 'hidden_size'):
                layers = config.num_hidden_layers
                hidden = config.hidden_size
                vocab_size = getattr(config, 'vocab_size', 50000)
                num_params = (vocab_size * hidden_size) + (layers * 4 * hidden * hidden)
            else:
                return None
            
            # Assume float16 (2 bytes per parameter)
            return num_params * 2
        except Exception as e:
            print(f"Warning: Could not estimate model size: {e}")
            return None
    
    def _get_gpu_memory_map(self) -> Dict:
        """Get max_memory dict for Accelerate with 93% GPU limit, then CPU, then disk."""
        import torch
        max_memory = {}
        
        # GPU memory: 93% of available VRAM per GPU
        if torch.cuda.is_available():
            for i in range(torch.cuda.device_count()):
                props = torch.cuda.get_device_properties(i)
                total_vram = props.total_memory
                # Leave 7% headroom for CUDA overhead (changed from 0.1% to 7%)
                usable_vram = int(total_vram * 0.93)
                max_memory[i] = usable_vram
                print(f"  GPU {i}: {total_vram / 1e9:.1f}GB total, {usable_vram / 1e9:.1f}GB usable")
        
        # CPU memory: use manual limit or auto-detect
        manual_ram_gb = self._pending_ram_gb
        if manual_ram_gb:
            # Convert GB to bytes
            max_memory['cpu'] = int(manual_ram_gb * 1e9)
            print(f"  CPU: {manual_ram_gb}GB (user specified)")
        else:
            # Auto-detect available system RAM, leave 4GB for system
            import psutil
            available_ram = psutil.virtual_memory().available
            usable_ram = max(0, available_ram - int(4e9))  # Leave 4GB for OS
            max_memory['cpu'] = usable_ram
            print(f"  CPU: {usable_ram / 1e9:.1f}GB (auto-detected, 4GB reserved for system)")
        
        return max_memory
    
    def _try_load_model(self, model_name: str, load_kwargs: dict, device: str) -> Optional[any]:
        """Try to load model with given settings, return None on OOM or retry without quantization if unsupported."""
        import torch
        from transformers import AutoModelForCausalLM
        
        try:
            model = AutoModelForCausalLM.from_pretrained(model_name, **load_kwargs)
            if device == "cpu" and load_kwargs.get('device_map') is None:
                model = model.to(device)
            return model
        except (RuntimeError, torch.cuda.OutOfMemoryError) as e:
            error_msg = str(e).lower()
            if "out of memory" in error_msg or "cuda" in error_msg or "oom" in error_msg:
                return None
            raise
        except TypeError as e:
            error_msg = str(e).lower()
            # Check if the error is about unsupported quantization arguments
            if "load_in_4bit" in error_msg or "load_in_8bit" in error_msg or "unexpected keyword argument" in error_msg:
                # Check if we have quantization args that need to be removed
                if 'load_in_4bit' in load_kwargs or 'load_in_8bit' in load_kwargs:
                    print(f"Warning: Model does not support bitsandbytes quantization (load_in_4bit/load_in_8bit)")
                    print("Retrying without quantization...")
                    # Create a copy of load_kwargs without quantization args
                    retry_kwargs = load_kwargs.copy()
                    retry_kwargs.pop('load_in_4bit', None)
                    retry_kwargs.pop('load_in_8bit', None)
                    # Retry loading without quantization
                    try:
                        model = AutoModelForCausalLM.from_pretrained(model_name, **retry_kwargs)
                        if device == "cpu" and retry_kwargs.get('device_map') is None:
                            model = model.to(device)
                        print("Model loaded successfully without quantization")
                        return model
                    except (RuntimeError, torch.cuda.OutOfMemoryError) as e2:
                        error_msg2 = str(e2).lower()
                        if "out of memory" in error_msg2 or "cuda" in error_msg2 or "oom" in error_msg2:
                            return None
                        raise
                    except TypeError:
                        # If it still fails with TypeError, re-raise the original error
                        raise e
            # Re-raise if not a quantization-related error
            raise
    
    def _is_moe_model(self, model_name: str) -> bool:
        """Check if model is a MoE (Mixture of Experts) model which needs more VRAM headroom."""
        moe_indicators = ['moe', 'mixtral', 'qwen3_5_moe', 'qwen3.5_moe', 'expert', 'a3b']
        model_name_lower = model_name.lower()
        return any(indicator in model_name_lower for indicator in moe_indicators)
    
    def _get_vram_percentages_for_strategy(self, strategy: str, is_moe: bool, total_vram_gb: float) -> list:
        """Get VRAM percentage steps based on offload strategy."""
        if strategy == "conservative":
            print(f"  Using conservative offload strategy - minimal VRAM usage for maximum stability")
            if is_moe:
                return [0.70, 0.65, 0.60, 0.50, 0.40, 0.30, 0.20, 0.0]
            return [0.80, 0.75, 0.70, 0.65, 0.50, 0.40, 0.30, 0.20, 0.0]
        elif strategy == "balanced":
            print(f"  Using balanced offload strategy - good performance with reasonable stability")
            if is_moe:
                return [0.75, 0.70, 0.65, 0.60, 0.50, 0.40, 0.30, 0.20, 0.0]
            return [0.85, 0.80, 0.75, 0.70, 0.65, 0.50, 0.40, 0.30, 0.20, 0.0]
        elif strategy == "aggressive":
            print(f"  Using aggressive offload strategy - maximize VRAM usage for performance")
            if is_moe:
                return [0.85, 0.80, 0.75, 0.70, 0.65, 0.60, 0.50, 0.40, 0.30, 0.20, 0.0]
            return [0.95, 0.90, 0.85, 0.80, 0.75, 0.70, 0.65, 0.50, 0.40, 0.30, 0.20, 0.0]
        elif strategy == "sequential":
            print(f"  Using sequential offload strategy - fine-grained incremental VRAM reduction")
            # Fine-grained steps with 2% increments for precise memory management
            if is_moe:
                return [0.80, 0.78, 0.76, 0.74, 0.72, 0.70, 0.68, 0.66, 0.64, 0.62, 0.60, 0.55, 0.50, 0.45, 0.40, 0.35, 0.30, 0.25, 0.20, 0.0]
            return [0.93, 0.91, 0.89, 0.87, 0.85, 0.83, 0.81, 0.79, 0.77, 0.75, 0.73, 0.71, 0.69, 0.67, 0.65, 0.60, 0.55, 0.50, 0.45, 0.40, 0.35, 0.30, 0.20, 0.0]
        else:  # auto
            if total_vram_gb < 3:
                print(f"  Detected small GPU ({total_vram_gb:.1f}GB), using aggressive VRAM usage (99% start)")
                return [0.99, 0.95, 0.90, 0.85, 0.75, 0.65, 0.50, 0.35, 0.20, 0.0]
            elif total_vram_gb <= 8:
                print(f"  Detected medium GPU ({total_vram_gb:.1f}GB), using high VRAM usage (96% start)")
                return [0.96, 0.90, 0.85, 0.75, 0.65, 0.50, 0.35, 0.20, 0.0]
            else:
                if is_moe:
                    print(f"  Detected large GPU ({total_vram_gb:.1f}GB), using MoE-safe VRAM usage (80% start)")
                    return [0.80, 0.75, 0.70, 0.65, 0.60, 0.50, 0.40, 0.30, 0.20, 0.0]
                else:
                    print(f"  Detected large GPU ({total_vram_gb:.1f}GB), using conservative VRAM usage (93% start)")
                    return [0.93, 0.85, 0.75, 0.65, 0.50, 0.35, 0.20, 0.0]
    
    def _get_vram_percentages_for_gpu(self, model_name: str = "", strategy: str = "auto", max_gpu_percent: float = None) -> list:
        """Get VRAM percentage steps based on GPU memory size, model type, and offload strategy."""
        import torch
        
        if not torch.cuda.is_available():
            return [0.0]  # CPU only
        
        # If max_gpu_percent is specified, use it to create custom percentage steps
        if max_gpu_percent is not None:
            # Clamp to valid range (5-100%)
            max_pct = max(0.05, min(1.0, max_gpu_percent / 100.0))
            print(f"  Using custom max GPU percent: {max_pct*100:.0f}%")
            # Create a descending series from max_pct down to 0
            steps = []
            current = max_pct
            while current > 0.05:
                steps.append(current)
                # Reduce by 5% each step, or smaller steps near the end
                if current > 0.3:
                    current -= 0.05
                elif current > 0.15:
                    current -= 0.03
                else:
                    current -= 0.02
            steps.append(0.0)
            return steps
        
        # Get total VRAM of the first GPU
        total_vram_gb = 0
        for i in range(torch.cuda.device_count()):
            props = torch.cuda.get_device_properties(i)
            total_vram_gb += props.total_memory / 1e9
        
        # Check if MoE model (needs more headroom for generation)
        is_moe = self._is_moe_model(model_name)
        if is_moe:
            print(f"  Detected MoE model, using extra conservative VRAM limits for generation headroom")
        
        return self._get_vram_percentages_for_strategy(strategy, is_moe, total_vram_gb)
    
    def load_model(self, model_name: str, **kwargs) -> None:
        """Load the model using HuggingFace Transformers with automatic OOM handling."""
        import torch
        from transformers import AutoModelForCausalLM, AutoTokenizer
        
        offload_dir = kwargs.get('offload_dir')
        load_in_4bit = kwargs.get('load_in_4bit', False)
        load_in_8bit = kwargs.get('load_in_8bit', False)
        manual_ram_gb = kwargs.get('manual_ram_gb')
        flash_attn = kwargs.get('flash_attn', False)
        offload_strategy = kwargs.get('offload_strategy', 'auto')
        max_gpu_percent = kwargs.get('max_gpu_percent', None)
        
        # Store RAM limit for use in _get_gpu_memory_map
        self._pending_ram_gb = manual_ram_gb
        
        print(f"Loading HuggingFace model: {model_name}")
        
        self.use_flash_attn = flash_attn
        self.check_flash_attn_support()
        
        self.device = self._detect_device()
        
        # Load tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(
            model_name,
            trust_remote_code=True,
            padding_side="left"
        )
        
        # Set pad token if not present
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
        
        # Prepare model loading arguments
        load_kwargs = {'trust_remote_code': True}
        
        # Check if model supports quantization
        if load_in_4bit or load_in_8bit:
            # Qwen3.5-A3B/MoE models don't support bitsandbytes quantization
            if 'qwen3.5' in model_name.lower() and ('a3b' in model_name.lower() or 'moe' in model_name.lower()):
                print(f"Warning: {model_name} does not support bitsandbytes quantization (load_in_4bit/load_in_8bit)")
                print("Quantization disabled for this model")
            else:
                try:
                    import bitsandbytes as bnb
                    print(f"Using {4 if load_in_4bit else 8}-bit quantization")
                    load_kwargs['load_in_4bit'] = load_in_4bit
                    load_kwargs['load_in_8bit'] = load_in_8bit
                except ImportError:
                    print("Warning: bitsandbytes not installed. Quantization disabled.")
        
        # Set dtype
        if self.device == "cuda":
            load_kwargs['torch_dtype'] = torch.float16
        else:
            load_kwargs['torch_dtype'] = torch.float32
        
        # Add offload folder if specified (disk offloading is last resort)
        if offload_dir:
            os.makedirs(offload_dir, exist_ok=True)
            load_kwargs['offload_folder'] = offload_dir
        
        # Add Flash Attention 2 if enabled
        if self.use_flash_attn and self.flash_attn_available:
            load_kwargs['attn_implementation'] = "flash_attention_2"
            print("Using Flash Attention 2")
        
        # Try loading with automatic fallback on OOM
        model = None
        vram_percentages = self._get_vram_percentages_for_gpu(model_name, offload_strategy, max_gpu_percent)
        first_vram_pct = vram_percentages[0] if vram_percentages else 0.93
        
        for vram_pct in vram_percentages:
            if self.device != "cuda":
                # CPU-only mode
                load_kwargs['device_map'] = None
                print("Loading model in CPU-only mode...")
                model = self._try_load_model(model_name, load_kwargs, self.device)
                if model is not None:
                    break
            
            # GPU mode with varying VRAM limits
            if vram_pct > 0:
                max_memory = self._get_gpu_memory_map_with_limit(vram_pct)
                load_kwargs['max_memory'] = max_memory
                load_kwargs['device_map'] = 'auto'
                print(f"\nTrying with GPU limit: {vram_pct*100:.0f}% VRAM")
                if offload_dir:
                    print(f"  Disk offload directory: {offload_dir}")
                
                model = self._try_load_model(model_name, load_kwargs, self.device)
                
                if model is not None:
                    print(f"  ✓ Model loaded successfully with {vram_pct*100:.0f}% GPU VRAM limit")
                    if vram_pct < first_vram_pct:
                        print(f"  (Reduced from {first_vram_pct*100:.0f}% due to memory constraints)")
                    break
                else:
                    print(f"  ✗ Out of memory with {vram_pct*100:.0f}% GPU VRAM, trying lower limit...")
                    # Clear CUDA cache before retry
                    if torch.cuda.is_available():
                        torch.cuda.empty_cache()
            else:
                # Last resort: CPU-only mode with offloading
                print("\nFalling back to CPU-only mode (no GPU layers)...")
                load_kwargs['max_memory'] = {0: 0, 'cpu': int((manual_ram_gb or 48) * 1e9)}
                load_kwargs['device_map'] = 'auto'
                model = self._try_load_model(model_name, load_kwargs, "cpu")
                if model is not None:
                    print("  ✓ Model loaded successfully on CPU")
                    break
        
        if model is None:
            raise RuntimeError("Failed to load model: Out of memory even with minimum GPU usage")
        
        self.model = model
        self.model.eval()
        self.model_name = model_name
        
        print(f"\nModel loaded successfully")
        print(f"Model device: {next(self.model.parameters()).device}")
    
    def _get_gpu_memory_map_with_limit(self, vram_fraction: float) -> Dict:
        """Get max_memory dict with specified VRAM fraction limit."""
        import torch
        max_memory = {}
        
        if torch.cuda.is_available():
            for i in range(torch.cuda.device_count()):
                props = torch.cuda.get_device_properties(i)
                total_vram = props.total_memory
                usable_vram = int(total_vram * vram_fraction)
                max_memory[i] = usable_vram
        
        # CPU memory
        manual_ram_gb = getattr(self, '_pending_ram_gb', None)
        if manual_ram_gb:
            max_memory['cpu'] = int(manual_ram_gb * 1e9)
        else:
            import psutil
            available_ram = psutil.virtual_memory().available
            usable_ram = max(0, available_ram - int(4e9))
            max_memory['cpu'] = usable_ram
        
        return max_memory
    
    def format_messages(self, messages: List[ChatMessage]) -> str:
        """Format messages into a prompt string."""
        formatted = []
        
        for msg in messages:
            if msg.role == "system":
                formatted.append(f"System: {msg.content}")
            elif msg.role == "user":
                formatted.append(f"User: {msg.content}")
            elif msg.role == "assistant":
                content = msg.content or ""
                if msg.tool_calls:
                    for tc in msg.tool_calls:
                        if tc.get("function"):
                            func = tc["function"]
                            content += f'\n<tool>{{"name": "{func.get("name", "")}", "arguments": {func.get("arguments", "{}")}}}</tool>'
                formatted.append(f"Assistant: {content}")
            elif msg.role == "tool":
                formatted.append(f"Tool ({msg.name}): {msg.content}")
        
        formatted.append("Assistant:")
        return "\n\n".join(formatted)
    
    def _validate_params(self, temperature: float, top_p: float) -> tuple:
        """Validate generation parameters."""
        if temperature <= 0:
            temperature = 1.0
            do_sample = False
        else:
            temperature = max(0.01, min(temperature, 2.0))
            do_sample = True
        top_p = max(0.0, min(top_p, 1.0))
        return temperature, top_p, do_sample
    
    def generate(self, prompt: str, max_tokens: Optional[int] = None,
                 temperature: float = 0.7, top_p: float = 1.0,
                 stop: Optional[List[str]] = None) -> str:
        """Generate text non-streaming."""
        import torch
        from transformers import LogitsProcessor, LogitsProcessorList
        
        class InvalidLogitsProcessor(LogitsProcessor):
            def __call__(self, input_ids, scores):
                scores = torch.where(torch.isnan(scores), torch.tensor(-1e9, dtype=scores.dtype, device=scores.device), scores)
                scores = torch.where(torch.isinf(scores), torch.tensor(1e9, dtype=scores.dtype, device=scores.device), scores)
                return scores
        
        inputs = self.tokenizer(prompt, return_tensors="pt", padding=True)
        inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
        
        if max_tokens is None:
            max_tokens = 512
        
        temperature, top_p, do_sample = self._validate_params(temperature, top_p)
        
        # Try generation with OOM handling
        try:
            with torch.no_grad():
                outputs = self.model.generate(
                    input_ids=inputs["input_ids"],
                    attention_mask=inputs["attention_mask"],
                    max_new_tokens=max_tokens,
                    temperature=temperature if do_sample else None,
                    top_p=top_p if do_sample else None,
                    do_sample=do_sample,
                    pad_token_id=self.tokenizer.pad_token_id,
                    eos_token_id=self.tokenizer.eos_token_id,
                    logits_processor=LogitsProcessorList([InvalidLogitsProcessor()]),
                )
            
            generated_tokens = outputs[0][inputs["input_ids"].shape[1]:]
            return self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
        except (RuntimeError, torch.cuda.OutOfMemoryError) as e:
            error_msg = str(e).lower()
            if "out of memory" in error_msg or "cuda" in error_msg or "oom" in error_msg:
                print(f"Warning: CUDA OOM during generation. Clearing cache and retrying with reduced tokens...")
                torch.cuda.empty_cache()
                # Retry with half the tokens
                try:
                    with torch.no_grad():
                        outputs = self.model.generate(
                            input_ids=inputs["input_ids"],
                            attention_mask=inputs["attention_mask"],
                            max_new_tokens=max(1, max_tokens // 2),
                            temperature=temperature if do_sample else None,
                            top_p=top_p if do_sample else None,
                            do_sample=do_sample,
                            pad_token_id=self.tokenizer.pad_token_id,
                            eos_token_id=self.tokenizer.eos_token_id,
                            logits_processor=LogitsProcessorList([InvalidLogitsProcessor()]),
                        )
                    
                    generated_tokens = outputs[0][inputs["input_ids"].shape[1]:]
                    return self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
                except Exception as e2:
                    print(f"Error: Generation failed even with reduced tokens: {e2}")
                    return "[Error: Out of memory during generation. Try reducing --max-gpu-percent or using a smaller model.]"
            raise
    
    async def generate_stream(self, prompt: str, max_tokens: Optional[int] = None,
                              temperature: float = 0.7, top_p: float = 1.0,
                              stop: Optional[List[str]] = None) -> AsyncGenerator[str, None]:
        """Generate text in streaming fashion."""
        import torch
        from transformers import TextIteratorStreamer, LogitsProcessor, LogitsProcessorList, StoppingCriteria, StoppingCriteriaList
        
        class InvalidLogitsProcessor(LogitsProcessor):
            def __call__(self, input_ids, scores):
                scores = torch.where(torch.isnan(scores), torch.tensor(-1e9, dtype=scores.dtype, device=scores.device), scores)
                scores = torch.where(torch.isinf(scores), torch.tensor(1e9, dtype=scores.dtype, device=scores.device), scores)
                return scores
        
        inputs = self.tokenizer(prompt, return_tensors="pt", padding=True)
        inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
        
        if max_tokens is None:
            max_tokens = 512
        
        temperature, top_p, do_sample = self._validate_params(temperature, top_p)
        
        streamer = TextIteratorStreamer(
            self.tokenizer,
            skip_prompt=True,
            skip_special_tokens=True,
        )
        
        generation_kwargs = {
            "input_ids": inputs["input_ids"],
            "attention_mask": inputs["attention_mask"],
            "max_new_tokens": max_tokens,
            "temperature": temperature,
            "top_p": top_p,
            "do_sample": do_sample,
            "streamer": streamer,
            "pad_token_id": self.tokenizer.pad_token_id,
            "eos_token_id": self.tokenizer.eos_token_id,
            "logits_processor": LogitsProcessorList([InvalidLogitsProcessor()]),
        }
        
        # Handle stop sequences
        if stop:
            class StopOnSequence(StoppingCriteria):
                def __init__(self, stop_sequences, tokenizer):
                    self.stop_sequences = stop_sequences
                    self.tokenizer = tokenizer
                
                def __call__(self, input_ids, scores, **kwargs):
                    decoded = self.tokenizer.decode(input_ids[0][-20:], skip_special_tokens=True)
                    return any(seq in decoded for seq in self.stop_sequences)
            
            generation_kwargs["stopping_criteria"] = StoppingCriteriaList([
                StopOnSequence(stop, self.tokenizer)
            ])
        
        # Run generation in a separate thread with OOM handling
        generation_error = None
        
        def generate_with_error_handling():
            nonlocal generation_error
            try:
                self.model.generate(**generation_kwargs)
            except (RuntimeError, torch.cuda.OutOfMemoryError) as e:
                error_msg = str(e).lower()
                if "out of memory" in error_msg or "cuda" in error_msg or "oom" in error_msg:
                    generation_error = "oom"
                    print(f"Warning: CUDA OOM during streaming generation. Clearing cache...")
                    torch.cuda.empty_cache()
                else:
                    generation_error = str(e)
        
        thread = Thread(target=generate_with_error_handling)
        thread.start()
        
        try:
            for text in streamer:
                yield text
        except Exception as e:
            print(f"Error during stream iteration: {e}")
        
        thread.join()
        
        # Check if there was an OOM error
        if generation_error == "oom":
            yield "\n[Warning: Generation stopped due to out-of-memory. Try reducing --max-gpu-percent.]"
        elif generation_error:
            yield f"\n[Error during generation: {generation_error}]"
    
    def get_model_name(self) -> str:
        return self.model_name or "unknown"
    
    def cleanup(self) -> None:
        import torch
        if self.model is not None:
            del self.model
            del self.tokenizer
            self.model = None
            self.tokenizer = None
            if torch.cuda.is_available():
                torch.cuda.empty_cache()


# =============================================================================
# Vulkan Backend (llama-cpp-python)
# =============================================================================

class VulkanBackend(ModelBackend):
    """Backend for Vulkan (AMD GPUs) using llama-cpp-python with GGUF models."""
    
    def __init__(self):
        self.model = None
        self.model_name = None
        self.n_gpu_layers = -1  # Offload all layers to GPU by default
        self.n_ctx = 2048
        self.verbose = True
        self.main_gpu = 0  # Default to first GPU
        self.chat_template = None  # Detected chat template name
        self._detect_chat_template()
    
    def _detect_chat_template(self):
        """Detect the chat template used by the model."""
        try:
            # Try to get the chat template from the model
            # llama.cpp models have a chat_template attribute
            from llama_cpp.llama_chat_format import ChatFormatterResponse
            # We'll detect it when the model is loaded
            self.chat_template = "unknown"
            print("DEBUG: Chat template detection will happen after model load")
        except Exception as e:
            print(f"DEBUG: Could not initialize chat template detection: {e}")
            self.chat_template = None
    
    def _finalize_chat_template_detection(self):
        """Finalize chat template detection after model is loaded."""
        try:
            # Try to get the chat template name from the model's chat formatter
            if hasattr(self.model, 'tokenizer') and self.model.tokenizer:
                tokenizer = self.model.tokenizer
                # Check if there's a chat_template attribute
                if hasattr(tokenizer, 'chat_template'):
                    template = tokenizer.chat_template
                    if template:
                        # Detect common templates
                        template_str = str(template)
                        if 'qwen' in template_str.lower():
                            self.chat_template = "qwen"
                        elif 'phi' in template_str.lower():
                            self.chat_template = "phi"
                        elif 'llama3' in template_str.lower() or 'llama-3' in template_str.lower():
                            self.chat_template = "llama3"
                        elif 'chatml' in template_str.lower():
                            self.chat_template = "chatml"
                        else:
                            self.chat_template = "default"
                        print(f"DEBUG: Detected chat template: {self.chat_template}")
                        return
            
            # Try a test message to see what format works
            test_messages = [{"role": "user", "content": "test"}]
            try:
                self.model.create_chat_completion(messages=test_messages, max_tokens=1)
                self.chat_template = "default"
                print("DEBUG: Chat template detected via test: default")
            except Exception as e:
                error_str = str(e).lower()
                if 'jinja' in error_str:
                    # Jinja template issue - try without tools
                    self.chat_template = "jinja_fallback"
                    print("DEBUG: Chat template detected: jinja_fallback (will use manual formatting)")
                else:
                    self.chat_template = "unknown"
                    print(f"DEBUG: Chat template detection failed: {e}")
        except Exception as e:
            self.chat_template = "unknown"
            print(f"DEBUG: Final chat template detection error: {e}")
        
    def list_vulkan_devices(self):
        """List available Vulkan GPU devices."""
        try:
            # Try to get device info via vulkaninfo or similar
            import subprocess
            result = subprocess.run(['vulkaninfo', '--summary'], capture_output=True, text=True)
            if result.returncode == 0:
                print("\nAvailable Vulkan devices:")
                print(result.stdout)
        except Exception:
            pass
    
    def count_vulkan_devices(self):
        """Count the number of Vulkan GPU devices available."""
        # llama.cpp filters out some devices (like CPU llvmpipe), so we need to
        # count only the devices that llama.cpp will actually use
        try:
            import subprocess
            result = subprocess.run(['vulkaninfo', '--summary'], capture_output=True, text=True)
            if result.returncode == 0:
                # Count GPU0, GPU1, etc. entries (these are actual GPUs, not CPU)
                import re
                gpu_matches = re.findall(r'^GPU\d+:', result.stdout, re.MULTILINE)
                # Filter out CPU devices (llvmpipe)
                non_cpu_gpus = 0
                for i, match in enumerate(gpu_matches):
                    # Check if this GPU is a CPU device
                    gpu_section = result.stdout.split(match)[1].split('\nGPU')[0] if i < len(gpu_matches) - 1 else result.stdout.split(match)[1]
                    if 'llvmpipe' not in gpu_section and 'CPU' not in gpu_section:
                        non_cpu_gpus += 1
                return max(non_cpu_gpus, 1)
        except:
            pass
        return 2  # Default to 2 (common for NVIDIA + AMD setups)
    
    def load_model(self, model_name: str, **kwargs) -> None:
        """Load a GGUF model using llama-cpp-python."""
        from llama_cpp import Llama
        
        # model_name should be a path to a .gguf file or a HuggingFace model ID
        # that will be resolved to a GGUF file
        
        n_gpu_layers = kwargs.get('n_gpu_layers', -1)
        n_ctx = kwargs.get('n_ctx', 2048)
        verbose = kwargs.get('verbose', True)
        main_gpu = kwargs.get('main_gpu', 0)
        self.main_gpu = main_gpu
        
        # Check if model_name is a local file
        if os.path.isfile(model_name):
            model_path = model_name
            print(f"Loading local GGUF model: {model_path}")
        else:
            # Try to download from HuggingFace Hub
            print(f"Attempting to download GGUF model: {model_name}")
            try:
                from huggingface_hub import hf_hub_download, list_repo_files
                
                # Parse model name (format: "org/model" or "org/model/filename.gguf")
                parts = model_name.split('/')
                if len(parts) >= 2:
                    repo_id = f"{parts[0]}/{parts[1]}"
                    
                    # If specific file provided
                    if len(parts) >= 3 and parts[-1].endswith('.gguf'):
                        filename = '/'.join(parts[2:])
                    else:
                        # Find GGUF files in the repo
                        files = list_repo_files(repo_id)
                        gguf_files = [f for f in files if f.endswith('.gguf')]
                        if not gguf_files:
                            raise ValueError(f"No GGUF files found in {repo_id}")
                        # Prefer Q4_K_M quantized models for good balance
                        preferred = [f for f in gguf_files if 'Q4_K_M' in f or 'q4_k_m' in f.lower()]
                        if preferred:
                            filename = preferred[0]
                        else:
                            filename = gguf_files[0]
                        print(f"Selected GGUF file: {filename}")
                    
                    model_path = hf_hub_download(repo_id=repo_id, filename=filename)
                    print(f"Downloaded to: {model_path}")
                else:
                    raise ValueError(f"Invalid model name format: {model_name}")
            except Exception as e:
                print(f"Error downloading model: {e}")
                print("Please provide a local path to a .gguf file")
                raise
        
        print(f"Loading GGUF model with Vulkan support...")
        print(f"  Model path: {model_path}")
        print(f"  GPU layers: {n_gpu_layers} (-1 = all layers)")
        print(f"  Context size: {n_ctx}")
        print(f"  GPU device: {main_gpu}")
        
        # List available devices for user reference
        self.list_vulkan_devices()
        
        # Check if single GPU mode is requested
        single_gpu = kwargs.get('single_gpu', False)
        tensor_split = None
        
        # First, get the number of Vulkan devices from llama.cpp's perspective
        # We'll try to detect from ggml_vulkan output by checking available GPUs
        num_devices = 2  # Default
        
        # Try to parse vulkaninfo to get actual device count
        try:
            import subprocess
            result = subprocess.run(['vulkaninfo', '--summary'], capture_output=True, text=True)
            if result.returncode == 0:
                # Count actual GPU devices (exclude llvmpipe CPU)
                import re
                lines = result.stdout.split('\n')
                gpu_count = 0
                for i, line in enumerate(lines):
                    if line.strip().startswith('GPU'):
                        # Check next few lines for device type
                        section = '\n'.join(lines[i:i+10])
                        if 'llvmpipe' not in section.lower() and 'cpu' not in section.split('deviceType')[0] if 'deviceType' in result.stdout else '':
                            gpu_count += 1
                if gpu_count > 0:
                    num_devices = gpu_count
        except Exception as e:
            print(f"Warning: Could not detect Vulkan device count: {e}")
        
        print(f"DEBUG: Detected {num_devices} Vulkan GPU devices")
        
        # Also try to set GGML_VULKAN_DEVICE env var to force the device
        # This affects which GPU does the actual computation
        if main_gpu >= 0:
            os.environ['GGML_VULKAN_DEVICE'] = str(main_gpu)
            print(f"DEBUG: Set GGML_VULKAN_DEVICE={main_gpu}")
        
        if single_gpu:
            # Build tensor_split to force all layers onto one GPU
            # tensor_split is a list where index = GPU device, value = weight (0.0 = don't use)
            tensor_split = [0.0] * num_devices
            if main_gpu < num_devices:
                tensor_split[main_gpu] = 1.0
                print(f"  Single GPU mode: Setting tensor_split for GPU {main_gpu}: {tensor_split}")
            else:
                print(f"Warning: main_gpu={main_gpu} exceeds detected devices ({num_devices}), ignoring single_gpu")
                tensor_split = None
        
        try:
            llama_kwargs = {
                'model_path': model_path,
                'n_gpu_layers': n_gpu_layers,
                'n_ctx': n_ctx,
                'verbose': verbose,
                'main_gpu': main_gpu,
            }
            
            if tensor_split:
                llama_kwargs['tensor_split'] = tensor_split
            
            self.model = Llama(**llama_kwargs)
            self.model_name = model_name
            print("\nModel loaded successfully with Vulkan!")
            
            # Detect the chat template after model load
            self._finalize_chat_template_detection()
            print(f"DEBUG: Chat template: {self.chat_template}")
        except Exception as e:
            print(f"Error loading model with Vulkan: {e}")
            print("Make sure Vulkan drivers are installed:")
            print("  Debian/Ubuntu: sudo apt install libvulkan-dev vulkan-tools")
            print("  Fedora: sudo dnf install vulkan-loader-devel vulkan-tools")
            raise
    
    def format_messages(self, messages: List[ChatMessage]) -> str:
        """Format messages into a prompt string suitable for chat models.
        
        Uses llama.cpp's built-in chat template support for proper formatting.
        """
        # Convert to format expected by llama.cpp
        chat_messages = []
        for msg in messages:
            chat_msg = {"role": msg.role}
            if msg.content:
                chat_msg["content"] = msg.content
            if msg.tool_calls:
                chat_msg["tool_calls"] = msg.tool_calls
            chat_messages.append(chat_msg)
        
        # Use llama.cpp's apply_chat_template if available
        try:
            prompt = self.model.apply_chat_template(chat_messages, tokenize=False, add_generation_prompt=True)
            return prompt
        except Exception as e:
            # Fallback to manual formatting if apply_chat_template fails
            print(f"Warning: apply_chat_template failed ({e}), using fallback formatting")
            formatted = []
            for msg in messages:
                if msg.role == "system":
                    formatted.append(f"<|im_start|>system\n{msg.content}<|im_end|>")
                elif msg.role == "user":
                    formatted.append(f"<|im_start|>user\n{msg.content}<|im_end|>")
                elif msg.role == "assistant":
                    content = msg.content or ""
                    formatted.append(f"<|im_start|>assistant\n{content}<|im_end|>")
            
            formatted.append("<|im_start|>assistant\n")
            return "\n".join(formatted)
    
    def generate(self, prompt: str, max_tokens: Optional[int] = None,
                 temperature: float = 0.7, top_p: float = 1.0,
                 stop: Optional[List[str]] = None) -> str:
        """Generate text non-streaming using llama-cpp."""
        if max_tokens is None:
            max_tokens = 512
        
        output = self.model(
            prompt,
            max_tokens=max_tokens,
            temperature=temperature,
            top_p=top_p,
            stop=stop or [],
        )
        
        return output["choices"][0]["text"]
    
    def generate_chat(self, messages: List[Dict], max_tokens: Optional[int] = None,
                      temperature: float = 0.7, top_p: float = 1.0,
                      stop: Optional[List[str]] = None, tools: Optional[List] = None) -> str:
        """Generate chat completion using llama-cpp's create_chat_completion."""
        if max_tokens is None:
            max_tokens = 512
        
        # Check if we should use manual formatting based on detected template
        # Always use manual formatting when tools are present, since Jinja templates often fail with tool messages
        use_manual = self.chat_template in ("unknown", "jinja_fallback", None) or tools is not None
        
        if use_manual:
            print(f"DEBUG: Using manual message formatting (template: {self.chat_template}, tools: {tools is not None})")
            prompt = self._manual_format_messages(messages)
            return self.generate(prompt, max_tokens, temperature, top_p, stop)
        
        try:
            response = self.model.create_chat_completion(
                messages=messages,
                max_tokens=max_tokens,
                temperature=temperature,
                top_p=top_p,
                stop=stop or [],
                tools=tools,
            )
            content = response["choices"][0]["message"].get("content", "")
            print(f"DEBUG: generate_chat returned content length: {len(content) if content else 0}")
            if not content or not content.strip():
                print(f"DEBUG: Empty content from create_chat_completion, using fallback")
                raise Exception("Empty response from create_chat_completion")
            return content
        except Exception as e:
            print(f"Warning: create_chat_completion failed ({e}), falling back to text generation")
            # Fallback: format messages manually and use text generation
            prompt = self._manual_format_messages(messages)
            return self.generate(prompt, max_tokens, temperature, top_p, stop)
    
    async def generate_chat_stream(self, messages: List[Dict], max_tokens: Optional[int] = None,
                                    temperature: float = 0.7, top_p: float = 1.0,
                                    stop: Optional[List[str]] = None, tools: Optional[List] = None) -> AsyncGenerator[str, None]:
        """Generate chat completion streaming using llama-cpp."""
        if max_tokens is None:
            max_tokens = 512
        
        total_content = ""
        chunk_count = 0
        
        # Check if we should use manual formatting based on detected template
        # Always use manual formatting when tools are present, since Jinja templates often fail with tool messages
        use_manual = self.chat_template in ("unknown", "jinja_fallback", None) or tools is not None
        
        if use_manual:
            print(f"DEBUG: Using manual message formatting for streaming (template: {self.chat_template}, tools: {tools is not None})")
            prompt = self._manual_format_messages(messages)
            async for chunk in self.generate_stream(prompt, max_tokens, temperature, top_p, stop):
                yield chunk
            return
        
        # Collect all chunks synchronously then yield them
        # This avoids issues with generators across thread boundaries
        def collect_chunks():
            """Collect all chunks from the stream."""
            print(f"DEBUG: generate_chat_stream: Calling create_chat_completion with tools={tools}")
            stream = self.model.create_chat_completion(
                messages=messages,
                max_tokens=max_tokens,
                temperature=temperature,
                top_p=top_p,
                stop=stop or [],
                tools=tools,
                stream=True,
            )
            print(f"DEBUG: generate_chat_stream: Got stream object: {type(stream)}")
            chunks = []
            for chunk in stream:
                chunks.append(chunk)
            print(f"DEBUG: generate_chat_stream: Collected {len(chunks)} chunks")
            return chunks
        
        try:
            # Run the collection in thread pool
            loop = asyncio.get_event_loop()
            chunks = await loop.run_in_executor(None, collect_chunks)
            
            for chunk in chunks:
                chunk_count += 1
                print(f"DEBUG: generate_chat_stream: Processing chunk {chunk_count}: {repr(chunk)}")
                delta = chunk["choices"][0].get("delta", {})
                content = delta.get("content", "")
                
                # Handle Qwen3's special thinking token - skip it and continue
                # Qwen3 uses `<think>` tags for reasoning, we should pass through the content
                if content:
                    total_content += content
                    yield content
                
                # Small yield to allow other async tasks
                await asyncio.sleep(0)
            
            print(f"DEBUG: generate_chat_stream yielded {chunk_count} chunks, total content length: {len(total_content)}")
            if chunk_count == 0 or not total_content.strip():
                print(f"DEBUG: Empty stream from create_chat_completion, using fallback")
                raise Exception("Empty stream response")
        except Exception as e:
            print(f"DEBUG: generate_chat_stream exception: {type(e).__name__}: {e}")
            import traceback
            traceback.print_exc()
            if chunk_count == 0:
                print(f"Warning: create_chat_completion stream failed ({e}), falling back to text generation")
                # Fallback: format messages manually and use text generation
                prompt = self._manual_format_messages(messages)
                async for chunk in self.generate_stream(prompt, max_tokens, temperature, top_p, stop):
                    yield chunk
            else:
                print(f"DEBUG: Stream completed with {chunk_count} chunks")
    
    def _manual_format_messages(self, messages: List[Dict]) -> str:
        """Manual fallback for formatting messages when create_chat_completion fails."""
        formatted = []
        for msg in messages:
            role = msg.get("role", "")
            content = msg.get("content", "") or ""
            
            if role == "system":
                formatted.append(f"<|im_start|>system\n{content}<|im_end|>")
            elif role == "user":
                formatted.append(f"<|im_start|>user\n{content}<|im_end|>")
            elif role == "assistant":
                # Handle tool_calls if present
                tool_calls = msg.get("tool_calls", [])
                if tool_calls:
                    for tc in tool_calls:
                        if isinstance(tc, dict) and "function" in tc:
                            func = tc["function"]
                            tc_str = f'<tool>{{"name": "{func.get("name", "")}", "arguments": {func.get("arguments", "{}")}}}</tool>'
                            content = content + "\n" + tc_str if content else tc_str
                formatted.append(f"<|im_start|>assistant\n{content}<|im_end|>")
            elif role == "tool":
                # Tool result messages
                tool_call_id = msg.get("tool_call_id", "")
                name = msg.get("name", "")
                formatted.append(f"<|im_start|>tool (tool_call_id={tool_call_id}, name={name})\n{content}<|im_end|>")
        
        formatted.append("<|im_start|>assistant\n")
        return "\n".join(formatted)
    
    async def generate_stream(self, prompt: str, max_tokens: Optional[int] = None,
                              temperature: float = 0.7, top_p: float = 1.0,
                              stop: Optional[List[str]] = None) -> AsyncGenerator[str, None]:
        """Generate text in streaming fashion using llama-cpp."""
        if max_tokens is None:
            max_tokens = 512
        
        stream = self.model(
            prompt,
            max_tokens=max_tokens,
            temperature=temperature,
            top_p=top_p,
            stop=stop or [],
            stream=True,
        )
        
        for chunk in stream:
            text = chunk["choices"][0].get("text", "")
            if text:
                yield text
    
    def get_model_name(self) -> str:
        return self.model_name or "unknown"
    
    def cleanup(self) -> None:
        if self.model is not None:
            del self.model
            self.model = None


# =============================================================================
# Model Manager
# =============================================================================

class ModelManager:
    """Manages the loaded model and tokenizer."""
    
    def __init__(self):
        self.backend: Optional[ModelBackend] = None
        self.backend_type: Optional[str] = None
        self.tool_parser = ToolCallParser()
        
    def load_model(self, model_name: str, backend_type: str = "auto", **kwargs):
        """
        Load the model with the specified backend.
        
        Args:
            model_name: Model name or path
            backend_type: 'nvidia', 'vulkan', or 'auto' to detect
            **kwargs: Additional arguments for the specific backend
        """
        available = detect_available_backends()
        
        # Determine backend
        if backend_type == "auto":
            if available.get('nvidia'):
                backend_type = "nvidia"
                print("Auto-detected NVIDIA backend")
            elif available.get('vulkan'):
                backend_type = "vulkan"
                print("Auto-detected Vulkan backend")
            else:
                print("No GPU backend detected. For NVIDIA, install PyTorch with CUDA.")
                print("For Vulkan, install llama-cpp-python with Vulkan support.")
                raise RuntimeError("No suitable backend found")
        
        self.backend_type = backend_type
        
        # Create appropriate backend
        if backend_type == "nvidia":
            if not available.get('nvidia'):
                raise RuntimeError("NVIDIA backend requested but PyTorch/CUDA not available")
            self.backend = NvidiaBackend()
        elif backend_type == "vulkan":
            if not available.get('vulkan'):
                raise RuntimeError("Vulkan backend requested but llama-cpp-python not available")
            self.backend = VulkanBackend()
        else:
            raise ValueError(f"Unknown backend: {backend_type}")
        
        # Load the model
        self.backend.load_model(model_name, **kwargs)
        self.tool_parser = ToolCallParser()
        
    def format_messages(self, messages: List[ChatMessage]) -> str:
        """Format messages into a prompt string."""
        if self.backend is None:
            raise RuntimeError("No model loaded")
        return self.backend.format_messages(messages)
    
    def generate(self, prompt: str, max_tokens: Optional[int] = None,
                 temperature: float = 0.7, top_p: float = 1.0,
                 stop: Optional[List[str]] = None) -> str:
        """Generate text non-streaming."""
        if self.backend is None:
            raise RuntimeError("No model loaded")
        return self.backend.generate(prompt, max_tokens, temperature, top_p, stop)
    
    def generate_chat(self, messages: List[Dict], max_tokens: Optional[int] = None,
                      temperature: float = 0.7, top_p: float = 1.0,
                      stop: Optional[List[str]] = None, tools: Optional[List] = None) -> str:
        """Generate chat completion non-streaming."""
        if self.backend is None:
            raise RuntimeError("No model loaded")
        # Use generate_chat if available (Vulkan backend), otherwise format and use generate
        if hasattr(self.backend, 'generate_chat'):
            return self.backend.generate_chat(messages, max_tokens, temperature, top_p, stop, tools)
        else:
            # Fallback for NVIDIA backend
            prompt = self.format_messages([ChatMessage(**m) for m in messages])
            return self.backend.generate(prompt, max_tokens, temperature, top_p, stop)
    
    async def generate_stream(self, prompt: str, max_tokens: Optional[int] = None,
                              temperature: float = 0.7, top_p: float = 1.0,
                              stop: Optional[List[str]] = None) -> AsyncGenerator[str, None]:
        """Generate text in streaming fashion."""
        if self.backend is None:
            raise RuntimeError("No model loaded")
        async for chunk in self.backend.generate_stream(prompt, max_tokens, temperature, top_p, stop):
            yield chunk
    
    async def generate_chat_stream(self, messages: List[Dict], max_tokens: Optional[int] = None,
                                    temperature: float = 0.7, top_p: float = 1.0,
                                    stop: Optional[List[str]] = None, tools: Optional[List] = None) -> AsyncGenerator[str, None]:
        """Generate chat completion streaming."""
        if self.backend is None:
            raise RuntimeError("No model loaded")
        # Use generate_chat_stream if available (Vulkan backend), otherwise format and use generate_stream
        if hasattr(self.backend, 'generate_chat_stream'):
            async for chunk in self.backend.generate_chat_stream(messages, max_tokens, temperature, top_p, stop, tools):
                yield chunk
        else:
            # Fallback for NVIDIA backend
            prompt = self.format_messages([ChatMessage(**m) for m in messages])
            async for chunk in self.backend.generate_stream(prompt, max_tokens, temperature, top_p, stop):
                yield chunk
    
    @property
    def model_name(self) -> str:
        if self.backend is None:
            return "unknown"
        return self.backend.get_model_name()
    
    @property
    def model(self):
        if self.backend is None:
            return None
        return self.backend
    
    @property
    def tokenizer(self):
        # Only NVIDIA backend has a tokenizer
        if isinstance(self.backend, NvidiaBackend):
            return self.backend.tokenizer
        return None
    
    def cleanup(self):
        if self.backend is not None:
            self.backend.cleanup()
            self.backend = None


# Global model manager
model_manager = ModelManager()

# Global system prompt (set via --system-prompt flag)
# None = don't inject, True = use default, string = use custom text
global_system_prompt = None


# =============================================================================
# FastAPI Application
# =============================================================================

@asynccontextmanager
async def lifespan(app: FastAPI):
    """Lifespan context manager for startup/shutdown."""
    # Startup
    yield
    # Shutdown
    model_manager.cleanup()


app = FastAPI(
    title="OpenAI-Compatible API",
    description="OpenAI-compatible API supporting NVIDIA (CUDA) and Vulkan backends",
    version="2.0.0",
    lifespan=lifespan,
)

# Add request logging middleware for debugging
@app.middleware("http")
async def log_requests(request: Request, call_next):
    """Log all incoming requests for debugging."""
    if request.url.path in ["/v1/chat/completions", "/v1/completions"]:
        body = b""
        body_str = ""
        try:
            body = await request.body()
            body_str = body.decode('utf-8')
            print(f"\n{'='*60}")
            print(f"=== INCOMING REQUEST ===")
            print(f"{'='*60}")
            print(f"Path: {request.url.path}")
            print(f"Method: {request.method}")
            print(f"Headers: {dict(request.headers)}")
            print(f"\n--- RAW BODY ({len(body)} bytes) ---")
            # Print body with truncation for very large bodies
            if len(body_str) > 2000:
                print(f"{body_str[:1000]}...\n... [truncated {len(body_str)-2000} chars] ...\n...{body_str[-1000:]}")
            else:
                print(body_str)
            print(f"--- END RAW BODY ---")
            
            # Try to parse as JSON to see if it's valid
            try:
                parsed = json.loads(body_str)
                print(f"\n--- PARSED JSON STRUCTURE ---")
                print(f"Keys: {list(parsed.keys())}")
                if 'messages' in parsed and isinstance(parsed['messages'], list):
                    print(f"Number of messages: {len(parsed['messages'])}")
                    for i, msg in enumerate(parsed['messages']):
                        role = msg.get('role', 'unknown')
                        content_preview = str(msg.get('content', ''))[:50].replace('\n', ' ')
                        print(f"  [{i}] {role}: {content_preview}...")
                print(f"--- END PARSED JSON ---")
            except json.JSONDecodeError as e:
                print(f"\n*** JSON Parse Error: {e} ***")
                print(f"Error at position: char {e.pos}, line {e.lineno}, column {e.colno}")
            except Exception as e:
                print(f"\n*** Error analyzing JSON: {e} ***")
        except Exception as e:
            # Handle ClientDisconnect and other exceptions gracefully
            print(f"Error logging request: {e}")
            # Continue with empty body if we couldn't read it
            body = b""
        
        # Re-create request with body for downstream handlers (only if we successfully read it)
        if body:
            async def receive():
                return {"type": "http.request", "body": body}
            request = Request(request.scope, receive, request._send)
    
    try:
        response = await call_next(request)
        if request.url.path in ["/v1/chat/completions", "/v1/completions"]:
            print(f"\n--- RESPONSE ---")
            print(f"Status Code: {response.status_code}")
            
            # For error responses, try to read and log the body, then create a new response
            if response.status_code >= 400:
                try:
                    # Read the response body
                    response_body = b""
                    async for chunk in response.body_iterator:
                        response_body += chunk
                    
                    error_body = response_body.decode('utf-8')
                    print(f"Error Response Body: {error_body}")
                    
                    # Try to parse and pretty-print error details
                    try:
                        error_json = json.loads(error_body)
                        if 'detail' in error_json:
                            print(f"\n*** VALIDATION ERROR DETAILS ***")
                            detail = error_json['detail']
                            if isinstance(detail, list):
                                for err in detail:
                                    if isinstance(err, dict):
                                        loc = err.get('loc', [])
                                        msg = err.get('msg', 'Unknown error')
                                        err_type = err.get('type', 'unknown')
                                        print(f"  - Location: {loc}")
                                        print(f"    Message: {msg}")
                                        print(f"    Type: {err_type}")
                            else:
                                print(f"  Detail: {detail}")
                            print(f"*** END ERROR DETAILS ***")
                    except Exception as parse_err:
                        print(f"Could not parse error details: {parse_err}")
                    
                    # Create new response with the same body
                    from starlette.responses import Response
                    return Response(
                        content=response_body,
                        status_code=response.status_code,
                        headers=dict(response.headers),
                        media_type=response.media_type
                    )
                except Exception as read_err:
                    print(f"Could not read error response body: {read_err}")
            
            print(f"--- END RESPONSE ---")
            print(f"{'='*60}\n")
        return response
    except Exception as e:
        print(f"\n*** EXCEPTION DURING REQUEST PROCESSING ***")
        print(f"Error type: {type(e).__name__}")
        print(f"Error message: {e}")
        import traceback
        traceback.print_exc()
        print(f"*** END EXCEPTION ***\n")
        raise
    finally:
        if request.url.path in ["/v1/chat/completions", "/v1/completions"]:
            pass  # End logging already done above for successful responses


@app.get("/v1/models", response_model=ModelList)
async def list_models():
    """List available models."""
    models = []
    if model_manager.model_name:
        models.append(ModelInfo(id=model_manager.model_name))
    else:
        models.append(ModelInfo(id="unknown"))
    return ModelList(data=models)


@app.post("/v1/chat/completions")
async def chat_completions(request: ChatCompletionRequest):
    """Chat completions endpoint with streaming and tool support."""
    if model_manager.backend is None:
        raise HTTPException(status_code=503, detail="Model not loaded")
    
    # Inject system prompt if --system-prompt flag was provided
    messages = request.messages
    if global_system_prompt is not None:
        # Check if there's already a system message
        has_system = any(msg.role == "system" for msg in messages)
        if not has_system:
            # Use default or custom system prompt
            if global_system_prompt is True:
                # Default system prompt
                system_text = "You are a helpful assistant."
            else:
                # Custom system prompt provided as argument
                system_text = str(global_system_prompt)
            # Insert system message at the beginning
            messages = [ChatMessage(role="system", content=system_text)] + list(messages)
    
    # Format messages with tools if provided
    if request.tools:
        messages = format_tools_for_prompt(request.tools, messages)
    
    # Prepare stop sequences
    stop_sequences = []
    if request.stop:
        if isinstance(request.stop, str):
            stop_sequences = [request.stop]
        else:
            stop_sequences = request.stop
    
    # Convert messages to dict format for chat completion
    messages_dict = []
    for msg in messages:
        msg_dict = {"role": msg.role}
        # Always include content key - llama_cpp template expects it
        # Convert content to string if it's a list (multipart content)
        content = msg.content
        if content is None:
            content = ""
        elif isinstance(content, list):
            # Handle multipart content array format: [{"type": "text", "text": "..."}]
            parts = []
            for item in content:
                if isinstance(item, dict):
                    if item.get('type') == 'text' and 'text' in item:
                        parts.append(item['text'])
                    else:
                        parts.append(f"[{item.get('type', 'unknown')} content]")
                else:
                    parts.append(str(item))
            content = '\n'.join(parts)
        msg_dict["content"] = content
        # Handle tool_calls - convert to proper format if present
        if msg.tool_calls:
            # tool_calls should be a list of dicts with 'id', 'type', 'function' keys
            msg_dict["tool_calls"] = msg.tool_calls
        if msg.name:
            msg_dict["name"] = msg.name
        if msg.tool_call_id:
            msg_dict["tool_call_id"] = msg.tool_call_id
        messages_dict.append(msg_dict)
    
    # Debug: print first few messages to see their structure
    print(f"DEBUG: messages_dict[0] keys: {list(messages_dict[0].keys()) if messages_dict else 'empty'}")
    if len(messages_dict) > 1:
        print(f"DEBUG: messages_dict[1] keys: {list(messages_dict[1].keys()) if len(messages_dict) > 1 else 'empty'}")
    
    # Convert tools to dict format if present
    tools_dict = None
    if request.tools:
        tools_dict = []
        for tool in request.tools:
            tools_dict.append({
                "type": tool.type,
                "function": {
                    "name": tool.function.name,
                    "description": tool.function.description,
                    "parameters": tool.function.parameters
                }
            })
    
    if request.stream:
        return StreamingResponse(
            stream_chat_response(
                messages_dict,
                request.model,
                request.max_tokens,
                request.temperature,
                request.top_p,
                stop_sequences,
                tools_dict,
            ),
            media_type="text/event-stream",
        )
    else:
        return await generate_chat_response(
            messages_dict,
            request.model,
            request.max_tokens,
            request.temperature,
            request.top_p,
            stop_sequences,
            tools_dict,
        )

async def stream_chat_response(
    messages: List[Dict],
    model_name: str,
    max_tokens: Optional[int],
    temperature: float,
    top_p: float,
    stop: List[str],
    tools: Optional[List[Dict]],
) -> AsyncGenerator[str, None]:
    """Stream chat completion response."""
    completion_id = f"chatcmpl-{uuid.uuid4().hex}"
    created = int(time.time())
    
    generated_text = ""
    print(f"DEBUG: stream_chat_response started, stream=True, tools={tools is not None}")
    
    
    try:
        chunk_count = 0
        # Use generate_chat_stream for proper chat template handling
        async for chunk in model_manager.generate_chat_stream(
            messages=messages,
            max_tokens=max_tokens,
            temperature=temperature,
            top_p=top_p,
            stop=stop,
            tools=tools,
        ):
            chunk_count += 1
            # Filter malformed content from each chunk
            filtered_chunk = filter_malformed_content(chunk)
            if not filtered_chunk:
                print(f"DEBUG: filtered_chunk was empty (original chunk: {repr(chunk[:50])})")
                continue
                
            generated_text += filtered_chunk
            
            data = {
                "id": completion_id,
                "object": "chat.completion.chunk",
                "created": created,
                "model": model_name,
                "choices": [{
                    "index": 0,
                    "delta": {"content": filtered_chunk},
                    "finish_reason": None,
                }],
            }
            yield f"data: {json.dumps(data)}\n\n"
            # Explicitly flush to ensure data is sent immediately
            await asyncio.sleep(0)
        
        print(f"DEBUG: stream_chat_response completed, {chunk_count} chunks, generated_text length: {len(generated_text)}")
        if not generated_text.strip():
            print(f"DEBUG: Warning - no content generated!")
        
        # Check for tool calls in complete output
        if tools:
            # Convert tools back to Tool objects for parsing
            from typing import cast
            tool_objects = []
            for t in tools:
                tool_func = ToolFunction(
                    name=t["function"]["name"],
                    description=t["function"].get("description"),
                    parameters=t["function"].get("parameters")
                )
                tool_objects.append(Tool(type=t.get("type", "function"), function=tool_func))
            tool_calls = model_manager.tool_parser.extract_tool_calls(generated_text, tool_objects)
            if tool_calls:
                data = {
                    "id": completion_id,
                    "object": "chat.completion.chunk",
                    "created": created,
                    "model": model_name,
                    "choices": [{
                        "index": 0,
                        "delta": {"tool_calls": tool_calls},
                        "finish_reason": "tool_calls",
                    }],
                }
                yield f"data: {json.dumps(data)}\n\n"
            else:
                yield f"data: {json.dumps({'choices': [{'finish_reason': 'stop'}]})}\n\n"
        else:
            yield f"data: {json.dumps({'choices': [{'finish_reason': 'stop'}]})}\n\n"
        
        yield "data: [DONE]\n\n"
    except Exception as e:
        print(f"Error during streaming generation: {e}")
        data = {
            "id": completion_id,
            "object": "chat.completion.chunk",
            "created": created,
            "model": model_name,
            "choices": [{
                "index": 0,
                "delta": {"content": f"\n[Generation error: {str(e)}]"},
                "finish_reason": "stop",
            }],
        }
        yield f"data: {json.dumps(data)}\n\n"
        yield "data: [DONE]\n\n"


async def generate_chat_response(
    messages: List[Dict],
    model_name: str,
    max_tokens: Optional[int],
    temperature: float,
    top_p: float,
    stop: List[str],
    tools: Optional[List[Dict]],
) -> Dict:
    """Generate non-streaming chat completion response."""
    completion_id = f"chatcmpl-{uuid.uuid4().hex}"
    created = int(time.time())
    
    try:
        # Use generate_chat for proper chat template handling
        generated_text = model_manager.generate_chat(
            messages=messages,
            max_tokens=max_tokens,
            temperature=temperature,
            top_p=top_p,
            stop=stop,
            tools=tools,
        )
        
        # Filter out malformed content from generated text
        generated_text = filter_malformed_content(generated_text)
        
        response_message = {
            "role": "assistant",
            "content": generated_text,
        }
        
        finish_reason = "stop"
        
        # Check for tool calls
        if tools:
            # Convert tools back to Tool objects for parsing
            tool_objects = []
            for t in tools:
                tool_func = ToolFunction(
                    name=t["function"]["name"],
                    description=t["function"].get("description"),
                    parameters=t["function"].get("parameters")
                )
                tool_objects.append(Tool(type=t.get("type", "function"), function=tool_func))
            tool_calls = model_manager.tool_parser.extract_tool_calls(generated_text, tool_objects)
            if tool_calls:
                response_message["tool_calls"] = tool_calls
                response_message["content"] = None
                finish_reason = "tool_calls"
        
        # Calculate token counts - rough estimate since we don't have direct access to tokenizer
        prompt_text = "\n".join([m.get("content", "") for m in messages])
        prompt_tokens = len(prompt_text.split())
        completion_tokens = len(generated_text.split()) if generated_text else 0
        
        return {
            "id": completion_id,
            "object": "chat.completion",
            "created": created,
            "model": model_name,
            "choices": [{
                "index": 0,
                "message": response_message,
                "finish_reason": finish_reason,
            }],
            "usage": {
                "prompt_tokens": prompt_tokens,
                "completion_tokens": completion_tokens,
                "total_tokens": prompt_tokens + completion_tokens,
            },
        }
    except Exception as e:
        print(f"Error during generation: {e}")
        raise HTTPException(status_code=500, detail=f"Generation error: {str(e)}")


@app.post("/v1/completions")
async def completions(request: CompletionRequest):
    """Text completions endpoint."""
    if model_manager.backend is None:
        raise HTTPException(status_code=503, detail="Model not loaded")
    
    prompts = request.prompt if isinstance(request.prompt, list) else [request.prompt]
    stop_sequences = []
    if request.stop:
        stop_sequences = [request.stop] if isinstance(request.stop, str) else request.stop
    
    if request.stream:
        return StreamingResponse(
            stream_completion_response(
                prompts[0],
                request.model,
                request.max_tokens,
                request.temperature,
                request.top_p,
                stop_sequences,
            ),
            media_type="text/event-stream",
        )
    else:
        return await generate_completion_response(
            prompts[0],
            request.model,
            request.max_tokens,
            request.temperature,
            request.top_p,
            stop_sequences,
        )


async def stream_completion_response(
    prompt: str,
    model_name: str,
    max_tokens: Optional[int],
    temperature: float,
    top_p: float,
    stop: List[str],
) -> AsyncGenerator[str, None]:
    """Stream completion response."""
    completion_id = f"cmpl-{uuid.uuid4().hex}"
    created = int(time.time())
    
    try:
        async for chunk in model_manager.generate_stream(
            prompt=prompt,
            max_tokens=max_tokens,
            temperature=temperature,
            top_p=top_p,
            stop=stop,
        ):
            data = {
                "id": completion_id,
                "object": "text_completion",
                "created": created,
                "model": model_name,
                "choices": [{
                    "text": chunk,
                    "index": 0,
                    "logprobs": None,
                    "finish_reason": None,
                }],
            }
            yield f"data: {json.dumps(data)}\n\n"
        
        yield f"data: {json.dumps({'choices': [{'finish_reason': 'stop'}]})}\n\n"
        yield "data: [DONE]\n\n"
    except Exception as e:
        print(f"Error during streaming completion: {e}")
        yield f"data: {json.dumps({'choices': [{'finish_reason': 'stop'}]})}\n\n"
        yield "data: [DONE]\n\n"


async def generate_completion_response(
    prompt: str,
    model_name: str,
    max_tokens: Optional[int],
    temperature: float,
    top_p: float,
    stop: List[str],
) -> Dict:
    """Generate non-streaming completion response."""
    completion_id = f"cmpl-{uuid.uuid4().hex}"
    created = int(time.time())
    
    try:
        generated_text = model_manager.generate(
            prompt=prompt,
            max_tokens=max_tokens,
            temperature=temperature,
            top_p=top_p,
            stop=stop,
        )
        
        # Calculate token counts if tokenizer available
        if model_manager.tokenizer:
            prompt_tokens = len(model_manager.tokenizer.encode(prompt))
            completion_tokens = len(model_manager.tokenizer.encode(generated_text))
        else:
            prompt_tokens = len(prompt.split())
            completion_tokens = len(generated_text.split())
        
        return {
            "id": completion_id,
            "object": "text_completion",
            "created": created,
            "model": model_name,
            "choices": [{
                "text": generated_text,
                "index": 0,
                "logprobs": None,
                "finish_reason": "stop",
            }],
            "usage": {
                "prompt_tokens": prompt_tokens,
                "completion_tokens": completion_tokens,
                "total_tokens": prompt_tokens + completion_tokens,
            },
        }
    except Exception as e:
        print(f"Error during completion: {e}")
        raise HTTPException(status_code=500, detail=f"Generation error: {str(e)}")


# =============================================================================
# Main Entry Point
# =============================================================================

def parse_args():
    """Parse command line arguments."""
    parser = argparse.ArgumentParser(
        description="OpenAI-compatible API server supporting NVIDIA (CUDA) and Vulkan backends"
    )
    parser.add_argument(
        "--model",
        type=str,
        default=None,
        help="Model name or path. For NVIDIA: HuggingFace model. For Vulkan: GGUF file path or HF repo",
    )
    parser.add_argument(
        "--backend",
        type=str,
        choices=["auto", "nvidia", "vulkan"],
        default="auto",
        help="Backend to use: auto (detect), nvidia (CUDA), or vulkan (AMD GPUs)",
    )
    parser.add_argument(
        "--host",
        type=str,
        default="0.0.0.0",
        help="Host to bind to (default: 0.0.0.0)",
    )
    parser.add_argument(
        "--port",
        type=int,
        default=8000,
        help="Port to bind to (default: 8000)",
    )
    parser.add_argument(
        "--offload-dir",
        type=str,
        default="./offload",
        help="Directory for disk offload (NVIDIA backend only, default: ./offload)",
    )
    parser.add_argument(
        "--load-in-4bit",
        action="store_true",
        help="Load model in 4-bit precision (NVIDIA backend only, requires bitsandbytes)",
    )
    parser.add_argument(
        "--load-in-8bit",
        action="store_true",
        help="Load model in 8-bit precision (NVIDIA backend only, requires bitsandbytes)",
    )
    parser.add_argument(
        "--ram",
        type=float,
        default=None,
        help="Maximum CPU RAM to use for model offloading in GB (NVIDIA backend only). Auto-detected if not specified. Disk offloading only occurs after this limit is exceeded.",
    )
    parser.add_argument(
        "--flash-attn",
        action="store_true",
        help="Use Flash Attention 2 (NVIDIA backend only, requires flash-attn package)",
    )
    parser.add_argument(
        "--offload-strategy",
        type=str,
        choices=["auto", "conservative", "balanced", "aggressive", "sequential"],
        default="auto",
        help="Offload strategy for NVIDIA backend (default: auto)",
    )
    parser.add_argument(
        "--max-gpu-percent",
        type=float,
        default=None,
        help="Maximum GPU VRAM to use as percentage (0-100). Overrides offload-strategy. Lower values offload more to CPU/RAM (default: None = use offload-strategy)",
    )
    parser.add_argument(
        "--n-gpu-layers",
        type=int,
        default=-1,
        help="Number of layers to offload to GPU (Vulkan backend only, default: -1 = all layers)",
    )
    parser.add_argument(
        "--n-ctx",
        type=int,
        default=2048,
        help="Context window size (Vulkan backend only, default: 2048)",
    )
    parser.add_argument(
        "--vulkan-device",
        type=int,
        default=0,
        help="Vulkan GPU device ID to use (Vulkan backend only, default: 0). Use --vulkan-list-devices to see available devices",
    )
    parser.add_argument(
        "--vulkan-single-gpu",
        action="store_true",
        help="Force Vulkan to use only the specified GPU device (prevents layer distribution across multiple GPUs)",
    )
    parser.add_argument(
        "--vulkan-list-devices",
        action="store_true",
        help="List available Vulkan GPU devices and exit",
    )
    parser.add_argument(
        "--system-prompt",
        nargs="?",
        const=True,
        default=None,
        help="Inject a system prompt at the beginning of conversations. Use without a value for a default prompt, or provide custom text.",
    )
    return parser.parse_args()


def main():
    """Main entry point."""
    global global_system_prompt
    
    # Optional: set process name if procname is available
    try:
        import procname
        procname.setprocname("coderai")
    except ImportError:
        pass
    args = parse_args()
    
    # Set global system prompt from --system-prompt flag
    global_system_prompt = args.system_prompt
    
    # Handle --vulkan-list-devices
    if args.vulkan_list_devices:
        print("\nListing Vulkan devices...")
        try:
            import subprocess
            result = subprocess.run(['vulkaninfo', '--summary'], capture_output=True, text=True)
            if result.returncode == 0:
                print(result.stdout)
            else:
                print("Could not run vulkaninfo. Make sure vulkan-tools is installed.")
        except Exception as e:
            print(f"Error listing devices: {e}")
        sys.exit(0)
    
    # Get model name from args or prompt interactively
    model_name = args.model
    if model_name is None:
        print("No model specified. Please enter a model name.")
        print("")
        print("For NVIDIA backend (HuggingFace models):")
        print("  - microsoft/DialoGPT-medium")
        print("  - meta-llama/Llama-2-7b-chat-hf (requires auth)")
        print("  - TinyLlama/TinyLlama-1.1B-Chat-v1.0")
        print("")
        print("For Vulkan backend (GGUF models):")
        print("  - Local path: ./phi-3-mini-4k-instruct-q4_k_m.gguf")
        print("  - HuggingFace: microsoft/Phi-3-mini-4k-instruct-gguf")
        print("")
        model_name = input("Enter model name: ").strip()
        
        if not model_name:
            print("Error: Model name is required")
            sys.exit(1)
    
    # Detect available backends
    available = detect_available_backends()
    print("\nAvailable backends:")
    for name, available_flag in available.items():
        status = "✓" if available_flag else "✗"
        print(f"  [{status}] {name}")
    print("")
    
    # Load the model
    load_kwargs = {
        'offload_dir': args.offload_dir,
        'load_in_4bit': args.load_in_4bit,
        'load_in_8bit': args.load_in_8bit,
        'manual_ram_gb': args.ram,
        'flash_attn': args.flash_attn,
        'offload_strategy': args.offload_strategy,
        'max_gpu_percent': args.max_gpu_percent,
        'n_gpu_layers': args.n_gpu_layers,
        'n_ctx': args.n_ctx,
        'main_gpu': args.vulkan_device,
        'single_gpu': args.vulkan_single_gpu,
    }
    
    try:
        model_manager.load_model(
            model_name=model_name,
            backend_type=args.backend,
            **load_kwargs
        )
    except Exception as e:
        print(f"\nError loading model: {e}")
        error_str = str(e).lower()
        print("\nTroubleshooting:")
        if args.backend == "vulkan":
            print("  - For Vulkan, ensure you have Vulkan drivers installed")
            print("  - Make sure you're using a GGUF format model")
            print("  - Run build.sh with 'vulkan' argument first")
        else:
            print("  - For NVIDIA, ensure PyTorch with CUDA is installed")
            print("  - Run build.sh with 'nvidia' argument first")
            if "tokenizer" in error_str or "sentencepiece" in error_str or "tiktoken" in error_str:
                print("  - Tokenizer error: ensure sentencepiece and tiktoken are installed")
                print("    pip install sentencepiece tiktoken tokenizers")
            # Check if trying to load GGUF model with NVIDIA backend
            if "gguf" in model_name.lower():
                print(f"\n  *** IMPORTANT: '{model_name}' appears to be a GGUF model ***")
                print("  GGUF models are NOT compatible with the NVIDIA backend.")
                print("  Use --backend vulkan instead, or choose a HuggingFace Transformers model.")
                print("\n  Example Vulkan command:")
                print(f"    coderai --backend vulkan --model {model_name}")
        sys.exit(1)
    
    # Start the server
    import uvicorn
    print(f"\nStarting server on http://{args.host}:{args.port}")
    print(f"API documentation available at http://{args.host}:{args.port}/docs")
    print(f"Using backend: {model_manager.backend_type}")
    uvicorn.run(app, host=args.host, port=args.port)


if __name__ == "__main__":
    main()
