#!/usr/bin/env python3
"""
OpenAI-compatible API server for HuggingFace models.
Supports CUDA, ROCm GPU auto-detection, memory-aware model loading,
sequential offload (VRAM -> RAM -> Disk), streaming, and tool calling.
"""

import argparse
import asyncio
import json
import os
import re
import sys
import time
import uuid
import warnings
from contextlib import asynccontextmanager
from typing import AsyncGenerator, Dict, List, Optional, Union

import psutil
import torch
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    AutoConfig,
    TextIteratorStreamer,
    StoppingCriteria,
    StoppingCriteriaList,
    LogitsProcessor,
    LogitsProcessorList,
)
from threading import Thread


# =============================================================================
# Flash Attention Detection
# =============================================================================

def check_flash_attn_availability() -> bool:
    """Check if flash-attn is installed and available."""
    try:
        import flash_attn
        return True
    except ImportError:
        return False


# =============================================================================
# Logits Processor for Numerical Stability
# =============================================================================

class InvalidLogitsProcessor(LogitsProcessor):
    """Replace NaN and Inf values in logits with finite values."""
    
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        """Replace invalid values in logits."""
        # Replace NaN with very negative number (near -inf but finite)
        scores = torch.where(torch.isnan(scores), torch.tensor(-1e9, dtype=scores.dtype, device=scores.device), scores)
        # Replace Inf with large finite number
        scores = torch.where(torch.isinf(scores), torch.tensor(1e9, dtype=scores.dtype, device=scores.device), scores)
        # Replace -Inf with very negative finite number
        scores = torch.where(scores < -1e9, torch.tensor(-1e9, dtype=scores.dtype, device=scores.device), scores)
        return scores


# =============================================================================
# Memory Detection and Model Sizing
# =============================================================================

def get_available_vram() -> int:
    """Get available VRAM in bytes. Returns 0 if no GPU available."""
    if not torch.cuda.is_available():
        return 0
    
    try:
        total_vram = 0
        for i in range(torch.cuda.device_count()):
            props = torch.cuda.get_device_properties(i)
            total_vram += props.total_memory
        return total_vram
    except Exception as e:
        print(f"Warning: Could not detect VRAM: {e}")
        return 0


def get_available_ram(manual_ram_gb: Optional[float] = None) -> int:
    """
    Get available system RAM in bytes.
    
    Args:
        manual_ram_gb: If specified, use this value in GB instead of auto-detection
    
    Returns:
        Available RAM in bytes
    """
    if manual_ram_gb is not None:
        ram_bytes = int(manual_ram_gb * 1e9)
        print(f"Using manually specified RAM: {manual_ram_gb} GB ({ram_bytes / 1e9:.2f} GB)")
        return ram_bytes
    
    try:
        mem = psutil.virtual_memory()
        print(f"Auto-detected RAM: {mem.available / 1e9:.2f} GB available")
        return mem.available
    except Exception as e:
        print(f"Warning: Could not detect RAM: {e}")
        return 0


def estimate_model_size_from_config(model_name: str) -> Optional[int]:
    """
    Estimate model size in bytes from config.
    Returns None if config cannot be loaded.
    """
    try:
        config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
        
        # Get model parameters from config
        if hasattr(config, 'num_parameters'):
            num_params = config.num_parameters
        elif hasattr(config, 'n_params'):
            num_params = config.n_params
        elif hasattr(config, 'num_hidden_layers') and hasattr(config, 'hidden_size'):
            # Estimate based on transformer architecture
            # Rough estimate: ~12 * num_layers * hidden_size^2 for standard transformers
            layers = config.num_hidden_layers
            hidden = config.hidden_size
            vocab_size = getattr(config, 'vocab_size', 50000)
            
            # Rough parameter count estimation
            # Embedding: vocab_size * hidden_size
            # Each layer: ~4 * hidden_size^2 (attn + FFN)
            num_params = (vocab_size * hidden_size) + (layers * 4 * hidden_size * hidden_size)
        else:
            return None
        
        # Assume float16 (2 bytes per parameter) for GPU loading
        # This is the typical loading format
        return num_params * 2
    except Exception as e:
        print(f"Warning: Could not estimate model size: {e}")
        return None


def calculate_safety_margin(memory_bytes: int) -> int:
    """Apply safety margin to available memory (leave 10% headroom)."""
    return int(memory_bytes * 0.9)


def determine_offload_strategy(
    model_name: str,
    available_vram: int,
    available_ram: int,
    quantization_bits: Optional[int] = None
) -> Dict[str, any]:
    """
    Determine the best offload strategy based on available memory.
    
    Returns a dict with:
    - device_map: str or dict for model loading
    - offload_folder: Optional[str] for disk offload
    - load_in_8bit: bool
    - load_in_4bit: bool
    - max_memory: Optional[dict]
    """
    # Estimate model size
    estimated_size = estimate_model_size_from_config(model_name)
    
    if estimated_size is None:
        print("Could not estimate model size, using auto device_map")
        return {
            'device_map': 'auto',
            'offload_folder': None,
            'load_in_8bit': False,
            'load_in_4bit': False,
            'max_memory': None,
        }
    
    # Apply quantization factor if specified
    if quantization_bits == 4:
        estimated_size = estimated_size // 4  # 4-bit = 0.5 bytes per param
    elif quantization_bits == 8:
        estimated_size = estimated_size // 2  # 8-bit = 1 byte per param
    
    # Add overhead for activations and gradients (roughly 20%)
    required_memory = int(estimated_size * 1.2)
    
    print(f"Estimated model size: {estimated_size / 1e9:.2f} GB")
    print(f"Required memory (with overhead): {required_memory / 1e9:.2f} GB")
    print(f"Available VRAM: {available_vram / 1e9:.2f} GB")
    print(f"Available RAM: {available_ram / 1e9:.2f} GB")
    
    safe_vram = calculate_safety_margin(available_vram)
    safe_ram = calculate_safety_margin(available_ram)
    
    strategy = {
        'device_map': None,
        'offload_folder': None,
        'load_in_8bit': False,
        'load_in_4bit': False,
        'max_memory': None,
    }
    
    # Case 1: Model fits entirely in VRAM
    if required_memory <= safe_vram:
        print("Strategy: Loading fully to GPU")
        strategy['device_map'] = 'cuda'
        if torch.cuda.device_count() > 1:
            strategy['device_map'] = 'auto'
    
    # Case 2: Model fits in VRAM + RAM combined
    elif required_memory <= (safe_vram + safe_ram):
        print("Strategy: Using device_map='auto' for VRAM + RAM offload")
        strategy['device_map'] = 'auto'
        # Set max_memory to help accelerate distribute layers
        if torch.cuda.is_available():
            max_memory = {}
            for i in range(torch.cuda.device_count()):
                max_memory[i] = safe_vram // torch.cuda.device_count()
            max_memory['cpu'] = safe_ram
            strategy['max_memory'] = max_memory
    
    # Case 3: Need disk offload
    else:
        print("Strategy: VRAM + RAM + Disk offload required")
        strategy['device_map'] = 'auto'
        if torch.cuda.is_available():
            max_memory = {}
            for i in range(torch.cuda.device_count()):
                max_memory[i] = safe_vram // torch.cuda.device_count()
            max_memory['cpu'] = safe_ram
            strategy['max_memory'] = max_memory
        # offload_folder will be set from command line argument
    
    return strategy


# =============================================================================
# Pydantic Models for API
# =============================================================================

class ToolFunction(BaseModel):
    name: str
    description: Optional[str] = None
    parameters: Optional[Dict] = None


class Tool(BaseModel):
    type: str = "function"
    function: ToolFunction


class ChatMessage(BaseModel):
    role: str
    content: Optional[str] = None
    name: Optional[str] = None
    tool_calls: Optional[List[Dict]] = None
    tool_call_id: Optional[str] = None


class ChatCompletionRequest(BaseModel):
    model: str
    messages: List[ChatMessage]
    temperature: float = 0.7
    top_p: float = 1.0
    n: int = 1
    max_tokens: Optional[int] = None
    stream: bool = False
    stop: Optional[Union[str, List[str]]] = None
    presence_penalty: float = 0.0
    frequency_penalty: float = 0.0
    tools: Optional[List[Tool]] = None
    tool_choice: Optional[Union[str, Dict]] = "auto"


class CompletionRequest(BaseModel):
    model: str
    prompt: Union[str, List[str]]
    temperature: float = 0.7
    top_p: float = 1.0
    n: int = 1
    max_tokens: Optional[int] = None
    stream: bool = False
    stop: Optional[Union[str, List[str]]] = None


class ModelInfo(BaseModel):
    id: str
    object: str = "model"
    created: int = Field(default_factory=lambda: int(time.time()))
    owned_by: str = "huggingface"


class ModelList(BaseModel):
    object: str = "list"
    data: List[ModelInfo]


# =============================================================================
# Tool Parsing and Function Calling
# =============================================================================

class ToolCallParser:
    """Parse model outputs to extract tool calls."""
    
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer
    
    def extract_tool_calls(self, text: str, available_tools: List[Tool]) -> Optional[List[Dict]]:
        """Extract tool calls from model output."""
        tool_calls = []
        
        # Look for function calls in various formats
        # Format 1: <tool> or <function> tags
        tool_pattern = r'<(?:tool|function)>(.*?)</(?:tool|function)>'
        tool_matches = re.findall(tool_pattern, text, re.DOTALL)
        
        for match in tool_matches:
            try:
                tool_data = json.loads(match.strip())
                if 'name' in tool_data and 'arguments' in tool_data:
                    tool_calls.append({
                        "id": f"call_{uuid.uuid4().hex[:16]}",
                        "type": "function",
                        "function": {
                            "name": tool_data["name"],
                            "arguments": json.dumps(tool_data["arguments"])
                        }
                    })
            except json.JSONDecodeError:
                pass
        
        # Format 2: JSON with function_call key
        try:
            if "function_call" in text:
                json_match = re.search(r'\{[^}]*"function_call"[^}]*\}', text, re.DOTALL)
                if json_match:
                    data = json.loads(json_match.group())
                    if "function_call" in data:
                        fc = data["function_call"]
                        tool_calls.append({
                            "id": f"call_{uuid.uuid4().hex[:16]}",
                            "type": "function",
                            "function": {
                                "name": fc.get("name", ""),
                                "arguments": fc.get("arguments", "{}")
                            }
                        })
        except (json.JSONDecodeError, AttributeError):
            pass
        
        # Format 3: Direct JSON function call
        for tool in available_tools:
            tool_name = tool.function.name
            pattern = rf'<{tool_name}>(.*?)</{tool_name}>'
            matches = re.findall(pattern, text, re.DOTALL)
            for match in matches:
                try:
                    args = json.loads(match.strip())
                    tool_calls.append({
                        "id": f"call_{uuid.uuid4().hex[:16]}",
                        "type": "function",
                        "function": {
                            "name": tool_name,
                            "arguments": json.dumps(args)
                        }
                    })
                except json.JSONDecodeError:
                    # Try to use the raw text as arguments
                    tool_calls.append({
                        "id": f"call_{uuid.uuid4().hex[:16]}",
                        "type": "function",
                        "function": {
                            "name": tool_name,
                            "arguments": match.strip()
                        }
                    })
        
        return tool_calls if tool_calls else None


def format_tools_for_prompt(tools: List[Tool], messages: List[ChatMessage]) -> List[ChatMessage]:
    """Format tools into the system message or add a tool description."""
    if not tools:
        return messages
    
    tool_descriptions = []
    for tool in tools:
        func = tool.function
        desc = f"Tool: {func.name}"
        if func.description:
            desc += f"\nDescription: {func.description}"
        if func.parameters:
            desc += f"\nParameters: {json.dumps(func.parameters, indent=2)}"
        tool_descriptions.append(desc)
    
    tools_text = "You have access to the following tools:\n\n" + "\n\n".join(tool_descriptions)
    tools_text += "\n\nWhen you need to use a tool, format your response as:\n"
    tools_text += '<tool>{"name": "tool_name", "arguments": {...}}</tool>'
    
    # Add or prepend to system message
    new_messages = list(messages)
    system_found = False
    
    for i, msg in enumerate(new_messages):
        if msg.role == "system":
            new_messages[i] = ChatMessage(
                role="system",
                content=f"{tools_text}\n\n{msg.content or ''}"
            )
            system_found = True
            break
    
    if not system_found:
        new_messages.insert(0, ChatMessage(role="system", content=tools_text))
    
    return new_messages


# =============================================================================
# Model Management
# =============================================================================

class ModelManager:
    """Manages the loaded model and tokenizer."""
    
    def __init__(self):
        self.model = None
        self.tokenizer = None
        self.model_name = None
        self.device = None
        self.tool_parser = None
        self.offload_folder = None
        self.use_flash_attn = False
        self.flash_attn_available = False
        
    def check_flash_attn_support(self) -> None:
        """Check and print Flash Attention availability status."""
        self.flash_attn_available = check_flash_attn_availability()
        if self.use_flash_attn:
            if self.flash_attn_available:
                print("Flash Attention 2: Available and enabled")
            else:
                print("Warning: Flash Attention 2 requested but not installed")
                print("Install with: pip install flash-attn --no-build-isolation")
                print("Falling back to standard attention")
                self.use_flash_attn = False
    
    def detect_device(self) -> str:
        """Auto-detect available GPU or fall back to CPU."""
        if torch.cuda.is_available():
            # Check for ROCm (HIP)
            if hasattr(torch.version, 'hip') and torch.version.hip is not None:
                print(f"ROCm/HIP detected: {torch.version.hip}")
                return "cuda"
            else:
                print(f"CUDA detected: {torch.version.cuda}")
                return "cuda"
        else:
            print("No GPU detected, using CPU")
            return "cpu"
    
    def load_model(
        self,
        model_name: str,
        offload_dir: Optional[str] = None,
        load_in_4bit: bool = False,
        load_in_8bit: bool = False,
        manual_ram_gb: Optional[float] = None,
        flash_attn: bool = False,
    ):
        """
        Load the model and tokenizer from HuggingFace with memory-aware offload.
        
        Args:
            model_name: HuggingFace model name or path
            offload_dir: Directory for disk offload when model doesn't fit in VRAM+RAM
            load_in_4bit: Use 4-bit quantization (requires bitsandbytes)
            load_in_8bit: Use 8-bit quantization (requires bitsandbytes)
            manual_ram_gb: Manually specify available RAM in GB (bypasses auto-detection)
            flash_attn: Use Flash Attention 2 if available (requires flash-attn package)
        """
        print(f"Loading model: {model_name}")
        
        self.use_flash_attn = flash_attn
        self.check_flash_attn_support()
        
        self.device = self.detect_device()
        self.offload_folder = offload_dir
        
        # Create offload directory if needed
        if offload_dir:
            os.makedirs(offload_dir, exist_ok=True)
            print(f"Disk offload directory: {offload_dir}")
        
        # Detect available memory
        available_vram = get_available_vram()
        available_ram = get_available_ram(manual_ram_gb)
        
        print(f"\nMemory Detection:")
        print(f"  Available VRAM: {available_vram / 1e9:.2f} GB")
        print(f"  Available RAM: {available_ram / 1e9:.2f} GB")
        
        # Determine quantization bits
        quantization_bits = None
        if load_in_4bit:
            quantization_bits = 4
        elif load_in_8bit:
            quantization_bits = 8
        
        # Determine offload strategy
        strategy = determine_offload_strategy(
            model_name,
            available_vram,
            available_ram,
            quantization_bits
        )
        
        # Set offload folder if determined necessary
        if strategy.get('offload_folder') is None and offload_dir:
            estimated_size = estimate_model_size_from_config(model_name)
            safe_vram = calculate_safety_margin(available_vram)
            safe_ram = calculate_safety_margin(available_ram)
            
            if estimated_size and estimated_size > (safe_vram + safe_ram):
                strategy['offload_folder'] = offload_dir
                print(f"Model will use disk offload at: {offload_dir}")
        
        # Load tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(
            model_name,
            trust_remote_code=True,
            padding_side="left"
        )
        
        # Set pad token if not present
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
        
        # Prepare model loading arguments
        load_kwargs = {
            'trust_remote_code': True,
        }
        
        # Set dtype based on device and quantization
        if load_in_4bit or load_in_8bit:
            # Check if bitsandbytes is available
            try:
                import bitsandbytes as bnb
                print(f"Using {4 if load_in_4bit else 8}-bit quantization")
                load_kwargs['load_in_4bit'] = load_in_4bit
                load_kwargs['load_in_8bit'] = load_in_8bit
                load_kwargs['device_map'] = strategy['device_map'] or 'auto'
            except ImportError:
                print("Warning: bitsandbytes not installed. Quantization disabled.")
                print("Install with: pip install bitsandbytes")
                if self.device == "cuda":
                    load_kwargs['torch_dtype'] = torch.float16
                else:
                    load_kwargs['torch_dtype'] = torch.float32
                load_kwargs['device_map'] = strategy['device_map'] or ('auto' if self.device == 'cuda' else None)
        else:
            if self.device == "cuda":
                load_kwargs['torch_dtype'] = torch.float16
            else:
                load_kwargs['torch_dtype'] = torch.float32
            load_kwargs['device_map'] = strategy['device_map'] or ('auto' if self.device == 'cuda' else None)
        
        # Add max_memory if specified
        if strategy.get('max_memory'):
            load_kwargs['max_memory'] = strategy['max_memory']
        
        # Add offload_folder if specified
        if strategy.get('offload_folder'):
            load_kwargs['offload_folder'] = strategy['offload_folder']
        
        # Add Flash Attention 2 configuration if enabled and available
        if self.use_flash_attn and self.flash_attn_available:
            load_kwargs['attn_implementation'] = "flash_attention_2"
            print("\nUsing Flash Attention 2 for attention implementation")
        
        print(f"\nModel loading arguments:")
        for key, value in load_kwargs.items():
            print(f"  {key}: {value}")
        
        # Load model
        self.model = AutoModelForCausalLM.from_pretrained(
            model_name,
            **load_kwargs
        )
        
        # Handle CPU case where device_map is None
        if self.device == "cpu" and load_kwargs.get('device_map') is None:
            self.model = self.model.to(self.device)
        
        self.model.eval()
        self.model_name = model_name
        self.tool_parser = ToolCallParser(self.tokenizer)
        
        # Print model device placement
        if hasattr(self.model, 'hf_device_map'):
            print(f"\nDevice map:")
            for layer, device in self.model.hf_device_map.items():
                print(f"  {layer}: {device}")
        
        print(f"\nModel loaded successfully")
        print(f"Model device: {next(self.model.parameters()).device}")
    
    def format_messages(self, messages: List[ChatMessage]) -> str:
        """Format messages into a prompt string."""
        formatted = []
        
        for msg in messages:
            if msg.role == "system":
                formatted.append(f"System: {msg.content}")
            elif msg.role == "user":
                formatted.append(f"User: {msg.content}")
            elif msg.role == "assistant":
                content = msg.content or ""
                if msg.tool_calls:
                    for tc in msg.tool_calls:
                        if tc.get("function"):
                            func = tc["function"]
                            content += f'\n<tool>{{"name": "{func.get("name", "")}", "arguments": {func.get("arguments", "{}")}}}</tool>'
                formatted.append(f"Assistant: {content}")
            elif msg.role == "tool":
                formatted.append(f"Tool ({msg.name}): {msg.content}")
        
        formatted.append("Assistant:")
        return "\n\n".join(formatted)
    
    def _validate_generation_params(self, temperature: float, top_p: float) -> tuple:
        """Validate and clamp generation parameters for numerical stability."""
        # Clamp temperature to avoid numerical issues
        # Temperature must be > 0 for sampling, but very small values can cause issues
        if temperature <= 0:
            temperature = 1.0
            do_sample = False
        else:
            temperature = max(0.01, min(temperature, 2.0))
            do_sample = True
        
        # Clamp top_p
        top_p = max(0.0, min(top_p, 1.0))
        
        return temperature, top_p, do_sample
    
    def generate_stream(
        self,
        prompt: str,
        max_tokens: Optional[int] = None,
        temperature: float = 0.7,
        top_p: float = 1.0,
        stop: Optional[List[str]] = None,
    ) -> AsyncGenerator[str, None]:
        """Generate text in streaming fashion."""
        inputs = self.tokenizer(prompt, return_tensors="pt", padding=True)
        inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
        
        input_length = inputs["input_ids"].shape[1]
        
        if max_tokens is None:
            max_tokens = 512
        
        # Validate parameters
        temperature, top_p, do_sample = self._validate_generation_params(temperature, top_p)
        
        streamer = TextIteratorStreamer(
            self.tokenizer,
            skip_prompt=True,
            skip_special_tokens=True,
        )
        
        generation_kwargs = {
            "input_ids": inputs["input_ids"],
            "attention_mask": inputs["attention_mask"],
            "max_new_tokens": max_tokens,
            "temperature": temperature,
            "top_p": top_p,
            "do_sample": do_sample,
            "streamer": streamer,
            "pad_token_id": self.tokenizer.pad_token_id,
            "eos_token_id": self.tokenizer.eos_token_id,
        }
        
        # Add logits processor to handle NaN/Inf values
        generation_kwargs["logits_processor"] = LogitsProcessorList([
            InvalidLogitsProcessor()
        ])
        
        # Handle stop sequences
        if stop:
            class StopOnSequence(StoppingCriteria):
                def __init__(self, stop_sequences, tokenizer):
                    self.stop_sequences = stop_sequences
                    self.tokenizer = tokenizer
                
                def __call__(self, input_ids, scores, **kwargs):
                    decoded = self.tokenizer.decode(input_ids[0][-20:], skip_special_tokens=True)
                    return any(seq in decoded for seq in self.stop_sequences)
            
            generation_kwargs["stopping_criteria"] = StoppingCriteriaList([
                StopOnSequence(stop, self.tokenizer)
            ])
        
        # Run generation in a separate thread with error handling
        generated_text = ""
        try:
            thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
            thread.start()
            
            for text in streamer:
                generated_text += text
                yield text
            
            thread.join()
        except RuntimeError as e:
            if "probability tensor contains" in str(e):
                print(f"Warning: Numerical error during generation: {e}")
                print("This may be due to temperature=0 or numerical instability.")
                print("Trying again with greedy decoding...")
                # Fallback to greedy decoding
                generation_kwargs["do_sample"] = False
                generation_kwargs["temperature"] = None
                generation_kwargs["top_p"] = None
                thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
                thread.start()
                for text in streamer:
                    generated_text += text
                    yield text
                thread.join()
            else:
                raise
    
    def generate(
        self,
        prompt: str,
        max_tokens: Optional[int] = None,
        temperature: float = 0.7,
        top_p: float = 1.0,
        stop: Optional[List[str]] = None,
    ) -> str:
        """Generate text non-streaming."""
        inputs = self.tokenizer(prompt, return_tensors="pt", padding=True)
        inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
        
        if max_tokens is None:
            max_tokens = 512
        
        # Validate parameters
        temperature, top_p, do_sample = self._validate_generation_params(temperature, top_p)
        
        try:
            with torch.no_grad():
                outputs = self.model.generate(
                    input_ids=inputs["input_ids"],
                    attention_mask=inputs["attention_mask"],
                    max_new_tokens=max_tokens,
                    temperature=temperature if do_sample else None,
                    top_p=top_p if do_sample else None,
                    do_sample=do_sample,
                    pad_token_id=self.tokenizer.pad_token_id,
                    eos_token_id=self.tokenizer.eos_token_id,
                    stopping_criteria=self._create_stopping_criteria(stop) if stop else None,
                    logits_processor=LogitsProcessorList([InvalidLogitsProcessor()]),
                )
            
            generated_tokens = outputs[0][inputs["input_ids"].shape[1]:]
            return self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
        except RuntimeError as e:
            if "probability tensor contains" in str(e):
                print(f"Warning: Numerical error during generation: {e}")
                print("Retrying with greedy decoding...")
                # Fallback to greedy decoding
                with torch.no_grad():
                    outputs = self.model.generate(
                        input_ids=inputs["input_ids"],
                        attention_mask=inputs["attention_mask"],
                        max_new_tokens=max_tokens,
                        do_sample=False,
                        pad_token_id=self.tokenizer.pad_token_id,
                        eos_token_id=self.tokenizer.eos_token_id,
                        stopping_criteria=self._create_stopping_criteria(stop) if stop else None,
                        logits_processor=LogitsProcessorList([InvalidLogitsProcessor()]),
                    )
                generated_tokens = outputs[0][inputs["input_ids"].shape[1]:]
                return self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
            else:
                raise
    
    def _create_stopping_criteria(self, stop_sequences):
        """Create stopping criteria for stop sequences."""
        if not stop_sequences:
            return None
        
        class StopOnSequence(StoppingCriteria):
            def __init__(self, stop_sequences, tokenizer):
                self.stop_sequences = stop_sequences
                self.tokenizer = tokenizer
            
            def __call__(self, input_ids, scores, **kwargs):
                decoded = self.tokenizer.decode(input_ids[0][-20:], skip_special_tokens=True)
                return any(seq in decoded for seq in self.stop_sequences)
        
        return StoppingCriteriaList([StopOnSequence(stop_sequences, self.tokenizer)])


# Global model manager
model_manager = ModelManager()


# =============================================================================
# FastAPI Application
# =============================================================================

@asynccontextmanager
async def lifespan(app: FastAPI):
    """Lifespan context manager for startup/shutdown."""
    # Startup
    yield
    # Shutdown
    if model_manager.model is not None:
        del model_manager.model
        del model_manager.tokenizer
        torch.cuda.empty_cache() if torch.cuda.is_available() else None


app = FastAPI(
    title="OpenAI-Compatible API",
    description="OpenAI-compatible API for HuggingFace models with memory-aware loading",
    version="1.0.0",
    lifespan=lifespan,
)


@app.get("/v1/models", response_model=ModelList)
async def list_models():
    """List available models."""
    models = []
    if model_manager.model_name:
        models.append(ModelInfo(id=model_manager.model_name))
    else:
        models.append(ModelInfo(id="unknown"))
    return ModelList(data=models)


@app.post("/v1/chat/completions")
async def chat_completions(request: ChatCompletionRequest):
    """Chat completions endpoint with streaming and tool support."""
    if model_manager.model is None:
        raise HTTPException(status_code=503, detail="Model not loaded")
    
    # Format messages with tools if provided
    messages = request.messages
    if request.tools:
        messages = format_tools_for_prompt(request.tools, messages)
    
    # Convert messages to prompt
    prompt = model_manager.format_messages(messages)
    
    # Prepare stop sequences
    stop_sequences = []
    if request.stop:
        if isinstance(request.stop, str):
            stop_sequences = [request.stop]
        else:
            stop_sequences = request.stop
    
    if request.stream:
        return StreamingResponse(
            stream_chat_response(
                prompt,
                request.model,
                request.max_tokens,
                request.temperature,
                request.top_p,
                stop_sequences,
                request.tools,
            ),
            media_type="text/event-stream",
        )
    else:
        return await generate_chat_response(
            prompt,
            request.model,
            request.max_tokens,
            request.temperature,
            request.top_p,
            stop_sequences,
            request.tools,
        )


async def stream_chat_response(
    prompt: str,
    model_name: str,
    max_tokens: Optional[int],
    temperature: float,
    top_p: float,
    stop: List[str],
    tools: Optional[List[Tool]],
) -> AsyncGenerator[str, None]:
    """Stream chat completion response."""
    completion_id = f"chatcmpl-{uuid.uuid4().hex}"
    created = int(time.time())
    
    generated_text = ""
    
    try:
        for chunk in model_manager.generate_stream(
            prompt=prompt,
            max_tokens=max_tokens,
            temperature=temperature,
            top_p=top_p,
            stop=stop,
        ):
            generated_text += chunk
            
            data = {
                "id": completion_id,
                "object": "chat.completion.chunk",
                "created": created,
                "model": model_name,
                "choices": [{
                    "index": 0,
                    "delta": {"content": chunk},
                    "finish_reason": None,
                }],
            }
            yield f"data: {json.dumps(data)}\n\n"
        
        # Check for tool calls in complete output
        if tools:
            tool_calls = model_manager.tool_parser.extract_tool_calls(generated_text, tools)
            if tool_calls:
                # Send tool calls as final delta
                data = {
                    "id": completion_id,
                    "object": "chat.completion.chunk",
                    "created": created,
                    "model": model_name,
                    "choices": [{
                        "index": 0,
                        "delta": {"tool_calls": tool_calls},
                        "finish_reason": "tool_calls",
                    }],
                }
                yield f"data: {json.dumps(data)}\n\n"
            else:
                yield f"data: {json.dumps({'choices': [{'finish_reason': 'stop'}]})}\n\n"
        else:
            yield f"data: {json.dumps({'choices': [{'finish_reason': 'stop'}]})}\n\n"
        
        yield "data: [DONE]\n\n"
    except Exception as e:
        print(f"Error during streaming generation: {e}")
        # Send error event
        data = {
            "id": completion_id,
            "object": "chat.completion.chunk",
            "created": created,
            "model": model_name,
            "choices": [{
                "index": 0,
                "delta": {"content": f"\n[Generation error: {str(e)}]"},
                "finish_reason": "stop",
            }],
        }
        yield f"data: {json.dumps(data)}\n\n"
        yield "data: [DONE]\n\n"


async def generate_chat_response(
    prompt: str,
    model_name: str,
    max_tokens: Optional[int],
    temperature: float,
    top_p: float,
    stop: List[str],
    tools: Optional[List[Tool]],
) -> Dict:
    """Generate non-streaming chat completion response."""
    completion_id = f"chatcmpl-{uuid.uuid4().hex}"
    created = int(time.time())
    
    try:
        generated_text = model_manager.generate(
            prompt=prompt,
            max_tokens=max_tokens,
            temperature=temperature,
            top_p=top_p,
            stop=stop,
        )
        
        response_message = {
            "role": "assistant",
            "content": generated_text,
        }
        
        finish_reason = "stop"
        
        # Check for tool calls
        if tools:
            tool_calls = model_manager.tool_parser.extract_tool_calls(generated_text, tools)
            if tool_calls:
                response_message["tool_calls"] = tool_calls
                response_message["content"] = None
                finish_reason = "tool_calls"
        
        return {
            "id": completion_id,
            "object": "chat.completion",
            "created": created,
            "model": model_name,
            "choices": [{
                "index": 0,
                "message": response_message,
                "finish_reason": finish_reason,
            }],
            "usage": {
                "prompt_tokens": len(model_manager.tokenizer.encode(prompt)),
                "completion_tokens": len(model_manager.tokenizer.encode(generated_text)),
                "total_tokens": len(model_manager.tokenizer.encode(prompt)) + len(model_manager.tokenizer.encode(generated_text)),
            },
        }
    except Exception as e:
        print(f"Error during generation: {e}")
        raise HTTPException(status_code=500, detail=f"Generation error: {str(e)}")


@app.post("/v1/completions")
async def completions(request: CompletionRequest):
    """Text completions endpoint."""
    if model_manager.model is None:
        raise HTTPException(status_code=503, detail="Model not loaded")
    
    prompts = request.prompt if isinstance(request.prompt, list) else [request.prompt]
    stop_sequences = []
    if request.stop:
        stop_sequences = [request.stop] if isinstance(request.stop, str) else request.stop
    
    if request.stream:
        return StreamingResponse(
            stream_completion_response(
                prompts[0],
                request.model,
                request.max_tokens,
                request.temperature,
                request.top_p,
                stop_sequences,
            ),
            media_type="text/event-stream",
        )
    else:
        return await generate_completion_response(
            prompts[0],
            request.model,
            request.max_tokens,
            request.temperature,
            request.top_p,
            stop_sequences,
        )


async def stream_completion_response(
    prompt: str,
    model_name: str,
    max_tokens: Optional[int],
    temperature: float,
    top_p: float,
    stop: List[str],
) -> AsyncGenerator[str, None]:
    """Stream completion response."""
    completion_id = f"cmpl-{uuid.uuid4().hex}"
    created = int(time.time())
    
    try:
        for chunk in model_manager.generate_stream(
            prompt=prompt,
            max_tokens=max_tokens,
            temperature=temperature,
            top_p=top_p,
            stop=stop,
        ):
            data = {
                "id": completion_id,
                "object": "text_completion",
                "created": created,
                "model": model_name,
                "choices": [{
                    "text": chunk,
                    "index": 0,
                    "logprobs": None,
                    "finish_reason": None,
                }],
            }
            yield f"data: {json.dumps(data)}\n\n"
        
        yield f"data: {json.dumps({'choices': [{'finish_reason': 'stop'}]})}\n\n"
        yield "data: [DONE]\n\n"
    except Exception as e:
        print(f"Error during streaming completion: {e}")
        yield f"data: {json.dumps({'choices': [{'finish_reason': 'stop'}]})}\n\n"
        yield "data: [DONE]\n\n"


async def generate_completion_response(
    prompt: str,
    model_name: str,
    max_tokens: Optional[int],
    temperature: float,
    top_p: float,
    stop: List[str],
) -> Dict:
    """Generate non-streaming completion response."""
    completion_id = f"cmpl-{uuid.uuid4().hex}"
    created = int(time.time())
    
    try:
        generated_text = model_manager.generate(
            prompt=prompt,
            max_tokens=max_tokens,
            temperature=temperature,
            top_p=top_p,
            stop=stop,
        )
        
        return {
            "id": completion_id,
            "object": "text_completion",
            "created": created,
            "model": model_name,
            "choices": [{
                "text": generated_text,
                "index": 0,
                "logprobs": None,
                "finish_reason": "stop",
            }],
            "usage": {
                "prompt_tokens": len(model_manager.tokenizer.encode(prompt)),
                "completion_tokens": len(model_manager.tokenizer.encode(generated_text)),
                "total_tokens": len(model_manager.tokenizer.encode(prompt)) + len(model_manager.tokenizer.encode(generated_text)),
            },
        }
    except Exception as e:
        print(f"Error during completion: {e}")
        raise HTTPException(status_code=500, detail=f"Generation error: {str(e)}")


# =============================================================================
# Main Entry Point
# =============================================================================

def parse_args():
    """Parse command line arguments."""
    parser = argparse.ArgumentParser(
        description="OpenAI-compatible API server with memory-aware model loading"
    )
    parser.add_argument(
        "--model",
        type=str,
        default=None,
        help="HuggingFace model name or path",
    )
    parser.add_argument(
        "--host",
        type=str,
        default="0.0.0.0",
        help="Host to bind to (default: 0.0.0.0)",
    )
    parser.add_argument(
        "--port",
        type=int,
        default=8000,
        help="Port to bind to (default: 8000)",
    )
    parser.add_argument(
        "--offload-dir",
        type=str,
        default="./offload",
        help="Directory for disk offload when model doesn't fit in VRAM+RAM (default: ./offload)",
    )
    parser.add_argument(
        "--load-in-4bit",
        action="store_true",
        help="Load model in 4-bit precision (requires bitsandbytes)",
    )
    parser.add_argument(
        "--load-in-8bit",
        action="store_true",
        help="Load model in 8-bit precision (requires bitsandbytes)",
    )
    parser.add_argument(
        "--ram",
        type=float,
        default=None,
        help="Manually specify available RAM in GB (bypasses auto-detection)",
    )
    parser.add_argument(
        "--flash-attn",
        action="store_true",
        help="Use Flash Attention 2 for faster inference (requires flash-attn package and compatible GPU)",
    )
    return parser.parse_args()


def main():
    """Main entry point."""
    import procname
    procname.setprocname("coderai")
    args = parse_args()
    
    # Get model name from args or prompt interactively
    model_name = args.model
    if model_name is None:
        print("No model specified. Please enter a HuggingFace model name.")
        print("Examples:")
        print("  - microsoft/DialoGPT-medium")
        print("  - facebook/blenderbot-400M-distill")
        print("  - meta-llama/Llama-2-7b-chat-hf (requires auth)")
        print("  - TinyLlama/TinyLlama-1.1B-Chat-v1.0")
        print("")
        model_name = input("Enter model name: ").strip()
        
        if not model_name:
            print("Error: Model name is required")
            sys.exit(1)
    
    # Load the model with memory-aware offload
    model_manager.load_model(
        model_name=model_name,
        offload_dir=args.offload_dir,
        load_in_4bit=args.load_in_4bit,
        load_in_8bit=args.load_in_8bit,
        manual_ram_gb=args.ram,
        flash_attn=getattr(args, 'flash_attn', False),
    )
    
    # Start the server
    import uvicorn
    print(f"\nStarting server on http://{args.host}:{args.port}")
    print(f"API documentation available at http://{args.host}:{args.port}/docs")
    uvicorn.run(app, host=args.host, port=args.port)


if __name__ == "__main__":
    main()
