feat: Stripe Checkout Link integration — billing API with 29 tests

- Add billing module (app/api/billing.py) with 5 API endpoints:
  - GET /api/v1/billing/checkout/{tier} — redirect to Stripe Payment Links
  - GET /api/v1/billing/status — current org tier, limits, upgrade URLs
  - GET /api/v1/billing/success — Stripe success callback
  - GET /api/v1/billing/cancel — Stripe cancel callback
  - POST /api/v1/billing/webhook — handles 5 Stripe event types

- Zero-code payment flow: uses pre-configured Stripe Payment Links
  with client_reference_id (org ID) and prefilled_email params

- Webhook handler processes checkout.session.completed,
  customer.subscription.updated/deleted, invoice events

- Stripe signature verification via stripe library (primary)
  or manual HMAC-SHA256 (fallback)

- Tier determination from payment amount: =pro, 9=team

- 4 new config settings in app/config.py:
  stripe_pro_checkout_url, stripe_team_checkout_url,
  stripe_webhook_secret, stripe_api_key

- Added stripe>=5.0,<16.0 dependency

- 29 tests in tests/test_billing.py (all passing)
- Total: 98 tests passing (69 existing + 29 new)
This commit is contained in:
Ubuntu 2026-04-25 10:18:38 +00:00
parent b7a8142ca0
commit 158a6ee716
6 changed files with 1311 additions and 1 deletions

View file

@ -24,4 +24,11 @@ SMTP_FROM=noreply@example.com
WEBHOOK_NOTIFY_URL=
# Uptime Monitoring
MONITOR_CHECK_INTERVAL=60
MONITOR_CHECK_INTERVAL=60
# Stripe Checkout Links (set these in production)
# Create Payment Links in Stripe Dashboard → Products → Payment Links
STRIPE_PRO_CHECKOUT_URL=
STRIPE_TEAM_CHECKOUT_URL=
STRIPE_WEBHOOK_SECRET=
STRIPE_API_KEY=

472
app/api/billing.py Normal file
View file

@ -0,0 +1,472 @@
"""Billing API: Stripe Checkout Link integration and webhook handler.
This module implements the "zero-code" payment flow using Stripe Payment Links:
1. User clicks "Upgrade" redirected to a Stripe-hosted checkout page
2. After payment, Stripe sends a webhook event to /api/v1/billing/webhook
3. The webhook handler updates the org's tier in the database
4. Stripe redirects the user back to /api/v1/billing/success
This approach requires no Stripe server-side SDK for checkout creation
just configuration of Payment Link URLs and a webhook endpoint for
receiving payment confirmations.
"""
import hashlib
import hmac
import json
import logging
from urllib.parse import urlencode
from fastapi import APIRouter, Depends, HTTPException, Request, status
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.auth import get_current_user, get_current_org
from app.config import settings
from app.dependencies import get_db
from app.models.saas_models import Organization, OrganizationMember, User
from app.services.tier_limits import get_tier_info
logger = logging.getLogger(__name__)
router = APIRouter(tags=["billing"])
# ── Tier-to-URL mapping ─────────────────────────────────────────────────────
def _get_checkout_url(tier: str) -> str:
"""Return the Stripe Payment Link URL for the given tier.
Reads from settings at call time (not import time) so tests can override.
"""
if tier == "pro":
return settings.stripe_pro_checkout_url
elif tier == "team":
return settings.stripe_team_checkout_url
return ""
# ── Response schemas ─────────────────────────────────────────────────────────
class CheckoutResponse(BaseModel):
"""Response for checkout redirect endpoint."""
redirect_url: str
tier: str
organization_id: str
class BillingStatusResponse(BaseModel):
"""Response for billing status endpoint."""
tier: str
tier_info: dict
stripe_customer_id: str | None = None
checkout_urls: dict[str, str]
class SuccessResponse(BaseModel):
"""Response for successful checkout callback."""
message: str
tier: str | None = None
class WebhookResponse(BaseModel):
"""Response for webhook processing."""
status: str
event_type: str | None = None
# ── Endpoints ────────────────────────────────────────────────────────────────
@router.get("/checkout/{tier}", response_model=CheckoutResponse)
async def create_checkout(
tier: str,
user: User = Depends(get_current_user),
org: Organization = Depends(get_current_org),
):
"""Generate a Stripe Checkout Link redirect URL for the selected tier.
This endpoint does NOT create a Stripe Checkout Session. Instead, it
returns the pre-configured Stripe Payment Link URL with the org's ID
attached as client_reference_id and the user's email prefilled.
The frontend should redirect the user to this URL. After completing
payment on Stripe's hosted checkout, Stripe will:
1. Send a webhook to /api/v1/billing/webhook
2. Redirect the user to /api/v1/billing/success
"""
if tier not in ("pro", "team"):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Invalid tier '{tier}'. Must be 'pro' or 'team'.",
)
checkout_url = _get_checkout_url(tier)
if not checkout_url:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail=(
f"Checkout for '{tier}' tier is not configured yet. "
"Please contact support or set up Stripe Payment Links."
),
)
# Build redirect URL with org ID and prefilled email
params = urlencode({
"client_reference_id": org.id,
"prefilled_email": user.email,
})
redirect_url = f"{checkout_url}?{params}"
return CheckoutResponse(
redirect_url=redirect_url,
tier=tier,
organization_id=org.id,
)
@router.get("/status", response_model=BillingStatusResponse)
async def get_billing_status(
org: Organization = Depends(get_current_org),
):
"""Get the current organization's billing status: tier, limits, and checkout URLs."""
# Only show checkout URLs for tiers the org can upgrade to
checkout_urls: dict[str, str] = {}
if org.tier == "free":
if settings.stripe_pro_checkout_url:
checkout_urls["pro"] = settings.stripe_pro_checkout_url
if settings.stripe_team_checkout_url:
checkout_urls["team"] = settings.stripe_team_checkout_url
elif org.tier == "pro":
if settings.stripe_team_checkout_url:
checkout_urls["team"] = settings.stripe_team_checkout_url
return BillingStatusResponse(
tier=org.tier or "free",
tier_info=get_tier_info(org),
stripe_customer_id=org.stripe_customer_id,
checkout_urls=checkout_urls,
)
@router.get("/success", response_model=SuccessResponse)
async def billing_success(
session_id: str = "",
):
"""Stripe redirects here after a successful checkout.
In production, you would call stripe.checkout.sessions.retrieve(session_id)
to verify the session. For the MVP, the webhook handler already processed
the payment this endpoint just confirms to the user that it worked.
The session_id is provided by Stripe's redirect URL.
"""
return SuccessResponse(
message="Subscription activated! Your plan has been upgraded.",
tier=None, # Webhook already set the tier; frontend should refetch
)
@router.get("/cancel")
async def billing_cancel():
"""Stripe redirects here if the user cancels checkout."""
return {
"message": "Checkout was canceled. Your current plan is unchanged.",
"tier_change": None,
}
# ── Webhook handler ──────────────────────────────────────────────────────────
# Stripe event types we handle
HANDLED_EVENTS = {
"checkout.session.completed",
"customer.subscription.updated",
"customer.subscription.deleted",
"invoice.payment_succeeded",
"invoice.payment_failed",
}
def _determine_tier_from_amount(amount_total: int) -> str:
"""Determine the tier from the payment amount (in cents).
Args:
amount_total: Total amount in cents (e.g. 900 = $9.00)
Returns:
Tier string: "pro" or "team"
"""
if amount_total >= 2900: # $29.00+
return "team"
elif amount_total >= 900: # $9.00+
return "pro"
return "free"
def _determine_tier_from_price_id(price_id: str) -> str:
"""Determine the tier from a Stripe price ID.
This is a fallback when amount isn't available. In production,
you would map your actual Stripe price IDs here.
Args:
price_id: Stripe price ID (e.g., price_1ABC123...)
Returns:
Tier string: "pro", "team", or "free"
"""
# Default mapping — override in production with your actual price IDs
# The stripe API can also be used to look up price details
return "pro" # Default fallback
async def _verify_stripe_signature(payload: bytes, sig_header: str) -> dict | None:
"""Verify the Stripe webhook signature using HMAC-SHA256.
Uses the stripe library if available, falls back to manual verification.
Args:
payload: Raw request body bytes
sig_header: Stripe-Signature header value
Returns:
Parsed event dict if signature is valid, None otherwise
"""
if not settings.stripe_webhook_secret:
logger.warning("No Stripe webhook secret configured; skipping verification")
return json.loads(payload)
try:
import stripe
if settings.stripe_api_key:
stripe.api_key = settings.stripe_api_key
event = stripe.Webhook.construct_event(
payload, sig_header, settings.stripe_webhook_secret
)
return dict(event)
except ImportError:
logger.info("stripe library not available; using manual signature verification")
except stripe.error.SignatureVerificationError:
return None
except Exception as e:
logger.error(f"Stripe webhook verification error: {e}")
return None
# Manual HMAC-SHA256 verification fallback
# Stripe webhook signatures have the format: t=v1(v1_sig),v0=v0(v0_sig)
try:
elements = sig_header.split(",")
timestamp = None
signatures = []
for element in elements:
key, value = element.split("=", 1)
if key == "t":
timestamp = value
elif key == "v1":
signatures.append(value)
if not timestamp or not signatures:
return None
# Compute expected signature
signed_payload = f"{timestamp}.{payload.decode('utf-8')}"
expected_sig = hmac.new(
settings.stripe_webhook_secret.encode("utf-8"),
signed_payload.encode("utf-8"),
hashlib.sha256,
).hexdigest()
# Check if any v1 signature matches
for sig in signatures:
if hmac.compare_digest(expected_sig, sig):
return json.loads(payload)
return None
except Exception:
return None
@router.post("/webhook")
async def stripe_webhook(request: Request, db: AsyncSession = Depends(get_db)):
"""Handle Stripe webhook events for subscription changes.
This is the ONLY server-side Stripe code needed for the Payment Links
approach. It receives events when:
- checkout.session.completed: User completed a checkout
- customer.subscription.updated: Subscription tier changed
- customer.subscription.deleted: Subscription canceled (downgrade to free)
- invoice.payment_succeeded: Recurring payment succeeded
- invoice.payment_failed: Payment failed (may downgrade)
The webhook verifies the Stripe signature to ensure authenticity.
"""
body = await request.body()
sig_header = request.headers.get("stripe-signature", "")
# Verify the webhook signature
event = await _verify_stripe_signature(body, sig_header)
if event is None:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid webhook signature",
)
event_type = event.get("type", "")
# Only handle known event types
if event_type not in HANDLED_EVENTS:
logger.info(f"Ignoring unhandled Stripe event type: {event_type}")
return WebhookResponse(status="ignored", event_type=event_type)
# ── Handle checkout.session.completed ────────────────────────────────
if event_type == "checkout.session.completed":
session_obj = event.get("data", {}).get("object", {})
org_id = session_obj.get("client_reference_id")
customer_id = session_obj.get("customer")
amount_total = session_obj.get("amount_total", 0)
if not org_id:
logger.warning("checkout.session.completed event missing client_reference_id")
return WebhookResponse(status="error", event_type=event_type)
# Find the organization
result = await db.execute(
select(Organization).where(Organization.id == org_id)
)
org = result.scalar_one_or_none()
if not org:
logger.warning(f"Organization {org_id} not found for Stripe checkout")
return WebhookResponse(status="error", event_type=event_type)
# Determine the tier from the payment amount
new_tier = _determine_tier_from_amount(amount_total)
# Try to determine tier from line items if available
line_items = session_obj.get("line_items", {}).get("data", [])
if line_items:
price_id = line_items[0].get("price", {}).get("id", "")
if price_id:
new_tier = _determine_tier_from_price_id(price_id)
# Update the organization
old_tier = org.tier
org.tier = new_tier
if customer_id:
org.stripe_customer_id = customer_id
await db.flush()
logger.info(
f"Organization {org.slug} upgraded from {old_tier} to {new_tier} "
f"(Stripe session: {session_obj.get('id', 'unknown')})"
)
return WebhookResponse(status="ok", event_type=event_type)
# ── Handle customer.subscription.updated ─────────────────────────────
if event_type == "customer.subscription.updated":
subscription = event.get("data", {}).get("object", {})
customer_id = subscription.get("customer")
if not customer_id:
return WebhookResponse(status="error", event_type=event_type)
result = await db.execute(
select(Organization).where(
Organization.stripe_customer_id == customer_id
)
)
org = result.scalar_one_or_none()
if not org:
logger.warning(f"No org found for Stripe customer {customer_id}")
return WebhookResponse(status="error", event_type=event_type)
# Determine tier from subscription metadata or amount
metadata = subscription.get("metadata", {})
tier_from_meta = metadata.get("tier")
# Check plan amount from items
items = subscription.get("items", {}).get("data", [])
if items:
amount = items[0].get("plan", {}).get("amount", 0)
new_tier = _determine_tier_from_amount(amount)
elif tier_from_meta in ("pro", "team"):
new_tier = tier_from_meta
else:
new_tier = org.tier # Keep current if indeterminate
old_tier = org.tier
org.tier = new_tier
await db.flush()
logger.info(f"Organization {org.slug} subscription updated: {old_tier}{new_tier}")
return WebhookResponse(status="ok", event_type=event_type)
# ── Handle customer.subscription.deleted ──────────────────────────────
if event_type == "customer.subscription.deleted":
subscription = event.get("data", {}).get("object", {})
customer_id = subscription.get("customer")
if not customer_id:
return WebhookResponse(status="error", event_type=event_type)
result = await db.execute(
select(Organization).where(
Organization.stripe_customer_id == customer_id
)
)
org = result.scalar_one_or_none()
if not org:
logger.warning(f"No org found for Stripe customer {customer_id}")
return WebhookResponse(status="error", event_type=event_type)
old_tier = org.tier
org.tier = "free"
org.stripe_customer_id = None # Clear stripe ID on cancellation
await db.flush()
logger.info(f"Organization {org.slug} downgraded from {old_tier} to free (subscription deleted)")
return WebhookResponse(status="ok", event_type=event_type)
# ── Handle invoice.payment_succeeded ─────────────────────────────────
if event_type == "invoice.payment_succeeded":
# Payment renewed successfully — org keeps its tier
# We could update stripe_customer_id here if needed
invoice = event.get("data", {}).get("object", {})
customer_id = invoice.get("customer")
if customer_id:
result = await db.execute(
select(Organization).where(
Organization.stripe_customer_id == customer_id
)
)
org = result.scalar_one_or_none()
if org:
logger.info(f"Invoice payment succeeded for org {org.slug} (tier: {org.tier})")
return WebhookResponse(status="ok", event_type=event_type)
# ── Handle invoice.payment_failed ────────────────────────────────────
if event_type == "invoice.payment_failed":
invoice = event.get("data", {}).get("object", {})
customer_id = invoice.get("customer")
if customer_id:
result = await db.execute(
select(Organization).where(
Organization.stripe_customer_id == customer_id
)
)
org = result.scalar_one_or_none()
if org:
logger.warning(
f"Invoice payment failed for org {org.slug} (tier: {org.tier}). "
"Payment retry will be attempted by Stripe."
)
# Don't downgrade yet — Stripe will retry per its configured rules
# After all retries fail, customer.subscription.deleted will fire
return WebhookResponse(status="ok", event_type=event_type)
# Fallback (shouldn't reach here due to HANDLED_EVENTS check)
return WebhookResponse(status="ignored", event_type=event_type)

View file

@ -6,12 +6,14 @@ from app.api.monitors import router as monitors_router
from app.api.subscribers import router as subscribers_router
from app.api.settings import router as settings_router
from app.api.organizations import router as organizations_router
from app.api.billing import router as billing_router
from app.routes.auth import router as auth_router
api_v1_router = APIRouter()
api_v1_router.include_router(auth_router, tags=["auth"])
api_v1_router.include_router(organizations_router, prefix="/organizations", tags=["organizations"])
api_v1_router.include_router(billing_router, prefix="/billing", tags=["billing"])
api_v1_router.include_router(services_router, prefix="/services", tags=["services"])
api_v1_router.include_router(incidents_router, prefix="/incidents", tags=["incidents"])
api_v1_router.include_router(monitors_router, prefix="/monitors", tags=["monitors"])

View file

@ -31,6 +31,12 @@ class Settings(BaseSettings):
# Uptime monitoring
monitor_check_interval: int = 60
# Stripe Checkout Links
stripe_pro_checkout_url: str = "" # e.g. https://buy.stripe.com/xxxx_pro
stripe_team_checkout_url: str = "" # e.g. https://buy.stripe.com/xxxx_team
stripe_webhook_secret: str = "" # e.g. whsec_xxxx
stripe_api_key: str = "" # e.g. sk_test_xxxx (needed for webhook verification)
model_config = {"env_file": ".env", "env_file_encoding": "utf-8"}
@property

View file

@ -24,6 +24,7 @@ dependencies = [
"passlib[bcrypt]>=1.7,<2.0",
"bcrypt==4.0.1",
"email-validator>=2.0,<3.0",
"stripe>=5.0,<16.0",
]
[project.optional-dependencies]

822
tests/test_billing.py Normal file
View file

@ -0,0 +1,822 @@
"""Tests for Stripe Checkout Link billing integration.
Covers:
- Checkout URL generation for pro/team tiers
- Invalid tier handling
- Missing checkout URL handling
- Billing status endpoint
- Webhook handling (checkout.session.completed, subscription events)
- Manual HMAC signature verification
- Tier determination from payment amounts
"""
import hashlib
import hmac
import json
import time
import pytest
import pytest_asyncio
from httpx import ASGITransport, AsyncClient
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from app.database import Base
from app.dependencies import get_db
from app.auth import create_access_token, hash_password
from app.main import app
from app.models.saas_models import Organization, OrganizationMember, User
# ── Test database setup ──────────────────────────────────────────────────────
TEST_DATABASE_URL = "sqlite+aiosqlite:///:memory:"
test_engine = create_async_engine(TEST_DATABASE_URL, echo=False)
TestSessionLocal = async_sessionmaker(
test_engine, class_=AsyncSession, expire_on_commit=False
)
@pytest_asyncio.fixture(scope="session", autouse=True)
async def setup_database():
"""Create all tables once for the test session."""
import app.models.saas_models # noqa: F401
import app.models.models # noqa: F401
async with test_engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield
async with test_engine.begin() as conn:
await conn.run_sync(Base.metadata.drop_all)
@pytest_asyncio.fixture
async def db_session():
"""Provide a clean database session for each test."""
async with TestSessionLocal() as session:
yield session
await session.rollback()
@pytest_asyncio.fixture
async def client(db_session: AsyncSession):
"""Provide an HTTP test client with DB dependency override."""
async def override_get_db():
yield db_session
app.dependency_overrides[get_db] = override_get_db
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as ac:
yield ac
app.dependency_overrides.clear()
async def _create_user_and_org(db: AsyncSession, tier: str = "free"):
"""Helper to create a user, org, and auth token."""
user = User(
email=f"test-{tier}-{hashlib.md5(str(time.time()).encode()).hexdigest()[:8]}@example.com",
password_hash=hash_password("testpass123"),
)
db.add(user)
await db.flush()
org = Organization(
name=f"Test Org {tier}",
slug=f"test-org-{tier}-{hashlib.md5(str(time.time()).encode()).hexdigest()[:8]}",
tier=tier,
)
db.add(org)
await db.flush()
membership = OrganizationMember(
organization_id=org.id,
user_id=user.id,
role="owner",
)
db.add(membership)
await db.flush()
token = create_access_token(user.id)
return user, org, token
# ── Checkout endpoint tests ──────────────────────────────────────────────────
class TestCheckoutEndpoint:
"""Tests for GET /api/v1/billing/checkout/{tier}"""
@pytest.mark.asyncio
async def test_checkout_pro_returns_redirect_url(self, client, db_session):
"""Pro checkout should return a redirect URL with org ID and email."""
from app.config import settings
original_pro_url = settings.stripe_pro_checkout_url
settings.stripe_pro_checkout_url = "https://buy.stripe.com/test_pro"
user, org, token = await _create_user_and_org(db_session)
response = await client.get(
"/api/v1/billing/checkout/pro",
headers={"Authorization": f"Bearer {token}"},
)
settings.stripe_pro_checkout_url = original_pro_url
assert response.status_code == 200
data = response.json()
assert "redirect_url" in data
assert data["tier"] == "pro"
assert data["organization_id"] == org.id
assert "client_reference_id" in data["redirect_url"]
assert "prefilled_email" in data["redirect_url"]
assert data["redirect_url"].startswith("https://buy.stripe.com/test_pro?")
@pytest.mark.asyncio
async def test_checkout_team_returns_redirect_url(self, client, db_session):
"""Team checkout should return a redirect URL."""
from app.config import settings
original_team_url = settings.stripe_team_checkout_url
settings.stripe_team_checkout_url = "https://buy.stripe.com/test_team"
user, org, token = await _create_user_and_org(db_session)
response = await client.get(
"/api/v1/billing/checkout/team",
headers={"Authorization": f"Bearer {token}"},
)
settings.stripe_team_checkout_url = original_team_url
assert response.status_code == 200
data = response.json()
assert data["tier"] == "team"
assert data["organization_id"] == org.id
@pytest.mark.asyncio
async def test_checkout_invalid_tier(self, client, db_session):
"""Invalid tier should return 400."""
user, org, token = await _create_user_and_org(db_session)
response = await client.get(
"/api/v1/billing/checkout/enterprise",
headers={"Authorization": f"Bearer {token}"},
)
assert response.status_code == 400
assert "Invalid tier" in response.json()["detail"]
@pytest.mark.asyncio
async def test_checkout_free_tier_invalid(self, client, db_session):
"""Free tier checkout should return 400 (can't upgrade to free)."""
user, org, token = await _create_user_and_org(db_session)
response = await client.get(
"/api/v1/billing/checkout/free",
headers={"Authorization": f"Bearer {token}"},
)
assert response.status_code == 400
@pytest.mark.asyncio
async def test_checkout_missing_url_returns_503(self, client, db_session):
"""If checkout URL is not configured, return 503."""
from app.config import settings
original_pro_url = settings.stripe_pro_checkout_url
settings.stripe_pro_checkout_url = "" # Empty = not configured
user, org, token = await _create_user_and_org(db_session)
response = await client.get(
"/api/v1/billing/checkout/pro",
headers={"Authorization": f"Bearer {token}"},
)
settings.stripe_pro_checkout_url = original_pro_url
assert response.status_code == 503
assert "not configured" in response.json()["detail"].lower()
@pytest.mark.asyncio
async def test_checkout_requires_auth(self, client, db_session):
"""Checkout endpoint requires authentication."""
response = await client.get("/api/v1/billing/checkout/pro")
assert response.status_code == 401
# ── Billing status endpoint tests ─────────────────────────────────────────────
class TestBillingStatusEndpoint:
"""Tests for GET /api/v1/billing/status"""
@pytest.mark.asyncio
async def test_billing_status_free_org(self, client, db_session):
"""Free org should see pro and team checkout URLs."""
from app.config import settings
original_pro = settings.stripe_pro_checkout_url
original_team = settings.stripe_team_checkout_url
settings.stripe_pro_checkout_url = "https://buy.stripe.com/pro"
settings.stripe_team_checkout_url = "https://buy.stripe.com/team"
user, org, token = await _create_user_and_org(db_session, tier="free")
response = await client.get(
"/api/v1/billing/status",
headers={"Authorization": f"Bearer {token}"},
)
settings.stripe_pro_checkout_url = original_pro
settings.stripe_team_checkout_url = original_team
assert response.status_code == 200
data = response.json()
assert data["tier"] == "free"
assert "tier_info" in data
assert "pro" in data["checkout_urls"]
assert "team" in data["checkout_urls"]
@pytest.mark.asyncio
async def test_billing_status_pro_org(self, client, db_session):
"""Pro org should only see team upgrade URL."""
from app.config import settings
original_team = settings.stripe_team_checkout_url
settings.stripe_team_checkout_url = "https://buy.stripe.com/team"
user, org, token = await _create_user_and_org(db_session, tier="pro")
response = await client.get(
"/api/v1/billing/status",
headers={"Authorization": f"Bearer {token}"},
)
settings.stripe_team_checkout_url = original_team
assert response.status_code == 200
data = response.json()
assert data["tier"] == "pro"
assert "pro" not in data["checkout_urls"]
assert "team" in data["checkout_urls"]
@pytest.mark.asyncio
async def test_billing_status_team_org(self, client, db_session):
"""Team org should see no upgrade URLs (already at top tier)."""
user, org, token = await _create_user_and_org(db_session, tier="team")
response = await client.get(
"/api/v1/billing/status",
headers={"Authorization": f"Bearer {token}"},
)
assert response.status_code == 200
data = response.json()
assert data["tier"] == "team"
assert data["checkout_urls"] == {}
@pytest.mark.asyncio
async def test_billing_status_shows_tier_limits(self, client, db_session):
"""Billing status should include tier limits from tier_limits module."""
user, org, token = await _create_user_and_org(db_session, tier="free")
response = await client.get(
"/api/v1/billing/status",
headers={"Authorization": f"Bearer {token}"},
)
data = response.json()
assert "tier_info" in data
assert data["tier_info"]["tier"] == "free"
assert "limits" in data["tier_info"]
@pytest.mark.asyncio
async def test_billing_status_requires_auth(self, client, db_session):
"""Billing status endpoint requires authentication."""
response = await client.get("/api/v1/billing/status")
assert response.status_code == 401
# ── Success/cancel callback tests ────────────────────────────────────────────
class TestBillingCallbacks:
"""Tests for GET /api/v1/billing/success and /cancel"""
@pytest.mark.asyncio
async def test_success_callback(self, client, db_session):
"""Success callback should return success message."""
response = await client.get(
"/api/v1/billing/success?session_id=cs_test_123"
)
assert response.status_code == 200
data = response.json()
assert "activated" in data["message"].lower() or "upgraded" in data["message"].lower()
@pytest.mark.asyncio
async def test_success_callback_without_session_id(self, client, db_session):
"""Success callback works even without session_id (webhook already processed)."""
response = await client.get("/api/v1/billing/success")
assert response.status_code == 200
@pytest.mark.asyncio
async def test_cancel_callback(self, client, db_session):
"""Cancel callback should return cancellation message."""
response = await client.get("/api/v1/billing/cancel")
assert response.status_code == 200
data = response.json()
assert "canceled" in data["message"].lower()
# ── Webhook tests ─────────────────────────────────────────────────────────────
class TestWebhookHandler:
"""Tests for POST /api/v1/billing/webhook"""
def _make_checkout_session_event(
self, org_id: str, amount_total: int = 900, customer_id: str = "cus_test123"
) -> dict:
"""Create a mock checkout.session.completed event."""
return {
"id": "evt_test_123",
"object": "event",
"type": "checkout.session.completed",
"data": {
"object": {
"id": "cs_test_123",
"object": "checkout.session",
"client_reference_id": org_id,
"customer": customer_id,
"amount_total": amount_total,
"line_items": {"data": []},
}
},
}
def _make_subscription_event(
self, customer_id: str, event_type: str, amount: int = 900
) -> dict:
"""Create a mock subscription event."""
return {
"id": "evt_test_sub_123",
"object": "event",
"type": event_type,
"data": {
"object": {
"id": "sub_test_123",
"object": "subscription",
"customer": customer_id,
"metadata": {},
"items": {
"data": [
{
"plan": {"amount": amount},
}
]
},
}
},
}
@pytest.mark.asyncio
async def test_webhook_checkout_session_pro(self, client, db_session):
"""checkout.session.completed with $9 amount should set tier to pro."""
from app.config import settings
original_secret = settings.stripe_webhook_secret
original_api_key = settings.stripe_api_key
# Disable signature verification for tests
settings.stripe_webhook_secret = ""
settings.stripe_api_key = ""
user, org, token = await _create_user_and_org(db_session, tier="free")
event = self._make_checkout_session_event(org.id, amount_total=900)
response = await client.post(
"/api/v1/billing/webhook",
json=event,
)
settings.stripe_webhook_secret = original_secret
settings.stripe_api_key = original_api_key
assert response.status_code == 200
data = response.json()
assert data["status"] == "ok"
assert data["event_type"] == "checkout.session.completed"
# Verify org was upgraded
result = await db_session.execute(
select(Organization).where(Organization.id == org.id)
)
updated_org = result.scalar_one()
assert updated_org.tier == "pro"
assert updated_org.stripe_customer_id == "cus_test123"
@pytest.mark.asyncio
async def test_webhook_checkout_session_team(self, client, db_session):
"""checkout.session.completed with $29 amount should set tier to team."""
from app.config import settings
original_secret = settings.stripe_webhook_secret
original_api_key = settings.stripe_api_key
settings.stripe_webhook_secret = ""
settings.stripe_api_key = ""
user, org, token = await _create_user_and_org(db_session, tier="free")
event = self._make_checkout_session_event(org.id, amount_total=2900)
response = await client.post(
"/api/v1/billing/webhook",
json=event,
)
settings.stripe_webhook_secret = original_secret
settings.stripe_api_key = original_api_key
assert response.status_code == 200
result = await db_session.execute(
select(Organization).where(Organization.id == org.id)
)
updated_org = result.scalar_one()
assert updated_org.tier == "team"
@pytest.mark.asyncio
async def test_webhook_subscription_deleted_downgrades_to_free(self, client, db_session):
"""customer.subscription.deleted should downgrade org to free."""
from app.config import settings
original_secret = settings.stripe_webhook_secret
original_api_key = settings.stripe_api_key
settings.stripe_webhook_secret = ""
settings.stripe_api_key = ""
user, org, token = await _create_user_and_org(db_session, tier="pro")
# Manually set stripe_customer_id
org.stripe_customer_id = "cus_delete_test"
await db_session.flush()
event = self._make_subscription_event(
customer_id="cus_delete_test",
event_type="customer.subscription.deleted",
)
response = await client.post(
"/api/v1/billing/webhook",
json=event,
)
settings.stripe_webhook_secret = original_secret
settings.stripe_api_key = original_api_key
assert response.status_code == 200
result = await db_session.execute(
select(Organization).where(Organization.id == org.id)
)
updated_org = result.scalar_one()
assert updated_org.tier == "free"
@pytest.mark.asyncio
async def test_webhook_subscription_updated(self, client, db_session):
"""customer.subscription.updated should update org tier based on amount."""
from app.config import settings
original_secret = settings.stripe_webhook_secret
original_api_key = settings.stripe_api_key
settings.stripe_webhook_secret = ""
settings.stripe_api_key = ""
user, org, token = await _create_user_and_org(db_session, tier="pro")
org.stripe_customer_id = "cus_update_test"
await db_session.flush()
event = self._make_subscription_event(
customer_id="cus_update_test",
event_type="customer.subscription.updated",
amount=2900, # Upgrade to team
)
response = await client.post(
"/api/v1/billing/webhook",
json=event,
)
settings.stripe_webhook_secret = original_secret
settings.stripe_api_key = original_api_key
assert response.status_code == 200
result = await db_session.execute(
select(Organization).where(Organization.id == org.id)
)
updated_org = result.scalar_one()
assert updated_org.tier == "team"
@pytest.mark.asyncio
async def test_webhook_ignores_unhandled_events(self, client, db_session):
"""Unhandled event types should be ignored."""
from app.config import settings
original_secret = settings.stripe_webhook_secret
settings.stripe_webhook_secret = ""
event = {
"id": "evt_test_ignore",
"type": "account.updated",
"data": {"object": {}},
}
response = await client.post(
"/api/v1/billing/webhook",
json=event,
)
settings.stripe_webhook_secret = original_secret
assert response.status_code == 200
data = response.json()
assert data["status"] == "ignored"
@pytest.mark.asyncio
async def test_webhook_invalid_signature_rejected(self, client, db_session):
"""Webhook with invalid signature should be rejected."""
from app.config import settings
original_secret = settings.stripe_webhook_secret
original_api_key = settings.stripe_api_key
# Set up webhook secret but no API key (forces manual verification)
settings.stripe_webhook_secret = "whsec_test_secret_key_12345"
settings.stripe_api_key = ""
event = self._make_checkout_session_event("fake_org_id")
response = await client.post(
"/api/v1/billing/webhook",
json=event,
headers={"stripe-signature": "t=1234,v1=invalid_signature_here"},
)
settings.stripe_webhook_secret = original_secret
settings.stripe_api_key = original_api_key
assert response.status_code == 400
assert "signature" in response.json()["detail"].lower()
@pytest.mark.asyncio
async def test_webhook_checkout_missing_org_id(self, client, db_session):
"""Webhook without client_reference_id should return error."""
from app.config import settings
original_secret = settings.stripe_webhook_secret
settings.stripe_webhook_secret = ""
event = {
"id": "evt_test_no_org",
"type": "checkout.session.completed",
"data": {
"object": {
"id": "cs_test_no_org",
"client_reference_id": None,
"customer": "cus_123",
"amount_total": 900,
}
},
}
response = await client.post(
"/api/v1/billing/webhook",
json=event,
)
settings.stripe_webhook_secret = original_secret
assert response.status_code == 200
assert response.json()["status"] == "error"
@pytest.mark.asyncio
async def test_webhook_invoice_payment_succeeded(self, client, db_session):
"""invoice.payment_succeeded should acknowledge without downgrading."""
from app.config import settings
original_secret = settings.stripe_webhook_secret
original_api_key = settings.stripe_api_key
settings.stripe_webhook_secret = ""
settings.stripe_api_key = ""
user, org, token = await _create_user_and_org(db_session, tier="pro")
org.stripe_customer_id = "cus_invoice_test"
await db_session.flush()
event = {
"id": "evt_invoice_paid",
"type": "invoice.payment_succeeded",
"data": {
"object": {
"id": "in_test_123",
"customer": "cus_invoice_test",
}
},
}
response = await client.post(
"/api/v1/billing/webhook",
json=event,
)
settings.stripe_webhook_secret = original_secret
settings.stripe_api_key = original_api_key
assert response.status_code == 200
data = response.json()
assert data["status"] == "ok"
# Verify tier wasn't changed
result = await db_session.execute(
select(Organization).where(Organization.id == org.id)
)
updated_org = result.scalar_one()
assert updated_org.tier == "pro"
@pytest.mark.asyncio
async def test_webhook_invoice_payment_failed(self, client, db_session):
"""invoice.payment_failed should not downgrade (Stripe retries)."""
from app.config import settings
original_secret = settings.stripe_webhook_secret
original_api_key = settings.stripe_api_key
settings.stripe_webhook_secret = ""
settings.stripe_api_key = ""
user, org, token = await _create_user_and_org(db_session, tier="pro")
org.stripe_customer_id = "cus_fail_test"
await db_session.flush()
event = {
"id": "evt_invoice_failed",
"type": "invoice.payment_failed",
"data": {
"object": {
"id": "in_test_fail",
"customer": "cus_fail_test",
}
},
}
response = await client.post(
"/api/v1/billing/webhook",
json=event,
)
settings.stripe_webhook_secret = original_secret
settings.stripe_api_key = original_api_key
assert response.status_code == 200
data = response.json()
assert data["status"] == "ok"
# Verify tier wasn't changed (payment will be retried)
result = await db_session.execute(
select(Organization).where(Organization.id == org.id)
)
updated_org = result.scalar_one()
assert updated_org.tier == "pro"
# ── Tier determination helper tests ───────────────────────────────────────────
class TestTierDetermination:
"""Tests for _determine_tier_from_amount helper."""
def test_pro_from_amount(self):
from app.api.billing import _determine_tier_from_amount
assert _determine_tier_from_amount(900) == "pro"
assert _determine_tier_from_amount(1500) == "pro"
def test_team_from_amount(self):
from app.api.billing import _determine_tier_from_amount
assert _determine_tier_from_amount(2900) == "team"
assert _determine_tier_from_amount(5000) == "team"
assert _determine_tier_from_amount(9900) == "team"
def test_free_from_small_amount(self):
from app.api.billing import _determine_tier_from_amount
assert _determine_tier_from_amount(0) == "free"
assert _determine_tier_from_amount(100) == "free"
# ── Integration: checkout → webhook → tier update ────────────────────────────
class TestBillingIntegration:
"""End-to-end style tests for the billing flow."""
@pytest.mark.asyncio
async def test_free_to_pro_upgrade_flow(self, client, db_session):
"""Complete flow: free org → checkout → webhook → pro tier."""
from app.config import settings
original_secret = settings.stripe_webhook_secret
original_pro_url = settings.stripe_pro_checkout_url
original_team_url = settings.stripe_team_checkout_url
original_api_key = settings.stripe_api_key
settings.stripe_webhook_secret = ""
settings.stripe_pro_checkout_url = "https://buy.stripe.com/test_pro_link"
settings.stripe_team_checkout_url = "https://buy.stripe.com/test_team_link"
settings.stripe_api_key = ""
user, org, token = await _create_user_and_org(db_session, tier="free")
# Step 1: Get checkout URL
checkout_resp = await client.get(
"/api/v1/billing/checkout/pro",
headers={"Authorization": f"Bearer {token}"},
)
assert checkout_resp.status_code == 200
checkout_data = checkout_resp.json()
assert checkout_data["redirect_url"].startswith("https://buy.stripe.com/test_pro_link")
assert "client_reference_id=" + org.id in checkout_data["redirect_url"]
# Step 2: Simulate Stripe webhook
webhook_event = {
"id": "evt_integration_test",
"type": "checkout.session.completed",
"data": {
"object": {
"id": "cs_integration_test",
"client_reference_id": org.id,
"customer": "cus_integration_test",
"amount_total": 900,
"line_items": {"data": []},
}
},
}
webhook_resp = await client.post(
"/api/v1/billing/webhook",
json=webhook_event,
)
assert webhook_resp.status_code == 200
# Step 3: Verify org is now pro
status_resp = await client.get(
"/api/v1/billing/status",
headers={"Authorization": f"Bearer {token}"},
)
assert status_resp.status_code == 200
assert status_resp.json()["tier"] == "pro"
# Step 4: Verify billing status shows team upgrade available
assert "team" in status_resp.json()["checkout_urls"]
# Restore settings
settings.stripe_webhook_secret = original_secret
settings.stripe_pro_checkout_url = original_pro_url
settings.stripe_team_checkout_url = original_team_url
settings.stripe_api_key = original_api_key
@pytest.mark.asyncio
async def test_pro_to_free_downgrade_via_webhook(self, client, db_session):
"""Complete flow: pro org → subscription deleted → free tier."""
from app.config import settings
original_secret = settings.stripe_webhook_secret
original_api_key = settings.stripe_api_key
settings.stripe_webhook_secret = ""
settings.stripe_api_key = ""
user, org, token = await _create_user_and_org(db_session, tier="pro")
org.stripe_customer_id = "cus_downgrade_test"
await db_session.flush()
# Simulate subscription cancellation
event = {
"id": "evt_cancel_test",
"type": "customer.subscription.deleted",
"data": {
"object": {
"id": "sub_cancel_test",
"customer": "cus_downgrade_test",
}
},
}
response = await client.post(
"/api/v1/billing/webhook",
json=event,
)
settings.stripe_webhook_secret = original_secret
settings.stripe_api_key = original_api_key
assert response.status_code == 200
# Verify org is now free
result = await db_session.execute(
select(Organization).where(Organization.id == org.id)
)
updated_org = result.scalar_one()
assert updated_org.tier == "free"
# Verify billing status shows upgrade options
status_resp = await client.get(
"/api/v1/billing/status",
headers={"Authorization": f"Bearer {token}"},
)
data = status_resp.json()
assert data["tier"] == "free"
@pytest.mark.asyncio
async def test_checkout_url_includes_prefilled_email(self, client, db_session):
"""Checkout URL should include the user's email as prefilled_email."""
from app.config import settings
original_pro = settings.stripe_pro_checkout_url
settings.stripe_pro_checkout_url = "https://buy.stripe.com/test_pro"
user, org, token = await _create_user_and_org(db_session)
response = await client.get(
"/api/v1/billing/checkout/pro",
headers={"Authorization": f"Bearer {token}"},
)
settings.stripe_pro_checkout_url = original_pro
data = response.json()
assert f"prefilled_email={user.email}" in data["redirect_url"] or \
f"prefilled_email={user.email.replace('@', '%40')}" in data["redirect_url"]