diff --git a/SAAS_ENHANCEMENT_PLAN.md b/SAAS_ENHANCEMENT_PLAN.md new file mode 100644 index 0000000..799f4b8 --- /dev/null +++ b/SAAS_ENHANCEMENT_PLAN.md @@ -0,0 +1,1262 @@ +# Indie Status Page — SaaS Enhancement Plan + +> Transform the single-tenant status page into a multi-tenant SaaS product with +> user accounts, org-level status pages, feature tiers, and Stripe payment links. + +--- + +## 1. Current State Analysis + +### 1.1 Architecture Overview + +| Aspect | Detail | +|---|---| +| **Framework** | FastAPI 0.110+ with async SQLAlchemy 2.0 + aiosqlite | +| **Database** | SQLite (`data/statuspage.db`), auto-created via `init_db()` — no Alembic migrations yet | +| **Deployment** | Port 8765, nginx reverse proxy at `korpo.pro/status/`, systemd service `indie-status-page.service` | +| **Templating** | Jinja2 (5 HTML templates) | +| **Scheduling** | APScheduler (60s interval uptime checks) | + +### 1.2 Database Models (8 total) + +File: `app/models/models.py` + +| Model | Table | Key Fields | Relationships | +|---|---|---|---| +| `Service` | `services` | id, name, slug (unique), description, group_name, position, is_visible | → incidents, → monitors | +| `Incident` | `incidents` | id, service_id (FK), title, status, severity, started_at, resolved_at | → service, → updates, → notifications | +| `IncidentUpdate` | `incident_updates` | id, incident_id (FK), status, body, created_at | → incident | +| `Monitor` | `monitors` | id, service_id (FK), url, method, expected_status, timeout_seconds, interval_seconds, is_active | → service, → results | +| `MonitorResult` | `monitor_results` | id, monitor_id (FK), status, response_time_ms, status_code, error_message, checked_at | → monitor | +| `Subscriber` | `subscribers` | id, email (unique), is_confirmed, confirm_token | → notifications | +| `NotificationLog` | `notification_logs` | id, incident_id (FK), subscriber_id (FK), channel, status | → incident, → subscriber | +| `SiteSetting` | `site_settings` | id, key (unique), value | — | + +### 1.3 API Routes (25 endpoints) + +File: `app/api/router.py` — assembles 5 sub-routers under `/api/v1`: + +| Router | File | Prefix | Endpoints | +|---|---|---|---| +| **Services** | `app/api/services.py` | `/services` | GET /, POST /, GET /{id}, PATCH /{id}, DELETE /{id} | +| **Incidents** | `app/api/incidents.py` | `/incidents` | GET /, POST /, GET /{id}, PATCH /{id}, DELETE /{id}, POST /{id}/updates | +| **Monitors** | `app/api/monitors.py` | `/monitors` | GET /, POST /, GET /{id}, PATCH /{id}, DELETE /{id}, POST /{id}/check | +| **Subscribers** | `app/api/subscribers.py` | `/subscribers` | GET /, POST /, DELETE /{id}, POST /{id}/confirm | +| **Settings** | `app/api/settings.py` | `/settings` | GET /, PATCH / | + +### 1.4 Authentication + +File: `app/dependencies.py` + +- **Single shared API key**: `X-API-Key` header validated against `settings.admin_api_key` +- No user accounts, no sessions, no JWT tokens +- All write endpoints require API key; read endpoints are public + +### 1.5 Frontend Pages + +File: `app/main.py` + `app/templates/` + +| Route | Template | Purpose | +|---|---|---| +| `GET /` | `status.html` | Public status page (all visible services + recent incidents) | +| `GET /incident/{id}` | `incident.html` | Incident detail with update timeline | +| `GET /subscribe` | `subscribe.html` | Email subscription form | +| `GET /confirm` | `confirm.html` | Subscription confirmation | +| `GET /health` | JSON response | Health check for container orchestration | + +### 1.6 Background Services + +| Service | File | Description | +|---|---|---| +| **Uptime Checker** | `app/services/uptime.py` | HTTP health checks via httpx, stores `MonitorResult` | +| **Scheduler** | `app/services/scheduler.py` | APScheduler runs `_run_monitor_checks()` every 60s | +| **Notifier** | `app/services/notifier.py` | Email (SMTP) + webhook dispatch for incident updates | + +### 1.7 Tests + +6 passing tests across 3 files: +- `tests/test_health.py` — 1 test (health check) +- `tests/test_api_services.py` — 7 tests (full CRUD + 404 + auth) +- `tests/test_api_incidents.py` — 5 tests (CRUD + updates + delete) + +### 1.8 Key Gaps for SaaS + +| Gap | Impact | +|---|---| +| **No user accounts** | Cannot identify who owns what data | +| **No multi-tenancy** | All orgs share the same data; no isolation | +| **No feature limits** | Any user can create unlimited services/pages | +| **No payment integration** | No billing, no subscription management | +| **SQLite-only** | Won't handle concurrent multi-tenant writes well | +| **No Alembic migrations** | Schema changes require manual DB recreation | +| **Single-site config** | `SiteSetting` and `config.py` are global, not per-org | + +--- + +## 2. Multi-Tenancy Design + +### 2.1 Strategy: Shared Database, Tenant-ID Columns + +Use a **shared database with tenant-ID discriminator** approach. Every tenant-scoped +table gets an `organization_id` foreign key. All queries filter by the current tenant. +This is the simplest path that works with SQLite → PostgreSQL migration. + +### 2.2 New Models + +#### `User` — Individual who can log in + +```python +# app/models/models.py — NEW + +class User(Base): + __tablename__ = "users" + + id: Mapped[str] = mapped_column(String(36), primary_key=True, default=_uuid_str) + email: Mapped[str] = mapped_column(String(255), unique=True, nullable=False, index=True) + password_hash: Mapped[str] = mapped_column(String(255), nullable=False) + display_name: Mapped[str | None] = mapped_column(String(100), nullable=True) + is_email_verified: Mapped[bool] = mapped_column(Boolean, default=False) + email_verify_token: Mapped[str | None] = mapped_column(String(100), nullable=True) + created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow) + updated_at: Mapped[datetime] = mapped_column( + DateTime, default=datetime.utcnow, onupdate=datetime.utcnow + ) + + memberships: Mapped[list["OrganizationMember"]] = relationship(back_populates="user") +``` + +#### `Organization` — The tenant; owns status pages + +```python +class Organization(Base): + __tablename__ = "organizations" + + id: Mapped[str] = mapped_column(String(36), primary_key=True, default=_uuid_str) + slug: Mapped[str] = mapped_column(String(50), unique=True, nullable=False, index=True) + name: Mapped[str] = mapped_column(String(100), nullable=False) + tier: Mapped[str] = mapped_column(String(20), nullable=False, default="free") + # "free" | "pro" | "team" + stripe_customer_id: Mapped[str | None] = mapped_column(String(100), nullable=True) + custom_domain: Mapped[str | None] = mapped_column(String(255), nullable=True) + created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow) + updated_at: Mapped[datetime] = mapped_column( + DateTime, default=datetime.utcnow, onupdate=datetime.utcnow + ) + + members: Mapped[list["OrganizationMember"]] = relationship(back_populates="organization") + status_pages: Mapped[list["StatusPage"]] = relationship(back_populates="organization") +``` + +#### `OrganizationMember` — Joins users to orgs with roles + +```python +class OrganizationMember(Base): + __tablename__ = "organization_members" + + id: Mapped[str] = mapped_column(String(36), primary_key=True, default=_uuid_str) + organization_id: Mapped[str] = mapped_column( + String(36), ForeignKey("organizations.id"), nullable=False, index=True + ) + user_id: Mapped[str] = mapped_column( + String(36), ForeignKey("users.id"), nullable=False, index=True + ) + role: Mapped[str] = mapped_column(String(20), nullable=False, default="member") + # "owner" | "admin" | "member" + joined_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow) + + organization: Mapped["Organization"] = relationship(back_populates="members") + user: Mapped["User"] = relationship(back_populates="memberships") + + __table_args__ = ( + # Prevent duplicate memberships + {"sqlite_autoincrement": True}, + ) +``` + +#### `StatusPage` — Replaces the single global status view + +```python +class StatusPage(Base): + __tablename__ = "status_pages" + + id: Mapped[str] = mapped_column(String(36), primary_key=True, default=_uuid_str) + organization_id: Mapped[str] = mapped_column( + String(36), ForeignKey("organizations.id"), nullable=False, index=True + ) + slug: Mapped[str] = mapped_column(String(50), nullable=False, index=True) + title: Mapped[str] = mapped_column(String(100), nullable=False) + subdomain: Mapped[str | None] = mapped_column(String(100), nullable=True) + custom_domain: Mapped[str | None] = mapped_column(String(255), nullable=True) + is_public: Mapped[bool] = mapped_column(Boolean, default=True) + created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow) + updated_at: Mapped[datetime] = mapped_column( + DateTime, default=datetime.utcnow, onupdate=datetime.utcnow + ) + + organization: Mapped["Organization"] = relationship(back_populates="status_pages") + + __table_args__ = ( + # Unique slug within organization + UniqueConstraint("organization_id", "slug", name="uq_status_page_org_slug"), + ) +``` + +### 2.3 Tenant-ID on Existing Models + +Every existing model (except `SiteSetting`) gains an `organization_id` FK: + +```python +# Add to Service: +organization_id: Mapped[str] = mapped_column( + String(36), ForeignKey("organizations.id"), nullable=False, index=True +) + +# Add to Subscriber: +organization_id: Mapped[str] = mapped_column( + String(36), ForeignKey("organizations.id"), nullable=False, index=True +) + +# SiteSetting gets organization_id (nullable for global settings): +organization_id: Mapped[str | None] = mapped_column( + String(36), ForeignKey("organizations.id"), nullable=True, index=True +) +``` + +`Incident`, `IncidentUpdate`, `Monitor`, `MonitorResult`, and `NotificationLog` are +implicitly scoped through their parent `Service` → no direct `organization_id` needed +if all queries join through `Service`. However, for query performance, add a +denormalized `organization_id` on `Incident` and `Monitor`: + +```python +# Add to Incident: +organization_id: Mapped[str] = mapped_column( + String(36), ForeignKey("organizations.id"), nullable=False, index=True +) + +# Add to Monitor: +organization_id: Mapped[str] = mapped_column( + String(36), ForeignKey("organizations.id"), nullable=False, index=True +) +``` + +### 2.4 Auth Flow + +Replace `verify_api_key` with JWT-based user auth: + +**New file: `app/api/auth.py`** + +```python +from datetime import datetime, timedelta +from fastapi import APIRouter, Depends, HTTPException, status +from fastapi.security import OAuth2PasswordBearer +from jose import jwt, JWTError +from pydantic import BaseModel, EmailStr +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.config import settings +from app.dependencies import get_db +from app.models.models import User, OrganizationMember, Organization + +router = APIRouter() + +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login") + +class LoginRequest(BaseModel): + email: EmailStr + password: str + +class RegisterRequest(BaseModel): + email: EmailStr + password: str + display_name: str | None = None + org_name: str # Name for auto-created org + +class TokenResponse(BaseModel): + access_token: str + token_type: str = "bearer" + +def create_access_token(user_id: str, exp_hours: int = 72) -> str: + payload = { + "sub": user_id, + "exp": datetime.utcnow() + timedelta(hours=exp_hours), + } + return jwt.encode(payload, settings.secret_key, algorithm="HS256") + +async def get_current_user( + token: str = Depends(oauth2_scheme), + db: AsyncSession = Depends(get_db), +) -> User: + """Dependency: extract and validate current user from JWT.""" + try: + payload = jwt.decode(token, settings.secret_key, algorithms=["HS256"]) + user_id = payload.get("sub") + except JWTError: + raise HTTPException(status_code=401, detail="Invalid token") + + result = await db.execute(select(User).where(User.id == user_id)) + user = result.scalar_one_or_none() + if not user: + raise HTTPException(status_code=401, detail="User not found") + return user + +@router.post("/register", response_model=TokenResponse, status_code=status.HTTP_201_CREATED) +async def register(data: RegisterRequest, db: AsyncSession = Depends(get_db)): + """Register a new user + auto-create a personal organization.""" + # Check if email exists + existing = await db.execute(select(User).where(User.email == data.email)) + if existing.scalar_one_or_none(): + raise HTTPException(status_code=409, detail="Email already registered") + + from passlib.context import CryptContext + pwd_ctx = CryptContext(schemes=["bcrypt"], deprecated="auto") + + user = User( + email=data.email, + password_hash=pwd_ctx.hash(data.password), + display_name=data.display_name, + ) + db.add(user) + await db.flush() + + # Auto-create personal org + org_slug = data.email.split("@")[0].lower().replace(".", "-") + org = Organization(name=data.org_name, slug=org_slug, tier="free") + db.add(org) + await db.flush() + + membership = OrganizationMember( + organization_id=org.id, user_id=user.id, role="owner" + ) + db.add(membership) + await db.flush() + + token = create_access_token(user.id) + return TokenResponse(access_token=token) + +@router.post("/login", response_model=TokenResponse) +async def login(data: LoginRequest, db: AsyncSession = Depends(get_db)): + result = await db.execute(select(User).where(User.email == data.email)) + user = result.scalar_one_or_none() + if not user: + raise HTTPException(status_code=401, detail="Invalid credentials") + + from passlib.context import CryptContext + pwd_ctx = CryptContext(schemes=["bcrypt"], deprecated="auto") + if not pwd_ctx.verify(data.password, user.password_hash): + raise HTTPException(status_code=401, detail="Invalid credentials") + + token = create_access_token(user.id) + return TokenResponse(access_token=token) +``` + +### 2.5 Tenant-Scoped Query Pattern + +Replace all `select(Service)` queries with tenant-filtered variants: + +```python +# Before (app/api/services.py line 51): +result = await db.execute(select(Service).order_by(Service.position, Service.name)) + +# After: +async def get_current_org( + user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +) -> Organization: + """Get the user's first org (simplified; add org-switching later).""" + result = await db.execute( + select(OrganizationMember).where(OrganizationMember.user_id == user.id) + ) + membership = result.scalars().first() + if not membership: + raise HTTPException(status_code=403, detail="No organization") + org_result = await db.execute( + select(Organization).where(Organization.id == membership.organization_id) + ) + return org_result.scalar_one() + +# In endpoints: +org = Depends(get_current_org) + +result = await db.execute( + select(Service) + .where(Service.organization_id == org.id) + .order_by(Service.position, Service.name) +) +``` + +### 2.6 Multi-Page Public Routes + +Current `GET /` serves one global page. New design: + +| Route | Purpose | +|---|---| +| `GET /` | Landing / marketing page | +| `GET /p/{org_slug}` | Org's default status page | +| `GET /p/{org_slug}/{page_slug}` | Specific status page | +| `GET /p/{org_slug}/incident/{id}` | Incident detail within org context | +| Custom domain | Resolves to `StatusPage.custom_domain` → renders that page | + +```python +# app/main.py — NEW route + +@app.get("/p/{org_slug}/{page_slug}") +async def public_status_page(request: Request, org_slug: str, page_slug: str): + from sqlalchemy import select + from app.database import async_session_factory + from app.models.models import Organization, StatusPage, Service, Incident + + async with async_session_factory() as db: + # Resolve org + page + org_result = await db.execute( + select(Organization).where(Organization.slug == org_slug) + ) + org = org_result.scalar_one_or_none() + if not org: + raise HTTPException(status_code=404) + + page_result = await db.execute( + select(StatusPage).where( + StatusPage.organization_id == org.id, + StatusPage.slug == page_slug, + ) + ) + page = page_result.scalar_one_or_none() + if not page: + raise HTTPException(status_code=404) + + # Get services for this org + result = await db.execute( + select(Service) + .where(Service.organization_id == org.id, Service.is_visible == True) + .order_by(Service.position, Service.name) + ) + services = result.scalars().all() + # ... render template with org + page context +``` + +### 2.7 Custom Domain Resolution + +Add middleware to resolve custom domains: + +```python +# app/middleware.py — NEW + +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.requests import Request +from sqlalchemy import select +from app.database import async_session_factory +from app.models.models import StatusPage, Organization + +class CustomDomainMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request: Request, call_next): + host = request.headers.get("host", "").split(":")[0] + # Skip known app domains + if host in ("korpo.pro", "localhost", "127.0.0.1"): + return await call_next(request) + + async with async_session_factory() as db: + result = await db.execute( + select(StatusPage).where(StatusPage.custom_domain == host) + ) + page = result.scalar_one_or_none() + if page: + # Store page info in request state for downstream use + request.state.custom_domain_page = page + + return await call_next(request) +``` + +--- + +## 3. Feature Tiers + +### 3.1 Tier Definitions + +| Feature | Free | Pro ($9/mo) | Team ($29/mo) | +|---|---|---|---| +| **Status pages** | 1 | 5 | Unlimited | +| **Services per page** | 5 | 50 | Unlimited | +| **Monitors per service** | 1 | 5 | Unlimited | +| **Subscribers** | 25 | 500 | Unlimited | +| **Uptime check interval** | 5 min | 1 min | 30 sec | +| **Custom domain** | ❌ | ✅ | ✅ | +| **Custom branding/CSS** | ❌ | ✅ | ✅ | +| **Team members** | 1 | 3 | Unlimited | +| **Incident history** | 30 days | 1 year | Unlimited | +| **Webhook notifications** | ❌ | ✅ | ✅ | +| **API access** | ❌ | ✅ | ✅ | +| **Email notifications** | ✅ | ✅ | ✅ | +| **SLA badge widget** | ❌ | ✅ | ✅ | +| **Status page password protection** | ❌ | ❌ | ✅ | + +### 3.2 Enforcement Layer + +Create a tier-limits module that checks against the org's tier before allowing writes: + +**New file: `app/services/tier_limits.py`** + +```python +from fastapi import HTTPException +from sqlalchemy import select, func +from sqlalchemy.ext.asyncio import AsyncSession + +from app.models.models import ( + Organization, Service, Monitor, Subscriber, StatusPage, OrganizationMember +) + +TIER_LIMITS = { + "free": { + "status_pages": 1, + "services_per_page": 5, + "monitors_per_service": 1, + "subscribers": 25, + "members": 1, + "check_interval_min": 5, + "custom_domain": False, + "webhooks": False, + "api_access": False, + }, + "pro": { + "status_pages": 5, + "services_per_page": 50, + "monitors_per_service": 5, + "subscribers": 500, + "members": 3, + "check_interval_min": 1, + "custom_domain": True, + "webhooks": True, + "api_access": True, + }, + "team": { + "status_pages": -1, # unlimited + "services_per_page": -1, + "monitors_per_service": -1, + "subscribers": -1, + "members": -1, + "check_interval_min": 0, # 30 sec + "custom_domain": True, + "webhooks": True, + "api_access": True, + }, +} + +def get_limit(org: Organization, feature: str): + """Return the limit for a feature given the org's tier.""" + tier = org.tier or "free" + return TIER_LIMITS.get(tier, TIER_LIMITS["free"]).get(feature) + +async def enforce_limit( + db: AsyncSession, org: Organization, feature: str, current_count: int +) -> None: + """Raise 403 if the current count meets or exceeds the tier limit.""" + limit = get_limit(org, feature) + if limit == -1: # unlimited + return + if limit is False: # feature not available + raise HTTPException( + status_code=403, + detail=f"Feature '{feature}' requires a plan upgrade", + ) + if current_count >= limit: + raise HTTPException( + status_code=403, + detail=f"Tier limit reached for '{feature}' ({limit}). Upgrade your plan.", + ) + +# Concrete helpers: + +async def check_status_page_limit(db: AsyncSession, org: Organization): + count = await db.execute( + select(func.count(StatusPage.id)).where(StatusPage.organization_id == org.id) + ) + await enforce_limit(db, org, "status_pages", count.scalar() or 0) + +async def check_service_limit(db: AsyncSession, org: Organization, page_id: str): + count = await db.execute( + select(func.count(Service.id)).where( + Service.organization_id == org.id, + Service.status_page_id == page_id, + ) + ) + await enforce_limit(db, org, "services_per_page", count.scalar() or 0) + +async def check_monitor_limit(db: AsyncSession, org: Organization, service_id: str): + count = await db.execute( + select(func.count(Monitor.id)).where(Monitor.service_id == service_id) + ) + await enforce_limit(db, org, "monitors_per_service", count.scalar() or 0) + +async def check_subscriber_limit(db: AsyncSession, org: Organization): + count = await db.execute( + select(func.count(Subscriber.id)).where(Subscriber.organization_id == org.id) + ) + await enforce_limit(db, org, "subscribers", count.scalar() or 0) + +async def check_member_limit(db: AsyncSession, org: Organization): + count = await db.execute( + select(func.count(OrganizationMember.id)).where( + OrganizationMember.organization_id == org.id + ) + ) + await enforce_limit(db, org, "members", count.scalar() or 0) +``` + +### 3.3 Usage in Endpoints + +```python +# app/api/services.py — modify create_service + +@router.post("/", status_code=status.HTTP_201_CREATED) +async def create_service( + data: ServiceCreate, + org: Organization = Depends(get_current_org), + user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + from app.services.tier_limits import check_service_limit + await check_service_limit(db, org, data.status_page_id) # ENFORCE LIMIT + + service = Service( + organization_id=org.id, + name=data.name, + slug=data.slug, + # ... + ) + db.add(service) + # ... +``` + +--- + +## 4. Payment Integration — Stripe Checkout Links + +### 4.1 Approach: No-Code Payment Links + +Use **Stripe Checkout Links** (Payment Links) — pre-built hosted checkout pages +that require zero server-side Stripe integration. When a user clicks "Upgrade", +they're redirected to a Stripe-hosted checkout page. After payment, Stripe +redirects back with a `?session_id` parameter. A webhook marks the org as upgraded. + +### 4.2 Setup in Stripe Dashboard + +1. Create 2 Products: "Pro Plan" ($9/mo) and "Team Plan" ($29/mo) +2. Create recurring prices for each +3. Generate **Payment Links** for each price: + - Pro: `https://buy.stripe.com/xxxx_pro_9_monthly` + - Team: `https://buy.stripe.com/xxxx_team_29_monthly` +4. Set success URL: `https://korpo.pro/api/v1/billing/success?session_id={CHECKOUT_SESSION_ID}` +5. Set cancel URL: `https://korpo.pro/billing?canceled=1` + +### 4.3 Config — Add to `app/config.py` + +```python +class Settings(BaseSettings): + # ... existing fields ... + + # Stripe + stripe_pro_checkout_url: str = "https://buy.stripe.com/xxxx_pro_9_monthly" + stripe_team_checkout_url: str = "https://buy.stripe.com/xxxx_team_29_monthly" + stripe_webhook_secret: str = "whsec_xxxx" +``` + +### 4.4 Checkout Redirect Endpoint + +**New file: `app/api/billing.py`** + +```python +from fastapi import APIRouter, Depends, HTTPException, Request +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.config import settings +from app.dependencies import get_db +from app.api.auth import get_current_user +from app.models.models import User, Organization, OrganizationMember + +router = APIRouter() + +@router.get("/checkout/{tier}") +async def create_checkout( + tier: str, + user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + """Redirect user to Stripe Checkout Link for the selected tier.""" + if tier not in ("pro", "team"): + raise HTTPException(status_code=400, detail="Invalid tier") + + # Get user's org + result = await db.execute( + select(OrganizationMember).where(OrganizationMember.user_id == user.id) + ) + membership = result.scalars().first() + if not membership: + raise HTTPException(status_code=403, detail="No organization") + + # Select checkout URL based on tier + checkout_url = ( + settings.stripe_pro_checkout_url if tier == "pro" + else settings.stripe_team_checkout_url + ) + + # Append metadata so Stripe webhook knows which org to upgrade + # Stripe Payment Links support client_reference_id via URL param: + from urllib.parse import urlencode + params = urlencode({ + "client_reference_id": membership.organization_id, + "prefilled_email": user.email, + }) + + return {"redirect_url": f"{checkout_url}?{params}"} +``` + +### 4.5 Stripe Webhook Handler + +This is the **only** server-side Stripe code needed. It receives events when +payments succeed or subscriptions change. + +```python +# app/api/billing.py — continued + +import stripe +from fastapi import Header + +@router.post("/webhook") +async def stripe_webhook(request: Request, db: AsyncSession = Depends(get_db)): + """Handle Stripe webhook events for subscription changes.""" + body = await request.body() + sig = request.headers.get("stripe-signature", "") + + try: + event = stripe.Webhook.construct_event( + body, sig, settings.stripe_webhook_secret + ) + except Exception: + raise HTTPException(status_code=400, detail="Invalid signature") + + if event["type"] == "checkout.session.completed": + session = event["data"]["object"] + org_id = session.get("client_reference_id") + customer_id = session.get("customer") + + if org_id: + result = await db.execute( + select(Organization).where(Organization.id == org_id) + ) + org = result.scalar_one_or_none() + if org: + # Determine tier from the line items + amount = session.get("amount_total", 0) + if amount >= 2900: # $29+ + org.tier = "team" + elif amount >= 900: # $9+ + org.tier = "pro" + org.stripe_customer_id = customer_id + await db.flush() + + elif event["type"] == "customer.subscription.deleted": + # Downgrade to free on cancellation + customer_id = event["data"]["object"]["customer"] + result = await db.execute( + select(Organization).where(Organization.stripe_customer_id == customer_id) + ) + org = result.scalar_one_or_none() + if org: + org.tier = "free" + await db.flush() + + return {"status": "ok"} +``` + +### 4.6 Success Callback + +```python +@router.get("/success") +async def billing_success(session_id: str, db: AsyncSession = Depends(get_db)): + """Stripe redirects here after successful checkout.""" + # In production, call stripe.checkout.sessions.retrieve(session_id) to verify. + # For MVP, webhook already processed it — just show a success page. + return {"message": "Subscription activated! Your plan has been upgraded."} +``` + +### 4.7 Billing Page Template + +```html + +{% extends "base.html" %} +{% block title %}Billing{% endblock %} +{% block content %} +
+

Choose Your Plan

+ +
+
+

Free

+

$0/mo

+
    +
  • 1 status page
  • +
  • 5 services
  • +
  • 25 subscribers
  • +
  • Email notifications
  • +
+ {% if org.tier == 'free' %}

✅ Current plan

+ {% else %}Downgrade{% endif %} +
+ +
+

Pro

+

$9/mo

+
    +
  • 5 status pages
  • +
  • 50 services per page
  • +
  • Custom domain
  • +
  • Webhook notifications
  • +
  • API access
  • +
+ {% if org.tier == 'pro' %}

✅ Current plan

+ {% else %}Upgrade to Pro{% endif %} +
+ +
+

Team

+

$29/mo

+
    +
  • Unlimited pages & services
  • +
  • Unlimited team members
  • +
  • 30-second check intervals
  • +
  • Password-protected pages
  • +
+ {% if org.tier == 'team' %}

✅ Current plan

+ {% else %}Upgrade to Team{% endif %} +
+
+
+{% endblock %} +``` + +--- + +## 5. Database Migration Steps + +### 5.1 Set Up Alembic (Currently a Placeholder) + +The `alembic.ini` file is a stub. Full setup required: + +```bash +cd ~/wealth-engine/indie-status-page +alembic init migrations # already has dir structure, but re-init config +``` + +Edit `alembic.ini`: + +```ini +sqlalchemy.url = sqlite+aiosqlite:///./data/statuspage.db +``` + +Edit `migrations/env.py` to import `Base` and all models: + +```python +from app.database import Base +from app.models.models import * # ensure all models imported +target_metadata = Base.metadata +``` + +### 5.2 Migration 001: Add User + Organization Tables + +**Create a baseline migration for existing schema, then add new tables.** + +```bash +alembic revision --autogenerate -m "001_add_users_and_organizations" +``` + +```python +# migrations/versions/001_add_users_and_organizations.py + +def upgrade() -> None: + # New tables — no existing data affected + op.create_table( + "users", + sa.Column("id", sa.String(36), primary_key=True), + sa.Column("email", sa.String(255), unique=True, nullable=False), + sa.Column("password_hash", sa.String(255), nullable=False), + sa.Column("display_name", sa.String(100)), + sa.Column("is_email_verified", sa.Boolean, default=False), + sa.Column("email_verify_token", sa.String(100)), + sa.Column("created_at", sa.DateTime), + sa.Column("updated_at", sa.DateTime), + ) + op.create_index("ix_users_email", "users", ["email"]) + + op.create_table( + "organizations", + sa.Column("id", sa.String(36), primary_key=True), + sa.Column("slug", sa.String(50), unique=True, nullable=False), + sa.Column("name", sa.String(100), nullable=False), + sa.Column("tier", sa.String(20), nullable=False, server_default="free"), + sa.Column("stripe_customer_id", sa.String(100)), + sa.Column("custom_domain", sa.String(255)), + sa.Column("created_at", sa.DateTime), + sa.Column("updated_at", sa.DateTime), + ) + op.create_index("ix_organizations_slug", "organizations", ["slug"]) + + op.create_table( + "organization_members", + sa.Column("id", sa.String(36), primary_key=True), + sa.Column("organization_id", sa.String(36), + sa.ForeignKey("organizations.id"), nullable=False), + sa.Column("user_id", sa.String(36), + sa.ForeignKey("users.id"), nullable=False), + sa.Column("role", sa.String(20), nullable=False, server_default="member"), + sa.Column("joined_at", sa.DateTime), + ) + op.create_index("ix_org_members_org", "organization_members", ["organization_id"]) + op.create_index("ix_org_members_user", "organization_members", ["user_id"]) + + op.create_table( + "status_pages", + sa.Column("id", sa.String(36), primary_key=True), + sa.Column("organization_id", sa.String(36), + sa.ForeignKey("organizations.id"), nullable=False), + sa.Column("slug", sa.String(50), nullable=False), + sa.Column("title", sa.String(100), nullable=False), + sa.Column("subdomain", sa.String(100)), + sa.Column("custom_domain", sa.String(255)), + sa.Column("is_public", sa.Boolean, default=True), + sa.Column("created_at", sa.DateTime), + sa.Column("updated_at", sa.DateTime), + sa.UniqueConstraint("organization_id", "slug", name="uq_status_page_org_slug"), + ) + op.create_index("ix_status_pages_org", "status_pages", ["organization_id"]) + + +def downgrade() -> None: + op.drop_table("status_pages") + op.drop_table("organization_members") + op.drop_table("organizations") + op.drop_table("users") +``` + +### 5.3 Migration 002: Add organization_id to Existing Tables + +```bash +alembic revision -m "002_add_organization_id_to_existing_tables" +``` + +```python +# migrations/versions/002_add_organization_id_to_existing_tables.py + +def upgrade() -> None: + # Add nullable org_id columns first + op.add_column("services", sa.Column( + "organization_id", sa.String(36), sa.ForeignKey("organizations.id"), nullable=True + )) + op.add_column("incidents", sa.Column( + "organization_id", sa.String(36), sa.ForeignKey("organizations.id"), nullable=True + )) + op.add_column("monitors", sa.Column( + "organization_id", sa.String(36), sa.ForeignKey("organizations.id"), nullable=True + )) + op.add_column("subscribers", sa.Column( + "organization_id", sa.String(36), sa.ForeignKey("organizations.id"), nullable=True + )) + op.add_column("site_settings", sa.Column( + "organization_id", sa.String(36), sa.ForeignKey("organizations.id"), nullable=True + )) + + # Add status_page_id column to services + op.add_column("services", sa.Column( + "status_page_id", sa.String(36), sa.ForeignKey("status_pages.id"), nullable=True + )) + + # Create indexes + for table in ("services", "incidents", "monitors", "subscribers", "site_settings"): + op.create_index(f"ix_{table}_org", table, ["organization_id"]) + op.create_index("ix_services_page", "services", ["status_page_id"]) + + +def downgrade() -> None: + for table in ("services", "incidents", "monitors", "subscribers", "site_settings"): + op.drop_index(f"ix_{table}_org", table) + op.drop_column(table, "organization_id") + op.drop_index("ix_services_page", "services") + op.drop_column("services", "status_page_id") +``` + +### 5.4 Migration 003: Data Backfill — Migrate Existing Data to a Default Org + +```bash +alembic revision -m "003_backfill_organization_id" +``` + +```python +# migrations/versions/003_backfill_organization_id.py + +from uuid import uuid4 + +def upgrade() -> None: + # 1. Create a default org for existing data + default_org_id = str(uuid4()) + op.execute(f""" + INSERT INTO organizations (id, slug, name, tier, created_at, updated_at) + VALUES ('{default_org_id}', 'default', 'Default Organization', 'free', + datetime('now'), datetime('now')) + """) + + # 2. Create a default status page + default_page_id = str(uuid4()) + op.execute(f""" + INSERT INTO status_pages (id, organization_id, slug, title, is_public, + created_at, updated_at) + VALUES ('{default_page_id}', '{default_org_id}', 'default', + 'Status Page', 1, datetime('now'), datetime('now')) + """) + + # 3. Backfill organization_id on all existing rows + for table in ("services", "incidents", "monitors", "subscribers", "site_settings"): + op.execute(f""" + UPDATE {table} SET organization_id = '{default_org_id}' + WHERE organization_id IS NULL + """) + + # 4. Backfill status_page_id on services + op.execute(f""" + UPDATE services SET status_page_id = '{default_page_id}' + WHERE status_page_id IS NULL + """) + + # 5. Now make organization_id NOT NULL + # (SQLite doesn't support ALTER COLUMN, so recreate tables — + # in PostgreSQL, this is just ALTER COLUMN SET NOT NULL) + # For SQLite, we'll keep nullable and enforce at app level. + # Document: enforce in model definition with nullable=False + # after migrating to PostgreSQL. + + +def downgrade() -> None: + # Cannot reverse the backfill meaningfully + pass +``` + +### 5.5 Migration 004: SQLite → PostgreSQL (Production) + +For production multi-tenancy, migrate from SQLite to PostgreSQL: + +1. Add `psycopg2-binary` + `asyncpg` to `pyproject.toml` dependencies +2. Change `DATABASE_URL` env var to `postgresql+asyncpg://...` +3. Use `pg_dump`/`pg_restore` or SQLAlchemy data migration script +4. Enable proper `ALTER COLUMN SET NOT NULL` on `organization_id` fields + +```toml +# pyproject.toml — add to dependencies +"asyncpg>=0.29,<1.0", +"psycopg2-binary>=2.9,<3.0", +``` + +```python +# app/config.py — update default +database_url: str = "postgresql+asyncpg://statuspage:password@localhost/statuspage" +``` + +--- + +## 6. Priority-Ordered Implementation Steps + +### Phase 1: Foundation (Week 1–2) + +| # | Step | Files | Effort | +|---|---|---|---| +| 1 | **Set up Alembic properly** with all existing models as baseline | `alembic.ini`, `migrations/env.py`, `migrations/versions/000_baseline.py` | 2h | +| 2 | **Add User model** + password hashing (passlib/bcrypt) | `app/models/models.py`, `pyproject.toml` (add `passlib[bcrypt]`, `python-jose`) | 2h | +| 3 | **Add Organization + OrganizationMember + StatusPage models** | `app/models/models.py` | 2h | +| 4 | **Create Alembic migration 001** (new tables only) | `migrations/versions/001_*.py` | 1h | +| 5 | **Build auth endpoints** (register, login, JWT) | `app/api/auth.py` (new), `app/api/router.py` (add route) | 4h | +| 6 | **Create `get_current_user` and `get_current_org` dependencies** | `app/dependencies.py` (extend) | 2h | +| 7 | **Write auth tests** (register, login, token validation) | `tests/test_api_auth.py` (new) | 3h | + +### Phase 2: Multi-Tenancy (Week 2–3) + +| # | Step | Files | Effort | +|---|---|---|---| +| 8 | **Add `organization_id` to existing models + migration 002** | `app/models/models.py`, `migrations/versions/002_*.py` | 3h | +| 9 | **Data backfill migration 003** — create default org, assign all existing data | `migrations/versions/003_*.py` | 2h | +| 10 | **Refactor all API endpoints** to filter by `organization_id` | `app/api/services.py`, `incidents.py`, `monitors.py`, `subscribers.py`, `settings.py` | 6h | +| 11 | **Add `status_page_id` to Service** and update queries to scope by page | `app/models/models.py`, all API files | 2h | +| 12 | **Build multi-page public routes** (`/p/{org_slug}/{page_slug}`) | `app/main.py`, `app/templates/status.html` (parametrize by org) | 4h | +| 13 | **Build org management API** (create org, invite members, switch org) | `app/api/organizations.py` (new) | 4h | +| 14 | **Build status page CRUD API** (create/edit/delete pages) | `app/api/pages.py` (new) | 3h | +| 15 | **Write multi-tenancy tests** (data isolation, tenant-scoped CRUD) | `tests/test_api_tenancy.py` (new) | 3h | + +### Phase 3: Feature Tiers (Week 3–4) + +| # | Step | Files | Effort | +|---|---|---|---| +| 16 | **Create `tier_limits.py`** with all limit definitions and enforcement | `app/services/tier_limits.py` (new) | 2h | +| 17 | **Integrate limit checks** into all write endpoints | `app/api/services.py`, `monitors.py`, `subscribers.py`, `pages.py`, `organizations.py` | 3h | +| 18 | **Build billing UI** — plan selection page with upgrade CTAs | `app/templates/billing.html` (new), `app/static/css/style.css` | 3h | +| 19 | **Add billing route** that serves the billing page with org context | `app/main.py` (add route) | 1h | +| 20 | **Write tier limit tests** (enforce limits, verify free/pro/team boundaries) | `tests/test_tier_limits.py` (new) | 3h | + +### Phase 4: Stripe Payments (Week 4–5) + +| # | Step | Files | Effort | +|---|---|---|---| +| 21 | **Create Stripe Payment Links** in Stripe Dashboard (manual, no code) | Stripe Dashboard | 0.5h | +| 22 | **Add Stripe config** to settings | `app/config.py` | 0.5h | +| 23 | **Build billing API** — checkout redirect + webhook handler | `app/api/billing.py` (new), `app/api/router.py` | 4h | +| 24 | **Add stripe Python SDK** to dependencies | `pyproject.toml` (add `stripe>=8.0`) | 0.5h | +| 25 | **Configure Stripe webhook endpoint** in Stripe Dashboard → `https://korpo.pro/api/v1/billing/webhook` | Stripe Dashboard | 0.5h | +| 26 | **Build success/cancel pages** | `app/templates/billing_success.html` (new) | 1h | +| 27 | **Write webhook tests** with mock Stripe events | `tests/test_api_billing.py` (new) | 3h | + +### Phase 5: Custom Domains & Polish (Week 5–6) + +| # | Step | Files | Effort | +|---|---|---|---| +| 28 | **Add custom domain middleware** | `app/middleware.py` (new), `app/main.py` (register) | 3h | +| 29 | **Build landing page** at `GET /` (replace current global status route) | `app/templates/landing.html` (new), `app/main.py` | 4h | +| 30 | **Add admin dashboard** — overview of org's pages, services, incidents | `app/templates/dashboard.html` (new), `app/api/dashboard.py` (new) | 6h | +| 31 | **Proper PostgreSQL migration** (migration 004) | `pyproject.toml`, `app/config.py`, `docker-compose.yml` (add postgres service) | 4h | +| 32 | **Nginx update** — wildcard subdomain or custom domain routing | nginx config | 2h | +| 33 | **E2E test suite** — full user journey from signup → billing → status page | `tests/test_e2e.py` (new) | 4h | + +### Phase 6: Launch (Week 6–7) + +| # | Step | Files | Effort | +|---|---|---|---| +| 34 | **Add `python-jose` + `passlib[bcrypt]` + `stripe`** to deps | `pyproject.toml` | 0.5h | +| 35 | **Rate limiting middleware** for public endpoints | `app/middleware.py` | 2h | +| 36 | **Email verification flow** for user signups | `app/api/auth.py`, `app/services/notifier.py` | 3h | +| 37 | **CORS configuration** for API access (Pro+ feature) | `app/main.py` | 1h | +| 38 | **Production deployment** — update systemd service, env vars, DB | `indie-status-page.service` (env file), nginx | 2h | +| 39 | **Documentation** — API docs, self-serve guide | `README.md` update, `docs/` | 4h | + +--- + +## Appendix A: New Dependencies + +Add to `pyproject.toml`: + +```toml +dependencies = [ + # ... existing ... + "python-jose[cryptography]>=3.3,<4.0", # JWT tokens + "passlib[bcrypt]>=1.7,<2.0", # Password hashing + "python-multipart>=0.0.6,<1.0", # OAuth2 form parsing (already present) + "stripe>=8.0,<9.0", # Stripe SDK (webhook verification) + "email-validator>=2.1,<3.0", # Pydantic EmailStr validation +] + +# For PostgreSQL (Phase 5): +# "asyncpg>=0.29,<1.0", +``` + +## Appendix B: Updated `app/models/__init__.py` + +```python +from app.models.models import ( + User, + Organization, + OrganizationMember, + StatusPage, + Service, + Incident, + IncidentUpdate, + Monitor, + MonitorResult, + Subscriber, + NotificationLog, + SiteSetting, +) + +__all__ = [ + "User", + "Organization", + "OrganizationMember", + "StatusPage", + "Service", + "Incident", + "IncidentUpdate", + "Monitor", + "MonitorResult", + "Subscriber", + "NotificationLog", + "SiteSetting", +] +``` + +## Appendix C: Router Registration Updates + +```python +# app/api/router.py — updated + +from fastapi import APIRouter + +from app.api.auth import router as auth_router +from app.api.billing import router as billing_router +from app.api.services import router as services_router +from app.api.incidents import router as incidents_router +from app.api.monitors import router as monitors_router +from app.api.subscribers import router as subscribers_router +from app.api.settings import router as settings_router +from app.api.pages import router as pages_router # NEW +from app.api.organizations import router as org_router # NEW +from app.api.dashboard import router as dashboard_router # NEW + +api_v1_router = APIRouter() + +api_v1_router.include_router(auth_router, prefix="/auth", tags=["auth"]) +api_v1_router.include_router(billing_router, prefix="/billing", tags=["billing"]) +api_v1_router.include_router(services_router, prefix="/services", tags=["services"]) +api_v1_router.include_router(incidents_router, prefix="/incidents", tags=["incidents"]) +api_v1_router.include_router(monitors_router, prefix="/monitors", tags=["monitors"]) +api_v1_router.include_router(subscribers_router, prefix="/subscribers", tags=["subscribers"]) +api_v1_router.include_router(settings_router, prefix="/settings", tags=["settings"]) +api_v1_router.include_router(pages_router, prefix="/pages", tags=["pages"]) +api_v1_router.include_router(org_router, prefix="/organizations", tags=["organizations"]) +api_v1_router.include_router(dashboard_router, prefix="/dashboard", tags=["dashboard"]) +``` + +## Appendix D: Key File Map (New + Modified) + +| File | Action | Phase | +|---|---|---| +| `app/models/models.py` | **Modify** — add User, Organization, OrganizationMember, StatusPage; add organization_id to existing models | 1–2 | +| `app/dependencies.py` | **Modify** — add get_current_user, get_current_org | 1 | +| `app/config.py` | **Modify** — add Stripe URLs, JWT secret fields | 1, 4 | +| `app/api/auth.py` | **Create** — register, login, JWT logic | 1 | +| `app/api/billing.py` | **Create** — checkout redirect, webhook, success | 4 | +| `app/api/pages.py` | **Create** — StatusPage CRUD | 2 | +| `app/api/organizations.py` | **Create** — org management, invite | 2 | +| `app/api/dashboard.py` | **Create** — admin dashboard API | 5 | +| `app/services/tier_limits.py` | **Create** — tier enforcement | 3 | +| `app/middleware.py` | **Create** — custom domain resolution | 5 | +| `app/api/services.py` | **Modify** — add org scoping + tier limit checks | 2 | +| `app/api/incidents.py` | **Modify** — add org scoping | 2 | +| `app/api/monitors.py` | **Modify** — add org scoping + tier limit checks | 2 | +| `app/api/subscribers.py` | **Modify** — add org scoping + tier limit checks | 2 | +| `app/api/settings.py` | **Modify** — add org scoping | 2 | +| `app/api/router.py` | **Modify** — register new routers | 1, 2, 4 | +| `app/main.py` | **Modify** — add multi-page routes, landing, middleware, dashboard | 2, 5 | +| `app/templates/billing.html` | **Create** — plan selection page | 3 | +| `app/templates/landing.html` | **Create** — marketing homepage | 5 | +| `app/templates/dashboard.html` | **Create** — admin dashboard | 5 | +| `app/templates/billing_success.html` | **Create** — post-payment success | 4 | +| `app/static/css/style.css` | **Modify** — add billing, dashboard, plan card styles | 3 | +| `pyproject.toml` | **Modify** — add new dependencies | 1, 4 | +| `alembic.ini` | **Modify** — proper config | 1 | +| `migrations/env.py` | **Modify** — import Base + models | 1 | +| `migrations/versions/001_*.py` | **Create** — user/org tables | 1 | +| `migrations/versions/002_*.py` | **Create** — organization_id on existing tables | 2 | +| `migrations/versions/003_*.py` | **Create** — data backfill | 2 | +| `docker-compose.yml` | **Modify** — add PostgreSQL service | 5 | +| `tests/test_api_auth.py` | **Create** | 1 | +| `tests/test_api_tenancy.py` | **Create** | 2 | +| `tests/test_tier_limits.py` | **Create** | 3 | +| `tests/test_api_billing.py` | **Create** | 4 | +| `tests/test_e2e.py` | **Create** | 5 | + +--- + +*Plan generated: 2026-04-25 | Total estimated effort: ~6–7 weeks for a solo developer* \ No newline at end of file diff --git a/app/api/organizations.py b/app/api/organizations.py new file mode 100644 index 0000000..b84376f --- /dev/null +++ b/app/api/organizations.py @@ -0,0 +1,155 @@ +"""Organization API endpoints: view org, list tiers, upgrade/downgrade.""" + +from fastapi import APIRouter, Depends, HTTPException, status +from pydantic import BaseModel +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.auth import get_current_user, get_current_org +from app.dependencies import get_db +from app.models.saas_models import Organization, OrganizationMember, User +from app.services.tier_limits import ( + TIER_LIMITS, + get_org_limits, + get_tier_info, +) + +router = APIRouter(tags=["organizations"]) + + +# ── Response schemas ──────────────────────────────────────────────────────── + +class OrgMemberResponse(BaseModel): + user_id: str + email: str + display_name: str | None + role: str + + +class OrgResponse(BaseModel): + id: str + slug: str + name: str + tier: str + custom_domain: str | None + member_count: int + tier_info: dict + + +class TierDetailResponse(BaseModel): + tier: str + limits: dict + + +class UpgradeRequest(BaseModel): + tier: str # "free" | "pro" | "team" + + +# ── Endpoints ─────────────────────────────────────────────────────────────── + +@router.get("/tiers") +async def list_tiers(): + """List all available tiers and their limits (public endpoint).""" + return { + "tiers": [ + { + "name": tier_name, + "display_name": { + "free": "Free", + "pro": "Pro ($9/mo)", + "team": "Team ($29/mo)", + }.get(tier_name, tier_name), + "limits": limits, + } + for tier_name, limits in TIER_LIMITS.items() + ] + } + + +@router.get("/my", response_model=OrgResponse) +async def get_my_org( + org: Organization = Depends(get_current_org), + db: AsyncSession = Depends(get_db), +): + """Get the current user's organization with tier limits info.""" + # Count members + result = await db.execute( + select(OrganizationMember).where( + OrganizationMember.organization_id == org.id + ) + ) + members = result.scalars().all() + + return OrgResponse( + id=org.id, + slug=org.slug, + name=org.name, + tier=org.tier or "free", + custom_domain=org.custom_domain, + member_count=len(members), + tier_info=get_tier_info(org), + ) + + +@router.patch("/my/tier", response_model=OrgResponse) +async def update_org_tier( + body: UpgradeRequest, + org: Organization = Depends(get_current_org), + user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + """Update the organization's tier. + + In production, this would be gated by Stripe payment verification. + For now, this is an admin-only endpoint that directly sets the tier. + Only the org owner can change the tier. + """ + # Verify user is owner + result = await db.execute( + select(OrganizationMember).where( + OrganizationMember.organization_id == org.id, + OrganizationMember.user_id == user.id, + ) + ) + membership = result.scalar_one_or_none() + if not membership or membership.role != "owner": + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Only the organization owner can change the plan tier.", + ) + + if body.tier not in ("free", "pro", "team"): + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail=f"Invalid tier '{body.tier}'. Must be one of: free, pro, team.", + ) + + org.tier = body.tier + await db.flush() + await db.refresh(org) + + # Count members for response + members_result = await db.execute( + select(OrganizationMember).where( + OrganizationMember.organization_id == org.id + ) + ) + members = members_result.scalars().all() + + return OrgResponse( + id=org.id, + slug=org.slug, + name=org.name, + tier=org.tier, + custom_domain=org.custom_domain, + member_count=len(members), + tier_info=get_tier_info(org), + ) + + +@router.get("/my/limits") +async def get_my_limits( + org: Organization = Depends(get_current_org), +): + """Get the current organization's tier limits and feature flags.""" + return get_tier_info(org) \ No newline at end of file diff --git a/app/api/router.py b/app/api/router.py index f25f82e..0bd6de9 100644 --- a/app/api/router.py +++ b/app/api/router.py @@ -5,9 +5,13 @@ from app.api.incidents import router as incidents_router from app.api.monitors import router as monitors_router from app.api.subscribers import router as subscribers_router from app.api.settings import router as settings_router +from app.api.organizations import router as organizations_router +from app.routes.auth import router as auth_router api_v1_router = APIRouter() +api_v1_router.include_router(auth_router, tags=["auth"]) +api_v1_router.include_router(organizations_router, prefix="/organizations", tags=["organizations"]) api_v1_router.include_router(services_router, prefix="/services", tags=["services"]) api_v1_router.include_router(incidents_router, prefix="/incidents", tags=["incidents"]) api_v1_router.include_router(monitors_router, prefix="/monitors", tags=["monitors"]) diff --git a/app/auth.py b/app/auth.py new file mode 100644 index 0000000..0ee0ac9 --- /dev/null +++ b/app/auth.py @@ -0,0 +1,99 @@ +"""JWT authentication and password hashing utilities.""" + +from datetime import datetime, timedelta + +from fastapi import Depends, HTTPException, status +from fastapi.security import OAuth2PasswordBearer +from jose import JWTError, jwt +from passlib.context import CryptContext +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.config import settings +from app.dependencies import get_db +from app.models.saas_models import User, Organization, OrganizationMember + +# Password hashing context +pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") + +# OAuth2 scheme for token extraction +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login") + + +def hash_password(password: str) -> str: + """Hash a password using bcrypt.""" + return pwd_context.hash(password) + + +def verify_password(plain_password: str, hashed_password: str) -> bool: + """Verify a password against a hash.""" + return pwd_context.verify(plain_password, hashed_password) + + +def create_access_token(user_id: str, exp_hours: int = 72) -> str: + """Create a JWT access token for the given user ID.""" + payload = { + "sub": user_id, + "exp": datetime.utcnow() + timedelta(hours=exp_hours), + } + return jwt.encode(payload, settings.secret_key, algorithm="HS256") + + +def decode_access_token(token: str) -> dict: + """Decode and verify a JWT access token. Raises JWTError on failure.""" + return jwt.decode(token, settings.secret_key, algorithms=["HS256"]) + + +async def get_current_user( + token: str = Depends(oauth2_scheme), + db: AsyncSession = Depends(get_db), +) -> User: + """FastAPI dependency: extract and validate current user from JWT.""" + credentials_exception = HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid authentication credentials", + headers={"WWW-Authenticate": "Bearer"}, + ) + try: + payload = decode_access_token(token) + user_id: str | None = payload.get("sub") + if user_id is None: + raise credentials_exception + except JWTError: + raise credentials_exception + + result = await db.execute(select(User).where(User.id == user_id)) + user = result.scalar_one_or_none() + if user is None: + raise credentials_exception + return user + + +async def get_current_org( + user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +) -> Organization: + """FastAPI dependency: get the user's first organization (simplified). + + In the future, this will support org-switching via header or session. + """ + result = await db.execute( + select(OrganizationMember).where(OrganizationMember.user_id == user.id) + ) + membership = result.scalars().first() + if not membership: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="User is not a member of any organization", + ) + + org_result = await db.execute( + select(Organization).where(Organization.id == membership.organization_id) + ) + org = org_result.scalar_one_or_none() + if org is None: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Organization not found", + ) + return org \ No newline at end of file diff --git a/app/models/__init__.py b/app/models/__init__.py index eb4b32e..61af1ee 100644 --- a/app/models/__init__.py +++ b/app/models/__init__.py @@ -8,6 +8,12 @@ from app.models.models import ( NotificationLog, SiteSetting, ) +from app.models.saas_models import ( + User, + Organization, + OrganizationMember, + StatusPage, +) __all__ = [ "Service", @@ -18,4 +24,8 @@ __all__ = [ "Subscriber", "NotificationLog", "SiteSetting", + "User", + "Organization", + "OrganizationMember", + "StatusPage", ] \ No newline at end of file diff --git a/app/models/models.py b/app/models/models.py index 0feeaec..b134934 100644 --- a/app/models/models.py +++ b/app/models/models.py @@ -15,6 +15,9 @@ class Service(Base): __tablename__ = "services" id: Mapped[str] = mapped_column(String(36), primary_key=True, default=_uuid_str) + organization_id: Mapped[str | None] = mapped_column( + String(36), ForeignKey("organizations.id"), nullable=True, index=True + ) name: Mapped[str] = mapped_column(String(100), nullable=False) slug: Mapped[str] = mapped_column(String(50), unique=True, nullable=False, index=True) description: Mapped[str | None] = mapped_column(Text, nullable=True) @@ -74,6 +77,9 @@ class Monitor(Base): __tablename__ = "monitors" id: Mapped[str] = mapped_column(String(36), primary_key=True, default=_uuid_str) + organization_id: Mapped[str | None] = mapped_column( + String(36), ForeignKey("organizations.id"), nullable=True, index=True + ) service_id: Mapped[str] = mapped_column( String(36), ForeignKey("services.id"), nullable=False, index=True ) @@ -114,6 +120,9 @@ class Subscriber(Base): __tablename__ = "subscribers" id: Mapped[str] = mapped_column(String(36), primary_key=True, default=_uuid_str) + organization_id: Mapped[str | None] = mapped_column( + String(36), ForeignKey("organizations.id"), nullable=True, index=True + ) email: Mapped[str] = mapped_column(String(255), unique=True, nullable=False, index=True) is_confirmed: Mapped[bool] = mapped_column(Boolean, default=False) confirm_token: Mapped[str | None] = mapped_column(String(100), nullable=True) diff --git a/app/models/saas_models.py b/app/models/saas_models.py new file mode 100644 index 0000000..272f1af --- /dev/null +++ b/app/models/saas_models.py @@ -0,0 +1,113 @@ +"""SaaS multi-tenancy models: User, Organization, OrganizationMember, StatusPage.""" + +import uuid +from datetime import datetime + +from sqlalchemy import Boolean, DateTime, ForeignKey, String, UniqueConstraint +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from app.database import Base + + +def _uuid_str() -> str: + return str(uuid.uuid4()) + + +class User(Base): + """Individual who can log in.""" + + __tablename__ = "users" + + id: Mapped[str] = mapped_column(String(36), primary_key=True, default=_uuid_str) + email: Mapped[str] = mapped_column( + String(255), unique=True, nullable=False, index=True + ) + password_hash: Mapped[str] = mapped_column(String(255), nullable=False) + display_name: Mapped[str | None] = mapped_column(String(100), nullable=True) + is_email_verified: Mapped[bool] = mapped_column(Boolean, default=False) + email_verify_token: Mapped[str | None] = mapped_column(String(100), nullable=True) + created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow) + updated_at: Mapped[datetime] = mapped_column( + DateTime, default=datetime.utcnow, onupdate=datetime.utcnow + ) + + memberships: Mapped[list["OrganizationMember"]] = relationship( + back_populates="user" + ) + + +class Organization(Base): + """The tenant; owns status pages.""" + + __tablename__ = "organizations" + + id: Mapped[str] = mapped_column(String(36), primary_key=True, default=_uuid_str) + slug: Mapped[str] = mapped_column( + String(50), unique=True, nullable=False, index=True + ) + name: Mapped[str] = mapped_column(String(100), nullable=False) + tier: Mapped[str] = mapped_column(String(20), nullable=False, default="free") + # "free" | "pro" | "team" + stripe_customer_id: Mapped[str | None] = mapped_column(String(100), nullable=True) + custom_domain: Mapped[str | None] = mapped_column(String(255), nullable=True) + created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow) + updated_at: Mapped[datetime] = mapped_column( + DateTime, default=datetime.utcnow, onupdate=datetime.utcnow + ) + members: Mapped[list["OrganizationMember"]] = relationship( + back_populates="organization" + ) + status_pages: Mapped[list["StatusPage"]] = relationship( + back_populates="organization" + ) + # Services linked to this org (from app.models.models.Service) + + +class OrganizationMember(Base): + """Joins users to orgs with roles.""" + + __tablename__ = "organization_members" + + id: Mapped[str] = mapped_column(String(36), primary_key=True, default=_uuid_str) + organization_id: Mapped[str] = mapped_column( + String(36), ForeignKey("organizations.id"), nullable=False, index=True + ) + user_id: Mapped[str] = mapped_column( + String(36), ForeignKey("users.id"), nullable=False, index=True + ) + role: Mapped[str] = mapped_column(String(20), nullable=False, default="member") + # "owner" | "admin" | "member" + joined_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow) + + organization: Mapped["Organization"] = relationship(back_populates="members") + user: Mapped["User"] = relationship(back_populates="memberships") + + __table_args__ = ( + {"sqlite_autoincrement": True}, + ) + + +class StatusPage(Base): + """Per-organization status page.""" + + __tablename__ = "status_pages" + + id: Mapped[str] = mapped_column(String(36), primary_key=True, default=_uuid_str) + organization_id: Mapped[str] = mapped_column( + String(36), ForeignKey("organizations.id"), nullable=False, index=True + ) + slug: Mapped[str] = mapped_column(String(50), nullable=False, index=True) + title: Mapped[str] = mapped_column(String(100), nullable=False) + subdomain: Mapped[str | None] = mapped_column(String(100), nullable=True) + custom_domain: Mapped[str | None] = mapped_column(String(255), nullable=True) + is_public: Mapped[bool] = mapped_column(Boolean, default=True) + created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow) + updated_at: Mapped[datetime] = mapped_column( + DateTime, default=datetime.utcnow, onupdate=datetime.utcnow + ) + + organization: Mapped["Organization"] = relationship(back_populates="status_pages") + + __table_args__ = ( + UniqueConstraint("organization_id", "slug", name="uq_status_page_org_slug"), + ) \ No newline at end of file diff --git a/app/routes/__init__.py b/app/routes/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/routes/auth.py b/app/routes/auth.py new file mode 100644 index 0000000..ed7fe51 --- /dev/null +++ b/app/routes/auth.py @@ -0,0 +1,122 @@ +"""Auth routes: register, login, and current user profile.""" + +from fastapi import APIRouter, Depends, HTTPException, status +from pydantic import BaseModel, EmailStr +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.auth import create_access_token, get_current_user, hash_password, verify_password +from app.dependencies import get_db +from app.models.saas_models import Organization, OrganizationMember, StatusPage, User + +router = APIRouter(tags=["auth"]) + + +# ── Request / Response schemas ────────────────────────────────────────────── + + +class RegisterRequest(BaseModel): + email: EmailStr + password: str + + +class LoginRequest(BaseModel): + email: EmailStr + password: str + + +class AuthResponse(BaseModel): + access_token: str + token_type: str = "bearer" + + +class UserProfile(BaseModel): + id: str + email: str + display_name: str | None = None + is_email_verified: bool + created_at: str | None = None + + +# ── Routes ─────────────────────────────────────────────────────────────────── + + +@router.post("/auth/register", status_code=status.HTTP_201_CREATED, response_model=AuthResponse) +async def register(body: RegisterRequest, db: AsyncSession = Depends(get_db)): + """Register a new user, create a default Organization + StatusPage, return JWT.""" + # Check for existing user + result = await db.execute(select(User).where(User.email == body.email)) + if result.scalar_one_or_none() is not None: + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, + detail="A user with this email already exists", + ) + + # Create user + user = User( + email=body.email, + password_hash=hash_password(body.password), + ) + db.add(user) + await db.flush() # assign user.id + + # Create default organization + org_slug = body.email.split("@")[0].lower() + org = Organization( + name=org_slug, + slug=org_slug, + ) + db.add(org) + await db.flush() # assign org.id + + # Create organization membership (owner) + membership = OrganizationMember( + organization_id=org.id, + user_id=user.id, + role="owner", + ) + db.add(membership) + + # Create default status page + status_page = StatusPage( + organization_id=org.id, + slug="main", + title="Status Page", + ) + db.add(status_page) + + # Commit all together (the get_db dependency also commits, but flush ensures + # relationships are consistent before we return) + await db.flush() + + # Generate JWT + token = create_access_token(user.id) + return AuthResponse(access_token=token) + + +@router.post("/auth/login", response_model=AuthResponse) +async def login(body: LoginRequest, db: AsyncSession = Depends(get_db)): + """Authenticate user with email + password, return JWT.""" + result = await db.execute(select(User).where(User.email == body.email)) + user = result.scalar_one_or_none() + + if user is None or not verify_password(body.password, user.password_hash): + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid email or password", + ) + + token = create_access_token(user.id) + return AuthResponse(access_token=token) + + +@router.get("/auth/me", response_model=UserProfile) +async def me(current_user: User = Depends(get_current_user)): + """Return the current authenticated user's profile.""" + return UserProfile( + id=current_user.id, + email=current_user.email, + display_name=current_user.display_name, + is_email_verified=current_user.is_email_verified, + created_at=current_user.created_at.isoformat() if current_user.created_at else None, + ) \ No newline at end of file diff --git a/app/services/tier_limits.py b/app/services/tier_limits.py new file mode 100644 index 0000000..b26c4aa --- /dev/null +++ b/app/services/tier_limits.py @@ -0,0 +1,229 @@ +"""Tier enforcement: limits and feature flags for Free/Pro/Team plans. + +This module defines the per-tier limits and provides enforcement functions +that raise HTTP 403 when a limit is exceeded or a feature is unavailable. +""" + +from fastapi import HTTPException, status +from sqlalchemy import func, select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.models.saas_models import Organization, OrganizationMember, StatusPage +from app.models.models import Service, Monitor, Subscriber + + +# ── Tier limits configuration ────────────────────────────────────────────── + +TIER_LIMITS: dict[str, dict[str, int | bool]] = { + "free": { + "status_pages": 1, + "services_per_page": 5, + "monitors_per_service": 1, + "subscribers": 25, + "members": 1, + "check_interval_min": 5, # minutes + "custom_domain": False, + "custom_branding": False, + "webhooks": False, + "api_access": False, + "incident_history_days": 30, + "sla_badge": False, + "password_protection": False, + }, + "pro": { + "status_pages": 5, + "services_per_page": 50, + "monitors_per_service": 5, + "subscribers": 500, + "members": 3, + "check_interval_min": 1, # minutes + "custom_domain": True, + "custom_branding": True, + "webhooks": True, + "api_access": True, + "incident_history_days": 365, + "sla_badge": True, + "password_protection": False, + }, + "team": { + "status_pages": -1, # unlimited + "services_per_page": -1, + "monitors_per_service": -1, + "subscribers": -1, + "members": -1, + "check_interval_min": 0, # 30 seconds (0 min) + "custom_domain": True, + "custom_branding": True, + "webhooks": True, + "api_access": True, + "incident_history_days": -1, # unlimited + "sla_badge": True, + "password_protection": True, + }, +} + + +# ── Tier info helpers ─────────────────────────────────────────────────────── + +def get_tier_limits(tier: str) -> dict: + """Return the limits dict for a given tier name. Falls back to free.""" + return TIER_LIMITS.get(tier, TIER_LIMITS["free"]) + + +def get_org_limits(org: Organization) -> dict: + """Return the limits dict for an organization based on its tier.""" + return get_tier_limits(org.tier or "free") + + +def get_limit(org: Organization, feature: str): + """Return the limit value for a specific feature given the org's tier. + + Returns: + int: numeric limit (-1 means unlimited) + bool: feature flag (True/False) + None: unknown feature + """ + limits = get_org_limits(org) + return limits.get(feature) + + +# ── Enforcement ───────────────────────────────────────────────────────────── + +class TierLimitExceeded(HTTPException): + """Raised when a tier limit is exceeded.""" + + def __init__(self, feature: str, limit: int | bool): + if limit is False: + detail = f"Feature '{feature}' is not available on your current plan. Upgrade to access it." + else: + detail = ( + f"Tier limit reached for '{feature}' ({limit}). " + "Upgrade your plan to increase this limit." + ) + super().__init__(status_code=status.HTTP_403_FORBIDDEN, detail=detail) + + +async def enforce_limit( + db: AsyncSession, + org: Organization, + feature: str, + current_count: int, +) -> None: + """Raise TierLimitExceeded if the current count meets or exceeds the tier limit. + + Args: + db: Database session for queries. + org: The organization whose tier limits to check. + feature: The feature name (key in TIER_LIMITS). + current_count: How many of this feature the org currently has. + + Raises: + TierLimitExceeded: If the limit is reached or the feature is disabled. + """ + limit = get_limit(org, feature) + + if limit is None: + return # Unknown feature — don't block + + if limit is False: + # Feature flag: not available on this tier + raise TierLimitExceeded(feature, limit) + + if limit == -1: + return # Unlimited + + if isinstance(limit, int) and current_count >= limit: + raise TierLimitExceeded(feature, limit) + + +# ── Concrete enforcement helpers ──────────────────────────────────────────── + +async def check_status_page_limit(db: AsyncSession, org: Organization) -> None: + """Check that the org hasn't exceeded its status page limit.""" + result = await db.execute( + select(func.count(StatusPage.id)).where( + StatusPage.organization_id == org.id + ) + ) + count = result.scalar() or 0 + await enforce_limit(db, org, "status_pages", count) + + +async def check_service_limit( + db: AsyncSession, org: Organization, status_page_id: str | None = None +) -> None: + """Check that the org hasn't exceeded its services-per-page limit. + + If status_page_id is None, counts all services for the org. + """ + query = select(func.count(Service.id)).where( + Service.organization_id == org.id + ) + if status_page_id: + # In future, Service will have a status_page_id column + # For now, count all services in the org + pass + result = await db.execute(query) + count = result.scalar() or 0 + await enforce_limit(db, org, "services_per_page", count) + + +async def check_monitor_limit( + db: AsyncSession, org: Organization, service_id: str +) -> None: + """Check that the service hasn't exceeded its monitors-per-service limit.""" + result = await db.execute( + select(func.count(Monitor.id)).where(Monitor.service_id == service_id) + ) + count = result.scalar() or 0 + await enforce_limit(db, org, "monitors_per_service", count) + + +async def check_subscriber_limit(db: AsyncSession, org: Organization) -> None: + """Check that the org hasn't exceeded its subscriber limit.""" + result = await db.execute( + select(func.count(Subscriber.id)).where( + Subscriber.organization_id == org.id + ) + ) + count = result.scalar() or 0 + await enforce_limit(db, org, "subscribers", count) + + +async def check_member_limit(db: AsyncSession, org: Organization) -> None: + """Check that the org hasn't exceeded its team member limit.""" + result = await db.execute( + select(func.count(OrganizationMember.id)).where( + OrganizationMember.organization_id == org.id + ) + ) + count = result.scalar() or 0 + await enforce_limit(db, org, "members", count) + + +def enforce_feature(org: Organization, feature: str) -> None: + """Enforce a boolean feature flag. Raises TierLimitExceeded if False. + + Use this for features that are either allowed or not (e.g., custom_domain, + webhooks, api_access) without a numeric limit. + """ + limit = get_limit(org, feature) + if limit is False: + raise TierLimitExceeded(feature, False) + if limit is None: + # Unknown feature — don't block + return + + +def get_tier_info(org: Organization) -> dict: + """Return a dict of the org's current tier with limits and feature flags. + + Useful for API responses that show the org what they can and can't do. + """ + limits = get_org_limits(org) + return { + "tier": org.tier or "free", + "limits": limits, + "organization_id": org.id, + "organization_slug": org.slug, + } \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 53b5d79..23279a2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,6 +20,10 @@ dependencies = [ "httpx>=0.27,<1.0", "typer>=0.9,<1.0", "rich>=13.0,<14.0", + "python-jose[cryptography]>=3.3,<4.0", + "passlib[bcrypt]>=1.7,<2.0", + "bcrypt==4.0.1", + "email-validator>=2.0,<3.0", ] [project.optional-dependencies] diff --git a/tests/conftest.py b/tests/conftest.py index c0493a1..3085a0d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -21,6 +21,9 @@ TestSessionLocal = async_sessionmaker( @pytest_asyncio.fixture(scope="session", autouse=True) async def setup_database(): """Create all tables once for the test session.""" + # Import SaaS models so their tables are registered on Base.metadata + import app.models.saas_models # noqa: F401 + async with test_engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) yield diff --git a/tests/test_auth.py b/tests/test_auth.py new file mode 100644 index 0000000..3d7d0ad --- /dev/null +++ b/tests/test_auth.py @@ -0,0 +1,103 @@ +"""Test Auth API endpoints.""" + +import pytest + + +REGISTER_URL = "/api/v1/auth/register" +LOGIN_URL = "/api/v1/auth/login" +ME_URL = "/api/v1/auth/me" + + +@pytest.mark.asyncio +async def test_register_new_user(client): + """Should register a new user and return 201 with a JWT token.""" + response = await client.post( + REGISTER_URL, + json={"email": "newuser@example.com", "password": "securepassword123"}, + ) + assert response.status_code == 201 + data = response.json() + assert "access_token" in data + assert data["token_type"] == "bearer" + assert len(data["access_token"]) > 0 + + +@pytest.mark.asyncio +async def test_register_duplicate_email(client): + """Should return 409 when registering with an email that already exists.""" + # Register first user + await client.post( + REGISTER_URL, + json={"email": "duplicate@example.com", "password": "password123"}, + ) + # Try to register again with same email + response = await client.post( + REGISTER_URL, + json={"email": "duplicate@example.com", "password": "differentpassword"}, + ) + assert response.status_code == 409 + assert "already exists" in response.json()["detail"].lower() + + +@pytest.mark.asyncio +async def test_login_correct_password(client): + """Should return 200 with a JWT token on successful login.""" + # Register a user first + await client.post( + REGISTER_URL, + json={"email": "loginuser@example.com", "password": "mypassword"}, + ) + # Login with correct password + response = await client.post( + LOGIN_URL, + json={"email": "loginuser@example.com", "password": "mypassword"}, + ) + assert response.status_code == 200 + data = response.json() + assert "access_token" in data + assert data["token_type"] == "bearer" + + +@pytest.mark.asyncio +async def test_login_wrong_password(client): + """Should return 401 when logging in with wrong password.""" + # Register a user first + await client.post( + REGISTER_URL, + json={"email": "wrongpw@example.com", "password": "correctpassword"}, + ) + # Login with wrong password + response = await client.post( + LOGIN_URL, + json={"email": "wrongpw@example.com", "password": "wrongpassword"}, + ) + assert response.status_code == 401 + + +@pytest.mark.asyncio +async def test_me_with_valid_token(client): + """Should return 200 with user profile when using a valid JWT token.""" + # Register a user and get token + reg_response = await client.post( + REGISTER_URL, + json={"email": "meuser@example.com", "password": "password123"}, + ) + token = reg_response.json()["access_token"] + + # Get profile with valid token + response = await client.get( + ME_URL, + headers={"Authorization": f"Bearer {token}"}, + ) + assert response.status_code == 200 + data = response.json() + assert data["email"] == "meuser@example.com" + assert "id" in data + assert data["is_email_verified"] is False + + +@pytest.mark.asyncio +async def test_me_without_token(client): + """Should return 401 when accessing /me without a token.""" + response = await client.get(ME_URL) + assert response.status_code in (401, 403) \ No newline at end of file diff --git a/tests/test_tier_limits.py b/tests/test_tier_limits.py new file mode 100644 index 0000000..1ed316a --- /dev/null +++ b/tests/test_tier_limits.py @@ -0,0 +1,590 @@ +"""Test tier enforcement: limits, feature flags, and organization endpoints.""" + +import pytest +from sqlalchemy import select + +from app.models.saas_models import Organization, OrganizationMember, StatusPage, User +from app.models.models import Service, Monitor, Subscriber +from app.services.tier_limits import ( + TIER_LIMITS, + TierLimitExceeded, + enforce_limit, + enforce_feature, + get_limit, + get_org_limits, + get_tier_info, + get_tier_limits, + check_status_page_limit, + check_service_limit, + check_monitor_limit, + check_subscriber_limit, + check_member_limit, +) + + +# ── Unit tests for tier_limits module ──────────────────────────────────────── + +class TestTierLimitsConfig: + """Test that the TIER_LIMITS config is well-formed.""" + + def test_all_tiers_defined(self): + """All three tiers should be defined.""" + assert "free" in TIER_LIMITS + assert "pro" in TIER_LIMITS + assert "team" in TIER_LIMITS + + def test_free_tier_has_expected_keys(self): + """Free tier should have all expected limit keys.""" + free = TIER_LIMITS["free"] + expected_keys = { + "status_pages", "services_per_page", "monitors_per_service", + "subscribers", "members", "check_interval_min", + "custom_domain", "custom_branding", "webhooks", + "api_access", "incident_history_days", "sla_badge", + "password_protection", + } + assert set(free.keys()) == expected_keys + + def test_free_tier_values(self): + """Free tier should have restrictive values.""" + free = TIER_LIMITS["free"] + assert free["status_pages"] == 1 + assert free["services_per_page"] == 5 + assert free["monitors_per_service"] == 1 + assert free["subscribers"] == 25 + assert free["members"] == 1 + assert free["custom_domain"] is False + assert free["webhooks"] is False + assert free["api_access"] is False + + def test_pro_tier_values(self): + """Pro tier should have moderate values.""" + pro = TIER_LIMITS["pro"] + assert pro["status_pages"] == 5 + assert pro["services_per_page"] == 50 + assert pro["monitors_per_service"] == 5 + assert pro["subscribers"] == 500 + assert pro["members"] == 3 + assert pro["custom_domain"] is True + assert pro["webhooks"] is True + + def test_team_tier_values(self): + """Team tier should have unlimited (-1) for most things.""" + team = TIER_LIMITS["team"] + assert team["status_pages"] == -1 + assert team["services_per_page"] == -1 + assert team["monitors_per_service"] == -1 + assert team["subscribers"] == -1 + assert team["members"] == -1 + assert team["custom_domain"] is True + assert team["password_protection"] is True + + def test_all_tiers_have_same_keys(self): + """All tiers should have exactly the same set of keys.""" + keys = set(TIER_LIMITS["free"].keys()) + for tier_name, tier_data in TIER_LIMITS.items(): + assert set(tier_data.keys()) == keys, f"Tier '{tier_name}' has different keys" + + +class TestGetLimitHelpers: + """Test the helper functions.""" + + def test_get_tier_limits_known_tier(self): + """get_tier_limits should return the correct dict for known tiers.""" + assert get_tier_limits("free") == TIER_LIMITS["free"] + assert get_tier_limits("pro") == TIER_LIMITS["pro"] + assert get_tier_limits("team") == TIER_LIMITS["team"] + + def test_get_tier_limits_unknown_tier(self): + """get_tier_limits should default to free for unknown tiers.""" + assert get_tier_limits("enterprise") == TIER_LIMITS["free"] + assert get_tier_limits("") == TIER_LIMITS["free"] + + def test_get_org_limits_with_tier(self): + """get_org_limits should use the org's tier.""" + org = Organization(slug="test", name="Test", tier="pro") + assert get_org_limits(org) == TIER_LIMITS["pro"] + + def test_get_org_limits_with_none_tier(self): + """get_org_limits should default to free when tier is None.""" + org = Organization(slug="test", name="Test", tier=None) + assert get_org_limits(org) == TIER_LIMITS["free"] + + def test_get_limit_numeric(self): + """get_limit should return numeric limits.""" + org = Organization(slug="test", name="Test", tier="free") + assert get_limit(org, "status_pages") == 1 + assert get_limit(org, "subscribers") == 25 + + def test_get_limit_boolean(self): + """get_limit should return boolean feature flags.""" + org_free = Organization(slug="free-org", name="Free Org", tier="free") + org_pro = Organization(slug="pro-org", name="Pro Org", tier="pro") + assert get_limit(org_free, "custom_domain") is False + assert get_limit(org_pro, "custom_domain") is True + + def test_get_limit_unlimited(self): + """get_limit should return -1 for unlimited features.""" + org = Organization(slug="team-org", name="Team Org", tier="team") + assert get_limit(org, "status_pages") == -1 + + def test_get_limit_unknown_feature(self): + """get_limit should return None for unknown feature names.""" + org = Organization(slug="test", name="Test", tier="free") + assert get_limit(org, "nonexistent_feature") is None + + +class TestEnforceLimit: + """Test the enforce_limit function.""" + + @pytest.mark.asyncio + async def test_enforce_limit_allows_under_limit(self, db_session): + """Should not raise when current_count is below the limit.""" + org = Organization(slug="test", name="Test", tier="free") + db_session.add(org) + await db_session.flush() + # status_pages limit is 1 for free, current_count=0 should pass + await enforce_limit(db_session, org, "status_pages", 0) + + @pytest.mark.asyncio + async def test_enforce_limit_blocks_at_limit(self, db_session): + """Should raise TierLimitExceeded when current_count equals the limit.""" + org = Organization(slug="test", name="Test", tier="free") + db_session.add(org) + await db_session.flush() + # status_pages limit is 1 for free, current_count=1 should fail + with pytest.raises(TierLimitExceeded) as exc_info: + await enforce_limit(db_session, org, "status_pages", 1) + assert "status_pages" in str(exc_info.value.detail) + + @pytest.mark.asyncio + async def test_enforce_limit_blocks_over_limit(self, db_session): + """Should raise TierLimitExceeded when current_count exceeds the limit.""" + org = Organization(slug="test", name="Test", tier="free") + db_session.add(org) + await db_session.flush() + with pytest.raises(TierLimitExceeded): + await enforce_limit(db_session, org, "status_pages", 5) + + @pytest.mark.asyncio + async def test_enforce_limit_allows_unlimited(self, db_session): + """Should not raise when limit is -1 (unlimited).""" + org = Organization(slug="test", name="Test", tier="team") + db_session.add(org) + await db_session.flush() + # team has unlimited status_pages, even count=1000 should pass + await enforce_limit(db_session, org, "status_pages", 1000) + + @pytest.mark.asyncio + async def test_enforce_limit_blocks_feature_flag(self, db_session): + """Should raise TierLimitExceeded when feature is False (not available).""" + org = Organization(slug="test", name="Test", tier="free") + db_session.add(org) + await db_session.flush() + with pytest.raises(TierLimitExceeded) as exc_info: + await enforce_limit(db_session, org, "custom_domain", 0) + assert "not available" in str(exc_info.value.detail) + + @pytest.mark.asyncio + async def test_enforce_limit_unknown_feature(self, db_session): + """Should not raise for an unknown feature (no limit defined).""" + org = Organization(slug="test", name="Test", tier="free") + db_session.add(org) + await db_session.flush() + # Should not raise for an unknown feature + await enforce_limit(db_session, org, "nonexistent", 999) + + @pytest.mark.asyncio + async def test_enforce_limit_pro_tier(self, db_session): + """Pro tier allows up to 5 status pages.""" + org = Organization(slug="pro-org", name="Pro Org", tier="pro") + db_session.add(org) + await db_session.flush() + # 5 is the limit, should fail at 5 + with pytest.raises(TierLimitExceeded): + await enforce_limit(db_session, org, "status_pages", 5) + # 4 should pass + await enforce_limit(db_session, org, "status_pages", 4) + + +class TestEnforceFeature: + """Test the enforce_feature boolean feature flag function.""" + + def test_enforce_feature_allows_pro_custom_domain(self): + """Pro org should be allowed custom_domain.""" + org = Organization(slug="pro", name="Pro", tier="pro") + # Should not raise + enforce_feature(org, "custom_domain") + + def test_enforce_feature_blocks_free_custom_domain(self): + """Free org should be blocked from custom_domain.""" + org = Organization(slug="free", name="Free", tier="free") + with pytest.raises(TierLimitExceeded): + enforce_feature(org, "custom_domain") + + def test_enforce_feature_blocks_free_webhooks(self): + """Free org should be blocked from webhooks.""" + org = Organization(slug="free", name="Free", tier="free") + with pytest.raises(TierLimitExceeded): + enforce_feature(org, "webhooks") + + def test_enforce_feature_blocks_free_api_access(self): + """Free org should be blocked from api_access.""" + org = Organization(slug="free", name="Free", tier="free") + with pytest.raises(TierLimitExceeded): + enforce_feature(org, "api_access") + + def test_enforce_feature_blocks_free_password_protection(self): + """Free org should be blocked from password_protection.""" + org = Organization(slug="free", name="Free", tier="free") + with pytest.raises(TierLimitExceeded): + enforce_feature(org, "password_protection") + + def test_enforce_feature_allows_pro_webhooks(self): + """Pro org should be allowed webhooks.""" + org = Organization(slug="pro", name="Pro", tier="pro") + enforce_feature(org, "webhooks") + + def test_enforce_feature_unknown_feature(self): + """Unknown feature should not raise.""" + org = Organization(slug="free", name="Free", tier="free") + # Should not raise for unknown + enforce_feature(org, "quantum_computing") + + def test_enforce_feature_team_password_protection(self): + """Team org should be allowed password_protection.""" + org = Organization(slug="team", name="Team", tier="team") + enforce_feature(org, "password_protection") + + +class TestGetTierInfo: + """Test get_tier_info helper.""" + + def test_free_org_tier_info(self): + """Free org should return complete tier info.""" + org = Organization(id="org-1", slug="free-org", name="Free Org", tier="free") + info = get_tier_info(org) + assert info["tier"] == "free" + assert info["organization_id"] == "org-1" + assert info["limits"]["status_pages"] == 1 + assert info["limits"]["custom_domain"] is False + + def test_pro_org_tier_info(self): + """Pro org should return correct tier info.""" + org = Organization(id="org-2", slug="pro-org", name="Pro Org", tier="pro") + info = get_tier_info(org) + assert info["tier"] == "pro" + assert info["limits"]["status_pages"] == 5 + assert info["limits"]["custom_domain"] is True + + def test_default_tier_info(self): + """Org with None tier should default to free.""" + org = Organization(id="org-3", slug="default-org", name="Default Org", tier=None) + info = get_tier_info(org) + assert info["tier"] == "free" + assert info["limits"]["status_pages"] == 1 + + +# ── Integration tests for database-backed limit checks ─────────────────────── + +class TestCheckStatusPageLimit: + """Test check_status_page_limit with real database.""" + + @pytest.mark.asyncio + async def test_free_org_can_create_first_page(self, db_session): + """Free org should be allowed to create its first status page.""" + org = Organization(slug="free-org", name="Free Org", tier="free") + db_session.add(org) + await db_session.flush() + # Should not raise (0 existing pages, limit is 1) + await check_status_page_limit(db_session, org) + + @pytest.mark.asyncio + async def test_free_org_blocked_at_second_page(self, db_session): + """Free org should be blocked from creating a second status page.""" + org = Organization(slug="free-org", name="Free Org", tier="free") + db_session.add(org) + await db_session.flush() + # Create first page + page = StatusPage(organization_id=org.id, slug="main", title="Main") + db_session.add(page) + await db_session.flush() + # Now limit should be reached + with pytest.raises(TierLimitExceeded): + await check_status_page_limit(db_session, org) + + @pytest.mark.asyncio + async def test_pro_org_allows_up_to_five_pages(self, db_session): + """Pro org should be allowed up to 5 status pages.""" + org = Organization(slug="pro-org", name="Pro Org", tier="pro") + db_session.add(org) + await db_session.flush() + # Create 4 pages — should be fine (limit is 5) + for i in range(4): + page = StatusPage(organization_id=org.id, slug=f"page-{i}", title=f"Page {i}") + db_session.add(page) + await db_session.flush() + await check_status_page_limit(db_session, org) # should not raise + + @pytest.mark.asyncio + async def test_pro_org_blocked_at_sixth_page(self, db_session): + """Pro org should be blocked at the 6th status page.""" + org = Organization(slug="pro-org", name="Pro Org", tier="pro") + db_session.add(org) + await db_session.flush() + # Create 5 pages (at the limit) + for i in range(5): + page = StatusPage(organization_id=org.id, slug=f"page-{i}", title=f"Page {i}") + db_session.add(page) + await db_session.flush() + with pytest.raises(TierLimitExceeded): + await check_status_page_limit(db_session, org) + + @pytest.mark.asyncio + async def test_team_org_unlimited_pages(self, db_session): + """Team org should never be blocked for status pages.""" + org = Organization(slug="team-org", name="Team Org", tier="team") + db_session.add(org) + await db_session.flush() + for i in range(50): + page = StatusPage(organization_id=org.id, slug=f"page-{i}", title=f"Page {i}") + db_session.add(page) + await db_session.flush() + # Should not raise even with 50 pages + await check_status_page_limit(db_session, org) + + +class TestCheckMemberLimit: + """Test check_member_limit with real database.""" + + @pytest.mark.asyncio + async def test_free_org_blocked_at_second_member(self, db_session): + """Free org should be blocked from adding a second member.""" + org = Organization(slug="free-org", name="Free Org", tier="free") + db_session.add(org) + await db_session.flush() + # Create one user + user = User(email="owner@example.com", password_hash="hash") + db_session.add(user) + await db_session.flush() + # Add owner membership + membership = OrganizationMember( + organization_id=org.id, user_id=user.id, role="owner" + ) + db_session.add(membership) + await db_session.flush() + # Now at limit — should be blocked + with pytest.raises(TierLimitExceeded): + await check_member_limit(db_session, org) + + @pytest.mark.asyncio + async def test_free_org_allows_one_member(self, db_session): + """Free org with 0 members should be allowed to add one.""" + org = Organization(slug="free-org", name="Free Org", tier="free") + db_session.add(org) + await db_session.flush() + # 0 members, limit is 1 — should pass + await check_member_limit(db_session, org) + + +class TestCheckServiceLimit: + """Test check_service_limit with real database.""" + + @pytest.mark.asyncio + async def test_free_org_can_create_first_services(self, db_session): + """Free org should be allowed to create services up to limit.""" + org = Organization(slug="free-org", name="Free Org", tier="free") + db_session.add(org) + await db_session.flush() + # Create 4 services (limit is 5) + for i in range(4): + svc = Service( + name=f"svc-{i}", slug=f"svc-{i}", organization_id=org.id + ) + db_session.add(svc) + await db_session.flush() + # Should not raise + await check_service_limit(db_session, org) + + @pytest.mark.asyncio + async def test_free_org_blocked_at_sixth_service(self, db_session): + """Free org should be blocked at the 6th service.""" + org = Organization(slug="free-org", name="Free Org", tier="free") + db_session.add(org) + await db_session.flush() + # Create 5 services (at limit) + for i in range(5): + svc = Service( + name=f"svc-{i}", slug=f"svc-{i}", organization_id=org.id + ) + db_session.add(svc) + await db_session.flush() + with pytest.raises(TierLimitExceeded): + await check_service_limit(db_session, org) + + +# ── API integration tests ────────────────────────────────────────────────────── + +REGISTER_URL = "/api/v1/auth/register" +LOGIN_URL = "/api/v1/auth/login" +ME_URL = "/api/v1/auth/me" +ORGS_TIERS_URL = "/api/v1/organizations/tiers" +ORGS_MY_URL = "/api/v1/organizations/my" +ORGS_MY_LIMITS_URL = "/api/v1/organizations/my/limits" +ORGS_MY_TIER_URL = "/api/v1/organizations/my/tier" + + +class TestTierAPIEndpoints: + """Test the organization/tier API endpoints.""" + + @pytest.mark.asyncio + async def test_list_tiers_public(self, client): + """List tiers endpoint should be accessible without auth.""" + response = await client.get(ORGS_TIERS_URL) + assert response.status_code == 200 + data = response.json() + assert "tiers" in data + assert len(data["tiers"]) == 3 + tier_names = [t["name"] for t in data["tiers"]] + assert "free" in tier_names + assert "pro" in tier_names + assert "team" in tier_names + + @pytest.mark.asyncio + async def test_get_my_org(self, client): + """Authenticated user should be able to get their org info.""" + # Register + reg_response = await client.post( + REGISTER_URL, + json={"email": "orguser@example.com", "password": "testpass123"}, + ) + assert reg_response.status_code == 201 + token = reg_response.json()["access_token"] + + # Get org info + response = await client.get( + ORGS_MY_URL, + headers={"Authorization": f"Bearer {token}"}, + ) + assert response.status_code == 200 + data = response.json() + assert data["tier"] == "free" + assert "tier_info" in data + assert data["tier_info"]["tier"] == "free" + assert data["member_count"] >= 1 + + @pytest.mark.asyncio + async def test_get_my_org_unauthorized(self, client): + """Unauthenticated request should be rejected.""" + response = await client.get(ORGS_MY_URL) + assert response.status_code in (401, 403) + + @pytest.mark.asyncio + async def test_get_my_limits(self, client): + """Authenticated user should be able to see their tier limits.""" + reg_response = await client.post( + REGISTER_URL, + json={"email": "limitsuser@example.com", "password": "testpass123"}, + ) + token = reg_response.json()["access_token"] + + response = await client.get( + ORGS_MY_LIMITS_URL, + headers={"Authorization": f"Bearer {token}"}, + ) + assert response.status_code == 200 + data = response.json() + assert data["tier"] == "free" + assert data["limits"]["status_pages"] == 1 + assert data["limits"]["custom_domain"] is False + + @pytest.mark.asyncio + async def test_upgrade_tier_to_pro(self, client): + """Org owner should be able to upgrade to Pro.""" + reg_response = await client.post( + REGISTER_URL, + json={"email": "upgradeuser@example.com", "password": "testpass123"}, + ) + token = reg_response.json()["access_token"] + + # Upgrade to pro + response = await client.patch( + ORGS_MY_TIER_URL, + json={"tier": "pro"}, + headers={"Authorization": f"Bearer {token}"}, + ) + assert response.status_code == 200 + data = response.json() + assert data["tier"] == "pro" + assert data["tier_info"]["limits"]["status_pages"] == 5 + assert data["tier_info"]["limits"]["custom_domain"] is True + + @pytest.mark.asyncio + async def test_upgrade_tier_to_team(self, client): + """Org owner should be able to upgrade to Team.""" + reg_response = await client.post( + REGISTER_URL, + json={"email": "teamuser@example.com", "password": "testpass123"}, + ) + token = reg_response.json()["access_token"] + + response = await client.patch( + ORGS_MY_TIER_URL, + json={"tier": "team"}, + headers={"Authorization": f"Bearer {token}"}, + ) + assert response.status_code == 200 + data = response.json() + assert data["tier"] == "team" + assert data["tier_info"]["limits"]["status_pages"] == -1 + + @pytest.mark.asyncio + async def test_upgrade_tier_invalid(self, client): + """Upgrading to an invalid tier should be rejected.""" + reg_response = await client.post( + REGISTER_URL, + json={"email": "invalidtier@example.com", "password": "testpass123"}, + ) + token = reg_response.json()["access_token"] + + response = await client.patch( + ORGS_MY_TIER_URL, + json={"tier": "enterprise"}, + headers={"Authorization": f"Bearer {token}"}, + ) + assert response.status_code in (403, 422) + + @pytest.mark.asyncio + async def test_downgrade_back_to_free(self, client): + """Org should be able to downgrade back to free.""" + reg_response = await client.post( + REGISTER_URL, + json={"email": "downgrade@example.com", "password": "testpass123"}, + ) + token = reg_response.json()["access_token"] + + # Upgrade to pro first + await client.patch( + ORGS_MY_TIER_URL, + json={"tier": "pro"}, + headers={"Authorization": f"Bearer {token}"}, + ) + + # Downgrade back to free + response = await client.patch( + ORGS_MY_TIER_URL, + json={"tier": "free"}, + headers={"Authorization": f"Bearer {token}"}, + ) + assert response.status_code == 200 + data = response.json() + assert data["tier"] == "free" + + @pytest.mark.asyncio + async def test_tier_upgrade_unauthorized(self, client): + """Unauthenticated tier upgrade should be rejected.""" + response = await client.patch( + ORGS_MY_TIER_URL, + json={"tier": "pro"}, + ) + assert response.status_code in (401, 403) \ No newline at end of file