Commit f0143b4c authored by Your Name's avatar Your Name

Implement tier downgrades and subscription cancellation

- Add downgrade_subscription() method to SubscriptionManager
  - Sets pending_tier_id for scheduled downgrade at period end
  - No immediate charge or refund
  - User keeps current tier until period end
  - Returns scheduled downgrade date
- Add cancel_subscription() method to SubscriptionManager
  - Sets cancel_at_period_end flag
  - User retains access until period end
  - No refund issued
  - Returns cancellation date
- Add comprehensive tests for both operations
  - Test downgrade scheduling and no-charge behavior
  - Test cancellation scheduling and no-refund behavior
  - Test error cases (no subscription, invalid tier)

All 12 subscription tests passing.
parent 00c73c80
......@@ -369,3 +369,126 @@ class SubscriptionManager:
import traceback
logger.error(traceback.format_exc())
return {'success': False, 'error': str(e)}
async def downgrade_subscription(self, user_id: int, new_tier_id: int) -> dict:
"""Schedule subscription downgrade at period end (no immediate charge/refund)"""
try:
# Get current subscription
with self.db._get_connection() as conn:
cursor = conn.cursor()
cursor.execute("""
SELECT s.id, s.user_id, s.tier_id, s.status,
s.current_period_end
FROM subscriptions s
WHERE s.user_id = ? AND s.status = 'active'
""", (user_id,))
sub_row = cursor.fetchone()
if not sub_row:
return {'success': False, 'error': 'No active subscription'}
subscription = {
'id': sub_row[0],
'user_id': sub_row[1],
'tier_id': sub_row[2],
'status': sub_row[3],
'current_period_end': sub_row[4]
}
# Verify new tier exists
with self.db._get_connection() as conn:
cursor = conn.cursor()
cursor.execute(
"SELECT id, name FROM account_tiers WHERE id = ?",
(new_tier_id,)
)
tier_row = cursor.fetchone()
if not tier_row:
return {'success': False, 'error': 'Invalid tier'}
new_tier = {
'id': tier_row[0],
'name': tier_row[1]
}
# Set pending_tier_id (downgrade scheduled at period end)
with self.db._get_connection() as conn:
cursor = conn.cursor()
cursor.execute("""
UPDATE subscriptions
SET pending_tier_id = ?
WHERE id = ?
""", (new_tier_id, subscription['id']))
conn.commit()
# Parse period_end if needed
period_end = subscription['current_period_end']
if isinstance(period_end, str):
period_end = datetime.fromisoformat(period_end)
logger.info(f"Scheduled downgrade for subscription {subscription['id']} to tier {new_tier_id} at {period_end}")
return {
'success': True,
'downgrade_date': period_end.isoformat(),
'message': f"Downgrade to {new_tier['name']} scheduled for {period_end.date()}"
}
except Exception as e:
logger.error(f"Error scheduling downgrade: {e}")
import traceback
logger.error(traceback.format_exc())
return {'success': False, 'error': str(e)}
async def cancel_subscription(self, user_id: int) -> dict:
"""Cancel subscription at period end (no refund, access until period end)"""
try:
# Get current subscription
with self.db._get_connection() as conn:
cursor = conn.cursor()
cursor.execute("""
SELECT s.id, s.user_id, s.status, s.current_period_end
FROM subscriptions s
WHERE s.user_id = ? AND s.status = 'active'
""", (user_id,))
sub_row = cursor.fetchone()
if not sub_row:
return {'success': False, 'error': 'No active subscription'}
subscription = {
'id': sub_row[0],
'user_id': sub_row[1],
'status': sub_row[2],
'current_period_end': sub_row[3]
}
# Set cancel_at_period_end flag
with self.db._get_connection() as conn:
cursor = conn.cursor()
cursor.execute("""
UPDATE subscriptions
SET cancel_at_period_end = 1
WHERE id = ?
""", (subscription['id'],))
conn.commit()
# Parse period_end if needed
period_end = subscription['current_period_end']
if isinstance(period_end, str):
period_end = datetime.fromisoformat(period_end)
logger.info(f"Scheduled cancellation for subscription {subscription['id']} at {period_end}")
return {
'success': True,
'cancellation_date': period_end.isoformat(),
'message': f"Subscription will be canceled on {period_end.date()}. You will retain access until then."
}
except Exception as e:
logger.error(f"Error canceling subscription: {e}")
import traceback
logger.error(traceback.format_exc())
return {'success': False, 'error': str(e)}
......@@ -254,3 +254,240 @@ async def test_upgrade_subscription_invalid_tier(db_manager):
assert result['success'] == False
assert 'Invalid tier' in result['error']
@pytest.mark.anyio
async def test_downgrade_subscription(db_manager):
"""Test subscription downgrade scheduled at period end"""
class MockStripeHandler:
async def charge_subscription(self, subscription_id, amount):
return {'success': True, 'transaction_id': 'test_tx'}
manager = SubscriptionManager(
db_manager,
MockStripeHandler(),
None, None, None
)
# Create Premium tier
with db_manager._get_connection() as conn:
cursor = conn.cursor()
cursor.execute("""
INSERT INTO account_tiers (name, price_monthly, price_yearly, is_default)
VALUES ('Premium', 20.00, 200.00, 0)
""")
conn.commit()
premium_tier_id = cursor.lastrowid
# Create subscription at Premium tier
result = await manager.create_subscription(
user_id=1,
tier_id=premium_tier_id,
payment_method_id=1,
billing_cycle='monthly'
)
assert result['success'] == True
# Downgrade to Pro tier (scheduled at period end)
result = await manager.downgrade_subscription(user_id=1, new_tier_id=db_manager._test_pro_tier_id)
assert result['success'] == True
assert 'downgrade_date' in result
# Verify pending_tier_id is set
with db_manager._get_connection() as conn:
cursor = conn.cursor()
cursor.execute("""
SELECT tier_id, pending_tier_id FROM subscriptions
WHERE user_id = 1 AND status = 'active'
""")
row = cursor.fetchone()
assert row[0] == premium_tier_id # Still on Premium
assert row[1] == db_manager._test_pro_tier_id # Downgrade scheduled
# Verify user still has Premium tier
with db_manager._get_connection() as conn:
cursor = conn.cursor()
cursor.execute("SELECT tier_id FROM users WHERE id = 1")
row = cursor.fetchone()
assert row[0] == premium_tier_id
@pytest.mark.anyio
async def test_downgrade_subscription_no_charge(db_manager):
"""Test downgrade does not charge immediately"""
charge_called = False
class MockStripeHandler:
async def charge_subscription(self, subscription_id, amount):
nonlocal charge_called
charge_called = True
return {'success': True, 'transaction_id': 'test_tx'}
manager = SubscriptionManager(
db_manager,
MockStripeHandler(),
None, None, None
)
# Create Premium tier
with db_manager._get_connection() as conn:
cursor = conn.cursor()
cursor.execute("""
INSERT INTO account_tiers (name, price_monthly, price_yearly, is_default)
VALUES ('Premium', 20.00, 200.00, 0)
""")
conn.commit()
premium_tier_id = cursor.lastrowid
# Create subscription
result = await manager.create_subscription(
user_id=1,
tier_id=premium_tier_id,
payment_method_id=1,
billing_cycle='monthly'
)
assert result['success'] == True
# Reset charge flag
charge_called = False
# Downgrade
result = await manager.downgrade_subscription(user_id=1, new_tier_id=db_manager._test_pro_tier_id)
assert result['success'] == True
assert charge_called == False # No charge should occur
@pytest.mark.anyio
async def test_downgrade_subscription_no_active_subscription(db_manager):
"""Test downgrade fails when no active subscription exists"""
manager = SubscriptionManager(db_manager, None, None, None, None)
result = await manager.downgrade_subscription(user_id=1, new_tier_id=db_manager._test_pro_tier_id)
assert result['success'] == False
assert 'No active subscription' in result['error']
@pytest.mark.anyio
async def test_downgrade_subscription_invalid_tier(db_manager):
"""Test downgrade fails with invalid tier ID"""
class MockStripeHandler:
async def charge_subscription(self, subscription_id, amount):
return {'success': True, 'transaction_id': 'test_tx'}
manager = SubscriptionManager(
db_manager,
MockStripeHandler(),
None, None, None
)
# Create subscription
result = await manager.create_subscription(
user_id=1,
tier_id=db_manager._test_pro_tier_id,
payment_method_id=1,
billing_cycle='monthly'
)
assert result['success'] == True
# Try to downgrade to non-existent tier
result = await manager.downgrade_subscription(user_id=1, new_tier_id=99999)
assert result['success'] == False
assert 'Invalid tier' in result['error']
@pytest.mark.anyio
async def test_cancel_subscription(db_manager):
"""Test subscription cancellation at period end"""
class MockStripeHandler:
async def charge_subscription(self, subscription_id, amount):
return {'success': True, 'transaction_id': 'test_tx'}
manager = SubscriptionManager(
db_manager,
MockStripeHandler(),
None, None, None
)
# Create subscription
result = await manager.create_subscription(
user_id=1,
tier_id=db_manager._test_pro_tier_id,
payment_method_id=1,
billing_cycle='monthly'
)
assert result['success'] == True
# Cancel subscription
result = await manager.cancel_subscription(user_id=1)
assert result['success'] == True
assert 'cancellation_date' in result
# Verify cancel_at_period_end is set
with db_manager._get_connection() as conn:
cursor = conn.cursor()
cursor.execute("""
SELECT status, cancel_at_period_end FROM subscriptions
WHERE user_id = 1
""")
row = cursor.fetchone()
assert row[0] == 'active' # Still active
assert row[1] == 1 # Cancel scheduled
# Verify user still has Pro tier
with db_manager._get_connection() as conn:
cursor = conn.cursor()
cursor.execute("SELECT tier_id FROM users WHERE id = 1")
row = cursor.fetchone()
assert row[0] == db_manager._test_pro_tier_id
@pytest.mark.anyio
async def test_cancel_subscription_no_refund(db_manager):
"""Test cancellation does not issue refund"""
refund_called = False
class MockStripeHandler:
async def charge_subscription(self, subscription_id, amount):
return {'success': True, 'transaction_id': 'test_tx'}
async def refund_payment(self, transaction_id, amount):
nonlocal refund_called
refund_called = True
return {'success': True}
manager = SubscriptionManager(
db_manager,
MockStripeHandler(),
None, None, None
)
# Create subscription
result = await manager.create_subscription(
user_id=1,
tier_id=db_manager._test_pro_tier_id,
payment_method_id=1,
billing_cycle='monthly'
)
assert result['success'] == True
# Cancel subscription
result = await manager.cancel_subscription(user_id=1)
assert result['success'] == True
assert refund_called == False # No refund should occur
@pytest.mark.anyio
async def test_cancel_subscription_no_active_subscription(db_manager):
"""Test cancellation fails when no active subscription exists"""
manager = SubscriptionManager(db_manager, None, None, None, None)
result = await manager.cancel_subscription(user_id=1)
assert result['success'] == False
assert 'No active subscription' in result['error']
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment