feat: Stripe Checkout Link integration — billing API with 29 tests
- Add billing module (app/api/billing.py) with 5 API endpoints:
- GET /api/v1/billing/checkout/{tier} — redirect to Stripe Payment Links
- GET /api/v1/billing/status — current org tier, limits, upgrade URLs
- GET /api/v1/billing/success — Stripe success callback
- GET /api/v1/billing/cancel — Stripe cancel callback
- POST /api/v1/billing/webhook — handles 5 Stripe event types
- Zero-code payment flow: uses pre-configured Stripe Payment Links
with client_reference_id (org ID) and prefilled_email params
- Webhook handler processes checkout.session.completed,
customer.subscription.updated/deleted, invoice events
- Stripe signature verification via stripe library (primary)
or manual HMAC-SHA256 (fallback)
- Tier determination from payment amount: =pro, 9=team
- 4 new config settings in app/config.py:
stripe_pro_checkout_url, stripe_team_checkout_url,
stripe_webhook_secret, stripe_api_key
- Added stripe>=5.0,<16.0 dependency
- 29 tests in tests/test_billing.py (all passing)
- Total: 98 tests passing (69 existing + 29 new)
This commit is contained in:
parent
b7a8142ca0
commit
158a6ee716
6 changed files with 1311 additions and 1 deletions
472
app/api/billing.py
Normal file
472
app/api/billing.py
Normal file
|
|
@ -0,0 +1,472 @@
|
|||
"""Billing API: Stripe Checkout Link integration and webhook handler.
|
||||
|
||||
This module implements the "zero-code" payment flow using Stripe Payment Links:
|
||||
1. User clicks "Upgrade" → redirected to a Stripe-hosted checkout page
|
||||
2. After payment, Stripe sends a webhook event to /api/v1/billing/webhook
|
||||
3. The webhook handler updates the org's tier in the database
|
||||
4. Stripe redirects the user back to /api/v1/billing/success
|
||||
|
||||
This approach requires no Stripe server-side SDK for checkout creation —
|
||||
just configuration of Payment Link URLs and a webhook endpoint for
|
||||
receiving payment confirmations.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
import logging
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.auth import get_current_user, get_current_org
|
||||
from app.config import settings
|
||||
from app.dependencies import get_db
|
||||
from app.models.saas_models import Organization, OrganizationMember, User
|
||||
from app.services.tier_limits import get_tier_info
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(tags=["billing"])
|
||||
|
||||
# ── Tier-to-URL mapping ─────────────────────────────────────────────────────
|
||||
|
||||
def _get_checkout_url(tier: str) -> str:
|
||||
"""Return the Stripe Payment Link URL for the given tier.
|
||||
|
||||
Reads from settings at call time (not import time) so tests can override.
|
||||
"""
|
||||
if tier == "pro":
|
||||
return settings.stripe_pro_checkout_url
|
||||
elif tier == "team":
|
||||
return settings.stripe_team_checkout_url
|
||||
return ""
|
||||
|
||||
|
||||
# ── Response schemas ─────────────────────────────────────────────────────────
|
||||
|
||||
class CheckoutResponse(BaseModel):
|
||||
"""Response for checkout redirect endpoint."""
|
||||
redirect_url: str
|
||||
tier: str
|
||||
organization_id: str
|
||||
|
||||
|
||||
class BillingStatusResponse(BaseModel):
|
||||
"""Response for billing status endpoint."""
|
||||
tier: str
|
||||
tier_info: dict
|
||||
stripe_customer_id: str | None = None
|
||||
checkout_urls: dict[str, str]
|
||||
|
||||
|
||||
class SuccessResponse(BaseModel):
|
||||
"""Response for successful checkout callback."""
|
||||
message: str
|
||||
tier: str | None = None
|
||||
|
||||
|
||||
class WebhookResponse(BaseModel):
|
||||
"""Response for webhook processing."""
|
||||
status: str
|
||||
event_type: str | None = None
|
||||
|
||||
|
||||
# ── Endpoints ────────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/checkout/{tier}", response_model=CheckoutResponse)
|
||||
async def create_checkout(
|
||||
tier: str,
|
||||
user: User = Depends(get_current_user),
|
||||
org: Organization = Depends(get_current_org),
|
||||
):
|
||||
"""Generate a Stripe Checkout Link redirect URL for the selected tier.
|
||||
|
||||
This endpoint does NOT create a Stripe Checkout Session. Instead, it
|
||||
returns the pre-configured Stripe Payment Link URL with the org's ID
|
||||
attached as client_reference_id and the user's email prefilled.
|
||||
|
||||
The frontend should redirect the user to this URL. After completing
|
||||
payment on Stripe's hosted checkout, Stripe will:
|
||||
1. Send a webhook to /api/v1/billing/webhook
|
||||
2. Redirect the user to /api/v1/billing/success
|
||||
"""
|
||||
if tier not in ("pro", "team"):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Invalid tier '{tier}'. Must be 'pro' or 'team'.",
|
||||
)
|
||||
|
||||
checkout_url = _get_checkout_url(tier)
|
||||
if not checkout_url:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail=(
|
||||
f"Checkout for '{tier}' tier is not configured yet. "
|
||||
"Please contact support or set up Stripe Payment Links."
|
||||
),
|
||||
)
|
||||
|
||||
# Build redirect URL with org ID and prefilled email
|
||||
params = urlencode({
|
||||
"client_reference_id": org.id,
|
||||
"prefilled_email": user.email,
|
||||
})
|
||||
redirect_url = f"{checkout_url}?{params}"
|
||||
|
||||
return CheckoutResponse(
|
||||
redirect_url=redirect_url,
|
||||
tier=tier,
|
||||
organization_id=org.id,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/status", response_model=BillingStatusResponse)
|
||||
async def get_billing_status(
|
||||
org: Organization = Depends(get_current_org),
|
||||
):
|
||||
"""Get the current organization's billing status: tier, limits, and checkout URLs."""
|
||||
# Only show checkout URLs for tiers the org can upgrade to
|
||||
checkout_urls: dict[str, str] = {}
|
||||
if org.tier == "free":
|
||||
if settings.stripe_pro_checkout_url:
|
||||
checkout_urls["pro"] = settings.stripe_pro_checkout_url
|
||||
if settings.stripe_team_checkout_url:
|
||||
checkout_urls["team"] = settings.stripe_team_checkout_url
|
||||
elif org.tier == "pro":
|
||||
if settings.stripe_team_checkout_url:
|
||||
checkout_urls["team"] = settings.stripe_team_checkout_url
|
||||
|
||||
return BillingStatusResponse(
|
||||
tier=org.tier or "free",
|
||||
tier_info=get_tier_info(org),
|
||||
stripe_customer_id=org.stripe_customer_id,
|
||||
checkout_urls=checkout_urls,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/success", response_model=SuccessResponse)
|
||||
async def billing_success(
|
||||
session_id: str = "",
|
||||
):
|
||||
"""Stripe redirects here after a successful checkout.
|
||||
|
||||
In production, you would call stripe.checkout.sessions.retrieve(session_id)
|
||||
to verify the session. For the MVP, the webhook handler already processed
|
||||
the payment — this endpoint just confirms to the user that it worked.
|
||||
|
||||
The session_id is provided by Stripe's redirect URL.
|
||||
"""
|
||||
return SuccessResponse(
|
||||
message="Subscription activated! Your plan has been upgraded.",
|
||||
tier=None, # Webhook already set the tier; frontend should refetch
|
||||
)
|
||||
|
||||
|
||||
@router.get("/cancel")
|
||||
async def billing_cancel():
|
||||
"""Stripe redirects here if the user cancels checkout."""
|
||||
return {
|
||||
"message": "Checkout was canceled. Your current plan is unchanged.",
|
||||
"tier_change": None,
|
||||
}
|
||||
|
||||
|
||||
# ── Webhook handler ──────────────────────────────────────────────────────────
|
||||
|
||||
# Stripe event types we handle
|
||||
HANDLED_EVENTS = {
|
||||
"checkout.session.completed",
|
||||
"customer.subscription.updated",
|
||||
"customer.subscription.deleted",
|
||||
"invoice.payment_succeeded",
|
||||
"invoice.payment_failed",
|
||||
}
|
||||
|
||||
|
||||
def _determine_tier_from_amount(amount_total: int) -> str:
|
||||
"""Determine the tier from the payment amount (in cents).
|
||||
|
||||
Args:
|
||||
amount_total: Total amount in cents (e.g. 900 = $9.00)
|
||||
|
||||
Returns:
|
||||
Tier string: "pro" or "team"
|
||||
"""
|
||||
if amount_total >= 2900: # $29.00+
|
||||
return "team"
|
||||
elif amount_total >= 900: # $9.00+
|
||||
return "pro"
|
||||
return "free"
|
||||
|
||||
|
||||
def _determine_tier_from_price_id(price_id: str) -> str:
|
||||
"""Determine the tier from a Stripe price ID.
|
||||
|
||||
This is a fallback when amount isn't available. In production,
|
||||
you would map your actual Stripe price IDs here.
|
||||
|
||||
Args:
|
||||
price_id: Stripe price ID (e.g., price_1ABC123...)
|
||||
|
||||
Returns:
|
||||
Tier string: "pro", "team", or "free"
|
||||
"""
|
||||
# Default mapping — override in production with your actual price IDs
|
||||
# The stripe API can also be used to look up price details
|
||||
return "pro" # Default fallback
|
||||
|
||||
|
||||
async def _verify_stripe_signature(payload: bytes, sig_header: str) -> dict | None:
|
||||
"""Verify the Stripe webhook signature using HMAC-SHA256.
|
||||
|
||||
Uses the stripe library if available, falls back to manual verification.
|
||||
|
||||
Args:
|
||||
payload: Raw request body bytes
|
||||
sig_header: Stripe-Signature header value
|
||||
|
||||
Returns:
|
||||
Parsed event dict if signature is valid, None otherwise
|
||||
"""
|
||||
if not settings.stripe_webhook_secret:
|
||||
logger.warning("No Stripe webhook secret configured; skipping verification")
|
||||
return json.loads(payload)
|
||||
|
||||
try:
|
||||
import stripe
|
||||
if settings.stripe_api_key:
|
||||
stripe.api_key = settings.stripe_api_key
|
||||
event = stripe.Webhook.construct_event(
|
||||
payload, sig_header, settings.stripe_webhook_secret
|
||||
)
|
||||
return dict(event)
|
||||
except ImportError:
|
||||
logger.info("stripe library not available; using manual signature verification")
|
||||
except stripe.error.SignatureVerificationError:
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Stripe webhook verification error: {e}")
|
||||
return None
|
||||
|
||||
# Manual HMAC-SHA256 verification fallback
|
||||
# Stripe webhook signatures have the format: t=v1(v1_sig),v0=v0(v0_sig)
|
||||
try:
|
||||
elements = sig_header.split(",")
|
||||
timestamp = None
|
||||
signatures = []
|
||||
for element in elements:
|
||||
key, value = element.split("=", 1)
|
||||
if key == "t":
|
||||
timestamp = value
|
||||
elif key == "v1":
|
||||
signatures.append(value)
|
||||
|
||||
if not timestamp or not signatures:
|
||||
return None
|
||||
|
||||
# Compute expected signature
|
||||
signed_payload = f"{timestamp}.{payload.decode('utf-8')}"
|
||||
expected_sig = hmac.new(
|
||||
settings.stripe_webhook_secret.encode("utf-8"),
|
||||
signed_payload.encode("utf-8"),
|
||||
hashlib.sha256,
|
||||
).hexdigest()
|
||||
|
||||
# Check if any v1 signature matches
|
||||
for sig in signatures:
|
||||
if hmac.compare_digest(expected_sig, sig):
|
||||
return json.loads(payload)
|
||||
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
@router.post("/webhook")
|
||||
async def stripe_webhook(request: Request, db: AsyncSession = Depends(get_db)):
|
||||
"""Handle Stripe webhook events for subscription changes.
|
||||
|
||||
This is the ONLY server-side Stripe code needed for the Payment Links
|
||||
approach. It receives events when:
|
||||
- checkout.session.completed: User completed a checkout
|
||||
- customer.subscription.updated: Subscription tier changed
|
||||
- customer.subscription.deleted: Subscription canceled (downgrade to free)
|
||||
- invoice.payment_succeeded: Recurring payment succeeded
|
||||
- invoice.payment_failed: Payment failed (may downgrade)
|
||||
|
||||
The webhook verifies the Stripe signature to ensure authenticity.
|
||||
"""
|
||||
body = await request.body()
|
||||
sig_header = request.headers.get("stripe-signature", "")
|
||||
|
||||
# Verify the webhook signature
|
||||
event = await _verify_stripe_signature(body, sig_header)
|
||||
if event is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid webhook signature",
|
||||
)
|
||||
|
||||
event_type = event.get("type", "")
|
||||
|
||||
# Only handle known event types
|
||||
if event_type not in HANDLED_EVENTS:
|
||||
logger.info(f"Ignoring unhandled Stripe event type: {event_type}")
|
||||
return WebhookResponse(status="ignored", event_type=event_type)
|
||||
|
||||
# ── Handle checkout.session.completed ────────────────────────────────
|
||||
if event_type == "checkout.session.completed":
|
||||
session_obj = event.get("data", {}).get("object", {})
|
||||
org_id = session_obj.get("client_reference_id")
|
||||
customer_id = session_obj.get("customer")
|
||||
amount_total = session_obj.get("amount_total", 0)
|
||||
|
||||
if not org_id:
|
||||
logger.warning("checkout.session.completed event missing client_reference_id")
|
||||
return WebhookResponse(status="error", event_type=event_type)
|
||||
|
||||
# Find the organization
|
||||
result = await db.execute(
|
||||
select(Organization).where(Organization.id == org_id)
|
||||
)
|
||||
org = result.scalar_one_or_none()
|
||||
if not org:
|
||||
logger.warning(f"Organization {org_id} not found for Stripe checkout")
|
||||
return WebhookResponse(status="error", event_type=event_type)
|
||||
|
||||
# Determine the tier from the payment amount
|
||||
new_tier = _determine_tier_from_amount(amount_total)
|
||||
|
||||
# Try to determine tier from line items if available
|
||||
line_items = session_obj.get("line_items", {}).get("data", [])
|
||||
if line_items:
|
||||
price_id = line_items[0].get("price", {}).get("id", "")
|
||||
if price_id:
|
||||
new_tier = _determine_tier_from_price_id(price_id)
|
||||
|
||||
# Update the organization
|
||||
old_tier = org.tier
|
||||
org.tier = new_tier
|
||||
if customer_id:
|
||||
org.stripe_customer_id = customer_id
|
||||
|
||||
await db.flush()
|
||||
logger.info(
|
||||
f"Organization {org.slug} upgraded from {old_tier} to {new_tier} "
|
||||
f"(Stripe session: {session_obj.get('id', 'unknown')})"
|
||||
)
|
||||
|
||||
return WebhookResponse(status="ok", event_type=event_type)
|
||||
|
||||
# ── Handle customer.subscription.updated ─────────────────────────────
|
||||
if event_type == "customer.subscription.updated":
|
||||
subscription = event.get("data", {}).get("object", {})
|
||||
customer_id = subscription.get("customer")
|
||||
|
||||
if not customer_id:
|
||||
return WebhookResponse(status="error", event_type=event_type)
|
||||
|
||||
result = await db.execute(
|
||||
select(Organization).where(
|
||||
Organization.stripe_customer_id == customer_id
|
||||
)
|
||||
)
|
||||
org = result.scalar_one_or_none()
|
||||
if not org:
|
||||
logger.warning(f"No org found for Stripe customer {customer_id}")
|
||||
return WebhookResponse(status="error", event_type=event_type)
|
||||
|
||||
# Determine tier from subscription metadata or amount
|
||||
metadata = subscription.get("metadata", {})
|
||||
tier_from_meta = metadata.get("tier")
|
||||
|
||||
# Check plan amount from items
|
||||
items = subscription.get("items", {}).get("data", [])
|
||||
if items:
|
||||
amount = items[0].get("plan", {}).get("amount", 0)
|
||||
new_tier = _determine_tier_from_amount(amount)
|
||||
elif tier_from_meta in ("pro", "team"):
|
||||
new_tier = tier_from_meta
|
||||
else:
|
||||
new_tier = org.tier # Keep current if indeterminate
|
||||
|
||||
old_tier = org.tier
|
||||
org.tier = new_tier
|
||||
await db.flush()
|
||||
logger.info(f"Organization {org.slug} subscription updated: {old_tier} → {new_tier}")
|
||||
|
||||
return WebhookResponse(status="ok", event_type=event_type)
|
||||
|
||||
# ── Handle customer.subscription.deleted ──────────────────────────────
|
||||
if event_type == "customer.subscription.deleted":
|
||||
subscription = event.get("data", {}).get("object", {})
|
||||
customer_id = subscription.get("customer")
|
||||
|
||||
if not customer_id:
|
||||
return WebhookResponse(status="error", event_type=event_type)
|
||||
|
||||
result = await db.execute(
|
||||
select(Organization).where(
|
||||
Organization.stripe_customer_id == customer_id
|
||||
)
|
||||
)
|
||||
org = result.scalar_one_or_none()
|
||||
if not org:
|
||||
logger.warning(f"No org found for Stripe customer {customer_id}")
|
||||
return WebhookResponse(status="error", event_type=event_type)
|
||||
|
||||
old_tier = org.tier
|
||||
org.tier = "free"
|
||||
org.stripe_customer_id = None # Clear stripe ID on cancellation
|
||||
await db.flush()
|
||||
logger.info(f"Organization {org.slug} downgraded from {old_tier} to free (subscription deleted)")
|
||||
|
||||
return WebhookResponse(status="ok", event_type=event_type)
|
||||
|
||||
# ── Handle invoice.payment_succeeded ─────────────────────────────────
|
||||
if event_type == "invoice.payment_succeeded":
|
||||
# Payment renewed successfully — org keeps its tier
|
||||
# We could update stripe_customer_id here if needed
|
||||
invoice = event.get("data", {}).get("object", {})
|
||||
customer_id = invoice.get("customer")
|
||||
|
||||
if customer_id:
|
||||
result = await db.execute(
|
||||
select(Organization).where(
|
||||
Organization.stripe_customer_id == customer_id
|
||||
)
|
||||
)
|
||||
org = result.scalar_one_or_none()
|
||||
if org:
|
||||
logger.info(f"Invoice payment succeeded for org {org.slug} (tier: {org.tier})")
|
||||
|
||||
return WebhookResponse(status="ok", event_type=event_type)
|
||||
|
||||
# ── Handle invoice.payment_failed ────────────────────────────────────
|
||||
if event_type == "invoice.payment_failed":
|
||||
invoice = event.get("data", {}).get("object", {})
|
||||
customer_id = invoice.get("customer")
|
||||
|
||||
if customer_id:
|
||||
result = await db.execute(
|
||||
select(Organization).where(
|
||||
Organization.stripe_customer_id == customer_id
|
||||
)
|
||||
)
|
||||
org = result.scalar_one_or_none()
|
||||
if org:
|
||||
logger.warning(
|
||||
f"Invoice payment failed for org {org.slug} (tier: {org.tier}). "
|
||||
"Payment retry will be attempted by Stripe."
|
||||
)
|
||||
# Don't downgrade yet — Stripe will retry per its configured rules
|
||||
# After all retries fail, customer.subscription.deleted will fire
|
||||
|
||||
return WebhookResponse(status="ok", event_type=event_type)
|
||||
|
||||
# Fallback (shouldn't reach here due to HANDLED_EVENTS check)
|
||||
return WebhookResponse(status="ignored", event_type=event_type)
|
||||
|
|
@ -6,12 +6,14 @@ from app.api.monitors import router as monitors_router
|
|||
from app.api.subscribers import router as subscribers_router
|
||||
from app.api.settings import router as settings_router
|
||||
from app.api.organizations import router as organizations_router
|
||||
from app.api.billing import router as billing_router
|
||||
from app.routes.auth import router as auth_router
|
||||
|
||||
api_v1_router = APIRouter()
|
||||
|
||||
api_v1_router.include_router(auth_router, tags=["auth"])
|
||||
api_v1_router.include_router(organizations_router, prefix="/organizations", tags=["organizations"])
|
||||
api_v1_router.include_router(billing_router, prefix="/billing", tags=["billing"])
|
||||
api_v1_router.include_router(services_router, prefix="/services", tags=["services"])
|
||||
api_v1_router.include_router(incidents_router, prefix="/incidents", tags=["incidents"])
|
||||
api_v1_router.include_router(monitors_router, prefix="/monitors", tags=["monitors"])
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue