"""
Authentication manager for web dashboard
"""

import hashlib
import secrets
import logging
from datetime import datetime, timedelta
from typing import Optional, Dict, Any, Tuple
from flask import Flask, request, session
from flask_login import UserMixin
from flask_jwt_extended import create_access_token, decode_token
import jwt

from ..database.manager import DatabaseManager
from ..database.models import User, APIToken

logger = logging.getLogger(__name__)


class AuthenticatedUser(UserMixin):
    """User class for Flask-Login"""
    
    def __init__(self, user_id: int, username: str, email: str, is_admin: bool = False):
        self.id = user_id
        self.username = username
        self.email = email
        self.is_admin = is_admin
        self.is_authenticated = True
        self.is_active = True
        self.is_anonymous = False
    
    def get_id(self):
        return str(self.id)
    
    def to_dict(self) -> Dict[str, Any]:
        return {
            'id': self.id,
            'username': self.username,
            'email': self.email,
            'is_admin': self.is_admin
        }


class AuthManager:
    """Handles authentication, JWT tokens, and user management"""
    
    def __init__(self, db_manager: DatabaseManager, app: Flask):
        self.db_manager = db_manager
        self.app = app
        self.jwt_secret = app.config['JWT_SECRET_KEY']
        self.session_timeout = app.config['PERMANENT_SESSION_LIFETIME']
        
        logger.info("AuthManager initialized")
    
    def hash_password(self, password: str) -> str:
        """Hash password using SHA-256 with salt"""
        salt = secrets.token_hex(16)
        password_hash = hashlib.sha256((password + salt).encode()).hexdigest()
        return f"{salt}:{password_hash}"
    
    def verify_password(self, password: str, stored_hash: str) -> bool:
        """Verify password against stored hash"""
        try:
            salt, password_hash = stored_hash.split(':', 1)
            expected_hash = hashlib.sha256((password + salt).encode()).hexdigest()
            return password_hash == expected_hash
        except (ValueError, AttributeError):
            return False
    
    def authenticate_user(self, username: str, password: str) -> Optional[AuthenticatedUser]:
        """Authenticate user with username/password"""
        try:
            user = self.db_manager.get_user_by_username(username)
            if not user:
                logger.warning(f"Authentication failed: user not found - {username}")
                return None
            
            if not self.verify_password(password, user.password_hash):
                logger.warning(f"Authentication failed: invalid password - {username}")
                return None
            
            # Update last login
            user.last_login = datetime.utcnow()
            self.db_manager.save_user(user)
            
            authenticated_user = AuthenticatedUser(
                user_id=user.id,
                username=user.username,
                email=user.email,
                is_admin=user.is_admin
            )
            
            logger.info(f"User authenticated successfully: {username}")
            return authenticated_user
            
        except Exception as e:
            logger.error(f"Authentication error: {e}")
            return None
    
    def create_user(self, username: str, email: str, password: str, 
                    is_admin: bool = False) -> Optional[User]:
        """Create new user"""
        try:
            # Check if user already exists
            existing_user = self.db_manager.get_user_by_username(username)
            if existing_user:
                logger.warning(f"User creation failed: username already exists - {username}")
                return None
            
            existing_email = self.db_manager.get_user_by_email(email)
            if existing_email:
                logger.warning(f"User creation failed: email already exists - {email}")
                return None
            
            # Create new user
            password_hash = self.hash_password(password)
            user = User(
                username=username,
                email=email,
                password_hash=password_hash,
                is_admin=is_admin,
                created_at=datetime.utcnow()
            )
            
            saved_user = self.db_manager.save_user(user)
            logger.info(f"User created successfully: {username}")
            return saved_user
            
        except Exception as e:
            logger.error(f"User creation error: {e}")
            return None
    
    def change_password(self, user_id: int, old_password: str, new_password: str) -> bool:
        """Change user password"""
        try:
            user = self.db_manager.get_user_by_id(user_id)
            if not user:
                logger.warning(f"Password change failed: user not found - {user_id}")
                return False
            
            if not self.verify_password(old_password, user.password_hash):
                logger.warning(f"Password change failed: invalid old password - {user_id}")
                return False
            
            # Update password
            user.password_hash = self.hash_password(new_password)
            user.updated_at = datetime.utcnow()
            self.db_manager.save_user(user)
            
            logger.info(f"Password changed successfully: {user.username}")
            return True
            
        except Exception as e:
            logger.error(f"Password change error: {e}")
            return False
    
    def create_jwt_token(self, user_id: int, expires_hours: int = 24) -> Optional[str]:
        """Create JWT access token"""
        try:
            user = self.db_manager.get_user_by_id(user_id)
            if not user:
                return None
            
            # Token payload
            payload = {
                'user_id': user.id,
                'username': user.username,
                'is_admin': user.is_admin,
                'exp': datetime.utcnow() + timedelta(hours=expires_hours),
                'iat': datetime.utcnow(),
                'sub': str(user.id)
            }
            
            # Create token
            token = jwt.encode(payload, self.jwt_secret, algorithm='HS256')
            
            # Store token in database
            api_token = APIToken(
                user_id=user.id,
                token_name=f"Web Token {datetime.utcnow().strftime('%Y-%m-%d %H:%M')}",
                token_hash=self._hash_token(token),
                expires_at=datetime.utcnow() + timedelta(hours=expires_hours),
                created_at=datetime.utcnow()
            )
            
            self.db_manager.save_api_token(api_token)
            
            logger.info(f"JWT token created for user: {user.username}")
            return token
            
        except Exception as e:
            logger.error(f"JWT token creation error: {e}")
            return None
    
    def verify_jwt_token(self, token: str) -> Optional[Dict[str, Any]]:
        """Verify JWT token and return payload"""
        try:
            # Decode token
            payload = jwt.decode(token, self.jwt_secret, algorithms=['HS256'])
            
            # Check if token is in database and not revoked
            token_hash = self._hash_token(token)
            api_token = self.db_manager.get_api_token_by_hash(token_hash)
            
            if not api_token or api_token.revoked:
                logger.warning("JWT token verification failed: token revoked or not found")
                return None
            
            # Check expiration
            if api_token.expires_at < datetime.utcnow():
                logger.warning("JWT token verification failed: token expired")
                return None
            
            # Update last used
            api_token.last_used = datetime.utcnow()
            self.db_manager.save_api_token(api_token)
            
            return payload
            
        except jwt.ExpiredSignatureError:
            logger.warning("JWT token verification failed: signature expired")
            return None
        except jwt.InvalidTokenError as e:
            logger.warning(f"JWT token verification failed: invalid token - {e}")
            return None
        except Exception as e:
            logger.error(f"JWT token verification error: {e}")
            return None
    
    def revoke_jwt_token(self, token: str) -> bool:
        """Revoke JWT token"""
        try:
            token_hash = self._hash_token(token)
            api_token = self.db_manager.get_api_token_by_hash(token_hash)
            
            if api_token:
                api_token.revoked = True
                api_token.updated_at = datetime.utcnow()
                self.db_manager.save_api_token(api_token)
                
                logger.info(f"JWT token revoked: {api_token.token_name}")
                return True
            
            return False
            
        except Exception as e:
            logger.error(f"JWT token revocation error: {e}")
            return False
    
    def create_api_token(self, user_id: int, token_name: str, 
                        expires_hours: int = 8760) -> Optional[Tuple[str, APIToken]]:  # 1 year default
        """Create long-lived API token"""
        try:
            user = self.db_manager.get_user_by_id(user_id)
            if not user:
                return None
            
            # Generate secure token
            token = secrets.token_urlsafe(32)
            token_hash = self._hash_token(token)
            
            # Create API token record
            api_token = APIToken(
                user_id=user.id,
                token_name=token_name,
                token_hash=token_hash,
                expires_at=datetime.utcnow() + timedelta(hours=expires_hours),
                created_at=datetime.utcnow()
            )
            
            saved_token = self.db_manager.save_api_token(api_token)
            
            logger.info(f"API token created: {token_name} for user {user.username}")
            return token, saved_token
            
        except Exception as e:
            logger.error(f"API token creation error: {e}")
            return None
    
    def verify_api_token(self, token: str) -> Optional[Dict[str, Any]]:
        """Verify API token"""
        try:
            token_hash = self._hash_token(token)
            api_token = self.db_manager.get_api_token_by_hash(token_hash)
            
            if not api_token or api_token.revoked:
                logger.warning("API token verification failed: token revoked or not found")
                return None
            
            # Check expiration
            if api_token.expires_at < datetime.utcnow():
                logger.warning("API token verification failed: token expired")
                return None
            
            # Get user
            user = self.db_manager.get_user_by_id(api_token.user_id)
            if not user:
                logger.warning("API token verification failed: user not found")
                return None
            
            # Update last used
            api_token.last_used = datetime.utcnow()
            self.db_manager.save_api_token(api_token)
            
            return {
                'user_id': user.id,
                'username': user.username,
                'is_admin': user.is_admin,
                'token_name': api_token.token_name,
                'token_id': api_token.id
            }
            
        except Exception as e:
            logger.error(f"API token verification error: {e}")
            return None
    
    def list_user_tokens(self, user_id: int) -> list:
        """List all tokens for user"""
        try:
            tokens = self.db_manager.get_user_api_tokens(user_id)
            return [
                {
                    'id': token.id,
                    'name': token.token_name,
                    'created_at': token.created_at.isoformat(),
                    'expires_at': token.expires_at.isoformat(),
                    'last_used': token.last_used.isoformat() if token.last_used else None,
                    'revoked': token.revoked
                }
                for token in tokens
            ]
        except Exception as e:
            logger.error(f"Failed to list user tokens: {e}")
            return []
    
    def revoke_api_token_by_id(self, token_id: int, user_id: int) -> bool:
        """Revoke API token by ID"""
        try:
            api_token = self.db_manager.get_api_token_by_id(token_id)
            
            if api_token and api_token.user_id == user_id:
                api_token.revoked = True
                api_token.updated_at = datetime.utcnow()
                self.db_manager.save_api_token(api_token)
                
                logger.info(f"API token revoked: {api_token.token_name}")
                return True
            
            return False
            
        except Exception as e:
            logger.error(f"API token revocation error: {e}")
            return False
    
    def cleanup_expired_tokens(self):
        """Clean up expired tokens"""
        try:
            count = self.db_manager.cleanup_expired_tokens()
            if count > 0:
                logger.info(f"Cleaned up {count} expired tokens")
        except Exception as e:
            logger.error(f"Token cleanup error: {e}")
    
    def _hash_token(self, token: str) -> str:
        """Hash token for secure storage"""
        return hashlib.sha256(token.encode()).hexdigest()
    
    def require_auth(self, f):
        """Decorator for routes requiring authentication"""
        from functools import wraps
        
        @wraps(f)
        def decorated_function(*args, **kwargs):
            auth_header = request.headers.get('Authorization')
            
            if auth_header and auth_header.startswith('Bearer '):
                token = auth_header.split(' ', 1)[1]
                
                # Try JWT token first
                payload = self.verify_jwt_token(token)
                if payload:
                    request.current_user = payload
                    return f(*args, **kwargs)
                
                # Try API token
                api_data = self.verify_api_token(token)
                if api_data:
                    request.current_user = api_data
                    return f(*args, **kwargs)
            
            return {'error': 'Authentication required'}, 401
        
        return decorated_function
    
    def require_admin(self, f):
        """Decorator for routes requiring admin access"""
        from functools import wraps
        
        @wraps(f)
        def decorated_function(*args, **kwargs):
            if not hasattr(request, 'current_user'):
                return {'error': 'Authentication required'}, 401
            
            if not request.current_user.get('is_admin', False):
                return {'error': 'Admin access required'}, 403
            
            return f(*args, **kwargs)
        
        return decorated_function