audio streaming app plyr.fm
38
fork

Configure Feed

Select the types of activity you want to include in your feed.

feat: add security headers middleware (#315)

Implements basic HTTP security headers as requested in #205.

Headers added:

- X-Content-Type-Options: nosniff

- X-Frame-Options: DENY

- X-XSS-Protection: 1; mode=block

- Referrer-Policy: strict-origin-when-cross-origin

- Strict-Transport-Security (in production)

authored by

nate nowack and committed by
GitHub
dd78a2bf c9004561

+95 -2
+34 -1
src/backend/main.py
··· 5 5 from collections.abc import AsyncIterator 6 6 from contextlib import asynccontextmanager 7 7 8 - from fastapi import FastAPI 8 + from fastapi import FastAPI, Request 9 9 from fastapi.middleware.cors import CORSMiddleware 10 + from starlette.middleware.base import BaseHTTPMiddleware 10 11 11 12 # filter pydantic warning from atproto library 12 13 warnings.filterwarnings( ··· 63 64 logger = logging.getLogger(__name__) 64 65 65 66 67 + class SecurityHeadersMiddleware(BaseHTTPMiddleware): 68 + """middleware to add security headers to all responses.""" 69 + 70 + async def dispatch(self, request: Request, call_next): 71 + """dispatch the request.""" 72 + response = await call_next(request) 73 + 74 + # prevent MIME sniffing 75 + response.headers["X-Content-Type-Options"] = "nosniff" 76 + 77 + # prevent clickjacking 78 + response.headers["X-Frame-Options"] = "DENY" 79 + 80 + # enable browser XSS protection 81 + response.headers["X-XSS-Protection"] = "1; mode=block" 82 + 83 + # control referrer information 84 + response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin" 85 + 86 + # enforce HTTPS in production (HSTS) 87 + # skip in debug mode (localhost usually doesn't have https) 88 + if not settings.app.debug: 89 + response.headers["Strict-Transport-Security"] = ( 90 + "max-age=31536000; includeSubDomains" 91 + ) 92 + 93 + return response 94 + 95 + 66 96 @asynccontextmanager 67 97 async def lifespan(app: FastAPI) -> AsyncIterator[None]: 68 98 """handle application lifespan events.""" ··· 92 122 # instrument fastapi with logfire 93 123 if logfire: 94 124 logfire.instrument_fastapi(app) 125 + 126 + # add security headers middleware 127 + app.add_middleware(SecurityHeadersMiddleware) 95 128 96 129 # configure CORS - allow localhost for dev and cloudflare pages for production 97 130 app.add_middleware(
+18 -1
tests/conftest.py
··· 1 1 """pytest configuration for relay tests.""" 2 2 3 - from collections.abc import AsyncGenerator 3 + from collections.abc import AsyncGenerator, Generator 4 4 from contextlib import asynccontextmanager 5 5 from datetime import UTC, datetime 6 6 7 7 import pytest 8 8 import sqlalchemy as sa 9 + from fastapi import FastAPI 10 + from fastapi.testclient import TestClient 9 11 from sqlalchemy.ext.asyncio import ( 10 12 AsyncConnection, 11 13 AsyncEngine, ··· 203 205 """ 204 206 async with session_context(engine=_engine) as session: 205 207 yield session 208 + 209 + 210 + @pytest.fixture 211 + def fastapi_app() -> FastAPI: 212 + """provides the FastAPI app instance.""" 213 + from backend.main import app as main_app 214 + 215 + return main_app 216 + 217 + 218 + @pytest.fixture 219 + def client(fastapi_app: FastAPI) -> Generator[TestClient, None, None]: 220 + """provides a TestClient for testing the FastAPI application.""" 221 + with TestClient(fastapi_app) as tc: 222 + yield tc
+43
tests/test_security_headers.py
··· 1 + """test security headers middleware.""" 2 + 3 + from fastapi.testclient import TestClient 4 + 5 + from backend.config import settings 6 + 7 + 8 + def test_security_headers_present(client: TestClient): 9 + """verify that security headers are present in responses.""" 10 + response = client.get("/health") 11 + assert response.status_code == 200 12 + 13 + headers = response.headers 14 + 15 + # check basic security headers 16 + assert headers["X-Content-Type-Options"] == "nosniff" 17 + assert headers["X-Frame-Options"] == "DENY" 18 + assert headers["X-XSS-Protection"] == "1; mode=block" 19 + assert headers["Referrer-Policy"] == "strict-origin-when-cross-origin" 20 + 21 + 22 + def test_hsts_header_logic(client: TestClient): 23 + """verify HSTS header logic based on debug mode.""" 24 + # save original setting 25 + original_debug = settings.app.debug 26 + 27 + try: 28 + # case 1: debug=True (default in tests) -> no HSTS 29 + settings.app.debug = True 30 + response = client.get("/health") 31 + assert "Strict-Transport-Security" not in response.headers 32 + 33 + # case 2: debug=False (production) -> HSTS present 34 + settings.app.debug = False 35 + response = client.get("/health") 36 + assert ( 37 + response.headers["Strict-Transport-Security"] 38 + == "max-age=31536000; includeSubDomains" 39 + ) 40 + 41 + finally: 42 + # restore setting 43 + settings.app.debug = original_debug