247 lines
No EOL
7.7 KiB
Python
247 lines
No EOL
7.7 KiB
Python
"""Tier enforcement middleware and dependencies for FastAPI.
|
|
|
|
This module provides FastAPI dependencies that enforce tier limits on API
|
|
endpoints. When an org-scoped endpoint creates a resource, these dependencies
|
|
check the org's tier limits before allowing the creation.
|
|
|
|
Usage:
|
|
from app.services.tier_enforcement import (
|
|
require_org_from_header,
|
|
enforce_status_page_limit,
|
|
enforce_service_limit,
|
|
enforce_monitor_limit,
|
|
enforce_subscriber_limit,
|
|
enforce_feature_flag,
|
|
validate_check_interval,
|
|
)
|
|
|
|
@router.post("/services")
|
|
async def create_service(
|
|
...,
|
|
org: Organization = Depends(require_org_from_header),
|
|
):
|
|
await enforce_service_limit(db, org)
|
|
...
|
|
"""
|
|
|
|
from fastapi import Depends, Header, HTTPException, status
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.auth import get_current_user
|
|
from app.dependencies import get_db
|
|
from app.models.saas_models import Organization, OrganizationMember, User
|
|
from app.services.tier_limits import (
|
|
TierLimitExceeded,
|
|
enforce_feature,
|
|
get_limit,
|
|
get_org_limits,
|
|
)
|
|
|
|
|
|
async def require_org_from_header(
|
|
x_organization_id: str = Header(..., alias="X-Organization-ID"),
|
|
user: User = Depends(get_current_user),
|
|
db: AsyncSession = Depends(get_db),
|
|
) -> Organization:
|
|
"""FastAPI dependency: extract org ID from header and verify membership.
|
|
|
|
The calling client must include an X-Organization-ID header.
|
|
We verify that the authenticated user is a member of this org.
|
|
"""
|
|
result = await db.execute(
|
|
select(Organization).where(Organization.id == x_organization_id)
|
|
)
|
|
org = result.scalar_one_or_none()
|
|
if org is None:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail=f"Organization '{x_organization_id}' not found",
|
|
)
|
|
|
|
# Verify membership
|
|
membership_result = await db.execute(
|
|
select(OrganizationMember).where(
|
|
OrganizationMember.organization_id == org.id,
|
|
OrganizationMember.user_id == user.id,
|
|
)
|
|
)
|
|
membership = membership_result.scalar_one_or_none()
|
|
if membership is None:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail="You are not a member of this organization",
|
|
)
|
|
|
|
return org
|
|
|
|
|
|
async def require_org_from_api_key(
|
|
x_organization_id: str = Header(None, alias="X-Organization-ID"),
|
|
api_key: str = Header(None, alias="X-API-Key"),
|
|
db: AsyncSession = Depends(get_db),
|
|
) -> Organization | None:
|
|
"""Optional dependency: extract org from header when both header and API key present.
|
|
|
|
Returns None if no X-Organization-ID is provided or org not found.
|
|
This allows gradual adoption: existing endpoints without org context still work.
|
|
"""
|
|
if not x_organization_id:
|
|
return None
|
|
|
|
result = await db.execute(
|
|
select(Organization).where(Organization.id == x_organization_id)
|
|
)
|
|
org = result.scalar_one_or_none()
|
|
return org
|
|
|
|
|
|
async def enforce_status_page_limit(db: AsyncSession, org: Organization) -> None:
|
|
"""Check that the org hasn't exceeded its status page limit.
|
|
|
|
Raises TierLimitExceeded (HTTP 403) if limit is reached.
|
|
"""
|
|
from app.models.saas_models import StatusPage
|
|
from sqlalchemy import func
|
|
|
|
result = await db.execute(
|
|
select(func.count(StatusPage.id)).where(
|
|
StatusPage.organization_id == org.id
|
|
)
|
|
)
|
|
count = result.scalar() or 0
|
|
|
|
limit = get_limit(org, "status_pages")
|
|
if limit is False:
|
|
raise TierLimitExceeded("status_pages", False)
|
|
if limit != -1 and count >= limit:
|
|
raise TierLimitExceeded("status_pages", limit)
|
|
|
|
|
|
async def enforce_service_limit(db: AsyncSession, org: Organization) -> None:
|
|
"""Check that the org hasn't exceeded its services limit.
|
|
|
|
Raises TierLimitExceeded (HTTP 403) if limit is reached.
|
|
"""
|
|
from app.models.models import Service
|
|
from sqlalchemy import func
|
|
|
|
result = await db.execute(
|
|
select(func.count(Service.id)).where(
|
|
Service.organization_id == org.id
|
|
)
|
|
)
|
|
count = result.scalar() or 0
|
|
|
|
limit = get_limit(org, "services_per_page")
|
|
if limit is False:
|
|
raise TierLimitExceeded("services_per_page", False)
|
|
if limit != -1 and count >= limit:
|
|
raise TierLimitExceeded("services_per_page", limit)
|
|
|
|
|
|
async def enforce_monitor_limit(
|
|
db: AsyncSession, org: Organization, service_id: str | None = None
|
|
) -> None:
|
|
"""Check that the org/service hasn't exceeded its monitor limit.
|
|
|
|
Raises TierLimitExceeded (HTTP 403) if limit is reached.
|
|
"""
|
|
from app.models.models import Monitor
|
|
from sqlalchemy import func
|
|
|
|
if service_id:
|
|
result = await db.execute(
|
|
select(func.count(Monitor.id)).where(Monitor.service_id == service_id)
|
|
)
|
|
else:
|
|
result = await db.execute(
|
|
select(func.count(Monitor.id)).where(
|
|
Monitor.organization_id == org.id
|
|
)
|
|
)
|
|
count = result.scalar() or 0
|
|
|
|
limit = get_limit(org, "monitors_per_service")
|
|
if limit is False:
|
|
raise TierLimitExceeded("monitors_per_service", False)
|
|
if limit != -1 and count >= limit:
|
|
raise TierLimitExceeded("monitors_per_service", limit)
|
|
|
|
|
|
async def enforce_subscriber_limit(db: AsyncSession, org: Organization) -> None:
|
|
"""Check that the org hasn't exceeded its subscriber limit.
|
|
|
|
Raises TierLimitExceeded (HTTP 403) if limit is reached.
|
|
"""
|
|
from app.models.models import Subscriber
|
|
from sqlalchemy import func
|
|
|
|
result = await db.execute(
|
|
select(func.count(Subscriber.id)).where(
|
|
Subscriber.organization_id == org.id
|
|
)
|
|
)
|
|
count = result.scalar() or 0
|
|
|
|
limit = get_limit(org, "subscribers")
|
|
if limit is False:
|
|
raise TierLimitExceeded("subscribers", False)
|
|
if limit != -1 and count >= limit:
|
|
raise TierLimitExceeded("subscribers", limit)
|
|
|
|
|
|
async def enforce_member_limit(db: AsyncSession, org: Organization) -> None:
|
|
"""Check that the org hasn't exceeded its team member limit.
|
|
|
|
Raises TierLimitExceeded (HTTP 403) if limit is reached.
|
|
"""
|
|
from sqlalchemy import func
|
|
|
|
result = await db.execute(
|
|
select(func.count(OrganizationMember.id)).where(
|
|
OrganizationMember.organization_id == org.id
|
|
)
|
|
)
|
|
count = result.scalar() or 0
|
|
|
|
limit = get_limit(org, "members")
|
|
if limit is False:
|
|
raise TierLimitExceeded("members", False)
|
|
if limit != -1 and count >= limit:
|
|
raise TierLimitExceeded("members", limit)
|
|
|
|
|
|
def enforce_feature_flag(org: Organization, feature: str) -> None:
|
|
"""Enforce a boolean feature flag based on org's tier.
|
|
|
|
Raises TierLimitExceeded (HTTP 403) if the feature is not available.
|
|
"""
|
|
enforce_feature(org, feature)
|
|
|
|
|
|
def validate_check_interval(org: Organization, interval_seconds: int) -> None:
|
|
"""Validate that the check interval meets the org's tier minimum.
|
|
|
|
Free tier: minimum 5 minutes (300 seconds)
|
|
Pro tier: minimum 1 minute (60 seconds)
|
|
Team tier: minimum 30 seconds
|
|
|
|
Raises TierLimitExceeded if the interval is below the tier minimum.
|
|
"""
|
|
limits = get_org_limits(org)
|
|
min_minutes = limits.get("check_interval_min", 5)
|
|
|
|
# Convert minimum minutes to seconds
|
|
min_seconds = min_minutes * 60
|
|
|
|
# Team tier has 0-minute minimum meaning 30 seconds
|
|
if min_minutes == 0:
|
|
min_seconds = 30
|
|
|
|
if interval_seconds < min_seconds:
|
|
raise TierLimitExceeded(
|
|
"check_interval_min",
|
|
f"Minimum check interval for your plan is {min_seconds} seconds. "
|
|
f"Requested: {interval_seconds} seconds.",
|
|
) |