"""Tier enforcement middleware and dependencies for FastAPI. This module provides FastAPI dependencies that enforce tier limits on API endpoints. When an org-scoped endpoint creates a resource, these dependencies check the org's tier limits before allowing the creation. Usage: from app.services.tier_enforcement import ( require_org_from_header, enforce_status_page_limit, enforce_service_limit, enforce_monitor_limit, enforce_subscriber_limit, enforce_feature_flag, validate_check_interval, ) @router.post("/services") async def create_service( ..., org: Organization = Depends(require_org_from_header), ): await enforce_service_limit(db, org) ... """ from fastapi import Depends, Header, HTTPException, status from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from app.auth import get_current_user from app.dependencies import get_db from app.models.saas_models import Organization, OrganizationMember, User from app.services.tier_limits import ( TierLimitExceeded, enforce_feature, get_limit, get_org_limits, ) async def require_org_from_header( x_organization_id: str = Header(..., alias="X-Organization-ID"), user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ) -> Organization: """FastAPI dependency: extract org ID from header and verify membership. The calling client must include an X-Organization-ID header. We verify that the authenticated user is a member of this org. """ result = await db.execute( select(Organization).where(Organization.id == x_organization_id) ) org = result.scalar_one_or_none() if org is None: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=f"Organization '{x_organization_id}' not found", ) # Verify membership membership_result = await db.execute( select(OrganizationMember).where( OrganizationMember.organization_id == org.id, OrganizationMember.user_id == user.id, ) ) membership = membership_result.scalar_one_or_none() if membership is None: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="You are not a member of this organization", ) return org async def require_org_from_api_key( x_organization_id: str = Header(None, alias="X-Organization-ID"), api_key: str = Header(None, alias="X-API-Key"), db: AsyncSession = Depends(get_db), ) -> Organization | None: """Optional dependency: extract org from header when both header and API key present. Returns None if no X-Organization-ID is provided or org not found. This allows gradual adoption: existing endpoints without org context still work. """ if not x_organization_id: return None result = await db.execute( select(Organization).where(Organization.id == x_organization_id) ) org = result.scalar_one_or_none() return org async def enforce_status_page_limit(db: AsyncSession, org: Organization) -> None: """Check that the org hasn't exceeded its status page limit. Raises TierLimitExceeded (HTTP 403) if limit is reached. """ from app.models.saas_models import StatusPage from sqlalchemy import func result = await db.execute( select(func.count(StatusPage.id)).where( StatusPage.organization_id == org.id ) ) count = result.scalar() or 0 limit = get_limit(org, "status_pages") if limit is False: raise TierLimitExceeded("status_pages", False) if limit != -1 and count >= limit: raise TierLimitExceeded("status_pages", limit) async def enforce_service_limit(db: AsyncSession, org: Organization) -> None: """Check that the org hasn't exceeded its services limit. Raises TierLimitExceeded (HTTP 403) if limit is reached. """ from app.models.models import Service from sqlalchemy import func result = await db.execute( select(func.count(Service.id)).where( Service.organization_id == org.id ) ) count = result.scalar() or 0 limit = get_limit(org, "services_per_page") if limit is False: raise TierLimitExceeded("services_per_page", False) if limit != -1 and count >= limit: raise TierLimitExceeded("services_per_page", limit) async def enforce_monitor_limit( db: AsyncSession, org: Organization, service_id: str | None = None ) -> None: """Check that the org/service hasn't exceeded its monitor limit. Raises TierLimitExceeded (HTTP 403) if limit is reached. """ from app.models.models import Monitor from sqlalchemy import func if service_id: result = await db.execute( select(func.count(Monitor.id)).where(Monitor.service_id == service_id) ) else: result = await db.execute( select(func.count(Monitor.id)).where( Monitor.organization_id == org.id ) ) count = result.scalar() or 0 limit = get_limit(org, "monitors_per_service") if limit is False: raise TierLimitExceeded("monitors_per_service", False) if limit != -1 and count >= limit: raise TierLimitExceeded("monitors_per_service", limit) async def enforce_subscriber_limit(db: AsyncSession, org: Organization) -> None: """Check that the org hasn't exceeded its subscriber limit. Raises TierLimitExceeded (HTTP 403) if limit is reached. """ from app.models.models import Subscriber from sqlalchemy import func result = await db.execute( select(func.count(Subscriber.id)).where( Subscriber.organization_id == org.id ) ) count = result.scalar() or 0 limit = get_limit(org, "subscribers") if limit is False: raise TierLimitExceeded("subscribers", False) if limit != -1 and count >= limit: raise TierLimitExceeded("subscribers", limit) async def enforce_member_limit(db: AsyncSession, org: Organization) -> None: """Check that the org hasn't exceeded its team member limit. Raises TierLimitExceeded (HTTP 403) if limit is reached. """ from sqlalchemy import func result = await db.execute( select(func.count(OrganizationMember.id)).where( OrganizationMember.organization_id == org.id ) ) count = result.scalar() or 0 limit = get_limit(org, "members") if limit is False: raise TierLimitExceeded("members", False) if limit != -1 and count >= limit: raise TierLimitExceeded("members", limit) def enforce_feature_flag(org: Organization, feature: str) -> None: """Enforce a boolean feature flag based on org's tier. Raises TierLimitExceeded (HTTP 403) if the feature is not available. """ enforce_feature(org, feature) def validate_check_interval(org: Organization, interval_seconds: int) -> None: """Validate that the check interval meets the org's tier minimum. Free tier: minimum 5 minutes (300 seconds) Pro tier: minimum 1 minute (60 seconds) Team tier: minimum 30 seconds Raises TierLimitExceeded if the interval is below the tier minimum. """ limits = get_org_limits(org) min_minutes = limits.get("check_interval_min", 5) # Convert minimum minutes to seconds min_seconds = min_minutes * 60 # Team tier has 0-minute minimum meaning 30 seconds if min_minutes == 0: min_seconds = 30 if interval_seconds < min_seconds: raise TierLimitExceeded( "check_interval_min", f"Minimum check interval for your plan is {min_seconds} seconds. " f"Requested: {interval_seconds} seconds.", )