"""Tests for Stripe Checkout Link billing integration. Covers: - Checkout URL generation for pro/team tiers - Invalid tier handling - Missing checkout URL handling - Billing status endpoint - Webhook handling (checkout.session.completed, subscription events) - Manual HMAC signature verification - Tier determination from payment amounts """ import hashlib import hmac import json import time import pytest import pytest_asyncio from httpx import ASGITransport, AsyncClient from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine from app.database import Base from app.dependencies import get_db from app.auth import create_access_token, hash_password from app.main import app from app.models.saas_models import Organization, OrganizationMember, User # ── Test database setup ────────────────────────────────────────────────────── TEST_DATABASE_URL = "sqlite+aiosqlite:///:memory:" test_engine = create_async_engine(TEST_DATABASE_URL, echo=False) TestSessionLocal = async_sessionmaker( test_engine, class_=AsyncSession, expire_on_commit=False ) @pytest_asyncio.fixture(scope="session", autouse=True) async def setup_database(): """Create all tables once for the test session.""" import app.models.saas_models # noqa: F401 import app.models.models # noqa: F401 async with test_engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) yield async with test_engine.begin() as conn: await conn.run_sync(Base.metadata.drop_all) @pytest_asyncio.fixture async def db_session(): """Provide a clean database session for each test.""" async with TestSessionLocal() as session: yield session await session.rollback() @pytest_asyncio.fixture async def client(db_session: AsyncSession): """Provide an HTTP test client with DB dependency override.""" async def override_get_db(): yield db_session app.dependency_overrides[get_db] = override_get_db transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://test") as ac: yield ac app.dependency_overrides.clear() async def _create_user_and_org(db: AsyncSession, tier: str = "free"): """Helper to create a user, org, and auth token.""" user = User( email=f"test-{tier}-{hashlib.md5(str(time.time()).encode()).hexdigest()[:8]}@example.com", password_hash=hash_password("testpass123"), ) db.add(user) await db.flush() org = Organization( name=f"Test Org {tier}", slug=f"test-org-{tier}-{hashlib.md5(str(time.time()).encode()).hexdigest()[:8]}", tier=tier, ) db.add(org) await db.flush() membership = OrganizationMember( organization_id=org.id, user_id=user.id, role="owner", ) db.add(membership) await db.flush() token = create_access_token(user.id) return user, org, token # ── Checkout endpoint tests ────────────────────────────────────────────────── class TestCheckoutEndpoint: """Tests for GET /api/v1/billing/checkout/{tier}""" @pytest.mark.asyncio async def test_checkout_pro_returns_redirect_url(self, client, db_session): """Pro checkout should return a redirect URL with org ID and email.""" from app.config import settings original_pro_url = settings.stripe_pro_checkout_url settings.stripe_pro_checkout_url = "https://buy.stripe.com/test_pro" user, org, token = await _create_user_and_org(db_session) response = await client.get( "/api/v1/billing/checkout/pro", headers={"Authorization": f"Bearer {token}"}, ) settings.stripe_pro_checkout_url = original_pro_url assert response.status_code == 200 data = response.json() assert "redirect_url" in data assert data["tier"] == "pro" assert data["organization_id"] == org.id assert "client_reference_id" in data["redirect_url"] assert "prefilled_email" in data["redirect_url"] assert data["redirect_url"].startswith("https://buy.stripe.com/test_pro?") @pytest.mark.asyncio async def test_checkout_team_returns_redirect_url(self, client, db_session): """Team checkout should return a redirect URL.""" from app.config import settings original_team_url = settings.stripe_team_checkout_url settings.stripe_team_checkout_url = "https://buy.stripe.com/test_team" user, org, token = await _create_user_and_org(db_session) response = await client.get( "/api/v1/billing/checkout/team", headers={"Authorization": f"Bearer {token}"}, ) settings.stripe_team_checkout_url = original_team_url assert response.status_code == 200 data = response.json() assert data["tier"] == "team" assert data["organization_id"] == org.id @pytest.mark.asyncio async def test_checkout_invalid_tier(self, client, db_session): """Invalid tier should return 400.""" user, org, token = await _create_user_and_org(db_session) response = await client.get( "/api/v1/billing/checkout/enterprise", headers={"Authorization": f"Bearer {token}"}, ) assert response.status_code == 400 assert "Invalid tier" in response.json()["detail"] @pytest.mark.asyncio async def test_checkout_free_tier_invalid(self, client, db_session): """Free tier checkout should return 400 (can't upgrade to free).""" user, org, token = await _create_user_and_org(db_session) response = await client.get( "/api/v1/billing/checkout/free", headers={"Authorization": f"Bearer {token}"}, ) assert response.status_code == 400 @pytest.mark.asyncio async def test_checkout_missing_url_returns_503(self, client, db_session): """If checkout URL is not configured, return 503.""" from app.config import settings original_pro_url = settings.stripe_pro_checkout_url settings.stripe_pro_checkout_url = "" # Empty = not configured user, org, token = await _create_user_and_org(db_session) response = await client.get( "/api/v1/billing/checkout/pro", headers={"Authorization": f"Bearer {token}"}, ) settings.stripe_pro_checkout_url = original_pro_url assert response.status_code == 503 assert "not configured" in response.json()["detail"].lower() @pytest.mark.asyncio async def test_checkout_requires_auth(self, client, db_session): """Checkout endpoint requires authentication.""" response = await client.get("/api/v1/billing/checkout/pro") assert response.status_code == 401 # ── Billing status endpoint tests ───────────────────────────────────────────── class TestBillingStatusEndpoint: """Tests for GET /api/v1/billing/status""" @pytest.mark.asyncio async def test_billing_status_free_org(self, client, db_session): """Free org should see pro and team checkout URLs.""" from app.config import settings original_pro = settings.stripe_pro_checkout_url original_team = settings.stripe_team_checkout_url settings.stripe_pro_checkout_url = "https://buy.stripe.com/pro" settings.stripe_team_checkout_url = "https://buy.stripe.com/team" user, org, token = await _create_user_and_org(db_session, tier="free") response = await client.get( "/api/v1/billing/status", headers={"Authorization": f"Bearer {token}"}, ) settings.stripe_pro_checkout_url = original_pro settings.stripe_team_checkout_url = original_team assert response.status_code == 200 data = response.json() assert data["tier"] == "free" assert "tier_info" in data assert "pro" in data["checkout_urls"] assert "team" in data["checkout_urls"] @pytest.mark.asyncio async def test_billing_status_pro_org(self, client, db_session): """Pro org should only see team upgrade URL.""" from app.config import settings original_team = settings.stripe_team_checkout_url settings.stripe_team_checkout_url = "https://buy.stripe.com/team" user, org, token = await _create_user_and_org(db_session, tier="pro") response = await client.get( "/api/v1/billing/status", headers={"Authorization": f"Bearer {token}"}, ) settings.stripe_team_checkout_url = original_team assert response.status_code == 200 data = response.json() assert data["tier"] == "pro" assert "pro" not in data["checkout_urls"] assert "team" in data["checkout_urls"] @pytest.mark.asyncio async def test_billing_status_team_org(self, client, db_session): """Team org should see no upgrade URLs (already at top tier).""" user, org, token = await _create_user_and_org(db_session, tier="team") response = await client.get( "/api/v1/billing/status", headers={"Authorization": f"Bearer {token}"}, ) assert response.status_code == 200 data = response.json() assert data["tier"] == "team" assert data["checkout_urls"] == {} @pytest.mark.asyncio async def test_billing_status_shows_tier_limits(self, client, db_session): """Billing status should include tier limits from tier_limits module.""" user, org, token = await _create_user_and_org(db_session, tier="free") response = await client.get( "/api/v1/billing/status", headers={"Authorization": f"Bearer {token}"}, ) data = response.json() assert "tier_info" in data assert data["tier_info"]["tier"] == "free" assert "limits" in data["tier_info"] @pytest.mark.asyncio async def test_billing_status_requires_auth(self, client, db_session): """Billing status endpoint requires authentication.""" response = await client.get("/api/v1/billing/status") assert response.status_code == 401 # ── Success/cancel callback tests ──────────────────────────────────────────── class TestBillingCallbacks: """Tests for GET /api/v1/billing/success and /cancel""" @pytest.mark.asyncio async def test_success_callback(self, client, db_session): """Success callback should return success message.""" response = await client.get( "/api/v1/billing/success?session_id=cs_test_123" ) assert response.status_code == 200 data = response.json() assert "activated" in data["message"].lower() or "upgraded" in data["message"].lower() @pytest.mark.asyncio async def test_success_callback_without_session_id(self, client, db_session): """Success callback works even without session_id (webhook already processed).""" response = await client.get("/api/v1/billing/success") assert response.status_code == 200 @pytest.mark.asyncio async def test_cancel_callback(self, client, db_session): """Cancel callback should return cancellation message.""" response = await client.get("/api/v1/billing/cancel") assert response.status_code == 200 data = response.json() assert "canceled" in data["message"].lower() # ── Webhook tests ───────────────────────────────────────────────────────────── class TestWebhookHandler: """Tests for POST /api/v1/billing/webhook""" def _make_checkout_session_event( self, org_id: str, amount_total: int = 900, customer_id: str = "cus_test123" ) -> dict: """Create a mock checkout.session.completed event.""" return { "id": "evt_test_123", "object": "event", "type": "checkout.session.completed", "data": { "object": { "id": "cs_test_123", "object": "checkout.session", "client_reference_id": org_id, "customer": customer_id, "amount_total": amount_total, "line_items": {"data": []}, } }, } def _make_subscription_event( self, customer_id: str, event_type: str, amount: int = 900 ) -> dict: """Create a mock subscription event.""" return { "id": "evt_test_sub_123", "object": "event", "type": event_type, "data": { "object": { "id": "sub_test_123", "object": "subscription", "customer": customer_id, "metadata": {}, "items": { "data": [ { "plan": {"amount": amount}, } ] }, } }, } @pytest.mark.asyncio async def test_webhook_checkout_session_pro(self, client, db_session): """checkout.session.completed with $9 amount should set tier to pro.""" from app.config import settings original_secret = settings.stripe_webhook_secret original_api_key = settings.stripe_api_key # Disable signature verification for tests settings.stripe_webhook_secret = "" settings.stripe_api_key = "" user, org, token = await _create_user_and_org(db_session, tier="free") event = self._make_checkout_session_event(org.id, amount_total=900) response = await client.post( "/api/v1/billing/webhook", json=event, ) settings.stripe_webhook_secret = original_secret settings.stripe_api_key = original_api_key assert response.status_code == 200 data = response.json() assert data["status"] == "ok" assert data["event_type"] == "checkout.session.completed" # Verify org was upgraded result = await db_session.execute( select(Organization).where(Organization.id == org.id) ) updated_org = result.scalar_one() assert updated_org.tier == "pro" assert updated_org.stripe_customer_id == "cus_test123" @pytest.mark.asyncio async def test_webhook_checkout_session_team(self, client, db_session): """checkout.session.completed with $29 amount should set tier to team.""" from app.config import settings original_secret = settings.stripe_webhook_secret original_api_key = settings.stripe_api_key settings.stripe_webhook_secret = "" settings.stripe_api_key = "" user, org, token = await _create_user_and_org(db_session, tier="free") event = self._make_checkout_session_event(org.id, amount_total=2900) response = await client.post( "/api/v1/billing/webhook", json=event, ) settings.stripe_webhook_secret = original_secret settings.stripe_api_key = original_api_key assert response.status_code == 200 result = await db_session.execute( select(Organization).where(Organization.id == org.id) ) updated_org = result.scalar_one() assert updated_org.tier == "team" @pytest.mark.asyncio async def test_webhook_subscription_deleted_downgrades_to_free(self, client, db_session): """customer.subscription.deleted should downgrade org to free.""" from app.config import settings original_secret = settings.stripe_webhook_secret original_api_key = settings.stripe_api_key settings.stripe_webhook_secret = "" settings.stripe_api_key = "" user, org, token = await _create_user_and_org(db_session, tier="pro") # Manually set stripe_customer_id org.stripe_customer_id = "cus_delete_test" await db_session.flush() event = self._make_subscription_event( customer_id="cus_delete_test", event_type="customer.subscription.deleted", ) response = await client.post( "/api/v1/billing/webhook", json=event, ) settings.stripe_webhook_secret = original_secret settings.stripe_api_key = original_api_key assert response.status_code == 200 result = await db_session.execute( select(Organization).where(Organization.id == org.id) ) updated_org = result.scalar_one() assert updated_org.tier == "free" @pytest.mark.asyncio async def test_webhook_subscription_updated(self, client, db_session): """customer.subscription.updated should update org tier based on amount.""" from app.config import settings original_secret = settings.stripe_webhook_secret original_api_key = settings.stripe_api_key settings.stripe_webhook_secret = "" settings.stripe_api_key = "" user, org, token = await _create_user_and_org(db_session, tier="pro") org.stripe_customer_id = "cus_update_test" await db_session.flush() event = self._make_subscription_event( customer_id="cus_update_test", event_type="customer.subscription.updated", amount=2900, # Upgrade to team ) response = await client.post( "/api/v1/billing/webhook", json=event, ) settings.stripe_webhook_secret = original_secret settings.stripe_api_key = original_api_key assert response.status_code == 200 result = await db_session.execute( select(Organization).where(Organization.id == org.id) ) updated_org = result.scalar_one() assert updated_org.tier == "team" @pytest.mark.asyncio async def test_webhook_ignores_unhandled_events(self, client, db_session): """Unhandled event types should be ignored.""" from app.config import settings original_secret = settings.stripe_webhook_secret settings.stripe_webhook_secret = "" event = { "id": "evt_test_ignore", "type": "account.updated", "data": {"object": {}}, } response = await client.post( "/api/v1/billing/webhook", json=event, ) settings.stripe_webhook_secret = original_secret assert response.status_code == 200 data = response.json() assert data["status"] == "ignored" @pytest.mark.asyncio async def test_webhook_invalid_signature_rejected(self, client, db_session): """Webhook with invalid signature should be rejected.""" from app.config import settings original_secret = settings.stripe_webhook_secret original_api_key = settings.stripe_api_key # Set up webhook secret but no API key (forces manual verification) settings.stripe_webhook_secret = "whsec_test_secret_key_12345" settings.stripe_api_key = "" event = self._make_checkout_session_event("fake_org_id") response = await client.post( "/api/v1/billing/webhook", json=event, headers={"stripe-signature": "t=1234,v1=invalid_signature_here"}, ) settings.stripe_webhook_secret = original_secret settings.stripe_api_key = original_api_key assert response.status_code == 400 assert "signature" in response.json()["detail"].lower() @pytest.mark.asyncio async def test_webhook_checkout_missing_org_id(self, client, db_session): """Webhook without client_reference_id should return error.""" from app.config import settings original_secret = settings.stripe_webhook_secret settings.stripe_webhook_secret = "" event = { "id": "evt_test_no_org", "type": "checkout.session.completed", "data": { "object": { "id": "cs_test_no_org", "client_reference_id": None, "customer": "cus_123", "amount_total": 900, } }, } response = await client.post( "/api/v1/billing/webhook", json=event, ) settings.stripe_webhook_secret = original_secret assert response.status_code == 200 assert response.json()["status"] == "error" @pytest.mark.asyncio async def test_webhook_invoice_payment_succeeded(self, client, db_session): """invoice.payment_succeeded should acknowledge without downgrading.""" from app.config import settings original_secret = settings.stripe_webhook_secret original_api_key = settings.stripe_api_key settings.stripe_webhook_secret = "" settings.stripe_api_key = "" user, org, token = await _create_user_and_org(db_session, tier="pro") org.stripe_customer_id = "cus_invoice_test" await db_session.flush() event = { "id": "evt_invoice_paid", "type": "invoice.payment_succeeded", "data": { "object": { "id": "in_test_123", "customer": "cus_invoice_test", } }, } response = await client.post( "/api/v1/billing/webhook", json=event, ) settings.stripe_webhook_secret = original_secret settings.stripe_api_key = original_api_key assert response.status_code == 200 data = response.json() assert data["status"] == "ok" # Verify tier wasn't changed result = await db_session.execute( select(Organization).where(Organization.id == org.id) ) updated_org = result.scalar_one() assert updated_org.tier == "pro" @pytest.mark.asyncio async def test_webhook_invoice_payment_failed(self, client, db_session): """invoice.payment_failed should not downgrade (Stripe retries).""" from app.config import settings original_secret = settings.stripe_webhook_secret original_api_key = settings.stripe_api_key settings.stripe_webhook_secret = "" settings.stripe_api_key = "" user, org, token = await _create_user_and_org(db_session, tier="pro") org.stripe_customer_id = "cus_fail_test" await db_session.flush() event = { "id": "evt_invoice_failed", "type": "invoice.payment_failed", "data": { "object": { "id": "in_test_fail", "customer": "cus_fail_test", } }, } response = await client.post( "/api/v1/billing/webhook", json=event, ) settings.stripe_webhook_secret = original_secret settings.stripe_api_key = original_api_key assert response.status_code == 200 data = response.json() assert data["status"] == "ok" # Verify tier wasn't changed (payment will be retried) result = await db_session.execute( select(Organization).where(Organization.id == org.id) ) updated_org = result.scalar_one() assert updated_org.tier == "pro" # ── Tier determination helper tests ─────────────────────────────────────────── class TestTierDetermination: """Tests for _determine_tier_from_amount helper.""" def test_pro_from_amount(self): from app.api.billing import _determine_tier_from_amount assert _determine_tier_from_amount(900) == "pro" assert _determine_tier_from_amount(1500) == "pro" def test_team_from_amount(self): from app.api.billing import _determine_tier_from_amount assert _determine_tier_from_amount(2900) == "team" assert _determine_tier_from_amount(5000) == "team" assert _determine_tier_from_amount(9900) == "team" def test_free_from_small_amount(self): from app.api.billing import _determine_tier_from_amount assert _determine_tier_from_amount(0) == "free" assert _determine_tier_from_amount(100) == "free" # ── Integration: checkout → webhook → tier update ──────────────────────────── class TestBillingIntegration: """End-to-end style tests for the billing flow.""" @pytest.mark.asyncio async def test_free_to_pro_upgrade_flow(self, client, db_session): """Complete flow: free org → checkout → webhook → pro tier.""" from app.config import settings original_secret = settings.stripe_webhook_secret original_pro_url = settings.stripe_pro_checkout_url original_team_url = settings.stripe_team_checkout_url original_api_key = settings.stripe_api_key settings.stripe_webhook_secret = "" settings.stripe_pro_checkout_url = "https://buy.stripe.com/test_pro_link" settings.stripe_team_checkout_url = "https://buy.stripe.com/test_team_link" settings.stripe_api_key = "" user, org, token = await _create_user_and_org(db_session, tier="free") # Step 1: Get checkout URL checkout_resp = await client.get( "/api/v1/billing/checkout/pro", headers={"Authorization": f"Bearer {token}"}, ) assert checkout_resp.status_code == 200 checkout_data = checkout_resp.json() assert checkout_data["redirect_url"].startswith("https://buy.stripe.com/test_pro_link") assert "client_reference_id=" + org.id in checkout_data["redirect_url"] # Step 2: Simulate Stripe webhook webhook_event = { "id": "evt_integration_test", "type": "checkout.session.completed", "data": { "object": { "id": "cs_integration_test", "client_reference_id": org.id, "customer": "cus_integration_test", "amount_total": 900, "line_items": {"data": []}, } }, } webhook_resp = await client.post( "/api/v1/billing/webhook", json=webhook_event, ) assert webhook_resp.status_code == 200 # Step 3: Verify org is now pro status_resp = await client.get( "/api/v1/billing/status", headers={"Authorization": f"Bearer {token}"}, ) assert status_resp.status_code == 200 assert status_resp.json()["tier"] == "pro" # Step 4: Verify billing status shows team upgrade available assert "team" in status_resp.json()["checkout_urls"] # Restore settings settings.stripe_webhook_secret = original_secret settings.stripe_pro_checkout_url = original_pro_url settings.stripe_team_checkout_url = original_team_url settings.stripe_api_key = original_api_key @pytest.mark.asyncio async def test_pro_to_free_downgrade_via_webhook(self, client, db_session): """Complete flow: pro org → subscription deleted → free tier.""" from app.config import settings original_secret = settings.stripe_webhook_secret original_api_key = settings.stripe_api_key settings.stripe_webhook_secret = "" settings.stripe_api_key = "" user, org, token = await _create_user_and_org(db_session, tier="pro") org.stripe_customer_id = "cus_downgrade_test" await db_session.flush() # Simulate subscription cancellation event = { "id": "evt_cancel_test", "type": "customer.subscription.deleted", "data": { "object": { "id": "sub_cancel_test", "customer": "cus_downgrade_test", } }, } response = await client.post( "/api/v1/billing/webhook", json=event, ) settings.stripe_webhook_secret = original_secret settings.stripe_api_key = original_api_key assert response.status_code == 200 # Verify org is now free result = await db_session.execute( select(Organization).where(Organization.id == org.id) ) updated_org = result.scalar_one() assert updated_org.tier == "free" # Verify billing status shows upgrade options status_resp = await client.get( "/api/v1/billing/status", headers={"Authorization": f"Bearer {token}"}, ) data = status_resp.json() assert data["tier"] == "free" @pytest.mark.asyncio async def test_checkout_url_includes_prefilled_email(self, client, db_session): """Checkout URL should include the user's email as prefilled_email.""" from app.config import settings original_pro = settings.stripe_pro_checkout_url settings.stripe_pro_checkout_url = "https://buy.stripe.com/test_pro" user, org, token = await _create_user_and_org(db_session) response = await client.get( "/api/v1/billing/checkout/pro", headers={"Authorization": f"Bearer {token}"}, ) settings.stripe_pro_checkout_url = original_pro data = response.json() assert f"prefilled_email={user.email}" in data["redirect_url"] or \ f"prefilled_email={user.email.replace('@', '%40')}" in data["redirect_url"]