feat: indie status page MVP -- FastAPI + SQLite

- 8 DB models (services, incidents, monitors, subscribers, etc.)
- Full CRUD API for services, incidents, monitors
- Public status page with live data
- Incident detail page with timeline
- API key authentication
- Uptime monitoring scheduler
- 13 tests passing
- TECHNICAL_DESIGN.md with full spec
This commit is contained in:
IndieStatusBot 2026-04-25 05:00:00 +00:00
commit 902133edd3
4655 changed files with 1342691 additions and 0 deletions

27
.env.example Normal file
View file

@ -0,0 +1,27 @@
# Indie Status Page Settings (copy to .env and fill in)
# App
APP_NAME=Indie Status Page
DATABASE_URL=sqlite+aiosqlite:///./data/statuspage.db
SECRET_KEY=change-me-to-a-random-string
ADMIN_API_KEY=change-me-to-a-secure-api-key
DEBUG=true
# Site
SITE_NAME=My SaaS Status
SITE_URL=http://localhost:8000
SITE_LOGO_URL=
SITE_ACCENT_COLOR=#4f46e5
# SMTP (optional - leave blank to disable email)
SMTP_HOST=
SMTP_PORT=587
SMTP_USER=
SMTP_PASS=
SMTP_FROM=noreply@example.com
# Webhook (optional - leave blank to disable)
WEBHOOK_NOTIFY_URL=
# Uptime Monitoring
MONITOR_CHECK_INTERVAL=60

9
.gitignore vendored Normal file
View file

@ -0,0 +1,9 @@
.venv/
__pycache__/
*.pyc
data/
*.egg-info/
dist/
build/
.pytest_cache/
.env

25
Dockerfile Normal file
View file

@ -0,0 +1,25 @@
FROM python:3.11-slim
WORKDIR /app
# Install system dependencies
RUN apt-get update && apt-get install -y --no-install-recommends \
gcc \
&& rm -rf /var/lib/apt/lists/*
# Copy project files
COPY pyproject.toml .
COPY app/ app/
COPY migrations/ migrations/
# Install Python dependencies
RUN pip install --no-cache-dir .
# Create data directory for SQLite
RUN mkdir -p /app/data
# Expose port
EXPOSE 8000
# Run with uvicorn
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]

12
README.md Normal file
View file

@ -0,0 +1,12 @@
# Indie Status Page
Lightweight, self-hosted status page tool for indie SaaS developers.
## Quick Start
```bash
pip install -e ".[dev]"
uvicorn app.main:app --reload
```
See [TECHNICAL_DESIGN.md](TECHNICAL_DESIGN.md) for full documentation.

384
TECHNICAL_DESIGN.md Normal file
View file

@ -0,0 +1,384 @@
# Indie Status Page — Technical Design Document
## 1. Architecture Overview
Indie Status Page is a self-hosted, lightweight status page tool built for indie SaaS developers. It follows a minimal monolithic architecture with three concerns:
1. **REST API** — FastAPI endpoints for programmatic incident & service management
2. **Public Status Page** — Jinja2-rendered HTML pages (SSR, no JS framework)
3. **Background Workers** — Uptime monitor (HTTP checks) and notification dispatcher
```
┌──────────────────────────────────────────────┐
│ FastAPI App │
│ │
│ ┌──────────┐ ┌──────────┐ ┌───────────┐ │
│ │ REST API │ │ Status │ │ Background│ │
│ │Endpoints │ │ Pages │ │ Scheduler │ │
│ │ /api/v1 │ │ /status/ │ │ (APScheduler)│
│ └────┬─────┘ └────┬─────┘ └─────┬──────┘ │
│ │ │ │ │
│ ┌────┴──────────────┴──────────────┴──────┐ │
│ │ SQLAlchemy + SQLite │ │
│ └────────────────────────────────────────┘ │
└──────────────────────────────────────────────┘
│ │ │
HTTP clients Visitors SMTP / Webhooks
(API/CLI) (public) (notifications)
```
### Design Principles
- **Single binary, single process** — FastAPI + APScheduler in one uvicorn process
- **SQLite as the only store** — zero-config, file-based, easy backups
- **Server-rendered pages** — Jinja2 templates with minimal CSS, no build step
- **No paid dependencies** — all OSS libraries, SMTP for email, SQLite for data
- **Docker-first deployment** — single container, volume-mount the DB file
---
## 2. Database Schema (SQLite via SQLAlchemy)
### 2.1 `services`
| Column | Type | Constraints | Description |
|-------------|---------------|---------------------------|------------------------------------|
| id | UUID | PK, default=uuid4 | Unique service identifier |
| name | VARCHAR(100) | NOT NULL | Display name (e.g. "API") |
| slug | VARCHAR(50) | NOT NULL, UNIQUE | URL slug (e.g. "api") |
| description | TEXT | | Optional one-liner |
| group_name | VARCHAR(50) | | Grouping label (e.g. "Core") |
| position | INTEGER | DEFAULT 0 | Sort order on status page |
| is_visible | BOOLEAN | DEFAULT TRUE | Show on public page? |
| created_at | DATETIME | DEFAULT NOW | |
| updated_at | DATETIME | DEFAULT NOW, on update | |
### 2.2 `incidents`
| Column | Type | Constraints | Description |
|----------------|---------------|---------------------------|--------------------------------------|
| id | UUID | PK, default=uuid4 | |
| service_id | UUID | FK → services.id | Affected service |
| title | VARCHAR(200) | NOT NULL | Incident title |
| status | VARCHAR(20) | NOT NULL | investigating, identified, monitoring, resolved |
| severity | VARCHAR(20) | NOT NULL | minor, major, outage |
| started_at | DATETIME | NOT NULL | When incident began |
| resolved_at | DATETIME | NULLABLE | When resolved |
| created_at | DATETIME | DEFAULT NOW | |
| updated_at | DATETIME | DEFAULT NOW, on update | |
### 2.3 `incident_updates`
| Column | Type | Constraints | Description |
|----------------|---------------|---------------------------|--------------------------------------|
| id | UUID | PK, default=uuid4 | |
| incident_id | UUID | FK → incidents.id | Parent incident |
| status | VARCHAR(20) | NOT NULL | Status at time of update |
| body | TEXT | NOT NULL | Update content (markdown) |
| created_at | DATETIME | DEFAULT NOW | |
### 2.4 `monitors`
| Column | Type | Constraints | Description |
|----------------|---------------|---------------------------|--------------------------------------|
| id | UUID | PK, default=uuid4 | |
| service_id | UUID | FK → services.id | Monitored service |
| url | VARCHAR(500) | NOT NULL | URL to check |
| method | VARCHAR(10) | DEFAULT "GET" | HTTP method |
| expected_status| INTEGER | DEFAULT 200 | Expected HTTP status |
| timeout_seconds| INTEGER | DEFAULT 10 | Request timeout |
| interval_seconds| INTEGER | DEFAULT 60 | Check interval |
| is_active | BOOLEAN | DEFAULT TRUE | Enabled? |
| created_at | DATETIME | DEFAULT NOW | |
| updated_at | DATETIME | DEFAULT NOW, on update | |
### 2.5 `monitor_results`
| Column | Type | Constraints | Description |
|----------------|---------------|---------------------------|--------------------------------------|
| id | UUID | PK, default=uuid4 | |
| monitor_id | UUID | FK → monitors.id | |
| status | VARCHAR(20) | NOT NULL | up, down, degraded |
| response_time_ms| INTEGER | | Latency in ms |
| status_code | INTEGER | | HTTP response code |
| error_message | TEXT | | Error if failed |
| checked_at | DATETIME | NOT NULL | When check ran |
### 2.6 `subscribers`
| Column | Type | Constraints | Description |
|----------------|---------------|---------------------------|--------------------------------------|
| id | UUID | PK, default=uuid4 | |
| email | VARCHAR(255) | NOT NULL, UNIQUE | Subscriber email |
| is_confirmed | BOOLEAN | DEFAULT FALSE | Double-opt-in? |
| confirm_token | VARCHAR(100) | | Email confirmation token |
| created_at | DATETIME | DEFAULT NOW | |
### 2.7 `notification_logs`
| Column | Type | Constraints | Description |
|----------------|---------------|---------------------------|--------------------------------------|
| id | UUID | PK, default=uuid4 | |
| incident_id | UUID | FK → incidents.id | |
| subscriber_id | UUID | FK → subscribers.id | |
| channel | VARCHAR(20) | NOT NULL | email, webhook |
| status | VARCHAR(20) | NOT NULL | sent, failed |
| created_at | DATETIME | DEFAULT NOW | |
### 2.8 `site_settings`
| Column | Type | Constraints | Description |
|-------------|---------------|---------------------------|--------------------------------------|
| id | UUID | PK, default=uuid4 | |
| key | VARCHAR(50) | NOT NULL, UNIQUE | Setting name |
| value | TEXT | | Setting value (JSON-serializable) |
| updated_at | DATETIME | DEFAULT NOW, on update | |
Pre-seeded settings: `site_name`, `site_url`, `logo_url`, `accent_color`, `smtp_host`, `smtp_port`, `smtp_user`, `smtp_pass`, `notify_from`, `webhook_url`.
---
## 3. API Endpoints
All API endpoints are versioned under `/api/v1`. Authentication uses a simple API key via `X-API-Key` header (stored in `site_settings` as `admin_api_key`).
### 3.1 Services
| Method | Endpoint | Description |
|--------|-----------------------|-----------------------|
| GET | /api/v1/services | List all services |
| POST | /api/v1/services | Create a service |
| GET | /api/v1/services/{id} | Get a service |
| PATCH | /api/v1/services/{id} | Update a service |
| DELETE | /api/v1/services/{id} | Delete a service |
### 3.2 Incidents
| Method | Endpoint | Description |
|--------|--------------------------------|--------------------------|
| GET | /api/v1/incidents | List incidents (filterable) |
| POST | /api/v1/incidents | Create an incident |
| GET | /api/v1/incidents/{id} | Get incident + updates |
| PATCH | /api/v1/incidents/{id} | Update incident status |
| DELETE | /api/v1/incidents/{id} | Delete incident |
| POST | /api/v1/incidents/{id}/updates | Add an update |
### 3.3 Monitors
| Method | Endpoint | Description |
|--------|------------------------|-----------------------|
| GET | /api/v1/monitors | List all monitors |
| POST | /api/v1/monitors | Create a monitor |
| GET | /api/v1/monitors/{id} | Get monitor + recent results |
| PATCH | /api/v1/monitors/{id} | Update monitor |
| DELETE | /api/v1/monitors/{id} | Delete monitor |
| POST | /api/v1/monitors/{id}/check | Trigger manual check |
### 3.4 Subscribers
| Method | Endpoint | Description |
|--------|----------------------------|-----------------------|
| GET | /api/v1/subscribers | List subscribers |
| POST | /api/v1/subscribers | Add subscriber |
| DELETE | /api/v1/subscribers/{id} | Remove subscriber |
| POST | /api/v1/subscribers/{id}/confirm | Confirm subscription |
### 3.5 Site Settings
| Method | Endpoint | Description |
|--------|-----------------------------|-----------------------|
| GET | /api/v1/settings | List all settings |
| PATCH | /api/v1/settings | Update settings |
### 3.6 Public Status Page (HTML)
| Method | Endpoint | Description |
|--------|---------------------|-------------------------|
| GET | / | Status page (HTML) |
| GET | /incident/{id} | Incident detail (HTML) |
| GET | /subscribe | Subscribe form (HTML) |
| POST | /subscribe | Handle subscription |
| GET | /confirm/{token} | Confirm email (HTML) |
### 3.7 Health
| Method | Endpoint | Description |
|--------|--------------|-----------------|
| GET | /health | Health check |
---
## 4. File/Folder Structure
```
indie-status-page/
├── pyproject.toml # Project config, deps, scripts
├── README.md # Quick-start guide
├── TECHNICAL_DESIGN.md # This document
├── Dockerfile # Production container
├── docker-compose.yml # Local dev setup
├── .env.example # Environment variable template
├── alembic.ini # DB migration config
├── migrations/
│ └── versions/ # Alembic migration scripts
├── app/
│ ├── __init__.py
│ ├── main.py # FastAPI app factory + lifespan
│ ├── config.py # Pydantic settings from env
│ ├── database.py # SQLAlchemy engine, session, Base
│ ├── dependencies.py # FastAPI deps (DB session, auth)
│ ├── models/
│ │ ├── __init__.py # Re-exports all models
│ │ ├── service.py # Service model
│ │ ├── incident.py # Incident + IncidentUpdate models
│ │ ├── monitor.py # Monitor + MonitorResult models
│ │ ├── subscriber.py # Subscriber model
│ │ ├── notification.py # NotificationLog model
│ │ └── settings.py # SiteSettings model
│ ├── api/
│ │ ├── __init__.py
│ │ ├── router.py # v1 router aggregation
│ │ ├── services.py # Service CRUD endpoints
│ │ ├── incidents.py # Incident CRUD + updates
│ │ ├── monitors.py # Monitor CRUD + manual check
│ │ ├── subscribers.py # Subscriber management
│ │ └── settings.py # Settings endpoints
│ ├── services/
│ │ ├── __init__.py
│ │ ├── uptime.py # HTTP check logic
│ │ ├── notifier.py # Email (SMTP) + webhook dispatch
│ │ └── scheduler.py # APScheduler job registration
│ ├── templates/
│ │ ├── base.html # Layout template
│ │ ├── status.html # Public status page
│ │ ├── incident.html # Incident detail page
│ │ ├── subscribe.html # Subscribe form
│ │ └── confirm.html # Confirmation page
│ └── static/
│ ├── css/
│ │ └── style.css # Minimal responsive styles
│ └── js/
│ └── status.js # Auto-refresh (optional)
├── tests/
│ ├── __init__.py
│ ├── conftest.py # Fixtures, test DB
│ ├── test_api_services.py
│ ├── test_api_incidents.py
│ ├── test_api_monitors.py
│ ├── test_api_subscribers.py
│ ├── test_health.py
│ └── test_services_uptime.py
└── scripts/
└── cli.py # CLI tool for managing incidents
```
---
## 5. Core Models (Pydantic Schemas)
### Service
```python
class ServiceCreate(BaseModel):
name: str = Field(..., max_length=100)
slug: str = Field(..., max_length=50, pattern=r"^[a-z0-9-]+$")
description: str | None = None
group_name: str | None = Field(None, max_length=50)
position: int = 0
is_visible: bool = True
class ServiceRead(ServiceCreate):
id: UUID
created_at: datetime
updated_at: datetime
model_config = ConfigDict(from_attributes=True)
```
### Incident
```python
class IncidentCreate(BaseModel):
service_id: UUID
title: str = Field(..., max_length=200)
status: Literal["investigating", "identified", "monitoring", "resolved"]
severity: Literal["minor", "major", "outage"]
started_at: datetime | None = None # defaults to now
class IncidentUpdate(BaseModel):
status: Literal["investigating", "identified", "monitoring", "resolved"]
body: str
```
### Monitor
```python
class MonitorCreate(BaseModel):
service_id: UUID
url: HttpUrl
method: Literal["GET", "POST", "HEAD"] = "GET"
expected_status: int = 200
timeout_seconds: int = Field(10, ge=1, le=60)
interval_seconds: int = Field(60, ge=30, le=3600)
```
### Subscriber
```python
class SubscriberCreate(BaseModel):
email: EmailStr
```
---
## 6. MVP Feature List (Prioritized)
### P0 — Ship Day 1 (Must Have)
| # | Feature | Description |
|---|-----------------------------------|------------------------------------------------|
| 1 | Service CRUD | Add/edit/delete services to track |
| 2 | Incident CRUD + Updates | Create incidents, post updates, resolve |
| 3 | Public Status Page | SSR status page with current service status |
| 4 | Uptime Monitoring (HTTP) | Periodic HTTP GET checks, store results |
| 5 | Status Derivation | Auto-derive service status from monitors |
| 6 | Health Check Endpoint | /health for container orchestration |
| 7 | API Key Auth | Simple X-API-Key header for protected routes |
| 8 | Docker Setup | Dockerfile + docker-compose |
### P1 — Ship Day 3-4 (Should Have)
| # | Feature | Description |
|---|-----------------------------------|------------------------------------------------|
| 9 | Subscriber Signup + Confirmation | Email opt-in with double confirmation |
| 10| Email Notifications (SMTP) | Send incident updates to subscribers |
| 11| Webhook Notifications | POST to external URL on incidents |
| 12| CLI Tool | Terminal commands for incident management |
| 13| 90-Day Uptime Calculator | Uptime % from monitor_results for each service |
### P2 — Post-MVP (Nice to Have)
| # | Feature | Description |
|---|-----------------------------------|------------------------------------------------|
| 14| RSS Feed | /feed.xml for incident history |
| 15| Custom Domain Support | CNAME-friendly setup guide |
| 16| Status Badge | SVG badge embed for READMEs |
| 17| Scheduled Maintenance Windows | Planned downtime with auto-resolve |
| 17| Multi-language UI | i18n via Jinja2 templates |
| 18| Slack/Discord Integration | Bot-based notifications |
---
## Technical Decisions & Rationale
| Decision | Choice | Rationale |
|-----------------------------|---------------------|--------------------------------------------------|
| Web Framework | FastAPI | Async support, auto-docs, type validation |
| ORM | SQLAlchemy 2.0 | Mature, async-capable, SQLite-friendly |
| Database | SQLite | Zero-config, single file, perfect for MVP |
| Migrations | Alembic | Standard for SQLAlchemy, easy schema evolution |
| Task Scheduler | APScheduler | In-process, no Redis/Celery needed |
| Templates | Jinja2 | Built into FastAPI, no build step |
| Email | stdlib smtplib | No paid service, any SMTP relay works |
| Settings | pydantic-settings | Env-based config, type validation, .env support |
| Testing | pytest + httpx | Standard Python testing with async ASGI client |
| CLI | Typer | Intuitive CLI with auto-help, same dep as FastAPI |

2
alembic.ini Normal file
View file

@ -0,0 +1,2 @@
# Alembic configuration will be added when migrations are set up.
# For MVP, use the auto-create tables approach via init_db().

0
app/__init__.py Normal file
View file

0
app/api/__init__.py Normal file
View file

195
app/api/incidents.py Normal file
View file

@ -0,0 +1,195 @@
"""Incidents API endpoints."""
from datetime import datetime
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.dependencies import get_db, verify_api_key
from app.models.models import Incident, IncidentUpdate
router = APIRouter()
class IncidentCreate(BaseModel):
service_id: UUID
title: str = Field(..., max_length=200)
status: str = Field(..., pattern=r"^(investigating|identified|monitoring|resolved)$")
severity: str = Field(..., pattern=r"^(minor|major|outage)$")
started_at: datetime | None = None
class IncidentUpdateCreate(BaseModel):
status: str = Field(..., pattern=r"^(investigating|identified|monitoring|resolved)$")
body: str
class IncidentPatch(BaseModel):
title: str | None = None
status: str | None = Field(None, pattern=r"^(investigating|identified|monitoring|resolved)$")
severity: str | None = Field(None, pattern=r"^(minor|major|outage)$")
def serialize_incident(i: Incident) -> dict:
return {
"id": i.id,
"service_id": i.service_id,
"title": i.title,
"status": i.status,
"severity": i.severity,
"started_at": i.started_at.isoformat() if i.started_at else None,
"resolved_at": i.resolved_at.isoformat() if i.resolved_at else None,
"created_at": i.created_at.isoformat() if i.created_at else None,
"updated_at": i.updated_at.isoformat() if i.updated_at else None,
}
async def serialize_incident_detail(i: Incident, db: AsyncSession) -> dict:
"""Serialize incident with its updates, querying explicitly to avoid lazy-load issues."""
data = serialize_incident(i)
# Explicitly query updates instead of relying on lazy-loaded relationship
result = await db.execute(
select(IncidentUpdate)
.where(IncidentUpdate.incident_id == i.id)
.order_by(IncidentUpdate.created_at)
)
updates = result.scalars().all()
data["updates"] = [
{
"id": u.id,
"status": u.status,
"body": u.body,
"created_at": u.created_at.isoformat() if u.created_at else None,
}
for u in updates
]
return data
@router.get("/")
async def list_incidents(
service_id: UUID | None = None,
status: str | None = None,
limit: int = 50,
offset: int = 0,
db: AsyncSession = Depends(get_db),
):
"""List incidents with optional filtering."""
query = select(Incident).order_by(Incident.started_at.desc())
if service_id:
query = query.where(Incident.service_id == str(service_id))
if status:
query = query.where(Incident.status == status)
query = query.offset(offset).limit(limit)
result = await db.execute(query)
incidents = result.scalars().all()
return [serialize_incident(i) for i in incidents]
@router.post("/", status_code=status.HTTP_201_CREATED)
async def create_incident(
data: IncidentCreate,
db: AsyncSession = Depends(get_db),
api_key: str = Depends(verify_api_key),
):
"""Create a new incident."""
incident = Incident(
service_id=str(data.service_id),
title=data.title,
status=data.status,
severity=data.severity,
started_at=data.started_at or datetime.utcnow(),
)
db.add(incident)
await db.flush()
await db.refresh(incident)
return serialize_incident(incident)
@router.get("/{incident_id}")
async def get_incident(incident_id: UUID, db: AsyncSession = Depends(get_db)):
"""Get an incident with its updates."""
result = await db.execute(select(Incident).where(Incident.id == str(incident_id)))
incident = result.scalar_one_or_none()
if not incident:
raise HTTPException(status_code=404, detail="Incident not found")
return await serialize_incident_detail(incident, db)
@router.patch("/{incident_id}")
async def update_incident(
incident_id: UUID,
data: IncidentPatch,
db: AsyncSession = Depends(get_db),
api_key: str = Depends(verify_api_key),
):
"""Update incident fields (title, status, severity)."""
result = await db.execute(select(Incident).where(Incident.id == str(incident_id)))
incident = result.scalar_one_or_none()
if not incident:
raise HTTPException(status_code=404, detail="Incident not found")
update_data = data.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(incident, field, value)
# If status changed to resolved, set resolved_at
if data.status == "resolved" and "status" in update_data:
incident.resolved_at = datetime.utcnow()
await db.flush()
await db.refresh(incident)
return await serialize_incident_detail(incident, db)
@router.delete("/{incident_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_incident(
incident_id: UUID,
db: AsyncSession = Depends(get_db),
api_key: str = Depends(verify_api_key),
):
"""Delete an incident."""
result = await db.execute(select(Incident).where(Incident.id == str(incident_id)))
incident = result.scalar_one_or_none()
if not incident:
raise HTTPException(status_code=404, detail="Incident not found")
await db.delete(incident)
@router.post("/{incident_id}/updates", status_code=status.HTTP_201_CREATED)
async def create_incident_update(
incident_id: UUID,
data: IncidentUpdateCreate,
db: AsyncSession = Depends(get_db),
api_key: str = Depends(verify_api_key),
):
"""Add an update to an incident."""
result = await db.execute(select(Incident).where(Incident.id == str(incident_id)))
incident = result.scalar_one_or_none()
if not incident:
raise HTTPException(status_code=404, detail="Incident not found")
update = IncidentUpdate(
incident_id=str(incident_id),
status=data.status,
body=data.body,
)
db.add(update)
# Also update incident status
incident.status = data.status
# If resolved, set resolved_at
if data.status == "resolved":
incident.resolved_at = datetime.utcnow()
await db.flush()
await db.refresh(update)
return {
"id": update.id,
"incident_id": update.incident_id,
"status": update.status,
"body": update.body,
"created_at": update.created_at.isoformat() if update.created_at else None,
}

166
app/api/monitors.py Normal file
View file

@ -0,0 +1,166 @@
"""Monitors API endpoints."""
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.dependencies import get_db, verify_api_key
from app.models.models import Monitor, MonitorResult
router = APIRouter()
class MonitorCreate(BaseModel):
service_id: UUID
url: str = Field(..., max_length=500)
method: str = Field("GET", pattern=r"^(GET|POST|HEAD)$")
expected_status: int = 200
timeout_seconds: int = Field(10, ge=1, le=60)
interval_seconds: int = Field(60, ge=30, le=3600)
class MonitorUpdate(BaseModel):
url: str | None = Field(None, max_length=500)
method: str | None = Field(None, pattern=r"^(GET|POST|HEAD)$")
expected_status: int | None = None
timeout_seconds: int | None = Field(None, ge=1, le=60)
interval_seconds: int | None = Field(None, ge=30, le=3600)
is_active: bool | None = None
def serialize_monitor(m: Monitor) -> dict:
return {
"id": m.id,
"service_id": m.service_id,
"url": m.url,
"method": m.method,
"expected_status": m.expected_status,
"timeout_seconds": m.timeout_seconds,
"interval_seconds": m.interval_seconds,
"is_active": m.is_active,
"created_at": m.created_at.isoformat() if m.created_at else None,
"updated_at": m.updated_at.isoformat() if m.updated_at else None,
}
@router.get("/")
async def list_monitors(db: AsyncSession = Depends(get_db)):
"""List all monitors."""
result = await db.execute(select(Monitor))
monitors = result.scalars().all()
return [serialize_monitor(m) for m in monitors]
@router.post("/", status_code=status.HTTP_201_CREATED)
async def create_monitor(
data: MonitorCreate,
db: AsyncSession = Depends(get_db),
api_key: str = Depends(verify_api_key),
):
"""Create a new monitor."""
monitor = Monitor(
service_id=str(data.service_id),
url=data.url,
method=data.method,
expected_status=data.expected_status,
timeout_seconds=data.timeout_seconds,
interval_seconds=data.interval_seconds,
)
db.add(monitor)
await db.flush()
await db.refresh(monitor)
return serialize_monitor(monitor)
@router.get("/{monitor_id}")
async def get_monitor(monitor_id: UUID, db: AsyncSession = Depends(get_db)):
"""Get a monitor with recent results."""
result = await db.execute(select(Monitor).where(Monitor.id == str(monitor_id)))
monitor = result.scalar_one_or_none()
if not monitor:
raise HTTPException(status_code=404, detail="Monitor not found")
# Query recent results separately
results_query = (
select(MonitorResult)
.where(MonitorResult.monitor_id == str(monitor_id))
.order_by(MonitorResult.checked_at.desc())
.limit(10)
)
results_result = await db.execute(results_query)
recent_results = results_result.scalars().all()
data = serialize_monitor(monitor)
data["recent_results"] = [
{
"id": r.id,
"status": r.status,
"response_time_ms": r.response_time_ms,
"status_code": r.status_code,
"error_message": r.error_message,
"checked_at": r.checked_at.isoformat() if r.checked_at else None,
}
for r in recent_results
]
return data
@router.patch("/{monitor_id}")
async def update_monitor(
monitor_id: UUID,
data: MonitorUpdate,
db: AsyncSession = Depends(get_db),
api_key: str = Depends(verify_api_key),
):
"""Update a monitor."""
result = await db.execute(select(Monitor).where(Monitor.id == str(monitor_id)))
monitor = result.scalar_one_or_none()
if not monitor:
raise HTTPException(status_code=404, detail="Monitor not found")
update_data = data.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(monitor, field, value)
await db.flush()
await db.refresh(monitor)
return serialize_monitor(monitor)
@router.delete("/{monitor_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_monitor(
monitor_id: UUID,
db: AsyncSession = Depends(get_db),
api_key: str = Depends(verify_api_key),
):
"""Delete a monitor."""
result = await db.execute(select(Monitor).where(Monitor.id == str(monitor_id)))
monitor = result.scalar_one_or_none()
if not monitor:
raise HTTPException(status_code=404, detail="Monitor not found")
await db.delete(monitor)
@router.post("/{monitor_id}/check")
async def trigger_check(
monitor_id: UUID,
db: AsyncSession = Depends(get_db),
api_key: str = Depends(verify_api_key),
):
"""Trigger a manual uptime check for this monitor."""
from app.services.uptime import check_monitor
result = await db.execute(select(Monitor).where(Monitor.id == str(monitor_id)))
monitor = result.scalar_one_or_none()
if not monitor:
raise HTTPException(status_code=404, detail="Monitor not found")
monitor_result = await check_monitor(monitor, db)
return {
"status": monitor_result.status,
"response_time_ms": monitor_result.response_time_ms,
"status_code": monitor_result.status_code,
}

15
app/api/router.py Normal file
View file

@ -0,0 +1,15 @@
from fastapi import APIRouter
from app.api.services import router as services_router
from app.api.incidents import router as incidents_router
from app.api.monitors import router as monitors_router
from app.api.subscribers import router as subscribers_router
from app.api.settings import router as settings_router
api_v1_router = APIRouter()
api_v1_router.include_router(services_router, prefix="/services", tags=["services"])
api_v1_router.include_router(incidents_router, prefix="/incidents", tags=["incidents"])
api_v1_router.include_router(monitors_router, prefix="/monitors", tags=["monitors"])
api_v1_router.include_router(subscribers_router, prefix="/subscribers", tags=["subscribers"])
api_v1_router.include_router(settings_router, prefix="/settings", tags=["settings"])

120
app/api/services.py Normal file
View file

@ -0,0 +1,120 @@
"""Services API endpoints."""
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.dependencies import get_db, verify_api_key
from app.models.models import Service
router = APIRouter()
class ServiceCreate(BaseModel):
name: str = Field(..., max_length=100)
slug: str = Field(..., max_length=50, pattern=r"^[a-z0-9-]+$")
description: str | None = None
group_name: str | None = Field(None, max_length=50)
position: int = 0
is_visible: bool = True
class ServiceUpdate(BaseModel):
name: str | None = None
slug: str | None = Field(None, max_length=50, pattern=r"^[a-z0-9-]+$")
description: str | None = None
group_name: str | None = None
position: int | None = None
is_visible: bool | None = None
def serialize_service(s: Service) -> dict:
return {
"id": s.id,
"name": s.name,
"slug": s.slug,
"description": s.description,
"group_name": s.group_name,
"position": s.position,
"is_visible": s.is_visible,
"created_at": s.created_at.isoformat() if s.created_at else None,
"updated_at": s.updated_at.isoformat() if s.updated_at else None,
}
@router.get("/")
async def list_services(db: AsyncSession = Depends(get_db)):
"""List all services."""
result = await db.execute(select(Service).order_by(Service.position, Service.name))
services = result.scalars().all()
return [serialize_service(s) for s in services]
@router.post("/", status_code=status.HTTP_201_CREATED)
async def create_service(
data: ServiceCreate,
db: AsyncSession = Depends(get_db),
api_key: str = Depends(verify_api_key),
):
"""Create a new service."""
service = Service(
name=data.name,
slug=data.slug,
description=data.description,
group_name=data.group_name,
position=data.position,
is_visible=data.is_visible,
)
db.add(service)
await db.flush()
await db.refresh(service)
return serialize_service(service)
@router.get("/{service_id}")
async def get_service(service_id: UUID, db: AsyncSession = Depends(get_db)):
"""Get a service by ID."""
result = await db.execute(select(Service).where(Service.id == str(service_id)))
service = result.scalar_one_or_none()
if not service:
raise HTTPException(status_code=404, detail="Service not found")
return serialize_service(service)
@router.patch("/{service_id}")
async def update_service(
service_id: UUID,
data: ServiceUpdate,
db: AsyncSession = Depends(get_db),
api_key: str = Depends(verify_api_key),
):
"""Update a service."""
result = await db.execute(select(Service).where(Service.id == str(service_id)))
service = result.scalar_one_or_none()
if not service:
raise HTTPException(status_code=404, detail="Service not found")
update_data = data.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(service, field, value)
await db.flush()
await db.refresh(service)
return serialize_service(service)
@router.delete("/{service_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_service(
service_id: UUID,
db: AsyncSession = Depends(get_db),
api_key: str = Depends(verify_api_key),
):
"""Delete a service."""
result = await db.execute(select(Service).where(Service.id == str(service_id)))
service = result.scalar_one_or_none()
if not service:
raise HTTPException(status_code=404, detail="Service not found")
await db.delete(service)

36
app/api/settings.py Normal file
View file

@ -0,0 +1,36 @@
"""Site settings API endpoints."""
from fastapi import APIRouter, Depends
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.dependencies import get_db, verify_api_key
from app.models.models import SiteSetting
router = APIRouter()
@router.get("/")
async def list_settings(db: AsyncSession = Depends(get_db)):
"""List all site settings."""
result = await db.execute(select(SiteSetting))
settings = result.scalars().all()
return {s.key: s.value for s in settings}
@router.patch("/")
async def update_settings(
updates: dict[str, str],
db: AsyncSession = Depends(get_db),
api_key: str = Depends(verify_api_key),
):
"""Update site settings (key-value pairs)."""
for key, value in updates.items():
result = await db.execute(select(SiteSetting).where(SiteSetting.key == key))
setting = result.scalar_one_or_none()
if setting:
setting.value = value
else:
db.add(SiteSetting(key=key, value=value))
await db.flush()
return {"message": "Settings updated"}

84
app/api/subscribers.py Normal file
View file

@ -0,0 +1,84 @@
"""Subscribers API endpoints."""
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.dependencies import get_db, verify_api_key
from app.models.models import Subscriber
router = APIRouter()
@router.get("/")
async def list_subscribers(db: AsyncSession = Depends(get_db)):
"""List all subscribers."""
result = await db.execute(select(Subscriber))
subscribers = result.scalars().all()
return [
{
"id": s.id,
"email": s.email,
"is_confirmed": s.is_confirmed,
"created_at": s.created_at.isoformat() if s.created_at else None,
}
for s in subscribers
]
@router.post("/", status_code=status.HTTP_201_CREATED)
async def create_subscriber(
email: str,
db: AsyncSession = Depends(get_db),
api_key: str = Depends(verify_api_key),
):
"""Add a new subscriber."""
import uuid
subscriber = Subscriber(
email=email,
confirm_token=str(uuid.uuid4()),
)
db.add(subscriber)
await db.flush()
await db.refresh(subscriber)
return {
"id": subscriber.id,
"email": subscriber.email,
"confirm_token": subscriber.confirm_token,
}
@router.delete("/{subscriber_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_subscriber(
subscriber_id: UUID,
db: AsyncSession = Depends(get_db),
api_key: str = Depends(verify_api_key),
):
"""Remove a subscriber."""
result = await db.execute(select(Subscriber).where(Subscriber.id == str(subscriber_id)))
subscriber = result.scalar_one_or_none()
if not subscriber:
raise HTTPException(status_code=404, detail="Subscriber not found")
await db.delete(subscriber)
@router.post("/{subscriber_id}/confirm")
async def confirm_subscriber(
subscriber_id: UUID,
token: str,
db: AsyncSession = Depends(get_db),
):
"""Confirm a subscriber's email address."""
result = await db.execute(select(Subscriber).where(Subscriber.id == str(subscriber_id)))
subscriber = result.scalar_one_or_none()
if not subscriber:
raise HTTPException(status_code=404, detail="Subscriber not found")
if subscriber.confirm_token != token:
raise HTTPException(status_code=400, detail="Invalid confirmation token")
subscriber.is_confirmed = True
subscriber.confirm_token = None
await db.flush()
return {"message": "Subscriber confirmed", "email": subscriber.email}

44
app/config.py Normal file
View file

@ -0,0 +1,44 @@
from pydantic_settings import BaseSettings
from pathlib import Path
class Settings(BaseSettings):
"""Application settings loaded from environment variables or .env file."""
# App
app_name: str = "Indie Status Page"
database_url: str = "sqlite+aiosqlite:///./data/statuspage.db"
secret_key: str = "change-me-to-a-random-string"
admin_api_key: str = "change-me-to-a-secure-api-key"
debug: bool = False
# Site
site_name: str = "My SaaS Status"
site_url: str = "http://localhost:8000"
site_logo_url: str = ""
site_accent_color: str = "#4f46e5"
# SMTP
smtp_host: str = ""
smtp_port: int = 587
smtp_user: str = ""
smtp_pass: str = ""
smtp_from: str = "noreply@example.com"
# Webhook
webhook_notify_url: str = ""
# Uptime monitoring
monitor_check_interval: int = 60
model_config = {"env_file": ".env", "env_file_encoding": "utf-8"}
@property
def db_path(self) -> Path:
"""Extract filesystem path from SQLite URL for directory creation."""
# Remove the sqlite+aiosqlite:/// prefix
path_str = self.database_url.replace("sqlite+aiosqlite:///", "")
return Path(path_str)
settings = Settings()

26
app/database.py Normal file
View file

@ -0,0 +1,26 @@
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.orm import DeclarativeBase
from app.config import settings
engine = create_async_engine(
settings.database_url,
echo=settings.debug,
future=True,
)
async_session_factory = async_sessionmaker(
engine,
class_=AsyncSession,
expire_on_commit=False,
)
class Base(DeclarativeBase):
pass
async def init_db() -> None:
"""Create all tables (used for dev/first run; in production use Alembic)."""
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)

26
app/dependencies.py Normal file
View file

@ -0,0 +1,26 @@
from fastapi import Depends, Header, HTTPException, status
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import settings
from app.database import async_session_factory
async def get_db() -> AsyncSession:
"""FastAPI dependency that yields an async database session."""
async with async_session_factory() as session:
try:
yield session
await session.commit()
except Exception:
await session.rollback()
raise
async def verify_api_key(x_api_key: str = Header(...)) -> str:
"""Validate the X-API-Key header against the configured admin key."""
if x_api_key != settings.admin_api_key:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid API key",
)
return x_api_key

213
app/main.py Normal file
View file

@ -0,0 +1,213 @@
from contextlib import asynccontextmanager
from datetime import datetime
from fastapi import FastAPI, Request
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from app.config import settings
from app.database import init_db
from app.api.router import api_v1_router
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Application lifespan: create DB directories and tables on startup."""
# Ensure the data directory exists for SQLite
db_path = settings.db_path
db_path.parent.mkdir(parents=True, exist_ok=True)
# Create tables (dev mode; use Alembic in production)
await init_db()
# Start the uptime monitoring scheduler
from app.services.scheduler import start_scheduler, shutdown_scheduler
start_scheduler()
yield
# Shutdown scheduler on exit
shutdown_scheduler()
app = FastAPI(
title=settings.app_name,
version="0.1.0",
description="Lightweight status page tool for indie SaaS developers",
lifespan=lifespan,
)
# API routes
app.include_router(api_v1_router, prefix="/api/v1")
# Static files and templates
app.mount("/static", StaticFiles(directory="app/static"), name="static")
templates = Jinja2Templates(directory="app/templates")
@app.get("/health")
async def health_check():
"""Health check endpoint for container orchestration."""
return {"status": "ok", "version": "0.1.0"}
async def _get_service_status(service_id: str, db) -> str:
"""Derive a service's current status from its monitors' latest results."""
from sqlalchemy import select
from app.models.models import Monitor, MonitorResult
# Get all monitors for this service
result = await db.execute(
select(Monitor).where(Monitor.service_id == service_id, Monitor.is_active == True) # noqa: E712
)
monitors = result.scalars().all()
if not monitors:
return "up" # No monitors = assume operational
# For each monitor, get the latest result
worst_status = "up"
status_priority = {"up": 0, "degraded": 1, "down": 2}
for monitor in monitors:
r = await db.execute(
select(MonitorResult)
.where(MonitorResult.monitor_id == monitor.id)
.order_by(MonitorResult.checked_at.desc())
.limit(1)
)
latest = r.scalar_one_or_none()
if latest:
if status_priority.get(latest.status, 0) > status_priority.get(worst_status, 0):
worst_status = latest.status
return worst_status
@app.get("/")
async def status_page(request: Request):
"""Public status page — shows all visible services and recent incidents."""
from sqlalchemy import select
from app.database import async_session_factory
from app.models.models import Service, Incident
async with async_session_factory() as db:
# Get all visible services, ordered by position then name
result = await db.execute(
select(Service).where(Service.is_visible == True).order_by(Service.position, Service.name) # noqa: E712
)
services = result.scalars().all()
# Build services_by_group dict and attach current_status
services_by_group = {}
service_list = []
for s in services:
current_status = await _get_service_status(s.id, db)
svc_data = {
"id": s.id,
"name": s.name,
"slug": s.slug,
"description": s.description,
"group_name": s.group_name,
"position": s.position,
"current_status": current_status,
}
service_list.append(svc_data)
group = s.group_name or "Services"
if group not in services_by_group:
services_by_group[group] = []
services_by_group[group].append(svc_data)
# Get recent unresolved + recently resolved incidents
result = await db.execute(
select(Incident).order_by(Incident.started_at.desc()).limit(20)
)
incidents = result.scalars().all()
incident_list = [
{
"id": str(i.id),
"title": i.title,
"status": i.status,
"severity": i.severity,
"started_at": i.started_at.isoformat() if i.started_at else None,
"resolved_at": i.resolved_at.isoformat() if i.resolved_at else None,
"service_id": str(i.service_id),
}
for i in incidents
]
# Check for active (unresolved) incidents
has_active = any(i["status"] != "resolved" for i in incident_list)
return templates.TemplateResponse(
"status.html",
{
"request": request,
"site_name": settings.site_name,
"services_by_group": services_by_group,
"incidents": incident_list,
"has_active_incidents": has_active,
"now": datetime.utcnow(),
},
)
@app.get("/incident/{incident_id}")
async def incident_detail_page(request: Request, incident_id: str):
"""Public incident detail page with timeline of updates."""
from sqlalchemy import select
from sqlalchemy.orm import selectinload
from app.database import async_session_factory
from app.models.models import Incident, IncidentUpdate
async with async_session_factory() as db:
result = await db.execute(
select(Incident)
.options(selectinload(Incident.updates))
.where(Incident.id == incident_id)
)
incident = result.scalar_one_or_none()
if not incident:
from fastapi.responses import HTMLResponse
return HTMLResponse("<h1>Incident not found</h1>", status_code=404)
incident_data = {
"id": str(incident.id),
"title": incident.title,
"status": incident.status,
"severity": incident.severity,
"started_at": incident.started_at.isoformat() if incident.started_at else None,
"resolved_at": incident.resolved_at.isoformat() if incident.resolved_at else None,
}
# Eagerly load updates
updates_result = await db.execute(
select(IncidentUpdate)
.where(IncidentUpdate.incident_id == incident_id)
.order_by(IncidentUpdate.created_at.asc())
)
updates = updates_result.scalars().all()
updates_list = [
{
"id": str(u.id),
"status": u.status,
"body": u.body,
"created_at": u.created_at.isoformat() if u.created_at else None,
}
for u in updates
]
return templates.TemplateResponse(
"incident.html",
{
"request": request,
"site_name": settings.site_name,
"incident": incident_data,
"updates": updates_list,
"now": datetime.utcnow(),
},
)

21
app/models/__init__.py Normal file
View file

@ -0,0 +1,21 @@
from app.models.models import (
Service,
Incident,
IncidentUpdate,
Monitor,
MonitorResult,
Subscriber,
NotificationLog,
SiteSetting,
)
__all__ = [
"Service",
"Incident",
"IncidentUpdate",
"Monitor",
"MonitorResult",
"Subscriber",
"NotificationLog",
"SiteSetting",
]

151
app/models/models.py Normal file
View file

@ -0,0 +1,151 @@
import uuid
from datetime import datetime
from sqlalchemy import Boolean, DateTime, Integer, String, Text, ForeignKey, Float
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.database import Base
def _uuid_str() -> str:
return str(uuid.uuid4())
class Service(Base):
__tablename__ = "services"
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=_uuid_str)
name: Mapped[str] = mapped_column(String(100), nullable=False)
slug: Mapped[str] = mapped_column(String(50), unique=True, nullable=False, index=True)
description: Mapped[str | None] = mapped_column(Text, nullable=True)
group_name: Mapped[str | None] = mapped_column(String(50), nullable=True)
position: Mapped[int] = mapped_column(Integer, default=0)
is_visible: Mapped[bool] = mapped_column(Boolean, default=True)
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
updated_at: Mapped[datetime] = mapped_column(
DateTime, default=datetime.utcnow, onupdate=datetime.utcnow
)
incidents: Mapped[list["Incident"]] = relationship(back_populates="service")
monitors: Mapped[list["Monitor"]] = relationship(back_populates="service")
class Incident(Base):
__tablename__ = "incidents"
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=_uuid_str)
service_id: Mapped[str] = mapped_column(
String(36), ForeignKey("services.id"), nullable=False, index=True
)
title: Mapped[str] = mapped_column(String(200), nullable=False)
status: Mapped[str] = mapped_column(String(20), nullable=False, index=True)
# investigating | identified | monitoring | resolved
severity: Mapped[str] = mapped_column(String(20), nullable=False)
# minor | major | outage
started_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, default=datetime.utcnow)
resolved_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
updated_at: Mapped[datetime] = mapped_column(
DateTime, default=datetime.utcnow, onupdate=datetime.utcnow
)
service: Mapped["Service"] = relationship(back_populates="incidents")
updates: Mapped[list["IncidentUpdate"]] = relationship(
back_populates="incident", cascade="all, delete-orphan"
)
notifications: Mapped[list["NotificationLog"]] = relationship(back_populates="incident")
class IncidentUpdate(Base):
__tablename__ = "incident_updates"
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=_uuid_str)
incident_id: Mapped[str] = mapped_column(
String(36), ForeignKey("incidents.id"), nullable=False, index=True
)
status: Mapped[str] = mapped_column(String(20), nullable=False)
body: Mapped[str] = mapped_column(Text, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
incident: Mapped["Incident"] = relationship(back_populates="updates")
class Monitor(Base):
__tablename__ = "monitors"
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=_uuid_str)
service_id: Mapped[str] = mapped_column(
String(36), ForeignKey("services.id"), nullable=False, index=True
)
url: Mapped[str] = mapped_column(String(500), nullable=False)
method: Mapped[str] = mapped_column(String(10), default="GET")
expected_status: Mapped[int] = mapped_column(Integer, default=200)
timeout_seconds: Mapped[int] = mapped_column(Integer, default=10)
interval_seconds: Mapped[int] = mapped_column(Integer, default=60)
is_active: Mapped[bool] = mapped_column(Boolean, default=True)
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
updated_at: Mapped[datetime] = mapped_column(
DateTime, default=datetime.utcnow, onupdate=datetime.utcnow
)
service: Mapped["Service"] = relationship(back_populates="monitors")
results: Mapped[list["MonitorResult"]] = relationship(
back_populates="monitor", cascade="all, delete-orphan"
)
class MonitorResult(Base):
__tablename__ = "monitor_results"
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=_uuid_str)
monitor_id: Mapped[str] = mapped_column(
String(36), ForeignKey("monitors.id"), nullable=False, index=True
)
status: Mapped[str] = mapped_column(String(20), nullable=False) # up | down | degraded
response_time_ms: Mapped[int | None] = mapped_column(Integer, nullable=True)
status_code: Mapped[int | None] = mapped_column(Integer, nullable=True)
error_message: Mapped[str | None] = mapped_column(Text, nullable=True)
checked_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, default=datetime.utcnow)
monitor: Mapped["Monitor"] = relationship(back_populates="results")
class Subscriber(Base):
__tablename__ = "subscribers"
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=_uuid_str)
email: Mapped[str] = mapped_column(String(255), unique=True, nullable=False, index=True)
is_confirmed: Mapped[bool] = mapped_column(Boolean, default=False)
confirm_token: Mapped[str | None] = mapped_column(String(100), nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
notifications: Mapped[list["NotificationLog"]] = relationship(back_populates="subscriber")
class NotificationLog(Base):
__tablename__ = "notification_logs"
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=_uuid_str)
incident_id: Mapped[str] = mapped_column(
String(36), ForeignKey("incidents.id"), nullable=False, index=True
)
subscriber_id: Mapped[str] = mapped_column(
String(36), ForeignKey("subscribers.id"), nullable=False
)
channel: Mapped[str] = mapped_column(String(20), nullable=False) # email | webhook
status: Mapped[str] = mapped_column(String(20), nullable=False) # sent | failed
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
incident: Mapped["Incident"] = relationship(back_populates="notifications")
subscriber: Mapped["Subscriber"] = relationship(back_populates="notifications")
class SiteSetting(Base):
__tablename__ = "site_settings"
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=_uuid_str)
key: Mapped[str] = mapped_column(String(50), unique=True, nullable=False, index=True)
value: Mapped[str | None] = mapped_column(Text, nullable=True)
updated_at: Mapped[datetime] = mapped_column(
DateTime, default=datetime.utcnow, onupdate=datetime.utcnow
)

0
app/services/__init__.py Normal file
View file

110
app/services/notifier.py Normal file
View file

@ -0,0 +1,110 @@
"""Notification service — email (SMTP) and webhook dispatch."""
import json
import smtplib
from email.mime.text import MIMEText
from datetime import datetime
import httpx
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import settings
from app.models.models import Incident, NotificationLog, Subscriber
async def send_email_notification(
to_email: str,
subject: str,
body: str,
) -> bool:
"""Send an email notification via SMTP. Returns True if successful."""
if not settings.smtp_host:
return False
msg = MIMEText(body, "html")
msg["Subject"] = subject
msg["From"] = settings.smtp_from
msg["To"] = to_email
try:
with smtplib.SMTP(settings.smtp_host, settings.smtp_port) as server:
if settings.smtp_user:
server.starttls()
server.login(settings.smtp_user, settings.smtp_pass)
server.send_message(msg)
return True
except Exception:
return False
async def send_webhook_notification(
payload: dict,
) -> bool:
"""Send a webhook POST notification. Returns True if successful."""
if not settings.webhook_notify_url:
return False
try:
async with httpx.AsyncClient() as client:
response = await client.post(
settings.webhook_notify_url,
json=payload,
timeout=10.0,
)
return response.status_code < 400
except Exception:
return False
async def notify_subscribers(
incident: Incident,
db: AsyncSession,
) -> int:
"""Notify all confirmed subscribers about an incident update. Returns count notified."""
result = await db.execute(
select(Subscriber).where(Subscriber.is_confirmed == True) # noqa: E712
)
subscribers = result.scalars().all()
notified = 0
subject = f"[{incident.severity.upper()}] {incident.title}"
for subscriber in subscribers:
# Email notification
email_sent = await send_email_notification(
to_email=subscriber.email,
subject=subject,
body=f"<p>{incident.title}</p><p>Status: {incident.status}</p>",
)
if email_sent:
log = NotificationLog(
incident_id=incident.id,
subscriber_id=subscriber.id,
channel="email",
status="sent",
)
db.add(log)
notified += 1
# Webhook notification
webhook_sent = await send_webhook_notification(
payload={
"incident_id": incident.id,
"title": incident.title,
"status": incident.status,
"severity": incident.severity,
"started_at": incident.started_at.isoformat() if incident.started_at else None,
}
)
if webhook_sent:
log = NotificationLog(
incident_id=incident.id,
subscriber_id=subscriber.id,
channel="webhook",
status="sent",
)
db.add(log)
await db.flush()
return notified

59
app/services/scheduler.py Normal file
View file

@ -0,0 +1,59 @@
"""Background scheduler for uptime monitoring using APScheduler."""
import asyncio
import logging
from datetime import datetime
from apscheduler.schedulers.asyncio import AsyncIOScheduler
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.database import async_session_factory
from app.models.models import Monitor
from app.services.uptime import check_monitor
logger = logging.getLogger(__name__)
_scheduler: AsyncIOScheduler | None = None
async def _run_monitor_checks() -> None:
"""Check all active monitors."""
async with async_session_factory() as db:
result = await db.execute(select(Monitor).where(Monitor.is_active == True)) # noqa: E712
monitors = result.scalars().all()
for monitor in monitors:
try:
await check_monitor(monitor, db)
except Exception as exc:
logger.error(f"Monitor check failed for {monitor.url}: {exc}")
await db.commit()
def start_scheduler() -> None:
"""Start the APScheduler with periodic monitor checks."""
global _scheduler
if _scheduler is not None:
return
_scheduler = AsyncIOScheduler()
_scheduler.add_job(
_run_monitor_checks,
"interval",
seconds=60,
id="monitor_checks",
replace_existing=True,
)
_scheduler.start()
logger.info("Uptime monitoring scheduler started (interval: 60s)")
def shutdown_scheduler() -> None:
"""Gracefully shut down the scheduler."""
global _scheduler
if _scheduler is not None:
_scheduler.shutdown(wait=False)
_scheduler = None
logger.info("Uptime monitoring scheduler stopped")

59
app/services/uptime.py Normal file
View file

@ -0,0 +1,59 @@
"""Uptime monitoring service — performs HTTP health checks."""
import time
from datetime import datetime
import httpx
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.models import Monitor, MonitorResult
async def check_monitor(monitor: Monitor, db: AsyncSession) -> MonitorResult:
"""Perform a single HTTP health check for a monitor and store the result."""
start = time.monotonic()
try:
async with httpx.AsyncClient() as client:
response = await client.request(
method=monitor.method,
url=monitor.url,
timeout=monitor.timeout_seconds,
follow_redirects=True,
)
elapsed_ms = int((time.monotonic() - start) * 1000)
if response.status_code == monitor.expected_status:
# Check response time threshold for "degraded"
status = "up" if elapsed_ms < 5000 else "degraded"
result = MonitorResult(
monitor_id=monitor.id,
status=status,
response_time_ms=elapsed_ms,
status_code=response.status_code,
error_message=None,
checked_at=datetime.utcnow(),
)
else:
result = MonitorResult(
monitor_id=monitor.id,
status="down",
response_time_ms=elapsed_ms,
status_code=response.status_code,
error_message=f"Expected {monitor.expected_status}, got {response.status_code}",
checked_at=datetime.utcnow(),
)
except Exception as exc:
elapsed_ms = int((time.monotonic() - start) * 1000)
result = MonitorResult(
monitor_id=monitor.id,
status="down",
response_time_ms=elapsed_ms,
status_code=None,
error_message=str(exc)[:500],
checked_at=datetime.utcnow(),
)
db.add(result)
await db.flush()
return result

172
app/static/css/style.css Normal file
View file

@ -0,0 +1,172 @@
/* Indie Status Page — Minimal responsive styles */
:root {
--accent: #4f46e5;
--up: #16a34a;
--down: #dc2626;
--degraded: #f59e0b;
--bg: #f9fafb;
--text: #111827;
--border: #e5e7eb;
--card-bg: #ffffff;
}
* {
margin: 0;
padding: 0;
box-sizing: border-box;
}
body {
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif;
background: var(--bg);
color: var(--text);
line-height: 1.6;
}
.container {
max-width: 720px;
margin: 0 auto;
padding: 0 1.5rem;
}
header {
border-bottom: 1px solid var(--border);
padding: 1.5rem 0;
}
header h1 a {
color: var(--text);
text-decoration: none;
font-size: 1.5rem;
}
main {
padding: 2rem 0;
min-height: 70vh;
}
footer {
border-top: 1px solid var(--border);
padding: 1.5rem 0;
color: #6b7280;
font-size: 0.85rem;
}
/* Status banners */
.status-banner {
padding: 1rem;
border-radius: 8px;
text-align: center;
font-weight: 600;
margin-bottom: 2rem;
}
.status-banner--operational {
background: #d1fae5;
color: #065f46;
}
.status-banner--major {
background: #fee2e2;
color: #991b1b;
}
/* Service rows */
.service-row {
display: flex;
justify-content: space-between;
align-items: center;
padding: 0.75rem 1rem;
background: var(--card-bg);
border: 1px solid var(--border);
border-radius: 6px;
margin-bottom: 0.5rem;
}
.service-name {
font-weight: 500;
}
.service-status {
font-size: 0.8rem;
font-weight: 600;
padding: 0.25rem 0.75rem;
border-radius: 9999px;
}
.status-up { background: #d1fae5; color: #065f46; }
.status-down { background: #fee2e2; color: #991b1b; }
.status-degraded { background: #fef3c7; color: #92400e; }
/* Severity badges */
.severity {
font-size: 0.75rem;
font-weight: 600;
padding: 0.15rem 0.5rem;
border-radius: 4px;
margin-right: 0.5rem;
}
.severity-minor { background: #dbeafe; color: #1e40af; }
.severity-major { background: #fef3c7; color: #92400e; }
.severity-outage { background: #fee2e2; color: #991b1b; }
/* Incident cards */
.incident-card {
background: var(--card-bg);
border: 1px solid var(--border);
border-radius: 6px;
padding: 1rem;
margin-bottom: 1rem;
}
.incident-card h3 a { color: var(--text); }
.incident-card h3 a:hover { color: var(--accent); }
.incident-card .timestamp { color: #6b7280; font-size: 0.85rem; }
/* Subscribe form */
.subscribe, .subscribe-page {
margin-top: 2rem;
padding: 1.5rem;
background: var(--card-bg);
border: 1px solid var(--border);
border-radius: 8px;
}
form {
display: flex;
gap: 0.5rem;
margin-top: 0.5rem;
}
input[type="email"] {
flex: 1;
padding: 0.5rem 0.75rem;
border: 1px solid var(--border);
border-radius: 4px;
font-size: 0.95rem;
}
button[type="submit"] {
padding: 0.5rem 1.25rem;
background: var(--accent);
color: white;
border: none;
border-radius: 4px;
font-weight: 600;
cursor: pointer;
}
button[type="submit"]:hover { opacity: 0.9; }
/* Timeline */
.timeline-entry {
padding: 1rem 0;
border-left: 3px solid var(--border);
padding-left: 1.5rem;
margin-left: 0.5rem;
}
.timeline-status { font-weight: 600; }
.timeline-body { margin: 0.25rem 0; }
.timeline-time { color: #6b7280; font-size: 0.85rem; }
/* Confirm page */
.confirm-page { text-align: center; padding: 3rem 0; }

22
app/static/js/status.js Normal file
View file

@ -0,0 +1,22 @@
/* Minimal JS for auto-refreshing the status page every 60 seconds */
(function () {
const REFRESH_INTERVAL = 60000;
function autoRefresh() {
setTimeout(function () {
fetch(window.location.href, { headers: { "X-Requested-With": "XMLHttpRequest" } })
.then(function () {
window.location.reload();
})
.catch(function () {
// Silently fail — the page will try again next interval
});
}, REFRESH_INTERVAL);
}
if (document.readyState === "loading") {
document.addEventListener("DOMContentLoaded", autoRefresh);
} else {
autoRefresh();
}
})();

22
app/templates/base.html Normal file
View file

@ -0,0 +1,22 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>{% block title %}Status{% endblock %} — {{ site_name }}</title>
<link rel="stylesheet" href="/static/css/style.css">
</head>
<body>
<header>
<div class="container">
<h1><a href="/">{{ site_name }}</a></h1>
</div>
</header>
<main class="container">
{% block content %}{% endblock %}
</main>
<footer class="container">
<p>&copy; {{ now.year }} {{ site_name }}. Powered by Indie Status Page.</p>
</footer>
</body>
</html>

View file

@ -0,0 +1,9 @@
{% extends "base.html" %}
{% block title %}Confirmed{% endblock %}
{% block content %}
<div class="confirm-page">
<h1>✅ Subscription Confirmed</h1>
<p>You're now subscribed to status updates. You'll receive notifications when incidents occur.</p>
<a href="/">← Back to Status Page</a>
</div>
{% endblock %}

View file

@ -0,0 +1,26 @@
{% extends "base.html" %}
{% block title %}Incident: {{ incident.title }}{% endblock %}
{% block content %}
<div class="incident-detail">
<a href="/">← Back to Status Page</a>
<h1>{{ incident.title }}</h1>
<div class="incident-meta">
<span class="severity severity-{{ incident.severity }}">{{ incident.severity | title }}</span>
<span class="incident-status">{{ incident.status | title }}</span>
<p>Started: {{ incident.started_at }}</p>
{% if incident.resolved_at %}
<p>Resolved: {{ incident.resolved_at }}</p>
{% endif %}
</div>
<div class="timeline">
{% for update in updates %}
<div class="timeline-entry">
<div class="timeline-status">{{ update.status | title }}</div>
<div class="timeline-body">{{ update.body }}</div>
<div class="timeline-time">{{ update.created_at }}</div>
</div>
{% endfor %}
</div>
</div>
{% endblock %}

49
app/templates/status.html Normal file
View file

@ -0,0 +1,49 @@
{% extends "base.html" %}
{% block title %}Status{% endblock %}
{% block content %}
<div class="status-page">
<section class="overall-status">
{% if has_active_incidents %}
<div class="status-banner status-banner--major">Active Incident</div>
{% else %}
<div class="status-banner status-banner--operational">All Systems Operational</div>
{% endif %}
</section>
<section class="services">
{% for group_name, group_services in services_by_group.items() %}
<h2>{{ group_name or "Services" }}</h2>
{% for service in group_services %}
<div class="service-row">
<span class="service-name">{{ service.name }}</span>
<span class="service-status {% if service.current_status == 'up' %}status-up{% elif service.current_status == 'down' %}status-down{% else %}status-degraded{% endif %}">
{{ service.current_status | title }}
</span>
</div>
{% endfor %}
{% endfor %}
</section>
{% if incidents %}
<section class="incidents">
<h2>Recent Incidents</h2>
{% for incident in incidents %}
<div class="incident-card">
<h3><a href="/incident/{{ incident.id }}">{{ incident.title }}</a></h3>
<span class="severity severity-{{ incident.severity }}">{{ incident.severity | title }}</span>
<span class="incident-status">{{ incident.status | title }}</span>
<p class="timestamp">{{ incident.started_at }}</p>
</div>
{% endfor %}
</section>
{% endif %}
<section class="subscribe">
<p>Get notified when incidents occur:</p>
<form action="/subscribe" method="POST">
<input type="email" name="email" placeholder="you@example.com" required>
<button type="submit">Subscribe</button>
</form>
</section>
</div>
{% endblock %}

View file

@ -0,0 +1,12 @@
{% extends "base.html" %}
{% block title %}Subscribe{% endblock %}
{% block content %}
<div class="subscribe-page">
<h1>Subscribe to Updates</h1>
<p>Get email notifications when we create or update incidents.</p>
<form action="/subscribe" method="POST">
<input type="email" name="email" placeholder="you@example.com" required>
<button type="submit">Subscribe</button>
</form>
</div>
{% endblock %}

23
docker-compose.yml Normal file
View file

@ -0,0 +1,23 @@
version: "3.8"
services:
app:
build: .
ports:
- "8000:8000"
environment:
- DATABASE_URL=sqlite+aiosqlite:///./data/statuspage.db
- SECRET_KEY=${SECRET_KEY:-change-me-to-a-random-string}
- ADMIN_API_KEY=${ADMIN_API_KEY:-change-me-to-a-secure-api-key}
- SITE_NAME=${SITE_NAME:-My SaaS Status}
- SITE_URL=${SITE_URL:-http://localhost:8000}
- SMTP_HOST=${SMTP_HOST:-}
- SMTP_PORT=${SMTP_PORT:-587}
- SMTP_USER=${SMTP_USER:-}
- SMTP_PASS=${SMTP_PASS:-}
- SMTP_FROM=${SMTP_FROM:-noreply@example.com}
volumes:
- statuspage-data:/app/data
volumes:
statuspage-data:

0
migrations/__init__.py Normal file
View file

View file

57
pyproject.toml Normal file
View file

@ -0,0 +1,57 @@
[project]
name = "indie-status-page"
version = "0.1.0"
description = "Lightweight, affordable status page tool for indie SaaS developers"
readme = "README.md"
requires-python = ">=3.11"
license = {text = "MIT"}
dependencies = [
"fastapi>=0.110,<1.0",
"uvicorn[standard]>=0.27,<1.0",
"starlette>=0.36,<1.0",
"sqlalchemy[asyncio]>=2.0,<3.0",
"aiosqlite>=0.19,<1.0",
"alembic>=1.13,<2.0",
"pydantic>=2.5,<3.0",
"pydantic-settings>=2.1,<3.0",
"python-multipart>=0.0.6,<1.0",
"jinja2>=3.1,<4.0",
"apscheduler>=3.10,<4.0",
"httpx>=0.27,<1.0",
"typer>=0.9,<1.0",
"rich>=13.0,<14.0",
]
[project.optional-dependencies]
dev = [
"pytest>=8.0,<9.0",
"pytest-asyncio>=0.23,<1.0",
"httpx>=0.27,<1.0",
"ruff>=0.2,<1.0",
"mypy>=1.8,<2.0",
]
[project.scripts]
statuspage = "app.cli:app"
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[tool.hatch.build.targets.wheel]
packages = ["app"]
[tool.ruff]
target-version = "py311"
line-length = 100
[tool.ruff.lint]
select = ["E", "F", "I", "N", "W", "UP"]
[tool.pytest.ini_options]
asyncio_mode = "auto"
testpaths = ["tests"]
[tool.mypy]
python_version = "3.11"
strict = true

0
tests/__init__.py Normal file
View file

56
tests/conftest.py Normal file
View file

@ -0,0 +1,56 @@
"""Test fixtures for Indie Status Page."""
import pytest_asyncio
from httpx import ASGITransport, AsyncClient
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from app.database import Base
from app.dependencies import get_db
from app.main import app
# Use an in-memory SQLite for tests
TEST_DATABASE_URL = "sqlite+aiosqlite:///:memory:"
test_engine = create_async_engine(TEST_DATABASE_URL, echo=False)
TestSessionLocal = async_sessionmaker(
test_engine, class_=AsyncSession, expire_on_commit=False
)
@pytest_asyncio.fixture(scope="session", autouse=True)
async def setup_database():
"""Create all tables once for the test session."""
async with test_engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield
async with test_engine.begin() as conn:
await conn.run_sync(Base.metadata.drop_all)
@pytest_asyncio.fixture
async def db_session():
"""Provide a clean database session for each test."""
async with TestSessionLocal() as session:
yield session
await session.rollback()
@pytest_asyncio.fixture
async def client(db_session: AsyncSession):
"""Provide an HTTP test client with DB dependency override."""
async def override_get_db():
yield db_session
app.dependency_overrides[get_db] = override_get_db
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as ac:
yield ac
app.dependency_overrides.clear()
@pytest_asyncio.fixture
async def api_key():
"""Return the test API key."""
from app.config import settings
return settings.admin_api_key

154
tests/test_api_incidents.py Normal file
View file

@ -0,0 +1,154 @@
"""Test Incidents API endpoints."""
import pytest
@pytest.mark.asyncio
async def test_create_incident(client, api_key):
"""Should create a new incident after creating a service."""
# Create a service first
svc = await client.post(
"/api/v1/services/",
json={"name": "Auth Service", "slug": "auth-service"},
headers={"X-API-Key": api_key},
)
assert svc.status_code == 201
service_id = svc.json()["id"]
# Create an incident
response = await client.post(
"/api/v1/incidents/",
json={
"service_id": service_id,
"title": "Login failures",
"status": "investigating",
"severity": "major",
},
headers={"X-API-Key": api_key},
)
assert response.status_code == 201
data = response.json()
assert data["title"] == "Login failures"
assert data["status"] == "investigating"
@pytest.mark.asyncio
async def test_list_incidents(client, api_key):
"""Should list incidents."""
response = await client.get("/api/v1/incidents/")
assert response.status_code == 200
data = response.json()
assert isinstance(data, list)
@pytest.mark.asyncio
async def test_get_incident_with_updates(client, api_key):
"""Should get an incident and add an update."""
# Create a service
svc = await client.post(
"/api/v1/services/",
json={"name": "Storage", "slug": "storage"},
headers={"X-API-Key": api_key},
)
service_id = svc.json()["id"]
# Create an incident
inc = await client.post(
"/api/v1/incidents/",
json={
"service_id": service_id,
"title": "Data loss",
"status": "investigating",
"severity": "outage",
},
headers={"X-API-Key": api_key},
)
incident_id = inc.json()["id"]
# Add an update
upd = await client.post(
f"/api/v1/incidents/{incident_id}/updates",
json={
"status": "identified",
"body": "We found the root cause.",
},
headers={"X-API-Key": api_key},
)
assert upd.status_code == 201
assert upd.json()["status"] == "identified"
# Get the incident - should include updates
get_resp = await client.get(f"/api/v1/incidents/{incident_id}")
assert get_resp.status_code == 200
data = get_resp.json()
assert data["status"] == "identified"
assert len(data["updates"]) == 1
assert data["updates"][0]["body"] == "We found the root cause."
@pytest.mark.asyncio
async def test_update_incident(client, api_key):
"""Should update incident status via PATCH."""
# Create service + incident
svc = await client.post(
"/api/v1/services/",
json={"name": "Cache", "slug": "cache"},
headers={"X-API-Key": api_key},
)
service_id = svc.json()["id"]
inc = await client.post(
"/api/v1/incidents/",
json={
"service_id": service_id,
"title": "Slow responses",
"status": "investigating",
"severity": "minor",
},
headers={"X-API-Key": api_key},
)
incident_id = inc.json()["id"]
# Patch the incident
patch_resp = await client.patch(
f"/api/v1/incidents/{incident_id}",
json={"status": "resolved"},
headers={"X-API-Key": api_key},
)
assert patch_resp.status_code == 200
data = patch_resp.json()
assert data["status"] == "resolved"
assert data["resolved_at"] is not None
@pytest.mark.asyncio
async def test_delete_incident(client, api_key):
"""Should delete an incident."""
svc = await client.post(
"/api/v1/services/",
json={"name": "Email", "slug": "email-svc"},
headers={"X-API-Key": api_key},
)
service_id = svc.json()["id"]
inc = await client.post(
"/api/v1/incidents/",
json={
"service_id": service_id,
"title": "Emails delayed",
"status": "monitoring",
"severity": "minor",
},
headers={"X-API-Key": api_key},
)
incident_id = inc.json()["id"]
del_resp = await client.delete(
f"/api/v1/incidents/{incident_id}",
headers={"X-API-Key": api_key},
)
assert del_resp.status_code == 204
# Verify it's gone
get_resp = await client.get(f"/api/v1/incidents/{incident_id}")
assert get_resp.status_code == 404

119
tests/test_api_services.py Normal file
View file

@ -0,0 +1,119 @@
"""Test Services API endpoints."""
import pytest
@pytest.mark.asyncio
async def test_create_service(client, api_key):
"""Should create a new service."""
response = await client.post(
"/api/v1/services/",
json={
"name": "API Server",
"slug": "api-server",
},
headers={"X-API-Key": api_key},
)
assert response.status_code == 201
data = response.json()
assert data["name"] == "API Server"
assert data["slug"] == "api-server"
assert data["is_visible"] is True
@pytest.mark.asyncio
async def test_list_services(client, api_key):
"""Should list all services."""
# Create a service first
await client.post(
"/api/v1/services/",
json={"name": "Database", "slug": "database"},
headers={"X-API-Key": api_key},
)
response = await client.get("/api/v1/services/")
assert response.status_code == 200
data = response.json()
assert isinstance(data, list)
assert len(data) >= 1
@pytest.mark.asyncio
async def test_api_key_required(client):
"""Should reject requests without API key for protected routes."""
response = await client.post(
"/api/v1/services/",
json={"name": "Unauthorized", "slug": "unauth"},
)
assert response.status_code == 422 # Missing required header
@pytest.mark.asyncio
async def test_get_service(client, api_key):
"""Should get a single service by ID."""
create_resp = await client.post(
"/api/v1/services/",
json={"name": "Web App", "slug": "web-app"},
headers={"X-API-Key": api_key},
)
assert create_resp.status_code == 201
service_id = create_resp.json()["id"]
response = await client.get(f"/api/v1/services/{service_id}")
assert response.status_code == 200
data = response.json()
assert data["id"] == service_id
assert data["name"] == "Web App"
@pytest.mark.asyncio
async def test_update_service(client, api_key):
"""Should update a service with PATCH."""
create_resp = await client.post(
"/api/v1/services/",
json={"name": "Old Name", "slug": "old-slug"},
headers={"X-API-Key": api_key},
)
service_id = create_resp.json()["id"]
response = await client.patch(
f"/api/v1/services/{service_id}",
json={"name": "New Name"},
headers={"X-API-Key": api_key},
)
assert response.status_code == 200
data = response.json()
assert data["name"] == "New Name"
assert data["slug"] == "old-slug" # unchanged
@pytest.mark.asyncio
async def test_delete_service(client, api_key):
"""Should delete a service."""
create_resp = await client.post(
"/api/v1/services/",
json={"name": "Delete Me", "slug": "delete-me"},
headers={"X-API-Key": api_key},
)
service_id = create_resp.json()["id"]
response = await client.delete(
f"/api/v1/services/{service_id}",
headers={"X-API-Key": api_key},
)
assert response.status_code == 204
# Verify it's gone
get_resp = await client.get(f"/api/v1/services/{service_id}")
assert get_resp.status_code == 404
@pytest.mark.asyncio
async def test_update_service_not_found(client, api_key):
"""Should return 404 when updating nonexistent service."""
response = await client.patch(
"/api/v1/services/00000000-0000-0000-0000-000000000000",
json={"name": "Ghost"},
headers={"X-API-Key": api_key},
)
assert response.status_code == 404

13
tests/test_health.py Normal file
View file

@ -0,0 +1,13 @@
"""Test the health check endpoint."""
import pytest
@pytest.mark.asyncio
async def test_health_check(client):
"""Health check should return 200 with status ok."""
response = await client.get("/health")
assert response.status_code == 200
data = response.json()
assert data["status"] == "ok"
assert "version" in data

247
venv/bin/Activate.ps1 Normal file
View file

@ -0,0 +1,247 @@
<#
.Synopsis
Activate a Python virtual environment for the current PowerShell session.
.Description
Pushes the python executable for a virtual environment to the front of the
$Env:PATH environment variable and sets the prompt to signify that you are
in a Python virtual environment. Makes use of the command line switches as
well as the `pyvenv.cfg` file values present in the virtual environment.
.Parameter VenvDir
Path to the directory that contains the virtual environment to activate. The
default value for this is the parent of the directory that the Activate.ps1
script is located within.
.Parameter Prompt
The prompt prefix to display when this virtual environment is activated. By
default, this prompt is the name of the virtual environment folder (VenvDir)
surrounded by parentheses and followed by a single space (ie. '(.venv) ').
.Example
Activate.ps1
Activates the Python virtual environment that contains the Activate.ps1 script.
.Example
Activate.ps1 -Verbose
Activates the Python virtual environment that contains the Activate.ps1 script,
and shows extra information about the activation as it executes.
.Example
Activate.ps1 -VenvDir C:\Users\MyUser\Common\.venv
Activates the Python virtual environment located in the specified location.
.Example
Activate.ps1 -Prompt "MyPython"
Activates the Python virtual environment that contains the Activate.ps1 script,
and prefixes the current prompt with the specified string (surrounded in
parentheses) while the virtual environment is active.
.Notes
On Windows, it may be required to enable this Activate.ps1 script by setting the
execution policy for the user. You can do this by issuing the following PowerShell
command:
PS C:\> Set-ExecutionPolicy -ExecutionPolicy RemoteSigned -Scope CurrentUser
For more information on Execution Policies:
https://go.microsoft.com/fwlink/?LinkID=135170
#>
Param(
[Parameter(Mandatory = $false)]
[String]
$VenvDir,
[Parameter(Mandatory = $false)]
[String]
$Prompt
)
<# Function declarations --------------------------------------------------- #>
<#
.Synopsis
Remove all shell session elements added by the Activate script, including the
addition of the virtual environment's Python executable from the beginning of
the PATH variable.
.Parameter NonDestructive
If present, do not remove this function from the global namespace for the
session.
#>
function global:deactivate ([switch]$NonDestructive) {
# Revert to original values
# The prior prompt:
if (Test-Path -Path Function:_OLD_VIRTUAL_PROMPT) {
Copy-Item -Path Function:_OLD_VIRTUAL_PROMPT -Destination Function:prompt
Remove-Item -Path Function:_OLD_VIRTUAL_PROMPT
}
# The prior PYTHONHOME:
if (Test-Path -Path Env:_OLD_VIRTUAL_PYTHONHOME) {
Copy-Item -Path Env:_OLD_VIRTUAL_PYTHONHOME -Destination Env:PYTHONHOME
Remove-Item -Path Env:_OLD_VIRTUAL_PYTHONHOME
}
# The prior PATH:
if (Test-Path -Path Env:_OLD_VIRTUAL_PATH) {
Copy-Item -Path Env:_OLD_VIRTUAL_PATH -Destination Env:PATH
Remove-Item -Path Env:_OLD_VIRTUAL_PATH
}
# Just remove the VIRTUAL_ENV altogether:
if (Test-Path -Path Env:VIRTUAL_ENV) {
Remove-Item -Path env:VIRTUAL_ENV
}
# Just remove VIRTUAL_ENV_PROMPT altogether.
if (Test-Path -Path Env:VIRTUAL_ENV_PROMPT) {
Remove-Item -Path env:VIRTUAL_ENV_PROMPT
}
# Just remove the _PYTHON_VENV_PROMPT_PREFIX altogether:
if (Get-Variable -Name "_PYTHON_VENV_PROMPT_PREFIX" -ErrorAction SilentlyContinue) {
Remove-Variable -Name _PYTHON_VENV_PROMPT_PREFIX -Scope Global -Force
}
# Leave deactivate function in the global namespace if requested:
if (-not $NonDestructive) {
Remove-Item -Path function:deactivate
}
}
<#
.Description
Get-PyVenvConfig parses the values from the pyvenv.cfg file located in the
given folder, and returns them in a map.
For each line in the pyvenv.cfg file, if that line can be parsed into exactly
two strings separated by `=` (with any amount of whitespace surrounding the =)
then it is considered a `key = value` line. The left hand string is the key,
the right hand is the value.
If the value starts with a `'` or a `"` then the first and last character is
stripped from the value before being captured.
.Parameter ConfigDir
Path to the directory that contains the `pyvenv.cfg` file.
#>
function Get-PyVenvConfig(
[String]
$ConfigDir
) {
Write-Verbose "Given ConfigDir=$ConfigDir, obtain values in pyvenv.cfg"
# Ensure the file exists, and issue a warning if it doesn't (but still allow the function to continue).
$pyvenvConfigPath = Join-Path -Resolve -Path $ConfigDir -ChildPath 'pyvenv.cfg' -ErrorAction Continue
# An empty map will be returned if no config file is found.
$pyvenvConfig = @{ }
if ($pyvenvConfigPath) {
Write-Verbose "File exists, parse `key = value` lines"
$pyvenvConfigContent = Get-Content -Path $pyvenvConfigPath
$pyvenvConfigContent | ForEach-Object {
$keyval = $PSItem -split "\s*=\s*", 2
if ($keyval[0] -and $keyval[1]) {
$val = $keyval[1]
# Remove extraneous quotations around a string value.
if ("'""".Contains($val.Substring(0, 1))) {
$val = $val.Substring(1, $val.Length - 2)
}
$pyvenvConfig[$keyval[0]] = $val
Write-Verbose "Adding Key: '$($keyval[0])'='$val'"
}
}
}
return $pyvenvConfig
}
<# Begin Activate script --------------------------------------------------- #>
# Determine the containing directory of this script
$VenvExecPath = Split-Path -Parent $MyInvocation.MyCommand.Definition
$VenvExecDir = Get-Item -Path $VenvExecPath
Write-Verbose "Activation script is located in path: '$VenvExecPath'"
Write-Verbose "VenvExecDir Fullname: '$($VenvExecDir.FullName)"
Write-Verbose "VenvExecDir Name: '$($VenvExecDir.Name)"
# Set values required in priority: CmdLine, ConfigFile, Default
# First, get the location of the virtual environment, it might not be
# VenvExecDir if specified on the command line.
if ($VenvDir) {
Write-Verbose "VenvDir given as parameter, using '$VenvDir' to determine values"
}
else {
Write-Verbose "VenvDir not given as a parameter, using parent directory name as VenvDir."
$VenvDir = $VenvExecDir.Parent.FullName.TrimEnd("\\/")
Write-Verbose "VenvDir=$VenvDir"
}
# Next, read the `pyvenv.cfg` file to determine any required value such
# as `prompt`.
$pyvenvCfg = Get-PyVenvConfig -ConfigDir $VenvDir
# Next, set the prompt from the command line, or the config file, or
# just use the name of the virtual environment folder.
if ($Prompt) {
Write-Verbose "Prompt specified as argument, using '$Prompt'"
}
else {
Write-Verbose "Prompt not specified as argument to script, checking pyvenv.cfg value"
if ($pyvenvCfg -and $pyvenvCfg['prompt']) {
Write-Verbose " Setting based on value in pyvenv.cfg='$($pyvenvCfg['prompt'])'"
$Prompt = $pyvenvCfg['prompt'];
}
else {
Write-Verbose " Setting prompt based on parent's directory's name. (Is the directory name passed to venv module when creating the virtual environment)"
Write-Verbose " Got leaf-name of $VenvDir='$(Split-Path -Path $venvDir -Leaf)'"
$Prompt = Split-Path -Path $venvDir -Leaf
}
}
Write-Verbose "Prompt = '$Prompt'"
Write-Verbose "VenvDir='$VenvDir'"
# Deactivate any currently active virtual environment, but leave the
# deactivate function in place.
deactivate -nondestructive
# Now set the environment variable VIRTUAL_ENV, used by many tools to determine
# that there is an activated venv.
$env:VIRTUAL_ENV = $VenvDir
if (-not $Env:VIRTUAL_ENV_DISABLE_PROMPT) {
Write-Verbose "Setting prompt to '$Prompt'"
# Set the prompt to include the env name
# Make sure _OLD_VIRTUAL_PROMPT is global
function global:_OLD_VIRTUAL_PROMPT { "" }
Copy-Item -Path function:prompt -Destination function:_OLD_VIRTUAL_PROMPT
New-Variable -Name _PYTHON_VENV_PROMPT_PREFIX -Description "Python virtual environment prompt prefix" -Scope Global -Option ReadOnly -Visibility Public -Value $Prompt
function global:prompt {
Write-Host -NoNewline -ForegroundColor Green "($_PYTHON_VENV_PROMPT_PREFIX) "
_OLD_VIRTUAL_PROMPT
}
$env:VIRTUAL_ENV_PROMPT = $Prompt
}
# Clear PYTHONHOME
if (Test-Path -Path Env:PYTHONHOME) {
Copy-Item -Path Env:PYTHONHOME -Destination Env:_OLD_VIRTUAL_PYTHONHOME
Remove-Item -Path Env:PYTHONHOME
}
# Add the venv to the PATH
Copy-Item -Path Env:PATH -Destination Env:_OLD_VIRTUAL_PATH
$Env:PATH = "$VenvExecDir$([System.IO.Path]::PathSeparator)$Env:PATH"

63
venv/bin/activate Normal file
View file

@ -0,0 +1,63 @@
# This file must be used with "source bin/activate" *from bash*
# you cannot run it directly
deactivate () {
# reset old environment variables
if [ -n "${_OLD_VIRTUAL_PATH:-}" ] ; then
PATH="${_OLD_VIRTUAL_PATH:-}"
export PATH
unset _OLD_VIRTUAL_PATH
fi
if [ -n "${_OLD_VIRTUAL_PYTHONHOME:-}" ] ; then
PYTHONHOME="${_OLD_VIRTUAL_PYTHONHOME:-}"
export PYTHONHOME
unset _OLD_VIRTUAL_PYTHONHOME
fi
# Call hash to forget past commands. Without forgetting
# past commands the $PATH changes we made may not be respected
hash -r 2> /dev/null
if [ -n "${_OLD_VIRTUAL_PS1:-}" ] ; then
PS1="${_OLD_VIRTUAL_PS1:-}"
export PS1
unset _OLD_VIRTUAL_PS1
fi
unset VIRTUAL_ENV
unset VIRTUAL_ENV_PROMPT
if [ ! "${1:-}" = "nondestructive" ] ; then
# Self destruct!
unset -f deactivate
fi
}
# unset irrelevant variables
deactivate nondestructive
VIRTUAL_ENV=/home/ubuntu/wealth-engine/indie-status-page/venv
export VIRTUAL_ENV
_OLD_VIRTUAL_PATH="$PATH"
PATH="$VIRTUAL_ENV/"bin":$PATH"
export PATH
# unset PYTHONHOME if set
# this will fail if PYTHONHOME is set to the empty string (which is bad anyway)
# could use `if (set -u; : $PYTHONHOME) ;` in bash
if [ -n "${PYTHONHOME:-}" ] ; then
_OLD_VIRTUAL_PYTHONHOME="${PYTHONHOME:-}"
unset PYTHONHOME
fi
if [ -z "${VIRTUAL_ENV_DISABLE_PROMPT:-}" ] ; then
_OLD_VIRTUAL_PS1="${PS1:-}"
PS1='(venv) '"${PS1:-}"
export PS1
VIRTUAL_ENV_PROMPT='(venv) '
export VIRTUAL_ENV_PROMPT
fi
# Call hash to forget past commands. Without forgetting
# past commands the $PATH changes we made may not be respected
hash -r 2> /dev/null

26
venv/bin/activate.csh Normal file
View file

@ -0,0 +1,26 @@
# This file must be used with "source bin/activate.csh" *from csh*.
# You cannot run it directly.
# Created by Davide Di Blasi <davidedb@gmail.com>.
# Ported to Python 3.3 venv by Andrew Svetlov <andrew.svetlov@gmail.com>
alias deactivate 'test $?_OLD_VIRTUAL_PATH != 0 && setenv PATH "$_OLD_VIRTUAL_PATH" && unset _OLD_VIRTUAL_PATH; rehash; test $?_OLD_VIRTUAL_PROMPT != 0 && set prompt="$_OLD_VIRTUAL_PROMPT" && unset _OLD_VIRTUAL_PROMPT; unsetenv VIRTUAL_ENV; unsetenv VIRTUAL_ENV_PROMPT; test "\!:*" != "nondestructive" && unalias deactivate'
# Unset irrelevant variables.
deactivate nondestructive
setenv VIRTUAL_ENV /home/ubuntu/wealth-engine/indie-status-page/venv
set _OLD_VIRTUAL_PATH="$PATH"
setenv PATH "$VIRTUAL_ENV/"bin":$PATH"
set _OLD_VIRTUAL_PROMPT="$prompt"
if (! "$?VIRTUAL_ENV_DISABLE_PROMPT") then
set prompt = '(venv) '"$prompt"
setenv VIRTUAL_ENV_PROMPT '(venv) '
endif
alias pydoc python -m pydoc
rehash

69
venv/bin/activate.fish Normal file
View file

@ -0,0 +1,69 @@
# This file must be used with "source <venv>/bin/activate.fish" *from fish*
# (https://fishshell.com/); you cannot run it directly.
function deactivate -d "Exit virtual environment and return to normal shell environment"
# reset old environment variables
if test -n "$_OLD_VIRTUAL_PATH"
set -gx PATH $_OLD_VIRTUAL_PATH
set -e _OLD_VIRTUAL_PATH
end
if test -n "$_OLD_VIRTUAL_PYTHONHOME"
set -gx PYTHONHOME $_OLD_VIRTUAL_PYTHONHOME
set -e _OLD_VIRTUAL_PYTHONHOME
end
if test -n "$_OLD_FISH_PROMPT_OVERRIDE"
set -e _OLD_FISH_PROMPT_OVERRIDE
# prevents error when using nested fish instances (Issue #93858)
if functions -q _old_fish_prompt
functions -e fish_prompt
functions -c _old_fish_prompt fish_prompt
functions -e _old_fish_prompt
end
end
set -e VIRTUAL_ENV
set -e VIRTUAL_ENV_PROMPT
if test "$argv[1]" != "nondestructive"
# Self-destruct!
functions -e deactivate
end
end
# Unset irrelevant variables.
deactivate nondestructive
set -gx VIRTUAL_ENV /home/ubuntu/wealth-engine/indie-status-page/venv
set -gx _OLD_VIRTUAL_PATH $PATH
set -gx PATH "$VIRTUAL_ENV/"bin $PATH
# Unset PYTHONHOME if set.
if set -q PYTHONHOME
set -gx _OLD_VIRTUAL_PYTHONHOME $PYTHONHOME
set -e PYTHONHOME
end
if test -z "$VIRTUAL_ENV_DISABLE_PROMPT"
# fish uses a function instead of an env var to generate the prompt.
# Save the current fish_prompt function as the function _old_fish_prompt.
functions -c fish_prompt _old_fish_prompt
# With the original prompt function renamed, we can override with our own.
function fish_prompt
# Save the return status of the last command.
set -l old_status $status
# Output the venv prompt; color taken from the blue of the Python logo.
printf "%s%s%s" (set_color 4B8BBE) '(venv) ' (set_color normal)
# Restore the return status of the previous command.
echo "exit $old_status" | .
# Output the original/"old" prompt.
_old_fish_prompt
end
set -gx _OLD_FISH_PROMPT_OVERRIDE "$VIRTUAL_ENV"
set -gx VIRTUAL_ENV_PROMPT '(venv) '
end

8
venv/bin/alembic Executable file
View file

@ -0,0 +1,8 @@
#!/home/ubuntu/wealth-engine/indie-status-page/venv/bin/python3.11
# -*- coding: utf-8 -*-
import re
import sys
from alembic.config import main
if __name__ == '__main__':
sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])
sys.exit(main())

8
venv/bin/dmypy Executable file
View file

@ -0,0 +1,8 @@
#!/home/ubuntu/wealth-engine/indie-status-page/venv/bin/python3.11
# -*- coding: utf-8 -*-
import re
import sys
from mypy.dmypy.client import console_entry
if __name__ == '__main__':
sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])
sys.exit(console_entry())

8
venv/bin/dotenv Executable file
View file

@ -0,0 +1,8 @@
#!/home/ubuntu/wealth-engine/indie-status-page/venv/bin/python3.11
# -*- coding: utf-8 -*-
import re
import sys
from dotenv.__main__ import cli
if __name__ == '__main__':
sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])
sys.exit(cli())

8
venv/bin/fastapi Executable file
View file

@ -0,0 +1,8 @@
#!/home/ubuntu/wealth-engine/indie-status-page/venv/bin/python3.11
# -*- coding: utf-8 -*-
import re
import sys
from fastapi.cli import main
if __name__ == '__main__':
sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])
sys.exit(main())

8
venv/bin/httpx Executable file
View file

@ -0,0 +1,8 @@
#!/home/ubuntu/wealth-engine/indie-status-page/venv/bin/python3.11
# -*- coding: utf-8 -*-
import re
import sys
from httpx import main
if __name__ == '__main__':
sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])
sys.exit(main())

8
venv/bin/mako-render Executable file
View file

@ -0,0 +1,8 @@
#!/home/ubuntu/wealth-engine/indie-status-page/venv/bin/python3.11
# -*- coding: utf-8 -*-
import re
import sys
from mako.cmd import cmdline
if __name__ == '__main__':
sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])
sys.exit(cmdline())

8
venv/bin/markdown-it Executable file
View file

@ -0,0 +1,8 @@
#!/home/ubuntu/wealth-engine/indie-status-page/venv/bin/python3.11
# -*- coding: utf-8 -*-
import re
import sys
from markdown_it.cli.parse import main
if __name__ == '__main__':
sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])
sys.exit(main())

8
venv/bin/mypy Executable file
View file

@ -0,0 +1,8 @@
#!/home/ubuntu/wealth-engine/indie-status-page/venv/bin/python3.11
# -*- coding: utf-8 -*-
import re
import sys
from mypy.__main__ import console_entry
if __name__ == '__main__':
sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])
sys.exit(console_entry())

8
venv/bin/mypyc Executable file
View file

@ -0,0 +1,8 @@
#!/home/ubuntu/wealth-engine/indie-status-page/venv/bin/python3.11
# -*- coding: utf-8 -*-
import re
import sys
from mypyc.__main__ import main
if __name__ == '__main__':
sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])
sys.exit(main())

8
venv/bin/pip Executable file
View file

@ -0,0 +1,8 @@
#!/home/ubuntu/wealth-engine/indie-status-page/venv/bin/python3.11
# -*- coding: utf-8 -*-
import re
import sys
from pip._internal.cli.main import main
if __name__ == '__main__':
sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])
sys.exit(main())

8
venv/bin/pip3 Executable file
View file

@ -0,0 +1,8 @@
#!/home/ubuntu/wealth-engine/indie-status-page/venv/bin/python3.11
# -*- coding: utf-8 -*-
import re
import sys
from pip._internal.cli.main import main
if __name__ == '__main__':
sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])
sys.exit(main())

8
venv/bin/pip3.11 Executable file
View file

@ -0,0 +1,8 @@
#!/home/ubuntu/wealth-engine/indie-status-page/venv/bin/python3.11
# -*- coding: utf-8 -*-
import re
import sys
from pip._internal.cli.main import main
if __name__ == '__main__':
sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])
sys.exit(main())

8
venv/bin/py.test Executable file
View file

@ -0,0 +1,8 @@
#!/home/ubuntu/wealth-engine/indie-status-page/venv/bin/python3.11
# -*- coding: utf-8 -*-
import re
import sys
from pytest import console_main
if __name__ == '__main__':
sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])
sys.exit(console_main())

8
venv/bin/pygmentize Executable file
View file

@ -0,0 +1,8 @@
#!/home/ubuntu/wealth-engine/indie-status-page/venv/bin/python3.11
# -*- coding: utf-8 -*-
import re
import sys
from pygments.cmdline import main
if __name__ == '__main__':
sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])
sys.exit(main())

8
venv/bin/pytest Executable file
View file

@ -0,0 +1,8 @@
#!/home/ubuntu/wealth-engine/indie-status-page/venv/bin/python3.11
# -*- coding: utf-8 -*-
import re
import sys
from pytest import console_main
if __name__ == '__main__':
sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])
sys.exit(console_main())

1
venv/bin/python Symbolic link
View file

@ -0,0 +1 @@
python3.11

1
venv/bin/python3 Symbolic link
View file

@ -0,0 +1 @@
python3.11

1
venv/bin/python3.11 Symbolic link
View file

@ -0,0 +1 @@
/home/ubuntu/.local/share/uv/python/cpython-3.11-linux-x86_64-gnu/bin/python3.11

BIN
venv/bin/ruff Executable file

Binary file not shown.

8
venv/bin/statuspage Executable file
View file

@ -0,0 +1,8 @@
#!/home/ubuntu/wealth-engine/indie-status-page/venv/bin/python3.11
# -*- coding: utf-8 -*-
import re
import sys
from app.cli import app
if __name__ == '__main__':
sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])
sys.exit(app())

8
venv/bin/stubgen Executable file
View file

@ -0,0 +1,8 @@
#!/home/ubuntu/wealth-engine/indie-status-page/venv/bin/python3.11
# -*- coding: utf-8 -*-
import re
import sys
from mypy.stubgen import main
if __name__ == '__main__':
sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])
sys.exit(main())

8
venv/bin/stubtest Executable file
View file

@ -0,0 +1,8 @@
#!/home/ubuntu/wealth-engine/indie-status-page/venv/bin/python3.11
# -*- coding: utf-8 -*-
import re
import sys
from mypy.stubtest import main
if __name__ == '__main__':
sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])
sys.exit(main())

8
venv/bin/typer Executable file
View file

@ -0,0 +1,8 @@
#!/home/ubuntu/wealth-engine/indie-status-page/venv/bin/python3.11
# -*- coding: utf-8 -*-
import re
import sys
from typer.cli import main
if __name__ == '__main__':
sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])
sys.exit(main())

8
venv/bin/uvicorn Executable file
View file

@ -0,0 +1,8 @@
#!/home/ubuntu/wealth-engine/indie-status-page/venv/bin/python3.11
# -*- coding: utf-8 -*-
import re
import sys
from uvicorn.main import main
if __name__ == '__main__':
sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])
sys.exit(main())

8
venv/bin/watchfiles Executable file
View file

@ -0,0 +1,8 @@
#!/home/ubuntu/wealth-engine/indie-status-page/venv/bin/python3.11
# -*- coding: utf-8 -*-
import re
import sys
from watchfiles.cli import cli
if __name__ == '__main__':
sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])
sys.exit(cli())

8
venv/bin/websockets Executable file
View file

@ -0,0 +1,8 @@
#!/home/ubuntu/wealth-engine/indie-status-page/venv/bin/python3.11
# -*- coding: utf-8 -*-
import re
import sys
from websockets.cli import main
if __name__ == '__main__':
sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])
sys.exit(main())

View file

@ -0,0 +1,164 @@
/* -*- indent-tabs-mode: nil; tab-width: 4; -*- */
/* Greenlet object interface */
#ifndef Py_GREENLETOBJECT_H
#define Py_GREENLETOBJECT_H
#include <Python.h>
#ifdef __cplusplus
extern "C" {
#endif
/* This is deprecated and undocumented. It does not change. */
#define GREENLET_VERSION "1.0.0"
#ifndef GREENLET_MODULE
#define implementation_ptr_t void*
#endif
typedef struct _greenlet {
PyObject_HEAD
PyObject* weakreflist;
PyObject* dict;
implementation_ptr_t pimpl;
} PyGreenlet;
#define PyGreenlet_Check(op) (op && PyObject_TypeCheck(op, &PyGreenlet_Type))
/* C API functions */
/* Total number of symbols that are exported */
#define PyGreenlet_API_pointers 12
#define PyGreenlet_Type_NUM 0
#define PyExc_GreenletError_NUM 1
#define PyExc_GreenletExit_NUM 2
#define PyGreenlet_New_NUM 3
#define PyGreenlet_GetCurrent_NUM 4
#define PyGreenlet_Throw_NUM 5
#define PyGreenlet_Switch_NUM 6
#define PyGreenlet_SetParent_NUM 7
#define PyGreenlet_MAIN_NUM 8
#define PyGreenlet_STARTED_NUM 9
#define PyGreenlet_ACTIVE_NUM 10
#define PyGreenlet_GET_PARENT_NUM 11
#ifndef GREENLET_MODULE
/* This section is used by modules that uses the greenlet C API */
static void** _PyGreenlet_API = NULL;
# define PyGreenlet_Type \
(*(PyTypeObject*)_PyGreenlet_API[PyGreenlet_Type_NUM])
# define PyExc_GreenletError \
((PyObject*)_PyGreenlet_API[PyExc_GreenletError_NUM])
# define PyExc_GreenletExit \
((PyObject*)_PyGreenlet_API[PyExc_GreenletExit_NUM])
/*
* PyGreenlet_New(PyObject *args)
*
* greenlet.greenlet(run, parent=None)
*/
# define PyGreenlet_New \
(*(PyGreenlet * (*)(PyObject * run, PyGreenlet * parent)) \
_PyGreenlet_API[PyGreenlet_New_NUM])
/*
* PyGreenlet_GetCurrent(void)
*
* greenlet.getcurrent()
*/
# define PyGreenlet_GetCurrent \
(*(PyGreenlet * (*)(void)) _PyGreenlet_API[PyGreenlet_GetCurrent_NUM])
/*
* PyGreenlet_Throw(
* PyGreenlet *greenlet,
* PyObject *typ,
* PyObject *val,
* PyObject *tb)
*
* g.throw(...)
*/
# define PyGreenlet_Throw \
(*(PyObject * (*)(PyGreenlet * self, \
PyObject * typ, \
PyObject * val, \
PyObject * tb)) \
_PyGreenlet_API[PyGreenlet_Throw_NUM])
/*
* PyGreenlet_Switch(PyGreenlet *greenlet, PyObject *args)
*
* g.switch(*args, **kwargs)
*/
# define PyGreenlet_Switch \
(*(PyObject * \
(*)(PyGreenlet * greenlet, PyObject * args, PyObject * kwargs)) \
_PyGreenlet_API[PyGreenlet_Switch_NUM])
/*
* PyGreenlet_SetParent(PyObject *greenlet, PyObject *new_parent)
*
* g.parent = new_parent
*/
# define PyGreenlet_SetParent \
(*(int (*)(PyGreenlet * greenlet, PyGreenlet * nparent)) \
_PyGreenlet_API[PyGreenlet_SetParent_NUM])
/*
* PyGreenlet_GetParent(PyObject* greenlet)
*
* return greenlet.parent;
*
* This could return NULL even if there is no exception active.
* If it does not return NULL, you are responsible for decrementing the
* reference count.
*/
# define PyGreenlet_GetParent \
(*(PyGreenlet* (*)(PyGreenlet*)) \
_PyGreenlet_API[PyGreenlet_GET_PARENT_NUM])
/*
* deprecated, undocumented alias.
*/
# define PyGreenlet_GET_PARENT PyGreenlet_GetParent
# define PyGreenlet_MAIN \
(*(int (*)(PyGreenlet*)) \
_PyGreenlet_API[PyGreenlet_MAIN_NUM])
# define PyGreenlet_STARTED \
(*(int (*)(PyGreenlet*)) \
_PyGreenlet_API[PyGreenlet_STARTED_NUM])
# define PyGreenlet_ACTIVE \
(*(int (*)(PyGreenlet*)) \
_PyGreenlet_API[PyGreenlet_ACTIVE_NUM])
/* Macro that imports greenlet and initializes C API */
/* NOTE: This has actually moved to ``greenlet._greenlet._C_API``, but we
keep the older definition to be sure older code that might have a copy of
the header still works. */
# define PyGreenlet_Import() \
{ \
_PyGreenlet_API = (void**)PyCapsule_Import("greenlet._C_API", 0); \
}
#endif /* GREENLET_MODULE */
#ifdef __cplusplus
}
#endif
#endif /* !Py_GREENLETOBJECT_H */

View file

@ -0,0 +1,239 @@
# don't import any costly modules
import os
import sys
report_url = (
"https://github.com/pypa/setuptools/issues/new?template=distutils-deprecation.yml"
)
def warn_distutils_present():
if 'distutils' not in sys.modules:
return
import warnings
warnings.warn(
"Distutils was imported before Setuptools, but importing Setuptools "
"also replaces the `distutils` module in `sys.modules`. This may lead "
"to undesirable behaviors or errors. To avoid these issues, avoid "
"using distutils directly, ensure that setuptools is installed in the "
"traditional way (e.g. not an editable install), and/or make sure "
"that setuptools is always imported before distutils."
)
def clear_distutils():
if 'distutils' not in sys.modules:
return
import warnings
warnings.warn(
"Setuptools is replacing distutils. Support for replacing "
"an already imported distutils is deprecated. In the future, "
"this condition will fail. "
f"Register concerns at {report_url}"
)
mods = [
name
for name in sys.modules
if name == "distutils" or name.startswith("distutils.")
]
for name in mods:
del sys.modules[name]
def enabled():
"""
Allow selection of distutils by environment variable.
"""
which = os.environ.get('SETUPTOOLS_USE_DISTUTILS', 'local')
if which == 'stdlib':
import warnings
warnings.warn(
"Reliance on distutils from stdlib is deprecated. Users "
"must rely on setuptools to provide the distutils module. "
"Avoid importing distutils or import setuptools first, "
"and avoid setting SETUPTOOLS_USE_DISTUTILS=stdlib. "
f"Register concerns at {report_url}"
)
return which == 'local'
def ensure_local_distutils():
import importlib
clear_distutils()
# With the DistutilsMetaFinder in place,
# perform an import to cause distutils to be
# loaded from setuptools._distutils. Ref #2906.
with shim():
importlib.import_module('distutils')
# check that submodules load as expected
core = importlib.import_module('distutils.core')
assert '_distutils' in core.__file__, core.__file__
assert 'setuptools._distutils.log' not in sys.modules
def do_override():
"""
Ensure that the local copy of distutils is preferred over stdlib.
See https://github.com/pypa/setuptools/issues/417#issuecomment-392298401
for more motivation.
"""
if enabled():
warn_distutils_present()
ensure_local_distutils()
class _TrivialRe:
def __init__(self, *patterns) -> None:
self._patterns = patterns
def match(self, string):
return all(pat in string for pat in self._patterns)
class DistutilsMetaFinder:
def find_spec(self, fullname, path, target=None):
# optimization: only consider top level modules and those
# found in the CPython test suite.
if path is not None and not fullname.startswith('test.'):
return None
method_name = 'spec_for_{fullname}'.format(**locals())
method = getattr(self, method_name, lambda: None)
return method()
def spec_for_distutils(self):
if self.is_cpython():
return None
import importlib
import importlib.abc
import importlib.util
try:
mod = importlib.import_module('setuptools._distutils')
except Exception:
# There are a couple of cases where setuptools._distutils
# may not be present:
# - An older Setuptools without a local distutils is
# taking precedence. Ref #2957.
# - Path manipulation during sitecustomize removes
# setuptools from the path but only after the hook
# has been loaded. Ref #2980.
# In either case, fall back to stdlib behavior.
return None
class DistutilsLoader(importlib.abc.Loader):
def create_module(self, spec):
mod.__name__ = 'distutils'
return mod
def exec_module(self, module):
pass
return importlib.util.spec_from_loader(
'distutils', DistutilsLoader(), origin=mod.__file__
)
@staticmethod
def is_cpython():
"""
Suppress supplying distutils for CPython (build and tests).
Ref #2965 and #3007.
"""
return os.path.isfile('pybuilddir.txt')
def spec_for_pip(self):
"""
Ensure stdlib distutils when running under pip.
See pypa/pip#8761 for rationale.
"""
if sys.version_info >= (3, 12) or self.pip_imported_during_build():
return
clear_distutils()
self.spec_for_distutils = lambda: None
@classmethod
def pip_imported_during_build(cls):
"""
Detect if pip is being imported in a build script. Ref #2355.
"""
import traceback
return any(
cls.frame_file_is_setup(frame) for frame, line in traceback.walk_stack(None)
)
@staticmethod
def frame_file_is_setup(frame):
"""
Return True if the indicated frame suggests a setup.py file.
"""
# some frames may not have __file__ (#2940)
return frame.f_globals.get('__file__', '').endswith('setup.py')
def spec_for_sensitive_tests(self):
"""
Ensure stdlib distutils when running select tests under CPython.
python/cpython#91169
"""
clear_distutils()
self.spec_for_distutils = lambda: None
sensitive_tests = (
[
'test.test_distutils',
'test.test_peg_generator',
'test.test_importlib',
]
if sys.version_info < (3, 10)
else [
'test.test_distutils',
]
)
for name in DistutilsMetaFinder.sensitive_tests:
setattr(
DistutilsMetaFinder,
f'spec_for_{name}',
DistutilsMetaFinder.spec_for_sensitive_tests,
)
DISTUTILS_FINDER = DistutilsMetaFinder()
def add_shim():
DISTUTILS_FINDER in sys.meta_path or insert_shim()
class shim:
def __enter__(self) -> None:
insert_shim()
def __exit__(self, exc: object, value: object, tb: object) -> None:
_remove_shim()
def insert_shim():
sys.meta_path.insert(0, DISTUTILS_FINDER)
def _remove_shim():
try:
sys.meta_path.remove(DISTUTILS_FINDER)
except ValueError:
pass
if sys.version_info < (3, 12):
# DistutilsMetaFinder can only be disabled in Python < 3.12 (PEP 632)
remove_shim = _remove_shim

View file

@ -0,0 +1 @@
__import__('_distutils_hack').do_override()

View file

@ -0,0 +1 @@
/home/ubuntu/wealth-engine/indie-status-page

View file

@ -0,0 +1,13 @@
from __future__ import annotations
__all__ = ["__version__", "version_tuple"]
try:
from ._version import version as __version__
from ._version import version_tuple
except ImportError: # pragma: no cover
# broken installation, we don't even try
# unknown only works because we do poor mans version compare
__version__ = "unknown"
version_tuple = (0, 0, "unknown")

View file

@ -0,0 +1,117 @@
"""Allow bash-completion for argparse with argcomplete if installed.
Needs argcomplete>=0.5.6 for python 3.2/3.3 (older versions fail
to find the magic string, so _ARGCOMPLETE env. var is never set, and
this does not need special code).
Function try_argcomplete(parser) should be called directly before
the call to ArgumentParser.parse_args().
The filescompleter is what you normally would use on the positional
arguments specification, in order to get "dirname/" after "dirn<TAB>"
instead of the default "dirname ":
optparser.add_argument(Config._file_or_dir, nargs='*').completer=filescompleter
Other, application specific, completers should go in the file
doing the add_argument calls as they need to be specified as .completer
attributes as well. (If argcomplete is not installed, the function the
attribute points to will not be used).
SPEEDUP
=======
The generic argcomplete script for bash-completion
(/etc/bash_completion.d/python-argcomplete.sh)
uses a python program to determine startup script generated by pip.
You can speed up completion somewhat by changing this script to include
# PYTHON_ARGCOMPLETE_OK
so the python-argcomplete-check-easy-install-script does not
need to be called to find the entry point of the code and see if that is
marked with PYTHON_ARGCOMPLETE_OK.
INSTALL/DEBUGGING
=================
To include this support in another application that has setup.py generated
scripts:
- Add the line:
# PYTHON_ARGCOMPLETE_OK
near the top of the main python entry point.
- Include in the file calling parse_args():
from _argcomplete import try_argcomplete, filescompleter
Call try_argcomplete just before parse_args(), and optionally add
filescompleter to the positional arguments' add_argument().
If things do not work right away:
- Switch on argcomplete debugging with (also helpful when doing custom
completers):
export _ARC_DEBUG=1
- Run:
python-argcomplete-check-easy-install-script $(which appname)
echo $?
will echo 0 if the magic line has been found, 1 if not.
- Sometimes it helps to find early on errors using:
_ARGCOMPLETE=1 _ARC_DEBUG=1 appname
which should throw a KeyError: 'COMPLINE' (which is properly set by the
global argcomplete script).
"""
from __future__ import annotations
import argparse
from glob import glob
import os
import sys
from typing import Any
class FastFilesCompleter:
"""Fast file completer class."""
def __init__(self, directories: bool = True) -> None:
self.directories = directories
def __call__(self, prefix: str, **kwargs: Any) -> list[str]:
# Only called on non option completions.
if os.sep in prefix[1:]:
prefix_dir = len(os.path.dirname(prefix) + os.sep)
else:
prefix_dir = 0
completion = []
globbed = []
if "*" not in prefix and "?" not in prefix:
# We are on unix, otherwise no bash.
if not prefix or prefix[-1] == os.sep:
globbed.extend(glob(prefix + ".*"))
prefix += "*"
globbed.extend(glob(prefix))
for x in sorted(globbed):
if os.path.isdir(x):
x += "/"
# Append stripping the prefix (like bash, not like compgen).
completion.append(x[prefix_dir:])
return completion
if os.environ.get("_ARGCOMPLETE"):
try:
import argcomplete.completers
except ImportError:
sys.exit(-1)
filescompleter: FastFilesCompleter | None = FastFilesCompleter()
def try_argcomplete(parser: argparse.ArgumentParser) -> None:
argcomplete.autocomplete(parser, always_complete_options=False)
else:
def try_argcomplete(parser: argparse.ArgumentParser) -> None:
pass
filescompleter = None

View file

@ -0,0 +1,26 @@
"""Python inspection/code generation API."""
from __future__ import annotations
from .code import Code
from .code import ExceptionInfo
from .code import filter_traceback
from .code import Frame
from .code import getfslineno
from .code import Traceback
from .code import TracebackEntry
from .source import getrawcode
from .source import Source
__all__ = [
"Code",
"ExceptionInfo",
"Frame",
"Source",
"Traceback",
"TracebackEntry",
"filter_traceback",
"getfslineno",
"getrawcode",
]

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,225 @@
# mypy: allow-untyped-defs
from __future__ import annotations
import ast
from bisect import bisect_right
from collections.abc import Iterable
from collections.abc import Iterator
import inspect
import textwrap
import tokenize
import types
from typing import overload
import warnings
class Source:
"""An immutable object holding a source code fragment.
When using Source(...), the source lines are deindented.
"""
def __init__(self, obj: object = None) -> None:
if not obj:
self.lines: list[str] = []
self.raw_lines: list[str] = []
elif isinstance(obj, Source):
self.lines = obj.lines
self.raw_lines = obj.raw_lines
elif isinstance(obj, (tuple, list)):
self.lines = deindent(x.rstrip("\n") for x in obj)
self.raw_lines = list(x.rstrip("\n") for x in obj)
elif isinstance(obj, str):
self.lines = deindent(obj.split("\n"))
self.raw_lines = obj.split("\n")
else:
try:
rawcode = getrawcode(obj)
src = inspect.getsource(rawcode)
except TypeError:
src = inspect.getsource(obj) # type: ignore[arg-type]
self.lines = deindent(src.split("\n"))
self.raw_lines = src.split("\n")
def __eq__(self, other: object) -> bool:
if not isinstance(other, Source):
return NotImplemented
return self.lines == other.lines
# Ignore type because of https://github.com/python/mypy/issues/4266.
__hash__ = None # type: ignore
@overload
def __getitem__(self, key: int) -> str: ...
@overload
def __getitem__(self, key: slice) -> Source: ...
def __getitem__(self, key: int | slice) -> str | Source:
if isinstance(key, int):
return self.lines[key]
else:
if key.step not in (None, 1):
raise IndexError("cannot slice a Source with a step")
newsource = Source()
newsource.lines = self.lines[key.start : key.stop]
newsource.raw_lines = self.raw_lines[key.start : key.stop]
return newsource
def __iter__(self) -> Iterator[str]:
return iter(self.lines)
def __len__(self) -> int:
return len(self.lines)
def strip(self) -> Source:
"""Return new Source object with trailing and leading blank lines removed."""
start, end = 0, len(self)
while start < end and not self.lines[start].strip():
start += 1
while end > start and not self.lines[end - 1].strip():
end -= 1
source = Source()
source.raw_lines = self.raw_lines
source.lines[:] = self.lines[start:end]
return source
def indent(self, indent: str = " " * 4) -> Source:
"""Return a copy of the source object with all lines indented by the
given indent-string."""
newsource = Source()
newsource.raw_lines = self.raw_lines
newsource.lines = [(indent + line) for line in self.lines]
return newsource
def getstatement(self, lineno: int) -> Source:
"""Return Source statement which contains the given linenumber
(counted from 0)."""
start, end = self.getstatementrange(lineno)
return self[start:end]
def getstatementrange(self, lineno: int) -> tuple[int, int]:
"""Return (start, end) tuple which spans the minimal statement region
which containing the given lineno."""
if not (0 <= lineno < len(self)):
raise IndexError("lineno out of range")
ast, start, end = getstatementrange_ast(lineno, self)
return start, end
def deindent(self) -> Source:
"""Return a new Source object deindented."""
newsource = Source()
newsource.lines[:] = deindent(self.lines)
newsource.raw_lines = self.raw_lines
return newsource
def __str__(self) -> str:
return "\n".join(self.lines)
#
# helper functions
#
def findsource(obj) -> tuple[Source | None, int]:
try:
sourcelines, lineno = inspect.findsource(obj)
except Exception:
return None, -1
source = Source()
source.lines = [line.rstrip() for line in sourcelines]
source.raw_lines = sourcelines
return source, lineno
def getrawcode(obj: object, trycall: bool = True) -> types.CodeType:
"""Return code object for given function."""
try:
return obj.__code__ # type: ignore[attr-defined,no-any-return]
except AttributeError:
pass
if trycall:
call = getattr(obj, "__call__", None)
if call and not isinstance(obj, type):
return getrawcode(call, trycall=False)
raise TypeError(f"could not get code object for {obj!r}")
def deindent(lines: Iterable[str]) -> list[str]:
return textwrap.dedent("\n".join(lines)).splitlines()
def get_statement_startend2(lineno: int, node: ast.AST) -> tuple[int, int | None]:
# Flatten all statements and except handlers into one lineno-list.
# AST's line numbers start indexing at 1.
values: list[int] = []
for x in ast.walk(node):
if isinstance(x, (ast.stmt, ast.ExceptHandler)):
# The lineno points to the class/def, so need to include the decorators.
if isinstance(x, (ast.ClassDef, ast.FunctionDef, ast.AsyncFunctionDef)):
for d in x.decorator_list:
values.append(d.lineno - 1)
values.append(x.lineno - 1)
for name in ("finalbody", "orelse"):
val: list[ast.stmt] | None = getattr(x, name, None)
if val:
# Treat the finally/orelse part as its own statement.
values.append(val[0].lineno - 1 - 1)
values.sort()
insert_index = bisect_right(values, lineno)
start = values[insert_index - 1]
if insert_index >= len(values):
end = None
else:
end = values[insert_index]
return start, end
def getstatementrange_ast(
lineno: int,
source: Source,
assertion: bool = False,
astnode: ast.AST | None = None,
) -> tuple[ast.AST, int, int]:
if astnode is None:
content = str(source)
# See #4260:
# Don't produce duplicate warnings when compiling source to find AST.
with warnings.catch_warnings():
warnings.simplefilter("ignore")
astnode = ast.parse(content, "source", "exec")
start, end = get_statement_startend2(lineno, astnode)
# We need to correct the end:
# - ast-parsing strips comments
# - there might be empty lines
# - we might have lesser indented code blocks at the end
if end is None:
end = len(source.lines)
if end > start + 1:
# Make sure we don't span differently indented code blocks
# by using the BlockFinder helper used which inspect.getsource() uses itself.
block_finder = inspect.BlockFinder()
# If we start with an indented line, put blockfinder to "started" mode.
block_finder.started = (
bool(source.lines[start]) and source.lines[start][0].isspace()
)
it = ((x + "\n") for x in source.lines[start:end])
try:
for tok in tokenize.generate_tokens(lambda: next(it)):
block_finder.tokeneater(*tok)
except (inspect.EndOfBlock, IndentationError):
end = block_finder.last + start
except Exception:
pass
# The end might still point to a comment or empty line, correct it.
while end:
line = source.lines[end - 1].lstrip()
if line.startswith("#") or not line:
end -= 1
else:
break
return astnode, start, end

View file

@ -0,0 +1,10 @@
from __future__ import annotations
from .terminalwriter import get_terminal_width
from .terminalwriter import TerminalWriter
__all__ = [
"TerminalWriter",
"get_terminal_width",
]

View file

@ -0,0 +1,673 @@
# mypy: allow-untyped-defs
# This module was imported from the cpython standard library
# (https://github.com/python/cpython/) at commit
# c5140945c723ae6c4b7ee81ff720ac8ea4b52cfd (python3.12).
#
#
# Original Author: Fred L. Drake, Jr.
# fdrake@acm.org
#
# This is a simple little module I wrote to make life easier. I didn't
# see anything quite like it in the library, though I may have overlooked
# something. I wrote this when I was trying to read some heavily nested
# tuples with fairly non-descriptive content. This is modeled very much
# after Lisp/Scheme - style pretty-printing of lists. If you find it
# useful, thank small children who sleep at night.
from __future__ import annotations
import collections as _collections
from collections.abc import Callable
from collections.abc import Iterator
import dataclasses as _dataclasses
from io import StringIO as _StringIO
import re
import types as _types
from typing import Any
from typing import IO
class _safe_key:
"""Helper function for key functions when sorting unorderable objects.
The wrapped-object will fallback to a Py2.x style comparison for
unorderable types (sorting first comparing the type name and then by
the obj ids). Does not work recursively, so dict.items() must have
_safe_key applied to both the key and the value.
"""
__slots__ = ["obj"]
def __init__(self, obj):
self.obj = obj
def __lt__(self, other):
try:
return self.obj < other.obj
except TypeError:
return (str(type(self.obj)), id(self.obj)) < (
str(type(other.obj)),
id(other.obj),
)
def _safe_tuple(t):
"""Helper function for comparing 2-tuples"""
return _safe_key(t[0]), _safe_key(t[1])
class PrettyPrinter:
def __init__(
self,
indent: int = 4,
width: int = 80,
depth: int | None = None,
) -> None:
"""Handle pretty printing operations onto a stream using a set of
configured parameters.
indent
Number of spaces to indent for each level of nesting.
width
Attempted maximum number of columns in the output.
depth
The maximum depth to print out nested structures.
"""
if indent < 0:
raise ValueError("indent must be >= 0")
if depth is not None and depth <= 0:
raise ValueError("depth must be > 0")
if not width:
raise ValueError("width must be != 0")
self._depth = depth
self._indent_per_level = indent
self._width = width
def pformat(self, object: Any) -> str:
sio = _StringIO()
self._format(object, sio, 0, 0, set(), 0)
return sio.getvalue()
def _format(
self,
object: Any,
stream: IO[str],
indent: int,
allowance: int,
context: set[int],
level: int,
) -> None:
objid = id(object)
if objid in context:
stream.write(_recursion(object))
return
p = self._dispatch.get(type(object).__repr__, None)
if p is not None:
context.add(objid)
p(self, object, stream, indent, allowance, context, level + 1)
context.remove(objid)
elif (
_dataclasses.is_dataclass(object)
and not isinstance(object, type)
and object.__dataclass_params__.repr # type:ignore[attr-defined]
and
# Check dataclass has generated repr method.
hasattr(object.__repr__, "__wrapped__")
and "__create_fn__" in object.__repr__.__wrapped__.__qualname__
):
context.add(objid)
self._pprint_dataclass(
object, stream, indent, allowance, context, level + 1
)
context.remove(objid)
else:
stream.write(self._repr(object, context, level))
def _pprint_dataclass(
self,
object: Any,
stream: IO[str],
indent: int,
allowance: int,
context: set[int],
level: int,
) -> None:
cls_name = object.__class__.__name__
items = [
(f.name, getattr(object, f.name))
for f in _dataclasses.fields(object)
if f.repr
]
stream.write(cls_name + "(")
self._format_namespace_items(items, stream, indent, allowance, context, level)
stream.write(")")
_dispatch: dict[
Callable[..., str],
Callable[[PrettyPrinter, Any, IO[str], int, int, set[int], int], None],
] = {}
def _pprint_dict(
self,
object: Any,
stream: IO[str],
indent: int,
allowance: int,
context: set[int],
level: int,
) -> None:
write = stream.write
write("{")
items = sorted(object.items(), key=_safe_tuple)
self._format_dict_items(items, stream, indent, allowance, context, level)
write("}")
_dispatch[dict.__repr__] = _pprint_dict
def _pprint_ordered_dict(
self,
object: Any,
stream: IO[str],
indent: int,
allowance: int,
context: set[int],
level: int,
) -> None:
if not len(object):
stream.write(repr(object))
return
cls = object.__class__
stream.write(cls.__name__ + "(")
self._pprint_dict(object, stream, indent, allowance, context, level)
stream.write(")")
_dispatch[_collections.OrderedDict.__repr__] = _pprint_ordered_dict
def _pprint_list(
self,
object: Any,
stream: IO[str],
indent: int,
allowance: int,
context: set[int],
level: int,
) -> None:
stream.write("[")
self._format_items(object, stream, indent, allowance, context, level)
stream.write("]")
_dispatch[list.__repr__] = _pprint_list
def _pprint_tuple(
self,
object: Any,
stream: IO[str],
indent: int,
allowance: int,
context: set[int],
level: int,
) -> None:
stream.write("(")
self._format_items(object, stream, indent, allowance, context, level)
stream.write(")")
_dispatch[tuple.__repr__] = _pprint_tuple
def _pprint_set(
self,
object: Any,
stream: IO[str],
indent: int,
allowance: int,
context: set[int],
level: int,
) -> None:
if not len(object):
stream.write(repr(object))
return
typ = object.__class__
if typ is set:
stream.write("{")
endchar = "}"
else:
stream.write(typ.__name__ + "({")
endchar = "})"
object = sorted(object, key=_safe_key)
self._format_items(object, stream, indent, allowance, context, level)
stream.write(endchar)
_dispatch[set.__repr__] = _pprint_set
_dispatch[frozenset.__repr__] = _pprint_set
def _pprint_str(
self,
object: Any,
stream: IO[str],
indent: int,
allowance: int,
context: set[int],
level: int,
) -> None:
write = stream.write
if not len(object):
write(repr(object))
return
chunks = []
lines = object.splitlines(True)
if level == 1:
indent += 1
allowance += 1
max_width1 = max_width = self._width - indent
for i, line in enumerate(lines):
rep = repr(line)
if i == len(lines) - 1:
max_width1 -= allowance
if len(rep) <= max_width1:
chunks.append(rep)
else:
# A list of alternating (non-space, space) strings
parts = re.findall(r"\S*\s*", line)
assert parts
assert not parts[-1]
parts.pop() # drop empty last part
max_width2 = max_width
current = ""
for j, part in enumerate(parts):
candidate = current + part
if j == len(parts) - 1 and i == len(lines) - 1:
max_width2 -= allowance
if len(repr(candidate)) > max_width2:
if current:
chunks.append(repr(current))
current = part
else:
current = candidate
if current:
chunks.append(repr(current))
if len(chunks) == 1:
write(rep)
return
if level == 1:
write("(")
for i, rep in enumerate(chunks):
if i > 0:
write("\n" + " " * indent)
write(rep)
if level == 1:
write(")")
_dispatch[str.__repr__] = _pprint_str
def _pprint_bytes(
self,
object: Any,
stream: IO[str],
indent: int,
allowance: int,
context: set[int],
level: int,
) -> None:
write = stream.write
if len(object) <= 4:
write(repr(object))
return
parens = level == 1
if parens:
indent += 1
allowance += 1
write("(")
delim = ""
for rep in _wrap_bytes_repr(object, self._width - indent, allowance):
write(delim)
write(rep)
if not delim:
delim = "\n" + " " * indent
if parens:
write(")")
_dispatch[bytes.__repr__] = _pprint_bytes
def _pprint_bytearray(
self,
object: Any,
stream: IO[str],
indent: int,
allowance: int,
context: set[int],
level: int,
) -> None:
write = stream.write
write("bytearray(")
self._pprint_bytes(
bytes(object), stream, indent + 10, allowance + 1, context, level + 1
)
write(")")
_dispatch[bytearray.__repr__] = _pprint_bytearray
def _pprint_mappingproxy(
self,
object: Any,
stream: IO[str],
indent: int,
allowance: int,
context: set[int],
level: int,
) -> None:
stream.write("mappingproxy(")
self._format(object.copy(), stream, indent, allowance, context, level)
stream.write(")")
_dispatch[_types.MappingProxyType.__repr__] = _pprint_mappingproxy
def _pprint_simplenamespace(
self,
object: Any,
stream: IO[str],
indent: int,
allowance: int,
context: set[int],
level: int,
) -> None:
if type(object) is _types.SimpleNamespace:
# The SimpleNamespace repr is "namespace" instead of the class
# name, so we do the same here. For subclasses; use the class name.
cls_name = "namespace"
else:
cls_name = object.__class__.__name__
items = object.__dict__.items()
stream.write(cls_name + "(")
self._format_namespace_items(items, stream, indent, allowance, context, level)
stream.write(")")
_dispatch[_types.SimpleNamespace.__repr__] = _pprint_simplenamespace
def _format_dict_items(
self,
items: list[tuple[Any, Any]],
stream: IO[str],
indent: int,
allowance: int,
context: set[int],
level: int,
) -> None:
if not items:
return
write = stream.write
item_indent = indent + self._indent_per_level
delimnl = "\n" + " " * item_indent
for key, ent in items:
write(delimnl)
write(self._repr(key, context, level))
write(": ")
self._format(ent, stream, item_indent, 1, context, level)
write(",")
write("\n" + " " * indent)
def _format_namespace_items(
self,
items: list[tuple[Any, Any]],
stream: IO[str],
indent: int,
allowance: int,
context: set[int],
level: int,
) -> None:
if not items:
return
write = stream.write
item_indent = indent + self._indent_per_level
delimnl = "\n" + " " * item_indent
for key, ent in items:
write(delimnl)
write(key)
write("=")
if id(ent) in context:
# Special-case representation of recursion to match standard
# recursive dataclass repr.
write("...")
else:
self._format(
ent,
stream,
item_indent + len(key) + 1,
1,
context,
level,
)
write(",")
write("\n" + " " * indent)
def _format_items(
self,
items: list[Any],
stream: IO[str],
indent: int,
allowance: int,
context: set[int],
level: int,
) -> None:
if not items:
return
write = stream.write
item_indent = indent + self._indent_per_level
delimnl = "\n" + " " * item_indent
for item in items:
write(delimnl)
self._format(item, stream, item_indent, 1, context, level)
write(",")
write("\n" + " " * indent)
def _repr(self, object: Any, context: set[int], level: int) -> str:
return self._safe_repr(object, context.copy(), self._depth, level)
def _pprint_default_dict(
self,
object: Any,
stream: IO[str],
indent: int,
allowance: int,
context: set[int],
level: int,
) -> None:
rdf = self._repr(object.default_factory, context, level)
stream.write(f"{object.__class__.__name__}({rdf}, ")
self._pprint_dict(object, stream, indent, allowance, context, level)
stream.write(")")
_dispatch[_collections.defaultdict.__repr__] = _pprint_default_dict
def _pprint_counter(
self,
object: Any,
stream: IO[str],
indent: int,
allowance: int,
context: set[int],
level: int,
) -> None:
stream.write(object.__class__.__name__ + "(")
if object:
stream.write("{")
items = object.most_common()
self._format_dict_items(items, stream, indent, allowance, context, level)
stream.write("}")
stream.write(")")
_dispatch[_collections.Counter.__repr__] = _pprint_counter
def _pprint_chain_map(
self,
object: Any,
stream: IO[str],
indent: int,
allowance: int,
context: set[int],
level: int,
) -> None:
if not len(object.maps) or (len(object.maps) == 1 and not len(object.maps[0])):
stream.write(repr(object))
return
stream.write(object.__class__.__name__ + "(")
self._format_items(object.maps, stream, indent, allowance, context, level)
stream.write(")")
_dispatch[_collections.ChainMap.__repr__] = _pprint_chain_map
def _pprint_deque(
self,
object: Any,
stream: IO[str],
indent: int,
allowance: int,
context: set[int],
level: int,
) -> None:
stream.write(object.__class__.__name__ + "(")
if object.maxlen is not None:
stream.write(f"maxlen={object.maxlen}, ")
stream.write("[")
self._format_items(object, stream, indent, allowance + 1, context, level)
stream.write("])")
_dispatch[_collections.deque.__repr__] = _pprint_deque
def _pprint_user_dict(
self,
object: Any,
stream: IO[str],
indent: int,
allowance: int,
context: set[int],
level: int,
) -> None:
self._format(object.data, stream, indent, allowance, context, level - 1)
_dispatch[_collections.UserDict.__repr__] = _pprint_user_dict
def _pprint_user_list(
self,
object: Any,
stream: IO[str],
indent: int,
allowance: int,
context: set[int],
level: int,
) -> None:
self._format(object.data, stream, indent, allowance, context, level - 1)
_dispatch[_collections.UserList.__repr__] = _pprint_user_list
def _pprint_user_string(
self,
object: Any,
stream: IO[str],
indent: int,
allowance: int,
context: set[int],
level: int,
) -> None:
self._format(object.data, stream, indent, allowance, context, level - 1)
_dispatch[_collections.UserString.__repr__] = _pprint_user_string
def _safe_repr(
self, object: Any, context: set[int], maxlevels: int | None, level: int
) -> str:
typ = type(object)
if typ in _builtin_scalars:
return repr(object)
r = getattr(typ, "__repr__", None)
if issubclass(typ, dict) and r is dict.__repr__:
if not object:
return "{}"
objid = id(object)
if maxlevels and level >= maxlevels:
return "{...}"
if objid in context:
return _recursion(object)
context.add(objid)
components: list[str] = []
append = components.append
level += 1
for k, v in sorted(object.items(), key=_safe_tuple):
krepr = self._safe_repr(k, context, maxlevels, level)
vrepr = self._safe_repr(v, context, maxlevels, level)
append(f"{krepr}: {vrepr}")
context.remove(objid)
return "{{{}}}".format(", ".join(components))
if (issubclass(typ, list) and r is list.__repr__) or (
issubclass(typ, tuple) and r is tuple.__repr__
):
if issubclass(typ, list):
if not object:
return "[]"
format = "[%s]"
elif len(object) == 1:
format = "(%s,)"
else:
if not object:
return "()"
format = "(%s)"
objid = id(object)
if maxlevels and level >= maxlevels:
return format % "..."
if objid in context:
return _recursion(object)
context.add(objid)
components = []
append = components.append
level += 1
for o in object:
orepr = self._safe_repr(o, context, maxlevels, level)
append(orepr)
context.remove(objid)
return format % ", ".join(components)
return repr(object)
_builtin_scalars = frozenset(
{str, bytes, bytearray, float, complex, bool, type(None), int}
)
def _recursion(object: Any) -> str:
return f"<Recursion on {type(object).__name__} with id={id(object)}>"
def _wrap_bytes_repr(object: Any, width: int, allowance: int) -> Iterator[str]:
current = b""
last = len(object) // 4 * 4
for i in range(0, len(object), 4):
part = object[i : i + 4]
candidate = current + part
if i == last:
width -= allowance
if len(repr(candidate)) > width:
if current:
yield repr(current)
current = part
else:
current = candidate
if current:
yield repr(current)

View file

@ -0,0 +1,130 @@
from __future__ import annotations
import pprint
import reprlib
def _try_repr_or_str(obj: object) -> str:
try:
return repr(obj)
except (KeyboardInterrupt, SystemExit):
raise
except BaseException:
return f'{type(obj).__name__}("{obj}")'
def _format_repr_exception(exc: BaseException, obj: object) -> str:
try:
exc_info = _try_repr_or_str(exc)
except (KeyboardInterrupt, SystemExit):
raise
except BaseException as inner_exc:
exc_info = f"unpresentable exception ({_try_repr_or_str(inner_exc)})"
return (
f"<[{exc_info} raised in repr()] {type(obj).__name__} object at 0x{id(obj):x}>"
)
def _ellipsize(s: str, maxsize: int) -> str:
if len(s) > maxsize:
i = max(0, (maxsize - 3) // 2)
j = max(0, maxsize - 3 - i)
return s[:i] + "..." + s[len(s) - j :]
return s
class SafeRepr(reprlib.Repr):
"""
repr.Repr that limits the resulting size of repr() and includes
information on exceptions raised during the call.
"""
def __init__(self, maxsize: int | None, use_ascii: bool = False) -> None:
"""
:param maxsize:
If not None, will truncate the resulting repr to that specific size, using ellipsis
somewhere in the middle to hide the extra text.
If None, will not impose any size limits on the returning repr.
"""
super().__init__()
# ``maxstring`` is used by the superclass, and needs to be an int; using a
# very large number in case maxsize is None, meaning we want to disable
# truncation.
self.maxstring = maxsize if maxsize is not None else 1_000_000_000
self.maxsize = maxsize
self.use_ascii = use_ascii
def repr(self, x: object) -> str:
try:
if self.use_ascii:
s = ascii(x)
else:
s = super().repr(x)
except (KeyboardInterrupt, SystemExit):
raise
except BaseException as exc:
s = _format_repr_exception(exc, x)
if self.maxsize is not None:
s = _ellipsize(s, self.maxsize)
return s
def repr_instance(self, x: object, level: int) -> str:
try:
s = repr(x)
except (KeyboardInterrupt, SystemExit):
raise
except BaseException as exc:
s = _format_repr_exception(exc, x)
if self.maxsize is not None:
s = _ellipsize(s, self.maxsize)
return s
def safeformat(obj: object) -> str:
"""Return a pretty printed string for the given object.
Failing __repr__ functions of user instances will be represented
with a short exception info.
"""
try:
return pprint.pformat(obj)
except Exception as exc:
return _format_repr_exception(exc, obj)
# Maximum size of overall repr of objects to display during assertion errors.
DEFAULT_REPR_MAX_SIZE = 240
def saferepr(
obj: object, maxsize: int | None = DEFAULT_REPR_MAX_SIZE, use_ascii: bool = False
) -> str:
"""Return a size-limited safe repr-string for the given object.
Failing __repr__ functions of user instances will be represented
with a short exception info and 'saferepr' generally takes
care to never raise exceptions itself.
This function is a wrapper around the Repr/reprlib functionality of the
stdlib.
"""
return SafeRepr(maxsize, use_ascii).repr(obj)
def saferepr_unlimited(obj: object, use_ascii: bool = True) -> str:
"""Return an unlimited-size safe repr-string for the given object.
As with saferepr, failing __repr__ functions of user instances
will be represented with a short exception info.
This function is a wrapper around simple repr.
Note: a cleaner solution would be to alter ``saferepr``this way
when maxsize=None, but that might affect some other code.
"""
try:
if use_ascii:
return ascii(obj)
return repr(obj)
except Exception as exc:
return _format_repr_exception(exc, obj)

View file

@ -0,0 +1,254 @@
"""Helper functions for writing to terminals and files."""
from __future__ import annotations
from collections.abc import Sequence
import os
import shutil
import sys
from typing import final
from typing import Literal
from typing import TextIO
import pygments
from pygments.formatters.terminal import TerminalFormatter
from pygments.lexer import Lexer
from pygments.lexers.diff import DiffLexer
from pygments.lexers.python import PythonLexer
from ..compat import assert_never
from .wcwidth import wcswidth
# This code was initially copied from py 1.8.1, file _io/terminalwriter.py.
def get_terminal_width() -> int:
width, _ = shutil.get_terminal_size(fallback=(80, 24))
# The Windows get_terminal_size may be bogus, let's sanify a bit.
if width < 40:
width = 80
return width
def should_do_markup(file: TextIO) -> bool:
if os.environ.get("PY_COLORS") == "1":
return True
if os.environ.get("PY_COLORS") == "0":
return False
if os.environ.get("NO_COLOR"):
return False
if os.environ.get("FORCE_COLOR"):
return True
return (
hasattr(file, "isatty") and file.isatty() and os.environ.get("TERM") != "dumb"
)
@final
class TerminalWriter:
_esctable = dict(
black=30,
red=31,
green=32,
yellow=33,
blue=34,
purple=35,
cyan=36,
white=37,
Black=40,
Red=41,
Green=42,
Yellow=43,
Blue=44,
Purple=45,
Cyan=46,
White=47,
bold=1,
light=2,
blink=5,
invert=7,
)
def __init__(self, file: TextIO | None = None) -> None:
if file is None:
file = sys.stdout
if hasattr(file, "isatty") and file.isatty() and sys.platform == "win32":
try:
import colorama
except ImportError:
pass
else:
file = colorama.AnsiToWin32(file).stream
assert file is not None
self._file = file
self.hasmarkup = should_do_markup(file)
self._current_line = ""
self._terminal_width: int | None = None
self.code_highlight = True
@property
def fullwidth(self) -> int:
if self._terminal_width is not None:
return self._terminal_width
return get_terminal_width()
@fullwidth.setter
def fullwidth(self, value: int) -> None:
self._terminal_width = value
@property
def width_of_current_line(self) -> int:
"""Return an estimate of the width so far in the current line."""
return wcswidth(self._current_line)
def markup(self, text: str, **markup: bool) -> str:
for name in markup:
if name not in self._esctable:
raise ValueError(f"unknown markup: {name!r}")
if self.hasmarkup:
esc = [self._esctable[name] for name, on in markup.items() if on]
if esc:
text = "".join(f"\x1b[{cod}m" for cod in esc) + text + "\x1b[0m"
return text
def sep(
self,
sepchar: str,
title: str | None = None,
fullwidth: int | None = None,
**markup: bool,
) -> None:
if fullwidth is None:
fullwidth = self.fullwidth
# The goal is to have the line be as long as possible
# under the condition that len(line) <= fullwidth.
if sys.platform == "win32":
# If we print in the last column on windows we are on a
# new line but there is no way to verify/neutralize this
# (we may not know the exact line width).
# So let's be defensive to avoid empty lines in the output.
fullwidth -= 1
if title is not None:
# we want 2 + 2*len(fill) + len(title) <= fullwidth
# i.e. 2 + 2*len(sepchar)*N + len(title) <= fullwidth
# 2*len(sepchar)*N <= fullwidth - len(title) - 2
# N <= (fullwidth - len(title) - 2) // (2*len(sepchar))
N = max((fullwidth - len(title) - 2) // (2 * len(sepchar)), 1)
fill = sepchar * N
line = f"{fill} {title} {fill}"
else:
# we want len(sepchar)*N <= fullwidth
# i.e. N <= fullwidth // len(sepchar)
line = sepchar * (fullwidth // len(sepchar))
# In some situations there is room for an extra sepchar at the right,
# in particular if we consider that with a sepchar like "_ " the
# trailing space is not important at the end of the line.
if len(line) + len(sepchar.rstrip()) <= fullwidth:
line += sepchar.rstrip()
self.line(line, **markup)
def write(self, msg: str, *, flush: bool = False, **markup: bool) -> None:
if msg:
current_line = msg.rsplit("\n", 1)[-1]
if "\n" in msg:
self._current_line = current_line
else:
self._current_line += current_line
msg = self.markup(msg, **markup)
try:
self._file.write(msg)
except UnicodeEncodeError:
# Some environments don't support printing general Unicode
# strings, due to misconfiguration or otherwise; in that case,
# print the string escaped to ASCII.
# When the Unicode situation improves we should consider
# letting the error propagate instead of masking it (see #7475
# for one brief attempt).
msg = msg.encode("unicode-escape").decode("ascii")
self._file.write(msg)
if flush:
self.flush()
def line(self, s: str = "", **markup: bool) -> None:
self.write(s, **markup)
self.write("\n")
def flush(self) -> None:
self._file.flush()
def _write_source(self, lines: Sequence[str], indents: Sequence[str] = ()) -> None:
"""Write lines of source code possibly highlighted.
Keeping this private for now because the API is clunky. We should discuss how
to evolve the terminal writer so we can have more precise color support, for example
being able to write part of a line in one color and the rest in another, and so on.
"""
if indents and len(indents) != len(lines):
raise ValueError(
f"indents size ({len(indents)}) should have same size as lines ({len(lines)})"
)
if not indents:
indents = [""] * len(lines)
source = "\n".join(lines)
new_lines = self._highlight(source).splitlines()
for indent, new_line in zip(indents, new_lines):
self.line(indent + new_line)
def _get_pygments_lexer(self, lexer: Literal["python", "diff"]) -> Lexer:
if lexer == "python":
return PythonLexer()
elif lexer == "diff":
return DiffLexer()
else:
assert_never(lexer)
def _get_pygments_formatter(self) -> TerminalFormatter:
from _pytest.config.exceptions import UsageError
theme = os.getenv("PYTEST_THEME")
theme_mode = os.getenv("PYTEST_THEME_MODE", "dark")
try:
return TerminalFormatter(bg=theme_mode, style=theme)
except pygments.util.ClassNotFound as e:
raise UsageError(
f"PYTEST_THEME environment variable has an invalid value: '{theme}'. "
"Hint: See available pygments styles with `pygmentize -L styles`."
) from e
except pygments.util.OptionError as e:
raise UsageError(
f"PYTEST_THEME_MODE environment variable has an invalid value: '{theme_mode}'. "
"The allowed values are 'dark' (default) and 'light'."
) from e
def _highlight(
self, source: str, lexer: Literal["diff", "python"] = "python"
) -> str:
"""Highlight the given source if we have markup support."""
if not source or not self.hasmarkup or not self.code_highlight:
return source
pygments_lexer = self._get_pygments_lexer(lexer)
pygments_formatter = self._get_pygments_formatter()
highlighted: str = pygments.highlight(
source, pygments_lexer, pygments_formatter
)
# pygments terminal formatter may add a newline when there wasn't one.
# We don't want this, remove.
if highlighted[-1] == "\n" and source[-1] != "\n":
highlighted = highlighted[:-1]
# Some lexers will not set the initial color explicitly
# which may lead to the previous color being propagated to the
# start of the expression, so reset first.
highlighted = "\x1b[0m" + highlighted
return highlighted

View file

@ -0,0 +1,57 @@
from __future__ import annotations
from functools import lru_cache
import unicodedata
@lru_cache(100)
def wcwidth(c: str) -> int:
"""Determine how many columns are needed to display a character in a terminal.
Returns -1 if the character is not printable.
Returns 0, 1 or 2 for other characters.
"""
o = ord(c)
# ASCII fast path.
if 0x20 <= o < 0x07F:
return 1
# Some Cf/Zp/Zl characters which should be zero-width.
if (
o == 0x0000
or 0x200B <= o <= 0x200F
or 0x2028 <= o <= 0x202E
or 0x2060 <= o <= 0x2063
):
return 0
category = unicodedata.category(c)
# Control characters.
if category == "Cc":
return -1
# Combining characters with zero width.
if category in ("Me", "Mn"):
return 0
# Full/Wide east asian characters.
if unicodedata.east_asian_width(c) in ("F", "W"):
return 2
return 1
def wcswidth(s: str) -> int:
"""Determine how many columns are needed to display a string in a terminal.
Returns -1 if the string contains non-printable characters.
"""
width = 0
for c in unicodedata.normalize("NFC", s):
wc = wcwidth(c)
if wc < 0:
return -1
width += wc
return width

View file

@ -0,0 +1,119 @@
"""create errno-specific classes for IO or os calls."""
from __future__ import annotations
from collections.abc import Callable
import errno
import os
import sys
from typing import TYPE_CHECKING
from typing import TypeVar
if TYPE_CHECKING:
from typing_extensions import ParamSpec
P = ParamSpec("P")
R = TypeVar("R")
class Error(EnvironmentError):
def __repr__(self) -> str:
return "{}.{} {!r}: {} ".format(
self.__class__.__module__,
self.__class__.__name__,
self.__class__.__doc__,
" ".join(map(str, self.args)),
# repr(self.args)
)
def __str__(self) -> str:
s = "[{}]: {}".format(
self.__class__.__doc__,
" ".join(map(str, self.args)),
)
return s
_winerrnomap = {
2: errno.ENOENT,
3: errno.ENOENT,
17: errno.EEXIST,
18: errno.EXDEV,
13: errno.EBUSY, # empty cd drive, but ENOMEDIUM seems unavailable
22: errno.ENOTDIR,
20: errno.ENOTDIR,
267: errno.ENOTDIR,
5: errno.EACCES, # anything better?
}
class ErrorMaker:
"""lazily provides Exception classes for each possible POSIX errno
(as defined per the 'errno' module). All such instances
subclass EnvironmentError.
"""
_errno2class: dict[int, type[Error]] = {}
def __getattr__(self, name: str) -> type[Error]:
if name[0] == "_":
raise AttributeError(name)
eno = getattr(errno, name)
cls = self._geterrnoclass(eno)
setattr(self, name, cls)
return cls
def _geterrnoclass(self, eno: int) -> type[Error]:
try:
return self._errno2class[eno]
except KeyError:
clsname = errno.errorcode.get(eno, f"UnknownErrno{eno}")
errorcls = type(
clsname,
(Error,),
{"__module__": "py.error", "__doc__": os.strerror(eno)},
)
self._errno2class[eno] = errorcls
return errorcls
def checked_call(
self, func: Callable[P, R], *args: P.args, **kwargs: P.kwargs
) -> R:
"""Call a function and raise an errno-exception if applicable."""
__tracebackhide__ = True
try:
return func(*args, **kwargs)
except Error:
raise
except OSError as value:
if not hasattr(value, "errno"):
raise
if sys.platform == "win32":
try:
# error: Invalid index type "Optional[int]" for "dict[int, int]"; expected type "int" [index]
# OK to ignore because we catch the KeyError below.
cls = self._geterrnoclass(_winerrnomap[value.errno]) # type:ignore[index]
except KeyError:
raise value
else:
# we are not on Windows, or we got a proper OSError
if value.errno is None:
cls = type(
"UnknownErrnoNone",
(Error,),
{"__module__": "py.error", "__doc__": None},
)
else:
cls = self._geterrnoclass(value.errno)
raise cls(f"{func.__name__}{args!r}")
_error_maker = ErrorMaker()
checked_call = _error_maker.checked_call
def __getattr__(attr: str) -> type[Error]:
return getattr(_error_maker, attr) # type: ignore[no-any-return]

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,34 @@
# file generated by setuptools-scm
# don't change, don't track in version control
__all__ = [
"__version__",
"__version_tuple__",
"version",
"version_tuple",
"__commit_id__",
"commit_id",
]
TYPE_CHECKING = False
if TYPE_CHECKING:
from typing import Tuple
from typing import Union
VERSION_TUPLE = Tuple[Union[int, str], ...]
COMMIT_ID = Union[str, None]
else:
VERSION_TUPLE = object
COMMIT_ID = object
version: str
__version__: str
__version_tuple__: VERSION_TUPLE
version_tuple: VERSION_TUPLE
commit_id: COMMIT_ID
__commit_id__: COMMIT_ID
__version__ = version = '8.4.2'
__version_tuple__ = version_tuple = (8, 4, 2)
__commit_id__ = commit_id = None

View file

@ -0,0 +1,208 @@
# mypy: allow-untyped-defs
"""Support for presenting detailed information in failing assertions."""
from __future__ import annotations
from collections.abc import Generator
import sys
from typing import Any
from typing import Protocol
from typing import TYPE_CHECKING
from _pytest.assertion import rewrite
from _pytest.assertion import truncate
from _pytest.assertion import util
from _pytest.assertion.rewrite import assertstate_key
from _pytest.config import Config
from _pytest.config import hookimpl
from _pytest.config.argparsing import Parser
from _pytest.nodes import Item
if TYPE_CHECKING:
from _pytest.main import Session
def pytest_addoption(parser: Parser) -> None:
group = parser.getgroup("debugconfig")
group.addoption(
"--assert",
action="store",
dest="assertmode",
choices=("rewrite", "plain"),
default="rewrite",
metavar="MODE",
help=(
"Control assertion debugging tools.\n"
"'plain' performs no assertion debugging.\n"
"'rewrite' (the default) rewrites assert statements in test modules"
" on import to provide assert expression information."
),
)
parser.addini(
"enable_assertion_pass_hook",
type="bool",
default=False,
help="Enables the pytest_assertion_pass hook. "
"Make sure to delete any previously generated pyc cache files.",
)
parser.addini(
"truncation_limit_lines",
default=None,
help="Set threshold of LINES after which truncation will take effect",
)
parser.addini(
"truncation_limit_chars",
default=None,
help=("Set threshold of CHARS after which truncation will take effect"),
)
Config._add_verbosity_ini(
parser,
Config.VERBOSITY_ASSERTIONS,
help=(
"Specify a verbosity level for assertions, overriding the main level. "
"Higher levels will provide more detailed explanation when an assertion fails."
),
)
def register_assert_rewrite(*names: str) -> None:
"""Register one or more module names to be rewritten on import.
This function will make sure that this module or all modules inside
the package will get their assert statements rewritten.
Thus you should make sure to call this before the module is
actually imported, usually in your __init__.py if you are a plugin
using a package.
:param names: The module names to register.
"""
for name in names:
if not isinstance(name, str):
msg = "expected module names as *args, got {0} instead" # type: ignore[unreachable]
raise TypeError(msg.format(repr(names)))
rewrite_hook: RewriteHook
for hook in sys.meta_path:
if isinstance(hook, rewrite.AssertionRewritingHook):
rewrite_hook = hook
break
else:
rewrite_hook = DummyRewriteHook()
rewrite_hook.mark_rewrite(*names)
class RewriteHook(Protocol):
def mark_rewrite(self, *names: str) -> None: ...
class DummyRewriteHook:
"""A no-op import hook for when rewriting is disabled."""
def mark_rewrite(self, *names: str) -> None:
pass
class AssertionState:
"""State for the assertion plugin."""
def __init__(self, config: Config, mode) -> None:
self.mode = mode
self.trace = config.trace.root.get("assertion")
self.hook: rewrite.AssertionRewritingHook | None = None
def install_importhook(config: Config) -> rewrite.AssertionRewritingHook:
"""Try to install the rewrite hook, raise SystemError if it fails."""
config.stash[assertstate_key] = AssertionState(config, "rewrite")
config.stash[assertstate_key].hook = hook = rewrite.AssertionRewritingHook(config)
sys.meta_path.insert(0, hook)
config.stash[assertstate_key].trace("installed rewrite import hook")
def undo() -> None:
hook = config.stash[assertstate_key].hook
if hook is not None and hook in sys.meta_path:
sys.meta_path.remove(hook)
config.add_cleanup(undo)
return hook
def pytest_collection(session: Session) -> None:
# This hook is only called when test modules are collected
# so for example not in the managing process of pytest-xdist
# (which does not collect test modules).
assertstate = session.config.stash.get(assertstate_key, None)
if assertstate:
if assertstate.hook is not None:
assertstate.hook.set_session(session)
@hookimpl(wrapper=True, tryfirst=True)
def pytest_runtest_protocol(item: Item) -> Generator[None, object, object]:
"""Setup the pytest_assertrepr_compare and pytest_assertion_pass hooks.
The rewrite module will use util._reprcompare if it exists to use custom
reporting via the pytest_assertrepr_compare hook. This sets up this custom
comparison for the test.
"""
ihook = item.ihook
def callbinrepr(op, left: object, right: object) -> str | None:
"""Call the pytest_assertrepr_compare hook and prepare the result.
This uses the first result from the hook and then ensures the
following:
* Overly verbose explanations are truncated unless configured otherwise
(eg. if running in verbose mode).
* Embedded newlines are escaped to help util.format_explanation()
later.
* If the rewrite mode is used embedded %-characters are replaced
to protect later % formatting.
The result can be formatted by util.format_explanation() for
pretty printing.
"""
hook_result = ihook.pytest_assertrepr_compare(
config=item.config, op=op, left=left, right=right
)
for new_expl in hook_result:
if new_expl:
new_expl = truncate.truncate_if_required(new_expl, item)
new_expl = [line.replace("\n", "\\n") for line in new_expl]
res = "\n~".join(new_expl)
if item.config.getvalue("assertmode") == "rewrite":
res = res.replace("%", "%%")
return res
return None
saved_assert_hooks = util._reprcompare, util._assertion_pass
util._reprcompare = callbinrepr
util._config = item.config
if ihook.pytest_assertion_pass.get_hookimpls():
def call_assertion_pass_hook(lineno: int, orig: str, expl: str) -> None:
ihook.pytest_assertion_pass(item=item, lineno=lineno, orig=orig, expl=expl)
util._assertion_pass = call_assertion_pass_hook
try:
return (yield)
finally:
util._reprcompare, util._assertion_pass = saved_assert_hooks
util._config = None
def pytest_sessionfinish(session: Session) -> None:
assertstate = session.config.stash.get(assertstate_key, None)
if assertstate:
if assertstate.hook is not None:
assertstate.hook.set_session(None)
def pytest_assertrepr_compare(
config: Config, op: str, left: Any, right: Any
) -> list[str] | None:
return util.assertrepr_compare(config=config, op=op, left=left, right=right)

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,137 @@
"""Utilities for truncating assertion output.
Current default behaviour is to truncate assertion explanations at
terminal lines, unless running with an assertions verbosity level of at least 2 or running on CI.
"""
from __future__ import annotations
from _pytest.assertion import util
from _pytest.config import Config
from _pytest.nodes import Item
DEFAULT_MAX_LINES = 8
DEFAULT_MAX_CHARS = DEFAULT_MAX_LINES * 80
USAGE_MSG = "use '-vv' to show"
def truncate_if_required(explanation: list[str], item: Item) -> list[str]:
"""Truncate this assertion explanation if the given test item is eligible."""
should_truncate, max_lines, max_chars = _get_truncation_parameters(item)
if should_truncate:
return _truncate_explanation(
explanation,
max_lines=max_lines,
max_chars=max_chars,
)
return explanation
def _get_truncation_parameters(item: Item) -> tuple[bool, int, int]:
"""Return the truncation parameters related to the given item, as (should truncate, max lines, max chars)."""
# We do not need to truncate if one of conditions is met:
# 1. Verbosity level is 2 or more;
# 2. Test is being run in CI environment;
# 3. Both truncation_limit_lines and truncation_limit_chars
# .ini parameters are set to 0 explicitly.
max_lines = item.config.getini("truncation_limit_lines")
max_lines = int(max_lines if max_lines is not None else DEFAULT_MAX_LINES)
max_chars = item.config.getini("truncation_limit_chars")
max_chars = int(max_chars if max_chars is not None else DEFAULT_MAX_CHARS)
verbose = item.config.get_verbosity(Config.VERBOSITY_ASSERTIONS)
should_truncate = verbose < 2 and not util.running_on_ci()
should_truncate = should_truncate and (max_lines > 0 or max_chars > 0)
return should_truncate, max_lines, max_chars
def _truncate_explanation(
input_lines: list[str],
max_lines: int,
max_chars: int,
) -> list[str]:
"""Truncate given list of strings that makes up the assertion explanation.
Truncates to either max_lines, or max_chars - whichever the input reaches
first, taking the truncation explanation into account. The remaining lines
will be replaced by a usage message.
"""
# Check if truncation required
input_char_count = len("".join(input_lines))
# The length of the truncation explanation depends on the number of lines
# removed but is at least 68 characters:
# The real value is
# 64 (for the base message:
# '...\n...Full output truncated (1 line hidden), use '-vv' to show")'
# )
# + 1 (for plural)
# + int(math.log10(len(input_lines) - max_lines)) (number of hidden line, at least 1)
# + 3 for the '...' added to the truncated line
# But if there's more than 100 lines it's very likely that we're going to
# truncate, so we don't need the exact value using log10.
tolerable_max_chars = (
max_chars + 70 # 64 + 1 (for plural) + 2 (for '99') + 3 for '...'
)
# The truncation explanation add two lines to the output
tolerable_max_lines = max_lines + 2
if (
len(input_lines) <= tolerable_max_lines
and input_char_count <= tolerable_max_chars
):
return input_lines
# Truncate first to max_lines, and then truncate to max_chars if necessary
if max_lines > 0:
truncated_explanation = input_lines[:max_lines]
else:
truncated_explanation = input_lines
truncated_char = True
# We reevaluate the need to truncate chars following removal of some lines
if len("".join(truncated_explanation)) > tolerable_max_chars and max_chars > 0:
truncated_explanation = _truncate_by_char_count(
truncated_explanation, max_chars
)
else:
truncated_char = False
if truncated_explanation == input_lines:
# No truncation happened, so we do not need to add any explanations
return truncated_explanation
truncated_line_count = len(input_lines) - len(truncated_explanation)
if truncated_explanation[-1]:
# Add ellipsis and take into account part-truncated final line
truncated_explanation[-1] = truncated_explanation[-1] + "..."
if truncated_char:
# It's possible that we did not remove any char from this line
truncated_line_count += 1
else:
# Add proper ellipsis when we were able to fit a full line exactly
truncated_explanation[-1] = "..."
return [
*truncated_explanation,
"",
f"...Full output truncated ({truncated_line_count} line"
f"{'' if truncated_line_count == 1 else 's'} hidden), {USAGE_MSG}",
]
def _truncate_by_char_count(input_lines: list[str], max_chars: int) -> list[str]:
# Find point at which input length exceeds total allowed length
iterated_char_count = 0
for iterated_index, input_line in enumerate(input_lines):
if iterated_char_count + len(input_line) > max_chars:
break
iterated_char_count += len(input_line)
# Create truncated explanation with modified final line
truncated_result = input_lines[:iterated_index]
final_line = input_lines[iterated_index]
if final_line:
final_line_truncate_point = max_chars - iterated_char_count
final_line = final_line[:final_line_truncate_point]
truncated_result.append(final_line)
return truncated_result

View file

@ -0,0 +1,621 @@
# mypy: allow-untyped-defs
"""Utilities for assertion debugging."""
from __future__ import annotations
import collections.abc
from collections.abc import Callable
from collections.abc import Iterable
from collections.abc import Mapping
from collections.abc import Sequence
from collections.abc import Set as AbstractSet
import os
import pprint
from typing import Any
from typing import Literal
from typing import Protocol
from unicodedata import normalize
from _pytest import outcomes
import _pytest._code
from _pytest._io.pprint import PrettyPrinter
from _pytest._io.saferepr import saferepr
from _pytest._io.saferepr import saferepr_unlimited
from _pytest.config import Config
# The _reprcompare attribute on the util module is used by the new assertion
# interpretation code and assertion rewriter to detect this plugin was
# loaded and in turn call the hooks defined here as part of the
# DebugInterpreter.
_reprcompare: Callable[[str, object, object], str | None] | None = None
# Works similarly as _reprcompare attribute. Is populated with the hook call
# when pytest_runtest_setup is called.
_assertion_pass: Callable[[int, str, str], None] | None = None
# Config object which is assigned during pytest_runtest_protocol.
_config: Config | None = None
class _HighlightFunc(Protocol):
def __call__(self, source: str, lexer: Literal["diff", "python"] = "python") -> str:
"""Apply highlighting to the given source."""
def dummy_highlighter(source: str, lexer: Literal["diff", "python"] = "python") -> str:
"""Dummy highlighter that returns the text unprocessed.
Needed for _notin_text, as the diff gets post-processed to only show the "+" part.
"""
return source
def format_explanation(explanation: str) -> str:
r"""Format an explanation.
Normally all embedded newlines are escaped, however there are
three exceptions: \n{, \n} and \n~. The first two are intended
cover nested explanations, see function and attribute explanations
for examples (.visit_Call(), visit_Attribute()). The last one is
for when one explanation needs to span multiple lines, e.g. when
displaying diffs.
"""
lines = _split_explanation(explanation)
result = _format_lines(lines)
return "\n".join(result)
def _split_explanation(explanation: str) -> list[str]:
r"""Return a list of individual lines in the explanation.
This will return a list of lines split on '\n{', '\n}' and '\n~'.
Any other newlines will be escaped and appear in the line as the
literal '\n' characters.
"""
raw_lines = (explanation or "").split("\n")
lines = [raw_lines[0]]
for values in raw_lines[1:]:
if values and values[0] in ["{", "}", "~", ">"]:
lines.append(values)
else:
lines[-1] += "\\n" + values
return lines
def _format_lines(lines: Sequence[str]) -> list[str]:
"""Format the individual lines.
This will replace the '{', '}' and '~' characters of our mini formatting
language with the proper 'where ...', 'and ...' and ' + ...' text, taking
care of indentation along the way.
Return a list of formatted lines.
"""
result = list(lines[:1])
stack = [0]
stackcnt = [0]
for line in lines[1:]:
if line.startswith("{"):
if stackcnt[-1]:
s = "and "
else:
s = "where "
stack.append(len(result))
stackcnt[-1] += 1
stackcnt.append(0)
result.append(" +" + " " * (len(stack) - 1) + s + line[1:])
elif line.startswith("}"):
stack.pop()
stackcnt.pop()
result[stack[-1]] += line[1:]
else:
assert line[0] in ["~", ">"]
stack[-1] += 1
indent = len(stack) if line.startswith("~") else len(stack) - 1
result.append(" " * indent + line[1:])
assert len(stack) == 1
return result
def issequence(x: Any) -> bool:
return isinstance(x, collections.abc.Sequence) and not isinstance(x, str)
def istext(x: Any) -> bool:
return isinstance(x, str)
def isdict(x: Any) -> bool:
return isinstance(x, dict)
def isset(x: Any) -> bool:
return isinstance(x, (set, frozenset))
def isnamedtuple(obj: Any) -> bool:
return isinstance(obj, tuple) and getattr(obj, "_fields", None) is not None
def isdatacls(obj: Any) -> bool:
return getattr(obj, "__dataclass_fields__", None) is not None
def isattrs(obj: Any) -> bool:
return getattr(obj, "__attrs_attrs__", None) is not None
def isiterable(obj: Any) -> bool:
try:
iter(obj)
return not istext(obj)
except Exception:
return False
def has_default_eq(
obj: object,
) -> bool:
"""Check if an instance of an object contains the default eq
First, we check if the object's __eq__ attribute has __code__,
if so, we check the equally of the method code filename (__code__.co_filename)
to the default one generated by the dataclass and attr module
for dataclasses the default co_filename is <string>, for attrs class, the __eq__ should contain "attrs eq generated"
"""
# inspired from https://github.com/willmcgugan/rich/blob/07d51ffc1aee6f16bd2e5a25b4e82850fb9ed778/rich/pretty.py#L68
if hasattr(obj.__eq__, "__code__") and hasattr(obj.__eq__.__code__, "co_filename"):
code_filename = obj.__eq__.__code__.co_filename
if isattrs(obj):
return "attrs generated " in code_filename
return code_filename == "<string>" # data class
return True
def assertrepr_compare(
config, op: str, left: Any, right: Any, use_ascii: bool = False
) -> list[str] | None:
"""Return specialised explanations for some operators/operands."""
verbose = config.get_verbosity(Config.VERBOSITY_ASSERTIONS)
# Strings which normalize equal are often hard to distinguish when printed; use ascii() to make this easier.
# See issue #3246.
use_ascii = (
isinstance(left, str)
and isinstance(right, str)
and normalize("NFD", left) == normalize("NFD", right)
)
if verbose > 1:
left_repr = saferepr_unlimited(left, use_ascii=use_ascii)
right_repr = saferepr_unlimited(right, use_ascii=use_ascii)
else:
# XXX: "15 chars indentation" is wrong
# ("E AssertionError: assert "); should use term width.
maxsize = (
80 - 15 - len(op) - 2
) // 2 # 15 chars indentation, 1 space around op
left_repr = saferepr(left, maxsize=maxsize, use_ascii=use_ascii)
right_repr = saferepr(right, maxsize=maxsize, use_ascii=use_ascii)
summary = f"{left_repr} {op} {right_repr}"
highlighter = config.get_terminal_writer()._highlight
explanation = None
try:
if op == "==":
explanation = _compare_eq_any(left, right, highlighter, verbose)
elif op == "not in":
if istext(left) and istext(right):
explanation = _notin_text(left, right, verbose)
elif op == "!=":
if isset(left) and isset(right):
explanation = ["Both sets are equal"]
elif op == ">=":
if isset(left) and isset(right):
explanation = _compare_gte_set(left, right, highlighter, verbose)
elif op == "<=":
if isset(left) and isset(right):
explanation = _compare_lte_set(left, right, highlighter, verbose)
elif op == ">":
if isset(left) and isset(right):
explanation = _compare_gt_set(left, right, highlighter, verbose)
elif op == "<":
if isset(left) and isset(right):
explanation = _compare_lt_set(left, right, highlighter, verbose)
except outcomes.Exit:
raise
except Exception:
repr_crash = _pytest._code.ExceptionInfo.from_current()._getreprcrash()
explanation = [
f"(pytest_assertion plugin: representation of details failed: {repr_crash}.",
" Probably an object has a faulty __repr__.)",
]
if not explanation:
return None
if explanation[0] != "":
explanation = ["", *explanation]
return [summary, *explanation]
def _compare_eq_any(
left: Any, right: Any, highlighter: _HighlightFunc, verbose: int = 0
) -> list[str]:
explanation = []
if istext(left) and istext(right):
explanation = _diff_text(left, right, highlighter, verbose)
else:
from _pytest.python_api import ApproxBase
if isinstance(left, ApproxBase) or isinstance(right, ApproxBase):
# Although the common order should be obtained == expected, this ensures both ways
approx_side = left if isinstance(left, ApproxBase) else right
other_side = right if isinstance(left, ApproxBase) else left
explanation = approx_side._repr_compare(other_side)
elif type(left) is type(right) and (
isdatacls(left) or isattrs(left) or isnamedtuple(left)
):
# Note: unlike dataclasses/attrs, namedtuples compare only the
# field values, not the type or field names. But this branch
# intentionally only handles the same-type case, which was often
# used in older code bases before dataclasses/attrs were available.
explanation = _compare_eq_cls(left, right, highlighter, verbose)
elif issequence(left) and issequence(right):
explanation = _compare_eq_sequence(left, right, highlighter, verbose)
elif isset(left) and isset(right):
explanation = _compare_eq_set(left, right, highlighter, verbose)
elif isdict(left) and isdict(right):
explanation = _compare_eq_dict(left, right, highlighter, verbose)
if isiterable(left) and isiterable(right):
expl = _compare_eq_iterable(left, right, highlighter, verbose)
explanation.extend(expl)
return explanation
def _diff_text(
left: str, right: str, highlighter: _HighlightFunc, verbose: int = 0
) -> list[str]:
"""Return the explanation for the diff between text.
Unless --verbose is used this will skip leading and trailing
characters which are identical to keep the diff minimal.
"""
from difflib import ndiff
explanation: list[str] = []
if verbose < 1:
i = 0 # just in case left or right has zero length
for i in range(min(len(left), len(right))):
if left[i] != right[i]:
break
if i > 42:
i -= 10 # Provide some context
explanation = [
f"Skipping {i} identical leading characters in diff, use -v to show"
]
left = left[i:]
right = right[i:]
if len(left) == len(right):
for i in range(len(left)):
if left[-i] != right[-i]:
break
if i > 42:
i -= 10 # Provide some context
explanation += [
f"Skipping {i} identical trailing "
"characters in diff, use -v to show"
]
left = left[:-i]
right = right[:-i]
keepends = True
if left.isspace() or right.isspace():
left = repr(str(left))
right = repr(str(right))
explanation += ["Strings contain only whitespace, escaping them using repr()"]
# "right" is the expected base against which we compare "left",
# see https://github.com/pytest-dev/pytest/issues/3333
explanation.extend(
highlighter(
"\n".join(
line.strip("\n")
for line in ndiff(right.splitlines(keepends), left.splitlines(keepends))
),
lexer="diff",
).splitlines()
)
return explanation
def _compare_eq_iterable(
left: Iterable[Any],
right: Iterable[Any],
highlighter: _HighlightFunc,
verbose: int = 0,
) -> list[str]:
if verbose <= 0 and not running_on_ci():
return ["Use -v to get more diff"]
# dynamic import to speedup pytest
import difflib
left_formatting = PrettyPrinter().pformat(left).splitlines()
right_formatting = PrettyPrinter().pformat(right).splitlines()
explanation = ["", "Full diff:"]
# "right" is the expected base against which we compare "left",
# see https://github.com/pytest-dev/pytest/issues/3333
explanation.extend(
highlighter(
"\n".join(
line.rstrip()
for line in difflib.ndiff(right_formatting, left_formatting)
),
lexer="diff",
).splitlines()
)
return explanation
def _compare_eq_sequence(
left: Sequence[Any],
right: Sequence[Any],
highlighter: _HighlightFunc,
verbose: int = 0,
) -> list[str]:
comparing_bytes = isinstance(left, bytes) and isinstance(right, bytes)
explanation: list[str] = []
len_left = len(left)
len_right = len(right)
for i in range(min(len_left, len_right)):
if left[i] != right[i]:
if comparing_bytes:
# when comparing bytes, we want to see their ascii representation
# instead of their numeric values (#5260)
# using a slice gives us the ascii representation:
# >>> s = b'foo'
# >>> s[0]
# 102
# >>> s[0:1]
# b'f'
left_value = left[i : i + 1]
right_value = right[i : i + 1]
else:
left_value = left[i]
right_value = right[i]
explanation.append(
f"At index {i} diff:"
f" {highlighter(repr(left_value))} != {highlighter(repr(right_value))}"
)
break
if comparing_bytes:
# when comparing bytes, it doesn't help to show the "sides contain one or more
# items" longer explanation, so skip it
return explanation
len_diff = len_left - len_right
if len_diff:
if len_diff > 0:
dir_with_more = "Left"
extra = saferepr(left[len_right])
else:
len_diff = 0 - len_diff
dir_with_more = "Right"
extra = saferepr(right[len_left])
if len_diff == 1:
explanation += [
f"{dir_with_more} contains one more item: {highlighter(extra)}"
]
else:
explanation += [
f"{dir_with_more} contains {len_diff} more items, first extra item: {highlighter(extra)}"
]
return explanation
def _compare_eq_set(
left: AbstractSet[Any],
right: AbstractSet[Any],
highlighter: _HighlightFunc,
verbose: int = 0,
) -> list[str]:
explanation = []
explanation.extend(_set_one_sided_diff("left", left, right, highlighter))
explanation.extend(_set_one_sided_diff("right", right, left, highlighter))
return explanation
def _compare_gt_set(
left: AbstractSet[Any],
right: AbstractSet[Any],
highlighter: _HighlightFunc,
verbose: int = 0,
) -> list[str]:
explanation = _compare_gte_set(left, right, highlighter)
if not explanation:
return ["Both sets are equal"]
return explanation
def _compare_lt_set(
left: AbstractSet[Any],
right: AbstractSet[Any],
highlighter: _HighlightFunc,
verbose: int = 0,
) -> list[str]:
explanation = _compare_lte_set(left, right, highlighter)
if not explanation:
return ["Both sets are equal"]
return explanation
def _compare_gte_set(
left: AbstractSet[Any],
right: AbstractSet[Any],
highlighter: _HighlightFunc,
verbose: int = 0,
) -> list[str]:
return _set_one_sided_diff("right", right, left, highlighter)
def _compare_lte_set(
left: AbstractSet[Any],
right: AbstractSet[Any],
highlighter: _HighlightFunc,
verbose: int = 0,
) -> list[str]:
return _set_one_sided_diff("left", left, right, highlighter)
def _set_one_sided_diff(
posn: str,
set1: AbstractSet[Any],
set2: AbstractSet[Any],
highlighter: _HighlightFunc,
) -> list[str]:
explanation = []
diff = set1 - set2
if diff:
explanation.append(f"Extra items in the {posn} set:")
for item in diff:
explanation.append(highlighter(saferepr(item)))
return explanation
def _compare_eq_dict(
left: Mapping[Any, Any],
right: Mapping[Any, Any],
highlighter: _HighlightFunc,
verbose: int = 0,
) -> list[str]:
explanation: list[str] = []
set_left = set(left)
set_right = set(right)
common = set_left.intersection(set_right)
same = {k: left[k] for k in common if left[k] == right[k]}
if same and verbose < 2:
explanation += [f"Omitting {len(same)} identical items, use -vv to show"]
elif same:
explanation += ["Common items:"]
explanation += highlighter(pprint.pformat(same)).splitlines()
diff = {k for k in common if left[k] != right[k]}
if diff:
explanation += ["Differing items:"]
for k in diff:
explanation += [
highlighter(saferepr({k: left[k]}))
+ " != "
+ highlighter(saferepr({k: right[k]}))
]
extra_left = set_left - set_right
len_extra_left = len(extra_left)
if len_extra_left:
explanation.append(
f"Left contains {len_extra_left} more item{'' if len_extra_left == 1 else 's'}:"
)
explanation.extend(
highlighter(pprint.pformat({k: left[k] for k in extra_left})).splitlines()
)
extra_right = set_right - set_left
len_extra_right = len(extra_right)
if len_extra_right:
explanation.append(
f"Right contains {len_extra_right} more item{'' if len_extra_right == 1 else 's'}:"
)
explanation.extend(
highlighter(pprint.pformat({k: right[k] for k in extra_right})).splitlines()
)
return explanation
def _compare_eq_cls(
left: Any, right: Any, highlighter: _HighlightFunc, verbose: int
) -> list[str]:
if not has_default_eq(left):
return []
if isdatacls(left):
import dataclasses
all_fields = dataclasses.fields(left)
fields_to_check = [info.name for info in all_fields if info.compare]
elif isattrs(left):
all_fields = left.__attrs_attrs__
fields_to_check = [field.name for field in all_fields if getattr(field, "eq")]
elif isnamedtuple(left):
fields_to_check = left._fields
else:
assert False
indent = " "
same = []
diff = []
for field in fields_to_check:
if getattr(left, field) == getattr(right, field):
same.append(field)
else:
diff.append(field)
explanation = []
if same or diff:
explanation += [""]
if same and verbose < 2:
explanation.append(f"Omitting {len(same)} identical items, use -vv to show")
elif same:
explanation += ["Matching attributes:"]
explanation += highlighter(pprint.pformat(same)).splitlines()
if diff:
explanation += ["Differing attributes:"]
explanation += highlighter(pprint.pformat(diff)).splitlines()
for field in diff:
field_left = getattr(left, field)
field_right = getattr(right, field)
explanation += [
"",
f"Drill down into differing attribute {field}:",
f"{indent}{field}: {highlighter(repr(field_left))} != {highlighter(repr(field_right))}",
]
explanation += [
indent + line
for line in _compare_eq_any(
field_left, field_right, highlighter, verbose
)
]
return explanation
def _notin_text(term: str, text: str, verbose: int = 0) -> list[str]:
index = text.find(term)
head = text[:index]
tail = text[index + len(term) :]
correct_text = head + tail
diff = _diff_text(text, correct_text, dummy_highlighter, verbose)
newdiff = [f"{saferepr(term, maxsize=42)} is contained here:"]
for line in diff:
if line.startswith("Skipping"):
continue
if line.startswith("- "):
continue
if line.startswith("+ "):
newdiff.append(" " + line[2:])
else:
newdiff.append(line)
return newdiff
def running_on_ci() -> bool:
"""Check if we're currently running on a CI system."""
env_vars = ["CI", "BUILD_NUMBER"]
return any(var in os.environ for var in env_vars)

View file

@ -0,0 +1,625 @@
# mypy: allow-untyped-defs
"""Implementation of the cache provider."""
# This plugin was not named "cache" to avoid conflicts with the external
# pytest-cache version.
from __future__ import annotations
from collections.abc import Generator
from collections.abc import Iterable
import dataclasses
import errno
import json
import os
from pathlib import Path
import tempfile
from typing import final
from .pathlib import resolve_from_str
from .pathlib import rm_rf
from .reports import CollectReport
from _pytest import nodes
from _pytest._io import TerminalWriter
from _pytest.config import Config
from _pytest.config import ExitCode
from _pytest.config import hookimpl
from _pytest.config.argparsing import Parser
from _pytest.deprecated import check_ispytest
from _pytest.fixtures import fixture
from _pytest.fixtures import FixtureRequest
from _pytest.main import Session
from _pytest.nodes import Directory
from _pytest.nodes import File
from _pytest.reports import TestReport
README_CONTENT = """\
# pytest cache directory #
This directory contains data from the pytest's cache plugin,
which provides the `--lf` and `--ff` options, as well as the `cache` fixture.
**Do not** commit this to version control.
See [the docs](https://docs.pytest.org/en/stable/how-to/cache.html) for more information.
"""
CACHEDIR_TAG_CONTENT = b"""\
Signature: 8a477f597d28d172789f06886806bc55
# This file is a cache directory tag created by pytest.
# For information about cache directory tags, see:
# https://bford.info/cachedir/spec.html
"""
@final
@dataclasses.dataclass
class Cache:
"""Instance of the `cache` fixture."""
_cachedir: Path = dataclasses.field(repr=False)
_config: Config = dataclasses.field(repr=False)
# Sub-directory under cache-dir for directories created by `mkdir()`.
_CACHE_PREFIX_DIRS = "d"
# Sub-directory under cache-dir for values created by `set()`.
_CACHE_PREFIX_VALUES = "v"
def __init__(
self, cachedir: Path, config: Config, *, _ispytest: bool = False
) -> None:
check_ispytest(_ispytest)
self._cachedir = cachedir
self._config = config
@classmethod
def for_config(cls, config: Config, *, _ispytest: bool = False) -> Cache:
"""Create the Cache instance for a Config.
:meta private:
"""
check_ispytest(_ispytest)
cachedir = cls.cache_dir_from_config(config, _ispytest=True)
if config.getoption("cacheclear") and cachedir.is_dir():
cls.clear_cache(cachedir, _ispytest=True)
return cls(cachedir, config, _ispytest=True)
@classmethod
def clear_cache(cls, cachedir: Path, _ispytest: bool = False) -> None:
"""Clear the sub-directories used to hold cached directories and values.
:meta private:
"""
check_ispytest(_ispytest)
for prefix in (cls._CACHE_PREFIX_DIRS, cls._CACHE_PREFIX_VALUES):
d = cachedir / prefix
if d.is_dir():
rm_rf(d)
@staticmethod
def cache_dir_from_config(config: Config, *, _ispytest: bool = False) -> Path:
"""Get the path to the cache directory for a Config.
:meta private:
"""
check_ispytest(_ispytest)
return resolve_from_str(config.getini("cache_dir"), config.rootpath)
def warn(self, fmt: str, *, _ispytest: bool = False, **args: object) -> None:
"""Issue a cache warning.
:meta private:
"""
check_ispytest(_ispytest)
import warnings
from _pytest.warning_types import PytestCacheWarning
warnings.warn(
PytestCacheWarning(fmt.format(**args) if args else fmt),
self._config.hook,
stacklevel=3,
)
def _mkdir(self, path: Path) -> None:
self._ensure_cache_dir_and_supporting_files()
path.mkdir(exist_ok=True, parents=True)
def mkdir(self, name: str) -> Path:
"""Return a directory path object with the given name.
If the directory does not yet exist, it will be created. You can use
it to manage files to e.g. store/retrieve database dumps across test
sessions.
.. versionadded:: 7.0
:param name:
Must be a string not containing a ``/`` separator.
Make sure the name contains your plugin or application
identifiers to prevent clashes with other cache users.
"""
path = Path(name)
if len(path.parts) > 1:
raise ValueError("name is not allowed to contain path separators")
res = self._cachedir.joinpath(self._CACHE_PREFIX_DIRS, path)
self._mkdir(res)
return res
def _getvaluepath(self, key: str) -> Path:
return self._cachedir.joinpath(self._CACHE_PREFIX_VALUES, Path(key))
def get(self, key: str, default):
"""Return the cached value for the given key.
If no value was yet cached or the value cannot be read, the specified
default is returned.
:param key:
Must be a ``/`` separated value. Usually the first
name is the name of your plugin or your application.
:param default:
The value to return in case of a cache-miss or invalid cache value.
"""
path = self._getvaluepath(key)
try:
with path.open("r", encoding="UTF-8") as f:
return json.load(f)
except (ValueError, OSError):
return default
def set(self, key: str, value: object) -> None:
"""Save value for the given key.
:param key:
Must be a ``/`` separated value. Usually the first
name is the name of your plugin or your application.
:param value:
Must be of any combination of basic python types,
including nested types like lists of dictionaries.
"""
path = self._getvaluepath(key)
try:
self._mkdir(path.parent)
except OSError as exc:
self.warn(
f"could not create cache path {path}: {exc}",
_ispytest=True,
)
return
data = json.dumps(value, ensure_ascii=False, indent=2)
try:
f = path.open("w", encoding="UTF-8")
except OSError as exc:
self.warn(
f"cache could not write path {path}: {exc}",
_ispytest=True,
)
else:
with f:
f.write(data)
def _ensure_cache_dir_and_supporting_files(self) -> None:
"""Create the cache dir and its supporting files."""
if self._cachedir.is_dir():
return
self._cachedir.parent.mkdir(parents=True, exist_ok=True)
with tempfile.TemporaryDirectory(
prefix="pytest-cache-files-",
dir=self._cachedir.parent,
) as newpath:
path = Path(newpath)
# Reset permissions to the default, see #12308.
# Note: there's no way to get the current umask atomically, eek.
umask = os.umask(0o022)
os.umask(umask)
path.chmod(0o777 - umask)
with open(path.joinpath("README.md"), "x", encoding="UTF-8") as f:
f.write(README_CONTENT)
with open(path.joinpath(".gitignore"), "x", encoding="UTF-8") as f:
f.write("# Created by pytest automatically.\n*\n")
with open(path.joinpath("CACHEDIR.TAG"), "xb") as f:
f.write(CACHEDIR_TAG_CONTENT)
try:
path.rename(self._cachedir)
except OSError as e:
# If 2 concurrent pytests both race to the rename, the loser
# gets "Directory not empty" from the rename. In this case,
# everything is handled so just continue (while letting the
# temporary directory be cleaned up).
# On Windows, the error is a FileExistsError which translates to EEXIST.
if e.errno not in (errno.ENOTEMPTY, errno.EEXIST):
raise
else:
# Create a directory in place of the one we just moved so that
# `TemporaryDirectory`'s cleanup doesn't complain.
#
# TODO: pass ignore_cleanup_errors=True when we no longer support python < 3.10.
# See https://github.com/python/cpython/issues/74168. Note that passing
# delete=False would do the wrong thing in case of errors and isn't supported
# until python 3.12.
path.mkdir()
class LFPluginCollWrapper:
def __init__(self, lfplugin: LFPlugin) -> None:
self.lfplugin = lfplugin
self._collected_at_least_one_failure = False
@hookimpl(wrapper=True)
def pytest_make_collect_report(
self, collector: nodes.Collector
) -> Generator[None, CollectReport, CollectReport]:
res = yield
if isinstance(collector, (Session, Directory)):
# Sort any lf-paths to the beginning.
lf_paths = self.lfplugin._last_failed_paths
# Use stable sort to prioritize last failed.
def sort_key(node: nodes.Item | nodes.Collector) -> bool:
return node.path in lf_paths
res.result = sorted(
res.result,
key=sort_key,
reverse=True,
)
elif isinstance(collector, File):
if collector.path in self.lfplugin._last_failed_paths:
result = res.result
lastfailed = self.lfplugin.lastfailed
# Only filter with known failures.
if not self._collected_at_least_one_failure:
if not any(x.nodeid in lastfailed for x in result):
return res
self.lfplugin.config.pluginmanager.register(
LFPluginCollSkipfiles(self.lfplugin), "lfplugin-collskip"
)
self._collected_at_least_one_failure = True
session = collector.session
result[:] = [
x
for x in result
if x.nodeid in lastfailed
# Include any passed arguments (not trivial to filter).
or session.isinitpath(x.path)
# Keep all sub-collectors.
or isinstance(x, nodes.Collector)
]
return res
class LFPluginCollSkipfiles:
def __init__(self, lfplugin: LFPlugin) -> None:
self.lfplugin = lfplugin
@hookimpl
def pytest_make_collect_report(
self, collector: nodes.Collector
) -> CollectReport | None:
if isinstance(collector, File):
if collector.path not in self.lfplugin._last_failed_paths:
self.lfplugin._skipped_files += 1
return CollectReport(
collector.nodeid, "passed", longrepr=None, result=[]
)
return None
class LFPlugin:
"""Plugin which implements the --lf (run last-failing) option."""
def __init__(self, config: Config) -> None:
self.config = config
active_keys = "lf", "failedfirst"
self.active = any(config.getoption(key) for key in active_keys)
assert config.cache
self.lastfailed: dict[str, bool] = config.cache.get("cache/lastfailed", {})
self._previously_failed_count: int | None = None
self._report_status: str | None = None
self._skipped_files = 0 # count skipped files during collection due to --lf
if config.getoption("lf"):
self._last_failed_paths = self.get_last_failed_paths()
config.pluginmanager.register(
LFPluginCollWrapper(self), "lfplugin-collwrapper"
)
def get_last_failed_paths(self) -> set[Path]:
"""Return a set with all Paths of the previously failed nodeids and
their parents."""
rootpath = self.config.rootpath
result = set()
for nodeid in self.lastfailed:
path = rootpath / nodeid.split("::")[0]
result.add(path)
result.update(path.parents)
return {x for x in result if x.exists()}
def pytest_report_collectionfinish(self) -> str | None:
if self.active and self.config.get_verbosity() >= 0:
return f"run-last-failure: {self._report_status}"
return None
def pytest_runtest_logreport(self, report: TestReport) -> None:
if (report.when == "call" and report.passed) or report.skipped:
self.lastfailed.pop(report.nodeid, None)
elif report.failed:
self.lastfailed[report.nodeid] = True
def pytest_collectreport(self, report: CollectReport) -> None:
passed = report.outcome in ("passed", "skipped")
if passed:
if report.nodeid in self.lastfailed:
self.lastfailed.pop(report.nodeid)
self.lastfailed.update((item.nodeid, True) for item in report.result)
else:
self.lastfailed[report.nodeid] = True
@hookimpl(wrapper=True, tryfirst=True)
def pytest_collection_modifyitems(
self, config: Config, items: list[nodes.Item]
) -> Generator[None]:
res = yield
if not self.active:
return res
if self.lastfailed:
previously_failed = []
previously_passed = []
for item in items:
if item.nodeid in self.lastfailed:
previously_failed.append(item)
else:
previously_passed.append(item)
self._previously_failed_count = len(previously_failed)
if not previously_failed:
# Running a subset of all tests with recorded failures
# only outside of it.
self._report_status = (
f"{len(self.lastfailed)} known failures not in selected tests"
)
else:
if self.config.getoption("lf"):
items[:] = previously_failed
config.hook.pytest_deselected(items=previously_passed)
else: # --failedfirst
items[:] = previously_failed + previously_passed
noun = "failure" if self._previously_failed_count == 1 else "failures"
suffix = " first" if self.config.getoption("failedfirst") else ""
self._report_status = (
f"rerun previous {self._previously_failed_count} {noun}{suffix}"
)
if self._skipped_files > 0:
files_noun = "file" if self._skipped_files == 1 else "files"
self._report_status += f" (skipped {self._skipped_files} {files_noun})"
else:
self._report_status = "no previously failed tests, "
if self.config.getoption("last_failed_no_failures") == "none":
self._report_status += "deselecting all items."
config.hook.pytest_deselected(items=items[:])
items[:] = []
else:
self._report_status += "not deselecting items."
return res
def pytest_sessionfinish(self, session: Session) -> None:
config = self.config
if config.getoption("cacheshow") or hasattr(config, "workerinput"):
return
assert config.cache is not None
saved_lastfailed = config.cache.get("cache/lastfailed", {})
if saved_lastfailed != self.lastfailed:
config.cache.set("cache/lastfailed", self.lastfailed)
class NFPlugin:
"""Plugin which implements the --nf (run new-first) option."""
def __init__(self, config: Config) -> None:
self.config = config
self.active = config.option.newfirst
assert config.cache is not None
self.cached_nodeids = set(config.cache.get("cache/nodeids", []))
@hookimpl(wrapper=True, tryfirst=True)
def pytest_collection_modifyitems(self, items: list[nodes.Item]) -> Generator[None]:
res = yield
if self.active:
new_items: dict[str, nodes.Item] = {}
other_items: dict[str, nodes.Item] = {}
for item in items:
if item.nodeid not in self.cached_nodeids:
new_items[item.nodeid] = item
else:
other_items[item.nodeid] = item
items[:] = self._get_increasing_order(
new_items.values()
) + self._get_increasing_order(other_items.values())
self.cached_nodeids.update(new_items)
else:
self.cached_nodeids.update(item.nodeid for item in items)
return res
def _get_increasing_order(self, items: Iterable[nodes.Item]) -> list[nodes.Item]:
return sorted(items, key=lambda item: item.path.stat().st_mtime, reverse=True)
def pytest_sessionfinish(self) -> None:
config = self.config
if config.getoption("cacheshow") or hasattr(config, "workerinput"):
return
if config.getoption("collectonly"):
return
assert config.cache is not None
config.cache.set("cache/nodeids", sorted(self.cached_nodeids))
def pytest_addoption(parser: Parser) -> None:
group = parser.getgroup("general")
group.addoption(
"--lf",
"--last-failed",
action="store_true",
dest="lf",
help="Rerun only the tests that failed at the last run (or all if none failed)",
)
group.addoption(
"--ff",
"--failed-first",
action="store_true",
dest="failedfirst",
help="Run all tests, but run the last failures first. "
"This may re-order tests and thus lead to "
"repeated fixture setup/teardown.",
)
group.addoption(
"--nf",
"--new-first",
action="store_true",
dest="newfirst",
help="Run tests from new files first, then the rest of the tests "
"sorted by file mtime",
)
group.addoption(
"--cache-show",
action="append",
nargs="?",
dest="cacheshow",
help=(
"Show cache contents, don't perform collection or tests. "
"Optional argument: glob (default: '*')."
),
)
group.addoption(
"--cache-clear",
action="store_true",
dest="cacheclear",
help="Remove all cache contents at start of test run",
)
cache_dir_default = ".pytest_cache"
if "TOX_ENV_DIR" in os.environ:
cache_dir_default = os.path.join(os.environ["TOX_ENV_DIR"], cache_dir_default)
parser.addini("cache_dir", default=cache_dir_default, help="Cache directory path")
group.addoption(
"--lfnf",
"--last-failed-no-failures",
action="store",
dest="last_failed_no_failures",
choices=("all", "none"),
default="all",
help="With ``--lf``, determines whether to execute tests when there "
"are no previously (known) failures or when no "
"cached ``lastfailed`` data was found. "
"``all`` (the default) runs the full test suite again. "
"``none`` just emits a message about no known failures and exits successfully.",
)
def pytest_cmdline_main(config: Config) -> int | ExitCode | None:
if config.option.cacheshow and not config.option.help:
from _pytest.main import wrap_session
return wrap_session(config, cacheshow)
return None
@hookimpl(tryfirst=True)
def pytest_configure(config: Config) -> None:
config.cache = Cache.for_config(config, _ispytest=True)
config.pluginmanager.register(LFPlugin(config), "lfplugin")
config.pluginmanager.register(NFPlugin(config), "nfplugin")
@fixture
def cache(request: FixtureRequest) -> Cache:
"""Return a cache object that can persist state between testing sessions.
cache.get(key, default)
cache.set(key, value)
Keys must be ``/`` separated strings, where the first part is usually the
name of your plugin or application to avoid clashes with other cache users.
Values can be any object handled by the json stdlib module.
"""
assert request.config.cache is not None
return request.config.cache
def pytest_report_header(config: Config) -> str | None:
"""Display cachedir with --cache-show and if non-default."""
if config.option.verbose > 0 or config.getini("cache_dir") != ".pytest_cache":
assert config.cache is not None
cachedir = config.cache._cachedir
# TODO: evaluate generating upward relative paths
# starting with .., ../.. if sensible
try:
displaypath = cachedir.relative_to(config.rootpath)
except ValueError:
displaypath = cachedir
return f"cachedir: {displaypath}"
return None
def cacheshow(config: Config, session: Session) -> int:
from pprint import pformat
assert config.cache is not None
tw = TerminalWriter()
tw.line("cachedir: " + str(config.cache._cachedir))
if not config.cache._cachedir.is_dir():
tw.line("cache is empty")
return 0
glob = config.option.cacheshow[0]
if glob is None:
glob = "*"
dummy = object()
basedir = config.cache._cachedir
vdir = basedir / Cache._CACHE_PREFIX_VALUES
tw.sep("-", f"cache values for {glob!r}")
for valpath in sorted(x for x in vdir.rglob(glob) if x.is_file()):
key = str(valpath.relative_to(vdir))
val = config.cache.get(key, dummy)
if val is dummy:
tw.line(f"{key} contains unreadable content, will be ignored")
else:
tw.line(f"{key} contains:")
for line in pformat(val).splitlines():
tw.line(" " + line)
ddir = basedir / Cache._CACHE_PREFIX_DIRS
if ddir.is_dir():
contents = sorted(ddir.rglob(glob))
tw.sep("-", f"cache directories for {glob!r}")
for p in contents:
# if p.is_dir():
# print("%s/" % p.relative_to(basedir))
if p.is_file():
key = str(p.relative_to(basedir))
tw.line(f"{key} is a file of length {p.stat().st_size}")
return 0

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,333 @@
# mypy: allow-untyped-defs
"""Python version compatibility code."""
from __future__ import annotations
from collections.abc import Callable
import enum
import functools
import inspect
from inspect import Parameter
from inspect import Signature
import os
from pathlib import Path
import sys
from typing import Any
from typing import Final
from typing import NoReturn
import py
if sys.version_info >= (3, 14):
from annotationlib import Format
#: constant to prepare valuing pylib path replacements/lazy proxies later on
# intended for removal in pytest 8.0 or 9.0
# fmt: off
# intentional space to create a fake difference for the verification
LEGACY_PATH = py.path. local
# fmt: on
def legacy_path(path: str | os.PathLike[str]) -> LEGACY_PATH:
"""Internal wrapper to prepare lazy proxies for legacy_path instances"""
return LEGACY_PATH(path)
# fmt: off
# Singleton type for NOTSET, as described in:
# https://www.python.org/dev/peps/pep-0484/#support-for-singleton-types-in-unions
class NotSetType(enum.Enum):
token = 0
NOTSET: Final = NotSetType.token
# fmt: on
def iscoroutinefunction(func: object) -> bool:
"""Return True if func is a coroutine function (a function defined with async
def syntax, and doesn't contain yield), or a function decorated with
@asyncio.coroutine.
Note: copied and modified from Python 3.5's builtin coroutines.py to avoid
importing asyncio directly, which in turns also initializes the "logging"
module as a side-effect (see issue #8).
"""
return inspect.iscoroutinefunction(func) or getattr(func, "_is_coroutine", False)
def is_async_function(func: object) -> bool:
"""Return True if the given function seems to be an async function or
an async generator."""
return iscoroutinefunction(func) or inspect.isasyncgenfunction(func)
def signature(obj: Callable[..., Any]) -> Signature:
"""Return signature without evaluating annotations."""
if sys.version_info >= (3, 14):
return inspect.signature(obj, annotation_format=Format.STRING)
return inspect.signature(obj)
def getlocation(function, curdir: str | os.PathLike[str] | None = None) -> str:
function = get_real_func(function)
fn = Path(inspect.getfile(function))
lineno = function.__code__.co_firstlineno
if curdir is not None:
try:
relfn = fn.relative_to(curdir)
except ValueError:
pass
else:
return f"{relfn}:{lineno + 1}"
return f"{fn}:{lineno + 1}"
def num_mock_patch_args(function) -> int:
"""Return number of arguments used up by mock arguments (if any)."""
patchings = getattr(function, "patchings", None)
if not patchings:
return 0
mock_sentinel = getattr(sys.modules.get("mock"), "DEFAULT", object())
ut_mock_sentinel = getattr(sys.modules.get("unittest.mock"), "DEFAULT", object())
return len(
[
p
for p in patchings
if not p.attribute_name
and (p.new is mock_sentinel or p.new is ut_mock_sentinel)
]
)
def getfuncargnames(
function: Callable[..., object],
*,
name: str = "",
cls: type | None = None,
) -> tuple[str, ...]:
"""Return the names of a function's mandatory arguments.
Should return the names of all function arguments that:
* Aren't bound to an instance or type as in instance or class methods.
* Don't have default values.
* Aren't bound with functools.partial.
* Aren't replaced with mocks.
The cls arguments indicate that the function should be treated as a bound
method even though it's not unless the function is a static method.
The name parameter should be the original name in which the function was collected.
"""
# TODO(RonnyPfannschmidt): This function should be refactored when we
# revisit fixtures. The fixture mechanism should ask the node for
# the fixture names, and not try to obtain directly from the
# function object well after collection has occurred.
# The parameters attribute of a Signature object contains an
# ordered mapping of parameter names to Parameter instances. This
# creates a tuple of the names of the parameters that don't have
# defaults.
try:
parameters = signature(function).parameters.values()
except (ValueError, TypeError) as e:
from _pytest.outcomes import fail
fail(
f"Could not determine arguments of {function!r}: {e}",
pytrace=False,
)
arg_names = tuple(
p.name
for p in parameters
if (
p.kind is Parameter.POSITIONAL_OR_KEYWORD
or p.kind is Parameter.KEYWORD_ONLY
)
and p.default is Parameter.empty
)
if not name:
name = function.__name__
# If this function should be treated as a bound method even though
# it's passed as an unbound method or function, and its first parameter
# wasn't defined as positional only, remove the first parameter name.
if not any(p.kind is Parameter.POSITIONAL_ONLY for p in parameters) and (
# Not using `getattr` because we don't want to resolve the staticmethod.
# Not using `cls.__dict__` because we want to check the entire MRO.
cls
and not isinstance(
inspect.getattr_static(cls, name, default=None), staticmethod
)
):
arg_names = arg_names[1:]
# Remove any names that will be replaced with mocks.
if hasattr(function, "__wrapped__"):
arg_names = arg_names[num_mock_patch_args(function) :]
return arg_names
def get_default_arg_names(function: Callable[..., Any]) -> tuple[str, ...]:
# Note: this code intentionally mirrors the code at the beginning of
# getfuncargnames, to get the arguments which were excluded from its result
# because they had default values.
return tuple(
p.name
for p in signature(function).parameters.values()
if p.kind in (Parameter.POSITIONAL_OR_KEYWORD, Parameter.KEYWORD_ONLY)
and p.default is not Parameter.empty
)
_non_printable_ascii_translate_table = {
i: f"\\x{i:02x}" for i in range(128) if i not in range(32, 127)
}
_non_printable_ascii_translate_table.update(
{ord("\t"): "\\t", ord("\r"): "\\r", ord("\n"): "\\n"}
)
def ascii_escaped(val: bytes | str) -> str:
r"""If val is pure ASCII, return it as an str, otherwise, escape
bytes objects into a sequence of escaped bytes:
b'\xc3\xb4\xc5\xd6' -> r'\xc3\xb4\xc5\xd6'
and escapes strings into a sequence of escaped unicode ids, e.g.:
r'4\nV\U00043efa\x0eMXWB\x1e\u3028\u15fd\xcd\U0007d944'
Note:
The obvious "v.decode('unicode-escape')" will return
valid UTF-8 unicode if it finds them in bytes, but we
want to return escaped bytes for any byte, even if they match
a UTF-8 string.
"""
if isinstance(val, bytes):
ret = val.decode("ascii", "backslashreplace")
else:
ret = val.encode("unicode_escape").decode("ascii")
return ret.translate(_non_printable_ascii_translate_table)
def get_real_func(obj):
"""Get the real function object of the (possibly) wrapped object by
:func:`functools.wraps`, or :func:`functools.partial`."""
obj = inspect.unwrap(obj)
if isinstance(obj, functools.partial):
obj = obj.func
return obj
def getimfunc(func):
try:
return func.__func__
except AttributeError:
return func
def safe_getattr(object: Any, name: str, default: Any) -> Any:
"""Like getattr but return default upon any Exception or any OutcomeException.
Attribute access can potentially fail for 'evil' Python objects.
See issue #214.
It catches OutcomeException because of #2490 (issue #580), new outcomes
are derived from BaseException instead of Exception (for more details
check #2707).
"""
from _pytest.outcomes import TEST_OUTCOME
try:
return getattr(object, name, default)
except TEST_OUTCOME:
return default
def safe_isclass(obj: object) -> bool:
"""Ignore any exception via isinstance on Python 3."""
try:
return inspect.isclass(obj)
except Exception:
return False
def get_user_id() -> int | None:
"""Return the current process's real user id or None if it could not be
determined.
:return: The user id or None if it could not be determined.
"""
# mypy follows the version and platform checking expectation of PEP 484:
# https://mypy.readthedocs.io/en/stable/common_issues.html?highlight=platform#python-version-and-system-platform-checks
# Containment checks are too complex for mypy v1.5.0 and cause failure.
if sys.platform == "win32" or sys.platform == "emscripten":
# win32 does not have a getuid() function.
# Emscripten has a return 0 stub.
return None
else:
# On other platforms, a return value of -1 is assumed to indicate that
# the current process's real user id could not be determined.
ERROR = -1
uid = os.getuid()
return uid if uid != ERROR else None
# Perform exhaustiveness checking.
#
# Consider this example:
#
# MyUnion = Union[int, str]
#
# def handle(x: MyUnion) -> int {
# if isinstance(x, int):
# return 1
# elif isinstance(x, str):
# return 2
# else:
# raise Exception('unreachable')
#
# Now suppose we add a new variant:
#
# MyUnion = Union[int, str, bytes]
#
# After doing this, we must remember ourselves to go and update the handle
# function to handle the new variant.
#
# With `assert_never` we can do better:
#
# // raise Exception('unreachable')
# return assert_never(x)
#
# Now, if we forget to handle the new variant, the type-checker will emit a
# compile-time error, instead of the runtime error we would have gotten
# previously.
#
# This also work for Enums (if you use `is` to compare) and Literals.
def assert_never(value: NoReturn) -> NoReturn:
assert False, f"Unhandled value: {value} ({type(value).__name__})"
class CallableBool:
"""
A bool-like object that can also be called, returning its true/false value.
Used for backwards compatibility in cases where something was supposed to be a method
but was implemented as a simple attribute by mistake (see `TerminalReporter.isatty`).
Do not use in new code.
"""
def __init__(self, value: bool) -> None:
self._value = value
def __bool__(self) -> bool:
return self._value
def __call__(self) -> bool:
return self._value

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,535 @@
# mypy: allow-untyped-defs
from __future__ import annotations
import argparse
from collections.abc import Callable
from collections.abc import Mapping
from collections.abc import Sequence
import os
from typing import Any
from typing import cast
from typing import final
from typing import Literal
from typing import NoReturn
import _pytest._io
from _pytest.config.exceptions import UsageError
from _pytest.deprecated import check_ispytest
FILE_OR_DIR = "file_or_dir"
class NotSet:
def __repr__(self) -> str:
return "<notset>"
NOT_SET = NotSet()
@final
class Parser:
"""Parser for command line arguments and ini-file values.
:ivar extra_info: Dict of generic param -> value to display in case
there's an error processing the command line arguments.
"""
prog: str | None = None
def __init__(
self,
usage: str | None = None,
processopt: Callable[[Argument], None] | None = None,
*,
_ispytest: bool = False,
) -> None:
check_ispytest(_ispytest)
self._anonymous = OptionGroup("Custom options", parser=self, _ispytest=True)
self._groups: list[OptionGroup] = []
self._processopt = processopt
self._usage = usage
self._inidict: dict[str, tuple[str, str | None, Any]] = {}
self._ininames: list[str] = []
self.extra_info: dict[str, Any] = {}
def processoption(self, option: Argument) -> None:
if self._processopt:
if option.dest:
self._processopt(option)
def getgroup(
self, name: str, description: str = "", after: str | None = None
) -> OptionGroup:
"""Get (or create) a named option Group.
:param name: Name of the option group.
:param description: Long description for --help output.
:param after: Name of another group, used for ordering --help output.
:returns: The option group.
The returned group object has an ``addoption`` method with the same
signature as :func:`parser.addoption <pytest.Parser.addoption>` but
will be shown in the respective group in the output of
``pytest --help``.
"""
for group in self._groups:
if group.name == name:
return group
group = OptionGroup(name, description, parser=self, _ispytest=True)
i = 0
for i, grp in enumerate(self._groups):
if grp.name == after:
break
self._groups.insert(i + 1, group)
return group
def addoption(self, *opts: str, **attrs: Any) -> None:
"""Register a command line option.
:param opts:
Option names, can be short or long options.
:param attrs:
Same attributes as the argparse library's :meth:`add_argument()
<argparse.ArgumentParser.add_argument>` function accepts.
After command line parsing, options are available on the pytest config
object via ``config.option.NAME`` where ``NAME`` is usually set
by passing a ``dest`` attribute, for example
``addoption("--long", dest="NAME", ...)``.
"""
self._anonymous.addoption(*opts, **attrs)
def parse(
self,
args: Sequence[str | os.PathLike[str]],
namespace: argparse.Namespace | None = None,
) -> argparse.Namespace:
from _pytest._argcomplete import try_argcomplete
self.optparser = self._getparser()
try_argcomplete(self.optparser)
strargs = [os.fspath(x) for x in args]
return self.optparser.parse_args(strargs, namespace=namespace)
def _getparser(self) -> MyOptionParser:
from _pytest._argcomplete import filescompleter
optparser = MyOptionParser(self, self.extra_info, prog=self.prog)
groups = [*self._groups, self._anonymous]
for group in groups:
if group.options:
desc = group.description or group.name
arggroup = optparser.add_argument_group(desc)
for option in group.options:
n = option.names()
a = option.attrs()
arggroup.add_argument(*n, **a)
file_or_dir_arg = optparser.add_argument(FILE_OR_DIR, nargs="*")
# bash like autocompletion for dirs (appending '/')
# Type ignored because typeshed doesn't know about argcomplete.
file_or_dir_arg.completer = filescompleter # type: ignore
return optparser
def parse_setoption(
self,
args: Sequence[str | os.PathLike[str]],
option: argparse.Namespace,
namespace: argparse.Namespace | None = None,
) -> list[str]:
parsedoption = self.parse(args, namespace=namespace)
for name, value in parsedoption.__dict__.items():
setattr(option, name, value)
return cast(list[str], getattr(parsedoption, FILE_OR_DIR))
def parse_known_args(
self,
args: Sequence[str | os.PathLike[str]],
namespace: argparse.Namespace | None = None,
) -> argparse.Namespace:
"""Parse the known arguments at this point.
:returns: An argparse namespace object.
"""
return self.parse_known_and_unknown_args(args, namespace=namespace)[0]
def parse_known_and_unknown_args(
self,
args: Sequence[str | os.PathLike[str]],
namespace: argparse.Namespace | None = None,
) -> tuple[argparse.Namespace, list[str]]:
"""Parse the known arguments at this point, and also return the
remaining unknown arguments.
:returns:
A tuple containing an argparse namespace object for the known
arguments, and a list of the unknown arguments.
"""
optparser = self._getparser()
strargs = [os.fspath(x) for x in args]
return optparser.parse_known_args(strargs, namespace=namespace)
def addini(
self,
name: str,
help: str,
type: Literal[
"string", "paths", "pathlist", "args", "linelist", "bool", "int", "float"
]
| None = None,
default: Any = NOT_SET,
) -> None:
"""Register an ini-file option.
:param name:
Name of the ini-variable.
:param type:
Type of the variable. Can be:
* ``string``: a string
* ``bool``: a boolean
* ``args``: a list of strings, separated as in a shell
* ``linelist``: a list of strings, separated by line breaks
* ``paths``: a list of :class:`pathlib.Path`, separated as in a shell
* ``pathlist``: a list of ``py.path``, separated as in a shell
* ``int``: an integer
* ``float``: a floating-point number
.. versionadded:: 8.4
The ``float`` and ``int`` types.
For ``paths`` and ``pathlist`` types, they are considered relative to the ini-file.
In case the execution is happening without an ini-file defined,
they will be considered relative to the current working directory (for example with ``--override-ini``).
.. versionadded:: 7.0
The ``paths`` variable type.
.. versionadded:: 8.1
Use the current working directory to resolve ``paths`` and ``pathlist`` in the absence of an ini-file.
Defaults to ``string`` if ``None`` or not passed.
:param default:
Default value if no ini-file option exists but is queried.
The value of ini-variables can be retrieved via a call to
:py:func:`config.getini(name) <pytest.Config.getini>`.
"""
assert type in (
None,
"string",
"paths",
"pathlist",
"args",
"linelist",
"bool",
"int",
"float",
)
if default is NOT_SET:
default = get_ini_default_for_type(type)
self._inidict[name] = (help, type, default)
self._ininames.append(name)
def get_ini_default_for_type(
type: Literal[
"string", "paths", "pathlist", "args", "linelist", "bool", "int", "float"
]
| None,
) -> Any:
"""
Used by addini to get the default value for a given ini-option type, when
default is not supplied.
"""
if type is None:
return ""
elif type in ("paths", "pathlist", "args", "linelist"):
return []
elif type == "bool":
return False
elif type == "int":
return 0
elif type == "float":
return 0.0
else:
return ""
class ArgumentError(Exception):
"""Raised if an Argument instance is created with invalid or
inconsistent arguments."""
def __init__(self, msg: str, option: Argument | str) -> None:
self.msg = msg
self.option_id = str(option)
def __str__(self) -> str:
if self.option_id:
return f"option {self.option_id}: {self.msg}"
else:
return self.msg
class Argument:
"""Class that mimics the necessary behaviour of optparse.Option.
It's currently a least effort implementation and ignoring choices
and integer prefixes.
https://docs.python.org/3/library/optparse.html#optparse-standard-option-types
"""
def __init__(self, *names: str, **attrs: Any) -> None:
"""Store params in private vars for use in add_argument."""
self._attrs = attrs
self._short_opts: list[str] = []
self._long_opts: list[str] = []
try:
self.type = attrs["type"]
except KeyError:
pass
try:
# Attribute existence is tested in Config._processopt.
self.default = attrs["default"]
except KeyError:
pass
self._set_opt_strings(names)
dest: str | None = attrs.get("dest")
if dest:
self.dest = dest
elif self._long_opts:
self.dest = self._long_opts[0][2:].replace("-", "_")
else:
try:
self.dest = self._short_opts[0][1:]
except IndexError as e:
self.dest = "???" # Needed for the error repr.
raise ArgumentError("need a long or short option", self) from e
def names(self) -> list[str]:
return self._short_opts + self._long_opts
def attrs(self) -> Mapping[str, Any]:
# Update any attributes set by processopt.
attrs = "default dest help".split()
attrs.append(self.dest)
for attr in attrs:
try:
self._attrs[attr] = getattr(self, attr)
except AttributeError:
pass
return self._attrs
def _set_opt_strings(self, opts: Sequence[str]) -> None:
"""Directly from optparse.
Might not be necessary as this is passed to argparse later on.
"""
for opt in opts:
if len(opt) < 2:
raise ArgumentError(
f"invalid option string {opt!r}: "
"must be at least two characters long",
self,
)
elif len(opt) == 2:
if not (opt[0] == "-" and opt[1] != "-"):
raise ArgumentError(
f"invalid short option string {opt!r}: "
"must be of the form -x, (x any non-dash char)",
self,
)
self._short_opts.append(opt)
else:
if not (opt[0:2] == "--" and opt[2] != "-"):
raise ArgumentError(
f"invalid long option string {opt!r}: "
"must start with --, followed by non-dash",
self,
)
self._long_opts.append(opt)
def __repr__(self) -> str:
args: list[str] = []
if self._short_opts:
args += ["_short_opts: " + repr(self._short_opts)]
if self._long_opts:
args += ["_long_opts: " + repr(self._long_opts)]
args += ["dest: " + repr(self.dest)]
if hasattr(self, "type"):
args += ["type: " + repr(self.type)]
if hasattr(self, "default"):
args += ["default: " + repr(self.default)]
return "Argument({})".format(", ".join(args))
class OptionGroup:
"""A group of options shown in its own section."""
def __init__(
self,
name: str,
description: str = "",
parser: Parser | None = None,
*,
_ispytest: bool = False,
) -> None:
check_ispytest(_ispytest)
self.name = name
self.description = description
self.options: list[Argument] = []
self.parser = parser
def addoption(self, *opts: str, **attrs: Any) -> None:
"""Add an option to this group.
If a shortened version of a long option is specified, it will
be suppressed in the help. ``addoption('--twowords', '--two-words')``
results in help showing ``--two-words`` only, but ``--twowords`` gets
accepted **and** the automatic destination is in ``args.twowords``.
:param opts:
Option names, can be short or long options.
:param attrs:
Same attributes as the argparse library's :meth:`add_argument()
<argparse.ArgumentParser.add_argument>` function accepts.
"""
conflict = set(opts).intersection(
name for opt in self.options for name in opt.names()
)
if conflict:
raise ValueError(f"option names {conflict} already added")
option = Argument(*opts, **attrs)
self._addoption_instance(option, shortupper=False)
def _addoption(self, *opts: str, **attrs: Any) -> None:
option = Argument(*opts, **attrs)
self._addoption_instance(option, shortupper=True)
def _addoption_instance(self, option: Argument, shortupper: bool = False) -> None:
if not shortupper:
for opt in option._short_opts:
if opt[0] == "-" and opt[1].islower():
raise ValueError("lowercase shortoptions reserved")
if self.parser:
self.parser.processoption(option)
self.options.append(option)
class MyOptionParser(argparse.ArgumentParser):
def __init__(
self,
parser: Parser,
extra_info: dict[str, Any] | None = None,
prog: str | None = None,
) -> None:
self._parser = parser
super().__init__(
prog=prog,
usage=parser._usage,
add_help=False,
formatter_class=DropShorterLongHelpFormatter,
allow_abbrev=False,
fromfile_prefix_chars="@",
)
# extra_info is a dict of (param -> value) to display if there's
# an usage error to provide more contextual information to the user.
self.extra_info = extra_info if extra_info else {}
def error(self, message: str) -> NoReturn:
"""Transform argparse error message into UsageError."""
msg = f"{self.prog}: error: {message}"
if hasattr(self._parser, "_config_source_hint"):
msg = f"{msg} ({self._parser._config_source_hint})"
raise UsageError(self.format_usage() + msg)
# Type ignored because typeshed has a very complex type in the superclass.
def parse_args( # type: ignore
self,
args: Sequence[str] | None = None,
namespace: argparse.Namespace | None = None,
) -> argparse.Namespace:
"""Allow splitting of positional arguments."""
parsed, unrecognized = self.parse_known_args(args, namespace)
if unrecognized:
for arg in unrecognized:
if arg and arg[0] == "-":
lines = [
"unrecognized arguments: {}".format(" ".join(unrecognized))
]
for k, v in sorted(self.extra_info.items()):
lines.append(f" {k}: {v}")
self.error("\n".join(lines))
getattr(parsed, FILE_OR_DIR).extend(unrecognized)
return parsed
class DropShorterLongHelpFormatter(argparse.HelpFormatter):
"""Shorten help for long options that differ only in extra hyphens.
- Collapse **long** options that are the same except for extra hyphens.
- Shortcut if there are only two options and one of them is a short one.
- Cache result on the action object as this is called at least 2 times.
"""
def __init__(self, *args: Any, **kwargs: Any) -> None:
# Use more accurate terminal width.
if "width" not in kwargs:
kwargs["width"] = _pytest._io.get_terminal_width()
super().__init__(*args, **kwargs)
def _format_action_invocation(self, action: argparse.Action) -> str:
orgstr = super()._format_action_invocation(action)
if orgstr and orgstr[0] != "-": # only optional arguments
return orgstr
res: str | None = getattr(action, "_formatted_action_invocation", None)
if res:
return res
options = orgstr.split(", ")
if len(options) == 2 and (len(options[0]) == 2 or len(options[1]) == 2):
# a shortcut for '-h, --help' or '--abc', '-a'
action._formatted_action_invocation = orgstr # type: ignore
return orgstr
return_list = []
short_long: dict[str, str] = {}
for option in options:
if len(option) == 2 or option[2] == " ":
continue
if not option.startswith("--"):
raise ArgumentError(
f'long optional argument without "--": [{option}]', option
)
xxoption = option[2:]
shortened = xxoption.replace("-", "")
if shortened not in short_long or len(short_long[shortened]) < len(
xxoption
):
short_long[shortened] = xxoption
# now short_long has been filled out to the longest with dashes
# **and** we keep the right option ordering from add_argument
for option in options:
if len(option) == 2 or option[2] == " ":
return_list.append(option)
if option[2:] == short_long.get(option.replace("-", "")):
return_list.append(option.replace(" ", "=", 1))
formatted_action_invocation = ", ".join(return_list)
action._formatted_action_invocation = formatted_action_invocation # type: ignore
return formatted_action_invocation
def _split_lines(self, text, width):
"""Wrap lines after splitting on original newlines.
This allows to have explicit line breaks in the help text.
"""
import textwrap
lines = []
for line in text.splitlines():
lines.extend(textwrap.wrap(line.strip(), width))
return lines

View file

@ -0,0 +1,85 @@
from __future__ import annotations
from collections.abc import Mapping
import functools
from pathlib import Path
from typing import Any
import warnings
import pluggy
from ..compat import LEGACY_PATH
from ..compat import legacy_path
from ..deprecated import HOOK_LEGACY_PATH_ARG
# hookname: (Path, LEGACY_PATH)
imply_paths_hooks: Mapping[str, tuple[str, str]] = {
"pytest_ignore_collect": ("collection_path", "path"),
"pytest_collect_file": ("file_path", "path"),
"pytest_pycollect_makemodule": ("module_path", "path"),
"pytest_report_header": ("start_path", "startdir"),
"pytest_report_collectionfinish": ("start_path", "startdir"),
}
def _check_path(path: Path, fspath: LEGACY_PATH) -> None:
if Path(fspath) != path:
raise ValueError(
f"Path({fspath!r}) != {path!r}\n"
"if both path and fspath are given they need to be equal"
)
class PathAwareHookProxy:
"""
this helper wraps around hook callers
until pluggy supports fixingcalls, this one will do
it currently doesn't return full hook caller proxies for fixed hooks,
this may have to be changed later depending on bugs
"""
def __init__(self, hook_relay: pluggy.HookRelay) -> None:
self._hook_relay = hook_relay
def __dir__(self) -> list[str]:
return dir(self._hook_relay)
def __getattr__(self, key: str) -> pluggy.HookCaller:
hook: pluggy.HookCaller = getattr(self._hook_relay, key)
if key not in imply_paths_hooks:
self.__dict__[key] = hook
return hook
else:
path_var, fspath_var = imply_paths_hooks[key]
@functools.wraps(hook)
def fixed_hook(**kw: Any) -> Any:
path_value: Path | None = kw.pop(path_var, None)
fspath_value: LEGACY_PATH | None = kw.pop(fspath_var, None)
if fspath_value is not None:
warnings.warn(
HOOK_LEGACY_PATH_ARG.format(
pylib_path_arg=fspath_var, pathlib_path_arg=path_var
),
stacklevel=2,
)
if path_value is not None:
if fspath_value is not None:
_check_path(path_value, fspath_value)
else:
fspath_value = legacy_path(path_value)
else:
assert fspath_value is not None
path_value = Path(fspath_value)
kw[path_var] = path_value
kw[fspath_var] = fspath_value
return hook(**kw)
fixed_hook.name = hook.name # type: ignore[attr-defined]
fixed_hook.spec = hook.spec # type: ignore[attr-defined]
fixed_hook.__name__ = key
self.__dict__[key] = fixed_hook
return fixed_hook # type: ignore[return-value]

View file

@ -0,0 +1,13 @@
from __future__ import annotations
from typing import final
@final
class UsageError(Exception):
"""Error in pytest usage or invocation."""
class PrintHelp(Exception):
"""Raised when pytest should print its help to skip the rest of the
argument parsing and validation."""

Some files were not shown because too many files have changed in this diff Show more