"""Billing API: Stripe Checkout Link integration and webhook handler. This module implements the "zero-code" payment flow using Stripe Payment Links: 1. User clicks "Upgrade" → redirected to a Stripe-hosted checkout page 2. After payment, Stripe sends a webhook event to /api/v1/billing/webhook 3. The webhook handler updates the org's tier in the database 4. Stripe redirects the user back to /api/v1/billing/success This approach requires no Stripe server-side SDK for checkout creation — just configuration of Payment Link URLs and a webhook endpoint for receiving payment confirmations. """ import hashlib import hmac import json import logging from urllib.parse import urlencode from fastapi import APIRouter, Depends, HTTPException, Request, status from pydantic import BaseModel from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from app.auth import get_current_user, get_current_org from app.config import settings from app.dependencies import get_db from app.models.saas_models import Organization, OrganizationMember, User from app.services.tier_limits import get_tier_info logger = logging.getLogger(__name__) router = APIRouter(tags=["billing"]) # ── Tier-to-URL mapping ───────────────────────────────────────────────────── def _get_checkout_url(tier: str) -> str: """Return the Stripe Payment Link URL for the given tier. Reads from settings at call time (not import time) so tests can override. """ if tier == "pro": return settings.stripe_pro_checkout_url elif tier == "team": return settings.stripe_team_checkout_url return "" # ── Response schemas ───────────────────────────────────────────────────────── class CheckoutResponse(BaseModel): """Response for checkout redirect endpoint.""" redirect_url: str tier: str organization_id: str class BillingStatusResponse(BaseModel): """Response for billing status endpoint.""" tier: str tier_info: dict stripe_customer_id: str | None = None checkout_urls: dict[str, str] class SuccessResponse(BaseModel): """Response for successful checkout callback.""" message: str tier: str | None = None class WebhookResponse(BaseModel): """Response for webhook processing.""" status: str event_type: str | None = None # ── Endpoints ──────────────────────────────────────────────────────────────── @router.get("/checkout/{tier}", response_model=CheckoutResponse) async def create_checkout( tier: str, user: User = Depends(get_current_user), org: Organization = Depends(get_current_org), ): """Generate a Stripe Checkout Link redirect URL for the selected tier. This endpoint does NOT create a Stripe Checkout Session. Instead, it returns the pre-configured Stripe Payment Link URL with the org's ID attached as client_reference_id and the user's email prefilled. The frontend should redirect the user to this URL. After completing payment on Stripe's hosted checkout, Stripe will: 1. Send a webhook to /api/v1/billing/webhook 2. Redirect the user to /api/v1/billing/success """ if tier not in ("pro", "team"): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=f"Invalid tier '{tier}'. Must be 'pro' or 'team'.", ) checkout_url = _get_checkout_url(tier) if not checkout_url: raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=( f"Checkout for '{tier}' tier is not configured yet. " "Please contact support or set up Stripe Payment Links." ), ) # Build redirect URL with org ID and prefilled email params = urlencode({ "client_reference_id": org.id, "prefilled_email": user.email, }) redirect_url = f"{checkout_url}?{params}" return CheckoutResponse( redirect_url=redirect_url, tier=tier, organization_id=org.id, ) @router.get("/status", response_model=BillingStatusResponse) async def get_billing_status( org: Organization = Depends(get_current_org), ): """Get the current organization's billing status: tier, limits, and checkout URLs.""" # Only show checkout URLs for tiers the org can upgrade to checkout_urls: dict[str, str] = {} if org.tier == "free": if settings.stripe_pro_checkout_url: checkout_urls["pro"] = settings.stripe_pro_checkout_url if settings.stripe_team_checkout_url: checkout_urls["team"] = settings.stripe_team_checkout_url elif org.tier == "pro": if settings.stripe_team_checkout_url: checkout_urls["team"] = settings.stripe_team_checkout_url return BillingStatusResponse( tier=org.tier or "free", tier_info=get_tier_info(org), stripe_customer_id=org.stripe_customer_id, checkout_urls=checkout_urls, ) @router.get("/success", response_model=SuccessResponse) async def billing_success( session_id: str = "", ): """Stripe redirects here after a successful checkout. In production, you would call stripe.checkout.sessions.retrieve(session_id) to verify the session. For the MVP, the webhook handler already processed the payment — this endpoint just confirms to the user that it worked. The session_id is provided by Stripe's redirect URL. """ return SuccessResponse( message="Subscription activated! Your plan has been upgraded.", tier=None, # Webhook already set the tier; frontend should refetch ) @router.get("/cancel") async def billing_cancel(): """Stripe redirects here if the user cancels checkout.""" return { "message": "Checkout was canceled. Your current plan is unchanged.", "tier_change": None, } # ── Webhook handler ────────────────────────────────────────────────────────── # Stripe event types we handle HANDLED_EVENTS = { "checkout.session.completed", "customer.subscription.updated", "customer.subscription.deleted", "invoice.payment_succeeded", "invoice.payment_failed", } def _determine_tier_from_amount(amount_total: int) -> str: """Determine the tier from the payment amount (in cents). Args: amount_total: Total amount in cents (e.g. 900 = $9.00) Returns: Tier string: "pro" or "team" """ if amount_total >= 2900: # $29.00+ return "team" elif amount_total >= 900: # $9.00+ return "pro" return "free" def _determine_tier_from_price_id(price_id: str) -> str: """Determine the tier from a Stripe price ID. This is a fallback when amount isn't available. In production, you would map your actual Stripe price IDs here. Args: price_id: Stripe price ID (e.g., price_1ABC123...) Returns: Tier string: "pro", "team", or "free" """ # Default mapping — override in production with your actual price IDs # The stripe API can also be used to look up price details return "pro" # Default fallback async def _verify_stripe_signature(payload: bytes, sig_header: str) -> dict | None: """Verify the Stripe webhook signature using HMAC-SHA256. Uses the stripe library if available, falls back to manual verification. Args: payload: Raw request body bytes sig_header: Stripe-Signature header value Returns: Parsed event dict if signature is valid, None otherwise """ if not settings.stripe_webhook_secret: logger.warning("No Stripe webhook secret configured; skipping verification") return json.loads(payload) try: import stripe if settings.stripe_api_key: stripe.api_key = settings.stripe_api_key event = stripe.Webhook.construct_event( payload, sig_header, settings.stripe_webhook_secret ) return dict(event) except ImportError: logger.info("stripe library not available; using manual signature verification") except stripe.error.SignatureVerificationError: return None except Exception as e: logger.error(f"Stripe webhook verification error: {e}") return None # Manual HMAC-SHA256 verification fallback # Stripe webhook signatures have the format: t=v1(v1_sig),v0=v0(v0_sig) try: elements = sig_header.split(",") timestamp = None signatures = [] for element in elements: key, value = element.split("=", 1) if key == "t": timestamp = value elif key == "v1": signatures.append(value) if not timestamp or not signatures: return None # Compute expected signature signed_payload = f"{timestamp}.{payload.decode('utf-8')}" expected_sig = hmac.new( settings.stripe_webhook_secret.encode("utf-8"), signed_payload.encode("utf-8"), hashlib.sha256, ).hexdigest() # Check if any v1 signature matches for sig in signatures: if hmac.compare_digest(expected_sig, sig): return json.loads(payload) return None except Exception: return None @router.post("/webhook") async def stripe_webhook(request: Request, db: AsyncSession = Depends(get_db)): """Handle Stripe webhook events for subscription changes. This is the ONLY server-side Stripe code needed for the Payment Links approach. It receives events when: - checkout.session.completed: User completed a checkout - customer.subscription.updated: Subscription tier changed - customer.subscription.deleted: Subscription canceled (downgrade to free) - invoice.payment_succeeded: Recurring payment succeeded - invoice.payment_failed: Payment failed (may downgrade) The webhook verifies the Stripe signature to ensure authenticity. """ body = await request.body() sig_header = request.headers.get("stripe-signature", "") # Verify the webhook signature event = await _verify_stripe_signature(body, sig_header) if event is None: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid webhook signature", ) event_type = event.get("type", "") # Only handle known event types if event_type not in HANDLED_EVENTS: logger.info(f"Ignoring unhandled Stripe event type: {event_type}") return WebhookResponse(status="ignored", event_type=event_type) # ── Handle checkout.session.completed ──────────────────────────────── if event_type == "checkout.session.completed": session_obj = event.get("data", {}).get("object", {}) org_id = session_obj.get("client_reference_id") customer_id = session_obj.get("customer") amount_total = session_obj.get("amount_total", 0) if not org_id: logger.warning("checkout.session.completed event missing client_reference_id") return WebhookResponse(status="error", event_type=event_type) # Find the organization result = await db.execute( select(Organization).where(Organization.id == org_id) ) org = result.scalar_one_or_none() if not org: logger.warning(f"Organization {org_id} not found for Stripe checkout") return WebhookResponse(status="error", event_type=event_type) # Determine the tier from the payment amount new_tier = _determine_tier_from_amount(amount_total) # Try to determine tier from line items if available line_items = session_obj.get("line_items", {}).get("data", []) if line_items: price_id = line_items[0].get("price", {}).get("id", "") if price_id: new_tier = _determine_tier_from_price_id(price_id) # Update the organization old_tier = org.tier org.tier = new_tier if customer_id: org.stripe_customer_id = customer_id await db.flush() logger.info( f"Organization {org.slug} upgraded from {old_tier} to {new_tier} " f"(Stripe session: {session_obj.get('id', 'unknown')})" ) return WebhookResponse(status="ok", event_type=event_type) # ── Handle customer.subscription.updated ───────────────────────────── if event_type == "customer.subscription.updated": subscription = event.get("data", {}).get("object", {}) customer_id = subscription.get("customer") if not customer_id: return WebhookResponse(status="error", event_type=event_type) result = await db.execute( select(Organization).where( Organization.stripe_customer_id == customer_id ) ) org = result.scalar_one_or_none() if not org: logger.warning(f"No org found for Stripe customer {customer_id}") return WebhookResponse(status="error", event_type=event_type) # Determine tier from subscription metadata or amount metadata = subscription.get("metadata", {}) tier_from_meta = metadata.get("tier") # Check plan amount from items items = subscription.get("items", {}).get("data", []) if items: amount = items[0].get("plan", {}).get("amount", 0) new_tier = _determine_tier_from_amount(amount) elif tier_from_meta in ("pro", "team"): new_tier = tier_from_meta else: new_tier = org.tier # Keep current if indeterminate old_tier = org.tier org.tier = new_tier await db.flush() logger.info(f"Organization {org.slug} subscription updated: {old_tier} → {new_tier}") return WebhookResponse(status="ok", event_type=event_type) # ── Handle customer.subscription.deleted ────────────────────────────── if event_type == "customer.subscription.deleted": subscription = event.get("data", {}).get("object", {}) customer_id = subscription.get("customer") if not customer_id: return WebhookResponse(status="error", event_type=event_type) result = await db.execute( select(Organization).where( Organization.stripe_customer_id == customer_id ) ) org = result.scalar_one_or_none() if not org: logger.warning(f"No org found for Stripe customer {customer_id}") return WebhookResponse(status="error", event_type=event_type) old_tier = org.tier org.tier = "free" org.stripe_customer_id = None # Clear stripe ID on cancellation await db.flush() logger.info(f"Organization {org.slug} downgraded from {old_tier} to free (subscription deleted)") return WebhookResponse(status="ok", event_type=event_type) # ── Handle invoice.payment_succeeded ───────────────────────────────── if event_type == "invoice.payment_succeeded": # Payment renewed successfully — org keeps its tier # We could update stripe_customer_id here if needed invoice = event.get("data", {}).get("object", {}) customer_id = invoice.get("customer") if customer_id: result = await db.execute( select(Organization).where( Organization.stripe_customer_id == customer_id ) ) org = result.scalar_one_or_none() if org: logger.info(f"Invoice payment succeeded for org {org.slug} (tier: {org.tier})") return WebhookResponse(status="ok", event_type=event_type) # ── Handle invoice.payment_failed ──────────────────────────────────── if event_type == "invoice.payment_failed": invoice = event.get("data", {}).get("object", {}) customer_id = invoice.get("customer") if customer_id: result = await db.execute( select(Organization).where( Organization.stripe_customer_id == customer_id ) ) org = result.scalar_one_or_none() if org: logger.warning( f"Invoice payment failed for org {org.slug} (tier: {org.tier}). " "Payment retry will be attempted by Stripe." ) # Don't downgrade yet — Stripe will retry per its configured rules # After all retries fail, customer.subscription.deleted will fire return WebhookResponse(status="ok", event_type=event_type) # Fallback (shouldn't reach here due to HANDLED_EVENTS check) return WebhookResponse(status="ignored", event_type=event_type)