"""Tier enforcement: limits and feature flags for Free/Pro/Team plans. This module defines the per-tier limits and provides enforcement functions that raise HTTP 403 when a limit is exceeded or a feature is unavailable. """ from fastapi import HTTPException, status from sqlalchemy import func, select from sqlalchemy.ext.asyncio import AsyncSession from app.models.saas_models import Organization, OrganizationMember, StatusPage from app.models.models import Service, Monitor, Subscriber # ── Tier limits configuration ────────────────────────────────────────────── TIER_LIMITS: dict[str, dict[str, int | bool]] = { "free": { "status_pages": 1, "services_per_page": 5, "monitors_per_service": 1, "subscribers": 25, "members": 1, "check_interval_min": 5, # minutes "custom_domain": False, "custom_branding": False, "webhooks": False, "api_access": False, "incident_history_days": 30, "sla_badge": False, "password_protection": False, }, "pro": { "status_pages": 5, "services_per_page": 50, "monitors_per_service": 5, "subscribers": 500, "members": 3, "check_interval_min": 1, # minutes "custom_domain": True, "custom_branding": True, "webhooks": True, "api_access": True, "incident_history_days": 365, "sla_badge": True, "password_protection": False, }, "team": { "status_pages": -1, # unlimited "services_per_page": -1, "monitors_per_service": -1, "subscribers": -1, "members": -1, "check_interval_min": 0, # 30 seconds (0 min) "custom_domain": True, "custom_branding": True, "webhooks": True, "api_access": True, "incident_history_days": -1, # unlimited "sla_badge": True, "password_protection": True, }, } # ── Tier info helpers ─────────────────────────────────────────────────────── def get_tier_limits(tier: str) -> dict: """Return the limits dict for a given tier name. Falls back to free.""" return TIER_LIMITS.get(tier, TIER_LIMITS["free"]) def get_org_limits(org: Organization) -> dict: """Return the limits dict for an organization based on its tier.""" return get_tier_limits(org.tier or "free") def get_limit(org: Organization, feature: str): """Return the limit value for a specific feature given the org's tier. Returns: int: numeric limit (-1 means unlimited) bool: feature flag (True/False) None: unknown feature """ limits = get_org_limits(org) return limits.get(feature) # ── Enforcement ───────────────────────────────────────────────────────────── class TierLimitExceeded(HTTPException): """Raised when a tier limit is exceeded.""" def __init__(self, feature: str, limit: int | bool): if limit is False: detail = f"Feature '{feature}' is not available on your current plan. Upgrade to access it." else: detail = ( f"Tier limit reached for '{feature}' ({limit}). " "Upgrade your plan to increase this limit." ) super().__init__(status_code=status.HTTP_403_FORBIDDEN, detail=detail) async def enforce_limit( db: AsyncSession, org: Organization, feature: str, current_count: int, ) -> None: """Raise TierLimitExceeded if the current count meets or exceeds the tier limit. Args: db: Database session for queries. org: The organization whose tier limits to check. feature: The feature name (key in TIER_LIMITS). current_count: How many of this feature the org currently has. Raises: TierLimitExceeded: If the limit is reached or the feature is disabled. """ limit = get_limit(org, feature) if limit is None: return # Unknown feature — don't block if limit is False: # Feature flag: not available on this tier raise TierLimitExceeded(feature, limit) if limit == -1: return # Unlimited if isinstance(limit, int) and current_count >= limit: raise TierLimitExceeded(feature, limit) # ── Concrete enforcement helpers ──────────────────────────────────────────── async def check_status_page_limit(db: AsyncSession, org: Organization) -> None: """Check that the org hasn't exceeded its status page limit.""" result = await db.execute( select(func.count(StatusPage.id)).where( StatusPage.organization_id == org.id ) ) count = result.scalar() or 0 await enforce_limit(db, org, "status_pages", count) async def check_service_limit( db: AsyncSession, org: Organization, status_page_id: str | None = None ) -> None: """Check that the org hasn't exceeded its services-per-page limit. If status_page_id is None, counts all services for the org. """ query = select(func.count(Service.id)).where( Service.organization_id == org.id ) if status_page_id: # In future, Service will have a status_page_id column # For now, count all services in the org pass result = await db.execute(query) count = result.scalar() or 0 await enforce_limit(db, org, "services_per_page", count) async def check_monitor_limit( db: AsyncSession, org: Organization, service_id: str ) -> None: """Check that the service hasn't exceeded its monitors-per-service limit.""" result = await db.execute( select(func.count(Monitor.id)).where(Monitor.service_id == service_id) ) count = result.scalar() or 0 await enforce_limit(db, org, "monitors_per_service", count) async def check_subscriber_limit(db: AsyncSession, org: Organization) -> None: """Check that the org hasn't exceeded its subscriber limit.""" result = await db.execute( select(func.count(Subscriber.id)).where( Subscriber.organization_id == org.id ) ) count = result.scalar() or 0 await enforce_limit(db, org, "subscribers", count) async def check_member_limit(db: AsyncSession, org: Organization) -> None: """Check that the org hasn't exceeded its team member limit.""" result = await db.execute( select(func.count(OrganizationMember.id)).where( OrganizationMember.organization_id == org.id ) ) count = result.scalar() or 0 await enforce_limit(db, org, "members", count) def enforce_feature(org: Organization, feature: str) -> None: """Enforce a boolean feature flag. Raises TierLimitExceeded if False. Use this for features that are either allowed or not (e.g., custom_domain, webhooks, api_access) without a numeric limit. """ limit = get_limit(org, feature) if limit is False: raise TierLimitExceeded(feature, False) if limit is None: # Unknown feature — don't block return def get_tier_info(org: Organization) -> dict: """Return a dict of the org's current tier with limits and feature flags. Useful for API responses that show the org what they can and can't do. """ limits = get_org_limits(org) return { "tier": org.tier or "free", "limits": limits, "organization_id": org.id, "organization_slug": org.slug, }