"""
JWT utilities for API token authentication
Handles token generation, validation, and API authentication
"""

import jwt
import logging
from datetime import datetime, timedelta
from functools import wraps
from flask import request, jsonify, current_app, g
from app.models import APIToken, User

logger = logging.getLogger(__name__)

class JWTError(Exception):
    """Custom JWT error"""
    pass

class TokenExpiredError(JWTError):
    """Token has expired"""
    pass

class TokenInvalidError(JWTError):
    """Token is invalid"""
    pass

def generate_jwt_token(user_id: int, token_name: str, expires_in: timedelta = None) -> str:
    """
    Generate a JWT token for API access
    
    Args:
        user_id: User ID
        token_name: Name/description of the token
        expires_in: Token expiration time (default: 1 year)
    
    Returns:
        JWT token string
    """
    if expires_in is None:
        expires_in = timedelta(days=365)
    
    payload = {
        'user_id': user_id,
        'token_name': token_name,
        'iat': datetime.utcnow(),
        'exp': datetime.utcnow() + expires_in,
        'type': 'api_access'
    }
    
    try:
        token = jwt.encode(
            payload,
            current_app.config['JWT_SECRET_KEY'],
            algorithm='HS256'
        )
        return token
    except Exception as e:
        logger.error(f"Failed to generate JWT token: {str(e)}")
        raise JWTError(f"Token generation failed: {str(e)}")

def decode_jwt_token(token: str) -> dict:
    """
    Decode and validate a JWT token
    
    Args:
        token: JWT token string
    
    Returns:
        Decoded token payload
    
    Raises:
        TokenExpiredError: If token has expired
        TokenInvalidError: If token is invalid
    """
    try:
        payload = jwt.decode(
            token,
            current_app.config['JWT_SECRET_KEY'],
            algorithms=['HS256']
        )
        
        # Validate token type
        if payload.get('type') != 'api_access':
            raise TokenInvalidError("Invalid token type")
        
        return payload
        
    except jwt.ExpiredSignatureError:
        raise TokenExpiredError("Token has expired")
    except jwt.InvalidTokenError as e:
        raise TokenInvalidError(f"Invalid token: {str(e)}")
    except Exception as e:
        logger.error(f"Token decode error: {str(e)}")
        raise TokenInvalidError(f"Token validation failed: {str(e)}")

def validate_api_token(token: str) -> tuple[User, APIToken]:
    """
    Validate an API token and return the associated user and token record
    
    Args:
        token: API token string
    
    Returns:
        Tuple of (User, APIToken) objects
    
    Raises:
        TokenInvalidError: If token is invalid or user not found
        TokenExpiredError: If token has expired
    """
    try:
        # First check if it's a database token (non-JWT)
        api_token = APIToken.find_by_token(token)
        if api_token:
            if not api_token.is_valid():
                if api_token.is_expired():
                    raise TokenExpiredError("API token has expired")
                else:
                    raise TokenInvalidError("API token is inactive")
            
            user = User.query.get(api_token.user_id)
            if not user or not user.is_active:
                raise TokenInvalidError("User not found or inactive")
            
            # Update last used timestamp
            api_token.update_last_used(request.remote_addr)
            
            return user, api_token
        
        # If not a database token, try JWT
        payload = decode_jwt_token(token)
        user_id = payload.get('user_id')
        
        if not user_id:
            raise TokenInvalidError("Token missing user ID")
        
        user = User.query.get(user_id)
        if not user or not user.is_active:
            raise TokenInvalidError("User not found or inactive")
        
        # For JWT tokens, we don't have an APIToken record
        return user, None
        
    except (TokenExpiredError, TokenInvalidError):
        raise
    except Exception as e:
        logger.error(f"Token validation error: {str(e)}")
        raise TokenInvalidError(f"Token validation failed: {str(e)}")

def extract_token_from_request() -> str:
    """
    Extract API token from request headers
    
    Returns:
        Token string or None if not found
    """
    # Check Authorization header (Bearer token)
    auth_header = request.headers.get('Authorization')
    if auth_header and auth_header.startswith('Bearer '):
        return auth_header[7:]  # Remove 'Bearer ' prefix
    
    # Check X-API-Token header
    api_token = request.headers.get('X-API-Token')
    if api_token:
        return api_token
    
    # Check query parameter (less secure, but sometimes needed)
    token_param = request.args.get('token')
    if token_param:
        return token_param
    
    return None

def require_api_token(f):
    """
    Decorator to require valid API token for route access
    
    Usage:
        @app.route('/api/data')
        @require_api_token
        def get_data():
            # Access current user via g.current_user
            # Access current token via g.current_token (may be None for JWT)
            return jsonify({'data': 'protected'})
    """
    @wraps(f)
    def decorated_function(*args, **kwargs):
        try:
            # Extract token from request
            token = extract_token_from_request()
            if not token:
                return jsonify({
                    'error': 'Missing API token',
                    'message': 'API token required in Authorization header, X-API-Token header, or token parameter'
                }), 401
            
            # Validate token
            user, api_token = validate_api_token(token)
            
            # Store in Flask's g object for use in the route
            g.current_user = user
            g.current_token = api_token
            
            # Log API access
            logger.info(f"API access by user {user.username} (ID: {user.id}) with token: {api_token.name if api_token else 'JWT'}")
            
            return f(*args, **kwargs)
            
        except TokenExpiredError as e:
            return jsonify({
                'error': 'Token expired',
                'message': str(e)
            }), 401
        except TokenInvalidError as e:
            return jsonify({
                'error': 'Invalid token',
                'message': str(e)
            }), 401
        except Exception as e:
            logger.error(f"API authentication error: {str(e)}")
            return jsonify({
                'error': 'Authentication failed',
                'message': 'Internal authentication error'
            }), 500
    
    return decorated_function

def optional_api_token(f):
    """
    Decorator that allows but doesn't require API token authentication
    If token is provided and valid, user info is available in g.current_user
    """
    @wraps(f)
    def decorated_function(*args, **kwargs):
        try:
            token = extract_token_from_request()
            if token:
                try:
                    user, api_token = validate_api_token(token)
                    g.current_user = user
                    g.current_token = api_token
                except (TokenExpiredError, TokenInvalidError):
                    # Token provided but invalid - continue without authentication
                    g.current_user = None
                    g.current_token = None
            else:
                g.current_user = None
                g.current_token = None
            
            return f(*args, **kwargs)
            
        except Exception as e:
            logger.error(f"Optional API authentication error: {str(e)}")
            g.current_user = None
            g.current_token = None
            return f(*args, **kwargs)
    
    return decorated_function

def get_current_api_user():
    """
    Get the current authenticated API user
    
    Returns:
        User object or None if not authenticated
    """
    return getattr(g, 'current_user', None)

def get_current_api_token():
    """
    Get the current API token record
    
    Returns:
        APIToken object or None if not authenticated or using JWT
    """
    return getattr(g, 'current_token', None)