From 44d353a30ff3133c6b96e77eb29ec5ccdd7ddcc7 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Sat, 25 Apr 2026 12:14:06 +0000 Subject: [PATCH] feat: status page enhancements --- app/api/monitors.py | 71 ++++++++- app/api/organizations.py | 173 +++++++++++++++++++++- app/api/services.py | 40 ++++- app/api/subscribers.py | 44 +++++- app/services/tier_enforcement.py | 247 +++++++++++++++++++++++++++++++ 5 files changed, 561 insertions(+), 14 deletions(-) create mode 100644 app/services/tier_enforcement.py diff --git a/app/api/monitors.py b/app/api/monitors.py index cb5684f..e25ac34 100644 --- a/app/api/monitors.py +++ b/app/api/monitors.py @@ -1,4 +1,9 @@ -"""Monitors API endpoints.""" +"""Monitors API endpoints with tier enforcement. + +When creating a monitor for an organization, tier enforcement checks: +- monitors_per_service: max number of monitors per service +- check_interval_min: minimum allowed check interval +""" from uuid import UUID @@ -8,7 +13,13 @@ from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from app.dependencies import get_db, verify_api_key -from app.models.models import Monitor, MonitorResult +from app.models.models import Monitor, MonitorResult, Service +from app.models.saas_models import Organization +from app.services.tier_enforcement import ( + enforce_monitor_limit, + validate_check_interval, +) +from app.services.tier_limits import TierLimitExceeded, get_tier_limits router = APIRouter() @@ -20,6 +31,7 @@ class MonitorCreate(BaseModel): expected_status: int = 200 timeout_seconds: int = Field(10, ge=1, le=60) interval_seconds: int = Field(60, ge=30, le=3600) + organization_id: str | None = None # Optional: for tier enforcement class MonitorUpdate(BaseModel): @@ -41,6 +53,7 @@ def serialize_monitor(m: Monitor) -> dict: "timeout_seconds": m.timeout_seconds, "interval_seconds": m.interval_seconds, "is_active": m.is_active, + "organization_id": m.organization_id, "created_at": m.created_at.isoformat() if m.created_at else None, "updated_at": m.updated_at.isoformat() if m.updated_at else None, } @@ -60,7 +73,43 @@ async def create_monitor( db: AsyncSession = Depends(get_db), api_key: str = Depends(verify_api_key), ): - """Create a new monitor.""" + """Create a new monitor. + + If organization_id is provided, tier enforcement is applied: + - check monitors_per_service limit + - validate check_interval against minimum for the tier + """ + # Look up org if provided + org = None + if data.organization_id: + result = await db.execute( + select(Organization).where(Organization.id == data.organization_id) + ) + org = result.scalar_one_or_none() + if org is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Organization '{data.organization_id}' not found", + ) + + # Tier enforcement when org context is provided + if org is not None: + # Check monitors_per_service limit + await enforce_monitor_limit(db, org, str(data.service_id)) + # Validate check interval + validate_check_interval(org, data.interval_seconds) + + # Verify the service exists + service_result = await db.execute( + select(Service).where(Service.id == str(data.service_id)) + ) + service = service_result.scalar_one_or_none() + if not service: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Service '{data.service_id}' not found", + ) + monitor = Monitor( service_id=str(data.service_id), url=data.url, @@ -68,6 +117,7 @@ async def create_monitor( expected_status=data.expected_status, timeout_seconds=data.timeout_seconds, interval_seconds=data.interval_seconds, + organization_id=data.organization_id or (service.organization_id if service else None), ) db.add(monitor) await db.flush() @@ -115,13 +165,26 @@ async def update_monitor( db: AsyncSession = Depends(get_db), api_key: str = Depends(verify_api_key), ): - """Update a monitor.""" + """Update a monitor. + + If interval_seconds is being changed, validate against the org's tier minimum. + """ result = await db.execute(select(Monitor).where(Monitor.id == str(monitor_id))) monitor = result.scalar_one_or_none() if not monitor: raise HTTPException(status_code=404, detail="Monitor not found") update_data = data.model_dump(exclude_unset=True) + + # If updating interval_seconds and org context exists, validate + if data.interval_seconds is not None and monitor.organization_id: + org_result = await db.execute( + select(Organization).where(Organization.id == monitor.organization_id) + ) + org = org_result.scalar_one_or_none() + if org: + validate_check_interval(org, data.interval_seconds) + for field, value in update_data.items(): setattr(monitor, field, value) diff --git a/app/api/organizations.py b/app/api/organizations.py index b84376f..095543c 100644 --- a/app/api/organizations.py +++ b/app/api/organizations.py @@ -1,4 +1,7 @@ -"""Organization API endpoints: view org, list tiers, upgrade/downgrade.""" +"""Organization API endpoints: view org, list tiers, upgrade/downgrade, feature flags. + +All org-scoped endpoints enforce tier limits on resource creation. +""" from fastapi import APIRouter, Depends, HTTPException, status from pydantic import BaseModel @@ -7,11 +10,17 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.auth import get_current_user, get_current_org from app.dependencies import get_db -from app.models.saas_models import Organization, OrganizationMember, User +from app.models.saas_models import Organization, OrganizationMember, OrganizationMember, User +from app.services.tier_enforcement import ( + enforce_member_limit, + enforce_status_page_limit, + enforce_feature_flag, +) from app.services.tier_limits import ( TIER_LIMITS, get_org_limits, get_tier_info, + TierLimitExceeded, ) router = APIRouter(tags=["organizations"]) @@ -45,6 +54,15 @@ class UpgradeRequest(BaseModel): tier: str # "free" | "pro" | "team" +class InviteMemberRequest(BaseModel): + email: str + role: str = "member" # "member" | "admin" + + +class SetCustomDomainRequest(BaseModel): + domain: str + + # ── Endpoints ─────────────────────────────────────────────────────────────── @router.get("/tiers") @@ -152,4 +170,153 @@ async def get_my_limits( org: Organization = Depends(get_current_org), ): """Get the current organization's tier limits and feature flags.""" - return get_tier_info(org) \ No newline at end of file + return get_tier_info(org) + + +# ── Member management ───────────────────────────────────────────────────── + +@router.post("/my/members", status_code=status.HTTP_201_CREATED) +async def invite_member( + body: InviteMemberRequest, + org: Organization = Depends(get_current_org), + user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + """Invite a new member to the organization. + + Enforces the org's member limit based on tier. + """ + # Enforce member limit + try: + await enforce_member_limit(db, org) + except TierLimitExceeded as e: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=e.detail, + ) + + # Find the user by email + result = await db.execute( + select(User).where(User.email == body.email) + ) + invited_user = result.scalar_one_or_none() + if not invited_user: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"User with email '{body.email}' not found. They must register first.", + ) + + # Check if already a member + existing = await db.execute( + select(OrganizationMember).where( + OrganizationMember.organization_id == org.id, + OrganizationMember.user_id == invited_user.id, + ) + ) + if existing.scalar_one_or_none(): + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, + detail=f"User '{body.email}' is already a member of this organization.", + ) + + if body.role not in ("member", "admin"): + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail="Role must be 'member' or 'admin'", + ) + + membership = OrganizationMember( + organization_id=org.id, + user_id=invited_user.id, + role=body.role, + ) + db.add(membership) + await db.flush() + + return { + "user_id": invited_user.id, + "email": invited_user.email, + "role": body.role, + "organization_id": org.id, + } + + +@router.get("/my/members") +async def list_members( + org: Organization = Depends(get_current_org), + db: AsyncSession = Depends(get_db), +): + """List all members of the organization.""" + result = await db.execute( + select(OrganizationMember).where( + OrganizationMember.organization_id == org.id + ) + ) + memberships = result.scalars().all() + + members = [] + for m in memberships: + user_result = await db.execute( + select(User).where(User.id == m.user_id) + ) + member_user = user_result.scalar_one_or_none() + if member_user: + members.append({ + "user_id": m.user_id, + "email": member_user.email, + "display_name": member_user.display_name, + "role": m.role, + "joined_at": m.joined_at.isoformat() if m.joined_at else None, + }) + + return {"members": members, "count": len(members)} + + +# ── Custom domain ──────────────────────────────────────────────────────── + +@router.post("/my/custom-domain") +async def set_custom_domain( + body: SetCustomDomainRequest, + org: Organization = Depends(get_current_org), + user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + """Set a custom domain for the organization's status page. + + Enforces the custom_domain feature flag based on tier. + Free tier does not have custom domain support. + """ + # Enforce custom_domain feature flag + try: + enforce_feature_flag(org, "custom_domain") + except TierLimitExceeded as e: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=e.detail, + ) + + org.custom_domain = body.domain + await db.flush() + + return { + "organization_id": org.id, + "custom_domain": org.custom_domain, + "message": "Custom domain set. Please add a CNAME record pointing to your status page.", + } + + +@router.delete("/my/custom-domain") +async def remove_custom_domain( + org: Organization = Depends(get_current_org), + user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + """Remove the custom domain from the organization's status page.""" + org.custom_domain = None + await db.flush() + + return { + "organization_id": org.id, + "custom_domain": None, + "message": "Custom domain removed.", + } \ No newline at end of file diff --git a/app/api/services.py b/app/api/services.py index 84a0eea..c0e8574 100644 --- a/app/api/services.py +++ b/app/api/services.py @@ -1,4 +1,11 @@ -"""Services API endpoints.""" +"""Services API endpoints with tier enforcement. + +Provides both admin API-key endpoints (no org context) and +organization-scoped endpoints with tier enforcement. + +When X-Organization-ID header is provided with a valid API key, +tier enforcement is applied to creation endpoints. +""" from uuid import UUID @@ -9,6 +16,11 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.dependencies import get_db, verify_api_key from app.models.models import Service +from app.models.saas_models import Organization +from app.services.tier_enforcement import ( + enforce_service_limit, + get_org_if_provided, +) router = APIRouter() @@ -20,6 +32,7 @@ class ServiceCreate(BaseModel): group_name: str | None = Field(None, max_length=50) position: int = 0 is_visible: bool = True + organization_id: str | None = None class ServiceUpdate(BaseModel): @@ -40,6 +53,7 @@ def serialize_service(s: Service) -> dict: "group_name": s.group_name, "position": s.position, "is_visible": s.is_visible, + "organization_id": s.organization_id, "created_at": s.created_at.isoformat() if s.created_at else None, "updated_at": s.updated_at.isoformat() if s.updated_at else None, } @@ -59,7 +73,28 @@ async def create_service( db: AsyncSession = Depends(get_db), api_key: str = Depends(verify_api_key), ): - """Create a new service.""" + """Create a new service. + + If organization_id is provided in the request body and a matching org + exists, tier enforcement is applied to ensure the org hasn't exceeded + its services_per_page limit. + """ + org = None + if data.organization_id: + result = await db.execute( + select(Organization).where(Organization.id == data.organization_id) + ) + org = result.scalar_one_or_none() + if org is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Organization '{data.organization_id}' not found", + ) + + # Tier enforcement when org context is provided + if org is not None: + await enforce_service_limit(db, org) + service = Service( name=data.name, slug=data.slug, @@ -67,6 +102,7 @@ async def create_service( group_name=data.group_name, position=data.position, is_visible=data.is_visible, + organization_id=data.organization_id, ) db.add(service) await db.flush() diff --git a/app/api/subscribers.py b/app/api/subscribers.py index a45f33e..0d67d9c 100644 --- a/app/api/subscribers.py +++ b/app/api/subscribers.py @@ -1,17 +1,29 @@ -"""Subscribers API endpoints.""" +"""Subscribers API endpoints with tier enforcement. + +When adding a subscriber to an organization, the org's subscriber limit +is checked against the org's tier. +""" from uuid import UUID from fastapi import APIRouter, Depends, HTTPException, status +from pydantic import BaseModel, Field from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from app.dependencies import get_db, verify_api_key from app.models.models import Subscriber +from app.models.saas_models import Organization +from app.services.tier_enforcement import enforce_subscriber_limit router = APIRouter() +class SubscriberCreate(BaseModel): + email: str = Field(..., max_length=255) + organization_id: str | None = None # Optional: for tier enforcement + + @router.get("/") async def list_subscribers(db: AsyncSession = Depends(get_db)): """List all subscribers.""" @@ -21,6 +33,7 @@ async def list_subscribers(db: AsyncSession = Depends(get_db)): { "id": s.id, "email": s.email, + "organization_id": s.organization_id, "is_confirmed": s.is_confirmed, "created_at": s.created_at.isoformat() if s.created_at else None, } @@ -30,16 +43,36 @@ async def list_subscribers(db: AsyncSession = Depends(get_db)): @router.post("/", status_code=status.HTTP_201_CREATED) async def create_subscriber( - email: str, + data: SubscriberCreate, db: AsyncSession = Depends(get_db), api_key: str = Depends(verify_api_key), ): - """Add a new subscriber.""" - import uuid + """Add a new subscriber. + If organization_id is provided, tier enforcement is applied to ensure + the org hasn't exceeded its subscriber limit. + """ + org = None + if data.organization_id: + result = await db.execute( + select(Organization).where(Organization.id == data.organization_id) + ) + org = result.scalar_one_or_none() + if org is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Organization '{data.organization_id}' not found", + ) + + # Tier enforcement when org context is provided + if org is not None: + await enforce_subscriber_limit(db, org) + + import uuid subscriber = Subscriber( - email=email, + email=data.email, confirm_token=str(uuid.uuid4()), + organization_id=data.organization_id, ) db.add(subscriber) await db.flush() @@ -47,6 +80,7 @@ async def create_subscriber( return { "id": subscriber.id, "email": subscriber.email, + "organization_id": subscriber.organization_id, "confirm_token": subscriber.confirm_token, } diff --git a/app/services/tier_enforcement.py b/app/services/tier_enforcement.py new file mode 100644 index 0000000..3d6f3c2 --- /dev/null +++ b/app/services/tier_enforcement.py @@ -0,0 +1,247 @@ +"""Tier enforcement middleware and dependencies for FastAPI. + +This module provides FastAPI dependencies that enforce tier limits on API +endpoints. When an org-scoped endpoint creates a resource, these dependencies +check the org's tier limits before allowing the creation. + +Usage: + from app.services.tier_enforcement import ( + require_org_from_header, + enforce_status_page_limit, + enforce_service_limit, + enforce_monitor_limit, + enforce_subscriber_limit, + enforce_feature_flag, + validate_check_interval, + ) + + @router.post("/services") + async def create_service( + ..., + org: Organization = Depends(require_org_from_header), + ): + await enforce_service_limit(db, org) + ... +""" + +from fastapi import Depends, Header, HTTPException, status +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.auth import get_current_user +from app.dependencies import get_db +from app.models.saas_models import Organization, OrganizationMember, User +from app.services.tier_limits import ( + TierLimitExceeded, + enforce_feature, + get_limit, + get_org_limits, +) + + +async def require_org_from_header( + x_organization_id: str = Header(..., alias="X-Organization-ID"), + user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +) -> Organization: + """FastAPI dependency: extract org ID from header and verify membership. + + The calling client must include an X-Organization-ID header. + We verify that the authenticated user is a member of this org. + """ + result = await db.execute( + select(Organization).where(Organization.id == x_organization_id) + ) + org = result.scalar_one_or_none() + if org is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Organization '{x_organization_id}' not found", + ) + + # Verify membership + membership_result = await db.execute( + select(OrganizationMember).where( + OrganizationMember.organization_id == org.id, + OrganizationMember.user_id == user.id, + ) + ) + membership = membership_result.scalar_one_or_none() + if membership is None: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="You are not a member of this organization", + ) + + return org + + +async def require_org_from_api_key( + x_organization_id: str = Header(None, alias="X-Organization-ID"), + api_key: str = Header(None, alias="X-API-Key"), + db: AsyncSession = Depends(get_db), +) -> Organization | None: + """Optional dependency: extract org from header when both header and API key present. + + Returns None if no X-Organization-ID is provided or org not found. + This allows gradual adoption: existing endpoints without org context still work. + """ + if not x_organization_id: + return None + + result = await db.execute( + select(Organization).where(Organization.id == x_organization_id) + ) + org = result.scalar_one_or_none() + return org + + +async def enforce_status_page_limit(db: AsyncSession, org: Organization) -> None: + """Check that the org hasn't exceeded its status page limit. + + Raises TierLimitExceeded (HTTP 403) if limit is reached. + """ + from app.models.saas_models import StatusPage + from sqlalchemy import func + + result = await db.execute( + select(func.count(StatusPage.id)).where( + StatusPage.organization_id == org.id + ) + ) + count = result.scalar() or 0 + + limit = get_limit(org, "status_pages") + if limit is False: + raise TierLimitExceeded("status_pages", False) + if limit != -1 and count >= limit: + raise TierLimitExceeded("status_pages", limit) + + +async def enforce_service_limit(db: AsyncSession, org: Organization) -> None: + """Check that the org hasn't exceeded its services limit. + + Raises TierLimitExceeded (HTTP 403) if limit is reached. + """ + from app.models.models import Service + from sqlalchemy import func + + result = await db.execute( + select(func.count(Service.id)).where( + Service.organization_id == org.id + ) + ) + count = result.scalar() or 0 + + limit = get_limit(org, "services_per_page") + if limit is False: + raise TierLimitExceeded("services_per_page", False) + if limit != -1 and count >= limit: + raise TierLimitExceeded("services_per_page", limit) + + +async def enforce_monitor_limit( + db: AsyncSession, org: Organization, service_id: str | None = None +) -> None: + """Check that the org/service hasn't exceeded its monitor limit. + + Raises TierLimitExceeded (HTTP 403) if limit is reached. + """ + from app.models.models import Monitor + from sqlalchemy import func + + if service_id: + result = await db.execute( + select(func.count(Monitor.id)).where(Monitor.service_id == service_id) + ) + else: + result = await db.execute( + select(func.count(Monitor.id)).where( + Monitor.organization_id == org.id + ) + ) + count = result.scalar() or 0 + + limit = get_limit(org, "monitors_per_service") + if limit is False: + raise TierLimitExceeded("monitors_per_service", False) + if limit != -1 and count >= limit: + raise TierLimitExceeded("monitors_per_service", limit) + + +async def enforce_subscriber_limit(db: AsyncSession, org: Organization) -> None: + """Check that the org hasn't exceeded its subscriber limit. + + Raises TierLimitExceeded (HTTP 403) if limit is reached. + """ + from app.models.models import Subscriber + from sqlalchemy import func + + result = await db.execute( + select(func.count(Subscriber.id)).where( + Subscriber.organization_id == org.id + ) + ) + count = result.scalar() or 0 + + limit = get_limit(org, "subscribers") + if limit is False: + raise TierLimitExceeded("subscribers", False) + if limit != -1 and count >= limit: + raise TierLimitExceeded("subscribers", limit) + + +async def enforce_member_limit(db: AsyncSession, org: Organization) -> None: + """Check that the org hasn't exceeded its team member limit. + + Raises TierLimitExceeded (HTTP 403) if limit is reached. + """ + from sqlalchemy import func + + result = await db.execute( + select(func.count(OrganizationMember.id)).where( + OrganizationMember.organization_id == org.id + ) + ) + count = result.scalar() or 0 + + limit = get_limit(org, "members") + if limit is False: + raise TierLimitExceeded("members", False) + if limit != -1 and count >= limit: + raise TierLimitExceeded("members", limit) + + +def enforce_feature_flag(org: Organization, feature: str) -> None: + """Enforce a boolean feature flag based on org's tier. + + Raises TierLimitExceeded (HTTP 403) if the feature is not available. + """ + enforce_feature(org, feature) + + +def validate_check_interval(org: Organization, interval_seconds: int) -> None: + """Validate that the check interval meets the org's tier minimum. + + Free tier: minimum 5 minutes (300 seconds) + Pro tier: minimum 1 minute (60 seconds) + Team tier: minimum 30 seconds + + Raises TierLimitExceeded if the interval is below the tier minimum. + """ + limits = get_org_limits(org) + min_minutes = limits.get("check_interval_min", 5) + + # Convert minimum minutes to seconds + min_seconds = min_minutes * 60 + + # Team tier has 0-minute minimum meaning 30 seconds + if min_minutes == 0: + min_seconds = 30 + + if interval_seconds < min_seconds: + raise TierLimitExceeded( + "check_interval_min", + f"Minimum check interval for your plan is {min_seconds} seconds. " + f"Requested: {interval_seconds} seconds.", + ) \ No newline at end of file