- Add billing module (app/api/billing.py) with 5 API endpoints:
- GET /api/v1/billing/checkout/{tier} — redirect to Stripe Payment Links
- GET /api/v1/billing/status — current org tier, limits, upgrade URLs
- GET /api/v1/billing/success — Stripe success callback
- GET /api/v1/billing/cancel — Stripe cancel callback
- POST /api/v1/billing/webhook — handles 5 Stripe event types
- Zero-code payment flow: uses pre-configured Stripe Payment Links
with client_reference_id (org ID) and prefilled_email params
- Webhook handler processes checkout.session.completed,
customer.subscription.updated/deleted, invoice events
- Stripe signature verification via stripe library (primary)
or manual HMAC-SHA256 (fallback)
- Tier determination from payment amount: =pro, 9=team
- 4 new config settings in app/config.py:
stripe_pro_checkout_url, stripe_team_checkout_url,
stripe_webhook_secret, stripe_api_key
- Added stripe>=5.0,<16.0 dependency
- 29 tests in tests/test_billing.py (all passing)
- Total: 98 tests passing (69 existing + 29 new)
822 lines
No EOL
30 KiB
Python
822 lines
No EOL
30 KiB
Python
"""Tests for Stripe Checkout Link billing integration.
|
|
|
|
Covers:
|
|
- Checkout URL generation for pro/team tiers
|
|
- Invalid tier handling
|
|
- Missing checkout URL handling
|
|
- Billing status endpoint
|
|
- Webhook handling (checkout.session.completed, subscription events)
|
|
- Manual HMAC signature verification
|
|
- Tier determination from payment amounts
|
|
"""
|
|
|
|
import hashlib
|
|
import hmac
|
|
import json
|
|
import time
|
|
|
|
import pytest
|
|
import pytest_asyncio
|
|
from httpx import ASGITransport, AsyncClient
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
|
|
|
from app.database import Base
|
|
from app.dependencies import get_db
|
|
from app.auth import create_access_token, hash_password
|
|
from app.main import app
|
|
from app.models.saas_models import Organization, OrganizationMember, User
|
|
|
|
# ── Test database setup ──────────────────────────────────────────────────────
|
|
|
|
TEST_DATABASE_URL = "sqlite+aiosqlite:///:memory:"
|
|
test_engine = create_async_engine(TEST_DATABASE_URL, echo=False)
|
|
TestSessionLocal = async_sessionmaker(
|
|
test_engine, class_=AsyncSession, expire_on_commit=False
|
|
)
|
|
|
|
|
|
@pytest_asyncio.fixture(scope="session", autouse=True)
|
|
async def setup_database():
|
|
"""Create all tables once for the test session."""
|
|
import app.models.saas_models # noqa: F401
|
|
import app.models.models # noqa: F401
|
|
|
|
async with test_engine.begin() as conn:
|
|
await conn.run_sync(Base.metadata.create_all)
|
|
yield
|
|
async with test_engine.begin() as conn:
|
|
await conn.run_sync(Base.metadata.drop_all)
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def db_session():
|
|
"""Provide a clean database session for each test."""
|
|
async with TestSessionLocal() as session:
|
|
yield session
|
|
await session.rollback()
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def client(db_session: AsyncSession):
|
|
"""Provide an HTTP test client with DB dependency override."""
|
|
async def override_get_db():
|
|
yield db_session
|
|
|
|
app.dependency_overrides[get_db] = override_get_db
|
|
transport = ASGITransport(app=app)
|
|
async with AsyncClient(transport=transport, base_url="http://test") as ac:
|
|
yield ac
|
|
app.dependency_overrides.clear()
|
|
|
|
|
|
async def _create_user_and_org(db: AsyncSession, tier: str = "free"):
|
|
"""Helper to create a user, org, and auth token."""
|
|
user = User(
|
|
email=f"test-{tier}-{hashlib.md5(str(time.time()).encode()).hexdigest()[:8]}@example.com",
|
|
password_hash=hash_password("testpass123"),
|
|
)
|
|
db.add(user)
|
|
await db.flush()
|
|
|
|
org = Organization(
|
|
name=f"Test Org {tier}",
|
|
slug=f"test-org-{tier}-{hashlib.md5(str(time.time()).encode()).hexdigest()[:8]}",
|
|
tier=tier,
|
|
)
|
|
db.add(org)
|
|
await db.flush()
|
|
|
|
membership = OrganizationMember(
|
|
organization_id=org.id,
|
|
user_id=user.id,
|
|
role="owner",
|
|
)
|
|
db.add(membership)
|
|
await db.flush()
|
|
|
|
token = create_access_token(user.id)
|
|
return user, org, token
|
|
|
|
|
|
# ── Checkout endpoint tests ──────────────────────────────────────────────────
|
|
|
|
class TestCheckoutEndpoint:
|
|
"""Tests for GET /api/v1/billing/checkout/{tier}"""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_checkout_pro_returns_redirect_url(self, client, db_session):
|
|
"""Pro checkout should return a redirect URL with org ID and email."""
|
|
from app.config import settings
|
|
original_pro_url = settings.stripe_pro_checkout_url
|
|
settings.stripe_pro_checkout_url = "https://buy.stripe.com/test_pro"
|
|
|
|
user, org, token = await _create_user_and_org(db_session)
|
|
|
|
response = await client.get(
|
|
"/api/v1/billing/checkout/pro",
|
|
headers={"Authorization": f"Bearer {token}"},
|
|
)
|
|
|
|
settings.stripe_pro_checkout_url = original_pro_url
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert "redirect_url" in data
|
|
assert data["tier"] == "pro"
|
|
assert data["organization_id"] == org.id
|
|
assert "client_reference_id" in data["redirect_url"]
|
|
assert "prefilled_email" in data["redirect_url"]
|
|
assert data["redirect_url"].startswith("https://buy.stripe.com/test_pro?")
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_checkout_team_returns_redirect_url(self, client, db_session):
|
|
"""Team checkout should return a redirect URL."""
|
|
from app.config import settings
|
|
original_team_url = settings.stripe_team_checkout_url
|
|
settings.stripe_team_checkout_url = "https://buy.stripe.com/test_team"
|
|
|
|
user, org, token = await _create_user_and_org(db_session)
|
|
|
|
response = await client.get(
|
|
"/api/v1/billing/checkout/team",
|
|
headers={"Authorization": f"Bearer {token}"},
|
|
)
|
|
|
|
settings.stripe_team_checkout_url = original_team_url
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["tier"] == "team"
|
|
assert data["organization_id"] == org.id
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_checkout_invalid_tier(self, client, db_session):
|
|
"""Invalid tier should return 400."""
|
|
user, org, token = await _create_user_and_org(db_session)
|
|
|
|
response = await client.get(
|
|
"/api/v1/billing/checkout/enterprise",
|
|
headers={"Authorization": f"Bearer {token}"},
|
|
)
|
|
|
|
assert response.status_code == 400
|
|
assert "Invalid tier" in response.json()["detail"]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_checkout_free_tier_invalid(self, client, db_session):
|
|
"""Free tier checkout should return 400 (can't upgrade to free)."""
|
|
user, org, token = await _create_user_and_org(db_session)
|
|
|
|
response = await client.get(
|
|
"/api/v1/billing/checkout/free",
|
|
headers={"Authorization": f"Bearer {token}"},
|
|
)
|
|
|
|
assert response.status_code == 400
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_checkout_missing_url_returns_503(self, client, db_session):
|
|
"""If checkout URL is not configured, return 503."""
|
|
from app.config import settings
|
|
original_pro_url = settings.stripe_pro_checkout_url
|
|
settings.stripe_pro_checkout_url = "" # Empty = not configured
|
|
|
|
user, org, token = await _create_user_and_org(db_session)
|
|
|
|
response = await client.get(
|
|
"/api/v1/billing/checkout/pro",
|
|
headers={"Authorization": f"Bearer {token}"},
|
|
)
|
|
|
|
settings.stripe_pro_checkout_url = original_pro_url
|
|
|
|
assert response.status_code == 503
|
|
assert "not configured" in response.json()["detail"].lower()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_checkout_requires_auth(self, client, db_session):
|
|
"""Checkout endpoint requires authentication."""
|
|
response = await client.get("/api/v1/billing/checkout/pro")
|
|
assert response.status_code == 401
|
|
|
|
|
|
# ── Billing status endpoint tests ─────────────────────────────────────────────
|
|
|
|
class TestBillingStatusEndpoint:
|
|
"""Tests for GET /api/v1/billing/status"""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_billing_status_free_org(self, client, db_session):
|
|
"""Free org should see pro and team checkout URLs."""
|
|
from app.config import settings
|
|
original_pro = settings.stripe_pro_checkout_url
|
|
original_team = settings.stripe_team_checkout_url
|
|
settings.stripe_pro_checkout_url = "https://buy.stripe.com/pro"
|
|
settings.stripe_team_checkout_url = "https://buy.stripe.com/team"
|
|
|
|
user, org, token = await _create_user_and_org(db_session, tier="free")
|
|
|
|
response = await client.get(
|
|
"/api/v1/billing/status",
|
|
headers={"Authorization": f"Bearer {token}"},
|
|
)
|
|
|
|
settings.stripe_pro_checkout_url = original_pro
|
|
settings.stripe_team_checkout_url = original_team
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["tier"] == "free"
|
|
assert "tier_info" in data
|
|
assert "pro" in data["checkout_urls"]
|
|
assert "team" in data["checkout_urls"]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_billing_status_pro_org(self, client, db_session):
|
|
"""Pro org should only see team upgrade URL."""
|
|
from app.config import settings
|
|
original_team = settings.stripe_team_checkout_url
|
|
settings.stripe_team_checkout_url = "https://buy.stripe.com/team"
|
|
|
|
user, org, token = await _create_user_and_org(db_session, tier="pro")
|
|
|
|
response = await client.get(
|
|
"/api/v1/billing/status",
|
|
headers={"Authorization": f"Bearer {token}"},
|
|
)
|
|
|
|
settings.stripe_team_checkout_url = original_team
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["tier"] == "pro"
|
|
assert "pro" not in data["checkout_urls"]
|
|
assert "team" in data["checkout_urls"]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_billing_status_team_org(self, client, db_session):
|
|
"""Team org should see no upgrade URLs (already at top tier)."""
|
|
user, org, token = await _create_user_and_org(db_session, tier="team")
|
|
|
|
response = await client.get(
|
|
"/api/v1/billing/status",
|
|
headers={"Authorization": f"Bearer {token}"},
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["tier"] == "team"
|
|
assert data["checkout_urls"] == {}
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_billing_status_shows_tier_limits(self, client, db_session):
|
|
"""Billing status should include tier limits from tier_limits module."""
|
|
user, org, token = await _create_user_and_org(db_session, tier="free")
|
|
|
|
response = await client.get(
|
|
"/api/v1/billing/status",
|
|
headers={"Authorization": f"Bearer {token}"},
|
|
)
|
|
|
|
data = response.json()
|
|
assert "tier_info" in data
|
|
assert data["tier_info"]["tier"] == "free"
|
|
assert "limits" in data["tier_info"]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_billing_status_requires_auth(self, client, db_session):
|
|
"""Billing status endpoint requires authentication."""
|
|
response = await client.get("/api/v1/billing/status")
|
|
assert response.status_code == 401
|
|
|
|
|
|
# ── Success/cancel callback tests ────────────────────────────────────────────
|
|
|
|
class TestBillingCallbacks:
|
|
"""Tests for GET /api/v1/billing/success and /cancel"""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_success_callback(self, client, db_session):
|
|
"""Success callback should return success message."""
|
|
response = await client.get(
|
|
"/api/v1/billing/success?session_id=cs_test_123"
|
|
)
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert "activated" in data["message"].lower() or "upgraded" in data["message"].lower()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_success_callback_without_session_id(self, client, db_session):
|
|
"""Success callback works even without session_id (webhook already processed)."""
|
|
response = await client.get("/api/v1/billing/success")
|
|
assert response.status_code == 200
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cancel_callback(self, client, db_session):
|
|
"""Cancel callback should return cancellation message."""
|
|
response = await client.get("/api/v1/billing/cancel")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert "canceled" in data["message"].lower()
|
|
|
|
|
|
# ── Webhook tests ─────────────────────────────────────────────────────────────
|
|
|
|
class TestWebhookHandler:
|
|
"""Tests for POST /api/v1/billing/webhook"""
|
|
|
|
def _make_checkout_session_event(
|
|
self, org_id: str, amount_total: int = 900, customer_id: str = "cus_test123"
|
|
) -> dict:
|
|
"""Create a mock checkout.session.completed event."""
|
|
return {
|
|
"id": "evt_test_123",
|
|
"object": "event",
|
|
"type": "checkout.session.completed",
|
|
"data": {
|
|
"object": {
|
|
"id": "cs_test_123",
|
|
"object": "checkout.session",
|
|
"client_reference_id": org_id,
|
|
"customer": customer_id,
|
|
"amount_total": amount_total,
|
|
"line_items": {"data": []},
|
|
}
|
|
},
|
|
}
|
|
|
|
def _make_subscription_event(
|
|
self, customer_id: str, event_type: str, amount: int = 900
|
|
) -> dict:
|
|
"""Create a mock subscription event."""
|
|
return {
|
|
"id": "evt_test_sub_123",
|
|
"object": "event",
|
|
"type": event_type,
|
|
"data": {
|
|
"object": {
|
|
"id": "sub_test_123",
|
|
"object": "subscription",
|
|
"customer": customer_id,
|
|
"metadata": {},
|
|
"items": {
|
|
"data": [
|
|
{
|
|
"plan": {"amount": amount},
|
|
}
|
|
]
|
|
},
|
|
}
|
|
},
|
|
}
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_webhook_checkout_session_pro(self, client, db_session):
|
|
"""checkout.session.completed with $9 amount should set tier to pro."""
|
|
from app.config import settings
|
|
original_secret = settings.stripe_webhook_secret
|
|
original_api_key = settings.stripe_api_key
|
|
# Disable signature verification for tests
|
|
settings.stripe_webhook_secret = ""
|
|
settings.stripe_api_key = ""
|
|
|
|
user, org, token = await _create_user_and_org(db_session, tier="free")
|
|
|
|
event = self._make_checkout_session_event(org.id, amount_total=900)
|
|
response = await client.post(
|
|
"/api/v1/billing/webhook",
|
|
json=event,
|
|
)
|
|
|
|
settings.stripe_webhook_secret = original_secret
|
|
settings.stripe_api_key = original_api_key
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["status"] == "ok"
|
|
assert data["event_type"] == "checkout.session.completed"
|
|
|
|
# Verify org was upgraded
|
|
result = await db_session.execute(
|
|
select(Organization).where(Organization.id == org.id)
|
|
)
|
|
updated_org = result.scalar_one()
|
|
assert updated_org.tier == "pro"
|
|
assert updated_org.stripe_customer_id == "cus_test123"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_webhook_checkout_session_team(self, client, db_session):
|
|
"""checkout.session.completed with $29 amount should set tier to team."""
|
|
from app.config import settings
|
|
original_secret = settings.stripe_webhook_secret
|
|
original_api_key = settings.stripe_api_key
|
|
settings.stripe_webhook_secret = ""
|
|
settings.stripe_api_key = ""
|
|
|
|
user, org, token = await _create_user_and_org(db_session, tier="free")
|
|
|
|
event = self._make_checkout_session_event(org.id, amount_total=2900)
|
|
response = await client.post(
|
|
"/api/v1/billing/webhook",
|
|
json=event,
|
|
)
|
|
|
|
settings.stripe_webhook_secret = original_secret
|
|
settings.stripe_api_key = original_api_key
|
|
|
|
assert response.status_code == 200
|
|
result = await db_session.execute(
|
|
select(Organization).where(Organization.id == org.id)
|
|
)
|
|
updated_org = result.scalar_one()
|
|
assert updated_org.tier == "team"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_webhook_subscription_deleted_downgrades_to_free(self, client, db_session):
|
|
"""customer.subscription.deleted should downgrade org to free."""
|
|
from app.config import settings
|
|
original_secret = settings.stripe_webhook_secret
|
|
original_api_key = settings.stripe_api_key
|
|
settings.stripe_webhook_secret = ""
|
|
settings.stripe_api_key = ""
|
|
|
|
user, org, token = await _create_user_and_org(db_session, tier="pro")
|
|
# Manually set stripe_customer_id
|
|
org.stripe_customer_id = "cus_delete_test"
|
|
await db_session.flush()
|
|
|
|
event = self._make_subscription_event(
|
|
customer_id="cus_delete_test",
|
|
event_type="customer.subscription.deleted",
|
|
)
|
|
response = await client.post(
|
|
"/api/v1/billing/webhook",
|
|
json=event,
|
|
)
|
|
|
|
settings.stripe_webhook_secret = original_secret
|
|
settings.stripe_api_key = original_api_key
|
|
|
|
assert response.status_code == 200
|
|
result = await db_session.execute(
|
|
select(Organization).where(Organization.id == org.id)
|
|
)
|
|
updated_org = result.scalar_one()
|
|
assert updated_org.tier == "free"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_webhook_subscription_updated(self, client, db_session):
|
|
"""customer.subscription.updated should update org tier based on amount."""
|
|
from app.config import settings
|
|
original_secret = settings.stripe_webhook_secret
|
|
original_api_key = settings.stripe_api_key
|
|
settings.stripe_webhook_secret = ""
|
|
settings.stripe_api_key = ""
|
|
|
|
user, org, token = await _create_user_and_org(db_session, tier="pro")
|
|
org.stripe_customer_id = "cus_update_test"
|
|
await db_session.flush()
|
|
|
|
event = self._make_subscription_event(
|
|
customer_id="cus_update_test",
|
|
event_type="customer.subscription.updated",
|
|
amount=2900, # Upgrade to team
|
|
)
|
|
response = await client.post(
|
|
"/api/v1/billing/webhook",
|
|
json=event,
|
|
)
|
|
|
|
settings.stripe_webhook_secret = original_secret
|
|
settings.stripe_api_key = original_api_key
|
|
|
|
assert response.status_code == 200
|
|
result = await db_session.execute(
|
|
select(Organization).where(Organization.id == org.id)
|
|
)
|
|
updated_org = result.scalar_one()
|
|
assert updated_org.tier == "team"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_webhook_ignores_unhandled_events(self, client, db_session):
|
|
"""Unhandled event types should be ignored."""
|
|
from app.config import settings
|
|
original_secret = settings.stripe_webhook_secret
|
|
settings.stripe_webhook_secret = ""
|
|
|
|
event = {
|
|
"id": "evt_test_ignore",
|
|
"type": "account.updated",
|
|
"data": {"object": {}},
|
|
}
|
|
response = await client.post(
|
|
"/api/v1/billing/webhook",
|
|
json=event,
|
|
)
|
|
|
|
settings.stripe_webhook_secret = original_secret
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["status"] == "ignored"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_webhook_invalid_signature_rejected(self, client, db_session):
|
|
"""Webhook with invalid signature should be rejected."""
|
|
from app.config import settings
|
|
original_secret = settings.stripe_webhook_secret
|
|
original_api_key = settings.stripe_api_key
|
|
# Set up webhook secret but no API key (forces manual verification)
|
|
settings.stripe_webhook_secret = "whsec_test_secret_key_12345"
|
|
settings.stripe_api_key = ""
|
|
|
|
event = self._make_checkout_session_event("fake_org_id")
|
|
|
|
response = await client.post(
|
|
"/api/v1/billing/webhook",
|
|
json=event,
|
|
headers={"stripe-signature": "t=1234,v1=invalid_signature_here"},
|
|
)
|
|
|
|
settings.stripe_webhook_secret = original_secret
|
|
settings.stripe_api_key = original_api_key
|
|
|
|
assert response.status_code == 400
|
|
assert "signature" in response.json()["detail"].lower()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_webhook_checkout_missing_org_id(self, client, db_session):
|
|
"""Webhook without client_reference_id should return error."""
|
|
from app.config import settings
|
|
original_secret = settings.stripe_webhook_secret
|
|
settings.stripe_webhook_secret = ""
|
|
|
|
event = {
|
|
"id": "evt_test_no_org",
|
|
"type": "checkout.session.completed",
|
|
"data": {
|
|
"object": {
|
|
"id": "cs_test_no_org",
|
|
"client_reference_id": None,
|
|
"customer": "cus_123",
|
|
"amount_total": 900,
|
|
}
|
|
},
|
|
}
|
|
|
|
response = await client.post(
|
|
"/api/v1/billing/webhook",
|
|
json=event,
|
|
)
|
|
|
|
settings.stripe_webhook_secret = original_secret
|
|
|
|
assert response.status_code == 200
|
|
assert response.json()["status"] == "error"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_webhook_invoice_payment_succeeded(self, client, db_session):
|
|
"""invoice.payment_succeeded should acknowledge without downgrading."""
|
|
from app.config import settings
|
|
original_secret = settings.stripe_webhook_secret
|
|
original_api_key = settings.stripe_api_key
|
|
settings.stripe_webhook_secret = ""
|
|
settings.stripe_api_key = ""
|
|
|
|
user, org, token = await _create_user_and_org(db_session, tier="pro")
|
|
org.stripe_customer_id = "cus_invoice_test"
|
|
await db_session.flush()
|
|
|
|
event = {
|
|
"id": "evt_invoice_paid",
|
|
"type": "invoice.payment_succeeded",
|
|
"data": {
|
|
"object": {
|
|
"id": "in_test_123",
|
|
"customer": "cus_invoice_test",
|
|
}
|
|
},
|
|
}
|
|
|
|
response = await client.post(
|
|
"/api/v1/billing/webhook",
|
|
json=event,
|
|
)
|
|
|
|
settings.stripe_webhook_secret = original_secret
|
|
settings.stripe_api_key = original_api_key
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["status"] == "ok"
|
|
|
|
# Verify tier wasn't changed
|
|
result = await db_session.execute(
|
|
select(Organization).where(Organization.id == org.id)
|
|
)
|
|
updated_org = result.scalar_one()
|
|
assert updated_org.tier == "pro"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_webhook_invoice_payment_failed(self, client, db_session):
|
|
"""invoice.payment_failed should not downgrade (Stripe retries)."""
|
|
from app.config import settings
|
|
original_secret = settings.stripe_webhook_secret
|
|
original_api_key = settings.stripe_api_key
|
|
settings.stripe_webhook_secret = ""
|
|
settings.stripe_api_key = ""
|
|
|
|
user, org, token = await _create_user_and_org(db_session, tier="pro")
|
|
org.stripe_customer_id = "cus_fail_test"
|
|
await db_session.flush()
|
|
|
|
event = {
|
|
"id": "evt_invoice_failed",
|
|
"type": "invoice.payment_failed",
|
|
"data": {
|
|
"object": {
|
|
"id": "in_test_fail",
|
|
"customer": "cus_fail_test",
|
|
}
|
|
},
|
|
}
|
|
|
|
response = await client.post(
|
|
"/api/v1/billing/webhook",
|
|
json=event,
|
|
)
|
|
|
|
settings.stripe_webhook_secret = original_secret
|
|
settings.stripe_api_key = original_api_key
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["status"] == "ok"
|
|
|
|
# Verify tier wasn't changed (payment will be retried)
|
|
result = await db_session.execute(
|
|
select(Organization).where(Organization.id == org.id)
|
|
)
|
|
updated_org = result.scalar_one()
|
|
assert updated_org.tier == "pro"
|
|
|
|
|
|
# ── Tier determination helper tests ───────────────────────────────────────────
|
|
|
|
class TestTierDetermination:
|
|
"""Tests for _determine_tier_from_amount helper."""
|
|
|
|
def test_pro_from_amount(self):
|
|
from app.api.billing import _determine_tier_from_amount
|
|
assert _determine_tier_from_amount(900) == "pro"
|
|
assert _determine_tier_from_amount(1500) == "pro"
|
|
|
|
def test_team_from_amount(self):
|
|
from app.api.billing import _determine_tier_from_amount
|
|
assert _determine_tier_from_amount(2900) == "team"
|
|
assert _determine_tier_from_amount(5000) == "team"
|
|
assert _determine_tier_from_amount(9900) == "team"
|
|
|
|
def test_free_from_small_amount(self):
|
|
from app.api.billing import _determine_tier_from_amount
|
|
assert _determine_tier_from_amount(0) == "free"
|
|
assert _determine_tier_from_amount(100) == "free"
|
|
|
|
|
|
# ── Integration: checkout → webhook → tier update ────────────────────────────
|
|
|
|
class TestBillingIntegration:
|
|
"""End-to-end style tests for the billing flow."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_free_to_pro_upgrade_flow(self, client, db_session):
|
|
"""Complete flow: free org → checkout → webhook → pro tier."""
|
|
from app.config import settings
|
|
original_secret = settings.stripe_webhook_secret
|
|
original_pro_url = settings.stripe_pro_checkout_url
|
|
original_team_url = settings.stripe_team_checkout_url
|
|
original_api_key = settings.stripe_api_key
|
|
|
|
settings.stripe_webhook_secret = ""
|
|
settings.stripe_pro_checkout_url = "https://buy.stripe.com/test_pro_link"
|
|
settings.stripe_team_checkout_url = "https://buy.stripe.com/test_team_link"
|
|
settings.stripe_api_key = ""
|
|
|
|
user, org, token = await _create_user_and_org(db_session, tier="free")
|
|
|
|
# Step 1: Get checkout URL
|
|
checkout_resp = await client.get(
|
|
"/api/v1/billing/checkout/pro",
|
|
headers={"Authorization": f"Bearer {token}"},
|
|
)
|
|
assert checkout_resp.status_code == 200
|
|
checkout_data = checkout_resp.json()
|
|
assert checkout_data["redirect_url"].startswith("https://buy.stripe.com/test_pro_link")
|
|
assert "client_reference_id=" + org.id in checkout_data["redirect_url"]
|
|
|
|
# Step 2: Simulate Stripe webhook
|
|
webhook_event = {
|
|
"id": "evt_integration_test",
|
|
"type": "checkout.session.completed",
|
|
"data": {
|
|
"object": {
|
|
"id": "cs_integration_test",
|
|
"client_reference_id": org.id,
|
|
"customer": "cus_integration_test",
|
|
"amount_total": 900,
|
|
"line_items": {"data": []},
|
|
}
|
|
},
|
|
}
|
|
webhook_resp = await client.post(
|
|
"/api/v1/billing/webhook",
|
|
json=webhook_event,
|
|
)
|
|
assert webhook_resp.status_code == 200
|
|
|
|
# Step 3: Verify org is now pro
|
|
status_resp = await client.get(
|
|
"/api/v1/billing/status",
|
|
headers={"Authorization": f"Bearer {token}"},
|
|
)
|
|
assert status_resp.status_code == 200
|
|
assert status_resp.json()["tier"] == "pro"
|
|
|
|
# Step 4: Verify billing status shows team upgrade available
|
|
assert "team" in status_resp.json()["checkout_urls"]
|
|
|
|
# Restore settings
|
|
settings.stripe_webhook_secret = original_secret
|
|
settings.stripe_pro_checkout_url = original_pro_url
|
|
settings.stripe_team_checkout_url = original_team_url
|
|
settings.stripe_api_key = original_api_key
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_pro_to_free_downgrade_via_webhook(self, client, db_session):
|
|
"""Complete flow: pro org → subscription deleted → free tier."""
|
|
from app.config import settings
|
|
original_secret = settings.stripe_webhook_secret
|
|
original_api_key = settings.stripe_api_key
|
|
settings.stripe_webhook_secret = ""
|
|
settings.stripe_api_key = ""
|
|
|
|
user, org, token = await _create_user_and_org(db_session, tier="pro")
|
|
org.stripe_customer_id = "cus_downgrade_test"
|
|
await db_session.flush()
|
|
|
|
# Simulate subscription cancellation
|
|
event = {
|
|
"id": "evt_cancel_test",
|
|
"type": "customer.subscription.deleted",
|
|
"data": {
|
|
"object": {
|
|
"id": "sub_cancel_test",
|
|
"customer": "cus_downgrade_test",
|
|
}
|
|
},
|
|
}
|
|
response = await client.post(
|
|
"/api/v1/billing/webhook",
|
|
json=event,
|
|
)
|
|
|
|
settings.stripe_webhook_secret = original_secret
|
|
settings.stripe_api_key = original_api_key
|
|
|
|
assert response.status_code == 200
|
|
|
|
# Verify org is now free
|
|
result = await db_session.execute(
|
|
select(Organization).where(Organization.id == org.id)
|
|
)
|
|
updated_org = result.scalar_one()
|
|
assert updated_org.tier == "free"
|
|
|
|
# Verify billing status shows upgrade options
|
|
status_resp = await client.get(
|
|
"/api/v1/billing/status",
|
|
headers={"Authorization": f"Bearer {token}"},
|
|
)
|
|
data = status_resp.json()
|
|
assert data["tier"] == "free"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_checkout_url_includes_prefilled_email(self, client, db_session):
|
|
"""Checkout URL should include the user's email as prefilled_email."""
|
|
from app.config import settings
|
|
original_pro = settings.stripe_pro_checkout_url
|
|
settings.stripe_pro_checkout_url = "https://buy.stripe.com/test_pro"
|
|
|
|
user, org, token = await _create_user_and_org(db_session)
|
|
|
|
response = await client.get(
|
|
"/api/v1/billing/checkout/pro",
|
|
headers={"Authorization": f"Bearer {token}"},
|
|
)
|
|
|
|
settings.stripe_pro_checkout_url = original_pro
|
|
|
|
data = response.json()
|
|
assert f"prefilled_email={user.email}" in data["redirect_url"] or \
|
|
f"prefilled_email={user.email.replace('@', '%40')}" in data["redirect_url"] |