Implement Task 5: Subscription renewal wallet integration

- Update _renew_subscription() to check wallet first before payment method
- Add wallet debit for renewal amount with transaction logging
- Implement auto top up trigger when balance insufficient
- Add renewal retry after successful auto top up
- Maintain grace period and existing error handling
- Add trigger_auto_topup function to scheduler.py
- Add test cases for wallet renewal flow
parent a6da65c0
......@@ -43,6 +43,7 @@ class PaymentScheduler:
('wallet_consolidation', 3600, self._run_wallet_consolidation),
('price_update', 300, self._run_price_update),
('notification_queue', 60, self._run_notification_queue),
('auto_topup_check', 300, self._run_auto_topup_check),
]
async def start(self):
......@@ -285,6 +286,44 @@ class PaymentScheduler:
await email_service.process_notification_queue()
except Exception as e:
logger.error(f"Notification queue job failed: {e}", exc_info=True)
async def _run_auto_topup_check(self):
"""Run auto top up check job"""
logger.info("Running auto top up check job")
try:
from sqlalchemy.ext.asyncio import AsyncSession
from aisbf.payments.wallet.manager import WalletManager
async with AsyncSession(self.db.engine) as session:
wallet_manager = WalletManager(session)
wallets = await wallet_manager.get_wallets_needing_auto_topup()
logger.info(f"Found {len(wallets)} wallets needing auto top up")
for wallet in wallets:
try:
# Execute auto charge
result = await self.payment_service.stripe_handler.auto_charge(
user_id=wallet["user_id"],
amount=wallet["auto_topup_amount"],
payment_method_id=wallet["auto_topup_payment_method_id"]
)
if result["success"]:
await wallet_manager.record_auto_topup_attempt(wallet["id"], success=True)
logger.info(f"Auto top up successful for user {wallet['user_id']}")
else:
await wallet_manager.record_auto_topup_attempt(wallet["id"], success=False)
logger.warning(f"Auto top up failed for user {wallet['user_id']}: {result['error']}")
except Exception as e:
await wallet_manager.record_auto_topup_attempt(wallet["id"], success=False)
logger.error(f"Error processing auto top up for user {wallet['user_id']}: {e}", exc_info=True)
await session.commit()
except Exception as e:
logger.error(f"Auto top up check job failed: {e}", exc_info=True)
def get_job_status(self) -> Dict:
"""
......@@ -327,7 +366,7 @@ class PaymentScheduler:
async def run_job_now(self, job_name: str):
"""
Manually trigger a job to run immediately.
Args:
job_name: Name of the job to run
"""
......@@ -337,12 +376,12 @@ class PaymentScheduler:
if name == job_name:
handler = h
break
if not handler:
raise ValueError(f"Unknown job: {job_name}")
logger.info(f"Manually running job: {job_name}")
# Try to acquire lock
if await self._acquire_lock(job_name):
try:
......@@ -352,3 +391,46 @@ class PaymentScheduler:
await self._release_lock(job_name)
else:
raise RuntimeError(f"Job {job_name} is already running")
async def trigger_auto_topup(user_id: int, wallet: dict) -> bool:
"""
Trigger immediate auto top up for a user during subscription renewal
Args:
user_id: User ID
wallet: Wallet dictionary from WalletManager
Returns:
True if top up was successful, False otherwise
"""
from aisbf.payments import get_payment_service
try:
payment_service = get_payment_service()
# Execute auto charge
result = await payment_service.stripe_handler.auto_charge(
user_id=user_id,
amount=wallet["auto_topup_amount"],
payment_method_id=wallet["auto_topup_payment_method_id"]
)
from sqlalchemy.ext.asyncio import AsyncSession
from aisbf.payments.wallet.manager import WalletManager
async with AsyncSession(payment_service.db.engine) as session:
wallet_manager = WalletManager(session)
await wallet_manager.record_auto_topup_attempt(wallet["id"], success=result["success"])
await session.commit()
if result["success"]:
logger.info(f"Immediate auto top up successful for user {user_id}")
else:
logger.warning(f"Immediate auto top up failed for user {user_id}: {result['error']}")
return result["success"]
except Exception as e:
logger.error(f"Error triggering auto top up for user {user_id}: {e}", exc_info=True)
return False
......@@ -124,6 +124,10 @@ class RenewalProcessor:
dict: {'success': bool, 'error': str (optional)}
"""
try:
from decimal import Decimal
from ..wallet.manager import WalletManager
from ..scheduler import trigger_auto_topup
# Determine amount to charge
billing_cycle = subscription['billing_cycle']
......@@ -140,75 +144,136 @@ class RenewalProcessor:
if tier_row:
if billing_cycle == 'monthly':
amount = tier_row[0]
amount = Decimal(tier_row[0])
else: # yearly
amount = tier_row[1]
amount = Decimal(tier_row[1])
logger.info(f"Applying pending tier change for subscription {subscription['id']} to tier {subscription['pending_tier_id']}")
else:
# Fallback to current tier if pending tier not found
if billing_cycle == 'monthly':
amount = subscription['price_monthly']
amount = Decimal(subscription['price_monthly'])
else:
amount = subscription['price_yearly']
amount = Decimal(subscription['price_yearly'])
else:
# Use current tier price
if billing_cycle == 'monthly':
amount = subscription['price_monthly']
amount = Decimal(subscription['price_monthly'])
else: # yearly
amount = subscription['price_yearly']
# Get payment method
with self.db._get_connection() as conn:
cursor = conn.cursor()
cursor.execute("""
SELECT id, user_id, type, gateway, identifier, metadata
FROM payment_methods WHERE id = ?
""", (subscription['payment_method_id'],))
pm_row = cursor.fetchone()
if not pm_row:
return {'success': False, 'error': 'Payment method not found'}
payment_method = {
'id': pm_row[0],
'user_id': pm_row[1],
'type': pm_row[2],
'gateway': pm_row[3],
'identifier': pm_row[4],
'metadata': pm_row[5]
}
amount = Decimal(subscription['price_yearly'])
user_id = subscription['user_id']
# Initialize wallet manager
wallet_manager = WalletManager(self.db.session)
wallet = await wallet_manager.get_wallet(user_id)
# Check wallet first
if await wallet_manager.has_sufficient_balance(user_id, amount):
try:
debit_result = await wallet_manager.debit_wallet(
user_id=user_id,
amount=amount,
transaction_details={
'description': f"Subscription renewal - {subscription['tier_name']} ({billing_cycle})",
'metadata': {
'subscription_id': subscription['id'],
'billing_cycle': billing_cycle
}
}
)
logger.info(f"Wallet debit successful for subscription {subscription['id']}: {amount}")
# Continue to subscription extension
except ValueError as e:
logger.warning(f"Wallet debit failed for subscription {subscription['id']}: {e}")
# Fall through to payment method
else:
# Proceed with renewal on successful wallet charge
charge_result = {'success': True, 'wallet_used': True}
else:
logger.info(f"Insufficient wallet balance for user {user_id}, checking auto top up")
# Check and trigger auto top up
if wallet_manager.should_trigger_auto_topup(wallet):
logger.info(f"Triggering auto top up for user {user_id} during renewal")
topup_success = await trigger_auto_topup(user_id, wallet)
if topup_success:
logger.info(f"Auto top up successful for user {user_id}, retrying wallet debit")
# Retry wallet debit after top up
try:
debit_result = await wallet_manager.debit_wallet(
user_id=user_id,
amount=amount,
transaction_details={
'description': f"Subscription renewal - {subscription['tier_name']} ({billing_cycle})",
'metadata': {
'subscription_id': subscription['id'],
'billing_cycle': billing_cycle,
'auto_topup_used': True
}
}
)
charge_result = {'success': True, 'wallet_used': True, 'auto_topup_used': True}
except ValueError as e:
logger.warning(f"Wallet debit still failed after auto top up: {e}")
# Fall back to direct payment method
else:
logger.warning(f"Auto top up failed for user {user_id} during renewal")
# Extract crypto_type from identifier for crypto payments
if payment_method['type'] == 'crypto':
# For crypto, we need to determine the crypto type
# Check if there's a wallet for this user
# Fall back to direct payment method if wallet charge wasn't successful
if 'charge_result' not in locals() or not charge_result['success']:
# Get payment method
with self.db._get_connection() as conn:
cursor = conn.cursor()
cursor.execute("""
SELECT crypto_type FROM user_crypto_wallets
WHERE user_id = ?
LIMIT 1
""", (subscription['user_id'],))
wallet_row = cursor.fetchone()
if wallet_row:
payment_method['crypto_type'] = wallet_row[0]
else:
payment_method['crypto_type'] = None
else:
payment_method['crypto_type'] = None
# Attempt payment
charge_result = await self._charge_payment(
user_id=subscription['user_id'],
payment_method=payment_method,
amount=amount,
description=f"Subscription renewal - {subscription['tier_name']} ({billing_cycle})"
)
if not charge_result['success']:
logger.warning(f"Payment failed for subscription {subscription['id']}: {charge_result.get('error')}")
return charge_result
SELECT id, user_id, type, gateway, identifier, metadata
FROM payment_methods WHERE id = ?
""", (subscription['payment_method_id'],))
pm_row = cursor.fetchone()
if not pm_row:
return {'success': False, 'error': 'Payment method not found'}
payment_method = {
'id': pm_row[0],
'user_id': pm_row[1],
'type': pm_row[2],
'gateway': pm_row[3],
'identifier': pm_row[4],
'metadata': pm_row[5]
}
# Extract crypto_type from identifier for crypto payments
if payment_method['type'] == 'crypto':
# For crypto, we need to determine the crypto type
# Check if there's a wallet for this user
with self.db._get_connection() as conn:
cursor = conn.cursor()
cursor.execute("""
SELECT crypto_type FROM user_crypto_wallets
WHERE user_id = ?
LIMIT 1
""", (subscription['user_id'],))
wallet_row = cursor.fetchone()
if wallet_row:
payment_method['crypto_type'] = wallet_row[0]
else:
payment_method['crypto_type'] = None
else:
payment_method['crypto_type'] = None
# Attempt payment
charge_result = await self._charge_payment(
user_id=subscription['user_id'],
payment_method=payment_method,
amount=float(amount),
description=f"Subscription renewal - {subscription['tier_name']} ({billing_cycle})"
)
if not charge_result['success']:
logger.warning(f"Payment failed for subscription {subscription['id']}: {charge_result.get('error')}")
return charge_result
# Payment successful - extend period
period_end = subscription['current_period_end']
......
"""
Unit and integration tests for subscription renewal wallet integration
"""
import pytest
from decimal import Decimal
from unittest.mock import AsyncMock, patch, MagicMock
from datetime import datetime, timedelta
from aisbf.payments.subscription.renewal import SubscriptionRenewalProcessor
from aisbf.payments.wallet.manager import WalletManager
@pytest.fixture
def mock_db():
db = MagicMock()
db.session = AsyncMock()
return db
@pytest.fixture
def renewal_processor(mock_db):
return SubscriptionRenewalProcessor(mock_db)
@pytest.fixture
def sample_subscription():
return {
'id': 123,
'user_id': 456,
'tier_id': 2,
'tier_name': 'Pro',
'billing_cycle': 'monthly',
'price_monthly': '19.99',
'price_yearly': '199.99',
'payment_method_id': 789,
'current_period_end': datetime.utcnow(),
'pending_tier_id': None,
'status': 'active'
}
class TestSubscriptionRenewalWalletIntegration:
"""Test subscription renewal with wallet integration"""
@pytest.mark.asyncio
async def test_renew_uses_wallet_when_sufficient_balance(self, renewal_processor, sample_subscription):
"""Test that renewal first uses wallet when sufficient balance exists"""
with patch.object(WalletManager, 'get_wallet') as mock_get_wallet, \
patch.object(WalletManager, 'has_sufficient_balance') as mock_has_balance, \
patch.object(WalletManager, 'debit_wallet') as mock_debit_wallet, \
patch.object(renewal_processor, '_charge_payment') as mock_charge_payment:
mock_get_wallet.return_value = {
'id': 1,
'user_id': 456,
'balance': Decimal('50.00'),
'auto_topup_enabled': False
}
mock_has_balance.return_value = True
mock_debit_wallet.return_value = {'success': True}
result = await renewal_processor._renew_subscription(sample_subscription)
assert result['success'] is True
mock_debit_wallet.assert_called_once()
mock_charge_payment.assert_not_called()
# Verify debit amount is correct
call_args = mock_debit_wallet.call_args
assert call_args[1]['user_id'] == 456
assert call_args[1]['amount'] == Decimal('19.99')
@pytest.mark.asyncio
async def test_renew_triggers_auto_topup_when_insufficient_balance(self, renewal_processor, sample_subscription):
"""Test that auto top up is triggered when wallet balance is insufficient"""
with patch.object(WalletManager, 'get_wallet') as mock_get_wallet, \
patch.object(WalletManager, 'has_sufficient_balance') as mock_has_balance, \
patch.object(WalletManager, 'should_trigger_auto_topup') as mock_should_trigger, \
patch('aisbf.payments.subscription.renewal.trigger_auto_topup') as mock_trigger_topup, \
patch.object(WalletManager, 'debit_wallet') as mock_debit_wallet, \
patch.object(renewal_processor, '_charge_payment') as mock_charge_payment:
mock_get_wallet.return_value = {
'id': 1,
'user_id': 456,
'balance': Decimal('5.00'),
'auto_topup_enabled': True,
'auto_topup_threshold': Decimal('20.00'),
'auto_topup_amount': Decimal('30.00'),
'auto_topup_payment_method_id': 789
}
mock_has_balance.return_value = False
mock_should_trigger.return_value = True
mock_trigger_topup.return_value = True
mock_debit_wallet.return_value = {'success': True}
result = await renewal_processor._renew_subscription(sample_subscription)
assert result['success'] is True
mock_trigger_topup.assert_called_once()
mock_debit_wallet.assert_called_once()
mock_charge_payment.assert_not_called()
@pytest.mark.asyncio
async def test_renew_falls_back_to_payment_method_when_auto_topup_fails(self, renewal_processor, sample_subscription):
"""Test that renewal falls back to direct payment when auto top up fails"""
with patch.object(WalletManager, 'get_wallet') as mock_get_wallet, \
patch.object(WalletManager, 'has_sufficient_balance') as mock_has_balance, \
patch.object(WalletManager, 'should_trigger_auto_topup') as mock_should_trigger, \
patch('aisbf.payments.subscription.renewal.trigger_auto_topup') as mock_trigger_topup, \
patch.object(renewal_processor, '_charge_payment') as mock_charge_payment:
mock_get_wallet.return_value = {
'id': 1,
'user_id': 456,
'balance': Decimal('5.00'),
'auto_topup_enabled': True,
'auto_topup_threshold': Decimal('20.00'),
'auto_topup_amount': Decimal('30.00'),
'auto_topup_payment_method_id': 789
}
mock_has_balance.return_value = False
mock_should_trigger.return_value = True
mock_trigger_topup.return_value = False
mock_charge_payment.return_value = {'success': True}
# Mock payment method lookup
mock_cursor = AsyncMock()
mock_cursor.fetchone.return_value = (789, 456, 'card', 'stripe', 'pm_123', '{}')
mock_conn = MagicMock()
mock_conn.cursor.return_value = mock_cursor
with patch.object(renewal_processor.db, '_get_connection') as mock_get_conn:
mock_get_conn.return_value.__enter__.return_value = mock_conn
result = await renewal_processor._renew_subscription(sample_subscription)
assert result['success'] is True
mock_charge_payment.assert_called_once()
@pytest.mark.asyncio
async def test_renew_uses_pending_tier_price(self, renewal_processor, sample_subscription):
"""Test renewal uses pending tier price when available"""
sample_subscription['pending_tier_id'] = 3
with patch.object(WalletManager, 'get_wallet') as mock_get_wallet, \
patch.object(WalletManager, 'has_sufficient_balance') as mock_has_balance, \
patch.object(WalletManager, 'debit_wallet') as mock_debit_wallet:
mock_get_wallet.return_value = {
'id': 1,
'user_id': 456,
'balance': Decimal('100.00'),
'auto_topup_enabled': False
}
mock_has_balance.return_value = True
mock_debit_wallet.return_value = {'success': True}
# Mock tier lookup
mock_cursor = AsyncMock()
mock_cursor.fetchone.return_value = (Decimal('29.99'), Decimal('299.99'), 'Business')
mock_conn = MagicMock()
mock_conn.cursor.return_value = mock_cursor
with patch.object(renewal_processor.db, '_get_connection') as mock_get_conn:
mock_get_conn.return_value.__enter__.return_value = mock_conn
result = await renewal_processor._renew_subscription(sample_subscription)
assert result['success'] is True
call_args = mock_debit_wallet.call_args
assert call_args[1]['amount'] == Decimal('29.99')
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment