From 158a6ee716493713a6fda65d041a6fcc3593a237 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Sat, 25 Apr 2026 10:18:38 +0000 Subject: [PATCH] =?UTF-8?q?feat:=20Stripe=20Checkout=20Link=20integration?= =?UTF-8?q?=20=E2=80=94=20billing=20API=20with=2029=20tests?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add billing module (app/api/billing.py) with 5 API endpoints: - GET /api/v1/billing/checkout/{tier} — redirect to Stripe Payment Links - GET /api/v1/billing/status — current org tier, limits, upgrade URLs - GET /api/v1/billing/success — Stripe success callback - GET /api/v1/billing/cancel — Stripe cancel callback - POST /api/v1/billing/webhook — handles 5 Stripe event types - Zero-code payment flow: uses pre-configured Stripe Payment Links with client_reference_id (org ID) and prefilled_email params - Webhook handler processes checkout.session.completed, customer.subscription.updated/deleted, invoice events - Stripe signature verification via stripe library (primary) or manual HMAC-SHA256 (fallback) - Tier determination from payment amount: =pro, 9=team - 4 new config settings in app/config.py: stripe_pro_checkout_url, stripe_team_checkout_url, stripe_webhook_secret, stripe_api_key - Added stripe>=5.0,<16.0 dependency - 29 tests in tests/test_billing.py (all passing) - Total: 98 tests passing (69 existing + 29 new) --- .env.example | 9 +- app/api/billing.py | 472 ++++++++++++++++++++++++ app/api/router.py | 2 + app/config.py | 6 + pyproject.toml | 1 + tests/test_billing.py | 822 ++++++++++++++++++++++++++++++++++++++++++ 6 files changed, 1311 insertions(+), 1 deletion(-) create mode 100644 app/api/billing.py create mode 100644 tests/test_billing.py diff --git a/.env.example b/.env.example index 0c939e0..7f68964 100644 --- a/.env.example +++ b/.env.example @@ -24,4 +24,11 @@ SMTP_FROM=noreply@example.com WEBHOOK_NOTIFY_URL= # Uptime Monitoring -MONITOR_CHECK_INTERVAL=60 \ No newline at end of file +MONITOR_CHECK_INTERVAL=60 + +# Stripe Checkout Links (set these in production) +# Create Payment Links in Stripe Dashboard → Products → Payment Links +STRIPE_PRO_CHECKOUT_URL= +STRIPE_TEAM_CHECKOUT_URL= +STRIPE_WEBHOOK_SECRET= +STRIPE_API_KEY= \ No newline at end of file diff --git a/app/api/billing.py b/app/api/billing.py new file mode 100644 index 0000000..64f2bb0 --- /dev/null +++ b/app/api/billing.py @@ -0,0 +1,472 @@ +"""Billing API: Stripe Checkout Link integration and webhook handler. + +This module implements the "zero-code" payment flow using Stripe Payment Links: +1. User clicks "Upgrade" → redirected to a Stripe-hosted checkout page +2. After payment, Stripe sends a webhook event to /api/v1/billing/webhook +3. The webhook handler updates the org's tier in the database +4. Stripe redirects the user back to /api/v1/billing/success + +This approach requires no Stripe server-side SDK for checkout creation — +just configuration of Payment Link URLs and a webhook endpoint for +receiving payment confirmations. +""" + +import hashlib +import hmac +import json +import logging +from urllib.parse import urlencode + +from fastapi import APIRouter, Depends, HTTPException, Request, status +from pydantic import BaseModel +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.auth import get_current_user, get_current_org +from app.config import settings +from app.dependencies import get_db +from app.models.saas_models import Organization, OrganizationMember, User +from app.services.tier_limits import get_tier_info + +logger = logging.getLogger(__name__) + +router = APIRouter(tags=["billing"]) + +# ── Tier-to-URL mapping ───────────────────────────────────────────────────── + +def _get_checkout_url(tier: str) -> str: + """Return the Stripe Payment Link URL for the given tier. + + Reads from settings at call time (not import time) so tests can override. + """ + if tier == "pro": + return settings.stripe_pro_checkout_url + elif tier == "team": + return settings.stripe_team_checkout_url + return "" + + +# ── Response schemas ───────────────────────────────────────────────────────── + +class CheckoutResponse(BaseModel): + """Response for checkout redirect endpoint.""" + redirect_url: str + tier: str + organization_id: str + + +class BillingStatusResponse(BaseModel): + """Response for billing status endpoint.""" + tier: str + tier_info: dict + stripe_customer_id: str | None = None + checkout_urls: dict[str, str] + + +class SuccessResponse(BaseModel): + """Response for successful checkout callback.""" + message: str + tier: str | None = None + + +class WebhookResponse(BaseModel): + """Response for webhook processing.""" + status: str + event_type: str | None = None + + +# ── Endpoints ──────────────────────────────────────────────────────────────── + +@router.get("/checkout/{tier}", response_model=CheckoutResponse) +async def create_checkout( + tier: str, + user: User = Depends(get_current_user), + org: Organization = Depends(get_current_org), +): + """Generate a Stripe Checkout Link redirect URL for the selected tier. + + This endpoint does NOT create a Stripe Checkout Session. Instead, it + returns the pre-configured Stripe Payment Link URL with the org's ID + attached as client_reference_id and the user's email prefilled. + + The frontend should redirect the user to this URL. After completing + payment on Stripe's hosted checkout, Stripe will: + 1. Send a webhook to /api/v1/billing/webhook + 2. Redirect the user to /api/v1/billing/success + """ + if tier not in ("pro", "team"): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Invalid tier '{tier}'. Must be 'pro' or 'team'.", + ) + + checkout_url = _get_checkout_url(tier) + if not checkout_url: + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail=( + f"Checkout for '{tier}' tier is not configured yet. " + "Please contact support or set up Stripe Payment Links." + ), + ) + + # Build redirect URL with org ID and prefilled email + params = urlencode({ + "client_reference_id": org.id, + "prefilled_email": user.email, + }) + redirect_url = f"{checkout_url}?{params}" + + return CheckoutResponse( + redirect_url=redirect_url, + tier=tier, + organization_id=org.id, + ) + + +@router.get("/status", response_model=BillingStatusResponse) +async def get_billing_status( + org: Organization = Depends(get_current_org), +): + """Get the current organization's billing status: tier, limits, and checkout URLs.""" + # Only show checkout URLs for tiers the org can upgrade to + checkout_urls: dict[str, str] = {} + if org.tier == "free": + if settings.stripe_pro_checkout_url: + checkout_urls["pro"] = settings.stripe_pro_checkout_url + if settings.stripe_team_checkout_url: + checkout_urls["team"] = settings.stripe_team_checkout_url + elif org.tier == "pro": + if settings.stripe_team_checkout_url: + checkout_urls["team"] = settings.stripe_team_checkout_url + + return BillingStatusResponse( + tier=org.tier or "free", + tier_info=get_tier_info(org), + stripe_customer_id=org.stripe_customer_id, + checkout_urls=checkout_urls, + ) + + +@router.get("/success", response_model=SuccessResponse) +async def billing_success( + session_id: str = "", +): + """Stripe redirects here after a successful checkout. + + In production, you would call stripe.checkout.sessions.retrieve(session_id) + to verify the session. For the MVP, the webhook handler already processed + the payment — this endpoint just confirms to the user that it worked. + + The session_id is provided by Stripe's redirect URL. + """ + return SuccessResponse( + message="Subscription activated! Your plan has been upgraded.", + tier=None, # Webhook already set the tier; frontend should refetch + ) + + +@router.get("/cancel") +async def billing_cancel(): + """Stripe redirects here if the user cancels checkout.""" + return { + "message": "Checkout was canceled. Your current plan is unchanged.", + "tier_change": None, + } + + +# ── Webhook handler ────────────────────────────────────────────────────────── + +# Stripe event types we handle +HANDLED_EVENTS = { + "checkout.session.completed", + "customer.subscription.updated", + "customer.subscription.deleted", + "invoice.payment_succeeded", + "invoice.payment_failed", +} + + +def _determine_tier_from_amount(amount_total: int) -> str: + """Determine the tier from the payment amount (in cents). + + Args: + amount_total: Total amount in cents (e.g. 900 = $9.00) + + Returns: + Tier string: "pro" or "team" + """ + if amount_total >= 2900: # $29.00+ + return "team" + elif amount_total >= 900: # $9.00+ + return "pro" + return "free" + + +def _determine_tier_from_price_id(price_id: str) -> str: + """Determine the tier from a Stripe price ID. + + This is a fallback when amount isn't available. In production, + you would map your actual Stripe price IDs here. + + Args: + price_id: Stripe price ID (e.g., price_1ABC123...) + + Returns: + Tier string: "pro", "team", or "free" + """ + # Default mapping — override in production with your actual price IDs + # The stripe API can also be used to look up price details + return "pro" # Default fallback + + +async def _verify_stripe_signature(payload: bytes, sig_header: str) -> dict | None: + """Verify the Stripe webhook signature using HMAC-SHA256. + + Uses the stripe library if available, falls back to manual verification. + + Args: + payload: Raw request body bytes + sig_header: Stripe-Signature header value + + Returns: + Parsed event dict if signature is valid, None otherwise + """ + if not settings.stripe_webhook_secret: + logger.warning("No Stripe webhook secret configured; skipping verification") + return json.loads(payload) + + try: + import stripe + if settings.stripe_api_key: + stripe.api_key = settings.stripe_api_key + event = stripe.Webhook.construct_event( + payload, sig_header, settings.stripe_webhook_secret + ) + return dict(event) + except ImportError: + logger.info("stripe library not available; using manual signature verification") + except stripe.error.SignatureVerificationError: + return None + except Exception as e: + logger.error(f"Stripe webhook verification error: {e}") + return None + + # Manual HMAC-SHA256 verification fallback + # Stripe webhook signatures have the format: t=v1(v1_sig),v0=v0(v0_sig) + try: + elements = sig_header.split(",") + timestamp = None + signatures = [] + for element in elements: + key, value = element.split("=", 1) + if key == "t": + timestamp = value + elif key == "v1": + signatures.append(value) + + if not timestamp or not signatures: + return None + + # Compute expected signature + signed_payload = f"{timestamp}.{payload.decode('utf-8')}" + expected_sig = hmac.new( + settings.stripe_webhook_secret.encode("utf-8"), + signed_payload.encode("utf-8"), + hashlib.sha256, + ).hexdigest() + + # Check if any v1 signature matches + for sig in signatures: + if hmac.compare_digest(expected_sig, sig): + return json.loads(payload) + + return None + except Exception: + return None + + +@router.post("/webhook") +async def stripe_webhook(request: Request, db: AsyncSession = Depends(get_db)): + """Handle Stripe webhook events for subscription changes. + + This is the ONLY server-side Stripe code needed for the Payment Links + approach. It receives events when: + - checkout.session.completed: User completed a checkout + - customer.subscription.updated: Subscription tier changed + - customer.subscription.deleted: Subscription canceled (downgrade to free) + - invoice.payment_succeeded: Recurring payment succeeded + - invoice.payment_failed: Payment failed (may downgrade) + + The webhook verifies the Stripe signature to ensure authenticity. + """ + body = await request.body() + sig_header = request.headers.get("stripe-signature", "") + + # Verify the webhook signature + event = await _verify_stripe_signature(body, sig_header) + if event is None: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Invalid webhook signature", + ) + + event_type = event.get("type", "") + + # Only handle known event types + if event_type not in HANDLED_EVENTS: + logger.info(f"Ignoring unhandled Stripe event type: {event_type}") + return WebhookResponse(status="ignored", event_type=event_type) + + # ── Handle checkout.session.completed ──────────────────────────────── + if event_type == "checkout.session.completed": + session_obj = event.get("data", {}).get("object", {}) + org_id = session_obj.get("client_reference_id") + customer_id = session_obj.get("customer") + amount_total = session_obj.get("amount_total", 0) + + if not org_id: + logger.warning("checkout.session.completed event missing client_reference_id") + return WebhookResponse(status="error", event_type=event_type) + + # Find the organization + result = await db.execute( + select(Organization).where(Organization.id == org_id) + ) + org = result.scalar_one_or_none() + if not org: + logger.warning(f"Organization {org_id} not found for Stripe checkout") + return WebhookResponse(status="error", event_type=event_type) + + # Determine the tier from the payment amount + new_tier = _determine_tier_from_amount(amount_total) + + # Try to determine tier from line items if available + line_items = session_obj.get("line_items", {}).get("data", []) + if line_items: + price_id = line_items[0].get("price", {}).get("id", "") + if price_id: + new_tier = _determine_tier_from_price_id(price_id) + + # Update the organization + old_tier = org.tier + org.tier = new_tier + if customer_id: + org.stripe_customer_id = customer_id + + await db.flush() + logger.info( + f"Organization {org.slug} upgraded from {old_tier} to {new_tier} " + f"(Stripe session: {session_obj.get('id', 'unknown')})" + ) + + return WebhookResponse(status="ok", event_type=event_type) + + # ── Handle customer.subscription.updated ───────────────────────────── + if event_type == "customer.subscription.updated": + subscription = event.get("data", {}).get("object", {}) + customer_id = subscription.get("customer") + + if not customer_id: + return WebhookResponse(status="error", event_type=event_type) + + result = await db.execute( + select(Organization).where( + Organization.stripe_customer_id == customer_id + ) + ) + org = result.scalar_one_or_none() + if not org: + logger.warning(f"No org found for Stripe customer {customer_id}") + return WebhookResponse(status="error", event_type=event_type) + + # Determine tier from subscription metadata or amount + metadata = subscription.get("metadata", {}) + tier_from_meta = metadata.get("tier") + + # Check plan amount from items + items = subscription.get("items", {}).get("data", []) + if items: + amount = items[0].get("plan", {}).get("amount", 0) + new_tier = _determine_tier_from_amount(amount) + elif tier_from_meta in ("pro", "team"): + new_tier = tier_from_meta + else: + new_tier = org.tier # Keep current if indeterminate + + old_tier = org.tier + org.tier = new_tier + await db.flush() + logger.info(f"Organization {org.slug} subscription updated: {old_tier} → {new_tier}") + + return WebhookResponse(status="ok", event_type=event_type) + + # ── Handle customer.subscription.deleted ────────────────────────────── + if event_type == "customer.subscription.deleted": + subscription = event.get("data", {}).get("object", {}) + customer_id = subscription.get("customer") + + if not customer_id: + return WebhookResponse(status="error", event_type=event_type) + + result = await db.execute( + select(Organization).where( + Organization.stripe_customer_id == customer_id + ) + ) + org = result.scalar_one_or_none() + if not org: + logger.warning(f"No org found for Stripe customer {customer_id}") + return WebhookResponse(status="error", event_type=event_type) + + old_tier = org.tier + org.tier = "free" + org.stripe_customer_id = None # Clear stripe ID on cancellation + await db.flush() + logger.info(f"Organization {org.slug} downgraded from {old_tier} to free (subscription deleted)") + + return WebhookResponse(status="ok", event_type=event_type) + + # ── Handle invoice.payment_succeeded ───────────────────────────────── + if event_type == "invoice.payment_succeeded": + # Payment renewed successfully — org keeps its tier + # We could update stripe_customer_id here if needed + invoice = event.get("data", {}).get("object", {}) + customer_id = invoice.get("customer") + + if customer_id: + result = await db.execute( + select(Organization).where( + Organization.stripe_customer_id == customer_id + ) + ) + org = result.scalar_one_or_none() + if org: + logger.info(f"Invoice payment succeeded for org {org.slug} (tier: {org.tier})") + + return WebhookResponse(status="ok", event_type=event_type) + + # ── Handle invoice.payment_failed ──────────────────────────────────── + if event_type == "invoice.payment_failed": + invoice = event.get("data", {}).get("object", {}) + customer_id = invoice.get("customer") + + if customer_id: + result = await db.execute( + select(Organization).where( + Organization.stripe_customer_id == customer_id + ) + ) + org = result.scalar_one_or_none() + if org: + logger.warning( + f"Invoice payment failed for org {org.slug} (tier: {org.tier}). " + "Payment retry will be attempted by Stripe." + ) + # Don't downgrade yet — Stripe will retry per its configured rules + # After all retries fail, customer.subscription.deleted will fire + + return WebhookResponse(status="ok", event_type=event_type) + + # Fallback (shouldn't reach here due to HANDLED_EVENTS check) + return WebhookResponse(status="ignored", event_type=event_type) \ No newline at end of file diff --git a/app/api/router.py b/app/api/router.py index 0bd6de9..378eb97 100644 --- a/app/api/router.py +++ b/app/api/router.py @@ -6,12 +6,14 @@ from app.api.monitors import router as monitors_router from app.api.subscribers import router as subscribers_router from app.api.settings import router as settings_router from app.api.organizations import router as organizations_router +from app.api.billing import router as billing_router from app.routes.auth import router as auth_router api_v1_router = APIRouter() api_v1_router.include_router(auth_router, tags=["auth"]) api_v1_router.include_router(organizations_router, prefix="/organizations", tags=["organizations"]) +api_v1_router.include_router(billing_router, prefix="/billing", tags=["billing"]) api_v1_router.include_router(services_router, prefix="/services", tags=["services"]) api_v1_router.include_router(incidents_router, prefix="/incidents", tags=["incidents"]) api_v1_router.include_router(monitors_router, prefix="/monitors", tags=["monitors"]) diff --git a/app/config.py b/app/config.py index 8ba1e64..dff4c60 100644 --- a/app/config.py +++ b/app/config.py @@ -31,6 +31,12 @@ class Settings(BaseSettings): # Uptime monitoring monitor_check_interval: int = 60 + # Stripe Checkout Links + stripe_pro_checkout_url: str = "" # e.g. https://buy.stripe.com/xxxx_pro + stripe_team_checkout_url: str = "" # e.g. https://buy.stripe.com/xxxx_team + stripe_webhook_secret: str = "" # e.g. whsec_xxxx + stripe_api_key: str = "" # e.g. sk_test_xxxx (needed for webhook verification) + model_config = {"env_file": ".env", "env_file_encoding": "utf-8"} @property diff --git a/pyproject.toml b/pyproject.toml index 23279a2..8608564 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,6 +24,7 @@ dependencies = [ "passlib[bcrypt]>=1.7,<2.0", "bcrypt==4.0.1", "email-validator>=2.0,<3.0", + "stripe>=5.0,<16.0", ] [project.optional-dependencies] diff --git a/tests/test_billing.py b/tests/test_billing.py new file mode 100644 index 0000000..8924e20 --- /dev/null +++ b/tests/test_billing.py @@ -0,0 +1,822 @@ +"""Tests for Stripe Checkout Link billing integration. + +Covers: +- Checkout URL generation for pro/team tiers +- Invalid tier handling +- Missing checkout URL handling +- Billing status endpoint +- Webhook handling (checkout.session.completed, subscription events) +- Manual HMAC signature verification +- Tier determination from payment amounts +""" + +import hashlib +import hmac +import json +import time + +import pytest +import pytest_asyncio +from httpx import ASGITransport, AsyncClient +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine + +from app.database import Base +from app.dependencies import get_db +from app.auth import create_access_token, hash_password +from app.main import app +from app.models.saas_models import Organization, OrganizationMember, User + +# ── Test database setup ────────────────────────────────────────────────────── + +TEST_DATABASE_URL = "sqlite+aiosqlite:///:memory:" +test_engine = create_async_engine(TEST_DATABASE_URL, echo=False) +TestSessionLocal = async_sessionmaker( + test_engine, class_=AsyncSession, expire_on_commit=False +) + + +@pytest_asyncio.fixture(scope="session", autouse=True) +async def setup_database(): + """Create all tables once for the test session.""" + import app.models.saas_models # noqa: F401 + import app.models.models # noqa: F401 + + async with test_engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + yield + async with test_engine.begin() as conn: + await conn.run_sync(Base.metadata.drop_all) + + +@pytest_asyncio.fixture +async def db_session(): + """Provide a clean database session for each test.""" + async with TestSessionLocal() as session: + yield session + await session.rollback() + + +@pytest_asyncio.fixture +async def client(db_session: AsyncSession): + """Provide an HTTP test client with DB dependency override.""" + async def override_get_db(): + yield db_session + + app.dependency_overrides[get_db] = override_get_db + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as ac: + yield ac + app.dependency_overrides.clear() + + +async def _create_user_and_org(db: AsyncSession, tier: str = "free"): + """Helper to create a user, org, and auth token.""" + user = User( + email=f"test-{tier}-{hashlib.md5(str(time.time()).encode()).hexdigest()[:8]}@example.com", + password_hash=hash_password("testpass123"), + ) + db.add(user) + await db.flush() + + org = Organization( + name=f"Test Org {tier}", + slug=f"test-org-{tier}-{hashlib.md5(str(time.time()).encode()).hexdigest()[:8]}", + tier=tier, + ) + db.add(org) + await db.flush() + + membership = OrganizationMember( + organization_id=org.id, + user_id=user.id, + role="owner", + ) + db.add(membership) + await db.flush() + + token = create_access_token(user.id) + return user, org, token + + +# ── Checkout endpoint tests ────────────────────────────────────────────────── + +class TestCheckoutEndpoint: + """Tests for GET /api/v1/billing/checkout/{tier}""" + + @pytest.mark.asyncio + async def test_checkout_pro_returns_redirect_url(self, client, db_session): + """Pro checkout should return a redirect URL with org ID and email.""" + from app.config import settings + original_pro_url = settings.stripe_pro_checkout_url + settings.stripe_pro_checkout_url = "https://buy.stripe.com/test_pro" + + user, org, token = await _create_user_and_org(db_session) + + response = await client.get( + "/api/v1/billing/checkout/pro", + headers={"Authorization": f"Bearer {token}"}, + ) + + settings.stripe_pro_checkout_url = original_pro_url + + assert response.status_code == 200 + data = response.json() + assert "redirect_url" in data + assert data["tier"] == "pro" + assert data["organization_id"] == org.id + assert "client_reference_id" in data["redirect_url"] + assert "prefilled_email" in data["redirect_url"] + assert data["redirect_url"].startswith("https://buy.stripe.com/test_pro?") + + @pytest.mark.asyncio + async def test_checkout_team_returns_redirect_url(self, client, db_session): + """Team checkout should return a redirect URL.""" + from app.config import settings + original_team_url = settings.stripe_team_checkout_url + settings.stripe_team_checkout_url = "https://buy.stripe.com/test_team" + + user, org, token = await _create_user_and_org(db_session) + + response = await client.get( + "/api/v1/billing/checkout/team", + headers={"Authorization": f"Bearer {token}"}, + ) + + settings.stripe_team_checkout_url = original_team_url + + assert response.status_code == 200 + data = response.json() + assert data["tier"] == "team" + assert data["organization_id"] == org.id + + @pytest.mark.asyncio + async def test_checkout_invalid_tier(self, client, db_session): + """Invalid tier should return 400.""" + user, org, token = await _create_user_and_org(db_session) + + response = await client.get( + "/api/v1/billing/checkout/enterprise", + headers={"Authorization": f"Bearer {token}"}, + ) + + assert response.status_code == 400 + assert "Invalid tier" in response.json()["detail"] + + @pytest.mark.asyncio + async def test_checkout_free_tier_invalid(self, client, db_session): + """Free tier checkout should return 400 (can't upgrade to free).""" + user, org, token = await _create_user_and_org(db_session) + + response = await client.get( + "/api/v1/billing/checkout/free", + headers={"Authorization": f"Bearer {token}"}, + ) + + assert response.status_code == 400 + + @pytest.mark.asyncio + async def test_checkout_missing_url_returns_503(self, client, db_session): + """If checkout URL is not configured, return 503.""" + from app.config import settings + original_pro_url = settings.stripe_pro_checkout_url + settings.stripe_pro_checkout_url = "" # Empty = not configured + + user, org, token = await _create_user_and_org(db_session) + + response = await client.get( + "/api/v1/billing/checkout/pro", + headers={"Authorization": f"Bearer {token}"}, + ) + + settings.stripe_pro_checkout_url = original_pro_url + + assert response.status_code == 503 + assert "not configured" in response.json()["detail"].lower() + + @pytest.mark.asyncio + async def test_checkout_requires_auth(self, client, db_session): + """Checkout endpoint requires authentication.""" + response = await client.get("/api/v1/billing/checkout/pro") + assert response.status_code == 401 + + +# ── Billing status endpoint tests ───────────────────────────────────────────── + +class TestBillingStatusEndpoint: + """Tests for GET /api/v1/billing/status""" + + @pytest.mark.asyncio + async def test_billing_status_free_org(self, client, db_session): + """Free org should see pro and team checkout URLs.""" + from app.config import settings + original_pro = settings.stripe_pro_checkout_url + original_team = settings.stripe_team_checkout_url + settings.stripe_pro_checkout_url = "https://buy.stripe.com/pro" + settings.stripe_team_checkout_url = "https://buy.stripe.com/team" + + user, org, token = await _create_user_and_org(db_session, tier="free") + + response = await client.get( + "/api/v1/billing/status", + headers={"Authorization": f"Bearer {token}"}, + ) + + settings.stripe_pro_checkout_url = original_pro + settings.stripe_team_checkout_url = original_team + + assert response.status_code == 200 + data = response.json() + assert data["tier"] == "free" + assert "tier_info" in data + assert "pro" in data["checkout_urls"] + assert "team" in data["checkout_urls"] + + @pytest.mark.asyncio + async def test_billing_status_pro_org(self, client, db_session): + """Pro org should only see team upgrade URL.""" + from app.config import settings + original_team = settings.stripe_team_checkout_url + settings.stripe_team_checkout_url = "https://buy.stripe.com/team" + + user, org, token = await _create_user_and_org(db_session, tier="pro") + + response = await client.get( + "/api/v1/billing/status", + headers={"Authorization": f"Bearer {token}"}, + ) + + settings.stripe_team_checkout_url = original_team + + assert response.status_code == 200 + data = response.json() + assert data["tier"] == "pro" + assert "pro" not in data["checkout_urls"] + assert "team" in data["checkout_urls"] + + @pytest.mark.asyncio + async def test_billing_status_team_org(self, client, db_session): + """Team org should see no upgrade URLs (already at top tier).""" + user, org, token = await _create_user_and_org(db_session, tier="team") + + response = await client.get( + "/api/v1/billing/status", + headers={"Authorization": f"Bearer {token}"}, + ) + + assert response.status_code == 200 + data = response.json() + assert data["tier"] == "team" + assert data["checkout_urls"] == {} + + @pytest.mark.asyncio + async def test_billing_status_shows_tier_limits(self, client, db_session): + """Billing status should include tier limits from tier_limits module.""" + user, org, token = await _create_user_and_org(db_session, tier="free") + + response = await client.get( + "/api/v1/billing/status", + headers={"Authorization": f"Bearer {token}"}, + ) + + data = response.json() + assert "tier_info" in data + assert data["tier_info"]["tier"] == "free" + assert "limits" in data["tier_info"] + + @pytest.mark.asyncio + async def test_billing_status_requires_auth(self, client, db_session): + """Billing status endpoint requires authentication.""" + response = await client.get("/api/v1/billing/status") + assert response.status_code == 401 + + +# ── Success/cancel callback tests ──────────────────────────────────────────── + +class TestBillingCallbacks: + """Tests for GET /api/v1/billing/success and /cancel""" + + @pytest.mark.asyncio + async def test_success_callback(self, client, db_session): + """Success callback should return success message.""" + response = await client.get( + "/api/v1/billing/success?session_id=cs_test_123" + ) + assert response.status_code == 200 + data = response.json() + assert "activated" in data["message"].lower() or "upgraded" in data["message"].lower() + + @pytest.mark.asyncio + async def test_success_callback_without_session_id(self, client, db_session): + """Success callback works even without session_id (webhook already processed).""" + response = await client.get("/api/v1/billing/success") + assert response.status_code == 200 + + @pytest.mark.asyncio + async def test_cancel_callback(self, client, db_session): + """Cancel callback should return cancellation message.""" + response = await client.get("/api/v1/billing/cancel") + assert response.status_code == 200 + data = response.json() + assert "canceled" in data["message"].lower() + + +# ── Webhook tests ───────────────────────────────────────────────────────────── + +class TestWebhookHandler: + """Tests for POST /api/v1/billing/webhook""" + + def _make_checkout_session_event( + self, org_id: str, amount_total: int = 900, customer_id: str = "cus_test123" + ) -> dict: + """Create a mock checkout.session.completed event.""" + return { + "id": "evt_test_123", + "object": "event", + "type": "checkout.session.completed", + "data": { + "object": { + "id": "cs_test_123", + "object": "checkout.session", + "client_reference_id": org_id, + "customer": customer_id, + "amount_total": amount_total, + "line_items": {"data": []}, + } + }, + } + + def _make_subscription_event( + self, customer_id: str, event_type: str, amount: int = 900 + ) -> dict: + """Create a mock subscription event.""" + return { + "id": "evt_test_sub_123", + "object": "event", + "type": event_type, + "data": { + "object": { + "id": "sub_test_123", + "object": "subscription", + "customer": customer_id, + "metadata": {}, + "items": { + "data": [ + { + "plan": {"amount": amount}, + } + ] + }, + } + }, + } + + @pytest.mark.asyncio + async def test_webhook_checkout_session_pro(self, client, db_session): + """checkout.session.completed with $9 amount should set tier to pro.""" + from app.config import settings + original_secret = settings.stripe_webhook_secret + original_api_key = settings.stripe_api_key + # Disable signature verification for tests + settings.stripe_webhook_secret = "" + settings.stripe_api_key = "" + + user, org, token = await _create_user_and_org(db_session, tier="free") + + event = self._make_checkout_session_event(org.id, amount_total=900) + response = await client.post( + "/api/v1/billing/webhook", + json=event, + ) + + settings.stripe_webhook_secret = original_secret + settings.stripe_api_key = original_api_key + + assert response.status_code == 200 + data = response.json() + assert data["status"] == "ok" + assert data["event_type"] == "checkout.session.completed" + + # Verify org was upgraded + result = await db_session.execute( + select(Organization).where(Organization.id == org.id) + ) + updated_org = result.scalar_one() + assert updated_org.tier == "pro" + assert updated_org.stripe_customer_id == "cus_test123" + + @pytest.mark.asyncio + async def test_webhook_checkout_session_team(self, client, db_session): + """checkout.session.completed with $29 amount should set tier to team.""" + from app.config import settings + original_secret = settings.stripe_webhook_secret + original_api_key = settings.stripe_api_key + settings.stripe_webhook_secret = "" + settings.stripe_api_key = "" + + user, org, token = await _create_user_and_org(db_session, tier="free") + + event = self._make_checkout_session_event(org.id, amount_total=2900) + response = await client.post( + "/api/v1/billing/webhook", + json=event, + ) + + settings.stripe_webhook_secret = original_secret + settings.stripe_api_key = original_api_key + + assert response.status_code == 200 + result = await db_session.execute( + select(Organization).where(Organization.id == org.id) + ) + updated_org = result.scalar_one() + assert updated_org.tier == "team" + + @pytest.mark.asyncio + async def test_webhook_subscription_deleted_downgrades_to_free(self, client, db_session): + """customer.subscription.deleted should downgrade org to free.""" + from app.config import settings + original_secret = settings.stripe_webhook_secret + original_api_key = settings.stripe_api_key + settings.stripe_webhook_secret = "" + settings.stripe_api_key = "" + + user, org, token = await _create_user_and_org(db_session, tier="pro") + # Manually set stripe_customer_id + org.stripe_customer_id = "cus_delete_test" + await db_session.flush() + + event = self._make_subscription_event( + customer_id="cus_delete_test", + event_type="customer.subscription.deleted", + ) + response = await client.post( + "/api/v1/billing/webhook", + json=event, + ) + + settings.stripe_webhook_secret = original_secret + settings.stripe_api_key = original_api_key + + assert response.status_code == 200 + result = await db_session.execute( + select(Organization).where(Organization.id == org.id) + ) + updated_org = result.scalar_one() + assert updated_org.tier == "free" + + @pytest.mark.asyncio + async def test_webhook_subscription_updated(self, client, db_session): + """customer.subscription.updated should update org tier based on amount.""" + from app.config import settings + original_secret = settings.stripe_webhook_secret + original_api_key = settings.stripe_api_key + settings.stripe_webhook_secret = "" + settings.stripe_api_key = "" + + user, org, token = await _create_user_and_org(db_session, tier="pro") + org.stripe_customer_id = "cus_update_test" + await db_session.flush() + + event = self._make_subscription_event( + customer_id="cus_update_test", + event_type="customer.subscription.updated", + amount=2900, # Upgrade to team + ) + response = await client.post( + "/api/v1/billing/webhook", + json=event, + ) + + settings.stripe_webhook_secret = original_secret + settings.stripe_api_key = original_api_key + + assert response.status_code == 200 + result = await db_session.execute( + select(Organization).where(Organization.id == org.id) + ) + updated_org = result.scalar_one() + assert updated_org.tier == "team" + + @pytest.mark.asyncio + async def test_webhook_ignores_unhandled_events(self, client, db_session): + """Unhandled event types should be ignored.""" + from app.config import settings + original_secret = settings.stripe_webhook_secret + settings.stripe_webhook_secret = "" + + event = { + "id": "evt_test_ignore", + "type": "account.updated", + "data": {"object": {}}, + } + response = await client.post( + "/api/v1/billing/webhook", + json=event, + ) + + settings.stripe_webhook_secret = original_secret + + assert response.status_code == 200 + data = response.json() + assert data["status"] == "ignored" + + @pytest.mark.asyncio + async def test_webhook_invalid_signature_rejected(self, client, db_session): + """Webhook with invalid signature should be rejected.""" + from app.config import settings + original_secret = settings.stripe_webhook_secret + original_api_key = settings.stripe_api_key + # Set up webhook secret but no API key (forces manual verification) + settings.stripe_webhook_secret = "whsec_test_secret_key_12345" + settings.stripe_api_key = "" + + event = self._make_checkout_session_event("fake_org_id") + + response = await client.post( + "/api/v1/billing/webhook", + json=event, + headers={"stripe-signature": "t=1234,v1=invalid_signature_here"}, + ) + + settings.stripe_webhook_secret = original_secret + settings.stripe_api_key = original_api_key + + assert response.status_code == 400 + assert "signature" in response.json()["detail"].lower() + + @pytest.mark.asyncio + async def test_webhook_checkout_missing_org_id(self, client, db_session): + """Webhook without client_reference_id should return error.""" + from app.config import settings + original_secret = settings.stripe_webhook_secret + settings.stripe_webhook_secret = "" + + event = { + "id": "evt_test_no_org", + "type": "checkout.session.completed", + "data": { + "object": { + "id": "cs_test_no_org", + "client_reference_id": None, + "customer": "cus_123", + "amount_total": 900, + } + }, + } + + response = await client.post( + "/api/v1/billing/webhook", + json=event, + ) + + settings.stripe_webhook_secret = original_secret + + assert response.status_code == 200 + assert response.json()["status"] == "error" + + @pytest.mark.asyncio + async def test_webhook_invoice_payment_succeeded(self, client, db_session): + """invoice.payment_succeeded should acknowledge without downgrading.""" + from app.config import settings + original_secret = settings.stripe_webhook_secret + original_api_key = settings.stripe_api_key + settings.stripe_webhook_secret = "" + settings.stripe_api_key = "" + + user, org, token = await _create_user_and_org(db_session, tier="pro") + org.stripe_customer_id = "cus_invoice_test" + await db_session.flush() + + event = { + "id": "evt_invoice_paid", + "type": "invoice.payment_succeeded", + "data": { + "object": { + "id": "in_test_123", + "customer": "cus_invoice_test", + } + }, + } + + response = await client.post( + "/api/v1/billing/webhook", + json=event, + ) + + settings.stripe_webhook_secret = original_secret + settings.stripe_api_key = original_api_key + + assert response.status_code == 200 + data = response.json() + assert data["status"] == "ok" + + # Verify tier wasn't changed + result = await db_session.execute( + select(Organization).where(Organization.id == org.id) + ) + updated_org = result.scalar_one() + assert updated_org.tier == "pro" + + @pytest.mark.asyncio + async def test_webhook_invoice_payment_failed(self, client, db_session): + """invoice.payment_failed should not downgrade (Stripe retries).""" + from app.config import settings + original_secret = settings.stripe_webhook_secret + original_api_key = settings.stripe_api_key + settings.stripe_webhook_secret = "" + settings.stripe_api_key = "" + + user, org, token = await _create_user_and_org(db_session, tier="pro") + org.stripe_customer_id = "cus_fail_test" + await db_session.flush() + + event = { + "id": "evt_invoice_failed", + "type": "invoice.payment_failed", + "data": { + "object": { + "id": "in_test_fail", + "customer": "cus_fail_test", + } + }, + } + + response = await client.post( + "/api/v1/billing/webhook", + json=event, + ) + + settings.stripe_webhook_secret = original_secret + settings.stripe_api_key = original_api_key + + assert response.status_code == 200 + data = response.json() + assert data["status"] == "ok" + + # Verify tier wasn't changed (payment will be retried) + result = await db_session.execute( + select(Organization).where(Organization.id == org.id) + ) + updated_org = result.scalar_one() + assert updated_org.tier == "pro" + + +# ── Tier determination helper tests ─────────────────────────────────────────── + +class TestTierDetermination: + """Tests for _determine_tier_from_amount helper.""" + + def test_pro_from_amount(self): + from app.api.billing import _determine_tier_from_amount + assert _determine_tier_from_amount(900) == "pro" + assert _determine_tier_from_amount(1500) == "pro" + + def test_team_from_amount(self): + from app.api.billing import _determine_tier_from_amount + assert _determine_tier_from_amount(2900) == "team" + assert _determine_tier_from_amount(5000) == "team" + assert _determine_tier_from_amount(9900) == "team" + + def test_free_from_small_amount(self): + from app.api.billing import _determine_tier_from_amount + assert _determine_tier_from_amount(0) == "free" + assert _determine_tier_from_amount(100) == "free" + + +# ── Integration: checkout → webhook → tier update ──────────────────────────── + +class TestBillingIntegration: + """End-to-end style tests for the billing flow.""" + + @pytest.mark.asyncio + async def test_free_to_pro_upgrade_flow(self, client, db_session): + """Complete flow: free org → checkout → webhook → pro tier.""" + from app.config import settings + original_secret = settings.stripe_webhook_secret + original_pro_url = settings.stripe_pro_checkout_url + original_team_url = settings.stripe_team_checkout_url + original_api_key = settings.stripe_api_key + + settings.stripe_webhook_secret = "" + settings.stripe_pro_checkout_url = "https://buy.stripe.com/test_pro_link" + settings.stripe_team_checkout_url = "https://buy.stripe.com/test_team_link" + settings.stripe_api_key = "" + + user, org, token = await _create_user_and_org(db_session, tier="free") + + # Step 1: Get checkout URL + checkout_resp = await client.get( + "/api/v1/billing/checkout/pro", + headers={"Authorization": f"Bearer {token}"}, + ) + assert checkout_resp.status_code == 200 + checkout_data = checkout_resp.json() + assert checkout_data["redirect_url"].startswith("https://buy.stripe.com/test_pro_link") + assert "client_reference_id=" + org.id in checkout_data["redirect_url"] + + # Step 2: Simulate Stripe webhook + webhook_event = { + "id": "evt_integration_test", + "type": "checkout.session.completed", + "data": { + "object": { + "id": "cs_integration_test", + "client_reference_id": org.id, + "customer": "cus_integration_test", + "amount_total": 900, + "line_items": {"data": []}, + } + }, + } + webhook_resp = await client.post( + "/api/v1/billing/webhook", + json=webhook_event, + ) + assert webhook_resp.status_code == 200 + + # Step 3: Verify org is now pro + status_resp = await client.get( + "/api/v1/billing/status", + headers={"Authorization": f"Bearer {token}"}, + ) + assert status_resp.status_code == 200 + assert status_resp.json()["tier"] == "pro" + + # Step 4: Verify billing status shows team upgrade available + assert "team" in status_resp.json()["checkout_urls"] + + # Restore settings + settings.stripe_webhook_secret = original_secret + settings.stripe_pro_checkout_url = original_pro_url + settings.stripe_team_checkout_url = original_team_url + settings.stripe_api_key = original_api_key + + @pytest.mark.asyncio + async def test_pro_to_free_downgrade_via_webhook(self, client, db_session): + """Complete flow: pro org → subscription deleted → free tier.""" + from app.config import settings + original_secret = settings.stripe_webhook_secret + original_api_key = settings.stripe_api_key + settings.stripe_webhook_secret = "" + settings.stripe_api_key = "" + + user, org, token = await _create_user_and_org(db_session, tier="pro") + org.stripe_customer_id = "cus_downgrade_test" + await db_session.flush() + + # Simulate subscription cancellation + event = { + "id": "evt_cancel_test", + "type": "customer.subscription.deleted", + "data": { + "object": { + "id": "sub_cancel_test", + "customer": "cus_downgrade_test", + } + }, + } + response = await client.post( + "/api/v1/billing/webhook", + json=event, + ) + + settings.stripe_webhook_secret = original_secret + settings.stripe_api_key = original_api_key + + assert response.status_code == 200 + + # Verify org is now free + result = await db_session.execute( + select(Organization).where(Organization.id == org.id) + ) + updated_org = result.scalar_one() + assert updated_org.tier == "free" + + # Verify billing status shows upgrade options + status_resp = await client.get( + "/api/v1/billing/status", + headers={"Authorization": f"Bearer {token}"}, + ) + data = status_resp.json() + assert data["tier"] == "free" + + @pytest.mark.asyncio + async def test_checkout_url_includes_prefilled_email(self, client, db_session): + """Checkout URL should include the user's email as prefilled_email.""" + from app.config import settings + original_pro = settings.stripe_pro_checkout_url + settings.stripe_pro_checkout_url = "https://buy.stripe.com/test_pro" + + user, org, token = await _create_user_and_org(db_session) + + response = await client.get( + "/api/v1/billing/checkout/pro", + headers={"Authorization": f"Bearer {token}"}, + ) + + settings.stripe_pro_checkout_url = original_pro + + data = response.json() + assert f"prefilled_email={user.email}" in data["redirect_url"] or \ + f"prefilled_email={user.email.replace('@', '%40')}" in data["redirect_url"] \ No newline at end of file