Commit 00c73c80 authored by Your Name's avatar Your Name

feat(payments): implement tier upgrades with proration

- Add upgrade_subscription() method to SubscriptionManager
- Calculate prorated charge based on unused time in billing period
- Formula: new_price - (old_price × unused_fraction)
- Update subscription tier immediately while keeping same period end date
- Charge prorated amount via existing payment method
- Add comprehensive tests for proration calculations
- Test edge cases: no subscription, invalid tier, accurate proration math
parent 394dabc4
......@@ -205,3 +205,167 @@ class SubscriptionManager:
except Exception as e:
logger.error(f"Error charging crypto wallet: {e}")
return {'success': False, 'error': str(e)}
async def upgrade_subscription(self, user_id: int, new_tier_id: int) -> dict:
"""Upgrade subscription to higher tier with prorated credit"""
try:
# Get current subscription
with self.db._get_connection() as conn:
cursor = conn.cursor()
cursor.execute("""
SELECT s.id, s.user_id, s.tier_id, s.payment_method_id, s.status,
s.billing_cycle, s.current_period_start, s.current_period_end,
t.price_monthly, t.price_yearly, t.name as tier_name
FROM subscriptions s
JOIN account_tiers t ON s.tier_id = t.id
WHERE s.user_id = ? AND s.status = 'active'
""", (user_id,))
sub_row = cursor.fetchone()
if not sub_row:
return {'success': False, 'error': 'No active subscription'}
# Convert row to dict
subscription = {
'id': sub_row[0],
'user_id': sub_row[1],
'tier_id': sub_row[2],
'payment_method_id': sub_row[3],
'status': sub_row[4],
'billing_cycle': sub_row[5],
'current_period_start': sub_row[6],
'current_period_end': sub_row[7],
'price_monthly': sub_row[8],
'price_yearly': sub_row[9],
'tier_name': sub_row[10]
}
# Get new tier
with self.db._get_connection() as conn:
cursor = conn.cursor()
cursor.execute(
"SELECT id, name, price_monthly, price_yearly FROM account_tiers WHERE id = ?",
(new_tier_id,)
)
tier_row = cursor.fetchone()
if not tier_row:
return {'success': False, 'error': 'Invalid tier'}
new_tier = {
'id': tier_row[0],
'name': tier_row[1],
'price_monthly': tier_row[2],
'price_yearly': tier_row[3]
}
# Calculate prorated amount
billing_cycle = subscription['billing_cycle']
if billing_cycle == 'monthly':
old_price = subscription['price_monthly']
new_price = new_tier['price_monthly']
else: # yearly
old_price = subscription['price_yearly']
new_price = new_tier['price_yearly']
if old_price is None or new_price is None:
return {'success': False, 'error': f'Invalid prices: old={old_price}, new={new_price}'}
# Calculate unused portion of current period
now = datetime.now(datetime.UTC if hasattr(datetime, 'UTC') else None).replace(tzinfo=None)
# Parse datetime strings if needed
period_start = subscription['current_period_start']
period_end = subscription['current_period_end']
if isinstance(period_start, str):
period_start = datetime.fromisoformat(period_start)
if isinstance(period_end, str):
period_end = datetime.fromisoformat(period_end)
total_period_seconds = (period_end - period_start).total_seconds()
remaining_seconds = (period_end - now).total_seconds()
if remaining_seconds <= 0:
# Period already ended, charge full amount
prorated_amount = new_price
unused_fraction = 0
else:
# Calculate unused portion
unused_fraction = remaining_seconds / total_period_seconds
unused_credit = old_price * unused_fraction
# New charge = new_price - unused_credit
prorated_amount = new_price - unused_credit
# Ensure non-negative
prorated_amount = max(0, prorated_amount)
logger.info(f"Upgrade proration: old=${old_price}, new=${new_price}, "
f"unused={unused_fraction:.2%}, charge=${prorated_amount:.2f}")
# Get payment method
with self.db._get_connection() as conn:
cursor = conn.cursor()
cursor.execute(
"SELECT * FROM payment_methods WHERE id = ?",
(subscription['payment_method_id'],)
)
pm_row = cursor.fetchone()
if not pm_row:
return {'success': False, 'error': 'Payment method not found'}
payment_method = {
'id': pm_row[0],
'user_id': pm_row[1],
'type': pm_row[2],
'gateway': pm_row[3],
'crypto_type': pm_row[4] if len(pm_row) > 4 else None
}
# Charge prorated amount
if prorated_amount > 0:
charge_result = await self._charge_payment(
user_id=user_id,
payment_method=payment_method,
amount=prorated_amount,
description=f"Upgrade to {new_tier['name']} (prorated)"
)
if not charge_result['success']:
return charge_result
# Update subscription
with self.db._get_connection() as conn:
cursor = conn.cursor()
cursor.execute("""
UPDATE subscriptions
SET tier_id = ?
WHERE id = ?
""", (new_tier_id, subscription['id']))
conn.commit()
# Update user tier
with self.db._get_connection() as conn:
cursor = conn.cursor()
cursor.execute(
"UPDATE users SET tier_id = ? WHERE id = ?",
(new_tier_id, user_id)
)
conn.commit()
logger.info(f"Upgraded subscription {subscription['id']} to tier {new_tier_id}")
return {
'success': True,
'charged_amount': prorated_amount,
'next_billing_date': period_end.isoformat()
}
except Exception as e:
logger.error(f"Error upgrading subscription: {e}")
import traceback
logger.error(traceback.format_exc())
return {'success': False, 'error': str(e)}
......@@ -86,3 +86,171 @@ async def test_create_subscription(db_manager):
print(f"Result: {result}")
assert result['success'] == True, f"Expected success but got: {result}"
assert 'subscription_id' in result
@pytest.mark.anyio
async def test_upgrade_subscription_with_proration(db_manager):
"""Test subscription upgrade with prorated charges"""
# Mock handlers
class MockStripeHandler:
async def charge_subscription(self, subscription_id, amount):
return {'success': True, 'transaction_id': 'test_tx'}
manager = SubscriptionManager(
db_manager,
MockStripeHandler(),
None, # PayPal handler
None, # Crypto wallet manager
None # Price service
)
# Create initial subscription
result = await manager.create_subscription(
user_id=1,
tier_id=db_manager._test_pro_tier_id,
payment_method_id=1,
billing_cycle='monthly'
)
assert result['success'] == True
# Add Premium tier
with db_manager._get_connection() as conn:
cursor = conn.cursor()
cursor.execute("""
INSERT INTO account_tiers (name, price_monthly, price_yearly, is_default)
VALUES ('Premium', 20.00, 200.00, 0)
""")
conn.commit()
premium_tier_id = cursor.lastrowid
# Upgrade to Premium
result = await manager.upgrade_subscription(user_id=1, new_tier_id=premium_tier_id)
assert result['success'] == True
assert 'charged_amount' in result
# Should charge less than full $20 due to proration
assert result['charged_amount'] < 20.00
assert result['charged_amount'] > 0
# Verify subscription was updated
with db_manager._get_connection() as conn:
cursor = conn.cursor()
cursor.execute("""
SELECT tier_id FROM subscriptions WHERE user_id = 1 AND status = 'active'
""")
row = cursor.fetchone()
assert row[0] == premium_tier_id
# Verify user tier was updated
with db_manager._get_connection() as conn:
cursor = conn.cursor()
cursor.execute("SELECT tier_id FROM users WHERE id = 1")
row = cursor.fetchone()
assert row[0] == premium_tier_id
@pytest.mark.anyio
async def test_upgrade_subscription_proration_calculation(db_manager):
"""Test that proration calculation is accurate"""
from datetime import datetime, timedelta
# Mock handlers
class MockStripeHandler:
async def charge_subscription(self, subscription_id, amount):
return {'success': True, 'transaction_id': 'test_tx'}
manager = SubscriptionManager(
db_manager,
MockStripeHandler(),
None, None, None
)
# Create initial subscription
result = await manager.create_subscription(
user_id=1,
tier_id=db_manager._test_pro_tier_id,
payment_method_id=1,
billing_cycle='monthly'
)
assert result['success'] == True
# Manually set subscription to be halfway through billing period
with db_manager._get_connection() as conn:
cursor = conn.cursor()
now = datetime.now(datetime.UTC if hasattr(datetime, 'UTC') else None).replace(tzinfo=None)
period_start = now - timedelta(days=15)
period_end = now + timedelta(days=15)
cursor.execute("""
UPDATE subscriptions
SET current_period_start = ?, current_period_end = ?
WHERE user_id = 1
""", (period_start, period_end))
conn.commit()
# Add Premium tier ($20/month)
with db_manager._get_connection() as conn:
cursor = conn.cursor()
cursor.execute("""
INSERT INTO account_tiers (name, price_monthly, price_yearly, is_default)
VALUES ('Premium', 20.00, 200.00, 0)
""")
conn.commit()
premium_tier_id = cursor.lastrowid
# Upgrade to Premium
result = await manager.upgrade_subscription(user_id=1, new_tier_id=premium_tier_id)
assert result['success'] == True
# At halfway point: new_price - (old_price * 0.5) = 20 - (10 * 0.5) = 15
# Allow small floating point variance
assert 14.9 < result['charged_amount'] < 15.1
@pytest.mark.anyio
async def test_upgrade_subscription_no_active_subscription(db_manager):
"""Test upgrade fails when no active subscription exists"""
class MockStripeHandler:
async def charge_subscription(self, subscription_id, amount):
return {'success': True, 'transaction_id': 'test_tx'}
manager = SubscriptionManager(
db_manager,
MockStripeHandler(),
None, None, None
)
# Try to upgrade without creating subscription first
result = await manager.upgrade_subscription(user_id=1, new_tier_id=2)
assert result['success'] == False
assert 'No active subscription' in result['error']
@pytest.mark.anyio
async def test_upgrade_subscription_invalid_tier(db_manager):
"""Test upgrade fails with invalid tier ID"""
class MockStripeHandler:
async def charge_subscription(self, subscription_id, amount):
return {'success': True, 'transaction_id': 'test_tx'}
manager = SubscriptionManager(
db_manager,
MockStripeHandler(),
None, None, None
)
# Create initial subscription
result = await manager.create_subscription(
user_id=1,
tier_id=db_manager._test_pro_tier_id,
payment_method_id=1,
billing_cycle='monthly'
)
assert result['success'] == True
# Try to upgrade to non-existent tier
result = await manager.upgrade_subscription(user_id=1, new_tier_id=99999)
assert result['success'] == False
assert 'Invalid tier' in result['error']
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment