Implement remember me functionality for login

- Added remember_tokens table to database
- Added functions to create, validate, and delete remember tokens
- Modified login route to set remember cookie when checkbox is checked
- Added before_request handler to auto-login from remember cookie
- Modified logout to clear remember cookie and delete token
- Cookie expires in 30 days
parent 049cd789
......@@ -345,6 +345,30 @@ def init_db(conn) -> None:
)
''')
# Remember tokens table
if config['type'] == 'mysql':
cursor.execute('''
CREATE TABLE IF NOT EXISTS remember_tokens (
id INT AUTO_INCREMENT PRIMARY KEY,
user_id INT NOT NULL,
token VARCHAR(255) UNIQUE NOT NULL,
expires_at TIMESTAMP NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (user_id) REFERENCES users (id) ON DELETE CASCADE
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci
''')
else:
cursor.execute('''
CREATE TABLE IF NOT EXISTS remember_tokens (
id INTEGER PRIMARY KEY,
user_id INTEGER NOT NULL,
token TEXT UNIQUE NOT NULL,
expires_at TIMESTAMP NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (user_id) REFERENCES users (id)
)
''')
# User API tokens table
if config['type'] == 'mysql':
cursor.execute('''
......@@ -1235,4 +1259,59 @@ def delete_user_api_token(user_id: int, token_id: int) -> bool:
conn.commit()
success = cursor.rowcount > 0
conn.close()
return success
\ No newline at end of file
return success
# Remember token functions
def create_remember_token(user_id: int) -> str:
"""Create a remember token for user (expires in 30 days)."""
import secrets
import time
token = secrets.token_hex(32)
expires_at = int(time.time()) + (30 * 24 * 60 * 60) # 30 days
conn = get_db_connection()
cursor = conn.cursor()
cursor.execute('INSERT INTO remember_tokens (user_id, token, expires_at) VALUES (?, ?, ?)',
(user_id, token, expires_at))
conn.commit()
conn.close()
return token
def validate_remember_token(token: str) -> Optional[Dict[str, Any]]:
"""Validate remember token and return user info."""
import time
conn = get_db_connection()
cursor = conn.cursor()
cursor.execute('''
SELECT u.id, u.username, u.email, u.role, u.active
FROM remember_tokens t
JOIN users u ON t.user_id = u.id
WHERE t.token = ? AND t.expires_at > ? AND u.active = 1
''', (token, int(time.time())))
row = cursor.fetchone()
conn.close()
return dict(row) if row else None
def delete_remember_token(token: str) -> None:
"""Delete a remember token."""
conn = get_db_connection()
cursor = conn.cursor()
cursor.execute('DELETE FROM remember_tokens WHERE token = ?', (token,))
conn.commit()
conn.close()
def delete_expired_remember_tokens() -> None:
"""Delete expired remember tokens."""
import time
conn = get_db_connection()
cursor = conn.cursor()
cursor.execute('DELETE FROM remember_tokens WHERE expires_at <= ?', (int(time.time()),))
conn.commit()
conn.close()
\ No newline at end of file
......@@ -28,7 +28,7 @@ import argparse
from .comm import SocketCommunicator, Message
from .config import get_all_settings, get_allow_registration
from .auth import login_user, logout_user, get_current_user, register_user, confirm_email, require_auth, require_admin
from .database import get_user_tokens, update_user_tokens, get_user_queue_items, get_default_user_tokens
from .database import get_user_tokens, update_user_tokens, get_user_queue_items, get_default_user_tokens, create_remember_token, validate_remember_token, delete_remember_token
app = Flask(__name__, template_folder=os.path.join(os.path.dirname(__file__), '..', 'templates'))
app.secret_key = os.environ.get('FLASK_SECRET_KEY', 'dev-secret-key-change-in-production')
......@@ -37,6 +37,19 @@ os.makedirs('static', exist_ok=True)
# Global configuration
server_dir = None
@app.before_request
def check_remember_me():
"""Check for remember me cookie and auto-login if valid."""
if 'session_id' not in session:
remember_token = request.cookies.get('remember_token')
if remember_token:
user = validate_remember_token(remember_token)
if user:
# Create a new session for the user
session_id = session_manager.create_session(user)
session['session_id'] = session_id
# Communicator to backend (always TCP)
comm = SocketCommunicator(host='localhost', port=5001, comm_type='tcp')
......@@ -138,11 +151,22 @@ def login():
if request.method == 'POST':
username = request.form.get('username')
password = request.form.get('password')
remember = request.form.get('remember') == 'on'
session_id = login_user(username, password)
if session_id:
session['session_id'] = session_id
flash('Login successful!', 'success')
# Handle remember me
if remember:
user = get_current_user(session_id)
if user:
remember_token = create_remember_token(user['id'])
response = make_response(redirect(url_for('dashboard')))
response.set_cookie('remember_token', remember_token, max_age=30*24*60*60, httponly=True, secure=False) # secure=True in production with HTTPS
return response
return redirect(url_for('dashboard'))
else:
flash('Invalid username or password', 'error')
......@@ -176,9 +200,17 @@ def logout():
session_id = session.get('session_id')
if session_id:
logout_user(session_id)
# Clear remember token if exists
remember_token = request.cookies.get('remember_token')
if remember_token:
delete_remember_token(remember_token)
session.clear()
flash('Logged out successfully', 'success')
return redirect(url_for('index'))
response = make_response(redirect(url_for('index')))
response.delete_cookie('remember_token')
return response
@app.route('/analyze', methods=['GET', 'POST'])
@login_required
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment