feat: Stripe Checkout Link integration — billing API with 29 tests
- Add billing module (app/api/billing.py) with 5 API endpoints:
- GET /api/v1/billing/checkout/{tier} — redirect to Stripe Payment Links
- GET /api/v1/billing/status — current org tier, limits, upgrade URLs
- GET /api/v1/billing/success — Stripe success callback
- GET /api/v1/billing/cancel — Stripe cancel callback
- POST /api/v1/billing/webhook — handles 5 Stripe event types
- Zero-code payment flow: uses pre-configured Stripe Payment Links
with client_reference_id (org ID) and prefilled_email params
- Webhook handler processes checkout.session.completed,
customer.subscription.updated/deleted, invoice events
- Stripe signature verification via stripe library (primary)
or manual HMAC-SHA256 (fallback)
- Tier determination from payment amount: =pro, 9=team
- 4 new config settings in app/config.py:
stripe_pro_checkout_url, stripe_team_checkout_url,
stripe_webhook_secret, stripe_api_key
- Added stripe>=5.0,<16.0 dependency
- 29 tests in tests/test_billing.py (all passing)
- Total: 98 tests passing (69 existing + 29 new)
This commit is contained in:
parent
b7a8142ca0
commit
158a6ee716
6 changed files with 1311 additions and 1 deletions
|
|
@ -25,3 +25,10 @@ WEBHOOK_NOTIFY_URL=
|
||||||
|
|
||||||
# Uptime Monitoring
|
# Uptime Monitoring
|
||||||
MONITOR_CHECK_INTERVAL=60
|
MONITOR_CHECK_INTERVAL=60
|
||||||
|
|
||||||
|
# Stripe Checkout Links (set these in production)
|
||||||
|
# Create Payment Links in Stripe Dashboard → Products → Payment Links
|
||||||
|
STRIPE_PRO_CHECKOUT_URL=
|
||||||
|
STRIPE_TEAM_CHECKOUT_URL=
|
||||||
|
STRIPE_WEBHOOK_SECRET=
|
||||||
|
STRIPE_API_KEY=
|
||||||
472
app/api/billing.py
Normal file
472
app/api/billing.py
Normal file
|
|
@ -0,0 +1,472 @@
|
||||||
|
"""Billing API: Stripe Checkout Link integration and webhook handler.
|
||||||
|
|
||||||
|
This module implements the "zero-code" payment flow using Stripe Payment Links:
|
||||||
|
1. User clicks "Upgrade" → redirected to a Stripe-hosted checkout page
|
||||||
|
2. After payment, Stripe sends a webhook event to /api/v1/billing/webhook
|
||||||
|
3. The webhook handler updates the org's tier in the database
|
||||||
|
4. Stripe redirects the user back to /api/v1/billing/success
|
||||||
|
|
||||||
|
This approach requires no Stripe server-side SDK for checkout creation —
|
||||||
|
just configuration of Payment Link URLs and a webhook endpoint for
|
||||||
|
receiving payment confirmations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import hmac
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from urllib.parse import urlencode
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.auth import get_current_user, get_current_org
|
||||||
|
from app.config import settings
|
||||||
|
from app.dependencies import get_db
|
||||||
|
from app.models.saas_models import Organization, OrganizationMember, User
|
||||||
|
from app.services.tier_limits import get_tier_info
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter(tags=["billing"])
|
||||||
|
|
||||||
|
# ── Tier-to-URL mapping ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _get_checkout_url(tier: str) -> str:
|
||||||
|
"""Return the Stripe Payment Link URL for the given tier.
|
||||||
|
|
||||||
|
Reads from settings at call time (not import time) so tests can override.
|
||||||
|
"""
|
||||||
|
if tier == "pro":
|
||||||
|
return settings.stripe_pro_checkout_url
|
||||||
|
elif tier == "team":
|
||||||
|
return settings.stripe_team_checkout_url
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
# ── Response schemas ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class CheckoutResponse(BaseModel):
|
||||||
|
"""Response for checkout redirect endpoint."""
|
||||||
|
redirect_url: str
|
||||||
|
tier: str
|
||||||
|
organization_id: str
|
||||||
|
|
||||||
|
|
||||||
|
class BillingStatusResponse(BaseModel):
|
||||||
|
"""Response for billing status endpoint."""
|
||||||
|
tier: str
|
||||||
|
tier_info: dict
|
||||||
|
stripe_customer_id: str | None = None
|
||||||
|
checkout_urls: dict[str, str]
|
||||||
|
|
||||||
|
|
||||||
|
class SuccessResponse(BaseModel):
|
||||||
|
"""Response for successful checkout callback."""
|
||||||
|
message: str
|
||||||
|
tier: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class WebhookResponse(BaseModel):
|
||||||
|
"""Response for webhook processing."""
|
||||||
|
status: str
|
||||||
|
event_type: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
# ── Endpoints ────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@router.get("/checkout/{tier}", response_model=CheckoutResponse)
|
||||||
|
async def create_checkout(
|
||||||
|
tier: str,
|
||||||
|
user: User = Depends(get_current_user),
|
||||||
|
org: Organization = Depends(get_current_org),
|
||||||
|
):
|
||||||
|
"""Generate a Stripe Checkout Link redirect URL for the selected tier.
|
||||||
|
|
||||||
|
This endpoint does NOT create a Stripe Checkout Session. Instead, it
|
||||||
|
returns the pre-configured Stripe Payment Link URL with the org's ID
|
||||||
|
attached as client_reference_id and the user's email prefilled.
|
||||||
|
|
||||||
|
The frontend should redirect the user to this URL. After completing
|
||||||
|
payment on Stripe's hosted checkout, Stripe will:
|
||||||
|
1. Send a webhook to /api/v1/billing/webhook
|
||||||
|
2. Redirect the user to /api/v1/billing/success
|
||||||
|
"""
|
||||||
|
if tier not in ("pro", "team"):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=f"Invalid tier '{tier}'. Must be 'pro' or 'team'.",
|
||||||
|
)
|
||||||
|
|
||||||
|
checkout_url = _get_checkout_url(tier)
|
||||||
|
if not checkout_url:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||||
|
detail=(
|
||||||
|
f"Checkout for '{tier}' tier is not configured yet. "
|
||||||
|
"Please contact support or set up Stripe Payment Links."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build redirect URL with org ID and prefilled email
|
||||||
|
params = urlencode({
|
||||||
|
"client_reference_id": org.id,
|
||||||
|
"prefilled_email": user.email,
|
||||||
|
})
|
||||||
|
redirect_url = f"{checkout_url}?{params}"
|
||||||
|
|
||||||
|
return CheckoutResponse(
|
||||||
|
redirect_url=redirect_url,
|
||||||
|
tier=tier,
|
||||||
|
organization_id=org.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/status", response_model=BillingStatusResponse)
|
||||||
|
async def get_billing_status(
|
||||||
|
org: Organization = Depends(get_current_org),
|
||||||
|
):
|
||||||
|
"""Get the current organization's billing status: tier, limits, and checkout URLs."""
|
||||||
|
# Only show checkout URLs for tiers the org can upgrade to
|
||||||
|
checkout_urls: dict[str, str] = {}
|
||||||
|
if org.tier == "free":
|
||||||
|
if settings.stripe_pro_checkout_url:
|
||||||
|
checkout_urls["pro"] = settings.stripe_pro_checkout_url
|
||||||
|
if settings.stripe_team_checkout_url:
|
||||||
|
checkout_urls["team"] = settings.stripe_team_checkout_url
|
||||||
|
elif org.tier == "pro":
|
||||||
|
if settings.stripe_team_checkout_url:
|
||||||
|
checkout_urls["team"] = settings.stripe_team_checkout_url
|
||||||
|
|
||||||
|
return BillingStatusResponse(
|
||||||
|
tier=org.tier or "free",
|
||||||
|
tier_info=get_tier_info(org),
|
||||||
|
stripe_customer_id=org.stripe_customer_id,
|
||||||
|
checkout_urls=checkout_urls,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/success", response_model=SuccessResponse)
|
||||||
|
async def billing_success(
|
||||||
|
session_id: str = "",
|
||||||
|
):
|
||||||
|
"""Stripe redirects here after a successful checkout.
|
||||||
|
|
||||||
|
In production, you would call stripe.checkout.sessions.retrieve(session_id)
|
||||||
|
to verify the session. For the MVP, the webhook handler already processed
|
||||||
|
the payment — this endpoint just confirms to the user that it worked.
|
||||||
|
|
||||||
|
The session_id is provided by Stripe's redirect URL.
|
||||||
|
"""
|
||||||
|
return SuccessResponse(
|
||||||
|
message="Subscription activated! Your plan has been upgraded.",
|
||||||
|
tier=None, # Webhook already set the tier; frontend should refetch
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/cancel")
|
||||||
|
async def billing_cancel():
|
||||||
|
"""Stripe redirects here if the user cancels checkout."""
|
||||||
|
return {
|
||||||
|
"message": "Checkout was canceled. Your current plan is unchanged.",
|
||||||
|
"tier_change": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ── Webhook handler ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
# Stripe event types we handle
|
||||||
|
HANDLED_EVENTS = {
|
||||||
|
"checkout.session.completed",
|
||||||
|
"customer.subscription.updated",
|
||||||
|
"customer.subscription.deleted",
|
||||||
|
"invoice.payment_succeeded",
|
||||||
|
"invoice.payment_failed",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _determine_tier_from_amount(amount_total: int) -> str:
|
||||||
|
"""Determine the tier from the payment amount (in cents).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
amount_total: Total amount in cents (e.g. 900 = $9.00)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tier string: "pro" or "team"
|
||||||
|
"""
|
||||||
|
if amount_total >= 2900: # $29.00+
|
||||||
|
return "team"
|
||||||
|
elif amount_total >= 900: # $9.00+
|
||||||
|
return "pro"
|
||||||
|
return "free"
|
||||||
|
|
||||||
|
|
||||||
|
def _determine_tier_from_price_id(price_id: str) -> str:
|
||||||
|
"""Determine the tier from a Stripe price ID.
|
||||||
|
|
||||||
|
This is a fallback when amount isn't available. In production,
|
||||||
|
you would map your actual Stripe price IDs here.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
price_id: Stripe price ID (e.g., price_1ABC123...)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tier string: "pro", "team", or "free"
|
||||||
|
"""
|
||||||
|
# Default mapping — override in production with your actual price IDs
|
||||||
|
# The stripe API can also be used to look up price details
|
||||||
|
return "pro" # Default fallback
|
||||||
|
|
||||||
|
|
||||||
|
async def _verify_stripe_signature(payload: bytes, sig_header: str) -> dict | None:
|
||||||
|
"""Verify the Stripe webhook signature using HMAC-SHA256.
|
||||||
|
|
||||||
|
Uses the stripe library if available, falls back to manual verification.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
payload: Raw request body bytes
|
||||||
|
sig_header: Stripe-Signature header value
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Parsed event dict if signature is valid, None otherwise
|
||||||
|
"""
|
||||||
|
if not settings.stripe_webhook_secret:
|
||||||
|
logger.warning("No Stripe webhook secret configured; skipping verification")
|
||||||
|
return json.loads(payload)
|
||||||
|
|
||||||
|
try:
|
||||||
|
import stripe
|
||||||
|
if settings.stripe_api_key:
|
||||||
|
stripe.api_key = settings.stripe_api_key
|
||||||
|
event = stripe.Webhook.construct_event(
|
||||||
|
payload, sig_header, settings.stripe_webhook_secret
|
||||||
|
)
|
||||||
|
return dict(event)
|
||||||
|
except ImportError:
|
||||||
|
logger.info("stripe library not available; using manual signature verification")
|
||||||
|
except stripe.error.SignatureVerificationError:
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Stripe webhook verification error: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Manual HMAC-SHA256 verification fallback
|
||||||
|
# Stripe webhook signatures have the format: t=v1(v1_sig),v0=v0(v0_sig)
|
||||||
|
try:
|
||||||
|
elements = sig_header.split(",")
|
||||||
|
timestamp = None
|
||||||
|
signatures = []
|
||||||
|
for element in elements:
|
||||||
|
key, value = element.split("=", 1)
|
||||||
|
if key == "t":
|
||||||
|
timestamp = value
|
||||||
|
elif key == "v1":
|
||||||
|
signatures.append(value)
|
||||||
|
|
||||||
|
if not timestamp or not signatures:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Compute expected signature
|
||||||
|
signed_payload = f"{timestamp}.{payload.decode('utf-8')}"
|
||||||
|
expected_sig = hmac.new(
|
||||||
|
settings.stripe_webhook_secret.encode("utf-8"),
|
||||||
|
signed_payload.encode("utf-8"),
|
||||||
|
hashlib.sha256,
|
||||||
|
).hexdigest()
|
||||||
|
|
||||||
|
# Check if any v1 signature matches
|
||||||
|
for sig in signatures:
|
||||||
|
if hmac.compare_digest(expected_sig, sig):
|
||||||
|
return json.loads(payload)
|
||||||
|
|
||||||
|
return None
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/webhook")
|
||||||
|
async def stripe_webhook(request: Request, db: AsyncSession = Depends(get_db)):
|
||||||
|
"""Handle Stripe webhook events for subscription changes.
|
||||||
|
|
||||||
|
This is the ONLY server-side Stripe code needed for the Payment Links
|
||||||
|
approach. It receives events when:
|
||||||
|
- checkout.session.completed: User completed a checkout
|
||||||
|
- customer.subscription.updated: Subscription tier changed
|
||||||
|
- customer.subscription.deleted: Subscription canceled (downgrade to free)
|
||||||
|
- invoice.payment_succeeded: Recurring payment succeeded
|
||||||
|
- invoice.payment_failed: Payment failed (may downgrade)
|
||||||
|
|
||||||
|
The webhook verifies the Stripe signature to ensure authenticity.
|
||||||
|
"""
|
||||||
|
body = await request.body()
|
||||||
|
sig_header = request.headers.get("stripe-signature", "")
|
||||||
|
|
||||||
|
# Verify the webhook signature
|
||||||
|
event = await _verify_stripe_signature(body, sig_header)
|
||||||
|
if event is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="Invalid webhook signature",
|
||||||
|
)
|
||||||
|
|
||||||
|
event_type = event.get("type", "")
|
||||||
|
|
||||||
|
# Only handle known event types
|
||||||
|
if event_type not in HANDLED_EVENTS:
|
||||||
|
logger.info(f"Ignoring unhandled Stripe event type: {event_type}")
|
||||||
|
return WebhookResponse(status="ignored", event_type=event_type)
|
||||||
|
|
||||||
|
# ── Handle checkout.session.completed ────────────────────────────────
|
||||||
|
if event_type == "checkout.session.completed":
|
||||||
|
session_obj = event.get("data", {}).get("object", {})
|
||||||
|
org_id = session_obj.get("client_reference_id")
|
||||||
|
customer_id = session_obj.get("customer")
|
||||||
|
amount_total = session_obj.get("amount_total", 0)
|
||||||
|
|
||||||
|
if not org_id:
|
||||||
|
logger.warning("checkout.session.completed event missing client_reference_id")
|
||||||
|
return WebhookResponse(status="error", event_type=event_type)
|
||||||
|
|
||||||
|
# Find the organization
|
||||||
|
result = await db.execute(
|
||||||
|
select(Organization).where(Organization.id == org_id)
|
||||||
|
)
|
||||||
|
org = result.scalar_one_or_none()
|
||||||
|
if not org:
|
||||||
|
logger.warning(f"Organization {org_id} not found for Stripe checkout")
|
||||||
|
return WebhookResponse(status="error", event_type=event_type)
|
||||||
|
|
||||||
|
# Determine the tier from the payment amount
|
||||||
|
new_tier = _determine_tier_from_amount(amount_total)
|
||||||
|
|
||||||
|
# Try to determine tier from line items if available
|
||||||
|
line_items = session_obj.get("line_items", {}).get("data", [])
|
||||||
|
if line_items:
|
||||||
|
price_id = line_items[0].get("price", {}).get("id", "")
|
||||||
|
if price_id:
|
||||||
|
new_tier = _determine_tier_from_price_id(price_id)
|
||||||
|
|
||||||
|
# Update the organization
|
||||||
|
old_tier = org.tier
|
||||||
|
org.tier = new_tier
|
||||||
|
if customer_id:
|
||||||
|
org.stripe_customer_id = customer_id
|
||||||
|
|
||||||
|
await db.flush()
|
||||||
|
logger.info(
|
||||||
|
f"Organization {org.slug} upgraded from {old_tier} to {new_tier} "
|
||||||
|
f"(Stripe session: {session_obj.get('id', 'unknown')})"
|
||||||
|
)
|
||||||
|
|
||||||
|
return WebhookResponse(status="ok", event_type=event_type)
|
||||||
|
|
||||||
|
# ── Handle customer.subscription.updated ─────────────────────────────
|
||||||
|
if event_type == "customer.subscription.updated":
|
||||||
|
subscription = event.get("data", {}).get("object", {})
|
||||||
|
customer_id = subscription.get("customer")
|
||||||
|
|
||||||
|
if not customer_id:
|
||||||
|
return WebhookResponse(status="error", event_type=event_type)
|
||||||
|
|
||||||
|
result = await db.execute(
|
||||||
|
select(Organization).where(
|
||||||
|
Organization.stripe_customer_id == customer_id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
org = result.scalar_one_or_none()
|
||||||
|
if not org:
|
||||||
|
logger.warning(f"No org found for Stripe customer {customer_id}")
|
||||||
|
return WebhookResponse(status="error", event_type=event_type)
|
||||||
|
|
||||||
|
# Determine tier from subscription metadata or amount
|
||||||
|
metadata = subscription.get("metadata", {})
|
||||||
|
tier_from_meta = metadata.get("tier")
|
||||||
|
|
||||||
|
# Check plan amount from items
|
||||||
|
items = subscription.get("items", {}).get("data", [])
|
||||||
|
if items:
|
||||||
|
amount = items[0].get("plan", {}).get("amount", 0)
|
||||||
|
new_tier = _determine_tier_from_amount(amount)
|
||||||
|
elif tier_from_meta in ("pro", "team"):
|
||||||
|
new_tier = tier_from_meta
|
||||||
|
else:
|
||||||
|
new_tier = org.tier # Keep current if indeterminate
|
||||||
|
|
||||||
|
old_tier = org.tier
|
||||||
|
org.tier = new_tier
|
||||||
|
await db.flush()
|
||||||
|
logger.info(f"Organization {org.slug} subscription updated: {old_tier} → {new_tier}")
|
||||||
|
|
||||||
|
return WebhookResponse(status="ok", event_type=event_type)
|
||||||
|
|
||||||
|
# ── Handle customer.subscription.deleted ──────────────────────────────
|
||||||
|
if event_type == "customer.subscription.deleted":
|
||||||
|
subscription = event.get("data", {}).get("object", {})
|
||||||
|
customer_id = subscription.get("customer")
|
||||||
|
|
||||||
|
if not customer_id:
|
||||||
|
return WebhookResponse(status="error", event_type=event_type)
|
||||||
|
|
||||||
|
result = await db.execute(
|
||||||
|
select(Organization).where(
|
||||||
|
Organization.stripe_customer_id == customer_id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
org = result.scalar_one_or_none()
|
||||||
|
if not org:
|
||||||
|
logger.warning(f"No org found for Stripe customer {customer_id}")
|
||||||
|
return WebhookResponse(status="error", event_type=event_type)
|
||||||
|
|
||||||
|
old_tier = org.tier
|
||||||
|
org.tier = "free"
|
||||||
|
org.stripe_customer_id = None # Clear stripe ID on cancellation
|
||||||
|
await db.flush()
|
||||||
|
logger.info(f"Organization {org.slug} downgraded from {old_tier} to free (subscription deleted)")
|
||||||
|
|
||||||
|
return WebhookResponse(status="ok", event_type=event_type)
|
||||||
|
|
||||||
|
# ── Handle invoice.payment_succeeded ─────────────────────────────────
|
||||||
|
if event_type == "invoice.payment_succeeded":
|
||||||
|
# Payment renewed successfully — org keeps its tier
|
||||||
|
# We could update stripe_customer_id here if needed
|
||||||
|
invoice = event.get("data", {}).get("object", {})
|
||||||
|
customer_id = invoice.get("customer")
|
||||||
|
|
||||||
|
if customer_id:
|
||||||
|
result = await db.execute(
|
||||||
|
select(Organization).where(
|
||||||
|
Organization.stripe_customer_id == customer_id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
org = result.scalar_one_or_none()
|
||||||
|
if org:
|
||||||
|
logger.info(f"Invoice payment succeeded for org {org.slug} (tier: {org.tier})")
|
||||||
|
|
||||||
|
return WebhookResponse(status="ok", event_type=event_type)
|
||||||
|
|
||||||
|
# ── Handle invoice.payment_failed ────────────────────────────────────
|
||||||
|
if event_type == "invoice.payment_failed":
|
||||||
|
invoice = event.get("data", {}).get("object", {})
|
||||||
|
customer_id = invoice.get("customer")
|
||||||
|
|
||||||
|
if customer_id:
|
||||||
|
result = await db.execute(
|
||||||
|
select(Organization).where(
|
||||||
|
Organization.stripe_customer_id == customer_id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
org = result.scalar_one_or_none()
|
||||||
|
if org:
|
||||||
|
logger.warning(
|
||||||
|
f"Invoice payment failed for org {org.slug} (tier: {org.tier}). "
|
||||||
|
"Payment retry will be attempted by Stripe."
|
||||||
|
)
|
||||||
|
# Don't downgrade yet — Stripe will retry per its configured rules
|
||||||
|
# After all retries fail, customer.subscription.deleted will fire
|
||||||
|
|
||||||
|
return WebhookResponse(status="ok", event_type=event_type)
|
||||||
|
|
||||||
|
# Fallback (shouldn't reach here due to HANDLED_EVENTS check)
|
||||||
|
return WebhookResponse(status="ignored", event_type=event_type)
|
||||||
|
|
@ -6,12 +6,14 @@ from app.api.monitors import router as monitors_router
|
||||||
from app.api.subscribers import router as subscribers_router
|
from app.api.subscribers import router as subscribers_router
|
||||||
from app.api.settings import router as settings_router
|
from app.api.settings import router as settings_router
|
||||||
from app.api.organizations import router as organizations_router
|
from app.api.organizations import router as organizations_router
|
||||||
|
from app.api.billing import router as billing_router
|
||||||
from app.routes.auth import router as auth_router
|
from app.routes.auth import router as auth_router
|
||||||
|
|
||||||
api_v1_router = APIRouter()
|
api_v1_router = APIRouter()
|
||||||
|
|
||||||
api_v1_router.include_router(auth_router, tags=["auth"])
|
api_v1_router.include_router(auth_router, tags=["auth"])
|
||||||
api_v1_router.include_router(organizations_router, prefix="/organizations", tags=["organizations"])
|
api_v1_router.include_router(organizations_router, prefix="/organizations", tags=["organizations"])
|
||||||
|
api_v1_router.include_router(billing_router, prefix="/billing", tags=["billing"])
|
||||||
api_v1_router.include_router(services_router, prefix="/services", tags=["services"])
|
api_v1_router.include_router(services_router, prefix="/services", tags=["services"])
|
||||||
api_v1_router.include_router(incidents_router, prefix="/incidents", tags=["incidents"])
|
api_v1_router.include_router(incidents_router, prefix="/incidents", tags=["incidents"])
|
||||||
api_v1_router.include_router(monitors_router, prefix="/monitors", tags=["monitors"])
|
api_v1_router.include_router(monitors_router, prefix="/monitors", tags=["monitors"])
|
||||||
|
|
|
||||||
|
|
@ -31,6 +31,12 @@ class Settings(BaseSettings):
|
||||||
# Uptime monitoring
|
# Uptime monitoring
|
||||||
monitor_check_interval: int = 60
|
monitor_check_interval: int = 60
|
||||||
|
|
||||||
|
# Stripe Checkout Links
|
||||||
|
stripe_pro_checkout_url: str = "" # e.g. https://buy.stripe.com/xxxx_pro
|
||||||
|
stripe_team_checkout_url: str = "" # e.g. https://buy.stripe.com/xxxx_team
|
||||||
|
stripe_webhook_secret: str = "" # e.g. whsec_xxxx
|
||||||
|
stripe_api_key: str = "" # e.g. sk_test_xxxx (needed for webhook verification)
|
||||||
|
|
||||||
model_config = {"env_file": ".env", "env_file_encoding": "utf-8"}
|
model_config = {"env_file": ".env", "env_file_encoding": "utf-8"}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|
|
||||||
|
|
@ -24,6 +24,7 @@ dependencies = [
|
||||||
"passlib[bcrypt]>=1.7,<2.0",
|
"passlib[bcrypt]>=1.7,<2.0",
|
||||||
"bcrypt==4.0.1",
|
"bcrypt==4.0.1",
|
||||||
"email-validator>=2.0,<3.0",
|
"email-validator>=2.0,<3.0",
|
||||||
|
"stripe>=5.0,<16.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
|
|
|
||||||
822
tests/test_billing.py
Normal file
822
tests/test_billing.py
Normal file
|
|
@ -0,0 +1,822 @@
|
||||||
|
"""Tests for Stripe Checkout Link billing integration.
|
||||||
|
|
||||||
|
Covers:
|
||||||
|
- Checkout URL generation for pro/team tiers
|
||||||
|
- Invalid tier handling
|
||||||
|
- Missing checkout URL handling
|
||||||
|
- Billing status endpoint
|
||||||
|
- Webhook handling (checkout.session.completed, subscription events)
|
||||||
|
- Manual HMAC signature verification
|
||||||
|
- Tier determination from payment amounts
|
||||||
|
"""
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import hmac
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
from httpx import ASGITransport, AsyncClient
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||||
|
|
||||||
|
from app.database import Base
|
||||||
|
from app.dependencies import get_db
|
||||||
|
from app.auth import create_access_token, hash_password
|
||||||
|
from app.main import app
|
||||||
|
from app.models.saas_models import Organization, OrganizationMember, User
|
||||||
|
|
||||||
|
# ── Test database setup ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
TEST_DATABASE_URL = "sqlite+aiosqlite:///:memory:"
|
||||||
|
test_engine = create_async_engine(TEST_DATABASE_URL, echo=False)
|
||||||
|
TestSessionLocal = async_sessionmaker(
|
||||||
|
test_engine, class_=AsyncSession, expire_on_commit=False
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture(scope="session", autouse=True)
|
||||||
|
async def setup_database():
|
||||||
|
"""Create all tables once for the test session."""
|
||||||
|
import app.models.saas_models # noqa: F401
|
||||||
|
import app.models.models # noqa: F401
|
||||||
|
|
||||||
|
async with test_engine.begin() as conn:
|
||||||
|
await conn.run_sync(Base.metadata.create_all)
|
||||||
|
yield
|
||||||
|
async with test_engine.begin() as conn:
|
||||||
|
await conn.run_sync(Base.metadata.drop_all)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def db_session():
|
||||||
|
"""Provide a clean database session for each test."""
|
||||||
|
async with TestSessionLocal() as session:
|
||||||
|
yield session
|
||||||
|
await session.rollback()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def client(db_session: AsyncSession):
|
||||||
|
"""Provide an HTTP test client with DB dependency override."""
|
||||||
|
async def override_get_db():
|
||||||
|
yield db_session
|
||||||
|
|
||||||
|
app.dependency_overrides[get_db] = override_get_db
|
||||||
|
transport = ASGITransport(app=app)
|
||||||
|
async with AsyncClient(transport=transport, base_url="http://test") as ac:
|
||||||
|
yield ac
|
||||||
|
app.dependency_overrides.clear()
|
||||||
|
|
||||||
|
|
||||||
|
async def _create_user_and_org(db: AsyncSession, tier: str = "free"):
|
||||||
|
"""Helper to create a user, org, and auth token."""
|
||||||
|
user = User(
|
||||||
|
email=f"test-{tier}-{hashlib.md5(str(time.time()).encode()).hexdigest()[:8]}@example.com",
|
||||||
|
password_hash=hash_password("testpass123"),
|
||||||
|
)
|
||||||
|
db.add(user)
|
||||||
|
await db.flush()
|
||||||
|
|
||||||
|
org = Organization(
|
||||||
|
name=f"Test Org {tier}",
|
||||||
|
slug=f"test-org-{tier}-{hashlib.md5(str(time.time()).encode()).hexdigest()[:8]}",
|
||||||
|
tier=tier,
|
||||||
|
)
|
||||||
|
db.add(org)
|
||||||
|
await db.flush()
|
||||||
|
|
||||||
|
membership = OrganizationMember(
|
||||||
|
organization_id=org.id,
|
||||||
|
user_id=user.id,
|
||||||
|
role="owner",
|
||||||
|
)
|
||||||
|
db.add(membership)
|
||||||
|
await db.flush()
|
||||||
|
|
||||||
|
token = create_access_token(user.id)
|
||||||
|
return user, org, token
|
||||||
|
|
||||||
|
|
||||||
|
# ── Checkout endpoint tests ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestCheckoutEndpoint:
|
||||||
|
"""Tests for GET /api/v1/billing/checkout/{tier}"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_checkout_pro_returns_redirect_url(self, client, db_session):
|
||||||
|
"""Pro checkout should return a redirect URL with org ID and email."""
|
||||||
|
from app.config import settings
|
||||||
|
original_pro_url = settings.stripe_pro_checkout_url
|
||||||
|
settings.stripe_pro_checkout_url = "https://buy.stripe.com/test_pro"
|
||||||
|
|
||||||
|
user, org, token = await _create_user_and_org(db_session)
|
||||||
|
|
||||||
|
response = await client.get(
|
||||||
|
"/api/v1/billing/checkout/pro",
|
||||||
|
headers={"Authorization": f"Bearer {token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
settings.stripe_pro_checkout_url = original_pro_url
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert "redirect_url" in data
|
||||||
|
assert data["tier"] == "pro"
|
||||||
|
assert data["organization_id"] == org.id
|
||||||
|
assert "client_reference_id" in data["redirect_url"]
|
||||||
|
assert "prefilled_email" in data["redirect_url"]
|
||||||
|
assert data["redirect_url"].startswith("https://buy.stripe.com/test_pro?")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_checkout_team_returns_redirect_url(self, client, db_session):
|
||||||
|
"""Team checkout should return a redirect URL."""
|
||||||
|
from app.config import settings
|
||||||
|
original_team_url = settings.stripe_team_checkout_url
|
||||||
|
settings.stripe_team_checkout_url = "https://buy.stripe.com/test_team"
|
||||||
|
|
||||||
|
user, org, token = await _create_user_and_org(db_session)
|
||||||
|
|
||||||
|
response = await client.get(
|
||||||
|
"/api/v1/billing/checkout/team",
|
||||||
|
headers={"Authorization": f"Bearer {token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
settings.stripe_team_checkout_url = original_team_url
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["tier"] == "team"
|
||||||
|
assert data["organization_id"] == org.id
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_checkout_invalid_tier(self, client, db_session):
|
||||||
|
"""Invalid tier should return 400."""
|
||||||
|
user, org, token = await _create_user_and_org(db_session)
|
||||||
|
|
||||||
|
response = await client.get(
|
||||||
|
"/api/v1/billing/checkout/enterprise",
|
||||||
|
headers={"Authorization": f"Bearer {token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 400
|
||||||
|
assert "Invalid tier" in response.json()["detail"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_checkout_free_tier_invalid(self, client, db_session):
|
||||||
|
"""Free tier checkout should return 400 (can't upgrade to free)."""
|
||||||
|
user, org, token = await _create_user_and_org(db_session)
|
||||||
|
|
||||||
|
response = await client.get(
|
||||||
|
"/api/v1/billing/checkout/free",
|
||||||
|
headers={"Authorization": f"Bearer {token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 400
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_checkout_missing_url_returns_503(self, client, db_session):
|
||||||
|
"""If checkout URL is not configured, return 503."""
|
||||||
|
from app.config import settings
|
||||||
|
original_pro_url = settings.stripe_pro_checkout_url
|
||||||
|
settings.stripe_pro_checkout_url = "" # Empty = not configured
|
||||||
|
|
||||||
|
user, org, token = await _create_user_and_org(db_session)
|
||||||
|
|
||||||
|
response = await client.get(
|
||||||
|
"/api/v1/billing/checkout/pro",
|
||||||
|
headers={"Authorization": f"Bearer {token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
settings.stripe_pro_checkout_url = original_pro_url
|
||||||
|
|
||||||
|
assert response.status_code == 503
|
||||||
|
assert "not configured" in response.json()["detail"].lower()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_checkout_requires_auth(self, client, db_session):
|
||||||
|
"""Checkout endpoint requires authentication."""
|
||||||
|
response = await client.get("/api/v1/billing/checkout/pro")
|
||||||
|
assert response.status_code == 401
|
||||||
|
|
||||||
|
|
||||||
|
# ── Billing status endpoint tests ─────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestBillingStatusEndpoint:
|
||||||
|
"""Tests for GET /api/v1/billing/status"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_billing_status_free_org(self, client, db_session):
|
||||||
|
"""Free org should see pro and team checkout URLs."""
|
||||||
|
from app.config import settings
|
||||||
|
original_pro = settings.stripe_pro_checkout_url
|
||||||
|
original_team = settings.stripe_team_checkout_url
|
||||||
|
settings.stripe_pro_checkout_url = "https://buy.stripe.com/pro"
|
||||||
|
settings.stripe_team_checkout_url = "https://buy.stripe.com/team"
|
||||||
|
|
||||||
|
user, org, token = await _create_user_and_org(db_session, tier="free")
|
||||||
|
|
||||||
|
response = await client.get(
|
||||||
|
"/api/v1/billing/status",
|
||||||
|
headers={"Authorization": f"Bearer {token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
settings.stripe_pro_checkout_url = original_pro
|
||||||
|
settings.stripe_team_checkout_url = original_team
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["tier"] == "free"
|
||||||
|
assert "tier_info" in data
|
||||||
|
assert "pro" in data["checkout_urls"]
|
||||||
|
assert "team" in data["checkout_urls"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_billing_status_pro_org(self, client, db_session):
|
||||||
|
"""Pro org should only see team upgrade URL."""
|
||||||
|
from app.config import settings
|
||||||
|
original_team = settings.stripe_team_checkout_url
|
||||||
|
settings.stripe_team_checkout_url = "https://buy.stripe.com/team"
|
||||||
|
|
||||||
|
user, org, token = await _create_user_and_org(db_session, tier="pro")
|
||||||
|
|
||||||
|
response = await client.get(
|
||||||
|
"/api/v1/billing/status",
|
||||||
|
headers={"Authorization": f"Bearer {token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
settings.stripe_team_checkout_url = original_team
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["tier"] == "pro"
|
||||||
|
assert "pro" not in data["checkout_urls"]
|
||||||
|
assert "team" in data["checkout_urls"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_billing_status_team_org(self, client, db_session):
|
||||||
|
"""Team org should see no upgrade URLs (already at top tier)."""
|
||||||
|
user, org, token = await _create_user_and_org(db_session, tier="team")
|
||||||
|
|
||||||
|
response = await client.get(
|
||||||
|
"/api/v1/billing/status",
|
||||||
|
headers={"Authorization": f"Bearer {token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["tier"] == "team"
|
||||||
|
assert data["checkout_urls"] == {}
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_billing_status_shows_tier_limits(self, client, db_session):
|
||||||
|
"""Billing status should include tier limits from tier_limits module."""
|
||||||
|
user, org, token = await _create_user_and_org(db_session, tier="free")
|
||||||
|
|
||||||
|
response = await client.get(
|
||||||
|
"/api/v1/billing/status",
|
||||||
|
headers={"Authorization": f"Bearer {token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
data = response.json()
|
||||||
|
assert "tier_info" in data
|
||||||
|
assert data["tier_info"]["tier"] == "free"
|
||||||
|
assert "limits" in data["tier_info"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_billing_status_requires_auth(self, client, db_session):
|
||||||
|
"""Billing status endpoint requires authentication."""
|
||||||
|
response = await client.get("/api/v1/billing/status")
|
||||||
|
assert response.status_code == 401
|
||||||
|
|
||||||
|
|
||||||
|
# ── Success/cancel callback tests ────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestBillingCallbacks:
|
||||||
|
"""Tests for GET /api/v1/billing/success and /cancel"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_success_callback(self, client, db_session):
|
||||||
|
"""Success callback should return success message."""
|
||||||
|
response = await client.get(
|
||||||
|
"/api/v1/billing/success?session_id=cs_test_123"
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert "activated" in data["message"].lower() or "upgraded" in data["message"].lower()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_success_callback_without_session_id(self, client, db_session):
|
||||||
|
"""Success callback works even without session_id (webhook already processed)."""
|
||||||
|
response = await client.get("/api/v1/billing/success")
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cancel_callback(self, client, db_session):
|
||||||
|
"""Cancel callback should return cancellation message."""
|
||||||
|
response = await client.get("/api/v1/billing/cancel")
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert "canceled" in data["message"].lower()
|
||||||
|
|
||||||
|
|
||||||
|
# ── Webhook tests ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestWebhookHandler:
|
||||||
|
"""Tests for POST /api/v1/billing/webhook"""
|
||||||
|
|
||||||
|
def _make_checkout_session_event(
|
||||||
|
self, org_id: str, amount_total: int = 900, customer_id: str = "cus_test123"
|
||||||
|
) -> dict:
|
||||||
|
"""Create a mock checkout.session.completed event."""
|
||||||
|
return {
|
||||||
|
"id": "evt_test_123",
|
||||||
|
"object": "event",
|
||||||
|
"type": "checkout.session.completed",
|
||||||
|
"data": {
|
||||||
|
"object": {
|
||||||
|
"id": "cs_test_123",
|
||||||
|
"object": "checkout.session",
|
||||||
|
"client_reference_id": org_id,
|
||||||
|
"customer": customer_id,
|
||||||
|
"amount_total": amount_total,
|
||||||
|
"line_items": {"data": []},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def _make_subscription_event(
|
||||||
|
self, customer_id: str, event_type: str, amount: int = 900
|
||||||
|
) -> dict:
|
||||||
|
"""Create a mock subscription event."""
|
||||||
|
return {
|
||||||
|
"id": "evt_test_sub_123",
|
||||||
|
"object": "event",
|
||||||
|
"type": event_type,
|
||||||
|
"data": {
|
||||||
|
"object": {
|
||||||
|
"id": "sub_test_123",
|
||||||
|
"object": "subscription",
|
||||||
|
"customer": customer_id,
|
||||||
|
"metadata": {},
|
||||||
|
"items": {
|
||||||
|
"data": [
|
||||||
|
{
|
||||||
|
"plan": {"amount": amount},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_webhook_checkout_session_pro(self, client, db_session):
|
||||||
|
"""checkout.session.completed with $9 amount should set tier to pro."""
|
||||||
|
from app.config import settings
|
||||||
|
original_secret = settings.stripe_webhook_secret
|
||||||
|
original_api_key = settings.stripe_api_key
|
||||||
|
# Disable signature verification for tests
|
||||||
|
settings.stripe_webhook_secret = ""
|
||||||
|
settings.stripe_api_key = ""
|
||||||
|
|
||||||
|
user, org, token = await _create_user_and_org(db_session, tier="free")
|
||||||
|
|
||||||
|
event = self._make_checkout_session_event(org.id, amount_total=900)
|
||||||
|
response = await client.post(
|
||||||
|
"/api/v1/billing/webhook",
|
||||||
|
json=event,
|
||||||
|
)
|
||||||
|
|
||||||
|
settings.stripe_webhook_secret = original_secret
|
||||||
|
settings.stripe_api_key = original_api_key
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["status"] == "ok"
|
||||||
|
assert data["event_type"] == "checkout.session.completed"
|
||||||
|
|
||||||
|
# Verify org was upgraded
|
||||||
|
result = await db_session.execute(
|
||||||
|
select(Organization).where(Organization.id == org.id)
|
||||||
|
)
|
||||||
|
updated_org = result.scalar_one()
|
||||||
|
assert updated_org.tier == "pro"
|
||||||
|
assert updated_org.stripe_customer_id == "cus_test123"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_webhook_checkout_session_team(self, client, db_session):
|
||||||
|
"""checkout.session.completed with $29 amount should set tier to team."""
|
||||||
|
from app.config import settings
|
||||||
|
original_secret = settings.stripe_webhook_secret
|
||||||
|
original_api_key = settings.stripe_api_key
|
||||||
|
settings.stripe_webhook_secret = ""
|
||||||
|
settings.stripe_api_key = ""
|
||||||
|
|
||||||
|
user, org, token = await _create_user_and_org(db_session, tier="free")
|
||||||
|
|
||||||
|
event = self._make_checkout_session_event(org.id, amount_total=2900)
|
||||||
|
response = await client.post(
|
||||||
|
"/api/v1/billing/webhook",
|
||||||
|
json=event,
|
||||||
|
)
|
||||||
|
|
||||||
|
settings.stripe_webhook_secret = original_secret
|
||||||
|
settings.stripe_api_key = original_api_key
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
result = await db_session.execute(
|
||||||
|
select(Organization).where(Organization.id == org.id)
|
||||||
|
)
|
||||||
|
updated_org = result.scalar_one()
|
||||||
|
assert updated_org.tier == "team"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_webhook_subscription_deleted_downgrades_to_free(self, client, db_session):
|
||||||
|
"""customer.subscription.deleted should downgrade org to free."""
|
||||||
|
from app.config import settings
|
||||||
|
original_secret = settings.stripe_webhook_secret
|
||||||
|
original_api_key = settings.stripe_api_key
|
||||||
|
settings.stripe_webhook_secret = ""
|
||||||
|
settings.stripe_api_key = ""
|
||||||
|
|
||||||
|
user, org, token = await _create_user_and_org(db_session, tier="pro")
|
||||||
|
# Manually set stripe_customer_id
|
||||||
|
org.stripe_customer_id = "cus_delete_test"
|
||||||
|
await db_session.flush()
|
||||||
|
|
||||||
|
event = self._make_subscription_event(
|
||||||
|
customer_id="cus_delete_test",
|
||||||
|
event_type="customer.subscription.deleted",
|
||||||
|
)
|
||||||
|
response = await client.post(
|
||||||
|
"/api/v1/billing/webhook",
|
||||||
|
json=event,
|
||||||
|
)
|
||||||
|
|
||||||
|
settings.stripe_webhook_secret = original_secret
|
||||||
|
settings.stripe_api_key = original_api_key
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
result = await db_session.execute(
|
||||||
|
select(Organization).where(Organization.id == org.id)
|
||||||
|
)
|
||||||
|
updated_org = result.scalar_one()
|
||||||
|
assert updated_org.tier == "free"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_webhook_subscription_updated(self, client, db_session):
|
||||||
|
"""customer.subscription.updated should update org tier based on amount."""
|
||||||
|
from app.config import settings
|
||||||
|
original_secret = settings.stripe_webhook_secret
|
||||||
|
original_api_key = settings.stripe_api_key
|
||||||
|
settings.stripe_webhook_secret = ""
|
||||||
|
settings.stripe_api_key = ""
|
||||||
|
|
||||||
|
user, org, token = await _create_user_and_org(db_session, tier="pro")
|
||||||
|
org.stripe_customer_id = "cus_update_test"
|
||||||
|
await db_session.flush()
|
||||||
|
|
||||||
|
event = self._make_subscription_event(
|
||||||
|
customer_id="cus_update_test",
|
||||||
|
event_type="customer.subscription.updated",
|
||||||
|
amount=2900, # Upgrade to team
|
||||||
|
)
|
||||||
|
response = await client.post(
|
||||||
|
"/api/v1/billing/webhook",
|
||||||
|
json=event,
|
||||||
|
)
|
||||||
|
|
||||||
|
settings.stripe_webhook_secret = original_secret
|
||||||
|
settings.stripe_api_key = original_api_key
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
result = await db_session.execute(
|
||||||
|
select(Organization).where(Organization.id == org.id)
|
||||||
|
)
|
||||||
|
updated_org = result.scalar_one()
|
||||||
|
assert updated_org.tier == "team"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_webhook_ignores_unhandled_events(self, client, db_session):
|
||||||
|
"""Unhandled event types should be ignored."""
|
||||||
|
from app.config import settings
|
||||||
|
original_secret = settings.stripe_webhook_secret
|
||||||
|
settings.stripe_webhook_secret = ""
|
||||||
|
|
||||||
|
event = {
|
||||||
|
"id": "evt_test_ignore",
|
||||||
|
"type": "account.updated",
|
||||||
|
"data": {"object": {}},
|
||||||
|
}
|
||||||
|
response = await client.post(
|
||||||
|
"/api/v1/billing/webhook",
|
||||||
|
json=event,
|
||||||
|
)
|
||||||
|
|
||||||
|
settings.stripe_webhook_secret = original_secret
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["status"] == "ignored"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_webhook_invalid_signature_rejected(self, client, db_session):
|
||||||
|
"""Webhook with invalid signature should be rejected."""
|
||||||
|
from app.config import settings
|
||||||
|
original_secret = settings.stripe_webhook_secret
|
||||||
|
original_api_key = settings.stripe_api_key
|
||||||
|
# Set up webhook secret but no API key (forces manual verification)
|
||||||
|
settings.stripe_webhook_secret = "whsec_test_secret_key_12345"
|
||||||
|
settings.stripe_api_key = ""
|
||||||
|
|
||||||
|
event = self._make_checkout_session_event("fake_org_id")
|
||||||
|
|
||||||
|
response = await client.post(
|
||||||
|
"/api/v1/billing/webhook",
|
||||||
|
json=event,
|
||||||
|
headers={"stripe-signature": "t=1234,v1=invalid_signature_here"},
|
||||||
|
)
|
||||||
|
|
||||||
|
settings.stripe_webhook_secret = original_secret
|
||||||
|
settings.stripe_api_key = original_api_key
|
||||||
|
|
||||||
|
assert response.status_code == 400
|
||||||
|
assert "signature" in response.json()["detail"].lower()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_webhook_checkout_missing_org_id(self, client, db_session):
|
||||||
|
"""Webhook without client_reference_id should return error."""
|
||||||
|
from app.config import settings
|
||||||
|
original_secret = settings.stripe_webhook_secret
|
||||||
|
settings.stripe_webhook_secret = ""
|
||||||
|
|
||||||
|
event = {
|
||||||
|
"id": "evt_test_no_org",
|
||||||
|
"type": "checkout.session.completed",
|
||||||
|
"data": {
|
||||||
|
"object": {
|
||||||
|
"id": "cs_test_no_org",
|
||||||
|
"client_reference_id": None,
|
||||||
|
"customer": "cus_123",
|
||||||
|
"amount_total": 900,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
response = await client.post(
|
||||||
|
"/api/v1/billing/webhook",
|
||||||
|
json=event,
|
||||||
|
)
|
||||||
|
|
||||||
|
settings.stripe_webhook_secret = original_secret
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["status"] == "error"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_webhook_invoice_payment_succeeded(self, client, db_session):
|
||||||
|
"""invoice.payment_succeeded should acknowledge without downgrading."""
|
||||||
|
from app.config import settings
|
||||||
|
original_secret = settings.stripe_webhook_secret
|
||||||
|
original_api_key = settings.stripe_api_key
|
||||||
|
settings.stripe_webhook_secret = ""
|
||||||
|
settings.stripe_api_key = ""
|
||||||
|
|
||||||
|
user, org, token = await _create_user_and_org(db_session, tier="pro")
|
||||||
|
org.stripe_customer_id = "cus_invoice_test"
|
||||||
|
await db_session.flush()
|
||||||
|
|
||||||
|
event = {
|
||||||
|
"id": "evt_invoice_paid",
|
||||||
|
"type": "invoice.payment_succeeded",
|
||||||
|
"data": {
|
||||||
|
"object": {
|
||||||
|
"id": "in_test_123",
|
||||||
|
"customer": "cus_invoice_test",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
response = await client.post(
|
||||||
|
"/api/v1/billing/webhook",
|
||||||
|
json=event,
|
||||||
|
)
|
||||||
|
|
||||||
|
settings.stripe_webhook_secret = original_secret
|
||||||
|
settings.stripe_api_key = original_api_key
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["status"] == "ok"
|
||||||
|
|
||||||
|
# Verify tier wasn't changed
|
||||||
|
result = await db_session.execute(
|
||||||
|
select(Organization).where(Organization.id == org.id)
|
||||||
|
)
|
||||||
|
updated_org = result.scalar_one()
|
||||||
|
assert updated_org.tier == "pro"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_webhook_invoice_payment_failed(self, client, db_session):
|
||||||
|
"""invoice.payment_failed should not downgrade (Stripe retries)."""
|
||||||
|
from app.config import settings
|
||||||
|
original_secret = settings.stripe_webhook_secret
|
||||||
|
original_api_key = settings.stripe_api_key
|
||||||
|
settings.stripe_webhook_secret = ""
|
||||||
|
settings.stripe_api_key = ""
|
||||||
|
|
||||||
|
user, org, token = await _create_user_and_org(db_session, tier="pro")
|
||||||
|
org.stripe_customer_id = "cus_fail_test"
|
||||||
|
await db_session.flush()
|
||||||
|
|
||||||
|
event = {
|
||||||
|
"id": "evt_invoice_failed",
|
||||||
|
"type": "invoice.payment_failed",
|
||||||
|
"data": {
|
||||||
|
"object": {
|
||||||
|
"id": "in_test_fail",
|
||||||
|
"customer": "cus_fail_test",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
response = await client.post(
|
||||||
|
"/api/v1/billing/webhook",
|
||||||
|
json=event,
|
||||||
|
)
|
||||||
|
|
||||||
|
settings.stripe_webhook_secret = original_secret
|
||||||
|
settings.stripe_api_key = original_api_key
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["status"] == "ok"
|
||||||
|
|
||||||
|
# Verify tier wasn't changed (payment will be retried)
|
||||||
|
result = await db_session.execute(
|
||||||
|
select(Organization).where(Organization.id == org.id)
|
||||||
|
)
|
||||||
|
updated_org = result.scalar_one()
|
||||||
|
assert updated_org.tier == "pro"
|
||||||
|
|
||||||
|
|
||||||
|
# ── Tier determination helper tests ───────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestTierDetermination:
|
||||||
|
"""Tests for _determine_tier_from_amount helper."""
|
||||||
|
|
||||||
|
def test_pro_from_amount(self):
|
||||||
|
from app.api.billing import _determine_tier_from_amount
|
||||||
|
assert _determine_tier_from_amount(900) == "pro"
|
||||||
|
assert _determine_tier_from_amount(1500) == "pro"
|
||||||
|
|
||||||
|
def test_team_from_amount(self):
|
||||||
|
from app.api.billing import _determine_tier_from_amount
|
||||||
|
assert _determine_tier_from_amount(2900) == "team"
|
||||||
|
assert _determine_tier_from_amount(5000) == "team"
|
||||||
|
assert _determine_tier_from_amount(9900) == "team"
|
||||||
|
|
||||||
|
def test_free_from_small_amount(self):
|
||||||
|
from app.api.billing import _determine_tier_from_amount
|
||||||
|
assert _determine_tier_from_amount(0) == "free"
|
||||||
|
assert _determine_tier_from_amount(100) == "free"
|
||||||
|
|
||||||
|
|
||||||
|
# ── Integration: checkout → webhook → tier update ────────────────────────────
|
||||||
|
|
||||||
|
class TestBillingIntegration:
|
||||||
|
"""End-to-end style tests for the billing flow."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_free_to_pro_upgrade_flow(self, client, db_session):
|
||||||
|
"""Complete flow: free org → checkout → webhook → pro tier."""
|
||||||
|
from app.config import settings
|
||||||
|
original_secret = settings.stripe_webhook_secret
|
||||||
|
original_pro_url = settings.stripe_pro_checkout_url
|
||||||
|
original_team_url = settings.stripe_team_checkout_url
|
||||||
|
original_api_key = settings.stripe_api_key
|
||||||
|
|
||||||
|
settings.stripe_webhook_secret = ""
|
||||||
|
settings.stripe_pro_checkout_url = "https://buy.stripe.com/test_pro_link"
|
||||||
|
settings.stripe_team_checkout_url = "https://buy.stripe.com/test_team_link"
|
||||||
|
settings.stripe_api_key = ""
|
||||||
|
|
||||||
|
user, org, token = await _create_user_and_org(db_session, tier="free")
|
||||||
|
|
||||||
|
# Step 1: Get checkout URL
|
||||||
|
checkout_resp = await client.get(
|
||||||
|
"/api/v1/billing/checkout/pro",
|
||||||
|
headers={"Authorization": f"Bearer {token}"},
|
||||||
|
)
|
||||||
|
assert checkout_resp.status_code == 200
|
||||||
|
checkout_data = checkout_resp.json()
|
||||||
|
assert checkout_data["redirect_url"].startswith("https://buy.stripe.com/test_pro_link")
|
||||||
|
assert "client_reference_id=" + org.id in checkout_data["redirect_url"]
|
||||||
|
|
||||||
|
# Step 2: Simulate Stripe webhook
|
||||||
|
webhook_event = {
|
||||||
|
"id": "evt_integration_test",
|
||||||
|
"type": "checkout.session.completed",
|
||||||
|
"data": {
|
||||||
|
"object": {
|
||||||
|
"id": "cs_integration_test",
|
||||||
|
"client_reference_id": org.id,
|
||||||
|
"customer": "cus_integration_test",
|
||||||
|
"amount_total": 900,
|
||||||
|
"line_items": {"data": []},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
webhook_resp = await client.post(
|
||||||
|
"/api/v1/billing/webhook",
|
||||||
|
json=webhook_event,
|
||||||
|
)
|
||||||
|
assert webhook_resp.status_code == 200
|
||||||
|
|
||||||
|
# Step 3: Verify org is now pro
|
||||||
|
status_resp = await client.get(
|
||||||
|
"/api/v1/billing/status",
|
||||||
|
headers={"Authorization": f"Bearer {token}"},
|
||||||
|
)
|
||||||
|
assert status_resp.status_code == 200
|
||||||
|
assert status_resp.json()["tier"] == "pro"
|
||||||
|
|
||||||
|
# Step 4: Verify billing status shows team upgrade available
|
||||||
|
assert "team" in status_resp.json()["checkout_urls"]
|
||||||
|
|
||||||
|
# Restore settings
|
||||||
|
settings.stripe_webhook_secret = original_secret
|
||||||
|
settings.stripe_pro_checkout_url = original_pro_url
|
||||||
|
settings.stripe_team_checkout_url = original_team_url
|
||||||
|
settings.stripe_api_key = original_api_key
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_pro_to_free_downgrade_via_webhook(self, client, db_session):
|
||||||
|
"""Complete flow: pro org → subscription deleted → free tier."""
|
||||||
|
from app.config import settings
|
||||||
|
original_secret = settings.stripe_webhook_secret
|
||||||
|
original_api_key = settings.stripe_api_key
|
||||||
|
settings.stripe_webhook_secret = ""
|
||||||
|
settings.stripe_api_key = ""
|
||||||
|
|
||||||
|
user, org, token = await _create_user_and_org(db_session, tier="pro")
|
||||||
|
org.stripe_customer_id = "cus_downgrade_test"
|
||||||
|
await db_session.flush()
|
||||||
|
|
||||||
|
# Simulate subscription cancellation
|
||||||
|
event = {
|
||||||
|
"id": "evt_cancel_test",
|
||||||
|
"type": "customer.subscription.deleted",
|
||||||
|
"data": {
|
||||||
|
"object": {
|
||||||
|
"id": "sub_cancel_test",
|
||||||
|
"customer": "cus_downgrade_test",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
response = await client.post(
|
||||||
|
"/api/v1/billing/webhook",
|
||||||
|
json=event,
|
||||||
|
)
|
||||||
|
|
||||||
|
settings.stripe_webhook_secret = original_secret
|
||||||
|
settings.stripe_api_key = original_api_key
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
# Verify org is now free
|
||||||
|
result = await db_session.execute(
|
||||||
|
select(Organization).where(Organization.id == org.id)
|
||||||
|
)
|
||||||
|
updated_org = result.scalar_one()
|
||||||
|
assert updated_org.tier == "free"
|
||||||
|
|
||||||
|
# Verify billing status shows upgrade options
|
||||||
|
status_resp = await client.get(
|
||||||
|
"/api/v1/billing/status",
|
||||||
|
headers={"Authorization": f"Bearer {token}"},
|
||||||
|
)
|
||||||
|
data = status_resp.json()
|
||||||
|
assert data["tier"] == "free"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_checkout_url_includes_prefilled_email(self, client, db_session):
|
||||||
|
"""Checkout URL should include the user's email as prefilled_email."""
|
||||||
|
from app.config import settings
|
||||||
|
original_pro = settings.stripe_pro_checkout_url
|
||||||
|
settings.stripe_pro_checkout_url = "https://buy.stripe.com/test_pro"
|
||||||
|
|
||||||
|
user, org, token = await _create_user_and_org(db_session)
|
||||||
|
|
||||||
|
response = await client.get(
|
||||||
|
"/api/v1/billing/checkout/pro",
|
||||||
|
headers={"Authorization": f"Bearer {token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
settings.stripe_pro_checkout_url = original_pro
|
||||||
|
|
||||||
|
data = response.json()
|
||||||
|
assert f"prefilled_email={user.email}" in data["redirect_url"] or \
|
||||||
|
f"prefilled_email={user.email.replace('@', '%40')}" in data["redirect_url"]
|
||||||
Loading…
Add table
Add a link
Reference in a new issue