Commit ba33a333 authored by Your Name's avatar Your Name

feat(payments): implement HD wallet manager with BIP32/BIP44

parent c57866bc
"""
Crypto payment module
"""
from aisbf.payments.crypto.wallet import CryptoWalletManager
__all__ = ['CryptoWalletManager']
"""
HD Wallet Manager for cryptocurrency addresses
Implements BIP32/BIP44 hierarchical deterministic wallet generation.
Each crypto type has its own encrypted master seed, from which user addresses
are deterministically derived.
"""
import logging
from mnemonic import Mnemonic
from bip32 import BIP32
from cryptography.fernet import Fernet
logger = logging.getLogger(__name__)
class CryptoWalletManager:
"""Manages HD wallets for all supported cryptocurrencies"""
# BIP44 coin types
COIN_TYPES = {
'btc': 0,
'eth': 60,
'usdt': 60, # ERC20 uses Ethereum
'usdc': 60 # ERC20 uses Ethereum
}
def __init__(self, db_manager, encryption_key: str):
self.db = db_manager
self.encryption_key = encryption_key
self.fernet = Fernet(encryption_key.encode())
# Initialize master keys on first run
self._initialize_master_keys()
def _initialize_master_keys(self):
"""Initialize master keys for all crypto types (run once on setup)"""
for crypto_type in ['btc', 'eth', 'usdt', 'usdc']:
with self.db._get_connection() as conn:
cursor = conn.cursor()
placeholder = '?' if self.db.db_type == 'sqlite' else '%s'
cursor.execute(
f"SELECT id FROM crypto_master_keys WHERE crypto_type = {placeholder}",
(crypto_type,)
)
existing = cursor.fetchone()
if not existing:
# Generate new BIP39 mnemonic (24 words)
mnemo = Mnemonic("english")
mnemonic = mnemo.generate(strength=256)
# Encrypt mnemonic
encrypted_seed = self.fernet.encrypt(mnemonic.encode()).decode()
# BIP44 derivation path
coin_type = self.COIN_TYPES[crypto_type]
derivation_path = f"m/44'/{coin_type}'/0'/0"
# Store in database
with self.db._get_connection() as conn:
cursor = conn.cursor()
placeholder = '?' if self.db.db_type == 'sqlite' else '%s'
cursor.execute(f"""
INSERT INTO crypto_master_keys
(crypto_type, encrypted_seed, encryption_key_id, derivation_path)
VALUES ({placeholder}, {placeholder}, {placeholder}, {placeholder})
""", (crypto_type, encrypted_seed, 'default', derivation_path))
conn.commit()
logger.info(f"Generated master key for {crypto_type}")
def get_master_seed(self, crypto_type: str) -> str:
"""Get decrypted master seed for crypto type"""
with self.db._get_connection() as conn:
cursor = conn.cursor()
placeholder = '?' if self.db.db_type == 'sqlite' else '%s'
cursor.execute(
f"SELECT encrypted_seed FROM crypto_master_keys WHERE crypto_type = {placeholder}",
(crypto_type,)
)
result = cursor.fetchone()
if not result:
raise ValueError(f"No master key found for {crypto_type}")
# Decrypt seed
encrypted_seed = result[0]
mnemonic = self.fernet.decrypt(encrypted_seed.encode()).decode()
return mnemonic
def derive_address(self, crypto_type: str, index: int) -> dict:
"""Derive address from master key using BIP44 path"""
mnemonic = self.get_master_seed(crypto_type)
# Generate seed from mnemonic
mnemo = Mnemonic("english")
seed = mnemo.to_seed(mnemonic)
# BIP44 path: m/44'/coin_type'/0'/0/index
coin_type = self.COIN_TYPES[crypto_type]
path = f"m/44'/{coin_type}'/0'/0/{index}"
if crypto_type == 'btc':
return self._derive_bitcoin_address(seed, path, index)
elif crypto_type in ['eth', 'usdt', 'usdc']:
return self._derive_ethereum_address(seed, path, index)
def _derive_bitcoin_address(self, seed: bytes, path: str, index: int) -> dict:
"""Derive Bitcoin address"""
from bitcoinlib.keys import HDKey
# Create HD key from seed
hd_key = HDKey.from_seed(seed)
# Derive child key
child_key = hd_key.subkey_for_path(path)
# Get P2WPKH address (native segwit, starts with bc1)
address = child_key.address(encoding='bech32')
return {
'address': address,
'derivation_path': path,
'derivation_index': index
}
def _derive_ethereum_address(self, seed: bytes, path: str, index: int) -> dict:
"""Derive Ethereum address (also used for USDT/USDC ERC20)"""
from eth_account import Account
# Create BIP32 instance
bip32 = BIP32.from_seed(seed)
# Derive child key
child_key = bip32.get_privkey_from_path(path)
# Get Ethereum address from private key
account = Account.from_key(child_key)
address = account.address
return {
'address': address,
'derivation_path': path,
'derivation_index': index
}
async def get_or_create_user_address(self, user_id: int, crypto_type: str) -> str:
"""Get existing address or create new one for user"""
# Check if user already has address
with self.db._get_connection() as conn:
cursor = conn.cursor()
placeholder = '?' if self.db.db_type == 'sqlite' else '%s'
cursor.execute(f"""
SELECT address FROM user_crypto_addresses
WHERE user_id = {placeholder} AND crypto_type = {placeholder}
""", (user_id, crypto_type))
existing = cursor.fetchone()
if existing:
return existing[0]
# Get next available index
with self.db._get_connection() as conn:
cursor = conn.cursor()
placeholder = '?' if self.db.db_type == 'sqlite' else '%s'
cursor.execute(f"""
SELECT COALESCE(MAX(derivation_index), -1) as max_idx
FROM user_crypto_addresses
WHERE crypto_type = {placeholder}
""", (crypto_type,))
max_index = cursor.fetchone()
next_index = max_index[0] + 1
# Derive new address
address_info = self.derive_address(crypto_type, next_index)
# Store in database
with self.db._get_connection() as conn:
cursor = conn.cursor()
placeholder = '?' if self.db.db_type == 'sqlite' else '%s'
cursor.execute(f"""
INSERT INTO user_crypto_addresses
(user_id, crypto_type, address, derivation_path, derivation_index)
VALUES ({placeholder}, {placeholder}, {placeholder}, {placeholder}, {placeholder})
""", (
user_id,
crypto_type,
address_info['address'],
address_info['derivation_path'],
address_info['derivation_index']
))
conn.commit()
# Create wallet entry
with self.db._get_connection() as conn:
cursor = conn.cursor()
placeholder = '?' if self.db.db_type == 'sqlite' else '%s'
cursor.execute(f"""
INSERT INTO user_crypto_wallets
(user_id, crypto_type, balance_crypto, balance_fiat)
VALUES ({placeholder}, {placeholder}, 0, 0)
""", (user_id, crypto_type))
conn.commit()
logger.info(f"Created {crypto_type} address for user {user_id}: {address_info['address']}")
return address_info['address']
...@@ -86,6 +86,7 @@ class PaymentMigrations: ...@@ -86,6 +86,7 @@ class PaymentMigrations:
user_id INTEGER NOT NULL, user_id INTEGER NOT NULL,
crypto_type VARCHAR(20) NOT NULL, crypto_type VARCHAR(20) NOT NULL,
address VARCHAR(255) NOT NULL UNIQUE, address VARCHAR(255) NOT NULL UNIQUE,
derivation_path VARCHAR(100) NOT NULL,
derivation_index INTEGER NOT NULL, derivation_index INTEGER NOT NULL,
created_at TIMESTAMP DEFAULT {timestamp_default}, created_at TIMESTAMP DEFAULT {timestamp_default},
FOREIGN KEY (user_id) REFERENCES users(id), FOREIGN KEY (user_id) REFERENCES users(id),
......
import pytest
from cryptography.fernet import Fernet
from aisbf.database import DatabaseManager
from aisbf.payments.migrations import PaymentMigrations
from aisbf.payments.crypto.wallet import CryptoWalletManager
@pytest.fixture
def db_manager(tmp_path):
"""Create test database"""
db_path = tmp_path / "test.db"
db_config = {
'type': 'sqlite',
'sqlite_path': str(db_path)
}
db = DatabaseManager(db_config)
migrations = PaymentMigrations(db)
migrations.run_migrations()
return db
@pytest.fixture
def encryption_key():
"""Generate test encryption key"""
return Fernet.generate_key().decode()
def test_initialize_master_keys(db_manager, encryption_key):
"""Test master key initialization"""
wallet_manager = CryptoWalletManager(db_manager, encryption_key)
# Should create keys for all crypto types
with db_manager._get_connection() as conn:
cursor = conn.cursor()
cursor.execute("SELECT * FROM crypto_master_keys")
keys = cursor.fetchall()
assert len(keys) == 4 # btc, eth, usdt, usdc
# Get crypto types
with db_manager._get_connection() as conn:
cursor = conn.cursor()
cursor.execute("SELECT crypto_type FROM crypto_master_keys")
crypto_types = [row[0] for row in cursor.fetchall()]
assert 'btc' in crypto_types
assert 'eth' in crypto_types
assert 'usdt' in crypto_types
assert 'usdc' in crypto_types
def test_derive_bitcoin_address(db_manager, encryption_key):
"""Test Bitcoin address derivation"""
wallet_manager = CryptoWalletManager(db_manager, encryption_key)
address_info = wallet_manager.derive_address('btc', 0)
assert address_info['address'].startswith('bc1')
assert address_info['derivation_path'] == "m/44'/0'/0'/0/0"
assert address_info['derivation_index'] == 0
def test_derive_ethereum_address(db_manager, encryption_key):
"""Test Ethereum address derivation"""
wallet_manager = CryptoWalletManager(db_manager, encryption_key)
address_info = wallet_manager.derive_address('eth', 0)
assert address_info['address'].startswith('0x')
assert len(address_info['address']) == 42
assert address_info['derivation_path'] == "m/44'/60'/0'/0/0"
@pytest.mark.anyio
async def test_get_or_create_user_address(db_manager, encryption_key):
"""Test user address creation"""
wallet_manager = CryptoWalletManager(db_manager, encryption_key)
# Create address for user 1
address1 = await wallet_manager.get_or_create_user_address(1, 'btc')
assert address1.startswith('bc1')
# Getting again should return same address
address2 = await wallet_manager.get_or_create_user_address(1, 'btc')
assert address1 == address2
# Different user should get different address
address3 = await wallet_manager.get_or_create_user_address(2, 'btc')
assert address3 != address1
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment