feat: indie status page SaaS - initial release

This commit is contained in:
Ubuntu 2026-04-25 09:39:57 +00:00
parent ee2bc87ade
commit b7a8142ca0
14 changed files with 2703 additions and 0 deletions

1262
SAAS_ENHANCEMENT_PLAN.md Normal file

File diff suppressed because it is too large Load diff

155
app/api/organizations.py Normal file
View file

@ -0,0 +1,155 @@
"""Organization API endpoints: view org, list tiers, upgrade/downgrade."""
from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.auth import get_current_user, get_current_org
from app.dependencies import get_db
from app.models.saas_models import Organization, OrganizationMember, User
from app.services.tier_limits import (
TIER_LIMITS,
get_org_limits,
get_tier_info,
)
router = APIRouter(tags=["organizations"])
# ── Response schemas ────────────────────────────────────────────────────────
class OrgMemberResponse(BaseModel):
user_id: str
email: str
display_name: str | None
role: str
class OrgResponse(BaseModel):
id: str
slug: str
name: str
tier: str
custom_domain: str | None
member_count: int
tier_info: dict
class TierDetailResponse(BaseModel):
tier: str
limits: dict
class UpgradeRequest(BaseModel):
tier: str # "free" | "pro" | "team"
# ── Endpoints ───────────────────────────────────────────────────────────────
@router.get("/tiers")
async def list_tiers():
"""List all available tiers and their limits (public endpoint)."""
return {
"tiers": [
{
"name": tier_name,
"display_name": {
"free": "Free",
"pro": "Pro ($9/mo)",
"team": "Team ($29/mo)",
}.get(tier_name, tier_name),
"limits": limits,
}
for tier_name, limits in TIER_LIMITS.items()
]
}
@router.get("/my", response_model=OrgResponse)
async def get_my_org(
org: Organization = Depends(get_current_org),
db: AsyncSession = Depends(get_db),
):
"""Get the current user's organization with tier limits info."""
# Count members
result = await db.execute(
select(OrganizationMember).where(
OrganizationMember.organization_id == org.id
)
)
members = result.scalars().all()
return OrgResponse(
id=org.id,
slug=org.slug,
name=org.name,
tier=org.tier or "free",
custom_domain=org.custom_domain,
member_count=len(members),
tier_info=get_tier_info(org),
)
@router.patch("/my/tier", response_model=OrgResponse)
async def update_org_tier(
body: UpgradeRequest,
org: Organization = Depends(get_current_org),
user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""Update the organization's tier.
In production, this would be gated by Stripe payment verification.
For now, this is an admin-only endpoint that directly sets the tier.
Only the org owner can change the tier.
"""
# Verify user is owner
result = await db.execute(
select(OrganizationMember).where(
OrganizationMember.organization_id == org.id,
OrganizationMember.user_id == user.id,
)
)
membership = result.scalar_one_or_none()
if not membership or membership.role != "owner":
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Only the organization owner can change the plan tier.",
)
if body.tier not in ("free", "pro", "team"):
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=f"Invalid tier '{body.tier}'. Must be one of: free, pro, team.",
)
org.tier = body.tier
await db.flush()
await db.refresh(org)
# Count members for response
members_result = await db.execute(
select(OrganizationMember).where(
OrganizationMember.organization_id == org.id
)
)
members = members_result.scalars().all()
return OrgResponse(
id=org.id,
slug=org.slug,
name=org.name,
tier=org.tier,
custom_domain=org.custom_domain,
member_count=len(members),
tier_info=get_tier_info(org),
)
@router.get("/my/limits")
async def get_my_limits(
org: Organization = Depends(get_current_org),
):
"""Get the current organization's tier limits and feature flags."""
return get_tier_info(org)

View file

@ -5,9 +5,13 @@ from app.api.incidents import router as incidents_router
from app.api.monitors import router as monitors_router
from app.api.subscribers import router as subscribers_router
from app.api.settings import router as settings_router
from app.api.organizations import router as organizations_router
from app.routes.auth import router as auth_router
api_v1_router = APIRouter()
api_v1_router.include_router(auth_router, tags=["auth"])
api_v1_router.include_router(organizations_router, prefix="/organizations", tags=["organizations"])
api_v1_router.include_router(services_router, prefix="/services", tags=["services"])
api_v1_router.include_router(incidents_router, prefix="/incidents", tags=["incidents"])
api_v1_router.include_router(monitors_router, prefix="/monitors", tags=["monitors"])

99
app/auth.py Normal file
View file

@ -0,0 +1,99 @@
"""JWT authentication and password hashing utilities."""
from datetime import datetime, timedelta
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from jose import JWTError, jwt
from passlib.context import CryptContext
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import settings
from app.dependencies import get_db
from app.models.saas_models import User, Organization, OrganizationMember
# Password hashing context
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
# OAuth2 scheme for token extraction
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
def hash_password(password: str) -> str:
"""Hash a password using bcrypt."""
return pwd_context.hash(password)
def verify_password(plain_password: str, hashed_password: str) -> bool:
"""Verify a password against a hash."""
return pwd_context.verify(plain_password, hashed_password)
def create_access_token(user_id: str, exp_hours: int = 72) -> str:
"""Create a JWT access token for the given user ID."""
payload = {
"sub": user_id,
"exp": datetime.utcnow() + timedelta(hours=exp_hours),
}
return jwt.encode(payload, settings.secret_key, algorithm="HS256")
def decode_access_token(token: str) -> dict:
"""Decode and verify a JWT access token. Raises JWTError on failure."""
return jwt.decode(token, settings.secret_key, algorithms=["HS256"])
async def get_current_user(
token: str = Depends(oauth2_scheme),
db: AsyncSession = Depends(get_db),
) -> User:
"""FastAPI dependency: extract and validate current user from JWT."""
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid authentication credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = decode_access_token(token)
user_id: str | None = payload.get("sub")
if user_id is None:
raise credentials_exception
except JWTError:
raise credentials_exception
result = await db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if user is None:
raise credentials_exception
return user
async def get_current_org(
user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> Organization:
"""FastAPI dependency: get the user's first organization (simplified).
In the future, this will support org-switching via header or session.
"""
result = await db.execute(
select(OrganizationMember).where(OrganizationMember.user_id == user.id)
)
membership = result.scalars().first()
if not membership:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="User is not a member of any organization",
)
org_result = await db.execute(
select(Organization).where(Organization.id == membership.organization_id)
)
org = org_result.scalar_one_or_none()
if org is None:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Organization not found",
)
return org

View file

@ -8,6 +8,12 @@ from app.models.models import (
NotificationLog,
SiteSetting,
)
from app.models.saas_models import (
User,
Organization,
OrganizationMember,
StatusPage,
)
__all__ = [
"Service",
@ -18,4 +24,8 @@ __all__ = [
"Subscriber",
"NotificationLog",
"SiteSetting",
"User",
"Organization",
"OrganizationMember",
"StatusPage",
]

View file

@ -15,6 +15,9 @@ class Service(Base):
__tablename__ = "services"
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=_uuid_str)
organization_id: Mapped[str | None] = mapped_column(
String(36), ForeignKey("organizations.id"), nullable=True, index=True
)
name: Mapped[str] = mapped_column(String(100), nullable=False)
slug: Mapped[str] = mapped_column(String(50), unique=True, nullable=False, index=True)
description: Mapped[str | None] = mapped_column(Text, nullable=True)
@ -74,6 +77,9 @@ class Monitor(Base):
__tablename__ = "monitors"
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=_uuid_str)
organization_id: Mapped[str | None] = mapped_column(
String(36), ForeignKey("organizations.id"), nullable=True, index=True
)
service_id: Mapped[str] = mapped_column(
String(36), ForeignKey("services.id"), nullable=False, index=True
)
@ -114,6 +120,9 @@ class Subscriber(Base):
__tablename__ = "subscribers"
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=_uuid_str)
organization_id: Mapped[str | None] = mapped_column(
String(36), ForeignKey("organizations.id"), nullable=True, index=True
)
email: Mapped[str] = mapped_column(String(255), unique=True, nullable=False, index=True)
is_confirmed: Mapped[bool] = mapped_column(Boolean, default=False)
confirm_token: Mapped[str | None] = mapped_column(String(100), nullable=True)

113
app/models/saas_models.py Normal file
View file

@ -0,0 +1,113 @@
"""SaaS multi-tenancy models: User, Organization, OrganizationMember, StatusPage."""
import uuid
from datetime import datetime
from sqlalchemy import Boolean, DateTime, ForeignKey, String, UniqueConstraint
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.database import Base
def _uuid_str() -> str:
return str(uuid.uuid4())
class User(Base):
"""Individual who can log in."""
__tablename__ = "users"
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=_uuid_str)
email: Mapped[str] = mapped_column(
String(255), unique=True, nullable=False, index=True
)
password_hash: Mapped[str] = mapped_column(String(255), nullable=False)
display_name: Mapped[str | None] = mapped_column(String(100), nullable=True)
is_email_verified: Mapped[bool] = mapped_column(Boolean, default=False)
email_verify_token: Mapped[str | None] = mapped_column(String(100), nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
updated_at: Mapped[datetime] = mapped_column(
DateTime, default=datetime.utcnow, onupdate=datetime.utcnow
)
memberships: Mapped[list["OrganizationMember"]] = relationship(
back_populates="user"
)
class Organization(Base):
"""The tenant; owns status pages."""
__tablename__ = "organizations"
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=_uuid_str)
slug: Mapped[str] = mapped_column(
String(50), unique=True, nullable=False, index=True
)
name: Mapped[str] = mapped_column(String(100), nullable=False)
tier: Mapped[str] = mapped_column(String(20), nullable=False, default="free")
# "free" | "pro" | "team"
stripe_customer_id: Mapped[str | None] = mapped_column(String(100), nullable=True)
custom_domain: Mapped[str | None] = mapped_column(String(255), nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
updated_at: Mapped[datetime] = mapped_column(
DateTime, default=datetime.utcnow, onupdate=datetime.utcnow
)
members: Mapped[list["OrganizationMember"]] = relationship(
back_populates="organization"
)
status_pages: Mapped[list["StatusPage"]] = relationship(
back_populates="organization"
)
# Services linked to this org (from app.models.models.Service)
class OrganizationMember(Base):
"""Joins users to orgs with roles."""
__tablename__ = "organization_members"
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=_uuid_str)
organization_id: Mapped[str] = mapped_column(
String(36), ForeignKey("organizations.id"), nullable=False, index=True
)
user_id: Mapped[str] = mapped_column(
String(36), ForeignKey("users.id"), nullable=False, index=True
)
role: Mapped[str] = mapped_column(String(20), nullable=False, default="member")
# "owner" | "admin" | "member"
joined_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
organization: Mapped["Organization"] = relationship(back_populates="members")
user: Mapped["User"] = relationship(back_populates="memberships")
__table_args__ = (
{"sqlite_autoincrement": True},
)
class StatusPage(Base):
"""Per-organization status page."""
__tablename__ = "status_pages"
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=_uuid_str)
organization_id: Mapped[str] = mapped_column(
String(36), ForeignKey("organizations.id"), nullable=False, index=True
)
slug: Mapped[str] = mapped_column(String(50), nullable=False, index=True)
title: Mapped[str] = mapped_column(String(100), nullable=False)
subdomain: Mapped[str | None] = mapped_column(String(100), nullable=True)
custom_domain: Mapped[str | None] = mapped_column(String(255), nullable=True)
is_public: Mapped[bool] = mapped_column(Boolean, default=True)
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
updated_at: Mapped[datetime] = mapped_column(
DateTime, default=datetime.utcnow, onupdate=datetime.utcnow
)
organization: Mapped["Organization"] = relationship(back_populates="status_pages")
__table_args__ = (
UniqueConstraint("organization_id", "slug", name="uq_status_page_org_slug"),
)

0
app/routes/__init__.py Normal file
View file

122
app/routes/auth.py Normal file
View file

@ -0,0 +1,122 @@
"""Auth routes: register, login, and current user profile."""
from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel, EmailStr
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.auth import create_access_token, get_current_user, hash_password, verify_password
from app.dependencies import get_db
from app.models.saas_models import Organization, OrganizationMember, StatusPage, User
router = APIRouter(tags=["auth"])
# ── Request / Response schemas ──────────────────────────────────────────────
class RegisterRequest(BaseModel):
email: EmailStr
password: str
class LoginRequest(BaseModel):
email: EmailStr
password: str
class AuthResponse(BaseModel):
access_token: str
token_type: str = "bearer"
class UserProfile(BaseModel):
id: str
email: str
display_name: str | None = None
is_email_verified: bool
created_at: str | None = None
# ── Routes ───────────────────────────────────────────────────────────────────
@router.post("/auth/register", status_code=status.HTTP_201_CREATED, response_model=AuthResponse)
async def register(body: RegisterRequest, db: AsyncSession = Depends(get_db)):
"""Register a new user, create a default Organization + StatusPage, return JWT."""
# Check for existing user
result = await db.execute(select(User).where(User.email == body.email))
if result.scalar_one_or_none() is not None:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="A user with this email already exists",
)
# Create user
user = User(
email=body.email,
password_hash=hash_password(body.password),
)
db.add(user)
await db.flush() # assign user.id
# Create default organization
org_slug = body.email.split("@")[0].lower()
org = Organization(
name=org_slug,
slug=org_slug,
)
db.add(org)
await db.flush() # assign org.id
# Create organization membership (owner)
membership = OrganizationMember(
organization_id=org.id,
user_id=user.id,
role="owner",
)
db.add(membership)
# Create default status page
status_page = StatusPage(
organization_id=org.id,
slug="main",
title="Status Page",
)
db.add(status_page)
# Commit all together (the get_db dependency also commits, but flush ensures
# relationships are consistent before we return)
await db.flush()
# Generate JWT
token = create_access_token(user.id)
return AuthResponse(access_token=token)
@router.post("/auth/login", response_model=AuthResponse)
async def login(body: LoginRequest, db: AsyncSession = Depends(get_db)):
"""Authenticate user with email + password, return JWT."""
result = await db.execute(select(User).where(User.email == body.email))
user = result.scalar_one_or_none()
if user is None or not verify_password(body.password, user.password_hash):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid email or password",
)
token = create_access_token(user.id)
return AuthResponse(access_token=token)
@router.get("/auth/me", response_model=UserProfile)
async def me(current_user: User = Depends(get_current_user)):
"""Return the current authenticated user's profile."""
return UserProfile(
id=current_user.id,
email=current_user.email,
display_name=current_user.display_name,
is_email_verified=current_user.is_email_verified,
created_at=current_user.created_at.isoformat() if current_user.created_at else None,
)

229
app/services/tier_limits.py Normal file
View file

@ -0,0 +1,229 @@
"""Tier enforcement: limits and feature flags for Free/Pro/Team plans.
This module defines the per-tier limits and provides enforcement functions
that raise HTTP 403 when a limit is exceeded or a feature is unavailable.
"""
from fastapi import HTTPException, status
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.saas_models import Organization, OrganizationMember, StatusPage
from app.models.models import Service, Monitor, Subscriber
# ── Tier limits configuration ──────────────────────────────────────────────
TIER_LIMITS: dict[str, dict[str, int | bool]] = {
"free": {
"status_pages": 1,
"services_per_page": 5,
"monitors_per_service": 1,
"subscribers": 25,
"members": 1,
"check_interval_min": 5, # minutes
"custom_domain": False,
"custom_branding": False,
"webhooks": False,
"api_access": False,
"incident_history_days": 30,
"sla_badge": False,
"password_protection": False,
},
"pro": {
"status_pages": 5,
"services_per_page": 50,
"monitors_per_service": 5,
"subscribers": 500,
"members": 3,
"check_interval_min": 1, # minutes
"custom_domain": True,
"custom_branding": True,
"webhooks": True,
"api_access": True,
"incident_history_days": 365,
"sla_badge": True,
"password_protection": False,
},
"team": {
"status_pages": -1, # unlimited
"services_per_page": -1,
"monitors_per_service": -1,
"subscribers": -1,
"members": -1,
"check_interval_min": 0, # 30 seconds (0 min)
"custom_domain": True,
"custom_branding": True,
"webhooks": True,
"api_access": True,
"incident_history_days": -1, # unlimited
"sla_badge": True,
"password_protection": True,
},
}
# ── Tier info helpers ───────────────────────────────────────────────────────
def get_tier_limits(tier: str) -> dict:
"""Return the limits dict for a given tier name. Falls back to free."""
return TIER_LIMITS.get(tier, TIER_LIMITS["free"])
def get_org_limits(org: Organization) -> dict:
"""Return the limits dict for an organization based on its tier."""
return get_tier_limits(org.tier or "free")
def get_limit(org: Organization, feature: str):
"""Return the limit value for a specific feature given the org's tier.
Returns:
int: numeric limit (-1 means unlimited)
bool: feature flag (True/False)
None: unknown feature
"""
limits = get_org_limits(org)
return limits.get(feature)
# ── Enforcement ─────────────────────────────────────────────────────────────
class TierLimitExceeded(HTTPException):
"""Raised when a tier limit is exceeded."""
def __init__(self, feature: str, limit: int | bool):
if limit is False:
detail = f"Feature '{feature}' is not available on your current plan. Upgrade to access it."
else:
detail = (
f"Tier limit reached for '{feature}' ({limit}). "
"Upgrade your plan to increase this limit."
)
super().__init__(status_code=status.HTTP_403_FORBIDDEN, detail=detail)
async def enforce_limit(
db: AsyncSession,
org: Organization,
feature: str,
current_count: int,
) -> None:
"""Raise TierLimitExceeded if the current count meets or exceeds the tier limit.
Args:
db: Database session for queries.
org: The organization whose tier limits to check.
feature: The feature name (key in TIER_LIMITS).
current_count: How many of this feature the org currently has.
Raises:
TierLimitExceeded: If the limit is reached or the feature is disabled.
"""
limit = get_limit(org, feature)
if limit is None:
return # Unknown feature — don't block
if limit is False:
# Feature flag: not available on this tier
raise TierLimitExceeded(feature, limit)
if limit == -1:
return # Unlimited
if isinstance(limit, int) and current_count >= limit:
raise TierLimitExceeded(feature, limit)
# ── Concrete enforcement helpers ────────────────────────────────────────────
async def check_status_page_limit(db: AsyncSession, org: Organization) -> None:
"""Check that the org hasn't exceeded its status page limit."""
result = await db.execute(
select(func.count(StatusPage.id)).where(
StatusPage.organization_id == org.id
)
)
count = result.scalar() or 0
await enforce_limit(db, org, "status_pages", count)
async def check_service_limit(
db: AsyncSession, org: Organization, status_page_id: str | None = None
) -> None:
"""Check that the org hasn't exceeded its services-per-page limit.
If status_page_id is None, counts all services for the org.
"""
query = select(func.count(Service.id)).where(
Service.organization_id == org.id
)
if status_page_id:
# In future, Service will have a status_page_id column
# For now, count all services in the org
pass
result = await db.execute(query)
count = result.scalar() or 0
await enforce_limit(db, org, "services_per_page", count)
async def check_monitor_limit(
db: AsyncSession, org: Organization, service_id: str
) -> None:
"""Check that the service hasn't exceeded its monitors-per-service limit."""
result = await db.execute(
select(func.count(Monitor.id)).where(Monitor.service_id == service_id)
)
count = result.scalar() or 0
await enforce_limit(db, org, "monitors_per_service", count)
async def check_subscriber_limit(db: AsyncSession, org: Organization) -> None:
"""Check that the org hasn't exceeded its subscriber limit."""
result = await db.execute(
select(func.count(Subscriber.id)).where(
Subscriber.organization_id == org.id
)
)
count = result.scalar() or 0
await enforce_limit(db, org, "subscribers", count)
async def check_member_limit(db: AsyncSession, org: Organization) -> None:
"""Check that the org hasn't exceeded its team member limit."""
result = await db.execute(
select(func.count(OrganizationMember.id)).where(
OrganizationMember.organization_id == org.id
)
)
count = result.scalar() or 0
await enforce_limit(db, org, "members", count)
def enforce_feature(org: Organization, feature: str) -> None:
"""Enforce a boolean feature flag. Raises TierLimitExceeded if False.
Use this for features that are either allowed or not (e.g., custom_domain,
webhooks, api_access) without a numeric limit.
"""
limit = get_limit(org, feature)
if limit is False:
raise TierLimitExceeded(feature, False)
if limit is None:
# Unknown feature — don't block
return
def get_tier_info(org: Organization) -> dict:
"""Return a dict of the org's current tier with limits and feature flags.
Useful for API responses that show the org what they can and can't do.
"""
limits = get_org_limits(org)
return {
"tier": org.tier or "free",
"limits": limits,
"organization_id": org.id,
"organization_slug": org.slug,
}

View file

@ -20,6 +20,10 @@ dependencies = [
"httpx>=0.27,<1.0",
"typer>=0.9,<1.0",
"rich>=13.0,<14.0",
"python-jose[cryptography]>=3.3,<4.0",
"passlib[bcrypt]>=1.7,<2.0",
"bcrypt==4.0.1",
"email-validator>=2.0,<3.0",
]
[project.optional-dependencies]

View file

@ -21,6 +21,9 @@ TestSessionLocal = async_sessionmaker(
@pytest_asyncio.fixture(scope="session", autouse=True)
async def setup_database():
"""Create all tables once for the test session."""
# Import SaaS models so their tables are registered on Base.metadata
import app.models.saas_models # noqa: F401
async with test_engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield

103
tests/test_auth.py Normal file
View file

@ -0,0 +1,103 @@
"""Test Auth API endpoints."""
import pytest
REGISTER_URL = "/api/v1/auth/register"
LOGIN_URL = "/api/v1/auth/login"
ME_URL = "/api/v1/auth/me"
@pytest.mark.asyncio
async def test_register_new_user(client):
"""Should register a new user and return 201 with a JWT token."""
response = await client.post(
REGISTER_URL,
json={"email": "newuser@example.com", "password": "securepassword123"},
)
assert response.status_code == 201
data = response.json()
assert "access_token" in data
assert data["token_type"] == "bearer"
assert len(data["access_token"]) > 0
@pytest.mark.asyncio
async def test_register_duplicate_email(client):
"""Should return 409 when registering with an email that already exists."""
# Register first user
await client.post(
REGISTER_URL,
json={"email": "duplicate@example.com", "password": "password123"},
)
# Try to register again with same email
response = await client.post(
REGISTER_URL,
json={"email": "duplicate@example.com", "password": "differentpassword"},
)
assert response.status_code == 409
assert "already exists" in response.json()["detail"].lower()
@pytest.mark.asyncio
async def test_login_correct_password(client):
"""Should return 200 with a JWT token on successful login."""
# Register a user first
await client.post(
REGISTER_URL,
json={"email": "loginuser@example.com", "password": "mypassword"},
)
# Login with correct password
response = await client.post(
LOGIN_URL,
json={"email": "loginuser@example.com", "password": "mypassword"},
)
assert response.status_code == 200
data = response.json()
assert "access_token" in data
assert data["token_type"] == "bearer"
@pytest.mark.asyncio
async def test_login_wrong_password(client):
"""Should return 401 when logging in with wrong password."""
# Register a user first
await client.post(
REGISTER_URL,
json={"email": "wrongpw@example.com", "password": "correctpassword"},
)
# Login with wrong password
response = await client.post(
LOGIN_URL,
json={"email": "wrongpw@example.com", "password": "wrongpassword"},
)
assert response.status_code == 401
@pytest.mark.asyncio
async def test_me_with_valid_token(client):
"""Should return 200 with user profile when using a valid JWT token."""
# Register a user and get token
reg_response = await client.post(
REGISTER_URL,
json={"email": "meuser@example.com", "password": "password123"},
)
token = reg_response.json()["access_token"]
# Get profile with valid token
response = await client.get(
ME_URL,
headers={"Authorization": f"Bearer {token}"},
)
assert response.status_code == 200
data = response.json()
assert data["email"] == "meuser@example.com"
assert "id" in data
assert data["is_email_verified"] is False
@pytest.mark.asyncio
async def test_me_without_token(client):
"""Should return 401 when accessing /me without a token."""
response = await client.get(ME_URL)
assert response.status_code in (401, 403)

590
tests/test_tier_limits.py Normal file
View file

@ -0,0 +1,590 @@
"""Test tier enforcement: limits, feature flags, and organization endpoints."""
import pytest
from sqlalchemy import select
from app.models.saas_models import Organization, OrganizationMember, StatusPage, User
from app.models.models import Service, Monitor, Subscriber
from app.services.tier_limits import (
TIER_LIMITS,
TierLimitExceeded,
enforce_limit,
enforce_feature,
get_limit,
get_org_limits,
get_tier_info,
get_tier_limits,
check_status_page_limit,
check_service_limit,
check_monitor_limit,
check_subscriber_limit,
check_member_limit,
)
# ── Unit tests for tier_limits module ────────────────────────────────────────
class TestTierLimitsConfig:
"""Test that the TIER_LIMITS config is well-formed."""
def test_all_tiers_defined(self):
"""All three tiers should be defined."""
assert "free" in TIER_LIMITS
assert "pro" in TIER_LIMITS
assert "team" in TIER_LIMITS
def test_free_tier_has_expected_keys(self):
"""Free tier should have all expected limit keys."""
free = TIER_LIMITS["free"]
expected_keys = {
"status_pages", "services_per_page", "monitors_per_service",
"subscribers", "members", "check_interval_min",
"custom_domain", "custom_branding", "webhooks",
"api_access", "incident_history_days", "sla_badge",
"password_protection",
}
assert set(free.keys()) == expected_keys
def test_free_tier_values(self):
"""Free tier should have restrictive values."""
free = TIER_LIMITS["free"]
assert free["status_pages"] == 1
assert free["services_per_page"] == 5
assert free["monitors_per_service"] == 1
assert free["subscribers"] == 25
assert free["members"] == 1
assert free["custom_domain"] is False
assert free["webhooks"] is False
assert free["api_access"] is False
def test_pro_tier_values(self):
"""Pro tier should have moderate values."""
pro = TIER_LIMITS["pro"]
assert pro["status_pages"] == 5
assert pro["services_per_page"] == 50
assert pro["monitors_per_service"] == 5
assert pro["subscribers"] == 500
assert pro["members"] == 3
assert pro["custom_domain"] is True
assert pro["webhooks"] is True
def test_team_tier_values(self):
"""Team tier should have unlimited (-1) for most things."""
team = TIER_LIMITS["team"]
assert team["status_pages"] == -1
assert team["services_per_page"] == -1
assert team["monitors_per_service"] == -1
assert team["subscribers"] == -1
assert team["members"] == -1
assert team["custom_domain"] is True
assert team["password_protection"] is True
def test_all_tiers_have_same_keys(self):
"""All tiers should have exactly the same set of keys."""
keys = set(TIER_LIMITS["free"].keys())
for tier_name, tier_data in TIER_LIMITS.items():
assert set(tier_data.keys()) == keys, f"Tier '{tier_name}' has different keys"
class TestGetLimitHelpers:
"""Test the helper functions."""
def test_get_tier_limits_known_tier(self):
"""get_tier_limits should return the correct dict for known tiers."""
assert get_tier_limits("free") == TIER_LIMITS["free"]
assert get_tier_limits("pro") == TIER_LIMITS["pro"]
assert get_tier_limits("team") == TIER_LIMITS["team"]
def test_get_tier_limits_unknown_tier(self):
"""get_tier_limits should default to free for unknown tiers."""
assert get_tier_limits("enterprise") == TIER_LIMITS["free"]
assert get_tier_limits("") == TIER_LIMITS["free"]
def test_get_org_limits_with_tier(self):
"""get_org_limits should use the org's tier."""
org = Organization(slug="test", name="Test", tier="pro")
assert get_org_limits(org) == TIER_LIMITS["pro"]
def test_get_org_limits_with_none_tier(self):
"""get_org_limits should default to free when tier is None."""
org = Organization(slug="test", name="Test", tier=None)
assert get_org_limits(org) == TIER_LIMITS["free"]
def test_get_limit_numeric(self):
"""get_limit should return numeric limits."""
org = Organization(slug="test", name="Test", tier="free")
assert get_limit(org, "status_pages") == 1
assert get_limit(org, "subscribers") == 25
def test_get_limit_boolean(self):
"""get_limit should return boolean feature flags."""
org_free = Organization(slug="free-org", name="Free Org", tier="free")
org_pro = Organization(slug="pro-org", name="Pro Org", tier="pro")
assert get_limit(org_free, "custom_domain") is False
assert get_limit(org_pro, "custom_domain") is True
def test_get_limit_unlimited(self):
"""get_limit should return -1 for unlimited features."""
org = Organization(slug="team-org", name="Team Org", tier="team")
assert get_limit(org, "status_pages") == -1
def test_get_limit_unknown_feature(self):
"""get_limit should return None for unknown feature names."""
org = Organization(slug="test", name="Test", tier="free")
assert get_limit(org, "nonexistent_feature") is None
class TestEnforceLimit:
"""Test the enforce_limit function."""
@pytest.mark.asyncio
async def test_enforce_limit_allows_under_limit(self, db_session):
"""Should not raise when current_count is below the limit."""
org = Organization(slug="test", name="Test", tier="free")
db_session.add(org)
await db_session.flush()
# status_pages limit is 1 for free, current_count=0 should pass
await enforce_limit(db_session, org, "status_pages", 0)
@pytest.mark.asyncio
async def test_enforce_limit_blocks_at_limit(self, db_session):
"""Should raise TierLimitExceeded when current_count equals the limit."""
org = Organization(slug="test", name="Test", tier="free")
db_session.add(org)
await db_session.flush()
# status_pages limit is 1 for free, current_count=1 should fail
with pytest.raises(TierLimitExceeded) as exc_info:
await enforce_limit(db_session, org, "status_pages", 1)
assert "status_pages" in str(exc_info.value.detail)
@pytest.mark.asyncio
async def test_enforce_limit_blocks_over_limit(self, db_session):
"""Should raise TierLimitExceeded when current_count exceeds the limit."""
org = Organization(slug="test", name="Test", tier="free")
db_session.add(org)
await db_session.flush()
with pytest.raises(TierLimitExceeded):
await enforce_limit(db_session, org, "status_pages", 5)
@pytest.mark.asyncio
async def test_enforce_limit_allows_unlimited(self, db_session):
"""Should not raise when limit is -1 (unlimited)."""
org = Organization(slug="test", name="Test", tier="team")
db_session.add(org)
await db_session.flush()
# team has unlimited status_pages, even count=1000 should pass
await enforce_limit(db_session, org, "status_pages", 1000)
@pytest.mark.asyncio
async def test_enforce_limit_blocks_feature_flag(self, db_session):
"""Should raise TierLimitExceeded when feature is False (not available)."""
org = Organization(slug="test", name="Test", tier="free")
db_session.add(org)
await db_session.flush()
with pytest.raises(TierLimitExceeded) as exc_info:
await enforce_limit(db_session, org, "custom_domain", 0)
assert "not available" in str(exc_info.value.detail)
@pytest.mark.asyncio
async def test_enforce_limit_unknown_feature(self, db_session):
"""Should not raise for an unknown feature (no limit defined)."""
org = Organization(slug="test", name="Test", tier="free")
db_session.add(org)
await db_session.flush()
# Should not raise for an unknown feature
await enforce_limit(db_session, org, "nonexistent", 999)
@pytest.mark.asyncio
async def test_enforce_limit_pro_tier(self, db_session):
"""Pro tier allows up to 5 status pages."""
org = Organization(slug="pro-org", name="Pro Org", tier="pro")
db_session.add(org)
await db_session.flush()
# 5 is the limit, should fail at 5
with pytest.raises(TierLimitExceeded):
await enforce_limit(db_session, org, "status_pages", 5)
# 4 should pass
await enforce_limit(db_session, org, "status_pages", 4)
class TestEnforceFeature:
"""Test the enforce_feature boolean feature flag function."""
def test_enforce_feature_allows_pro_custom_domain(self):
"""Pro org should be allowed custom_domain."""
org = Organization(slug="pro", name="Pro", tier="pro")
# Should not raise
enforce_feature(org, "custom_domain")
def test_enforce_feature_blocks_free_custom_domain(self):
"""Free org should be blocked from custom_domain."""
org = Organization(slug="free", name="Free", tier="free")
with pytest.raises(TierLimitExceeded):
enforce_feature(org, "custom_domain")
def test_enforce_feature_blocks_free_webhooks(self):
"""Free org should be blocked from webhooks."""
org = Organization(slug="free", name="Free", tier="free")
with pytest.raises(TierLimitExceeded):
enforce_feature(org, "webhooks")
def test_enforce_feature_blocks_free_api_access(self):
"""Free org should be blocked from api_access."""
org = Organization(slug="free", name="Free", tier="free")
with pytest.raises(TierLimitExceeded):
enforce_feature(org, "api_access")
def test_enforce_feature_blocks_free_password_protection(self):
"""Free org should be blocked from password_protection."""
org = Organization(slug="free", name="Free", tier="free")
with pytest.raises(TierLimitExceeded):
enforce_feature(org, "password_protection")
def test_enforce_feature_allows_pro_webhooks(self):
"""Pro org should be allowed webhooks."""
org = Organization(slug="pro", name="Pro", tier="pro")
enforce_feature(org, "webhooks")
def test_enforce_feature_unknown_feature(self):
"""Unknown feature should not raise."""
org = Organization(slug="free", name="Free", tier="free")
# Should not raise for unknown
enforce_feature(org, "quantum_computing")
def test_enforce_feature_team_password_protection(self):
"""Team org should be allowed password_protection."""
org = Organization(slug="team", name="Team", tier="team")
enforce_feature(org, "password_protection")
class TestGetTierInfo:
"""Test get_tier_info helper."""
def test_free_org_tier_info(self):
"""Free org should return complete tier info."""
org = Organization(id="org-1", slug="free-org", name="Free Org", tier="free")
info = get_tier_info(org)
assert info["tier"] == "free"
assert info["organization_id"] == "org-1"
assert info["limits"]["status_pages"] == 1
assert info["limits"]["custom_domain"] is False
def test_pro_org_tier_info(self):
"""Pro org should return correct tier info."""
org = Organization(id="org-2", slug="pro-org", name="Pro Org", tier="pro")
info = get_tier_info(org)
assert info["tier"] == "pro"
assert info["limits"]["status_pages"] == 5
assert info["limits"]["custom_domain"] is True
def test_default_tier_info(self):
"""Org with None tier should default to free."""
org = Organization(id="org-3", slug="default-org", name="Default Org", tier=None)
info = get_tier_info(org)
assert info["tier"] == "free"
assert info["limits"]["status_pages"] == 1
# ── Integration tests for database-backed limit checks ───────────────────────
class TestCheckStatusPageLimit:
"""Test check_status_page_limit with real database."""
@pytest.mark.asyncio
async def test_free_org_can_create_first_page(self, db_session):
"""Free org should be allowed to create its first status page."""
org = Organization(slug="free-org", name="Free Org", tier="free")
db_session.add(org)
await db_session.flush()
# Should not raise (0 existing pages, limit is 1)
await check_status_page_limit(db_session, org)
@pytest.mark.asyncio
async def test_free_org_blocked_at_second_page(self, db_session):
"""Free org should be blocked from creating a second status page."""
org = Organization(slug="free-org", name="Free Org", tier="free")
db_session.add(org)
await db_session.flush()
# Create first page
page = StatusPage(organization_id=org.id, slug="main", title="Main")
db_session.add(page)
await db_session.flush()
# Now limit should be reached
with pytest.raises(TierLimitExceeded):
await check_status_page_limit(db_session, org)
@pytest.mark.asyncio
async def test_pro_org_allows_up_to_five_pages(self, db_session):
"""Pro org should be allowed up to 5 status pages."""
org = Organization(slug="pro-org", name="Pro Org", tier="pro")
db_session.add(org)
await db_session.flush()
# Create 4 pages — should be fine (limit is 5)
for i in range(4):
page = StatusPage(organization_id=org.id, slug=f"page-{i}", title=f"Page {i}")
db_session.add(page)
await db_session.flush()
await check_status_page_limit(db_session, org) # should not raise
@pytest.mark.asyncio
async def test_pro_org_blocked_at_sixth_page(self, db_session):
"""Pro org should be blocked at the 6th status page."""
org = Organization(slug="pro-org", name="Pro Org", tier="pro")
db_session.add(org)
await db_session.flush()
# Create 5 pages (at the limit)
for i in range(5):
page = StatusPage(organization_id=org.id, slug=f"page-{i}", title=f"Page {i}")
db_session.add(page)
await db_session.flush()
with pytest.raises(TierLimitExceeded):
await check_status_page_limit(db_session, org)
@pytest.mark.asyncio
async def test_team_org_unlimited_pages(self, db_session):
"""Team org should never be blocked for status pages."""
org = Organization(slug="team-org", name="Team Org", tier="team")
db_session.add(org)
await db_session.flush()
for i in range(50):
page = StatusPage(organization_id=org.id, slug=f"page-{i}", title=f"Page {i}")
db_session.add(page)
await db_session.flush()
# Should not raise even with 50 pages
await check_status_page_limit(db_session, org)
class TestCheckMemberLimit:
"""Test check_member_limit with real database."""
@pytest.mark.asyncio
async def test_free_org_blocked_at_second_member(self, db_session):
"""Free org should be blocked from adding a second member."""
org = Organization(slug="free-org", name="Free Org", tier="free")
db_session.add(org)
await db_session.flush()
# Create one user
user = User(email="owner@example.com", password_hash="hash")
db_session.add(user)
await db_session.flush()
# Add owner membership
membership = OrganizationMember(
organization_id=org.id, user_id=user.id, role="owner"
)
db_session.add(membership)
await db_session.flush()
# Now at limit — should be blocked
with pytest.raises(TierLimitExceeded):
await check_member_limit(db_session, org)
@pytest.mark.asyncio
async def test_free_org_allows_one_member(self, db_session):
"""Free org with 0 members should be allowed to add one."""
org = Organization(slug="free-org", name="Free Org", tier="free")
db_session.add(org)
await db_session.flush()
# 0 members, limit is 1 — should pass
await check_member_limit(db_session, org)
class TestCheckServiceLimit:
"""Test check_service_limit with real database."""
@pytest.mark.asyncio
async def test_free_org_can_create_first_services(self, db_session):
"""Free org should be allowed to create services up to limit."""
org = Organization(slug="free-org", name="Free Org", tier="free")
db_session.add(org)
await db_session.flush()
# Create 4 services (limit is 5)
for i in range(4):
svc = Service(
name=f"svc-{i}", slug=f"svc-{i}", organization_id=org.id
)
db_session.add(svc)
await db_session.flush()
# Should not raise
await check_service_limit(db_session, org)
@pytest.mark.asyncio
async def test_free_org_blocked_at_sixth_service(self, db_session):
"""Free org should be blocked at the 6th service."""
org = Organization(slug="free-org", name="Free Org", tier="free")
db_session.add(org)
await db_session.flush()
# Create 5 services (at limit)
for i in range(5):
svc = Service(
name=f"svc-{i}", slug=f"svc-{i}", organization_id=org.id
)
db_session.add(svc)
await db_session.flush()
with pytest.raises(TierLimitExceeded):
await check_service_limit(db_session, org)
# ── API integration tests ──────────────────────────────────────────────────────
REGISTER_URL = "/api/v1/auth/register"
LOGIN_URL = "/api/v1/auth/login"
ME_URL = "/api/v1/auth/me"
ORGS_TIERS_URL = "/api/v1/organizations/tiers"
ORGS_MY_URL = "/api/v1/organizations/my"
ORGS_MY_LIMITS_URL = "/api/v1/organizations/my/limits"
ORGS_MY_TIER_URL = "/api/v1/organizations/my/tier"
class TestTierAPIEndpoints:
"""Test the organization/tier API endpoints."""
@pytest.mark.asyncio
async def test_list_tiers_public(self, client):
"""List tiers endpoint should be accessible without auth."""
response = await client.get(ORGS_TIERS_URL)
assert response.status_code == 200
data = response.json()
assert "tiers" in data
assert len(data["tiers"]) == 3
tier_names = [t["name"] for t in data["tiers"]]
assert "free" in tier_names
assert "pro" in tier_names
assert "team" in tier_names
@pytest.mark.asyncio
async def test_get_my_org(self, client):
"""Authenticated user should be able to get their org info."""
# Register
reg_response = await client.post(
REGISTER_URL,
json={"email": "orguser@example.com", "password": "testpass123"},
)
assert reg_response.status_code == 201
token = reg_response.json()["access_token"]
# Get org info
response = await client.get(
ORGS_MY_URL,
headers={"Authorization": f"Bearer {token}"},
)
assert response.status_code == 200
data = response.json()
assert data["tier"] == "free"
assert "tier_info" in data
assert data["tier_info"]["tier"] == "free"
assert data["member_count"] >= 1
@pytest.mark.asyncio
async def test_get_my_org_unauthorized(self, client):
"""Unauthenticated request should be rejected."""
response = await client.get(ORGS_MY_URL)
assert response.status_code in (401, 403)
@pytest.mark.asyncio
async def test_get_my_limits(self, client):
"""Authenticated user should be able to see their tier limits."""
reg_response = await client.post(
REGISTER_URL,
json={"email": "limitsuser@example.com", "password": "testpass123"},
)
token = reg_response.json()["access_token"]
response = await client.get(
ORGS_MY_LIMITS_URL,
headers={"Authorization": f"Bearer {token}"},
)
assert response.status_code == 200
data = response.json()
assert data["tier"] == "free"
assert data["limits"]["status_pages"] == 1
assert data["limits"]["custom_domain"] is False
@pytest.mark.asyncio
async def test_upgrade_tier_to_pro(self, client):
"""Org owner should be able to upgrade to Pro."""
reg_response = await client.post(
REGISTER_URL,
json={"email": "upgradeuser@example.com", "password": "testpass123"},
)
token = reg_response.json()["access_token"]
# Upgrade to pro
response = await client.patch(
ORGS_MY_TIER_URL,
json={"tier": "pro"},
headers={"Authorization": f"Bearer {token}"},
)
assert response.status_code == 200
data = response.json()
assert data["tier"] == "pro"
assert data["tier_info"]["limits"]["status_pages"] == 5
assert data["tier_info"]["limits"]["custom_domain"] is True
@pytest.mark.asyncio
async def test_upgrade_tier_to_team(self, client):
"""Org owner should be able to upgrade to Team."""
reg_response = await client.post(
REGISTER_URL,
json={"email": "teamuser@example.com", "password": "testpass123"},
)
token = reg_response.json()["access_token"]
response = await client.patch(
ORGS_MY_TIER_URL,
json={"tier": "team"},
headers={"Authorization": f"Bearer {token}"},
)
assert response.status_code == 200
data = response.json()
assert data["tier"] == "team"
assert data["tier_info"]["limits"]["status_pages"] == -1
@pytest.mark.asyncio
async def test_upgrade_tier_invalid(self, client):
"""Upgrading to an invalid tier should be rejected."""
reg_response = await client.post(
REGISTER_URL,
json={"email": "invalidtier@example.com", "password": "testpass123"},
)
token = reg_response.json()["access_token"]
response = await client.patch(
ORGS_MY_TIER_URL,
json={"tier": "enterprise"},
headers={"Authorization": f"Bearer {token}"},
)
assert response.status_code in (403, 422)
@pytest.mark.asyncio
async def test_downgrade_back_to_free(self, client):
"""Org should be able to downgrade back to free."""
reg_response = await client.post(
REGISTER_URL,
json={"email": "downgrade@example.com", "password": "testpass123"},
)
token = reg_response.json()["access_token"]
# Upgrade to pro first
await client.patch(
ORGS_MY_TIER_URL,
json={"tier": "pro"},
headers={"Authorization": f"Bearer {token}"},
)
# Downgrade back to free
response = await client.patch(
ORGS_MY_TIER_URL,
json={"tier": "free"},
headers={"Authorization": f"Bearer {token}"},
)
assert response.status_code == 200
data = response.json()
assert data["tier"] == "free"
@pytest.mark.asyncio
async def test_tier_upgrade_unauthorized(self, client):
"""Unauthenticated tier upgrade should be rejected."""
response = await client.patch(
ORGS_MY_TIER_URL,
json={"tier": "pro"},
)
assert response.status_code in (401, 403)