590 lines
No EOL
23 KiB
Python
590 lines
No EOL
23 KiB
Python
"""Test tier enforcement: limits, feature flags, and organization endpoints."""
|
|
|
|
import pytest
|
|
from sqlalchemy import select
|
|
|
|
from app.models.saas_models import Organization, OrganizationMember, StatusPage, User
|
|
from app.models.models import Service, Monitor, Subscriber
|
|
from app.services.tier_limits import (
|
|
TIER_LIMITS,
|
|
TierLimitExceeded,
|
|
enforce_limit,
|
|
enforce_feature,
|
|
get_limit,
|
|
get_org_limits,
|
|
get_tier_info,
|
|
get_tier_limits,
|
|
check_status_page_limit,
|
|
check_service_limit,
|
|
check_monitor_limit,
|
|
check_subscriber_limit,
|
|
check_member_limit,
|
|
)
|
|
|
|
|
|
# ── Unit tests for tier_limits module ────────────────────────────────────────
|
|
|
|
class TestTierLimitsConfig:
|
|
"""Test that the TIER_LIMITS config is well-formed."""
|
|
|
|
def test_all_tiers_defined(self):
|
|
"""All three tiers should be defined."""
|
|
assert "free" in TIER_LIMITS
|
|
assert "pro" in TIER_LIMITS
|
|
assert "team" in TIER_LIMITS
|
|
|
|
def test_free_tier_has_expected_keys(self):
|
|
"""Free tier should have all expected limit keys."""
|
|
free = TIER_LIMITS["free"]
|
|
expected_keys = {
|
|
"status_pages", "services_per_page", "monitors_per_service",
|
|
"subscribers", "members", "check_interval_min",
|
|
"custom_domain", "custom_branding", "webhooks",
|
|
"api_access", "incident_history_days", "sla_badge",
|
|
"password_protection",
|
|
}
|
|
assert set(free.keys()) == expected_keys
|
|
|
|
def test_free_tier_values(self):
|
|
"""Free tier should have restrictive values."""
|
|
free = TIER_LIMITS["free"]
|
|
assert free["status_pages"] == 1
|
|
assert free["services_per_page"] == 5
|
|
assert free["monitors_per_service"] == 1
|
|
assert free["subscribers"] == 25
|
|
assert free["members"] == 1
|
|
assert free["custom_domain"] is False
|
|
assert free["webhooks"] is False
|
|
assert free["api_access"] is False
|
|
|
|
def test_pro_tier_values(self):
|
|
"""Pro tier should have moderate values."""
|
|
pro = TIER_LIMITS["pro"]
|
|
assert pro["status_pages"] == 5
|
|
assert pro["services_per_page"] == 50
|
|
assert pro["monitors_per_service"] == 5
|
|
assert pro["subscribers"] == 500
|
|
assert pro["members"] == 3
|
|
assert pro["custom_domain"] is True
|
|
assert pro["webhooks"] is True
|
|
|
|
def test_team_tier_values(self):
|
|
"""Team tier should have unlimited (-1) for most things."""
|
|
team = TIER_LIMITS["team"]
|
|
assert team["status_pages"] == -1
|
|
assert team["services_per_page"] == -1
|
|
assert team["monitors_per_service"] == -1
|
|
assert team["subscribers"] == -1
|
|
assert team["members"] == -1
|
|
assert team["custom_domain"] is True
|
|
assert team["password_protection"] is True
|
|
|
|
def test_all_tiers_have_same_keys(self):
|
|
"""All tiers should have exactly the same set of keys."""
|
|
keys = set(TIER_LIMITS["free"].keys())
|
|
for tier_name, tier_data in TIER_LIMITS.items():
|
|
assert set(tier_data.keys()) == keys, f"Tier '{tier_name}' has different keys"
|
|
|
|
|
|
class TestGetLimitHelpers:
|
|
"""Test the helper functions."""
|
|
|
|
def test_get_tier_limits_known_tier(self):
|
|
"""get_tier_limits should return the correct dict for known tiers."""
|
|
assert get_tier_limits("free") == TIER_LIMITS["free"]
|
|
assert get_tier_limits("pro") == TIER_LIMITS["pro"]
|
|
assert get_tier_limits("team") == TIER_LIMITS["team"]
|
|
|
|
def test_get_tier_limits_unknown_tier(self):
|
|
"""get_tier_limits should default to free for unknown tiers."""
|
|
assert get_tier_limits("enterprise") == TIER_LIMITS["free"]
|
|
assert get_tier_limits("") == TIER_LIMITS["free"]
|
|
|
|
def test_get_org_limits_with_tier(self):
|
|
"""get_org_limits should use the org's tier."""
|
|
org = Organization(slug="test", name="Test", tier="pro")
|
|
assert get_org_limits(org) == TIER_LIMITS["pro"]
|
|
|
|
def test_get_org_limits_with_none_tier(self):
|
|
"""get_org_limits should default to free when tier is None."""
|
|
org = Organization(slug="test", name="Test", tier=None)
|
|
assert get_org_limits(org) == TIER_LIMITS["free"]
|
|
|
|
def test_get_limit_numeric(self):
|
|
"""get_limit should return numeric limits."""
|
|
org = Organization(slug="test", name="Test", tier="free")
|
|
assert get_limit(org, "status_pages") == 1
|
|
assert get_limit(org, "subscribers") == 25
|
|
|
|
def test_get_limit_boolean(self):
|
|
"""get_limit should return boolean feature flags."""
|
|
org_free = Organization(slug="free-org", name="Free Org", tier="free")
|
|
org_pro = Organization(slug="pro-org", name="Pro Org", tier="pro")
|
|
assert get_limit(org_free, "custom_domain") is False
|
|
assert get_limit(org_pro, "custom_domain") is True
|
|
|
|
def test_get_limit_unlimited(self):
|
|
"""get_limit should return -1 for unlimited features."""
|
|
org = Organization(slug="team-org", name="Team Org", tier="team")
|
|
assert get_limit(org, "status_pages") == -1
|
|
|
|
def test_get_limit_unknown_feature(self):
|
|
"""get_limit should return None for unknown feature names."""
|
|
org = Organization(slug="test", name="Test", tier="free")
|
|
assert get_limit(org, "nonexistent_feature") is None
|
|
|
|
|
|
class TestEnforceLimit:
|
|
"""Test the enforce_limit function."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_enforce_limit_allows_under_limit(self, db_session):
|
|
"""Should not raise when current_count is below the limit."""
|
|
org = Organization(slug="test", name="Test", tier="free")
|
|
db_session.add(org)
|
|
await db_session.flush()
|
|
# status_pages limit is 1 for free, current_count=0 should pass
|
|
await enforce_limit(db_session, org, "status_pages", 0)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_enforce_limit_blocks_at_limit(self, db_session):
|
|
"""Should raise TierLimitExceeded when current_count equals the limit."""
|
|
org = Organization(slug="test", name="Test", tier="free")
|
|
db_session.add(org)
|
|
await db_session.flush()
|
|
# status_pages limit is 1 for free, current_count=1 should fail
|
|
with pytest.raises(TierLimitExceeded) as exc_info:
|
|
await enforce_limit(db_session, org, "status_pages", 1)
|
|
assert "status_pages" in str(exc_info.value.detail)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_enforce_limit_blocks_over_limit(self, db_session):
|
|
"""Should raise TierLimitExceeded when current_count exceeds the limit."""
|
|
org = Organization(slug="test", name="Test", tier="free")
|
|
db_session.add(org)
|
|
await db_session.flush()
|
|
with pytest.raises(TierLimitExceeded):
|
|
await enforce_limit(db_session, org, "status_pages", 5)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_enforce_limit_allows_unlimited(self, db_session):
|
|
"""Should not raise when limit is -1 (unlimited)."""
|
|
org = Organization(slug="test", name="Test", tier="team")
|
|
db_session.add(org)
|
|
await db_session.flush()
|
|
# team has unlimited status_pages, even count=1000 should pass
|
|
await enforce_limit(db_session, org, "status_pages", 1000)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_enforce_limit_blocks_feature_flag(self, db_session):
|
|
"""Should raise TierLimitExceeded when feature is False (not available)."""
|
|
org = Organization(slug="test", name="Test", tier="free")
|
|
db_session.add(org)
|
|
await db_session.flush()
|
|
with pytest.raises(TierLimitExceeded) as exc_info:
|
|
await enforce_limit(db_session, org, "custom_domain", 0)
|
|
assert "not available" in str(exc_info.value.detail)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_enforce_limit_unknown_feature(self, db_session):
|
|
"""Should not raise for an unknown feature (no limit defined)."""
|
|
org = Organization(slug="test", name="Test", tier="free")
|
|
db_session.add(org)
|
|
await db_session.flush()
|
|
# Should not raise for an unknown feature
|
|
await enforce_limit(db_session, org, "nonexistent", 999)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_enforce_limit_pro_tier(self, db_session):
|
|
"""Pro tier allows up to 5 status pages."""
|
|
org = Organization(slug="pro-org", name="Pro Org", tier="pro")
|
|
db_session.add(org)
|
|
await db_session.flush()
|
|
# 5 is the limit, should fail at 5
|
|
with pytest.raises(TierLimitExceeded):
|
|
await enforce_limit(db_session, org, "status_pages", 5)
|
|
# 4 should pass
|
|
await enforce_limit(db_session, org, "status_pages", 4)
|
|
|
|
|
|
class TestEnforceFeature:
|
|
"""Test the enforce_feature boolean feature flag function."""
|
|
|
|
def test_enforce_feature_allows_pro_custom_domain(self):
|
|
"""Pro org should be allowed custom_domain."""
|
|
org = Organization(slug="pro", name="Pro", tier="pro")
|
|
# Should not raise
|
|
enforce_feature(org, "custom_domain")
|
|
|
|
def test_enforce_feature_blocks_free_custom_domain(self):
|
|
"""Free org should be blocked from custom_domain."""
|
|
org = Organization(slug="free", name="Free", tier="free")
|
|
with pytest.raises(TierLimitExceeded):
|
|
enforce_feature(org, "custom_domain")
|
|
|
|
def test_enforce_feature_blocks_free_webhooks(self):
|
|
"""Free org should be blocked from webhooks."""
|
|
org = Organization(slug="free", name="Free", tier="free")
|
|
with pytest.raises(TierLimitExceeded):
|
|
enforce_feature(org, "webhooks")
|
|
|
|
def test_enforce_feature_blocks_free_api_access(self):
|
|
"""Free org should be blocked from api_access."""
|
|
org = Organization(slug="free", name="Free", tier="free")
|
|
with pytest.raises(TierLimitExceeded):
|
|
enforce_feature(org, "api_access")
|
|
|
|
def test_enforce_feature_blocks_free_password_protection(self):
|
|
"""Free org should be blocked from password_protection."""
|
|
org = Organization(slug="free", name="Free", tier="free")
|
|
with pytest.raises(TierLimitExceeded):
|
|
enforce_feature(org, "password_protection")
|
|
|
|
def test_enforce_feature_allows_pro_webhooks(self):
|
|
"""Pro org should be allowed webhooks."""
|
|
org = Organization(slug="pro", name="Pro", tier="pro")
|
|
enforce_feature(org, "webhooks")
|
|
|
|
def test_enforce_feature_unknown_feature(self):
|
|
"""Unknown feature should not raise."""
|
|
org = Organization(slug="free", name="Free", tier="free")
|
|
# Should not raise for unknown
|
|
enforce_feature(org, "quantum_computing")
|
|
|
|
def test_enforce_feature_team_password_protection(self):
|
|
"""Team org should be allowed password_protection."""
|
|
org = Organization(slug="team", name="Team", tier="team")
|
|
enforce_feature(org, "password_protection")
|
|
|
|
|
|
class TestGetTierInfo:
|
|
"""Test get_tier_info helper."""
|
|
|
|
def test_free_org_tier_info(self):
|
|
"""Free org should return complete tier info."""
|
|
org = Organization(id="org-1", slug="free-org", name="Free Org", tier="free")
|
|
info = get_tier_info(org)
|
|
assert info["tier"] == "free"
|
|
assert info["organization_id"] == "org-1"
|
|
assert info["limits"]["status_pages"] == 1
|
|
assert info["limits"]["custom_domain"] is False
|
|
|
|
def test_pro_org_tier_info(self):
|
|
"""Pro org should return correct tier info."""
|
|
org = Organization(id="org-2", slug="pro-org", name="Pro Org", tier="pro")
|
|
info = get_tier_info(org)
|
|
assert info["tier"] == "pro"
|
|
assert info["limits"]["status_pages"] == 5
|
|
assert info["limits"]["custom_domain"] is True
|
|
|
|
def test_default_tier_info(self):
|
|
"""Org with None tier should default to free."""
|
|
org = Organization(id="org-3", slug="default-org", name="Default Org", tier=None)
|
|
info = get_tier_info(org)
|
|
assert info["tier"] == "free"
|
|
assert info["limits"]["status_pages"] == 1
|
|
|
|
|
|
# ── Integration tests for database-backed limit checks ───────────────────────
|
|
|
|
class TestCheckStatusPageLimit:
|
|
"""Test check_status_page_limit with real database."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_free_org_can_create_first_page(self, db_session):
|
|
"""Free org should be allowed to create its first status page."""
|
|
org = Organization(slug="free-org", name="Free Org", tier="free")
|
|
db_session.add(org)
|
|
await db_session.flush()
|
|
# Should not raise (0 existing pages, limit is 1)
|
|
await check_status_page_limit(db_session, org)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_free_org_blocked_at_second_page(self, db_session):
|
|
"""Free org should be blocked from creating a second status page."""
|
|
org = Organization(slug="free-org", name="Free Org", tier="free")
|
|
db_session.add(org)
|
|
await db_session.flush()
|
|
# Create first page
|
|
page = StatusPage(organization_id=org.id, slug="main", title="Main")
|
|
db_session.add(page)
|
|
await db_session.flush()
|
|
# Now limit should be reached
|
|
with pytest.raises(TierLimitExceeded):
|
|
await check_status_page_limit(db_session, org)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_pro_org_allows_up_to_five_pages(self, db_session):
|
|
"""Pro org should be allowed up to 5 status pages."""
|
|
org = Organization(slug="pro-org", name="Pro Org", tier="pro")
|
|
db_session.add(org)
|
|
await db_session.flush()
|
|
# Create 4 pages — should be fine (limit is 5)
|
|
for i in range(4):
|
|
page = StatusPage(organization_id=org.id, slug=f"page-{i}", title=f"Page {i}")
|
|
db_session.add(page)
|
|
await db_session.flush()
|
|
await check_status_page_limit(db_session, org) # should not raise
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_pro_org_blocked_at_sixth_page(self, db_session):
|
|
"""Pro org should be blocked at the 6th status page."""
|
|
org = Organization(slug="pro-org", name="Pro Org", tier="pro")
|
|
db_session.add(org)
|
|
await db_session.flush()
|
|
# Create 5 pages (at the limit)
|
|
for i in range(5):
|
|
page = StatusPage(organization_id=org.id, slug=f"page-{i}", title=f"Page {i}")
|
|
db_session.add(page)
|
|
await db_session.flush()
|
|
with pytest.raises(TierLimitExceeded):
|
|
await check_status_page_limit(db_session, org)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_team_org_unlimited_pages(self, db_session):
|
|
"""Team org should never be blocked for status pages."""
|
|
org = Organization(slug="team-org", name="Team Org", tier="team")
|
|
db_session.add(org)
|
|
await db_session.flush()
|
|
for i in range(50):
|
|
page = StatusPage(organization_id=org.id, slug=f"page-{i}", title=f"Page {i}")
|
|
db_session.add(page)
|
|
await db_session.flush()
|
|
# Should not raise even with 50 pages
|
|
await check_status_page_limit(db_session, org)
|
|
|
|
|
|
class TestCheckMemberLimit:
|
|
"""Test check_member_limit with real database."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_free_org_blocked_at_second_member(self, db_session):
|
|
"""Free org should be blocked from adding a second member."""
|
|
org = Organization(slug="free-org", name="Free Org", tier="free")
|
|
db_session.add(org)
|
|
await db_session.flush()
|
|
# Create one user
|
|
user = User(email="owner@example.com", password_hash="hash")
|
|
db_session.add(user)
|
|
await db_session.flush()
|
|
# Add owner membership
|
|
membership = OrganizationMember(
|
|
organization_id=org.id, user_id=user.id, role="owner"
|
|
)
|
|
db_session.add(membership)
|
|
await db_session.flush()
|
|
# Now at limit — should be blocked
|
|
with pytest.raises(TierLimitExceeded):
|
|
await check_member_limit(db_session, org)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_free_org_allows_one_member(self, db_session):
|
|
"""Free org with 0 members should be allowed to add one."""
|
|
org = Organization(slug="free-org", name="Free Org", tier="free")
|
|
db_session.add(org)
|
|
await db_session.flush()
|
|
# 0 members, limit is 1 — should pass
|
|
await check_member_limit(db_session, org)
|
|
|
|
|
|
class TestCheckServiceLimit:
|
|
"""Test check_service_limit with real database."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_free_org_can_create_first_services(self, db_session):
|
|
"""Free org should be allowed to create services up to limit."""
|
|
org = Organization(slug="free-org", name="Free Org", tier="free")
|
|
db_session.add(org)
|
|
await db_session.flush()
|
|
# Create 4 services (limit is 5)
|
|
for i in range(4):
|
|
svc = Service(
|
|
name=f"svc-{i}", slug=f"svc-{i}", organization_id=org.id
|
|
)
|
|
db_session.add(svc)
|
|
await db_session.flush()
|
|
# Should not raise
|
|
await check_service_limit(db_session, org)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_free_org_blocked_at_sixth_service(self, db_session):
|
|
"""Free org should be blocked at the 6th service."""
|
|
org = Organization(slug="free-org", name="Free Org", tier="free")
|
|
db_session.add(org)
|
|
await db_session.flush()
|
|
# Create 5 services (at limit)
|
|
for i in range(5):
|
|
svc = Service(
|
|
name=f"svc-{i}", slug=f"svc-{i}", organization_id=org.id
|
|
)
|
|
db_session.add(svc)
|
|
await db_session.flush()
|
|
with pytest.raises(TierLimitExceeded):
|
|
await check_service_limit(db_session, org)
|
|
|
|
|
|
# ── API integration tests ──────────────────────────────────────────────────────
|
|
|
|
REGISTER_URL = "/api/v1/auth/register"
|
|
LOGIN_URL = "/api/v1/auth/login"
|
|
ME_URL = "/api/v1/auth/me"
|
|
ORGS_TIERS_URL = "/api/v1/organizations/tiers"
|
|
ORGS_MY_URL = "/api/v1/organizations/my"
|
|
ORGS_MY_LIMITS_URL = "/api/v1/organizations/my/limits"
|
|
ORGS_MY_TIER_URL = "/api/v1/organizations/my/tier"
|
|
|
|
|
|
class TestTierAPIEndpoints:
|
|
"""Test the organization/tier API endpoints."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_list_tiers_public(self, client):
|
|
"""List tiers endpoint should be accessible without auth."""
|
|
response = await client.get(ORGS_TIERS_URL)
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert "tiers" in data
|
|
assert len(data["tiers"]) == 3
|
|
tier_names = [t["name"] for t in data["tiers"]]
|
|
assert "free" in tier_names
|
|
assert "pro" in tier_names
|
|
assert "team" in tier_names
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_my_org(self, client):
|
|
"""Authenticated user should be able to get their org info."""
|
|
# Register
|
|
reg_response = await client.post(
|
|
REGISTER_URL,
|
|
json={"email": "orguser@example.com", "password": "testpass123"},
|
|
)
|
|
assert reg_response.status_code == 201
|
|
token = reg_response.json()["access_token"]
|
|
|
|
# Get org info
|
|
response = await client.get(
|
|
ORGS_MY_URL,
|
|
headers={"Authorization": f"Bearer {token}"},
|
|
)
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["tier"] == "free"
|
|
assert "tier_info" in data
|
|
assert data["tier_info"]["tier"] == "free"
|
|
assert data["member_count"] >= 1
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_my_org_unauthorized(self, client):
|
|
"""Unauthenticated request should be rejected."""
|
|
response = await client.get(ORGS_MY_URL)
|
|
assert response.status_code in (401, 403)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_my_limits(self, client):
|
|
"""Authenticated user should be able to see their tier limits."""
|
|
reg_response = await client.post(
|
|
REGISTER_URL,
|
|
json={"email": "limitsuser@example.com", "password": "testpass123"},
|
|
)
|
|
token = reg_response.json()["access_token"]
|
|
|
|
response = await client.get(
|
|
ORGS_MY_LIMITS_URL,
|
|
headers={"Authorization": f"Bearer {token}"},
|
|
)
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["tier"] == "free"
|
|
assert data["limits"]["status_pages"] == 1
|
|
assert data["limits"]["custom_domain"] is False
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_upgrade_tier_to_pro(self, client):
|
|
"""Org owner should be able to upgrade to Pro."""
|
|
reg_response = await client.post(
|
|
REGISTER_URL,
|
|
json={"email": "upgradeuser@example.com", "password": "testpass123"},
|
|
)
|
|
token = reg_response.json()["access_token"]
|
|
|
|
# Upgrade to pro
|
|
response = await client.patch(
|
|
ORGS_MY_TIER_URL,
|
|
json={"tier": "pro"},
|
|
headers={"Authorization": f"Bearer {token}"},
|
|
)
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["tier"] == "pro"
|
|
assert data["tier_info"]["limits"]["status_pages"] == 5
|
|
assert data["tier_info"]["limits"]["custom_domain"] is True
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_upgrade_tier_to_team(self, client):
|
|
"""Org owner should be able to upgrade to Team."""
|
|
reg_response = await client.post(
|
|
REGISTER_URL,
|
|
json={"email": "teamuser@example.com", "password": "testpass123"},
|
|
)
|
|
token = reg_response.json()["access_token"]
|
|
|
|
response = await client.patch(
|
|
ORGS_MY_TIER_URL,
|
|
json={"tier": "team"},
|
|
headers={"Authorization": f"Bearer {token}"},
|
|
)
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["tier"] == "team"
|
|
assert data["tier_info"]["limits"]["status_pages"] == -1
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_upgrade_tier_invalid(self, client):
|
|
"""Upgrading to an invalid tier should be rejected."""
|
|
reg_response = await client.post(
|
|
REGISTER_URL,
|
|
json={"email": "invalidtier@example.com", "password": "testpass123"},
|
|
)
|
|
token = reg_response.json()["access_token"]
|
|
|
|
response = await client.patch(
|
|
ORGS_MY_TIER_URL,
|
|
json={"tier": "enterprise"},
|
|
headers={"Authorization": f"Bearer {token}"},
|
|
)
|
|
assert response.status_code in (403, 422)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_downgrade_back_to_free(self, client):
|
|
"""Org should be able to downgrade back to free."""
|
|
reg_response = await client.post(
|
|
REGISTER_URL,
|
|
json={"email": "downgrade@example.com", "password": "testpass123"},
|
|
)
|
|
token = reg_response.json()["access_token"]
|
|
|
|
# Upgrade to pro first
|
|
await client.patch(
|
|
ORGS_MY_TIER_URL,
|
|
json={"tier": "pro"},
|
|
headers={"Authorization": f"Bearer {token}"},
|
|
)
|
|
|
|
# Downgrade back to free
|
|
response = await client.patch(
|
|
ORGS_MY_TIER_URL,
|
|
json={"tier": "free"},
|
|
headers={"Authorization": f"Bearer {token}"},
|
|
)
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["tier"] == "free"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_tier_upgrade_unauthorized(self, client):
|
|
"""Unauthenticated tier upgrade should be rejected."""
|
|
response = await client.patch(
|
|
ORGS_MY_TIER_URL,
|
|
json={"tier": "pro"},
|
|
)
|
|
assert response.status_code in (401, 403) |