feat: Stripe Checkout Link integration — billing API with 29 tests
- Add billing module (app/api/billing.py) with 5 API endpoints:
- GET /api/v1/billing/checkout/{tier} — redirect to Stripe Payment Links
- GET /api/v1/billing/status — current org tier, limits, upgrade URLs
- GET /api/v1/billing/success — Stripe success callback
- GET /api/v1/billing/cancel — Stripe cancel callback
- POST /api/v1/billing/webhook — handles 5 Stripe event types
- Zero-code payment flow: uses pre-configured Stripe Payment Links
with client_reference_id (org ID) and prefilled_email params
- Webhook handler processes checkout.session.completed,
customer.subscription.updated/deleted, invoice events
- Stripe signature verification via stripe library (primary)
or manual HMAC-SHA256 (fallback)
- Tier determination from payment amount: =pro, 9=team
- 4 new config settings in app/config.py:
stripe_pro_checkout_url, stripe_team_checkout_url,
stripe_webhook_secret, stripe_api_key
- Added stripe>=5.0,<16.0 dependency
- 29 tests in tests/test_billing.py (all passing)
- Total: 98 tests passing (69 existing + 29 new)
This commit is contained in:
parent
b7a8142ca0
commit
158a6ee716
6 changed files with 1311 additions and 1 deletions
|
|
@ -24,4 +24,11 @@ SMTP_FROM=noreply@example.com
|
|||
WEBHOOK_NOTIFY_URL=
|
||||
|
||||
# Uptime Monitoring
|
||||
MONITOR_CHECK_INTERVAL=60
|
||||
MONITOR_CHECK_INTERVAL=60
|
||||
|
||||
# Stripe Checkout Links (set these in production)
|
||||
# Create Payment Links in Stripe Dashboard → Products → Payment Links
|
||||
STRIPE_PRO_CHECKOUT_URL=
|
||||
STRIPE_TEAM_CHECKOUT_URL=
|
||||
STRIPE_WEBHOOK_SECRET=
|
||||
STRIPE_API_KEY=
|
||||
472
app/api/billing.py
Normal file
472
app/api/billing.py
Normal file
|
|
@ -0,0 +1,472 @@
|
|||
"""Billing API: Stripe Checkout Link integration and webhook handler.
|
||||
|
||||
This module implements the "zero-code" payment flow using Stripe Payment Links:
|
||||
1. User clicks "Upgrade" → redirected to a Stripe-hosted checkout page
|
||||
2. After payment, Stripe sends a webhook event to /api/v1/billing/webhook
|
||||
3. The webhook handler updates the org's tier in the database
|
||||
4. Stripe redirects the user back to /api/v1/billing/success
|
||||
|
||||
This approach requires no Stripe server-side SDK for checkout creation —
|
||||
just configuration of Payment Link URLs and a webhook endpoint for
|
||||
receiving payment confirmations.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
import logging
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.auth import get_current_user, get_current_org
|
||||
from app.config import settings
|
||||
from app.dependencies import get_db
|
||||
from app.models.saas_models import Organization, OrganizationMember, User
|
||||
from app.services.tier_limits import get_tier_info
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(tags=["billing"])
|
||||
|
||||
# ── Tier-to-URL mapping ─────────────────────────────────────────────────────
|
||||
|
||||
def _get_checkout_url(tier: str) -> str:
|
||||
"""Return the Stripe Payment Link URL for the given tier.
|
||||
|
||||
Reads from settings at call time (not import time) so tests can override.
|
||||
"""
|
||||
if tier == "pro":
|
||||
return settings.stripe_pro_checkout_url
|
||||
elif tier == "team":
|
||||
return settings.stripe_team_checkout_url
|
||||
return ""
|
||||
|
||||
|
||||
# ── Response schemas ─────────────────────────────────────────────────────────
|
||||
|
||||
class CheckoutResponse(BaseModel):
|
||||
"""Response for checkout redirect endpoint."""
|
||||
redirect_url: str
|
||||
tier: str
|
||||
organization_id: str
|
||||
|
||||
|
||||
class BillingStatusResponse(BaseModel):
|
||||
"""Response for billing status endpoint."""
|
||||
tier: str
|
||||
tier_info: dict
|
||||
stripe_customer_id: str | None = None
|
||||
checkout_urls: dict[str, str]
|
||||
|
||||
|
||||
class SuccessResponse(BaseModel):
|
||||
"""Response for successful checkout callback."""
|
||||
message: str
|
||||
tier: str | None = None
|
||||
|
||||
|
||||
class WebhookResponse(BaseModel):
|
||||
"""Response for webhook processing."""
|
||||
status: str
|
||||
event_type: str | None = None
|
||||
|
||||
|
||||
# ── Endpoints ────────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/checkout/{tier}", response_model=CheckoutResponse)
|
||||
async def create_checkout(
|
||||
tier: str,
|
||||
user: User = Depends(get_current_user),
|
||||
org: Organization = Depends(get_current_org),
|
||||
):
|
||||
"""Generate a Stripe Checkout Link redirect URL for the selected tier.
|
||||
|
||||
This endpoint does NOT create a Stripe Checkout Session. Instead, it
|
||||
returns the pre-configured Stripe Payment Link URL with the org's ID
|
||||
attached as client_reference_id and the user's email prefilled.
|
||||
|
||||
The frontend should redirect the user to this URL. After completing
|
||||
payment on Stripe's hosted checkout, Stripe will:
|
||||
1. Send a webhook to /api/v1/billing/webhook
|
||||
2. Redirect the user to /api/v1/billing/success
|
||||
"""
|
||||
if tier not in ("pro", "team"):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Invalid tier '{tier}'. Must be 'pro' or 'team'.",
|
||||
)
|
||||
|
||||
checkout_url = _get_checkout_url(tier)
|
||||
if not checkout_url:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail=(
|
||||
f"Checkout for '{tier}' tier is not configured yet. "
|
||||
"Please contact support or set up Stripe Payment Links."
|
||||
),
|
||||
)
|
||||
|
||||
# Build redirect URL with org ID and prefilled email
|
||||
params = urlencode({
|
||||
"client_reference_id": org.id,
|
||||
"prefilled_email": user.email,
|
||||
})
|
||||
redirect_url = f"{checkout_url}?{params}"
|
||||
|
||||
return CheckoutResponse(
|
||||
redirect_url=redirect_url,
|
||||
tier=tier,
|
||||
organization_id=org.id,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/status", response_model=BillingStatusResponse)
|
||||
async def get_billing_status(
|
||||
org: Organization = Depends(get_current_org),
|
||||
):
|
||||
"""Get the current organization's billing status: tier, limits, and checkout URLs."""
|
||||
# Only show checkout URLs for tiers the org can upgrade to
|
||||
checkout_urls: dict[str, str] = {}
|
||||
if org.tier == "free":
|
||||
if settings.stripe_pro_checkout_url:
|
||||
checkout_urls["pro"] = settings.stripe_pro_checkout_url
|
||||
if settings.stripe_team_checkout_url:
|
||||
checkout_urls["team"] = settings.stripe_team_checkout_url
|
||||
elif org.tier == "pro":
|
||||
if settings.stripe_team_checkout_url:
|
||||
checkout_urls["team"] = settings.stripe_team_checkout_url
|
||||
|
||||
return BillingStatusResponse(
|
||||
tier=org.tier or "free",
|
||||
tier_info=get_tier_info(org),
|
||||
stripe_customer_id=org.stripe_customer_id,
|
||||
checkout_urls=checkout_urls,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/success", response_model=SuccessResponse)
|
||||
async def billing_success(
|
||||
session_id: str = "",
|
||||
):
|
||||
"""Stripe redirects here after a successful checkout.
|
||||
|
||||
In production, you would call stripe.checkout.sessions.retrieve(session_id)
|
||||
to verify the session. For the MVP, the webhook handler already processed
|
||||
the payment — this endpoint just confirms to the user that it worked.
|
||||
|
||||
The session_id is provided by Stripe's redirect URL.
|
||||
"""
|
||||
return SuccessResponse(
|
||||
message="Subscription activated! Your plan has been upgraded.",
|
||||
tier=None, # Webhook already set the tier; frontend should refetch
|
||||
)
|
||||
|
||||
|
||||
@router.get("/cancel")
|
||||
async def billing_cancel():
|
||||
"""Stripe redirects here if the user cancels checkout."""
|
||||
return {
|
||||
"message": "Checkout was canceled. Your current plan is unchanged.",
|
||||
"tier_change": None,
|
||||
}
|
||||
|
||||
|
||||
# ── Webhook handler ──────────────────────────────────────────────────────────
|
||||
|
||||
# Stripe event types we handle
|
||||
HANDLED_EVENTS = {
|
||||
"checkout.session.completed",
|
||||
"customer.subscription.updated",
|
||||
"customer.subscription.deleted",
|
||||
"invoice.payment_succeeded",
|
||||
"invoice.payment_failed",
|
||||
}
|
||||
|
||||
|
||||
def _determine_tier_from_amount(amount_total: int) -> str:
|
||||
"""Determine the tier from the payment amount (in cents).
|
||||
|
||||
Args:
|
||||
amount_total: Total amount in cents (e.g. 900 = $9.00)
|
||||
|
||||
Returns:
|
||||
Tier string: "pro" or "team"
|
||||
"""
|
||||
if amount_total >= 2900: # $29.00+
|
||||
return "team"
|
||||
elif amount_total >= 900: # $9.00+
|
||||
return "pro"
|
||||
return "free"
|
||||
|
||||
|
||||
def _determine_tier_from_price_id(price_id: str) -> str:
|
||||
"""Determine the tier from a Stripe price ID.
|
||||
|
||||
This is a fallback when amount isn't available. In production,
|
||||
you would map your actual Stripe price IDs here.
|
||||
|
||||
Args:
|
||||
price_id: Stripe price ID (e.g., price_1ABC123...)
|
||||
|
||||
Returns:
|
||||
Tier string: "pro", "team", or "free"
|
||||
"""
|
||||
# Default mapping — override in production with your actual price IDs
|
||||
# The stripe API can also be used to look up price details
|
||||
return "pro" # Default fallback
|
||||
|
||||
|
||||
async def _verify_stripe_signature(payload: bytes, sig_header: str) -> dict | None:
|
||||
"""Verify the Stripe webhook signature using HMAC-SHA256.
|
||||
|
||||
Uses the stripe library if available, falls back to manual verification.
|
||||
|
||||
Args:
|
||||
payload: Raw request body bytes
|
||||
sig_header: Stripe-Signature header value
|
||||
|
||||
Returns:
|
||||
Parsed event dict if signature is valid, None otherwise
|
||||
"""
|
||||
if not settings.stripe_webhook_secret:
|
||||
logger.warning("No Stripe webhook secret configured; skipping verification")
|
||||
return json.loads(payload)
|
||||
|
||||
try:
|
||||
import stripe
|
||||
if settings.stripe_api_key:
|
||||
stripe.api_key = settings.stripe_api_key
|
||||
event = stripe.Webhook.construct_event(
|
||||
payload, sig_header, settings.stripe_webhook_secret
|
||||
)
|
||||
return dict(event)
|
||||
except ImportError:
|
||||
logger.info("stripe library not available; using manual signature verification")
|
||||
except stripe.error.SignatureVerificationError:
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Stripe webhook verification error: {e}")
|
||||
return None
|
||||
|
||||
# Manual HMAC-SHA256 verification fallback
|
||||
# Stripe webhook signatures have the format: t=v1(v1_sig),v0=v0(v0_sig)
|
||||
try:
|
||||
elements = sig_header.split(",")
|
||||
timestamp = None
|
||||
signatures = []
|
||||
for element in elements:
|
||||
key, value = element.split("=", 1)
|
||||
if key == "t":
|
||||
timestamp = value
|
||||
elif key == "v1":
|
||||
signatures.append(value)
|
||||
|
||||
if not timestamp or not signatures:
|
||||
return None
|
||||
|
||||
# Compute expected signature
|
||||
signed_payload = f"{timestamp}.{payload.decode('utf-8')}"
|
||||
expected_sig = hmac.new(
|
||||
settings.stripe_webhook_secret.encode("utf-8"),
|
||||
signed_payload.encode("utf-8"),
|
||||
hashlib.sha256,
|
||||
).hexdigest()
|
||||
|
||||
# Check if any v1 signature matches
|
||||
for sig in signatures:
|
||||
if hmac.compare_digest(expected_sig, sig):
|
||||
return json.loads(payload)
|
||||
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
@router.post("/webhook")
|
||||
async def stripe_webhook(request: Request, db: AsyncSession = Depends(get_db)):
|
||||
"""Handle Stripe webhook events for subscription changes.
|
||||
|
||||
This is the ONLY server-side Stripe code needed for the Payment Links
|
||||
approach. It receives events when:
|
||||
- checkout.session.completed: User completed a checkout
|
||||
- customer.subscription.updated: Subscription tier changed
|
||||
- customer.subscription.deleted: Subscription canceled (downgrade to free)
|
||||
- invoice.payment_succeeded: Recurring payment succeeded
|
||||
- invoice.payment_failed: Payment failed (may downgrade)
|
||||
|
||||
The webhook verifies the Stripe signature to ensure authenticity.
|
||||
"""
|
||||
body = await request.body()
|
||||
sig_header = request.headers.get("stripe-signature", "")
|
||||
|
||||
# Verify the webhook signature
|
||||
event = await _verify_stripe_signature(body, sig_header)
|
||||
if event is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid webhook signature",
|
||||
)
|
||||
|
||||
event_type = event.get("type", "")
|
||||
|
||||
# Only handle known event types
|
||||
if event_type not in HANDLED_EVENTS:
|
||||
logger.info(f"Ignoring unhandled Stripe event type: {event_type}")
|
||||
return WebhookResponse(status="ignored", event_type=event_type)
|
||||
|
||||
# ── Handle checkout.session.completed ────────────────────────────────
|
||||
if event_type == "checkout.session.completed":
|
||||
session_obj = event.get("data", {}).get("object", {})
|
||||
org_id = session_obj.get("client_reference_id")
|
||||
customer_id = session_obj.get("customer")
|
||||
amount_total = session_obj.get("amount_total", 0)
|
||||
|
||||
if not org_id:
|
||||
logger.warning("checkout.session.completed event missing client_reference_id")
|
||||
return WebhookResponse(status="error", event_type=event_type)
|
||||
|
||||
# Find the organization
|
||||
result = await db.execute(
|
||||
select(Organization).where(Organization.id == org_id)
|
||||
)
|
||||
org = result.scalar_one_or_none()
|
||||
if not org:
|
||||
logger.warning(f"Organization {org_id} not found for Stripe checkout")
|
||||
return WebhookResponse(status="error", event_type=event_type)
|
||||
|
||||
# Determine the tier from the payment amount
|
||||
new_tier = _determine_tier_from_amount(amount_total)
|
||||
|
||||
# Try to determine tier from line items if available
|
||||
line_items = session_obj.get("line_items", {}).get("data", [])
|
||||
if line_items:
|
||||
price_id = line_items[0].get("price", {}).get("id", "")
|
||||
if price_id:
|
||||
new_tier = _determine_tier_from_price_id(price_id)
|
||||
|
||||
# Update the organization
|
||||
old_tier = org.tier
|
||||
org.tier = new_tier
|
||||
if customer_id:
|
||||
org.stripe_customer_id = customer_id
|
||||
|
||||
await db.flush()
|
||||
logger.info(
|
||||
f"Organization {org.slug} upgraded from {old_tier} to {new_tier} "
|
||||
f"(Stripe session: {session_obj.get('id', 'unknown')})"
|
||||
)
|
||||
|
||||
return WebhookResponse(status="ok", event_type=event_type)
|
||||
|
||||
# ── Handle customer.subscription.updated ─────────────────────────────
|
||||
if event_type == "customer.subscription.updated":
|
||||
subscription = event.get("data", {}).get("object", {})
|
||||
customer_id = subscription.get("customer")
|
||||
|
||||
if not customer_id:
|
||||
return WebhookResponse(status="error", event_type=event_type)
|
||||
|
||||
result = await db.execute(
|
||||
select(Organization).where(
|
||||
Organization.stripe_customer_id == customer_id
|
||||
)
|
||||
)
|
||||
org = result.scalar_one_or_none()
|
||||
if not org:
|
||||
logger.warning(f"No org found for Stripe customer {customer_id}")
|
||||
return WebhookResponse(status="error", event_type=event_type)
|
||||
|
||||
# Determine tier from subscription metadata or amount
|
||||
metadata = subscription.get("metadata", {})
|
||||
tier_from_meta = metadata.get("tier")
|
||||
|
||||
# Check plan amount from items
|
||||
items = subscription.get("items", {}).get("data", [])
|
||||
if items:
|
||||
amount = items[0].get("plan", {}).get("amount", 0)
|
||||
new_tier = _determine_tier_from_amount(amount)
|
||||
elif tier_from_meta in ("pro", "team"):
|
||||
new_tier = tier_from_meta
|
||||
else:
|
||||
new_tier = org.tier # Keep current if indeterminate
|
||||
|
||||
old_tier = org.tier
|
||||
org.tier = new_tier
|
||||
await db.flush()
|
||||
logger.info(f"Organization {org.slug} subscription updated: {old_tier} → {new_tier}")
|
||||
|
||||
return WebhookResponse(status="ok", event_type=event_type)
|
||||
|
||||
# ── Handle customer.subscription.deleted ──────────────────────────────
|
||||
if event_type == "customer.subscription.deleted":
|
||||
subscription = event.get("data", {}).get("object", {})
|
||||
customer_id = subscription.get("customer")
|
||||
|
||||
if not customer_id:
|
||||
return WebhookResponse(status="error", event_type=event_type)
|
||||
|
||||
result = await db.execute(
|
||||
select(Organization).where(
|
||||
Organization.stripe_customer_id == customer_id
|
||||
)
|
||||
)
|
||||
org = result.scalar_one_or_none()
|
||||
if not org:
|
||||
logger.warning(f"No org found for Stripe customer {customer_id}")
|
||||
return WebhookResponse(status="error", event_type=event_type)
|
||||
|
||||
old_tier = org.tier
|
||||
org.tier = "free"
|
||||
org.stripe_customer_id = None # Clear stripe ID on cancellation
|
||||
await db.flush()
|
||||
logger.info(f"Organization {org.slug} downgraded from {old_tier} to free (subscription deleted)")
|
||||
|
||||
return WebhookResponse(status="ok", event_type=event_type)
|
||||
|
||||
# ── Handle invoice.payment_succeeded ─────────────────────────────────
|
||||
if event_type == "invoice.payment_succeeded":
|
||||
# Payment renewed successfully — org keeps its tier
|
||||
# We could update stripe_customer_id here if needed
|
||||
invoice = event.get("data", {}).get("object", {})
|
||||
customer_id = invoice.get("customer")
|
||||
|
||||
if customer_id:
|
||||
result = await db.execute(
|
||||
select(Organization).where(
|
||||
Organization.stripe_customer_id == customer_id
|
||||
)
|
||||
)
|
||||
org = result.scalar_one_or_none()
|
||||
if org:
|
||||
logger.info(f"Invoice payment succeeded for org {org.slug} (tier: {org.tier})")
|
||||
|
||||
return WebhookResponse(status="ok", event_type=event_type)
|
||||
|
||||
# ── Handle invoice.payment_failed ────────────────────────────────────
|
||||
if event_type == "invoice.payment_failed":
|
||||
invoice = event.get("data", {}).get("object", {})
|
||||
customer_id = invoice.get("customer")
|
||||
|
||||
if customer_id:
|
||||
result = await db.execute(
|
||||
select(Organization).where(
|
||||
Organization.stripe_customer_id == customer_id
|
||||
)
|
||||
)
|
||||
org = result.scalar_one_or_none()
|
||||
if org:
|
||||
logger.warning(
|
||||
f"Invoice payment failed for org {org.slug} (tier: {org.tier}). "
|
||||
"Payment retry will be attempted by Stripe."
|
||||
)
|
||||
# Don't downgrade yet — Stripe will retry per its configured rules
|
||||
# After all retries fail, customer.subscription.deleted will fire
|
||||
|
||||
return WebhookResponse(status="ok", event_type=event_type)
|
||||
|
||||
# Fallback (shouldn't reach here due to HANDLED_EVENTS check)
|
||||
return WebhookResponse(status="ignored", event_type=event_type)
|
||||
|
|
@ -6,12 +6,14 @@ from app.api.monitors import router as monitors_router
|
|||
from app.api.subscribers import router as subscribers_router
|
||||
from app.api.settings import router as settings_router
|
||||
from app.api.organizations import router as organizations_router
|
||||
from app.api.billing import router as billing_router
|
||||
from app.routes.auth import router as auth_router
|
||||
|
||||
api_v1_router = APIRouter()
|
||||
|
||||
api_v1_router.include_router(auth_router, tags=["auth"])
|
||||
api_v1_router.include_router(organizations_router, prefix="/organizations", tags=["organizations"])
|
||||
api_v1_router.include_router(billing_router, prefix="/billing", tags=["billing"])
|
||||
api_v1_router.include_router(services_router, prefix="/services", tags=["services"])
|
||||
api_v1_router.include_router(incidents_router, prefix="/incidents", tags=["incidents"])
|
||||
api_v1_router.include_router(monitors_router, prefix="/monitors", tags=["monitors"])
|
||||
|
|
|
|||
|
|
@ -31,6 +31,12 @@ class Settings(BaseSettings):
|
|||
# Uptime monitoring
|
||||
monitor_check_interval: int = 60
|
||||
|
||||
# Stripe Checkout Links
|
||||
stripe_pro_checkout_url: str = "" # e.g. https://buy.stripe.com/xxxx_pro
|
||||
stripe_team_checkout_url: str = "" # e.g. https://buy.stripe.com/xxxx_team
|
||||
stripe_webhook_secret: str = "" # e.g. whsec_xxxx
|
||||
stripe_api_key: str = "" # e.g. sk_test_xxxx (needed for webhook verification)
|
||||
|
||||
model_config = {"env_file": ".env", "env_file_encoding": "utf-8"}
|
||||
|
||||
@property
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ dependencies = [
|
|||
"passlib[bcrypt]>=1.7,<2.0",
|
||||
"bcrypt==4.0.1",
|
||||
"email-validator>=2.0,<3.0",
|
||||
"stripe>=5.0,<16.0",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
|
|
|
|||
822
tests/test_billing.py
Normal file
822
tests/test_billing.py
Normal file
|
|
@ -0,0 +1,822 @@
|
|||
"""Tests for Stripe Checkout Link billing integration.
|
||||
|
||||
Covers:
|
||||
- Checkout URL generation for pro/team tiers
|
||||
- Invalid tier handling
|
||||
- Missing checkout URL handling
|
||||
- Billing status endpoint
|
||||
- Webhook handling (checkout.session.completed, subscription events)
|
||||
- Manual HMAC signature verification
|
||||
- Tier determination from payment amounts
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
import time
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
from app.database import Base
|
||||
from app.dependencies import get_db
|
||||
from app.auth import create_access_token, hash_password
|
||||
from app.main import app
|
||||
from app.models.saas_models import Organization, OrganizationMember, User
|
||||
|
||||
# ── Test database setup ──────────────────────────────────────────────────────
|
||||
|
||||
TEST_DATABASE_URL = "sqlite+aiosqlite:///:memory:"
|
||||
test_engine = create_async_engine(TEST_DATABASE_URL, echo=False)
|
||||
TestSessionLocal = async_sessionmaker(
|
||||
test_engine, class_=AsyncSession, expire_on_commit=False
|
||||
)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session", autouse=True)
|
||||
async def setup_database():
|
||||
"""Create all tables once for the test session."""
|
||||
import app.models.saas_models # noqa: F401
|
||||
import app.models.models # noqa: F401
|
||||
|
||||
async with test_engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
yield
|
||||
async with test_engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.drop_all)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def db_session():
|
||||
"""Provide a clean database session for each test."""
|
||||
async with TestSessionLocal() as session:
|
||||
yield session
|
||||
await session.rollback()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def client(db_session: AsyncSession):
|
||||
"""Provide an HTTP test client with DB dependency override."""
|
||||
async def override_get_db():
|
||||
yield db_session
|
||||
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as ac:
|
||||
yield ac
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
async def _create_user_and_org(db: AsyncSession, tier: str = "free"):
|
||||
"""Helper to create a user, org, and auth token."""
|
||||
user = User(
|
||||
email=f"test-{tier}-{hashlib.md5(str(time.time()).encode()).hexdigest()[:8]}@example.com",
|
||||
password_hash=hash_password("testpass123"),
|
||||
)
|
||||
db.add(user)
|
||||
await db.flush()
|
||||
|
||||
org = Organization(
|
||||
name=f"Test Org {tier}",
|
||||
slug=f"test-org-{tier}-{hashlib.md5(str(time.time()).encode()).hexdigest()[:8]}",
|
||||
tier=tier,
|
||||
)
|
||||
db.add(org)
|
||||
await db.flush()
|
||||
|
||||
membership = OrganizationMember(
|
||||
organization_id=org.id,
|
||||
user_id=user.id,
|
||||
role="owner",
|
||||
)
|
||||
db.add(membership)
|
||||
await db.flush()
|
||||
|
||||
token = create_access_token(user.id)
|
||||
return user, org, token
|
||||
|
||||
|
||||
# ── Checkout endpoint tests ──────────────────────────────────────────────────
|
||||
|
||||
class TestCheckoutEndpoint:
|
||||
"""Tests for GET /api/v1/billing/checkout/{tier}"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_checkout_pro_returns_redirect_url(self, client, db_session):
|
||||
"""Pro checkout should return a redirect URL with org ID and email."""
|
||||
from app.config import settings
|
||||
original_pro_url = settings.stripe_pro_checkout_url
|
||||
settings.stripe_pro_checkout_url = "https://buy.stripe.com/test_pro"
|
||||
|
||||
user, org, token = await _create_user_and_org(db_session)
|
||||
|
||||
response = await client.get(
|
||||
"/api/v1/billing/checkout/pro",
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
|
||||
settings.stripe_pro_checkout_url = original_pro_url
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "redirect_url" in data
|
||||
assert data["tier"] == "pro"
|
||||
assert data["organization_id"] == org.id
|
||||
assert "client_reference_id" in data["redirect_url"]
|
||||
assert "prefilled_email" in data["redirect_url"]
|
||||
assert data["redirect_url"].startswith("https://buy.stripe.com/test_pro?")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_checkout_team_returns_redirect_url(self, client, db_session):
|
||||
"""Team checkout should return a redirect URL."""
|
||||
from app.config import settings
|
||||
original_team_url = settings.stripe_team_checkout_url
|
||||
settings.stripe_team_checkout_url = "https://buy.stripe.com/test_team"
|
||||
|
||||
user, org, token = await _create_user_and_org(db_session)
|
||||
|
||||
response = await client.get(
|
||||
"/api/v1/billing/checkout/team",
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
|
||||
settings.stripe_team_checkout_url = original_team_url
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["tier"] == "team"
|
||||
assert data["organization_id"] == org.id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_checkout_invalid_tier(self, client, db_session):
|
||||
"""Invalid tier should return 400."""
|
||||
user, org, token = await _create_user_and_org(db_session)
|
||||
|
||||
response = await client.get(
|
||||
"/api/v1/billing/checkout/enterprise",
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "Invalid tier" in response.json()["detail"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_checkout_free_tier_invalid(self, client, db_session):
|
||||
"""Free tier checkout should return 400 (can't upgrade to free)."""
|
||||
user, org, token = await _create_user_and_org(db_session)
|
||||
|
||||
response = await client.get(
|
||||
"/api/v1/billing/checkout/free",
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_checkout_missing_url_returns_503(self, client, db_session):
|
||||
"""If checkout URL is not configured, return 503."""
|
||||
from app.config import settings
|
||||
original_pro_url = settings.stripe_pro_checkout_url
|
||||
settings.stripe_pro_checkout_url = "" # Empty = not configured
|
||||
|
||||
user, org, token = await _create_user_and_org(db_session)
|
||||
|
||||
response = await client.get(
|
||||
"/api/v1/billing/checkout/pro",
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
|
||||
settings.stripe_pro_checkout_url = original_pro_url
|
||||
|
||||
assert response.status_code == 503
|
||||
assert "not configured" in response.json()["detail"].lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_checkout_requires_auth(self, client, db_session):
|
||||
"""Checkout endpoint requires authentication."""
|
||||
response = await client.get("/api/v1/billing/checkout/pro")
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
# ── Billing status endpoint tests ─────────────────────────────────────────────
|
||||
|
||||
class TestBillingStatusEndpoint:
|
||||
"""Tests for GET /api/v1/billing/status"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_billing_status_free_org(self, client, db_session):
|
||||
"""Free org should see pro and team checkout URLs."""
|
||||
from app.config import settings
|
||||
original_pro = settings.stripe_pro_checkout_url
|
||||
original_team = settings.stripe_team_checkout_url
|
||||
settings.stripe_pro_checkout_url = "https://buy.stripe.com/pro"
|
||||
settings.stripe_team_checkout_url = "https://buy.stripe.com/team"
|
||||
|
||||
user, org, token = await _create_user_and_org(db_session, tier="free")
|
||||
|
||||
response = await client.get(
|
||||
"/api/v1/billing/status",
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
|
||||
settings.stripe_pro_checkout_url = original_pro
|
||||
settings.stripe_team_checkout_url = original_team
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["tier"] == "free"
|
||||
assert "tier_info" in data
|
||||
assert "pro" in data["checkout_urls"]
|
||||
assert "team" in data["checkout_urls"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_billing_status_pro_org(self, client, db_session):
|
||||
"""Pro org should only see team upgrade URL."""
|
||||
from app.config import settings
|
||||
original_team = settings.stripe_team_checkout_url
|
||||
settings.stripe_team_checkout_url = "https://buy.stripe.com/team"
|
||||
|
||||
user, org, token = await _create_user_and_org(db_session, tier="pro")
|
||||
|
||||
response = await client.get(
|
||||
"/api/v1/billing/status",
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
|
||||
settings.stripe_team_checkout_url = original_team
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["tier"] == "pro"
|
||||
assert "pro" not in data["checkout_urls"]
|
||||
assert "team" in data["checkout_urls"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_billing_status_team_org(self, client, db_session):
|
||||
"""Team org should see no upgrade URLs (already at top tier)."""
|
||||
user, org, token = await _create_user_and_org(db_session, tier="team")
|
||||
|
||||
response = await client.get(
|
||||
"/api/v1/billing/status",
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["tier"] == "team"
|
||||
assert data["checkout_urls"] == {}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_billing_status_shows_tier_limits(self, client, db_session):
|
||||
"""Billing status should include tier limits from tier_limits module."""
|
||||
user, org, token = await _create_user_and_org(db_session, tier="free")
|
||||
|
||||
response = await client.get(
|
||||
"/api/v1/billing/status",
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
assert "tier_info" in data
|
||||
assert data["tier_info"]["tier"] == "free"
|
||||
assert "limits" in data["tier_info"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_billing_status_requires_auth(self, client, db_session):
|
||||
"""Billing status endpoint requires authentication."""
|
||||
response = await client.get("/api/v1/billing/status")
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
# ── Success/cancel callback tests ────────────────────────────────────────────
|
||||
|
||||
class TestBillingCallbacks:
|
||||
"""Tests for GET /api/v1/billing/success and /cancel"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_success_callback(self, client, db_session):
|
||||
"""Success callback should return success message."""
|
||||
response = await client.get(
|
||||
"/api/v1/billing/success?session_id=cs_test_123"
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "activated" in data["message"].lower() or "upgraded" in data["message"].lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_success_callback_without_session_id(self, client, db_session):
|
||||
"""Success callback works even without session_id (webhook already processed)."""
|
||||
response = await client.get("/api/v1/billing/success")
|
||||
assert response.status_code == 200
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_callback(self, client, db_session):
|
||||
"""Cancel callback should return cancellation message."""
|
||||
response = await client.get("/api/v1/billing/cancel")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "canceled" in data["message"].lower()
|
||||
|
||||
|
||||
# ── Webhook tests ─────────────────────────────────────────────────────────────
|
||||
|
||||
class TestWebhookHandler:
|
||||
"""Tests for POST /api/v1/billing/webhook"""
|
||||
|
||||
def _make_checkout_session_event(
|
||||
self, org_id: str, amount_total: int = 900, customer_id: str = "cus_test123"
|
||||
) -> dict:
|
||||
"""Create a mock checkout.session.completed event."""
|
||||
return {
|
||||
"id": "evt_test_123",
|
||||
"object": "event",
|
||||
"type": "checkout.session.completed",
|
||||
"data": {
|
||||
"object": {
|
||||
"id": "cs_test_123",
|
||||
"object": "checkout.session",
|
||||
"client_reference_id": org_id,
|
||||
"customer": customer_id,
|
||||
"amount_total": amount_total,
|
||||
"line_items": {"data": []},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
def _make_subscription_event(
|
||||
self, customer_id: str, event_type: str, amount: int = 900
|
||||
) -> dict:
|
||||
"""Create a mock subscription event."""
|
||||
return {
|
||||
"id": "evt_test_sub_123",
|
||||
"object": "event",
|
||||
"type": event_type,
|
||||
"data": {
|
||||
"object": {
|
||||
"id": "sub_test_123",
|
||||
"object": "subscription",
|
||||
"customer": customer_id,
|
||||
"metadata": {},
|
||||
"items": {
|
||||
"data": [
|
||||
{
|
||||
"plan": {"amount": amount},
|
||||
}
|
||||
]
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_webhook_checkout_session_pro(self, client, db_session):
|
||||
"""checkout.session.completed with $9 amount should set tier to pro."""
|
||||
from app.config import settings
|
||||
original_secret = settings.stripe_webhook_secret
|
||||
original_api_key = settings.stripe_api_key
|
||||
# Disable signature verification for tests
|
||||
settings.stripe_webhook_secret = ""
|
||||
settings.stripe_api_key = ""
|
||||
|
||||
user, org, token = await _create_user_and_org(db_session, tier="free")
|
||||
|
||||
event = self._make_checkout_session_event(org.id, amount_total=900)
|
||||
response = await client.post(
|
||||
"/api/v1/billing/webhook",
|
||||
json=event,
|
||||
)
|
||||
|
||||
settings.stripe_webhook_secret = original_secret
|
||||
settings.stripe_api_key = original_api_key
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "ok"
|
||||
assert data["event_type"] == "checkout.session.completed"
|
||||
|
||||
# Verify org was upgraded
|
||||
result = await db_session.execute(
|
||||
select(Organization).where(Organization.id == org.id)
|
||||
)
|
||||
updated_org = result.scalar_one()
|
||||
assert updated_org.tier == "pro"
|
||||
assert updated_org.stripe_customer_id == "cus_test123"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_webhook_checkout_session_team(self, client, db_session):
|
||||
"""checkout.session.completed with $29 amount should set tier to team."""
|
||||
from app.config import settings
|
||||
original_secret = settings.stripe_webhook_secret
|
||||
original_api_key = settings.stripe_api_key
|
||||
settings.stripe_webhook_secret = ""
|
||||
settings.stripe_api_key = ""
|
||||
|
||||
user, org, token = await _create_user_and_org(db_session, tier="free")
|
||||
|
||||
event = self._make_checkout_session_event(org.id, amount_total=2900)
|
||||
response = await client.post(
|
||||
"/api/v1/billing/webhook",
|
||||
json=event,
|
||||
)
|
||||
|
||||
settings.stripe_webhook_secret = original_secret
|
||||
settings.stripe_api_key = original_api_key
|
||||
|
||||
assert response.status_code == 200
|
||||
result = await db_session.execute(
|
||||
select(Organization).where(Organization.id == org.id)
|
||||
)
|
||||
updated_org = result.scalar_one()
|
||||
assert updated_org.tier == "team"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_webhook_subscription_deleted_downgrades_to_free(self, client, db_session):
|
||||
"""customer.subscription.deleted should downgrade org to free."""
|
||||
from app.config import settings
|
||||
original_secret = settings.stripe_webhook_secret
|
||||
original_api_key = settings.stripe_api_key
|
||||
settings.stripe_webhook_secret = ""
|
||||
settings.stripe_api_key = ""
|
||||
|
||||
user, org, token = await _create_user_and_org(db_session, tier="pro")
|
||||
# Manually set stripe_customer_id
|
||||
org.stripe_customer_id = "cus_delete_test"
|
||||
await db_session.flush()
|
||||
|
||||
event = self._make_subscription_event(
|
||||
customer_id="cus_delete_test",
|
||||
event_type="customer.subscription.deleted",
|
||||
)
|
||||
response = await client.post(
|
||||
"/api/v1/billing/webhook",
|
||||
json=event,
|
||||
)
|
||||
|
||||
settings.stripe_webhook_secret = original_secret
|
||||
settings.stripe_api_key = original_api_key
|
||||
|
||||
assert response.status_code == 200
|
||||
result = await db_session.execute(
|
||||
select(Organization).where(Organization.id == org.id)
|
||||
)
|
||||
updated_org = result.scalar_one()
|
||||
assert updated_org.tier == "free"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_webhook_subscription_updated(self, client, db_session):
|
||||
"""customer.subscription.updated should update org tier based on amount."""
|
||||
from app.config import settings
|
||||
original_secret = settings.stripe_webhook_secret
|
||||
original_api_key = settings.stripe_api_key
|
||||
settings.stripe_webhook_secret = ""
|
||||
settings.stripe_api_key = ""
|
||||
|
||||
user, org, token = await _create_user_and_org(db_session, tier="pro")
|
||||
org.stripe_customer_id = "cus_update_test"
|
||||
await db_session.flush()
|
||||
|
||||
event = self._make_subscription_event(
|
||||
customer_id="cus_update_test",
|
||||
event_type="customer.subscription.updated",
|
||||
amount=2900, # Upgrade to team
|
||||
)
|
||||
response = await client.post(
|
||||
"/api/v1/billing/webhook",
|
||||
json=event,
|
||||
)
|
||||
|
||||
settings.stripe_webhook_secret = original_secret
|
||||
settings.stripe_api_key = original_api_key
|
||||
|
||||
assert response.status_code == 200
|
||||
result = await db_session.execute(
|
||||
select(Organization).where(Organization.id == org.id)
|
||||
)
|
||||
updated_org = result.scalar_one()
|
||||
assert updated_org.tier == "team"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_webhook_ignores_unhandled_events(self, client, db_session):
|
||||
"""Unhandled event types should be ignored."""
|
||||
from app.config import settings
|
||||
original_secret = settings.stripe_webhook_secret
|
||||
settings.stripe_webhook_secret = ""
|
||||
|
||||
event = {
|
||||
"id": "evt_test_ignore",
|
||||
"type": "account.updated",
|
||||
"data": {"object": {}},
|
||||
}
|
||||
response = await client.post(
|
||||
"/api/v1/billing/webhook",
|
||||
json=event,
|
||||
)
|
||||
|
||||
settings.stripe_webhook_secret = original_secret
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "ignored"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_webhook_invalid_signature_rejected(self, client, db_session):
|
||||
"""Webhook with invalid signature should be rejected."""
|
||||
from app.config import settings
|
||||
original_secret = settings.stripe_webhook_secret
|
||||
original_api_key = settings.stripe_api_key
|
||||
# Set up webhook secret but no API key (forces manual verification)
|
||||
settings.stripe_webhook_secret = "whsec_test_secret_key_12345"
|
||||
settings.stripe_api_key = ""
|
||||
|
||||
event = self._make_checkout_session_event("fake_org_id")
|
||||
|
||||
response = await client.post(
|
||||
"/api/v1/billing/webhook",
|
||||
json=event,
|
||||
headers={"stripe-signature": "t=1234,v1=invalid_signature_here"},
|
||||
)
|
||||
|
||||
settings.stripe_webhook_secret = original_secret
|
||||
settings.stripe_api_key = original_api_key
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "signature" in response.json()["detail"].lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_webhook_checkout_missing_org_id(self, client, db_session):
|
||||
"""Webhook without client_reference_id should return error."""
|
||||
from app.config import settings
|
||||
original_secret = settings.stripe_webhook_secret
|
||||
settings.stripe_webhook_secret = ""
|
||||
|
||||
event = {
|
||||
"id": "evt_test_no_org",
|
||||
"type": "checkout.session.completed",
|
||||
"data": {
|
||||
"object": {
|
||||
"id": "cs_test_no_org",
|
||||
"client_reference_id": None,
|
||||
"customer": "cus_123",
|
||||
"amount_total": 900,
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
response = await client.post(
|
||||
"/api/v1/billing/webhook",
|
||||
json=event,
|
||||
)
|
||||
|
||||
settings.stripe_webhook_secret = original_secret
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["status"] == "error"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_webhook_invoice_payment_succeeded(self, client, db_session):
|
||||
"""invoice.payment_succeeded should acknowledge without downgrading."""
|
||||
from app.config import settings
|
||||
original_secret = settings.stripe_webhook_secret
|
||||
original_api_key = settings.stripe_api_key
|
||||
settings.stripe_webhook_secret = ""
|
||||
settings.stripe_api_key = ""
|
||||
|
||||
user, org, token = await _create_user_and_org(db_session, tier="pro")
|
||||
org.stripe_customer_id = "cus_invoice_test"
|
||||
await db_session.flush()
|
||||
|
||||
event = {
|
||||
"id": "evt_invoice_paid",
|
||||
"type": "invoice.payment_succeeded",
|
||||
"data": {
|
||||
"object": {
|
||||
"id": "in_test_123",
|
||||
"customer": "cus_invoice_test",
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
response = await client.post(
|
||||
"/api/v1/billing/webhook",
|
||||
json=event,
|
||||
)
|
||||
|
||||
settings.stripe_webhook_secret = original_secret
|
||||
settings.stripe_api_key = original_api_key
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "ok"
|
||||
|
||||
# Verify tier wasn't changed
|
||||
result = await db_session.execute(
|
||||
select(Organization).where(Organization.id == org.id)
|
||||
)
|
||||
updated_org = result.scalar_one()
|
||||
assert updated_org.tier == "pro"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_webhook_invoice_payment_failed(self, client, db_session):
|
||||
"""invoice.payment_failed should not downgrade (Stripe retries)."""
|
||||
from app.config import settings
|
||||
original_secret = settings.stripe_webhook_secret
|
||||
original_api_key = settings.stripe_api_key
|
||||
settings.stripe_webhook_secret = ""
|
||||
settings.stripe_api_key = ""
|
||||
|
||||
user, org, token = await _create_user_and_org(db_session, tier="pro")
|
||||
org.stripe_customer_id = "cus_fail_test"
|
||||
await db_session.flush()
|
||||
|
||||
event = {
|
||||
"id": "evt_invoice_failed",
|
||||
"type": "invoice.payment_failed",
|
||||
"data": {
|
||||
"object": {
|
||||
"id": "in_test_fail",
|
||||
"customer": "cus_fail_test",
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
response = await client.post(
|
||||
"/api/v1/billing/webhook",
|
||||
json=event,
|
||||
)
|
||||
|
||||
settings.stripe_webhook_secret = original_secret
|
||||
settings.stripe_api_key = original_api_key
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "ok"
|
||||
|
||||
# Verify tier wasn't changed (payment will be retried)
|
||||
result = await db_session.execute(
|
||||
select(Organization).where(Organization.id == org.id)
|
||||
)
|
||||
updated_org = result.scalar_one()
|
||||
assert updated_org.tier == "pro"
|
||||
|
||||
|
||||
# ── Tier determination helper tests ───────────────────────────────────────────
|
||||
|
||||
class TestTierDetermination:
|
||||
"""Tests for _determine_tier_from_amount helper."""
|
||||
|
||||
def test_pro_from_amount(self):
|
||||
from app.api.billing import _determine_tier_from_amount
|
||||
assert _determine_tier_from_amount(900) == "pro"
|
||||
assert _determine_tier_from_amount(1500) == "pro"
|
||||
|
||||
def test_team_from_amount(self):
|
||||
from app.api.billing import _determine_tier_from_amount
|
||||
assert _determine_tier_from_amount(2900) == "team"
|
||||
assert _determine_tier_from_amount(5000) == "team"
|
||||
assert _determine_tier_from_amount(9900) == "team"
|
||||
|
||||
def test_free_from_small_amount(self):
|
||||
from app.api.billing import _determine_tier_from_amount
|
||||
assert _determine_tier_from_amount(0) == "free"
|
||||
assert _determine_tier_from_amount(100) == "free"
|
||||
|
||||
|
||||
# ── Integration: checkout → webhook → tier update ────────────────────────────
|
||||
|
||||
class TestBillingIntegration:
|
||||
"""End-to-end style tests for the billing flow."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_free_to_pro_upgrade_flow(self, client, db_session):
|
||||
"""Complete flow: free org → checkout → webhook → pro tier."""
|
||||
from app.config import settings
|
||||
original_secret = settings.stripe_webhook_secret
|
||||
original_pro_url = settings.stripe_pro_checkout_url
|
||||
original_team_url = settings.stripe_team_checkout_url
|
||||
original_api_key = settings.stripe_api_key
|
||||
|
||||
settings.stripe_webhook_secret = ""
|
||||
settings.stripe_pro_checkout_url = "https://buy.stripe.com/test_pro_link"
|
||||
settings.stripe_team_checkout_url = "https://buy.stripe.com/test_team_link"
|
||||
settings.stripe_api_key = ""
|
||||
|
||||
user, org, token = await _create_user_and_org(db_session, tier="free")
|
||||
|
||||
# Step 1: Get checkout URL
|
||||
checkout_resp = await client.get(
|
||||
"/api/v1/billing/checkout/pro",
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
assert checkout_resp.status_code == 200
|
||||
checkout_data = checkout_resp.json()
|
||||
assert checkout_data["redirect_url"].startswith("https://buy.stripe.com/test_pro_link")
|
||||
assert "client_reference_id=" + org.id in checkout_data["redirect_url"]
|
||||
|
||||
# Step 2: Simulate Stripe webhook
|
||||
webhook_event = {
|
||||
"id": "evt_integration_test",
|
||||
"type": "checkout.session.completed",
|
||||
"data": {
|
||||
"object": {
|
||||
"id": "cs_integration_test",
|
||||
"client_reference_id": org.id,
|
||||
"customer": "cus_integration_test",
|
||||
"amount_total": 900,
|
||||
"line_items": {"data": []},
|
||||
}
|
||||
},
|
||||
}
|
||||
webhook_resp = await client.post(
|
||||
"/api/v1/billing/webhook",
|
||||
json=webhook_event,
|
||||
)
|
||||
assert webhook_resp.status_code == 200
|
||||
|
||||
# Step 3: Verify org is now pro
|
||||
status_resp = await client.get(
|
||||
"/api/v1/billing/status",
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
assert status_resp.status_code == 200
|
||||
assert status_resp.json()["tier"] == "pro"
|
||||
|
||||
# Step 4: Verify billing status shows team upgrade available
|
||||
assert "team" in status_resp.json()["checkout_urls"]
|
||||
|
||||
# Restore settings
|
||||
settings.stripe_webhook_secret = original_secret
|
||||
settings.stripe_pro_checkout_url = original_pro_url
|
||||
settings.stripe_team_checkout_url = original_team_url
|
||||
settings.stripe_api_key = original_api_key
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pro_to_free_downgrade_via_webhook(self, client, db_session):
|
||||
"""Complete flow: pro org → subscription deleted → free tier."""
|
||||
from app.config import settings
|
||||
original_secret = settings.stripe_webhook_secret
|
||||
original_api_key = settings.stripe_api_key
|
||||
settings.stripe_webhook_secret = ""
|
||||
settings.stripe_api_key = ""
|
||||
|
||||
user, org, token = await _create_user_and_org(db_session, tier="pro")
|
||||
org.stripe_customer_id = "cus_downgrade_test"
|
||||
await db_session.flush()
|
||||
|
||||
# Simulate subscription cancellation
|
||||
event = {
|
||||
"id": "evt_cancel_test",
|
||||
"type": "customer.subscription.deleted",
|
||||
"data": {
|
||||
"object": {
|
||||
"id": "sub_cancel_test",
|
||||
"customer": "cus_downgrade_test",
|
||||
}
|
||||
},
|
||||
}
|
||||
response = await client.post(
|
||||
"/api/v1/billing/webhook",
|
||||
json=event,
|
||||
)
|
||||
|
||||
settings.stripe_webhook_secret = original_secret
|
||||
settings.stripe_api_key = original_api_key
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
# Verify org is now free
|
||||
result = await db_session.execute(
|
||||
select(Organization).where(Organization.id == org.id)
|
||||
)
|
||||
updated_org = result.scalar_one()
|
||||
assert updated_org.tier == "free"
|
||||
|
||||
# Verify billing status shows upgrade options
|
||||
status_resp = await client.get(
|
||||
"/api/v1/billing/status",
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
data = status_resp.json()
|
||||
assert data["tier"] == "free"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_checkout_url_includes_prefilled_email(self, client, db_session):
|
||||
"""Checkout URL should include the user's email as prefilled_email."""
|
||||
from app.config import settings
|
||||
original_pro = settings.stripe_pro_checkout_url
|
||||
settings.stripe_pro_checkout_url = "https://buy.stripe.com/test_pro"
|
||||
|
||||
user, org, token = await _create_user_and_org(db_session)
|
||||
|
||||
response = await client.get(
|
||||
"/api/v1/billing/checkout/pro",
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
|
||||
settings.stripe_pro_checkout_url = original_pro
|
||||
|
||||
data = response.json()
|
||||
assert f"prefilled_email={user.email}" in data["redirect_url"] or \
|
||||
f"prefilled_email={user.email.replace('@', '%40')}" in data["redirect_url"]
|
||||
Loading…
Add table
Add a link
Reference in a new issue