229 lines
No EOL
7.7 KiB
Python
229 lines
No EOL
7.7 KiB
Python
"""Tier enforcement: limits and feature flags for Free/Pro/Team plans.
|
|
|
|
This module defines the per-tier limits and provides enforcement functions
|
|
that raise HTTP 403 when a limit is exceeded or a feature is unavailable.
|
|
"""
|
|
|
|
from fastapi import HTTPException, status
|
|
from sqlalchemy import func, select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.models.saas_models import Organization, OrganizationMember, StatusPage
|
|
from app.models.models import Service, Monitor, Subscriber
|
|
|
|
|
|
# ── Tier limits configuration ──────────────────────────────────────────────
|
|
|
|
TIER_LIMITS: dict[str, dict[str, int | bool]] = {
|
|
"free": {
|
|
"status_pages": 1,
|
|
"services_per_page": 5,
|
|
"monitors_per_service": 1,
|
|
"subscribers": 25,
|
|
"members": 1,
|
|
"check_interval_min": 5, # minutes
|
|
"custom_domain": False,
|
|
"custom_branding": False,
|
|
"webhooks": False,
|
|
"api_access": False,
|
|
"incident_history_days": 30,
|
|
"sla_badge": False,
|
|
"password_protection": False,
|
|
},
|
|
"pro": {
|
|
"status_pages": 5,
|
|
"services_per_page": 50,
|
|
"monitors_per_service": 5,
|
|
"subscribers": 500,
|
|
"members": 3,
|
|
"check_interval_min": 1, # minutes
|
|
"custom_domain": True,
|
|
"custom_branding": True,
|
|
"webhooks": True,
|
|
"api_access": True,
|
|
"incident_history_days": 365,
|
|
"sla_badge": True,
|
|
"password_protection": False,
|
|
},
|
|
"team": {
|
|
"status_pages": -1, # unlimited
|
|
"services_per_page": -1,
|
|
"monitors_per_service": -1,
|
|
"subscribers": -1,
|
|
"members": -1,
|
|
"check_interval_min": 0, # 30 seconds (0 min)
|
|
"custom_domain": True,
|
|
"custom_branding": True,
|
|
"webhooks": True,
|
|
"api_access": True,
|
|
"incident_history_days": -1, # unlimited
|
|
"sla_badge": True,
|
|
"password_protection": True,
|
|
},
|
|
}
|
|
|
|
|
|
# ── Tier info helpers ───────────────────────────────────────────────────────
|
|
|
|
def get_tier_limits(tier: str) -> dict:
|
|
"""Return the limits dict for a given tier name. Falls back to free."""
|
|
return TIER_LIMITS.get(tier, TIER_LIMITS["free"])
|
|
|
|
|
|
def get_org_limits(org: Organization) -> dict:
|
|
"""Return the limits dict for an organization based on its tier."""
|
|
return get_tier_limits(org.tier or "free")
|
|
|
|
|
|
def get_limit(org: Organization, feature: str):
|
|
"""Return the limit value for a specific feature given the org's tier.
|
|
|
|
Returns:
|
|
int: numeric limit (-1 means unlimited)
|
|
bool: feature flag (True/False)
|
|
None: unknown feature
|
|
"""
|
|
limits = get_org_limits(org)
|
|
return limits.get(feature)
|
|
|
|
|
|
# ── Enforcement ─────────────────────────────────────────────────────────────
|
|
|
|
class TierLimitExceeded(HTTPException):
|
|
"""Raised when a tier limit is exceeded."""
|
|
|
|
def __init__(self, feature: str, limit: int | bool):
|
|
if limit is False:
|
|
detail = f"Feature '{feature}' is not available on your current plan. Upgrade to access it."
|
|
else:
|
|
detail = (
|
|
f"Tier limit reached for '{feature}' ({limit}). "
|
|
"Upgrade your plan to increase this limit."
|
|
)
|
|
super().__init__(status_code=status.HTTP_403_FORBIDDEN, detail=detail)
|
|
|
|
|
|
async def enforce_limit(
|
|
db: AsyncSession,
|
|
org: Organization,
|
|
feature: str,
|
|
current_count: int,
|
|
) -> None:
|
|
"""Raise TierLimitExceeded if the current count meets or exceeds the tier limit.
|
|
|
|
Args:
|
|
db: Database session for queries.
|
|
org: The organization whose tier limits to check.
|
|
feature: The feature name (key in TIER_LIMITS).
|
|
current_count: How many of this feature the org currently has.
|
|
|
|
Raises:
|
|
TierLimitExceeded: If the limit is reached or the feature is disabled.
|
|
"""
|
|
limit = get_limit(org, feature)
|
|
|
|
if limit is None:
|
|
return # Unknown feature — don't block
|
|
|
|
if limit is False:
|
|
# Feature flag: not available on this tier
|
|
raise TierLimitExceeded(feature, limit)
|
|
|
|
if limit == -1:
|
|
return # Unlimited
|
|
|
|
if isinstance(limit, int) and current_count >= limit:
|
|
raise TierLimitExceeded(feature, limit)
|
|
|
|
|
|
# ── Concrete enforcement helpers ────────────────────────────────────────────
|
|
|
|
async def check_status_page_limit(db: AsyncSession, org: Organization) -> None:
|
|
"""Check that the org hasn't exceeded its status page limit."""
|
|
result = await db.execute(
|
|
select(func.count(StatusPage.id)).where(
|
|
StatusPage.organization_id == org.id
|
|
)
|
|
)
|
|
count = result.scalar() or 0
|
|
await enforce_limit(db, org, "status_pages", count)
|
|
|
|
|
|
async def check_service_limit(
|
|
db: AsyncSession, org: Organization, status_page_id: str | None = None
|
|
) -> None:
|
|
"""Check that the org hasn't exceeded its services-per-page limit.
|
|
|
|
If status_page_id is None, counts all services for the org.
|
|
"""
|
|
query = select(func.count(Service.id)).where(
|
|
Service.organization_id == org.id
|
|
)
|
|
if status_page_id:
|
|
# In future, Service will have a status_page_id column
|
|
# For now, count all services in the org
|
|
pass
|
|
result = await db.execute(query)
|
|
count = result.scalar() or 0
|
|
await enforce_limit(db, org, "services_per_page", count)
|
|
|
|
|
|
async def check_monitor_limit(
|
|
db: AsyncSession, org: Organization, service_id: str
|
|
) -> None:
|
|
"""Check that the service hasn't exceeded its monitors-per-service limit."""
|
|
result = await db.execute(
|
|
select(func.count(Monitor.id)).where(Monitor.service_id == service_id)
|
|
)
|
|
count = result.scalar() or 0
|
|
await enforce_limit(db, org, "monitors_per_service", count)
|
|
|
|
|
|
async def check_subscriber_limit(db: AsyncSession, org: Organization) -> None:
|
|
"""Check that the org hasn't exceeded its subscriber limit."""
|
|
result = await db.execute(
|
|
select(func.count(Subscriber.id)).where(
|
|
Subscriber.organization_id == org.id
|
|
)
|
|
)
|
|
count = result.scalar() or 0
|
|
await enforce_limit(db, org, "subscribers", count)
|
|
|
|
|
|
async def check_member_limit(db: AsyncSession, org: Organization) -> None:
|
|
"""Check that the org hasn't exceeded its team member limit."""
|
|
result = await db.execute(
|
|
select(func.count(OrganizationMember.id)).where(
|
|
OrganizationMember.organization_id == org.id
|
|
)
|
|
)
|
|
count = result.scalar() or 0
|
|
await enforce_limit(db, org, "members", count)
|
|
|
|
|
|
def enforce_feature(org: Organization, feature: str) -> None:
|
|
"""Enforce a boolean feature flag. Raises TierLimitExceeded if False.
|
|
|
|
Use this for features that are either allowed or not (e.g., custom_domain,
|
|
webhooks, api_access) without a numeric limit.
|
|
"""
|
|
limit = get_limit(org, feature)
|
|
if limit is False:
|
|
raise TierLimitExceeded(feature, False)
|
|
if limit is None:
|
|
# Unknown feature — don't block
|
|
return
|
|
|
|
|
|
def get_tier_info(org: Organization) -> dict:
|
|
"""Return a dict of the org's current tier with limits and feature flags.
|
|
|
|
Useful for API responses that show the org what they can and can't do.
|
|
"""
|
|
limits = get_org_limits(org)
|
|
return {
|
|
"tier": org.tier or "free",
|
|
"limits": limits,
|
|
"organization_id": org.id,
|
|
"organization_slug": org.slug,
|
|
} |