feat: indie status page SaaS - initial release
This commit is contained in:
parent
ee2bc87ade
commit
b7a8142ca0
14 changed files with 2703 additions and 0 deletions
1262
SAAS_ENHANCEMENT_PLAN.md
Normal file
1262
SAAS_ENHANCEMENT_PLAN.md
Normal file
File diff suppressed because it is too large
Load diff
155
app/api/organizations.py
Normal file
155
app/api/organizations.py
Normal file
|
|
@ -0,0 +1,155 @@
|
|||
"""Organization API endpoints: view org, list tiers, upgrade/downgrade."""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.auth import get_current_user, get_current_org
|
||||
from app.dependencies import get_db
|
||||
from app.models.saas_models import Organization, OrganizationMember, User
|
||||
from app.services.tier_limits import (
|
||||
TIER_LIMITS,
|
||||
get_org_limits,
|
||||
get_tier_info,
|
||||
)
|
||||
|
||||
router = APIRouter(tags=["organizations"])
|
||||
|
||||
|
||||
# ── Response schemas ────────────────────────────────────────────────────────
|
||||
|
||||
class OrgMemberResponse(BaseModel):
|
||||
user_id: str
|
||||
email: str
|
||||
display_name: str | None
|
||||
role: str
|
||||
|
||||
|
||||
class OrgResponse(BaseModel):
|
||||
id: str
|
||||
slug: str
|
||||
name: str
|
||||
tier: str
|
||||
custom_domain: str | None
|
||||
member_count: int
|
||||
tier_info: dict
|
||||
|
||||
|
||||
class TierDetailResponse(BaseModel):
|
||||
tier: str
|
||||
limits: dict
|
||||
|
||||
|
||||
class UpgradeRequest(BaseModel):
|
||||
tier: str # "free" | "pro" | "team"
|
||||
|
||||
|
||||
# ── Endpoints ───────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/tiers")
|
||||
async def list_tiers():
|
||||
"""List all available tiers and their limits (public endpoint)."""
|
||||
return {
|
||||
"tiers": [
|
||||
{
|
||||
"name": tier_name,
|
||||
"display_name": {
|
||||
"free": "Free",
|
||||
"pro": "Pro ($9/mo)",
|
||||
"team": "Team ($29/mo)",
|
||||
}.get(tier_name, tier_name),
|
||||
"limits": limits,
|
||||
}
|
||||
for tier_name, limits in TIER_LIMITS.items()
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@router.get("/my", response_model=OrgResponse)
|
||||
async def get_my_org(
|
||||
org: Organization = Depends(get_current_org),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Get the current user's organization with tier limits info."""
|
||||
# Count members
|
||||
result = await db.execute(
|
||||
select(OrganizationMember).where(
|
||||
OrganizationMember.organization_id == org.id
|
||||
)
|
||||
)
|
||||
members = result.scalars().all()
|
||||
|
||||
return OrgResponse(
|
||||
id=org.id,
|
||||
slug=org.slug,
|
||||
name=org.name,
|
||||
tier=org.tier or "free",
|
||||
custom_domain=org.custom_domain,
|
||||
member_count=len(members),
|
||||
tier_info=get_tier_info(org),
|
||||
)
|
||||
|
||||
|
||||
@router.patch("/my/tier", response_model=OrgResponse)
|
||||
async def update_org_tier(
|
||||
body: UpgradeRequest,
|
||||
org: Organization = Depends(get_current_org),
|
||||
user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Update the organization's tier.
|
||||
|
||||
In production, this would be gated by Stripe payment verification.
|
||||
For now, this is an admin-only endpoint that directly sets the tier.
|
||||
Only the org owner can change the tier.
|
||||
"""
|
||||
# Verify user is owner
|
||||
result = await db.execute(
|
||||
select(OrganizationMember).where(
|
||||
OrganizationMember.organization_id == org.id,
|
||||
OrganizationMember.user_id == user.id,
|
||||
)
|
||||
)
|
||||
membership = result.scalar_one_or_none()
|
||||
if not membership or membership.role != "owner":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Only the organization owner can change the plan tier.",
|
||||
)
|
||||
|
||||
if body.tier not in ("free", "pro", "team"):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
detail=f"Invalid tier '{body.tier}'. Must be one of: free, pro, team.",
|
||||
)
|
||||
|
||||
org.tier = body.tier
|
||||
await db.flush()
|
||||
await db.refresh(org)
|
||||
|
||||
# Count members for response
|
||||
members_result = await db.execute(
|
||||
select(OrganizationMember).where(
|
||||
OrganizationMember.organization_id == org.id
|
||||
)
|
||||
)
|
||||
members = members_result.scalars().all()
|
||||
|
||||
return OrgResponse(
|
||||
id=org.id,
|
||||
slug=org.slug,
|
||||
name=org.name,
|
||||
tier=org.tier,
|
||||
custom_domain=org.custom_domain,
|
||||
member_count=len(members),
|
||||
tier_info=get_tier_info(org),
|
||||
)
|
||||
|
||||
|
||||
@router.get("/my/limits")
|
||||
async def get_my_limits(
|
||||
org: Organization = Depends(get_current_org),
|
||||
):
|
||||
"""Get the current organization's tier limits and feature flags."""
|
||||
return get_tier_info(org)
|
||||
|
|
@ -5,9 +5,13 @@ from app.api.incidents import router as incidents_router
|
|||
from app.api.monitors import router as monitors_router
|
||||
from app.api.subscribers import router as subscribers_router
|
||||
from app.api.settings import router as settings_router
|
||||
from app.api.organizations import router as organizations_router
|
||||
from app.routes.auth import router as auth_router
|
||||
|
||||
api_v1_router = APIRouter()
|
||||
|
||||
api_v1_router.include_router(auth_router, tags=["auth"])
|
||||
api_v1_router.include_router(organizations_router, prefix="/organizations", tags=["organizations"])
|
||||
api_v1_router.include_router(services_router, prefix="/services", tags=["services"])
|
||||
api_v1_router.include_router(incidents_router, prefix="/incidents", tags=["incidents"])
|
||||
api_v1_router.include_router(monitors_router, prefix="/monitors", tags=["monitors"])
|
||||
|
|
|
|||
99
app/auth.py
Normal file
99
app/auth.py
Normal file
|
|
@ -0,0 +1,99 @@
|
|||
"""JWT authentication and password hashing utilities."""
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from jose import JWTError, jwt
|
||||
from passlib.context import CryptContext
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import settings
|
||||
from app.dependencies import get_db
|
||||
from app.models.saas_models import User, Organization, OrganizationMember
|
||||
|
||||
# Password hashing context
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
|
||||
# OAuth2 scheme for token extraction
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
|
||||
|
||||
|
||||
def hash_password(password: str) -> str:
|
||||
"""Hash a password using bcrypt."""
|
||||
return pwd_context.hash(password)
|
||||
|
||||
|
||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
"""Verify a password against a hash."""
|
||||
return pwd_context.verify(plain_password, hashed_password)
|
||||
|
||||
|
||||
def create_access_token(user_id: str, exp_hours: int = 72) -> str:
|
||||
"""Create a JWT access token for the given user ID."""
|
||||
payload = {
|
||||
"sub": user_id,
|
||||
"exp": datetime.utcnow() + timedelta(hours=exp_hours),
|
||||
}
|
||||
return jwt.encode(payload, settings.secret_key, algorithm="HS256")
|
||||
|
||||
|
||||
def decode_access_token(token: str) -> dict:
|
||||
"""Decode and verify a JWT access token. Raises JWTError on failure."""
|
||||
return jwt.decode(token, settings.secret_key, algorithms=["HS256"])
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
token: str = Depends(oauth2_scheme),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> User:
|
||||
"""FastAPI dependency: extract and validate current user from JWT."""
|
||||
credentials_exception = HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid authentication credentials",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
try:
|
||||
payload = decode_access_token(token)
|
||||
user_id: str | None = payload.get("sub")
|
||||
if user_id is None:
|
||||
raise credentials_exception
|
||||
except JWTError:
|
||||
raise credentials_exception
|
||||
|
||||
result = await db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
if user is None:
|
||||
raise credentials_exception
|
||||
return user
|
||||
|
||||
|
||||
async def get_current_org(
|
||||
user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Organization:
|
||||
"""FastAPI dependency: get the user's first organization (simplified).
|
||||
|
||||
In the future, this will support org-switching via header or session.
|
||||
"""
|
||||
result = await db.execute(
|
||||
select(OrganizationMember).where(OrganizationMember.user_id == user.id)
|
||||
)
|
||||
membership = result.scalars().first()
|
||||
if not membership:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="User is not a member of any organization",
|
||||
)
|
||||
|
||||
org_result = await db.execute(
|
||||
select(Organization).where(Organization.id == membership.organization_id)
|
||||
)
|
||||
org = org_result.scalar_one_or_none()
|
||||
if org is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Organization not found",
|
||||
)
|
||||
return org
|
||||
|
|
@ -8,6 +8,12 @@ from app.models.models import (
|
|||
NotificationLog,
|
||||
SiteSetting,
|
||||
)
|
||||
from app.models.saas_models import (
|
||||
User,
|
||||
Organization,
|
||||
OrganizationMember,
|
||||
StatusPage,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Service",
|
||||
|
|
@ -18,4 +24,8 @@ __all__ = [
|
|||
"Subscriber",
|
||||
"NotificationLog",
|
||||
"SiteSetting",
|
||||
"User",
|
||||
"Organization",
|
||||
"OrganizationMember",
|
||||
"StatusPage",
|
||||
]
|
||||
|
|
@ -15,6 +15,9 @@ class Service(Base):
|
|||
__tablename__ = "services"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=_uuid_str)
|
||||
organization_id: Mapped[str | None] = mapped_column(
|
||||
String(36), ForeignKey("organizations.id"), nullable=True, index=True
|
||||
)
|
||||
name: Mapped[str] = mapped_column(String(100), nullable=False)
|
||||
slug: Mapped[str] = mapped_column(String(50), unique=True, nullable=False, index=True)
|
||||
description: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
|
|
@ -74,6 +77,9 @@ class Monitor(Base):
|
|||
__tablename__ = "monitors"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=_uuid_str)
|
||||
organization_id: Mapped[str | None] = mapped_column(
|
||||
String(36), ForeignKey("organizations.id"), nullable=True, index=True
|
||||
)
|
||||
service_id: Mapped[str] = mapped_column(
|
||||
String(36), ForeignKey("services.id"), nullable=False, index=True
|
||||
)
|
||||
|
|
@ -114,6 +120,9 @@ class Subscriber(Base):
|
|||
__tablename__ = "subscribers"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=_uuid_str)
|
||||
organization_id: Mapped[str | None] = mapped_column(
|
||||
String(36), ForeignKey("organizations.id"), nullable=True, index=True
|
||||
)
|
||||
email: Mapped[str] = mapped_column(String(255), unique=True, nullable=False, index=True)
|
||||
is_confirmed: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
confirm_token: Mapped[str | None] = mapped_column(String(100), nullable=True)
|
||||
|
|
|
|||
113
app/models/saas_models.py
Normal file
113
app/models/saas_models.py
Normal file
|
|
@ -0,0 +1,113 @@
|
|||
"""SaaS multi-tenancy models: User, Organization, OrganizationMember, StatusPage."""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import Boolean, DateTime, ForeignKey, String, UniqueConstraint
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.database import Base
|
||||
|
||||
|
||||
def _uuid_str() -> str:
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
class User(Base):
|
||||
"""Individual who can log in."""
|
||||
|
||||
__tablename__ = "users"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=_uuid_str)
|
||||
email: Mapped[str] = mapped_column(
|
||||
String(255), unique=True, nullable=False, index=True
|
||||
)
|
||||
password_hash: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
display_name: Mapped[str | None] = mapped_column(String(100), nullable=True)
|
||||
is_email_verified: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
email_verify_token: Mapped[str | None] = mapped_column(String(100), nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, default=datetime.utcnow, onupdate=datetime.utcnow
|
||||
)
|
||||
|
||||
memberships: Mapped[list["OrganizationMember"]] = relationship(
|
||||
back_populates="user"
|
||||
)
|
||||
|
||||
|
||||
class Organization(Base):
|
||||
"""The tenant; owns status pages."""
|
||||
|
||||
__tablename__ = "organizations"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=_uuid_str)
|
||||
slug: Mapped[str] = mapped_column(
|
||||
String(50), unique=True, nullable=False, index=True
|
||||
)
|
||||
name: Mapped[str] = mapped_column(String(100), nullable=False)
|
||||
tier: Mapped[str] = mapped_column(String(20), nullable=False, default="free")
|
||||
# "free" | "pro" | "team"
|
||||
stripe_customer_id: Mapped[str | None] = mapped_column(String(100), nullable=True)
|
||||
custom_domain: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, default=datetime.utcnow, onupdate=datetime.utcnow
|
||||
)
|
||||
members: Mapped[list["OrganizationMember"]] = relationship(
|
||||
back_populates="organization"
|
||||
)
|
||||
status_pages: Mapped[list["StatusPage"]] = relationship(
|
||||
back_populates="organization"
|
||||
)
|
||||
# Services linked to this org (from app.models.models.Service)
|
||||
|
||||
|
||||
class OrganizationMember(Base):
|
||||
"""Joins users to orgs with roles."""
|
||||
|
||||
__tablename__ = "organization_members"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=_uuid_str)
|
||||
organization_id: Mapped[str] = mapped_column(
|
||||
String(36), ForeignKey("organizations.id"), nullable=False, index=True
|
||||
)
|
||||
user_id: Mapped[str] = mapped_column(
|
||||
String(36), ForeignKey("users.id"), nullable=False, index=True
|
||||
)
|
||||
role: Mapped[str] = mapped_column(String(20), nullable=False, default="member")
|
||||
# "owner" | "admin" | "member"
|
||||
joined_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
|
||||
|
||||
organization: Mapped["Organization"] = relationship(back_populates="members")
|
||||
user: Mapped["User"] = relationship(back_populates="memberships")
|
||||
|
||||
__table_args__ = (
|
||||
{"sqlite_autoincrement": True},
|
||||
)
|
||||
|
||||
|
||||
class StatusPage(Base):
|
||||
"""Per-organization status page."""
|
||||
|
||||
__tablename__ = "status_pages"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=_uuid_str)
|
||||
organization_id: Mapped[str] = mapped_column(
|
||||
String(36), ForeignKey("organizations.id"), nullable=False, index=True
|
||||
)
|
||||
slug: Mapped[str] = mapped_column(String(50), nullable=False, index=True)
|
||||
title: Mapped[str] = mapped_column(String(100), nullable=False)
|
||||
subdomain: Mapped[str | None] = mapped_column(String(100), nullable=True)
|
||||
custom_domain: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
is_public: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, default=datetime.utcnow, onupdate=datetime.utcnow
|
||||
)
|
||||
|
||||
organization: Mapped["Organization"] = relationship(back_populates="status_pages")
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint("organization_id", "slug", name="uq_status_page_org_slug"),
|
||||
)
|
||||
0
app/routes/__init__.py
Normal file
0
app/routes/__init__.py
Normal file
122
app/routes/auth.py
Normal file
122
app/routes/auth.py
Normal file
|
|
@ -0,0 +1,122 @@
|
|||
"""Auth routes: register, login, and current user profile."""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from pydantic import BaseModel, EmailStr
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.auth import create_access_token, get_current_user, hash_password, verify_password
|
||||
from app.dependencies import get_db
|
||||
from app.models.saas_models import Organization, OrganizationMember, StatusPage, User
|
||||
|
||||
router = APIRouter(tags=["auth"])
|
||||
|
||||
|
||||
# ── Request / Response schemas ──────────────────────────────────────────────
|
||||
|
||||
|
||||
class RegisterRequest(BaseModel):
|
||||
email: EmailStr
|
||||
password: str
|
||||
|
||||
|
||||
class LoginRequest(BaseModel):
|
||||
email: EmailStr
|
||||
password: str
|
||||
|
||||
|
||||
class AuthResponse(BaseModel):
|
||||
access_token: str
|
||||
token_type: str = "bearer"
|
||||
|
||||
|
||||
class UserProfile(BaseModel):
|
||||
id: str
|
||||
email: str
|
||||
display_name: str | None = None
|
||||
is_email_verified: bool
|
||||
created_at: str | None = None
|
||||
|
||||
|
||||
# ── Routes ───────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.post("/auth/register", status_code=status.HTTP_201_CREATED, response_model=AuthResponse)
|
||||
async def register(body: RegisterRequest, db: AsyncSession = Depends(get_db)):
|
||||
"""Register a new user, create a default Organization + StatusPage, return JWT."""
|
||||
# Check for existing user
|
||||
result = await db.execute(select(User).where(User.email == body.email))
|
||||
if result.scalar_one_or_none() is not None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail="A user with this email already exists",
|
||||
)
|
||||
|
||||
# Create user
|
||||
user = User(
|
||||
email=body.email,
|
||||
password_hash=hash_password(body.password),
|
||||
)
|
||||
db.add(user)
|
||||
await db.flush() # assign user.id
|
||||
|
||||
# Create default organization
|
||||
org_slug = body.email.split("@")[0].lower()
|
||||
org = Organization(
|
||||
name=org_slug,
|
||||
slug=org_slug,
|
||||
)
|
||||
db.add(org)
|
||||
await db.flush() # assign org.id
|
||||
|
||||
# Create organization membership (owner)
|
||||
membership = OrganizationMember(
|
||||
organization_id=org.id,
|
||||
user_id=user.id,
|
||||
role="owner",
|
||||
)
|
||||
db.add(membership)
|
||||
|
||||
# Create default status page
|
||||
status_page = StatusPage(
|
||||
organization_id=org.id,
|
||||
slug="main",
|
||||
title="Status Page",
|
||||
)
|
||||
db.add(status_page)
|
||||
|
||||
# Commit all together (the get_db dependency also commits, but flush ensures
|
||||
# relationships are consistent before we return)
|
||||
await db.flush()
|
||||
|
||||
# Generate JWT
|
||||
token = create_access_token(user.id)
|
||||
return AuthResponse(access_token=token)
|
||||
|
||||
|
||||
@router.post("/auth/login", response_model=AuthResponse)
|
||||
async def login(body: LoginRequest, db: AsyncSession = Depends(get_db)):
|
||||
"""Authenticate user with email + password, return JWT."""
|
||||
result = await db.execute(select(User).where(User.email == body.email))
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
if user is None or not verify_password(body.password, user.password_hash):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid email or password",
|
||||
)
|
||||
|
||||
token = create_access_token(user.id)
|
||||
return AuthResponse(access_token=token)
|
||||
|
||||
|
||||
@router.get("/auth/me", response_model=UserProfile)
|
||||
async def me(current_user: User = Depends(get_current_user)):
|
||||
"""Return the current authenticated user's profile."""
|
||||
return UserProfile(
|
||||
id=current_user.id,
|
||||
email=current_user.email,
|
||||
display_name=current_user.display_name,
|
||||
is_email_verified=current_user.is_email_verified,
|
||||
created_at=current_user.created_at.isoformat() if current_user.created_at else None,
|
||||
)
|
||||
229
app/services/tier_limits.py
Normal file
229
app/services/tier_limits.py
Normal file
|
|
@ -0,0 +1,229 @@
|
|||
"""Tier enforcement: limits and feature flags for Free/Pro/Team plans.
|
||||
|
||||
This module defines the per-tier limits and provides enforcement functions
|
||||
that raise HTTP 403 when a limit is exceeded or a feature is unavailable.
|
||||
"""
|
||||
|
||||
from fastapi import HTTPException, status
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.saas_models import Organization, OrganizationMember, StatusPage
|
||||
from app.models.models import Service, Monitor, Subscriber
|
||||
|
||||
|
||||
# ── Tier limits configuration ──────────────────────────────────────────────
|
||||
|
||||
TIER_LIMITS: dict[str, dict[str, int | bool]] = {
|
||||
"free": {
|
||||
"status_pages": 1,
|
||||
"services_per_page": 5,
|
||||
"monitors_per_service": 1,
|
||||
"subscribers": 25,
|
||||
"members": 1,
|
||||
"check_interval_min": 5, # minutes
|
||||
"custom_domain": False,
|
||||
"custom_branding": False,
|
||||
"webhooks": False,
|
||||
"api_access": False,
|
||||
"incident_history_days": 30,
|
||||
"sla_badge": False,
|
||||
"password_protection": False,
|
||||
},
|
||||
"pro": {
|
||||
"status_pages": 5,
|
||||
"services_per_page": 50,
|
||||
"monitors_per_service": 5,
|
||||
"subscribers": 500,
|
||||
"members": 3,
|
||||
"check_interval_min": 1, # minutes
|
||||
"custom_domain": True,
|
||||
"custom_branding": True,
|
||||
"webhooks": True,
|
||||
"api_access": True,
|
||||
"incident_history_days": 365,
|
||||
"sla_badge": True,
|
||||
"password_protection": False,
|
||||
},
|
||||
"team": {
|
||||
"status_pages": -1, # unlimited
|
||||
"services_per_page": -1,
|
||||
"monitors_per_service": -1,
|
||||
"subscribers": -1,
|
||||
"members": -1,
|
||||
"check_interval_min": 0, # 30 seconds (0 min)
|
||||
"custom_domain": True,
|
||||
"custom_branding": True,
|
||||
"webhooks": True,
|
||||
"api_access": True,
|
||||
"incident_history_days": -1, # unlimited
|
||||
"sla_badge": True,
|
||||
"password_protection": True,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# ── Tier info helpers ───────────────────────────────────────────────────────
|
||||
|
||||
def get_tier_limits(tier: str) -> dict:
|
||||
"""Return the limits dict for a given tier name. Falls back to free."""
|
||||
return TIER_LIMITS.get(tier, TIER_LIMITS["free"])
|
||||
|
||||
|
||||
def get_org_limits(org: Organization) -> dict:
|
||||
"""Return the limits dict for an organization based on its tier."""
|
||||
return get_tier_limits(org.tier or "free")
|
||||
|
||||
|
||||
def get_limit(org: Organization, feature: str):
|
||||
"""Return the limit value for a specific feature given the org's tier.
|
||||
|
||||
Returns:
|
||||
int: numeric limit (-1 means unlimited)
|
||||
bool: feature flag (True/False)
|
||||
None: unknown feature
|
||||
"""
|
||||
limits = get_org_limits(org)
|
||||
return limits.get(feature)
|
||||
|
||||
|
||||
# ── Enforcement ─────────────────────────────────────────────────────────────
|
||||
|
||||
class TierLimitExceeded(HTTPException):
|
||||
"""Raised when a tier limit is exceeded."""
|
||||
|
||||
def __init__(self, feature: str, limit: int | bool):
|
||||
if limit is False:
|
||||
detail = f"Feature '{feature}' is not available on your current plan. Upgrade to access it."
|
||||
else:
|
||||
detail = (
|
||||
f"Tier limit reached for '{feature}' ({limit}). "
|
||||
"Upgrade your plan to increase this limit."
|
||||
)
|
||||
super().__init__(status_code=status.HTTP_403_FORBIDDEN, detail=detail)
|
||||
|
||||
|
||||
async def enforce_limit(
|
||||
db: AsyncSession,
|
||||
org: Organization,
|
||||
feature: str,
|
||||
current_count: int,
|
||||
) -> None:
|
||||
"""Raise TierLimitExceeded if the current count meets or exceeds the tier limit.
|
||||
|
||||
Args:
|
||||
db: Database session for queries.
|
||||
org: The organization whose tier limits to check.
|
||||
feature: The feature name (key in TIER_LIMITS).
|
||||
current_count: How many of this feature the org currently has.
|
||||
|
||||
Raises:
|
||||
TierLimitExceeded: If the limit is reached or the feature is disabled.
|
||||
"""
|
||||
limit = get_limit(org, feature)
|
||||
|
||||
if limit is None:
|
||||
return # Unknown feature — don't block
|
||||
|
||||
if limit is False:
|
||||
# Feature flag: not available on this tier
|
||||
raise TierLimitExceeded(feature, limit)
|
||||
|
||||
if limit == -1:
|
||||
return # Unlimited
|
||||
|
||||
if isinstance(limit, int) and current_count >= limit:
|
||||
raise TierLimitExceeded(feature, limit)
|
||||
|
||||
|
||||
# ── Concrete enforcement helpers ────────────────────────────────────────────
|
||||
|
||||
async def check_status_page_limit(db: AsyncSession, org: Organization) -> None:
|
||||
"""Check that the org hasn't exceeded its status page limit."""
|
||||
result = await db.execute(
|
||||
select(func.count(StatusPage.id)).where(
|
||||
StatusPage.organization_id == org.id
|
||||
)
|
||||
)
|
||||
count = result.scalar() or 0
|
||||
await enforce_limit(db, org, "status_pages", count)
|
||||
|
||||
|
||||
async def check_service_limit(
|
||||
db: AsyncSession, org: Organization, status_page_id: str | None = None
|
||||
) -> None:
|
||||
"""Check that the org hasn't exceeded its services-per-page limit.
|
||||
|
||||
If status_page_id is None, counts all services for the org.
|
||||
"""
|
||||
query = select(func.count(Service.id)).where(
|
||||
Service.organization_id == org.id
|
||||
)
|
||||
if status_page_id:
|
||||
# In future, Service will have a status_page_id column
|
||||
# For now, count all services in the org
|
||||
pass
|
||||
result = await db.execute(query)
|
||||
count = result.scalar() or 0
|
||||
await enforce_limit(db, org, "services_per_page", count)
|
||||
|
||||
|
||||
async def check_monitor_limit(
|
||||
db: AsyncSession, org: Organization, service_id: str
|
||||
) -> None:
|
||||
"""Check that the service hasn't exceeded its monitors-per-service limit."""
|
||||
result = await db.execute(
|
||||
select(func.count(Monitor.id)).where(Monitor.service_id == service_id)
|
||||
)
|
||||
count = result.scalar() or 0
|
||||
await enforce_limit(db, org, "monitors_per_service", count)
|
||||
|
||||
|
||||
async def check_subscriber_limit(db: AsyncSession, org: Organization) -> None:
|
||||
"""Check that the org hasn't exceeded its subscriber limit."""
|
||||
result = await db.execute(
|
||||
select(func.count(Subscriber.id)).where(
|
||||
Subscriber.organization_id == org.id
|
||||
)
|
||||
)
|
||||
count = result.scalar() or 0
|
||||
await enforce_limit(db, org, "subscribers", count)
|
||||
|
||||
|
||||
async def check_member_limit(db: AsyncSession, org: Organization) -> None:
|
||||
"""Check that the org hasn't exceeded its team member limit."""
|
||||
result = await db.execute(
|
||||
select(func.count(OrganizationMember.id)).where(
|
||||
OrganizationMember.organization_id == org.id
|
||||
)
|
||||
)
|
||||
count = result.scalar() or 0
|
||||
await enforce_limit(db, org, "members", count)
|
||||
|
||||
|
||||
def enforce_feature(org: Organization, feature: str) -> None:
|
||||
"""Enforce a boolean feature flag. Raises TierLimitExceeded if False.
|
||||
|
||||
Use this for features that are either allowed or not (e.g., custom_domain,
|
||||
webhooks, api_access) without a numeric limit.
|
||||
"""
|
||||
limit = get_limit(org, feature)
|
||||
if limit is False:
|
||||
raise TierLimitExceeded(feature, False)
|
||||
if limit is None:
|
||||
# Unknown feature — don't block
|
||||
return
|
||||
|
||||
|
||||
def get_tier_info(org: Organization) -> dict:
|
||||
"""Return a dict of the org's current tier with limits and feature flags.
|
||||
|
||||
Useful for API responses that show the org what they can and can't do.
|
||||
"""
|
||||
limits = get_org_limits(org)
|
||||
return {
|
||||
"tier": org.tier or "free",
|
||||
"limits": limits,
|
||||
"organization_id": org.id,
|
||||
"organization_slug": org.slug,
|
||||
}
|
||||
|
|
@ -20,6 +20,10 @@ dependencies = [
|
|||
"httpx>=0.27,<1.0",
|
||||
"typer>=0.9,<1.0",
|
||||
"rich>=13.0,<14.0",
|
||||
"python-jose[cryptography]>=3.3,<4.0",
|
||||
"passlib[bcrypt]>=1.7,<2.0",
|
||||
"bcrypt==4.0.1",
|
||||
"email-validator>=2.0,<3.0",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
|
|
|
|||
|
|
@ -21,6 +21,9 @@ TestSessionLocal = async_sessionmaker(
|
|||
@pytest_asyncio.fixture(scope="session", autouse=True)
|
||||
async def setup_database():
|
||||
"""Create all tables once for the test session."""
|
||||
# Import SaaS models so their tables are registered on Base.metadata
|
||||
import app.models.saas_models # noqa: F401
|
||||
|
||||
async with test_engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
yield
|
||||
|
|
|
|||
103
tests/test_auth.py
Normal file
103
tests/test_auth.py
Normal file
|
|
@ -0,0 +1,103 @@
|
|||
"""Test Auth API endpoints."""
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
REGISTER_URL = "/api/v1/auth/register"
|
||||
LOGIN_URL = "/api/v1/auth/login"
|
||||
ME_URL = "/api/v1/auth/me"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_new_user(client):
|
||||
"""Should register a new user and return 201 with a JWT token."""
|
||||
response = await client.post(
|
||||
REGISTER_URL,
|
||||
json={"email": "newuser@example.com", "password": "securepassword123"},
|
||||
)
|
||||
assert response.status_code == 201
|
||||
data = response.json()
|
||||
assert "access_token" in data
|
||||
assert data["token_type"] == "bearer"
|
||||
assert len(data["access_token"]) > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_duplicate_email(client):
|
||||
"""Should return 409 when registering with an email that already exists."""
|
||||
# Register first user
|
||||
await client.post(
|
||||
REGISTER_URL,
|
||||
json={"email": "duplicate@example.com", "password": "password123"},
|
||||
)
|
||||
# Try to register again with same email
|
||||
response = await client.post(
|
||||
REGISTER_URL,
|
||||
json={"email": "duplicate@example.com", "password": "differentpassword"},
|
||||
)
|
||||
assert response.status_code == 409
|
||||
assert "already exists" in response.json()["detail"].lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_login_correct_password(client):
|
||||
"""Should return 200 with a JWT token on successful login."""
|
||||
# Register a user first
|
||||
await client.post(
|
||||
REGISTER_URL,
|
||||
json={"email": "loginuser@example.com", "password": "mypassword"},
|
||||
)
|
||||
# Login with correct password
|
||||
response = await client.post(
|
||||
LOGIN_URL,
|
||||
json={"email": "loginuser@example.com", "password": "mypassword"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "access_token" in data
|
||||
assert data["token_type"] == "bearer"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_login_wrong_password(client):
|
||||
"""Should return 401 when logging in with wrong password."""
|
||||
# Register a user first
|
||||
await client.post(
|
||||
REGISTER_URL,
|
||||
json={"email": "wrongpw@example.com", "password": "correctpassword"},
|
||||
)
|
||||
# Login with wrong password
|
||||
response = await client.post(
|
||||
LOGIN_URL,
|
||||
json={"email": "wrongpw@example.com", "password": "wrongpassword"},
|
||||
)
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_me_with_valid_token(client):
|
||||
"""Should return 200 with user profile when using a valid JWT token."""
|
||||
# Register a user and get token
|
||||
reg_response = await client.post(
|
||||
REGISTER_URL,
|
||||
json={"email": "meuser@example.com", "password": "password123"},
|
||||
)
|
||||
token = reg_response.json()["access_token"]
|
||||
|
||||
# Get profile with valid token
|
||||
response = await client.get(
|
||||
ME_URL,
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["email"] == "meuser@example.com"
|
||||
assert "id" in data
|
||||
assert data["is_email_verified"] is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_me_without_token(client):
|
||||
"""Should return 401 when accessing /me without a token."""
|
||||
response = await client.get(ME_URL)
|
||||
assert response.status_code in (401, 403)
|
||||
590
tests/test_tier_limits.py
Normal file
590
tests/test_tier_limits.py
Normal file
|
|
@ -0,0 +1,590 @@
|
|||
"""Test tier enforcement: limits, feature flags, and organization endpoints."""
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.models.saas_models import Organization, OrganizationMember, StatusPage, User
|
||||
from app.models.models import Service, Monitor, Subscriber
|
||||
from app.services.tier_limits import (
|
||||
TIER_LIMITS,
|
||||
TierLimitExceeded,
|
||||
enforce_limit,
|
||||
enforce_feature,
|
||||
get_limit,
|
||||
get_org_limits,
|
||||
get_tier_info,
|
||||
get_tier_limits,
|
||||
check_status_page_limit,
|
||||
check_service_limit,
|
||||
check_monitor_limit,
|
||||
check_subscriber_limit,
|
||||
check_member_limit,
|
||||
)
|
||||
|
||||
|
||||
# ── Unit tests for tier_limits module ────────────────────────────────────────
|
||||
|
||||
class TestTierLimitsConfig:
|
||||
"""Test that the TIER_LIMITS config is well-formed."""
|
||||
|
||||
def test_all_tiers_defined(self):
|
||||
"""All three tiers should be defined."""
|
||||
assert "free" in TIER_LIMITS
|
||||
assert "pro" in TIER_LIMITS
|
||||
assert "team" in TIER_LIMITS
|
||||
|
||||
def test_free_tier_has_expected_keys(self):
|
||||
"""Free tier should have all expected limit keys."""
|
||||
free = TIER_LIMITS["free"]
|
||||
expected_keys = {
|
||||
"status_pages", "services_per_page", "monitors_per_service",
|
||||
"subscribers", "members", "check_interval_min",
|
||||
"custom_domain", "custom_branding", "webhooks",
|
||||
"api_access", "incident_history_days", "sla_badge",
|
||||
"password_protection",
|
||||
}
|
||||
assert set(free.keys()) == expected_keys
|
||||
|
||||
def test_free_tier_values(self):
|
||||
"""Free tier should have restrictive values."""
|
||||
free = TIER_LIMITS["free"]
|
||||
assert free["status_pages"] == 1
|
||||
assert free["services_per_page"] == 5
|
||||
assert free["monitors_per_service"] == 1
|
||||
assert free["subscribers"] == 25
|
||||
assert free["members"] == 1
|
||||
assert free["custom_domain"] is False
|
||||
assert free["webhooks"] is False
|
||||
assert free["api_access"] is False
|
||||
|
||||
def test_pro_tier_values(self):
|
||||
"""Pro tier should have moderate values."""
|
||||
pro = TIER_LIMITS["pro"]
|
||||
assert pro["status_pages"] == 5
|
||||
assert pro["services_per_page"] == 50
|
||||
assert pro["monitors_per_service"] == 5
|
||||
assert pro["subscribers"] == 500
|
||||
assert pro["members"] == 3
|
||||
assert pro["custom_domain"] is True
|
||||
assert pro["webhooks"] is True
|
||||
|
||||
def test_team_tier_values(self):
|
||||
"""Team tier should have unlimited (-1) for most things."""
|
||||
team = TIER_LIMITS["team"]
|
||||
assert team["status_pages"] == -1
|
||||
assert team["services_per_page"] == -1
|
||||
assert team["monitors_per_service"] == -1
|
||||
assert team["subscribers"] == -1
|
||||
assert team["members"] == -1
|
||||
assert team["custom_domain"] is True
|
||||
assert team["password_protection"] is True
|
||||
|
||||
def test_all_tiers_have_same_keys(self):
|
||||
"""All tiers should have exactly the same set of keys."""
|
||||
keys = set(TIER_LIMITS["free"].keys())
|
||||
for tier_name, tier_data in TIER_LIMITS.items():
|
||||
assert set(tier_data.keys()) == keys, f"Tier '{tier_name}' has different keys"
|
||||
|
||||
|
||||
class TestGetLimitHelpers:
|
||||
"""Test the helper functions."""
|
||||
|
||||
def test_get_tier_limits_known_tier(self):
|
||||
"""get_tier_limits should return the correct dict for known tiers."""
|
||||
assert get_tier_limits("free") == TIER_LIMITS["free"]
|
||||
assert get_tier_limits("pro") == TIER_LIMITS["pro"]
|
||||
assert get_tier_limits("team") == TIER_LIMITS["team"]
|
||||
|
||||
def test_get_tier_limits_unknown_tier(self):
|
||||
"""get_tier_limits should default to free for unknown tiers."""
|
||||
assert get_tier_limits("enterprise") == TIER_LIMITS["free"]
|
||||
assert get_tier_limits("") == TIER_LIMITS["free"]
|
||||
|
||||
def test_get_org_limits_with_tier(self):
|
||||
"""get_org_limits should use the org's tier."""
|
||||
org = Organization(slug="test", name="Test", tier="pro")
|
||||
assert get_org_limits(org) == TIER_LIMITS["pro"]
|
||||
|
||||
def test_get_org_limits_with_none_tier(self):
|
||||
"""get_org_limits should default to free when tier is None."""
|
||||
org = Organization(slug="test", name="Test", tier=None)
|
||||
assert get_org_limits(org) == TIER_LIMITS["free"]
|
||||
|
||||
def test_get_limit_numeric(self):
|
||||
"""get_limit should return numeric limits."""
|
||||
org = Organization(slug="test", name="Test", tier="free")
|
||||
assert get_limit(org, "status_pages") == 1
|
||||
assert get_limit(org, "subscribers") == 25
|
||||
|
||||
def test_get_limit_boolean(self):
|
||||
"""get_limit should return boolean feature flags."""
|
||||
org_free = Organization(slug="free-org", name="Free Org", tier="free")
|
||||
org_pro = Organization(slug="pro-org", name="Pro Org", tier="pro")
|
||||
assert get_limit(org_free, "custom_domain") is False
|
||||
assert get_limit(org_pro, "custom_domain") is True
|
||||
|
||||
def test_get_limit_unlimited(self):
|
||||
"""get_limit should return -1 for unlimited features."""
|
||||
org = Organization(slug="team-org", name="Team Org", tier="team")
|
||||
assert get_limit(org, "status_pages") == -1
|
||||
|
||||
def test_get_limit_unknown_feature(self):
|
||||
"""get_limit should return None for unknown feature names."""
|
||||
org = Organization(slug="test", name="Test", tier="free")
|
||||
assert get_limit(org, "nonexistent_feature") is None
|
||||
|
||||
|
||||
class TestEnforceLimit:
|
||||
"""Test the enforce_limit function."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enforce_limit_allows_under_limit(self, db_session):
|
||||
"""Should not raise when current_count is below the limit."""
|
||||
org = Organization(slug="test", name="Test", tier="free")
|
||||
db_session.add(org)
|
||||
await db_session.flush()
|
||||
# status_pages limit is 1 for free, current_count=0 should pass
|
||||
await enforce_limit(db_session, org, "status_pages", 0)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enforce_limit_blocks_at_limit(self, db_session):
|
||||
"""Should raise TierLimitExceeded when current_count equals the limit."""
|
||||
org = Organization(slug="test", name="Test", tier="free")
|
||||
db_session.add(org)
|
||||
await db_session.flush()
|
||||
# status_pages limit is 1 for free, current_count=1 should fail
|
||||
with pytest.raises(TierLimitExceeded) as exc_info:
|
||||
await enforce_limit(db_session, org, "status_pages", 1)
|
||||
assert "status_pages" in str(exc_info.value.detail)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enforce_limit_blocks_over_limit(self, db_session):
|
||||
"""Should raise TierLimitExceeded when current_count exceeds the limit."""
|
||||
org = Organization(slug="test", name="Test", tier="free")
|
||||
db_session.add(org)
|
||||
await db_session.flush()
|
||||
with pytest.raises(TierLimitExceeded):
|
||||
await enforce_limit(db_session, org, "status_pages", 5)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enforce_limit_allows_unlimited(self, db_session):
|
||||
"""Should not raise when limit is -1 (unlimited)."""
|
||||
org = Organization(slug="test", name="Test", tier="team")
|
||||
db_session.add(org)
|
||||
await db_session.flush()
|
||||
# team has unlimited status_pages, even count=1000 should pass
|
||||
await enforce_limit(db_session, org, "status_pages", 1000)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enforce_limit_blocks_feature_flag(self, db_session):
|
||||
"""Should raise TierLimitExceeded when feature is False (not available)."""
|
||||
org = Organization(slug="test", name="Test", tier="free")
|
||||
db_session.add(org)
|
||||
await db_session.flush()
|
||||
with pytest.raises(TierLimitExceeded) as exc_info:
|
||||
await enforce_limit(db_session, org, "custom_domain", 0)
|
||||
assert "not available" in str(exc_info.value.detail)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enforce_limit_unknown_feature(self, db_session):
|
||||
"""Should not raise for an unknown feature (no limit defined)."""
|
||||
org = Organization(slug="test", name="Test", tier="free")
|
||||
db_session.add(org)
|
||||
await db_session.flush()
|
||||
# Should not raise for an unknown feature
|
||||
await enforce_limit(db_session, org, "nonexistent", 999)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enforce_limit_pro_tier(self, db_session):
|
||||
"""Pro tier allows up to 5 status pages."""
|
||||
org = Organization(slug="pro-org", name="Pro Org", tier="pro")
|
||||
db_session.add(org)
|
||||
await db_session.flush()
|
||||
# 5 is the limit, should fail at 5
|
||||
with pytest.raises(TierLimitExceeded):
|
||||
await enforce_limit(db_session, org, "status_pages", 5)
|
||||
# 4 should pass
|
||||
await enforce_limit(db_session, org, "status_pages", 4)
|
||||
|
||||
|
||||
class TestEnforceFeature:
|
||||
"""Test the enforce_feature boolean feature flag function."""
|
||||
|
||||
def test_enforce_feature_allows_pro_custom_domain(self):
|
||||
"""Pro org should be allowed custom_domain."""
|
||||
org = Organization(slug="pro", name="Pro", tier="pro")
|
||||
# Should not raise
|
||||
enforce_feature(org, "custom_domain")
|
||||
|
||||
def test_enforce_feature_blocks_free_custom_domain(self):
|
||||
"""Free org should be blocked from custom_domain."""
|
||||
org = Organization(slug="free", name="Free", tier="free")
|
||||
with pytest.raises(TierLimitExceeded):
|
||||
enforce_feature(org, "custom_domain")
|
||||
|
||||
def test_enforce_feature_blocks_free_webhooks(self):
|
||||
"""Free org should be blocked from webhooks."""
|
||||
org = Organization(slug="free", name="Free", tier="free")
|
||||
with pytest.raises(TierLimitExceeded):
|
||||
enforce_feature(org, "webhooks")
|
||||
|
||||
def test_enforce_feature_blocks_free_api_access(self):
|
||||
"""Free org should be blocked from api_access."""
|
||||
org = Organization(slug="free", name="Free", tier="free")
|
||||
with pytest.raises(TierLimitExceeded):
|
||||
enforce_feature(org, "api_access")
|
||||
|
||||
def test_enforce_feature_blocks_free_password_protection(self):
|
||||
"""Free org should be blocked from password_protection."""
|
||||
org = Organization(slug="free", name="Free", tier="free")
|
||||
with pytest.raises(TierLimitExceeded):
|
||||
enforce_feature(org, "password_protection")
|
||||
|
||||
def test_enforce_feature_allows_pro_webhooks(self):
|
||||
"""Pro org should be allowed webhooks."""
|
||||
org = Organization(slug="pro", name="Pro", tier="pro")
|
||||
enforce_feature(org, "webhooks")
|
||||
|
||||
def test_enforce_feature_unknown_feature(self):
|
||||
"""Unknown feature should not raise."""
|
||||
org = Organization(slug="free", name="Free", tier="free")
|
||||
# Should not raise for unknown
|
||||
enforce_feature(org, "quantum_computing")
|
||||
|
||||
def test_enforce_feature_team_password_protection(self):
|
||||
"""Team org should be allowed password_protection."""
|
||||
org = Organization(slug="team", name="Team", tier="team")
|
||||
enforce_feature(org, "password_protection")
|
||||
|
||||
|
||||
class TestGetTierInfo:
|
||||
"""Test get_tier_info helper."""
|
||||
|
||||
def test_free_org_tier_info(self):
|
||||
"""Free org should return complete tier info."""
|
||||
org = Organization(id="org-1", slug="free-org", name="Free Org", tier="free")
|
||||
info = get_tier_info(org)
|
||||
assert info["tier"] == "free"
|
||||
assert info["organization_id"] == "org-1"
|
||||
assert info["limits"]["status_pages"] == 1
|
||||
assert info["limits"]["custom_domain"] is False
|
||||
|
||||
def test_pro_org_tier_info(self):
|
||||
"""Pro org should return correct tier info."""
|
||||
org = Organization(id="org-2", slug="pro-org", name="Pro Org", tier="pro")
|
||||
info = get_tier_info(org)
|
||||
assert info["tier"] == "pro"
|
||||
assert info["limits"]["status_pages"] == 5
|
||||
assert info["limits"]["custom_domain"] is True
|
||||
|
||||
def test_default_tier_info(self):
|
||||
"""Org with None tier should default to free."""
|
||||
org = Organization(id="org-3", slug="default-org", name="Default Org", tier=None)
|
||||
info = get_tier_info(org)
|
||||
assert info["tier"] == "free"
|
||||
assert info["limits"]["status_pages"] == 1
|
||||
|
||||
|
||||
# ── Integration tests for database-backed limit checks ───────────────────────
|
||||
|
||||
class TestCheckStatusPageLimit:
|
||||
"""Test check_status_page_limit with real database."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_free_org_can_create_first_page(self, db_session):
|
||||
"""Free org should be allowed to create its first status page."""
|
||||
org = Organization(slug="free-org", name="Free Org", tier="free")
|
||||
db_session.add(org)
|
||||
await db_session.flush()
|
||||
# Should not raise (0 existing pages, limit is 1)
|
||||
await check_status_page_limit(db_session, org)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_free_org_blocked_at_second_page(self, db_session):
|
||||
"""Free org should be blocked from creating a second status page."""
|
||||
org = Organization(slug="free-org", name="Free Org", tier="free")
|
||||
db_session.add(org)
|
||||
await db_session.flush()
|
||||
# Create first page
|
||||
page = StatusPage(organization_id=org.id, slug="main", title="Main")
|
||||
db_session.add(page)
|
||||
await db_session.flush()
|
||||
# Now limit should be reached
|
||||
with pytest.raises(TierLimitExceeded):
|
||||
await check_status_page_limit(db_session, org)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pro_org_allows_up_to_five_pages(self, db_session):
|
||||
"""Pro org should be allowed up to 5 status pages."""
|
||||
org = Organization(slug="pro-org", name="Pro Org", tier="pro")
|
||||
db_session.add(org)
|
||||
await db_session.flush()
|
||||
# Create 4 pages — should be fine (limit is 5)
|
||||
for i in range(4):
|
||||
page = StatusPage(organization_id=org.id, slug=f"page-{i}", title=f"Page {i}")
|
||||
db_session.add(page)
|
||||
await db_session.flush()
|
||||
await check_status_page_limit(db_session, org) # should not raise
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pro_org_blocked_at_sixth_page(self, db_session):
|
||||
"""Pro org should be blocked at the 6th status page."""
|
||||
org = Organization(slug="pro-org", name="Pro Org", tier="pro")
|
||||
db_session.add(org)
|
||||
await db_session.flush()
|
||||
# Create 5 pages (at the limit)
|
||||
for i in range(5):
|
||||
page = StatusPage(organization_id=org.id, slug=f"page-{i}", title=f"Page {i}")
|
||||
db_session.add(page)
|
||||
await db_session.flush()
|
||||
with pytest.raises(TierLimitExceeded):
|
||||
await check_status_page_limit(db_session, org)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_team_org_unlimited_pages(self, db_session):
|
||||
"""Team org should never be blocked for status pages."""
|
||||
org = Organization(slug="team-org", name="Team Org", tier="team")
|
||||
db_session.add(org)
|
||||
await db_session.flush()
|
||||
for i in range(50):
|
||||
page = StatusPage(organization_id=org.id, slug=f"page-{i}", title=f"Page {i}")
|
||||
db_session.add(page)
|
||||
await db_session.flush()
|
||||
# Should not raise even with 50 pages
|
||||
await check_status_page_limit(db_session, org)
|
||||
|
||||
|
||||
class TestCheckMemberLimit:
|
||||
"""Test check_member_limit with real database."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_free_org_blocked_at_second_member(self, db_session):
|
||||
"""Free org should be blocked from adding a second member."""
|
||||
org = Organization(slug="free-org", name="Free Org", tier="free")
|
||||
db_session.add(org)
|
||||
await db_session.flush()
|
||||
# Create one user
|
||||
user = User(email="owner@example.com", password_hash="hash")
|
||||
db_session.add(user)
|
||||
await db_session.flush()
|
||||
# Add owner membership
|
||||
membership = OrganizationMember(
|
||||
organization_id=org.id, user_id=user.id, role="owner"
|
||||
)
|
||||
db_session.add(membership)
|
||||
await db_session.flush()
|
||||
# Now at limit — should be blocked
|
||||
with pytest.raises(TierLimitExceeded):
|
||||
await check_member_limit(db_session, org)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_free_org_allows_one_member(self, db_session):
|
||||
"""Free org with 0 members should be allowed to add one."""
|
||||
org = Organization(slug="free-org", name="Free Org", tier="free")
|
||||
db_session.add(org)
|
||||
await db_session.flush()
|
||||
# 0 members, limit is 1 — should pass
|
||||
await check_member_limit(db_session, org)
|
||||
|
||||
|
||||
class TestCheckServiceLimit:
|
||||
"""Test check_service_limit with real database."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_free_org_can_create_first_services(self, db_session):
|
||||
"""Free org should be allowed to create services up to limit."""
|
||||
org = Organization(slug="free-org", name="Free Org", tier="free")
|
||||
db_session.add(org)
|
||||
await db_session.flush()
|
||||
# Create 4 services (limit is 5)
|
||||
for i in range(4):
|
||||
svc = Service(
|
||||
name=f"svc-{i}", slug=f"svc-{i}", organization_id=org.id
|
||||
)
|
||||
db_session.add(svc)
|
||||
await db_session.flush()
|
||||
# Should not raise
|
||||
await check_service_limit(db_session, org)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_free_org_blocked_at_sixth_service(self, db_session):
|
||||
"""Free org should be blocked at the 6th service."""
|
||||
org = Organization(slug="free-org", name="Free Org", tier="free")
|
||||
db_session.add(org)
|
||||
await db_session.flush()
|
||||
# Create 5 services (at limit)
|
||||
for i in range(5):
|
||||
svc = Service(
|
||||
name=f"svc-{i}", slug=f"svc-{i}", organization_id=org.id
|
||||
)
|
||||
db_session.add(svc)
|
||||
await db_session.flush()
|
||||
with pytest.raises(TierLimitExceeded):
|
||||
await check_service_limit(db_session, org)
|
||||
|
||||
|
||||
# ── API integration tests ──────────────────────────────────────────────────────
|
||||
|
||||
REGISTER_URL = "/api/v1/auth/register"
|
||||
LOGIN_URL = "/api/v1/auth/login"
|
||||
ME_URL = "/api/v1/auth/me"
|
||||
ORGS_TIERS_URL = "/api/v1/organizations/tiers"
|
||||
ORGS_MY_URL = "/api/v1/organizations/my"
|
||||
ORGS_MY_LIMITS_URL = "/api/v1/organizations/my/limits"
|
||||
ORGS_MY_TIER_URL = "/api/v1/organizations/my/tier"
|
||||
|
||||
|
||||
class TestTierAPIEndpoints:
|
||||
"""Test the organization/tier API endpoints."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_tiers_public(self, client):
|
||||
"""List tiers endpoint should be accessible without auth."""
|
||||
response = await client.get(ORGS_TIERS_URL)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "tiers" in data
|
||||
assert len(data["tiers"]) == 3
|
||||
tier_names = [t["name"] for t in data["tiers"]]
|
||||
assert "free" in tier_names
|
||||
assert "pro" in tier_names
|
||||
assert "team" in tier_names
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_my_org(self, client):
|
||||
"""Authenticated user should be able to get their org info."""
|
||||
# Register
|
||||
reg_response = await client.post(
|
||||
REGISTER_URL,
|
||||
json={"email": "orguser@example.com", "password": "testpass123"},
|
||||
)
|
||||
assert reg_response.status_code == 201
|
||||
token = reg_response.json()["access_token"]
|
||||
|
||||
# Get org info
|
||||
response = await client.get(
|
||||
ORGS_MY_URL,
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["tier"] == "free"
|
||||
assert "tier_info" in data
|
||||
assert data["tier_info"]["tier"] == "free"
|
||||
assert data["member_count"] >= 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_my_org_unauthorized(self, client):
|
||||
"""Unauthenticated request should be rejected."""
|
||||
response = await client.get(ORGS_MY_URL)
|
||||
assert response.status_code in (401, 403)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_my_limits(self, client):
|
||||
"""Authenticated user should be able to see their tier limits."""
|
||||
reg_response = await client.post(
|
||||
REGISTER_URL,
|
||||
json={"email": "limitsuser@example.com", "password": "testpass123"},
|
||||
)
|
||||
token = reg_response.json()["access_token"]
|
||||
|
||||
response = await client.get(
|
||||
ORGS_MY_LIMITS_URL,
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["tier"] == "free"
|
||||
assert data["limits"]["status_pages"] == 1
|
||||
assert data["limits"]["custom_domain"] is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upgrade_tier_to_pro(self, client):
|
||||
"""Org owner should be able to upgrade to Pro."""
|
||||
reg_response = await client.post(
|
||||
REGISTER_URL,
|
||||
json={"email": "upgradeuser@example.com", "password": "testpass123"},
|
||||
)
|
||||
token = reg_response.json()["access_token"]
|
||||
|
||||
# Upgrade to pro
|
||||
response = await client.patch(
|
||||
ORGS_MY_TIER_URL,
|
||||
json={"tier": "pro"},
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["tier"] == "pro"
|
||||
assert data["tier_info"]["limits"]["status_pages"] == 5
|
||||
assert data["tier_info"]["limits"]["custom_domain"] is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upgrade_tier_to_team(self, client):
|
||||
"""Org owner should be able to upgrade to Team."""
|
||||
reg_response = await client.post(
|
||||
REGISTER_URL,
|
||||
json={"email": "teamuser@example.com", "password": "testpass123"},
|
||||
)
|
||||
token = reg_response.json()["access_token"]
|
||||
|
||||
response = await client.patch(
|
||||
ORGS_MY_TIER_URL,
|
||||
json={"tier": "team"},
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["tier"] == "team"
|
||||
assert data["tier_info"]["limits"]["status_pages"] == -1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upgrade_tier_invalid(self, client):
|
||||
"""Upgrading to an invalid tier should be rejected."""
|
||||
reg_response = await client.post(
|
||||
REGISTER_URL,
|
||||
json={"email": "invalidtier@example.com", "password": "testpass123"},
|
||||
)
|
||||
token = reg_response.json()["access_token"]
|
||||
|
||||
response = await client.patch(
|
||||
ORGS_MY_TIER_URL,
|
||||
json={"tier": "enterprise"},
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
assert response.status_code in (403, 422)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_downgrade_back_to_free(self, client):
|
||||
"""Org should be able to downgrade back to free."""
|
||||
reg_response = await client.post(
|
||||
REGISTER_URL,
|
||||
json={"email": "downgrade@example.com", "password": "testpass123"},
|
||||
)
|
||||
token = reg_response.json()["access_token"]
|
||||
|
||||
# Upgrade to pro first
|
||||
await client.patch(
|
||||
ORGS_MY_TIER_URL,
|
||||
json={"tier": "pro"},
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
|
||||
# Downgrade back to free
|
||||
response = await client.patch(
|
||||
ORGS_MY_TIER_URL,
|
||||
json={"tier": "free"},
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["tier"] == "free"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tier_upgrade_unauthorized(self, client):
|
||||
"""Unauthenticated tier upgrade should be rejected."""
|
||||
response = await client.patch(
|
||||
ORGS_MY_TIER_URL,
|
||||
json={"tier": "pro"},
|
||||
)
|
||||
assert response.status_code in (401, 403)
|
||||
Loading…
Add table
Add a link
Reference in a new issue