#!/usr/bin/env python3
"""
OpenAI-compatible API server for HuggingFace models (NVIDIA) and GGUF models (Vulkan).
Supports CUDA (NVIDIA) and Vulkan (AMD) GPU backends, memory-aware model loading,
streaming, and tool calling.
"""

import argparse
import asyncio
import json
import os
import re
import sys
import time
import uuid
import warnings
from abc import ABC, abstractmethod
from contextlib import asynccontextmanager
from typing import AsyncGenerator, Dict, List, Optional, Union

import psutil
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field
from threading import Thread


# =============================================================================
# Backend Detection and Imports
# =============================================================================

def detect_available_backends():
    """Detect which backends are available."""
    backends = {'cpu': True}
    
    # Check for PyTorch/CUDA
    try:
        import torch
        if torch.cuda.is_available():
            backends['nvidia'] = True
    except ImportError:
        pass
    
    # Check for llama-cpp-python (Vulkan)
    try:
        import llama_cpp
        backends['vulkan'] = True
    except ImportError:
        pass
    
    return backends


# =============================================================================
# Flash Attention Detection (for NVIDIA backend)
# =============================================================================

def check_flash_attn_availability() -> bool:
    """Check if flash-attn is installed and available."""
    try:
        import flash_attn
        return True
    except ImportError:
        return False


# =============================================================================
# Pydantic Models for API
# =============================================================================

class ToolFunction(BaseModel):
    name: str
    description: Optional[str] = None
    parameters: Optional[Dict] = None


class Tool(BaseModel):
    type: str = "function"
    function: ToolFunction


class ChatMessage(BaseModel):
    role: str
    content: Optional[str] = None
    name: Optional[str] = None
    tool_calls: Optional[List[Dict]] = None
    tool_call_id: Optional[str] = None


class ChatCompletionRequest(BaseModel):
    model: str
    messages: List[ChatMessage]
    temperature: float = 0.7
    top_p: float = 1.0
    n: int = 1
    max_tokens: Optional[int] = None
    stream: bool = False
    stop: Optional[Union[str, List[str]]] = None
    presence_penalty: float = 0.0
    frequency_penalty: float = 0.0
    tools: Optional[List[Tool]] = None
    tool_choice: Optional[Union[str, Dict]] = "auto"
    # Extra fields that clients may send but we ignore
    seed: Optional[int] = None
    logprobs: Optional[bool] = None
    top_logprobs: Optional[int] = None
    response_format: Optional[Dict] = None
    user: Optional[str] = None
    
    class Config:
        extra = "allow"  # Allow extra fields to prevent 422 errors


class CompletionRequest(BaseModel):
    model: str
    prompt: Union[str, List[str]]
    temperature: float = 0.7
    top_p: float = 1.0
    n: int = 1
    max_tokens: Optional[int] = None
    stream: bool = False
    stop: Optional[Union[str, List[str]]] = None
    presence_penalty: float = 0.0
    frequency_penalty: float = 0.0
    # Extra fields that clients may send but we ignore
    seed: Optional[int] = None
    logprobs: Optional[bool] = None
    top_logprobs: Optional[int] = None
    best_of: Optional[int] = None
    echo: Optional[bool] = None
    user: Optional[str] = None
    
    class Config:
        extra = "allow"  # Allow extra fields to prevent 422 errors


class ModelInfo(BaseModel):
    id: str
    object: str = "model"
    created: int = Field(default_factory=lambda: int(time.time()))
    owned_by: str = "huggingface"


class ModelList(BaseModel):
    object: str = "list"
    data: List[ModelInfo]


# =============================================================================
# Tool Parsing
# =============================================================================

class ToolCallParser:
    """Parse model outputs to extract tool calls."""
    
    def __init__(self, tokenizer=None):
        self.tokenizer = tokenizer
    
    def extract_tool_calls(self, text: str, available_tools: List[Tool]) -> Optional[List[Dict]]:
        """Extract tool calls from model output."""
        tool_calls = []
        
        # Look for function calls in various formats
        # Format 1: <tool> or <function> tags
        tool_pattern = r'<(?:tool|function)>(.*?)</(?:tool|function)>'
        tool_matches = re.findall(tool_pattern, text, re.DOTALL)
        
        for match in tool_matches:
            try:
                tool_data = json.loads(match.strip())
                if 'name' in tool_data and 'arguments' in tool_data:
                    tool_calls.append({
                        "id": f"call_{uuid.uuid4().hex[:16]}",
                        "type": "function",
                        "function": {
                            "name": tool_data["name"],
                            "arguments": json.dumps(tool_data["arguments"])
                        }
                    })
            except json.JSONDecodeError:
                pass
        
        # Format 2: JSON with function_call key
        try:
            if "function_call" in text:
                json_match = re.search(r'\{[^}]*"function_call"[^}]*\}', text, re.DOTALL)
                if json_match:
                    data = json.loads(json_match.group())
                    if "function_call" in data:
                        fc = data["function_call"]
                        tool_calls.append({
                            "id": f"call_{uuid.uuid4().hex[:16]}",
                            "type": "function",
                            "function": {
                                "name": fc.get("name", ""),
                                "arguments": fc.get("arguments", "{}")
                            }
                        })
        except (json.JSONDecodeError, AttributeError):
            pass
        
        # Format 3: Direct JSON function call
        for tool in available_tools:
            tool_name = tool.function.name
            pattern = rf'<{tool_name}>(.*?)</{tool_name}>'
            matches = re.findall(pattern, text, re.DOTALL)
            for match in matches:
                try:
                    args = json.loads(match.strip())
                    tool_calls.append({
                        "id": f"call_{uuid.uuid4().hex[:16]}",
                        "type": "function",
                        "function": {
                            "name": tool_name,
                            "arguments": json.dumps(args)
                        }
                    })
                except json.JSONDecodeError:
                    # Try to use the raw text as arguments
                    tool_calls.append({
                        "id": f"call_{uuid.uuid4().hex[:16]}",
                        "type": "function",
                        "function": {
                            "name": tool_name,
                            "arguments": match.strip()
                        }
                    })
        
        return tool_calls if tool_calls else None


def format_tools_for_prompt(tools: List[Tool], messages: List[ChatMessage]) -> List[ChatMessage]:
    """Format tools into the system message or add a tool description."""
    if not tools:
        return messages
    
    tool_descriptions = []
    for tool in tools:
        func = tool.function
        desc = f"Tool: {func.name}"
        if func.description:
            desc += f"\nDescription: {func.description}"
        if func.parameters:
            desc += f"\nParameters: {json.dumps(func.parameters, indent=2)}"
        tool_descriptions.append(desc)
    
    tools_text = "You have access to the following tools:\n\n" + "\n\n".join(tool_descriptions)
    tools_text += "\n\nWhen you need to use a tool, format your response as:\n"
    tools_text += '<tool>{"name": "tool_name", "arguments": {...}}</tool>'
    
    # Add or prepend to system message
    new_messages = list(messages)
    system_found = False
    
    for i, msg in enumerate(new_messages):
        if msg.role == "system":
            new_messages[i] = ChatMessage(
                role="system",
                content=f"{tools_text}\n\n{msg.content or ''}"
            )
            system_found = True
            break
    
    if not system_found:
        new_messages.insert(0, ChatMessage(role="system", content=tools_text))
    
    return new_messages


# =============================================================================
# Abstract Model Backend
# =============================================================================

class ModelBackend(ABC):
    """Abstract base class for model backends."""
    
    @abstractmethod
    def load_model(self, model_name: str, **kwargs) -> None:
        """Load the model."""
        pass
    
    @abstractmethod
    def generate(self, prompt: str, max_tokens: Optional[int] = None, 
                 temperature: float = 0.7, top_p: float = 1.0,
                 stop: Optional[List[str]] = None) -> str:
        """Generate text non-streaming."""
        pass
    
    @abstractmethod
    def generate_stream(self, prompt: str, max_tokens: Optional[int] = None,
                        temperature: float = 0.7, top_p: float = 1.0,
                        stop: Optional[List[str]] = None) -> AsyncGenerator[str, None]:
        """Generate text in streaming fashion."""
        pass
    
    @abstractmethod
    def format_messages(self, messages: List[ChatMessage]) -> str:
        """Format messages into a prompt string."""
        pass
    
    @abstractmethod
    def get_model_name(self) -> str:
        """Return the loaded model name."""
        pass
    
    @abstractmethod
    def cleanup(self) -> None:
        """Cleanup resources."""
        pass


# =============================================================================
# NVIDIA/HuggingFace Backend
# =============================================================================

class NvidiaBackend(ModelBackend):
    """Backend for NVIDIA GPUs using HuggingFace Transformers."""
    
    def __init__(self):
        self.model = None
        self.tokenizer = None
        self.model_name = None
        self.device = None
        self.use_flash_attn = False
        self.flash_attn_available = False
        
    def check_flash_attn_support(self) -> None:
        """Check and print Flash Attention availability status."""
        self.flash_attn_available = check_flash_attn_availability()
        if self.use_flash_attn:
            if self.flash_attn_available:
                print("Flash Attention 2: Available and enabled")
            else:
                print("Warning: Flash Attention 2 requested but not installed")
                print("Install with: pip install flash-attn --no-build-isolation")
                print("Falling back to standard attention")
                self.use_flash_attn = False
    
    def _detect_device(self) -> str:
        """Auto-detect available GPU or fall back to CPU."""
        import torch
        if torch.cuda.is_available():
            # Check for ROCm (HIP)
            if hasattr(torch.version, 'hip') and torch.version.hip is not None:
                print(f"ROCm/HIP detected: {torch.version.hip}")
                return "cuda"
            else:
                print(f"CUDA detected: {torch.version.cuda}")
                return "cuda"
        else:
            print("No GPU detected, using CPU")
            return "cpu"
    
    def _get_available_vram(self) -> int:
        """Get available VRAM in bytes. Returns 0 if no GPU available."""
        import torch
        if not torch.cuda.is_available():
            return 0
        
        try:
            total_vram = 0
            for i in range(torch.cuda.device_count()):
                props = torch.cuda.get_device_properties(i)
                total_vram += props.total_memory
            return total_vram
        except Exception as e:
            print(f"Warning: Could not detect VRAM: {e}")
            return 0
    
    def _estimate_model_size(self, model_name: str) -> Optional[int]:
        """Estimate model size in bytes from config."""
        from transformers import AutoConfig
        try:
            config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
            
            # Get model parameters from config
            if hasattr(config, 'num_parameters'):
                num_params = config.num_parameters
            elif hasattr(config, 'n_params'):
                num_params = config.n_params
            elif hasattr(config, 'num_hidden_layers') and hasattr(config, 'hidden_size'):
                layers = config.num_hidden_layers
                hidden = config.hidden_size
                vocab_size = getattr(config, 'vocab_size', 50000)
                num_params = (vocab_size * hidden_size) + (layers * 4 * hidden * hidden)
            else:
                return None
            
            # Assume float16 (2 bytes per parameter)
            return num_params * 2
        except Exception as e:
            print(f"Warning: Could not estimate model size: {e}")
            return None
    
    def load_model(self, model_name: str, **kwargs) -> None:
        """Load the model using HuggingFace Transformers."""
        import torch
        from transformers import AutoModelForCausalLM, AutoTokenizer
        
        offload_dir = kwargs.get('offload_dir')
        load_in_4bit = kwargs.get('load_in_4bit', False)
        load_in_8bit = kwargs.get('load_in_8bit', False)
        manual_ram_gb = kwargs.get('manual_ram_gb')
        flash_attn = kwargs.get('flash_attn', False)
        
        print(f"Loading HuggingFace model: {model_name}")
        
        self.use_flash_attn = flash_attn
        self.check_flash_attn_support()
        
        self.device = self._detect_device()
        
        # Load tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(
            model_name,
            trust_remote_code=True,
            padding_side="left"
        )
        
        # Set pad token if not present
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
        
        # Prepare model loading arguments
        load_kwargs = {'trust_remote_code': True}
        
        if load_in_4bit or load_in_8bit:
            try:
                import bitsandbytes as bnb
                print(f"Using {4 if load_in_4bit else 8}-bit quantization")
                load_kwargs['load_in_4bit'] = load_in_4bit
                load_kwargs['load_in_8bit'] = load_in_8bit
                load_kwargs['device_map'] = 'auto'
            except ImportError:
                print("Warning: bitsandbytes not installed. Quantization disabled.")
                if self.device == "cuda":
                    load_kwargs['torch_dtype'] = torch.float16
                else:
                    load_kwargs['torch_dtype'] = torch.float32
                load_kwargs['device_map'] = 'auto' if self.device == 'cuda' else None
        else:
            if self.device == "cuda":
                load_kwargs['torch_dtype'] = torch.float16
            else:
                load_kwargs['torch_dtype'] = torch.float32
            load_kwargs['device_map'] = 'auto' if self.device == 'cuda' else None
        
        # Add offload folder if specified
        if offload_dir:
            os.makedirs(offload_dir, exist_ok=True)
            load_kwargs['offload_folder'] = offload_dir
            print(f"Disk offload directory: {offload_dir}")
        
        # Add Flash Attention 2 if enabled
        if self.use_flash_attn and self.flash_attn_available:
            load_kwargs['attn_implementation'] = "flash_attention_2"
            print("Using Flash Attention 2")
        
        # Load model
        self.model = AutoModelForCausalLM.from_pretrained(model_name, **load_kwargs)
        
        if self.device == "cpu" and load_kwargs.get('device_map') is None:
            self.model = self.model.to(self.device)
        
        self.model.eval()
        self.model_name = model_name
        
        print(f"\nModel loaded successfully")
        print(f"Model device: {next(self.model.parameters()).device}")
    
    def format_messages(self, messages: List[ChatMessage]) -> str:
        """Format messages into a prompt string."""
        formatted = []
        
        for msg in messages:
            if msg.role == "system":
                formatted.append(f"System: {msg.content}")
            elif msg.role == "user":
                formatted.append(f"User: {msg.content}")
            elif msg.role == "assistant":
                content = msg.content or ""
                if msg.tool_calls:
                    for tc in msg.tool_calls:
                        if tc.get("function"):
                            func = tc["function"]
                            content += f'\n<tool>{{"name": "{func.get("name", "")}", "arguments": {func.get("arguments", "{}")}}}</tool>'
                formatted.append(f"Assistant: {content}")
            elif msg.role == "tool":
                formatted.append(f"Tool ({msg.name}): {msg.content}")
        
        formatted.append("Assistant:")
        return "\n\n".join(formatted)
    
    def _validate_params(self, temperature: float, top_p: float) -> tuple:
        """Validate generation parameters."""
        if temperature <= 0:
            temperature = 1.0
            do_sample = False
        else:
            temperature = max(0.01, min(temperature, 2.0))
            do_sample = True
        top_p = max(0.0, min(top_p, 1.0))
        return temperature, top_p, do_sample
    
    def generate(self, prompt: str, max_tokens: Optional[int] = None,
                 temperature: float = 0.7, top_p: float = 1.0,
                 stop: Optional[List[str]] = None) -> str:
        """Generate text non-streaming."""
        import torch
        from transformers import LogitsProcessor, LogitsProcessorList
        
        class InvalidLogitsProcessor(LogitsProcessor):
            def __call__(self, input_ids, scores):
                scores = torch.where(torch.isnan(scores), torch.tensor(-1e9, dtype=scores.dtype, device=scores.device), scores)
                scores = torch.where(torch.isinf(scores), torch.tensor(1e9, dtype=scores.dtype, device=scores.device), scores)
                return scores
        
        inputs = self.tokenizer(prompt, return_tensors="pt", padding=True)
        inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
        
        if max_tokens is None:
            max_tokens = 512
        
        temperature, top_p, do_sample = self._validate_params(temperature, top_p)
        
        with torch.no_grad():
            outputs = self.model.generate(
                input_ids=inputs["input_ids"],
                attention_mask=inputs["attention_mask"],
                max_new_tokens=max_tokens,
                temperature=temperature if do_sample else None,
                top_p=top_p if do_sample else None,
                do_sample=do_sample,
                pad_token_id=self.tokenizer.pad_token_id,
                eos_token_id=self.tokenizer.eos_token_id,
                logits_processor=LogitsProcessorList([InvalidLogitsProcessor()]),
            )
        
        generated_tokens = outputs[0][inputs["input_ids"].shape[1]:]
        return self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
    
    async def generate_stream(self, prompt: str, max_tokens: Optional[int] = None,
                              temperature: float = 0.7, top_p: float = 1.0,
                              stop: Optional[List[str]] = None) -> AsyncGenerator[str, None]:
        """Generate text in streaming fashion."""
        import torch
        from transformers import TextIteratorStreamer, LogitsProcessor, LogitsProcessorList, StoppingCriteria, StoppingCriteriaList
        
        class InvalidLogitsProcessor(LogitsProcessor):
            def __call__(self, input_ids, scores):
                scores = torch.where(torch.isnan(scores), torch.tensor(-1e9, dtype=scores.dtype, device=scores.device), scores)
                scores = torch.where(torch.isinf(scores), torch.tensor(1e9, dtype=scores.dtype, device=scores.device), scores)
                return scores
        
        inputs = self.tokenizer(prompt, return_tensors="pt", padding=True)
        inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
        
        if max_tokens is None:
            max_tokens = 512
        
        temperature, top_p, do_sample = self._validate_params(temperature, top_p)
        
        streamer = TextIteratorStreamer(
            self.tokenizer,
            skip_prompt=True,
            skip_special_tokens=True,
        )
        
        generation_kwargs = {
            "input_ids": inputs["input_ids"],
            "attention_mask": inputs["attention_mask"],
            "max_new_tokens": max_tokens,
            "temperature": temperature,
            "top_p": top_p,
            "do_sample": do_sample,
            "streamer": streamer,
            "pad_token_id": self.tokenizer.pad_token_id,
            "eos_token_id": self.tokenizer.eos_token_id,
            "logits_processor": LogitsProcessorList([InvalidLogitsProcessor()]),
        }
        
        # Handle stop sequences
        if stop:
            class StopOnSequence(StoppingCriteria):
                def __init__(self, stop_sequences, tokenizer):
                    self.stop_sequences = stop_sequences
                    self.tokenizer = tokenizer
                
                def __call__(self, input_ids, scores, **kwargs):
                    decoded = self.tokenizer.decode(input_ids[0][-20:], skip_special_tokens=True)
                    return any(seq in decoded for seq in self.stop_sequences)
            
            generation_kwargs["stopping_criteria"] = StoppingCriteriaList([
                StopOnSequence(stop, self.tokenizer)
            ])
        
        # Run generation in a separate thread
        thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
        thread.start()
        
        for text in streamer:
            yield text
        
        thread.join()
    
    def get_model_name(self) -> str:
        return self.model_name or "unknown"
    
    def cleanup(self) -> None:
        import torch
        if self.model is not None:
            del self.model
            del self.tokenizer
            self.model = None
            self.tokenizer = None
            if torch.cuda.is_available():
                torch.cuda.empty_cache()


# =============================================================================
# Vulkan Backend (llama-cpp-python)
# =============================================================================

class VulkanBackend(ModelBackend):
    """Backend for Vulkan (AMD GPUs) using llama-cpp-python with GGUF models."""
    
    def __init__(self):
        self.model = None
        self.model_name = None
        self.n_gpu_layers = -1  # Offload all layers to GPU by default
        self.n_ctx = 2048
        self.verbose = True
        self.main_gpu = 0  # Default to first GPU
        
    def list_vulkan_devices(self):
        """List available Vulkan GPU devices."""
        try:
            # Try to get device info via vulkaninfo or similar
            import subprocess
            result = subprocess.run(['vulkaninfo', '--summary'], capture_output=True, text=True)
            if result.returncode == 0:
                print("\nAvailable Vulkan devices:")
                print(result.stdout)
        except Exception:
            pass
    
    def count_vulkan_devices(self):
        """Count the number of Vulkan GPU devices available."""
        # llama.cpp filters out some devices (like CPU llvmpipe), so we need to
        # count only the devices that llama.cpp will actually use
        try:
            import subprocess
            result = subprocess.run(['vulkaninfo', '--summary'], capture_output=True, text=True)
            if result.returncode == 0:
                # Count GPU0, GPU1, etc. entries (these are actual GPUs, not CPU)
                import re
                gpu_matches = re.findall(r'^GPU\d+:', result.stdout, re.MULTILINE)
                # Filter out CPU devices (llvmpipe)
                non_cpu_gpus = 0
                for i, match in enumerate(gpu_matches):
                    # Check if this GPU is a CPU device
                    gpu_section = result.stdout.split(match)[1].split('\nGPU')[0] if i < len(gpu_matches) - 1 else result.stdout.split(match)[1]
                    if 'llvmpipe' not in gpu_section and 'CPU' not in gpu_section:
                        non_cpu_gpus += 1
                return max(non_cpu_gpus, 1)
        except:
            pass
        return 2  # Default to 2 (common for NVIDIA + AMD setups)
    
    def load_model(self, model_name: str, **kwargs) -> None:
        """Load a GGUF model using llama-cpp-python."""
        from llama_cpp import Llama
        
        # model_name should be a path to a .gguf file or a HuggingFace model ID
        # that will be resolved to a GGUF file
        
        n_gpu_layers = kwargs.get('n_gpu_layers', -1)
        n_ctx = kwargs.get('n_ctx', 2048)
        verbose = kwargs.get('verbose', True)
        main_gpu = kwargs.get('main_gpu', 0)
        self.main_gpu = main_gpu
        
        # Check if model_name is a local file
        if os.path.isfile(model_name):
            model_path = model_name
            print(f"Loading local GGUF model: {model_path}")
        else:
            # Try to download from HuggingFace Hub
            print(f"Attempting to download GGUF model: {model_name}")
            try:
                from huggingface_hub import hf_hub_download, list_repo_files
                
                # Parse model name (format: "org/model" or "org/model/filename.gguf")
                parts = model_name.split('/')
                if len(parts) >= 2:
                    repo_id = f"{parts[0]}/{parts[1]}"
                    
                    # If specific file provided
                    if len(parts) >= 3 and parts[-1].endswith('.gguf'):
                        filename = '/'.join(parts[2:])
                    else:
                        # Find GGUF files in the repo
                        files = list_repo_files(repo_id)
                        gguf_files = [f for f in files if f.endswith('.gguf')]
                        if not gguf_files:
                            raise ValueError(f"No GGUF files found in {repo_id}")
                        # Prefer Q4_K_M quantized models for good balance
                        preferred = [f for f in gguf_files if 'Q4_K_M' in f or 'q4_k_m' in f.lower()]
                        if preferred:
                            filename = preferred[0]
                        else:
                            filename = gguf_files[0]
                        print(f"Selected GGUF file: {filename}")
                    
                    model_path = hf_hub_download(repo_id=repo_id, filename=filename)
                    print(f"Downloaded to: {model_path}")
                else:
                    raise ValueError(f"Invalid model name format: {model_name}")
            except Exception as e:
                print(f"Error downloading model: {e}")
                print("Please provide a local path to a .gguf file")
                raise
        
        print(f"Loading GGUF model with Vulkan support...")
        print(f"  Model path: {model_path}")
        print(f"  GPU layers: {n_gpu_layers} (-1 = all layers)")
        print(f"  Context size: {n_ctx}")
        print(f"  GPU device: {main_gpu}")
        
        # List available devices for user reference
        self.list_vulkan_devices()
        
        # Check if single GPU mode is requested
        single_gpu = kwargs.get('single_gpu', False)
        tensor_split = None
        
        if single_gpu:
            # Build tensor_split to force all layers onto one GPU
            # We need to detect how many GPUs are visible to Vulkan
            num_devices = self.count_vulkan_devices()
            # Create tensor_split array: 1.0 for selected GPU, 0.0 for others
            tensor_split = [0.0] * num_devices
            if main_gpu < len(tensor_split):
                tensor_split[main_gpu] = 1.0
            else:
                print(f"Warning: main_gpu={main_gpu} exceeds detected devices ({num_devices})")
                tensor_split = None
            
            if tensor_split:
                print(f"  Single GPU mode: Forcing all layers to GPU {main_gpu}")
                print(f"  Tensor split: {tensor_split}")
        
        try:
            llama_kwargs = {
                'model_path': model_path,
                'n_gpu_layers': n_gpu_layers,
                'n_ctx': n_ctx,
                'verbose': verbose,
                'main_gpu': main_gpu,
            }
            
            if tensor_split:
                llama_kwargs['tensor_split'] = tensor_split
            
            self.model = Llama(**llama_kwargs)
            self.model_name = model_name
            print("\nModel loaded successfully with Vulkan!")
        except Exception as e:
            print(f"Error loading model with Vulkan: {e}")
            print("Make sure Vulkan drivers are installed:")
            print("  Debian/Ubuntu: sudo apt install libvulkan-dev vulkan-tools")
            print("  Fedora: sudo dnf install vulkan-loader-devel vulkan-tools")
            raise
    
    def format_messages(self, messages: List[ChatMessage]) -> str:
        """Format messages into a prompt string suitable for chat models."""
        formatted = []
        
        for msg in messages:
            if msg.role == "system":
                formatted.append(f"<|system|>\n{msg.content}")
            elif msg.role == "user":
                formatted.append(f"<|user|>\n{msg.content}")
            elif msg.role == "assistant":
                content = msg.content or ""
                formatted.append(f"<|assistant|>\n{content}")
        
        formatted.append("<|assistant|>\n")
        return "\n".join(formatted)
    
    def generate(self, prompt: str, max_tokens: Optional[int] = None,
                 temperature: float = 0.7, top_p: float = 1.0,
                 stop: Optional[List[str]] = None) -> str:
        """Generate text non-streaming using llama-cpp."""
        if max_tokens is None:
            max_tokens = 512
        
        output = self.model(
            prompt,
            max_tokens=max_tokens,
            temperature=temperature,
            top_p=top_p,
            stop=stop or [],
        )
        
        return output["choices"][0]["text"]
    
    async def generate_stream(self, prompt: str, max_tokens: Optional[int] = None,
                              temperature: float = 0.7, top_p: float = 1.0,
                              stop: Optional[List[str]] = None) -> AsyncGenerator[str, None]:
        """Generate text in streaming fashion using llama-cpp."""
        if max_tokens is None:
            max_tokens = 512
        
        stream = self.model(
            prompt,
            max_tokens=max_tokens,
            temperature=temperature,
            top_p=top_p,
            stop=stop or [],
            stream=True,
        )
        
        for chunk in stream:
            text = chunk["choices"][0].get("text", "")
            if text:
                yield text
    
    def get_model_name(self) -> str:
        return self.model_name or "unknown"
    
    def cleanup(self) -> None:
        if self.model is not None:
            del self.model
            self.model = None


# =============================================================================
# Model Manager
# =============================================================================

class ModelManager:
    """Manages the loaded model and tokenizer."""
    
    def __init__(self):
        self.backend: Optional[ModelBackend] = None
        self.backend_type: Optional[str] = None
        self.tool_parser = ToolCallParser()
        
    def load_model(self, model_name: str, backend_type: str = "auto", **kwargs):
        """
        Load the model with the specified backend.
        
        Args:
            model_name: Model name or path
            backend_type: 'nvidia', 'vulkan', or 'auto' to detect
            **kwargs: Additional arguments for the specific backend
        """
        available = detect_available_backends()
        
        # Determine backend
        if backend_type == "auto":
            if available.get('nvidia'):
                backend_type = "nvidia"
                print("Auto-detected NVIDIA backend")
            elif available.get('vulkan'):
                backend_type = "vulkan"
                print("Auto-detected Vulkan backend")
            else:
                print("No GPU backend detected. For NVIDIA, install PyTorch with CUDA.")
                print("For Vulkan, install llama-cpp-python with Vulkan support.")
                raise RuntimeError("No suitable backend found")
        
        self.backend_type = backend_type
        
        # Create appropriate backend
        if backend_type == "nvidia":
            if not available.get('nvidia'):
                raise RuntimeError("NVIDIA backend requested but PyTorch/CUDA not available")
            self.backend = NvidiaBackend()
        elif backend_type == "vulkan":
            if not available.get('vulkan'):
                raise RuntimeError("Vulkan backend requested but llama-cpp-python not available")
            self.backend = VulkanBackend()
        else:
            raise ValueError(f"Unknown backend: {backend_type}")
        
        # Load the model
        self.backend.load_model(model_name, **kwargs)
        self.tool_parser = ToolCallParser()
        
    def format_messages(self, messages: List[ChatMessage]) -> str:
        """Format messages into a prompt string."""
        if self.backend is None:
            raise RuntimeError("No model loaded")
        return self.backend.format_messages(messages)
    
    def generate(self, prompt: str, max_tokens: Optional[int] = None,
                 temperature: float = 0.7, top_p: float = 1.0,
                 stop: Optional[List[str]] = None) -> str:
        """Generate text non-streaming."""
        if self.backend is None:
            raise RuntimeError("No model loaded")
        return self.backend.generate(prompt, max_tokens, temperature, top_p, stop)
    
    async def generate_stream(self, prompt: str, max_tokens: Optional[int] = None,
                              temperature: float = 0.7, top_p: float = 1.0,
                              stop: Optional[List[str]] = None) -> AsyncGenerator[str, None]:
        """Generate text in streaming fashion."""
        if self.backend is None:
            raise RuntimeError("No model loaded")
        async for chunk in self.backend.generate_stream(prompt, max_tokens, temperature, top_p, stop):
            yield chunk
    
    @property
    def model_name(self) -> str:
        if self.backend is None:
            return "unknown"
        return self.backend.get_model_name()
    
    @property
    def model(self):
        if self.backend is None:
            return None
        return self.backend
    
    @property
    def tokenizer(self):
        # Only NVIDIA backend has a tokenizer
        if isinstance(self.backend, NvidiaBackend):
            return self.backend.tokenizer
        return None
    
    def cleanup(self):
        if self.backend is not None:
            self.backend.cleanup()
            self.backend = None


# Global model manager
model_manager = ModelManager()


# =============================================================================
# FastAPI Application
# =============================================================================

@asynccontextmanager
async def lifespan(app: FastAPI):
    """Lifespan context manager for startup/shutdown."""
    # Startup
    yield
    # Shutdown
    model_manager.cleanup()


app = FastAPI(
    title="OpenAI-Compatible API",
    description="OpenAI-compatible API supporting NVIDIA (CUDA) and Vulkan backends",
    version="2.0.0",
    lifespan=lifespan,
)

# Add request logging middleware for debugging
@app.middleware("http")
async def log_requests(request: Request, call_next):
    """Log all incoming requests for debugging."""
    if request.url.path in ["/v1/chat/completions", "/v1/completions"]:
        body = await request.body()
        body_str = ""
        try:
            body_str = body.decode('utf-8')
            print(f"\n{'='*60}")
            print(f"=== INCOMING REQUEST ===")
            print(f"{'='*60}")
            print(f"Path: {request.url.path}")
            print(f"Method: {request.method}")
            print(f"Headers: {dict(request.headers)}")
            print(f"\n--- RAW BODY ({len(body)} bytes) ---")
            # Print body with truncation for very large bodies
            if len(body_str) > 2000:
                print(f"{body_str[:1000]}...\n... [truncated {len(body_str)-2000} chars] ...\n...{body_str[-1000:]}")
            else:
                print(body_str)
            print(f"--- END RAW BODY ---")
            
            # Try to parse as JSON to see if it's valid
            try:
                parsed = json.loads(body_str)
                print(f"\n--- PARSED JSON STRUCTURE ---")
                print(f"Keys: {list(parsed.keys())}")
                if 'messages' in parsed and isinstance(parsed['messages'], list):
                    print(f"Number of messages: {len(parsed['messages'])}")
                    for i, msg in enumerate(parsed['messages']):
                        role = msg.get('role', 'unknown')
                        content_preview = str(msg.get('content', ''))[:50].replace('\n', ' ')
                        print(f"  [{i}] {role}: {content_preview}...")
                print(f"--- END PARSED JSON ---")
            except json.JSONDecodeError as e:
                print(f"\n*** JSON Parse Error: {e} ***")
                print(f"Error at position: char {e.pos}, line {e.lineno}, column {e.colno}")
            except Exception as e:
                print(f"\n*** Error analyzing JSON: {e} ***")
        except Exception as e:
            print(f"Error logging request: {e}")
        
        # Re-create request with body for downstream handlers
        async def receive():
            return {"type": "http.request", "body": body}
        request = Request(request.scope, receive, request._send)
    
    try:
        response = await call_next(request)
        if request.url.path in ["/v1/chat/completions", "/v1/completions"]:
            print(f"\n--- RESPONSE ---")
            print(f"Status Code: {response.status_code}")
            
            # For error responses, try to read and log the body, then create a new response
            if response.status_code >= 400:
                try:
                    # Read the response body
                    response_body = b""
                    async for chunk in response.body_iterator:
                        response_body += chunk
                    
                    error_body = response_body.decode('utf-8')
                    print(f"Error Response Body: {error_body}")
                    
                    # Try to parse and pretty-print error details
                    try:
                        error_json = json.loads(error_body)
                        if 'detail' in error_json:
                            print(f"\n*** VALIDATION ERROR DETAILS ***")
                            detail = error_json['detail']
                            if isinstance(detail, list):
                                for err in detail:
                                    if isinstance(err, dict):
                                        loc = err.get('loc', [])
                                        msg = err.get('msg', 'Unknown error')
                                        err_type = err.get('type', 'unknown')
                                        print(f"  - Location: {loc}")
                                        print(f"    Message: {msg}")
                                        print(f"    Type: {err_type}")
                            else:
                                print(f"  Detail: {detail}")
                            print(f"*** END ERROR DETAILS ***")
                    except Exception as parse_err:
                        print(f"Could not parse error details: {parse_err}")
                    
                    # Create new response with the same body
                    from starlette.responses import Response
                    return Response(
                        content=response_body,
                        status_code=response.status_code,
                        headers=dict(response.headers),
                        media_type=response.media_type
                    )
                except Exception as read_err:
                    print(f"Could not read error response body: {read_err}")
            
            print(f"--- END RESPONSE ---")
            print(f"{'='*60}\n")
        return response
    except Exception as e:
        print(f"\n*** EXCEPTION DURING REQUEST PROCESSING ***")
        print(f"Error type: {type(e).__name__}")
        print(f"Error message: {e}")
        import traceback
        traceback.print_exc()
        print(f"*** END EXCEPTION ***\n")
        raise
    finally:
        if request.url.path in ["/v1/chat/completions", "/v1/completions"]:
            pass  # End logging already done above for successful responses


@app.get("/v1/models", response_model=ModelList)
async def list_models():
    """List available models."""
    models = []
    if model_manager.model_name:
        models.append(ModelInfo(id=model_manager.model_name))
    else:
        models.append(ModelInfo(id="unknown"))
    return ModelList(data=models)


@app.post("/v1/chat/completions")
async def chat_completions(request: ChatCompletionRequest):
    """Chat completions endpoint with streaming and tool support."""
    if model_manager.backend is None:
        raise HTTPException(status_code=503, detail="Model not loaded")
    
    # Format messages with tools if provided
    messages = request.messages
    if request.tools:
        messages = format_tools_for_prompt(request.tools, messages)
    
    # Convert messages to prompt
    prompt = model_manager.format_messages(messages)
    
    # Prepare stop sequences
    stop_sequences = []
    if request.stop:
        if isinstance(request.stop, str):
            stop_sequences = [request.stop]
        else:
            stop_sequences = request.stop
    
    if request.stream:
        return StreamingResponse(
            stream_chat_response(
                prompt,
                request.model,
                request.max_tokens,
                request.temperature,
                request.top_p,
                stop_sequences,
                request.tools,
            ),
            media_type="text/event-stream",
        )
    else:
        return await generate_chat_response(
            prompt,
            request.model,
            request.max_tokens,
            request.temperature,
            request.top_p,
            stop_sequences,
            request.tools,
        )


async def stream_chat_response(
    prompt: str,
    model_name: str,
    max_tokens: Optional[int],
    temperature: float,
    top_p: float,
    stop: List[str],
    tools: Optional[List[Tool]],
) -> AsyncGenerator[str, None]:
    """Stream chat completion response."""
    completion_id = f"chatcmpl-{uuid.uuid4().hex}"
    created = int(time.time())
    
    generated_text = ""
    
    try:
        async for chunk in model_manager.generate_stream(
            prompt=prompt,
            max_tokens=max_tokens,
            temperature=temperature,
            top_p=top_p,
            stop=stop,
        ):
            generated_text += chunk
            
            data = {
                "id": completion_id,
                "object": "chat.completion.chunk",
                "created": created,
                "model": model_name,
                "choices": [{
                    "index": 0,
                    "delta": {"content": chunk},
                    "finish_reason": None,
                }],
            }
            yield f"data: {json.dumps(data)}\n\n"
        
        # Check for tool calls in complete output
        if tools:
            tool_calls = model_manager.tool_parser.extract_tool_calls(generated_text, tools)
            if tool_calls:
                data = {
                    "id": completion_id,
                    "object": "chat.completion.chunk",
                    "created": created,
                    "model": model_name,
                    "choices": [{
                        "index": 0,
                        "delta": {"tool_calls": tool_calls},
                        "finish_reason": "tool_calls",
                    }],
                }
                yield f"data: {json.dumps(data)}\n\n"
            else:
                yield f"data: {json.dumps({'choices': [{'finish_reason': 'stop'}]})}\n\n"
        else:
            yield f"data: {json.dumps({'choices': [{'finish_reason': 'stop'}]})}\n\n"
        
        yield "data: [DONE]\n\n"
    except Exception as e:
        print(f"Error during streaming generation: {e}")
        data = {
            "id": completion_id,
            "object": "chat.completion.chunk",
            "created": created,
            "model": model_name,
            "choices": [{
                "index": 0,
                "delta": {"content": f"\n[Generation error: {str(e)}]"},
                "finish_reason": "stop",
            }],
        }
        yield f"data: {json.dumps(data)}\n\n"
        yield "data: [DONE]\n\n"


async def generate_chat_response(
    prompt: str,
    model_name: str,
    max_tokens: Optional[int],
    temperature: float,
    top_p: float,
    stop: List[str],
    tools: Optional[List[Tool]],
) -> Dict:
    """Generate non-streaming chat completion response."""
    completion_id = f"chatcmpl-{uuid.uuid4().hex}"
    created = int(time.time())
    
    try:
        generated_text = model_manager.generate(
            prompt=prompt,
            max_tokens=max_tokens,
            temperature=temperature,
            top_p=top_p,
            stop=stop,
        )
        
        response_message = {
            "role": "assistant",
            "content": generated_text,
        }
        
        finish_reason = "stop"
        
        # Check for tool calls
        if tools:
            tool_calls = model_manager.tool_parser.extract_tool_calls(generated_text, tools)
            if tool_calls:
                response_message["tool_calls"] = tool_calls
                response_message["content"] = None
                finish_reason = "tool_calls"
        
        # Calculate token counts if tokenizer available
        if model_manager.tokenizer:
            prompt_tokens = len(model_manager.tokenizer.encode(prompt))
            completion_tokens = len(model_manager.tokenizer.encode(generated_text))
        else:
            # Rough estimate for Vulkan backend
            prompt_tokens = len(prompt.split())
            completion_tokens = len(generated_text.split())
        
        return {
            "id": completion_id,
            "object": "chat.completion",
            "created": created,
            "model": model_name,
            "choices": [{
                "index": 0,
                "message": response_message,
                "finish_reason": finish_reason,
            }],
            "usage": {
                "prompt_tokens": prompt_tokens,
                "completion_tokens": completion_tokens,
                "total_tokens": prompt_tokens + completion_tokens,
            },
        }
    except Exception as e:
        print(f"Error during generation: {e}")
        raise HTTPException(status_code=500, detail=f"Generation error: {str(e)}")


@app.post("/v1/completions")
async def completions(request: CompletionRequest):
    """Text completions endpoint."""
    if model_manager.backend is None:
        raise HTTPException(status_code=503, detail="Model not loaded")
    
    prompts = request.prompt if isinstance(request.prompt, list) else [request.prompt]
    stop_sequences = []
    if request.stop:
        stop_sequences = [request.stop] if isinstance(request.stop, str) else request.stop
    
    if request.stream:
        return StreamingResponse(
            stream_completion_response(
                prompts[0],
                request.model,
                request.max_tokens,
                request.temperature,
                request.top_p,
                stop_sequences,
            ),
            media_type="text/event-stream",
        )
    else:
        return await generate_completion_response(
            prompts[0],
            request.model,
            request.max_tokens,
            request.temperature,
            request.top_p,
            stop_sequences,
        )


async def stream_completion_response(
    prompt: str,
    model_name: str,
    max_tokens: Optional[int],
    temperature: float,
    top_p: float,
    stop: List[str],
) -> AsyncGenerator[str, None]:
    """Stream completion response."""
    completion_id = f"cmpl-{uuid.uuid4().hex}"
    created = int(time.time())
    
    try:
        async for chunk in model_manager.generate_stream(
            prompt=prompt,
            max_tokens=max_tokens,
            temperature=temperature,
            top_p=top_p,
            stop=stop,
        ):
            data = {
                "id": completion_id,
                "object": "text_completion",
                "created": created,
                "model": model_name,
                "choices": [{
                    "text": chunk,
                    "index": 0,
                    "logprobs": None,
                    "finish_reason": None,
                }],
            }
            yield f"data: {json.dumps(data)}\n\n"
        
        yield f"data: {json.dumps({'choices': [{'finish_reason': 'stop'}]})}\n\n"
        yield "data: [DONE]\n\n"
    except Exception as e:
        print(f"Error during streaming completion: {e}")
        yield f"data: {json.dumps({'choices': [{'finish_reason': 'stop'}]})}\n\n"
        yield "data: [DONE]\n\n"


async def generate_completion_response(
    prompt: str,
    model_name: str,
    max_tokens: Optional[int],
    temperature: float,
    top_p: float,
    stop: List[str],
) -> Dict:
    """Generate non-streaming completion response."""
    completion_id = f"cmpl-{uuid.uuid4().hex}"
    created = int(time.time())
    
    try:
        generated_text = model_manager.generate(
            prompt=prompt,
            max_tokens=max_tokens,
            temperature=temperature,
            top_p=top_p,
            stop=stop,
        )
        
        # Calculate token counts if tokenizer available
        if model_manager.tokenizer:
            prompt_tokens = len(model_manager.tokenizer.encode(prompt))
            completion_tokens = len(model_manager.tokenizer.encode(generated_text))
        else:
            prompt_tokens = len(prompt.split())
            completion_tokens = len(generated_text.split())
        
        return {
            "id": completion_id,
            "object": "text_completion",
            "created": created,
            "model": model_name,
            "choices": [{
                "text": generated_text,
                "index": 0,
                "logprobs": None,
                "finish_reason": "stop",
            }],
            "usage": {
                "prompt_tokens": prompt_tokens,
                "completion_tokens": completion_tokens,
                "total_tokens": prompt_tokens + completion_tokens,
            },
        }
    except Exception as e:
        print(f"Error during completion: {e}")
        raise HTTPException(status_code=500, detail=f"Generation error: {str(e)}")


# =============================================================================
# Main Entry Point
# =============================================================================

def parse_args():
    """Parse command line arguments."""
    parser = argparse.ArgumentParser(
        description="OpenAI-compatible API server supporting NVIDIA (CUDA) and Vulkan backends"
    )
    parser.add_argument(
        "--model",
        type=str,
        default=None,
        help="Model name or path. For NVIDIA: HuggingFace model. For Vulkan: GGUF file path or HF repo",
    )
    parser.add_argument(
        "--backend",
        type=str,
        choices=["auto", "nvidia", "vulkan"],
        default="auto",
        help="Backend to use: auto (detect), nvidia (CUDA), or vulkan (AMD GPUs)",
    )
    parser.add_argument(
        "--host",
        type=str,
        default="0.0.0.0",
        help="Host to bind to (default: 0.0.0.0)",
    )
    parser.add_argument(
        "--port",
        type=int,
        default=8000,
        help="Port to bind to (default: 8000)",
    )
    parser.add_argument(
        "--offload-dir",
        type=str,
        default="./offload",
        help="Directory for disk offload (NVIDIA backend only, default: ./offload)",
    )
    parser.add_argument(
        "--load-in-4bit",
        action="store_true",
        help="Load model in 4-bit precision (NVIDIA backend only, requires bitsandbytes)",
    )
    parser.add_argument(
        "--load-in-8bit",
        action="store_true",
        help="Load model in 8-bit precision (NVIDIA backend only, requires bitsandbytes)",
    )
    parser.add_argument(
        "--ram",
        type=float,
        default=None,
        help="Manually specify available RAM in GB (NVIDIA backend only)",
    )
    parser.add_argument(
        "--flash-attn",
        action="store_true",
        help="Use Flash Attention 2 (NVIDIA backend only, requires flash-attn package)",
    )
    parser.add_argument(
        "--n-gpu-layers",
        type=int,
        default=-1,
        help="Number of layers to offload to GPU (Vulkan backend only, default: -1 = all layers)",
    )
    parser.add_argument(
        "--n-ctx",
        type=int,
        default=2048,
        help="Context window size (Vulkan backend only, default: 2048)",
    )
    parser.add_argument(
        "--vulkan-device",
        type=int,
        default=0,
        help="Vulkan GPU device ID to use (Vulkan backend only, default: 0). Use --vulkan-list-devices to see available devices",
    )
    parser.add_argument(
        "--vulkan-single-gpu",
        action="store_true",
        help="Force Vulkan to use only the specified GPU device (prevents layer distribution across multiple GPUs)",
    )
    parser.add_argument(
        "--vulkan-list-devices",
        action="store_true",
        help="List available Vulkan GPU devices and exit",
    )
    return parser.parse_args()


def main():
    """Main entry point."""
    # Optional: set process name if procname is available
    try:
        import procname
        procname.setprocname("coderai")
    except ImportError:
        pass
    args = parse_args()
    
    # Handle --vulkan-list-devices
    if args.vulkan_list_devices:
        print("\nListing Vulkan devices...")
        try:
            import subprocess
            result = subprocess.run(['vulkaninfo', '--summary'], capture_output=True, text=True)
            if result.returncode == 0:
                print(result.stdout)
            else:
                print("Could not run vulkaninfo. Make sure vulkan-tools is installed.")
        except Exception as e:
            print(f"Error listing devices: {e}")
        sys.exit(0)
    
    # Get model name from args or prompt interactively
    model_name = args.model
    if model_name is None:
        print("No model specified. Please enter a model name.")
        print("")
        print("For NVIDIA backend (HuggingFace models):")
        print("  - microsoft/DialoGPT-medium")
        print("  - meta-llama/Llama-2-7b-chat-hf (requires auth)")
        print("  - TinyLlama/TinyLlama-1.1B-Chat-v1.0")
        print("")
        print("For Vulkan backend (GGUF models):")
        print("  - Local path: ./phi-3-mini-4k-instruct-q4_k_m.gguf")
        print("  - HuggingFace: microsoft/Phi-3-mini-4k-instruct-gguf")
        print("")
        model_name = input("Enter model name: ").strip()
        
        if not model_name:
            print("Error: Model name is required")
            sys.exit(1)
    
    # Detect available backends
    available = detect_available_backends()
    print("\nAvailable backends:")
    for name, available_flag in available.items():
        status = "✓" if available_flag else "✗"
        print(f"  [{status}] {name}")
    print("")
    
    # Load the model
    load_kwargs = {
        'offload_dir': args.offload_dir,
        'load_in_4bit': args.load_in_4bit,
        'load_in_8bit': args.load_in_8bit,
        'manual_ram_gb': args.ram,
        'flash_attn': args.flash_attn,
        'n_gpu_layers': args.n_gpu_layers,
        'n_ctx': args.n_ctx,
        'main_gpu': args.vulkan_device,
        'single_gpu': args.vulkan_single_gpu,
    }
    
    try:
        model_manager.load_model(
            model_name=model_name,
            backend_type=args.backend,
            **load_kwargs
        )
    except Exception as e:
        print(f"\nError loading model: {e}")
        print("\nTroubleshooting:")
        if args.backend == "vulkan":
            print("  - For Vulkan, ensure you have Vulkan drivers installed")
            print("  - Make sure you're using a GGUF format model")
            print("  - Run build.sh with 'vulkan' argument first")
        else:
            print("  - For NVIDIA, ensure PyTorch with CUDA is installed")
            print("  - Run build.sh with 'nvidia' argument first")
        sys.exit(1)
    
    # Start the server
    import uvicorn
    print(f"\nStarting server on http://{args.host}:{args.port}")
    print(f"API documentation available at http://{args.host}:{args.port}/docs")
    print(f"Using backend: {model_manager.backend_type}")
    uvicorn.run(app, host=args.host, port=args.port)


if __name__ == "__main__":
    main()
