···2121- Warming tier caps protect the shared IP during the first 14 days
2222 of a new member's lifetime.
2323- Pool-level FBL registrations: Gmail Postmaster verified, Microsoft
2424- SNDS + JMRP registered, Yahoo CFL pending. Operator-classified
2424+ SNDS + JMRP registered, Yahoo CFL verified. Operator-classified
2525 inbound (`postmaster@`, `abuse@`, `fbl@`, …) forwards to an
2626 external inbox for provider authorization flows. See
2727 [docs/operator-runbook.md](docs/operator-runbook.md) for the live
···8888 members, inbound log, shadow-verdicts, review queue for
8989 auto-suspensions.
9090- **FBL integrations**: Gmail Postmaster Tools verified, Microsoft
9191- SNDS + JMRP registered, Yahoo CFL pending. Pool-level registration
9191+ SNDS + JMRP registered, Yahoo CFL verified. All three major US
9292+ mailbox-provider feedback loops are live. Pool-level registration
9293 via `d=atmos.email` signing means one registration per provider
9394 covers every member.
9495- **Atproto OAuth** (PAR + DPoP + PKCE + `private_key_jwt`) for
···132133 dashboard. Rules will be frozen at their current behavior by a
133134 harness that publishes fixtures to a test Kafka and asserts on
134135 verdicts.
135135-4. **Yahoo CFL registration.** The last externally-gated FBL
136136- program. Manual form, 1–5 day turnaround.
137137-5. **Content policies that aren't just abuse.** Transactional-only is
136136+4. **Content policies that aren't just abuse.** Transactional-only is
138137 a deliberate v1 constraint; the path to "Postmark for atproto"
139138 runs through richer template support and eventually a managed
140139 API alongside SMTP.
+1-1
docs/operator-runbook.md
···213213| Gmail Postmaster Tools | Verified | TXT token published for `atmos.email`; dashboard live at postmaster.google.com. Reputation score needs ~48 h of sending volume to populate. |
214214| Microsoft SNDS | IP registered, authorization email landed via operator-forwarder | The enrollment flow required receiving a verification mail at `postmaster@atmos.email` — handled by the operator-forwarder routing described in section 6. |
215215| Microsoft JMRP | Registered | FBL recipient `fbl@atmospheremail.com` accepted. First complaint probe will confirm the delivery path. |
216216-| Yahoo CFL | Pending | Manual form at `senders.yahooinc.com/complaint-feedback-loop/` — no API. Tracked as the last externally-gated item before the FBL triangle is complete. |
216216+| Yahoo CFL | Verified 2026-04-20 | Domain verified via TXT (`yahoo-verification-key=…`) at the atmos.email apex. Verification record is a no-op now; tracked for removal in chainlink #144. Complaints will arrive at `fbl@atmospheremail.com` once Yahoo begins sending. |
217217218218Adding a new provider later: publish the FBL recipient as
219219`fbl@atmospheremail.com` if they accept an external address, otherwise
···3535 boot.loader.grub = {
3636 enable = true;
3737 efiSupport = false;
3838+ configurationLimit = 20;
3839 };
39404041 boot.initrd.availableKernelModules = [
···102103 # Don't restart tailscaled during deploys — avoids SSH drops
103104 systemd.services.tailscaled.restartIfChanged = false;
104105105105- # Tailscale Serve: proxy HTTPS :443 → admin/dashboard on :8080
106106- # Gives clean URLs: https://atmos-relay.internal.example/ui/
107107- #
108108- # The relay's public HTTPS listener binds to the public IP
109109- # specifically (detected at startup), so Tailscale Serve can use
110110- # :443 on the Tailscale interface without conflict.
106106+ # Tailscale Serve: proxy HTTPS on the Tailscale interface → admin
107107+ # dashboard on :8080. The relay binds its public HTTPS listener to
108108+ # the detected public IP (not 0.0.0.0), so both can use :443 on
109109+ # different interfaces without conflict. Requires AF_NETLINK in the
110110+ # relay's RestrictAddressFamilies for the IP detection to work.
111111 systemd.services.tailscale-serve = {
112112 description = "Configure Tailscale Serve for admin dashboard";
113113 after = [ "tailscaled.service" "atmos-relay.service" ];
···151151 ADMIN_TOKEN=${config.sops.placeholder.admin_token}
152152 LABELER_URL=${config.sops.placeholder.labeler_url}
153153 WARMUP_SEED_ADDRESSES=${config.sops.placeholder.warmup_seed_addresses}
154154+ WARMUP_FROM_LOCAL_PARTS=scott,hello
155155+ WARMUP_DIDS=did:plc:dy67wyyakm7u4v2lthy5zwbn
154156 '';
155157 };
156158···165167 email = "postmaster@atmos.email";
166168 webroot = "/var/lib/acme/.challenges";
167169 group = "atmos-relay";
168168- reloadServices = [ "atmos-relay.service" ];
170170+ # No reloadServices: the relay uses an in-process
171171+ # CertReloader (internal/relay/cert_reload.go) that picks
172172+ # up new certs on the next TLS handshake via mtime polling.
173173+ # A systemd reload/restart would drop in-flight SMTP/HTTP
174174+ # sessions every 60-90 days and trigger the spool-reload
175175+ # race that #208 fixed — this is exactly the failure mode
176176+ # #216 closed.
177177+ reloadServices = [ ];
169178 };
170179 certs."smtp.atmos.email" = {};
171180 certs."atmos.email" = {};
···305314 ProtectKernelTunables = true;
306315 ProtectKernelModules = true;
307316 ProtectControlGroups = true;
317317+ RestrictAddressFamilies = [ "AF_INET" "AF_INET6" "AF_UNIX" "AF_NETLINK" ];
318318+ RestrictNamespaces = true;
319319+ RestrictRealtime = true;
320320+ RestrictSUIDSGID = true;
321321+ LockPersonality = true;
322322+ MemoryDenyWriteExecute = true;
308323 ReadWritePaths = [ "/var/lib/atmos-relay" ];
309324 ReadOnlyPaths = [
310325 "/var/lib/acme/smtp.atmos.email"
···344359 curl
345360 htop
346361 jq
362362+ sqlite
347363 ];
348364349365 # -------------------------------------------------------------------
···355371 '';
356372357373 # -------------------------------------------------------------------
374374+ # Backup — encrypted Restic backups to Hetzner Cloud Volume.
375375+ #
376376+ # Same pattern as atmos-ops: auto-format on first boot, mount by
377377+ # label, auto-generate restic password, timer every 6h.
378378+ #
379379+ # Critical data: relay.sqlite, DKIM signing keys, OAuth key.
380380+ # -------------------------------------------------------------------
381381+ systemd.services.format-backup-volume = {
382382+ description = "Format Hetzner backup volume if unformatted";
383383+ wantedBy = [ "multi-user.target" ];
384384+ serviceConfig = {
385385+ Type = "oneshot";
386386+ RemainAfterExit = true;
387387+ };
388388+ path = [ pkgs.util-linux pkgs.e2fsprogs pkgs.systemd ];
389389+ script = ''
390390+ DEV=""
391391+ for d in /dev/disk/by-id/scsi-0HC_Volume_*; do
392392+ [ -b "$d" ] && DEV="$d" && break
393393+ done
394394+ if [ -z "$DEV" ]; then
395395+ echo "No Hetzner Cloud Volume found, skipping"
396396+ exit 0
397397+ fi
398398+ RESOLVED=$(readlink -f "$DEV")
399399+ if blkid -o value -s TYPE "$DEV" 2>/dev/null | grep -q .; then
400400+ echo "$DEV ($RESOLVED) already formatted"
401401+ else
402402+ echo "Formatting $DEV ($RESOLVED) as ext4 with label atmos-relay-backup"
403403+ mkfs.ext4 -L atmos-relay-backup "$DEV"
404404+ fi
405405+ if ! mountpoint -q /var/lib/atmos-backup 2>/dev/null; then
406406+ systemctl start var-lib-atmos\\x2dbackup.mount 2>/dev/null || true
407407+ fi
408408+ '';
409409+ };
410410+411411+ fileSystems."/var/lib/atmos-backup" = {
412412+ device = "/dev/disk/by-label/atmos-relay-backup";
413413+ fsType = "ext4";
414414+ options = [ "nofail" "x-systemd.device-timeout=30" ];
415415+ };
416416+417417+ systemd.services.restic-password-init = {
418418+ description = "Generate restic encryption password if missing";
419419+ after = [ "local-fs.target" ];
420420+ wantedBy = [ "multi-user.target" ];
421421+ serviceConfig = {
422422+ Type = "oneshot";
423423+ RemainAfterExit = true;
424424+ };
425425+ script = ''
426426+ if [ ! -f /root/.restic-password ]; then
427427+ ${pkgs.coreutils}/bin/head -c 32 /dev/urandom | ${pkgs.coreutils}/bin/base64 > /root/.restic-password
428428+ chmod 0400 /root/.restic-password
429429+ fi
430430+ if ${pkgs.util-linux}/bin/mountpoint -q /var/lib/atmos-backup && [ ! -f /var/lib/atmos-backup/.restic-password ]; then
431431+ cp /root/.restic-password /var/lib/atmos-backup/.restic-password
432432+ chmod 0400 /var/lib/atmos-backup/.restic-password
433433+ fi
434434+ '';
435435+ };
436436+437437+ services.restic.backups.atmos-relay = {
438438+ initialize = true;
439439+ repository = "/var/lib/atmos-backup/restic-repo";
440440+ passwordFile = "/root/.restic-password";
441441+ paths = [
442442+ "/var/lib/atmos-backup/dumps"
443443+ ];
444444+ backupPrepareCommand = ''
445445+ if ! ${pkgs.util-linux}/bin/mountpoint -q /var/lib/atmos-backup; then
446446+ echo "ERROR: backup volume not mounted"
447447+ exit 1
448448+ fi
449449+ mkdir -p /var/lib/atmos-backup/dumps
450450+451451+ # Relay SQLite — hot backup
452452+ if [ -f /var/lib/atmos-relay/relay.sqlite ]; then
453453+ ${pkgs.sqlite}/bin/sqlite3 /var/lib/atmos-relay/relay.sqlite \
454454+ ".backup '/var/lib/atmos-backup/dumps/relay.sqlite'"
455455+ fi
456456+457457+ # DKIM signing keys (generated at first boot, no other copy exists)
458458+ if [ -f /var/lib/atmos-relay/operator-dkim-keys.json ]; then
459459+ cp /var/lib/atmos-relay/operator-dkim-keys.json /var/lib/atmos-backup/dumps/
460460+ fi
461461+462462+ # OAuth signing key
463463+ if [ -f /var/lib/atmos-relay/oauth-signing-key.pem ]; then
464464+ cp /var/lib/atmos-relay/oauth-signing-key.pem /var/lib/atmos-backup/dumps/
465465+ fi
466466+ '';
467467+ timerConfig = {
468468+ OnCalendar = "*-*-* 00/6:00:00";
469469+ Persistent = true;
470470+ RandomizedDelaySec = "30m";
471471+ };
472472+ pruneOpts = [
473473+ "--keep-daily 7"
474474+ "--keep-weekly 4"
475475+ "--keep-monthly 3"
476476+ ];
477477+ };
478478+479479+ # -------------------------------------------------------------------
358480 # Nix — enable flakes for nixos-rebuild
359481 # -------------------------------------------------------------------
360482 nix.settings = {
361483 experimental-features = [ "nix-command" "flakes" ];
362484 trusted-users = [ "root" ];
485485+ };
486486+487487+ nix.gc = {
488488+ automatic = true;
489489+ dates = "daily";
490490+ options = "--delete-older-than 5d";
491491+ persistent = true;
492492+ };
493493+494494+ nix.optimise = {
495495+ automatic = true;
496496+ dates = [ "weekly" ];
363497 };
364498 };
365499}
+14
infra/outputs.tf
···8585 After that, all updates go through git push → CI → deploy.
8686 EOT
8787}
8888+8989+# ---------------------------------------------------------------------------
9090+# Backup volume outputs
9191+# ---------------------------------------------------------------------------
9292+9393+output "ops_backup_volume_id" {
9494+ description = "Hetzner volume ID of the ops backup volume"
9595+ value = hcloud_volume.ops_backup.id
9696+}
9797+9898+output "relay_backup_volume_id" {
9999+ description = "Hetzner volume ID of the relay backup volume"
100100+ value = hcloud_volume.relay_backup.id
101101+}
···11+// SPDX-License-Identifier: AGPL-3.0-or-later
22+33+package admin
44+55+import (
66+ "bytes"
77+ "context"
88+ "encoding/json"
99+ "fmt"
1010+ "net/http"
1111+ "net/http/httptest"
1212+ "strings"
1313+ "testing"
1414+ "time"
1515+1616+ "atmosphere-mail/internal/relay"
1717+ "atmosphere-mail/internal/relaystore"
1818+)
1919+2020+func newBypassAPI(t *testing.T) (*API, *relay.LabelChecker, *relaystore.Store) {
2121+ t.Helper()
2222+ store, err := relaystore.New(":memory:")
2323+ if err != nil {
2424+ t.Fatal(err)
2525+ }
2626+ t.Cleanup(func() { store.Close() })
2727+ lc := relay.NewLabelChecker("http://127.0.0.1:1", nil)
2828+ api := NewWithLabelChecker(store, "tok", "atmos.email", lc)
2929+ return api, lc, store
3030+}
3131+3232+func bypassAddReq(t *testing.T, api *API, did string, body any) *httptest.ResponseRecorder {
3333+ t.Helper()
3434+ var buf []byte
3535+ if body != nil {
3636+ var err error
3737+ buf, err = json.Marshal(body)
3838+ if err != nil {
3939+ t.Fatal(err)
4040+ }
4141+ }
4242+ req := httptest.NewRequest("POST", "/admin/member/"+did+"/bypass-labels", bytes.NewReader(buf))
4343+ if buf != nil {
4444+ req.Header.Set("Content-Type", "application/json")
4545+ }
4646+ req.Header.Set("Authorization", "Bearer tok")
4747+ w := httptest.NewRecorder()
4848+ api.ServeHTTP(w, req)
4949+ return w
5050+}
5151+5252+// TestBypassAdd_DefaultTTLApplied confirms a bare add (no body) lands
5353+// with the 24h default expiry rather than no expiry.
5454+func TestBypassAdd_DefaultTTLApplied(t *testing.T) {
5555+ api, _, store := newBypassAPI(t)
5656+ did := "did:plc:aaaaaaaabbbbbbbbcccccccc"
5757+5858+ w := bypassAddReq(t, api, did, nil)
5959+ if w.Code != http.StatusOK {
6060+ t.Fatalf("status = %d body=%s", w.Code, w.Body.String())
6161+ }
6262+ var resp map[string]string
6363+ _ = json.Unmarshal(w.Body.Bytes(), &resp)
6464+ exp, err := time.Parse(time.RFC3339, resp["expires_at"])
6565+ if err != nil {
6666+ t.Fatalf("parse expires_at %q: %v", resp["expires_at"], err)
6767+ }
6868+ dt := time.Until(exp)
6969+ if dt < 23*time.Hour || dt > 25*time.Hour {
7070+ t.Errorf("default TTL = %s, want ~24h", dt)
7171+ }
7272+7373+ // Persisted in store with that expiry.
7474+ listed, _ := store.ListBypassDIDs(context.Background())
7575+ if len(listed) != 1 || listed[0] != did {
7676+ t.Errorf("ListBypassDIDs = %v, want [%s]", listed, did)
7777+ }
7878+}
7979+8080+// TestBypassAdd_RejectsTTLOverCap pins the security cap.
8181+func TestBypassAdd_RejectsTTLOverCap(t *testing.T) {
8282+ api, _, _ := newBypassAPI(t)
8383+ w := bypassAddReq(t, api, "did:plc:bbbbbbbbccccccccdddddddd",
8484+ map[string]any{"ttl_hours": 24*30 + 1, "reason": "dangerous"})
8585+ if w.Code != http.StatusBadRequest {
8686+ t.Fatalf("status = %d, want 400; body=%s", w.Code, w.Body.String())
8787+ }
8888+ if !strings.Contains(strings.ToLower(w.Body.String()), "ttl_hours") {
8989+ t.Errorf("error message should mention ttl_hours; got %q", w.Body.String())
9090+ }
9191+}
9292+9393+// TestBypassAdd_PersistsReason — the reason string round-trips.
9494+func TestBypassAdd_PersistsReason(t *testing.T) {
9595+ api, _, store := newBypassAPI(t)
9696+ did := "did:plc:ccccccccddddddddeeeeeeee"
9797+ w := bypassAddReq(t, api, did, map[string]any{"ttl_hours": 1, "reason": "investigating sender flood"})
9898+ if w.Code != http.StatusOK {
9999+ t.Fatalf("status = %d body=%s", w.Code, w.Body.String())
100100+ }
101101+ var n int
102102+ _ = store.SampleStats() // unused but verifies store is initialized
103103+ row := storeRowReason(t, store, did)
104104+ if row != "investigating sender flood" {
105105+ t.Errorf("persisted reason = %q, want %q", row, "investigating sender flood")
106106+ }
107107+ _ = n
108108+}
109109+110110+// storeRowReason peeks the bypass_dids row directly so the test can
111111+// assert reason persistence without exposing a Store getter.
112112+func storeRowReason(t *testing.T, s *relaystore.Store, did string) string {
113113+ t.Helper()
114114+ type stmtCarrier interface {
115115+ // duck-typed access to the underlying *sql.DB via a short
116116+ // query helper we already use in audit-log tests.
117117+ }
118118+ _ = stmtCarrier(nil)
119119+ rows, err := s.ListBypassDIDs(context.Background())
120120+ if err != nil || len(rows) == 0 {
121121+ t.Fatalf("expected at least one row, got %v err=%v", rows, err)
122122+ }
123123+ // Fall back: query store internals via SQL through the public
124124+ // ListBypassAuditEntries helper. We don't have one yet; in the
125125+ // meantime, the audit row carries the reason and is asserted in
126126+ // the relaystore-package tests. Here we just verify the entry
127127+ // exists in the active set.
128128+ for _, d := range rows {
129129+ if d == did {
130130+ return persistedReasonFromAudit(t, s, did)
131131+ }
132132+ }
133133+ return ""
134134+}
135135+136136+// persistedReasonFromAudit reads the most-recent 'add' audit row for
137137+// the given DID via a raw query. The Store doesn't expose the audit
138138+// table publicly, so we go through the test-only SQL helper to keep
139139+// the assertion legible.
140140+func persistedReasonFromAudit(t *testing.T, s *relaystore.Store, did string) string {
141141+ t.Helper()
142142+ rows, err := s.ListBypassAuditForTest(context.Background(), did)
143143+ if err != nil {
144144+ t.Fatalf("audit query: %v", err)
145145+ }
146146+ for _, e := range rows {
147147+ if e.Action == "add" {
148148+ return e.Reason
149149+ }
150150+ }
151151+ return ""
152152+}
153153+154154+// TestBypassRemove_WritesAuditRow confirms a manual removal lands a
155155+// 'remove'/'manual' audit row so post-hoc analysis can distinguish
156156+// it from janitor-driven 'expired' removals.
157157+func TestBypassRemove_WritesAuditRow(t *testing.T) {
158158+ api, _, store := newBypassAPI(t)
159159+ did := "did:plc:ddddddddeeeeeeeeffffffff"
160160+ if w := bypassAddReq(t, api, did, map[string]any{"ttl_hours": 1}); w.Code != http.StatusOK {
161161+ t.Fatalf("add: %d %s", w.Code, w.Body.String())
162162+ }
163163+ req := httptest.NewRequest("DELETE", "/admin/member/"+did+"/bypass-labels", nil)
164164+ req.Header.Set("Authorization", "Bearer tok")
165165+ w := httptest.NewRecorder()
166166+ api.ServeHTTP(w, req)
167167+ if w.Code != http.StatusOK {
168168+ t.Fatalf("remove: %d %s", w.Code, w.Body.String())
169169+ }
170170+ rows, err := store.ListBypassAuditForTest(context.Background(), did)
171171+ if err != nil {
172172+ t.Fatal(err)
173173+ }
174174+ var sawRemove bool
175175+ for _, e := range rows {
176176+ if e.Action == "remove" && e.Reason == "manual" {
177177+ sawRemove = true
178178+ }
179179+ }
180180+ if !sawRemove {
181181+ t.Errorf("expected audit row action=remove reason=manual, got %+v", rows)
182182+ }
183183+}
184184+185185+// Compile-time guard: ensure *relaystore.Store has ListBypassAuditForTest.
186186+// Without it the tests above will fail to link, signalling we forgot
187187+// to add the test-only accessor.
188188+var _ = func() bool {
189189+ var s *relaystore.Store
190190+ _ = s
191191+ return true
192192+}()
193193+194194+// silence unused warnings for fmt import if we ever drop it.
195195+var _ = fmt.Sprintf
+187
internal/admin/enroll_oauth_gate_test.go
···11+// SPDX-License-Identifier: AGPL-3.0-or-later
22+33+package admin
44+55+import (
66+ "bytes"
77+ "encoding/json"
88+ "net/http"
99+ "net/http/httptest"
1010+ "strings"
1111+ "testing"
1212+)
1313+1414+// fakeAuthVerifier returns a hard-coded verified DID when a "verified"
1515+// cookie is present, simulating the EnrollHandler's OAuth ticket
1616+// lookup without dragging the UI package into the test boundary.
1717+type fakeAuthVerifier struct {
1818+ verifiedDID string
1919+}
2020+2121+func (f *fakeAuthVerifier) VerifyAuthCookie(r *http.Request) (string, bool) {
2222+ if r == nil || f == nil {
2323+ return "", false
2424+ }
2525+ if c, _ := r.Cookie("verified"); c != nil && c.Value == "yes" {
2626+ return f.verifiedDID, true
2727+ }
2828+ return "", false
2929+}
3030+3131+// TestEnrollStart_OAuthGate_RejectsMissingCookie pins #207: when an
3232+// OAuth verifier is wired, /admin/enroll-start must refuse a request
3333+// that does not present a verified-DID cookie.
3434+func TestEnrollStart_OAuthGate_RejectsMissingCookie(t *testing.T) {
3535+ api, _, _ := testEnrollAPI(t)
3636+ api.SetEnrollAuthVerifier(&fakeAuthVerifier{verifiedDID: "did:plc:aaaaaaaabbbbbbbbcccccccc"})
3737+3838+ body, _ := json.Marshal(EnrollStartRequest{
3939+ DID: "did:plc:aaaaaaaabbbbbbbbcccccccc", Domain: "ok.example", TermsAccepted: true,
4040+ })
4141+ req := httptest.NewRequest(http.MethodPost, "/admin/enroll-start", bytes.NewReader(body))
4242+ // No cookie set.
4343+ w := httptest.NewRecorder()
4444+ api.ServeHTTP(w, req)
4545+ if w.Code != http.StatusForbidden {
4646+ t.Fatalf("expected 403 without cookie, got %d body=%s", w.Code, w.Body.String())
4747+ }
4848+ if !strings.Contains(strings.ToLower(w.Body.String()), "identity verification") {
4949+ t.Errorf("body should mention identity verification, got %q", w.Body.String())
5050+ }
5151+}
5252+5353+// TestEnrollStart_OAuthGate_RejectsDIDMismatch is the central #207
5454+// scenario: caller proves DID A via OAuth but tries to enroll
5555+// claiming DID B. The mismatch must be refused.
5656+func TestEnrollStart_OAuthGate_RejectsDIDMismatch(t *testing.T) {
5757+ api, _, _ := testEnrollAPI(t)
5858+ verifier := &fakeAuthVerifier{verifiedDID: "did:plc:bbbbbbbbccccccccdddddddd"}
5959+ api.SetEnrollAuthVerifier(verifier)
6060+6161+ // Claimed DID does NOT match the OAuth-verified DID.
6262+ body, _ := json.Marshal(EnrollStartRequest{
6363+ DID: "did:plc:zzzzzzzzyyyyyyyyxxxxxxxx", Domain: "ok.example", TermsAccepted: true,
6464+ })
6565+ req := httptest.NewRequest(http.MethodPost, "/admin/enroll-start", bytes.NewReader(body))
6666+ req.AddCookie(&http.Cookie{Name: "verified", Value: "yes"})
6767+ w := httptest.NewRecorder()
6868+ api.ServeHTTP(w, req)
6969+ if w.Code != http.StatusForbidden {
7070+ t.Fatalf("expected 403 on DID mismatch, got %d body=%s", w.Code, w.Body.String())
7171+ }
7272+ if !strings.Contains(strings.ToLower(w.Body.String()), "does not match") {
7373+ t.Errorf("body should mention mismatch, got %q", w.Body.String())
7474+ }
7575+}
7676+7777+// TestEnrollStart_OAuthGate_AllowsExactMatch confirms the happy path:
7878+// claimed DID == OAuth-verified DID, the start succeeds.
7979+func TestEnrollStart_OAuthGate_AllowsExactMatch(t *testing.T) {
8080+ api, store, _ := testEnrollAPI(t)
8181+ did := "did:plc:aaaaaaaabbbbbbbbcccccccc"
8282+ api.SetEnrollAuthVerifier(&fakeAuthVerifier{verifiedDID: did})
8383+8484+ body, _ := json.Marshal(EnrollStartRequest{DID: did, Domain: "ok.example", TermsAccepted: true})
8585+ req := httptest.NewRequest(http.MethodPost, "/admin/enroll-start", bytes.NewReader(body))
8686+ req.AddCookie(&http.Cookie{Name: "verified", Value: "yes"})
8787+ w := httptest.NewRecorder()
8888+ api.ServeHTTP(w, req)
8989+ if w.Code != http.StatusOK {
9090+ t.Fatalf("expected 200 on matching DID, got %d body=%s", w.Code, w.Body.String())
9191+ }
9292+ // Pending row must be persisted (existing happy-path invariant).
9393+ var resp EnrollStartResponse
9494+ if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
9595+ t.Fatalf("decode response: %v", err)
9696+ }
9797+ if resp.Token == "" {
9898+ t.Error("expected non-empty token in response")
9999+ }
100100+ if pending, _ := store.GetPendingEnrollment(req.Context(), resp.Token); pending == nil {
101101+ t.Error("pending enrollment row not persisted")
102102+ }
103103+}
104104+105105+// TestEnrollStart_OAuthGate_CaseInsensitiveDIDMatch — DIDs are
106106+// case-insensitive in their method-specific identifier portion for
107107+// did:plc (lowercase base32) but did:web allows mixed case in
108108+// hostnames per RFC 3986. Match must use EqualFold to avoid spurious
109109+// rejection on identity systems that produce mixed-case DIDs.
110110+func TestEnrollStart_OAuthGate_CaseInsensitiveDIDMatch(t *testing.T) {
111111+ api, _, _ := testEnrollAPI(t)
112112+ api.SetEnrollAuthVerifier(&fakeAuthVerifier{verifiedDID: "did:web:Example.com"})
113113+114114+ body, _ := json.Marshal(EnrollStartRequest{
115115+ DID: "did:web:example.com", Domain: "ok.example", TermsAccepted: true,
116116+ })
117117+ req := httptest.NewRequest(http.MethodPost, "/admin/enroll-start", bytes.NewReader(body))
118118+ req.AddCookie(&http.Cookie{Name: "verified", Value: "yes"})
119119+ w := httptest.NewRecorder()
120120+ api.ServeHTTP(w, req)
121121+ if w.Code != http.StatusOK {
122122+ t.Fatalf("expected 200 with case-folded DID match, got %d body=%s", w.Code, w.Body.String())
123123+ }
124124+}
125125+126126+// TestEnrollStart_OAuthGate_NilVerifierIsLegacyOpen pins backward-
127127+// compatibility: deployments that haven't wired SetEnrollAuthVerifier
128128+// yet (pre-#207 binaries during rolling deploy) must accept requests
129129+// the same as before. The code path is exercised by every existing
130130+// enroll test, but having a dedicated assertion makes the contract
131131+// explicit so a future refactor can't quietly tighten it.
132132+func TestEnrollStart_OAuthGate_NilVerifierIsLegacyOpen(t *testing.T) {
133133+ api, _, _ := testEnrollAPI(t)
134134+ // Verifier intentionally NOT set.
135135+136136+ body, _ := json.Marshal(EnrollStartRequest{
137137+ DID: "did:plc:aaaaaaaabbbbbbbbcccccccc", Domain: "ok.example", TermsAccepted: true,
138138+ })
139139+ req := httptest.NewRequest(http.MethodPost, "/admin/enroll-start", bytes.NewReader(body))
140140+ w := httptest.NewRecorder()
141141+ api.ServeHTTP(w, req)
142142+ if w.Code != http.StatusOK {
143143+ t.Fatalf("expected 200 with nil verifier (legacy mode), got %d body=%s", w.Code, w.Body.String())
144144+ }
145145+}
146146+147147+// TestEnroll_OAuthGate_BlocksDIDSwapAtCompletion exercises the second-
148148+// layer check: even if a pending row was created with a verified DID,
149149+// the /admin/enroll completion step must independently re-verify so a
150150+// stolen token can't be redeemed from a session that does not own the
151151+// pending row's DID.
152152+func TestEnroll_OAuthGate_BlocksDIDSwapAtCompletion(t *testing.T) {
153153+ api, store, lk := testEnrollAPI(t)
154154+ pendingDID := "did:plc:aaaaaaaabbbbbbbbcccccccc"
155155+ domain := "ok.example"
156156+157157+ // Step 1: legitimate user starts enrollment with their verified DID.
158158+ api.SetEnrollAuthVerifier(&fakeAuthVerifier{verifiedDID: pendingDID})
159159+ startBody, _ := json.Marshal(EnrollStartRequest{DID: pendingDID, Domain: domain, TermsAccepted: true})
160160+ startReq := httptest.NewRequest(http.MethodPost, "/admin/enroll-start", bytes.NewReader(startBody))
161161+ startReq.AddCookie(&http.Cookie{Name: "verified", Value: "yes"})
162162+ startW := httptest.NewRecorder()
163163+ api.ServeHTTP(startW, startReq)
164164+ if startW.Code != http.StatusOK {
165165+ t.Fatalf("enroll-start: %d %s", startW.Code, startW.Body.String())
166166+ }
167167+ var sr EnrollStartResponse
168168+ _ = json.Unmarshal(startW.Body.Bytes(), &sr)
169169+ lk.records["_atmos-enroll."+domain] = []string{"atmos-verify=" + sr.Token}
170170+171171+ // Step 2: attacker has the token (e.g. captured from DNS) but their
172172+ // session is verified as a DIFFERENT DID. The completion must refuse.
173173+ api.SetEnrollAuthVerifier(&fakeAuthVerifier{verifiedDID: "did:plc:zzzzzzzzyyyyyyyyxxxxxxxx"})
174174+ completeBody, _ := json.Marshal(EnrollRequest{Token: sr.Token})
175175+ completeReq := httptest.NewRequest(http.MethodPost, "/admin/enroll", bytes.NewReader(completeBody))
176176+ completeReq.AddCookie(&http.Cookie{Name: "verified", Value: "yes"})
177177+ completeW := httptest.NewRecorder()
178178+ api.ServeHTTP(completeW, completeReq)
179179+ if completeW.Code != http.StatusForbidden {
180180+ t.Fatalf("expected 403 on session-DID swap at completion, got %d body=%s",
181181+ completeW.Code, completeW.Body.String())
182182+ }
183183+ // Member must not exist.
184184+ if got, _ := store.GetMember(completeReq.Context(), pendingDID); got != nil {
185185+ t.Error("attacker session created a member despite DID-swap rejection")
186186+ }
187187+}
+330
internal/admin/enroll_phases.go
···11+// SPDX-License-Identifier: AGPL-3.0-or-later
22+33+package admin
44+55+import (
66+ "context"
77+ "crypto/x509"
88+ "encoding/json"
99+ "errors"
1010+ "fmt"
1111+ "io"
1212+ "log"
1313+ "net/http"
1414+ "strings"
1515+ "time"
1616+1717+ "atmosphere-mail/internal/enroll"
1818+ "atmosphere-mail/internal/notify"
1919+ "atmosphere-mail/internal/relay"
2020+ "atmosphere-mail/internal/relaystore"
2121+)
2222+2323+// enrollHTTPError is the failure value returned by the enrollment phase
2424+// helpers below. handleEnroll renders it via http.Error and otherwise
2525+// proceeds to the next phase.
2626+//
2727+// Splitting the handler into discrete phases (validate → load+verify →
2828+// authorize → provision → persist → dispatch → respond) makes each step
2929+// individually unit-testable and keeps handleEnroll itself a short
3030+// orchestration function. See #223.
3131+type enrollHTTPError struct {
3232+ Status int
3333+ Message string
3434+}
3535+3636+func (e *enrollHTTPError) Error() string { return e.Message }
3737+3838+func enrollErrf(status int, format string, args ...any) *enrollHTTPError {
3939+ return &enrollHTTPError{Status: status, Message: fmt.Sprintf(format, args...)}
4040+}
4141+4242+// --- Phase 1: validate ------------------------------------------------------
4343+4444+// validateEnrollRequest reads the JSON body and query params from the
4545+// public POST /admin/enroll request. Returns the parsed token and the
4646+// optional forward_to address, or an HTTP error to return to the caller.
4747+//
4848+// Body size is capped at 4 KiB; tokens are bounded by the pending row
4949+// they look up so a giant body would be a hostile no-op.
5050+func validateEnrollRequest(r *http.Request) (token, forwardTo string, herr *enrollHTTPError) {
5151+ forwardTo = r.URL.Query().Get("forward_to")
5252+ if forwardTo != "" && !strings.Contains(forwardTo, "@") {
5353+ return "", "", enrollErrf(http.StatusBadRequest, "forward_to must be a valid email address")
5454+ }
5555+5656+ body, err := io.ReadAll(io.LimitReader(r.Body, 4096))
5757+ if err != nil {
5858+ return "", "", enrollErrf(http.StatusBadRequest, "error reading request body")
5959+ }
6060+ if len(body) == 0 {
6161+ return "", "", enrollErrf(http.StatusBadRequest, "enrollment token required: POST JSON body with {\"token\": \"...\"}")
6262+ }
6363+ var req EnrollRequest
6464+ if err := json.Unmarshal(body, &req); err != nil {
6565+ return "", "", enrollErrf(http.StatusBadRequest, "invalid JSON body")
6666+ }
6767+ if req.Token == "" {
6868+ return "", "", enrollErrf(http.StatusBadRequest, "token field required")
6969+ }
7070+ return req.Token, forwardTo, nil
7171+}
7272+7373+// --- Phase 2: load + verify -------------------------------------------------
7474+7575+// loadAndVerifyPending fetches the pending enrollment by token, runs the
7676+// OAuth-cookie identity gate (#207), enforces the expiry cutoff, and
7777+// re-runs DNS TXT verification. Returns the pending row on success or an
7878+// HTTP error otherwise.
7979+//
8080+// Side effect: deletes the pending row on expiry so the same expired
8181+// token can't be retried.
8282+func (a *API) loadAndVerifyPending(ctx context.Context, r *http.Request, token string) (*relaystore.PendingEnrollment, *enrollHTTPError) {
8383+ pending, err := a.store.GetPendingEnrollment(ctx, token)
8484+ if err != nil {
8585+ log.Printf("admin.enroll: token_lookup_error=%v", err)
8686+ return nil, enrollErrf(http.StatusInternalServerError, "internal error")
8787+ }
8888+ if pending == nil {
8989+ // Don't distinguish "never existed" from "already consumed" to
9090+ // avoid leaking enrollment state to callers.
9191+ return nil, enrollErrf(http.StatusNotFound, "token not found or already used")
9292+ }
9393+9494+ // OAuth-verified DID gate, second layer (#207). The pending row was
9595+ // created by handleEnrollStart, which already enforces the same
9696+ // check, but a stale pending row from before the verifier was wired
9797+ // or a path that bypasses /admin/enroll-start altogether (e.g. an
9898+ // admin-driven test fixture replay) must not let a verified DID be
9999+ // swapped at completion time. Re-check here against pending.DID so
100100+ // the *member-creation* moment is also gated.
101101+ if a.enrollAuthVerifier != nil {
102102+ verifiedDID, ok := a.enrollAuthVerifier.VerifyAuthCookie(r)
103103+ if !ok {
104104+ log.Printf("admin.enroll.no_oauth: pending_did=%s", pending.DID)
105105+ return nil, enrollErrf(http.StatusForbidden, "identity verification required — sign in with your handle before completing enrollment")
106106+ }
107107+ if !strings.EqualFold(verifiedDID, pending.DID) {
108108+ log.Printf("admin.enroll.did_mismatch: pending=%s verified=%s", pending.DID, verifiedDID)
109109+ return nil, enrollErrf(http.StatusForbidden, "verified identity does not match the pending enrollment")
110110+ }
111111+ }
112112+113113+ if time.Now().UTC().After(pending.ExpiresAt) {
114114+ // 410 Gone signals "the thing you're pointing at existed but is no
115115+ // longer retrievable" — precisely the pending-expired semantic.
116116+ // Clean the row so the same token can't be retried.
117117+ _ = a.store.DeletePendingEnrollment(ctx, token)
118118+ return nil, enrollErrf(http.StatusGone, "enrollment token expired — start over")
119119+ }
120120+121121+ if err := a.domainVerifier.Verify(ctx, pending.Domain, token); err != nil {
122122+ log.Printf("admin.enroll: did=%s domain=%s dns_verify_error=%v", pending.DID, pending.Domain, err)
123123+ switch {
124124+ case errors.Is(err, enroll.ErrNoTXTRecord):
125125+ return nil, enrollErrf(http.StatusForbidden, "no atmos-verify TXT record found at _atmos-enroll.%s — publish the record and retry", pending.Domain)
126126+ case errors.Is(err, enroll.ErrTokenMismatch):
127127+ return nil, enrollErrf(http.StatusForbidden, "TXT record does not contain the expected token — double-check the value")
128128+ default:
129129+ return nil, enrollErrf(http.StatusServiceUnavailable, "DNS lookup failed: %v — retry in a moment", err)
130130+ }
131131+ }
132132+133133+ log.Printf("admin.enroll: did=%s domain=%s dns_verified=true", pending.DID, pending.Domain)
134134+135135+ // Consume the pending row now that verification succeeded. Don't
136136+ // fail the enrollment if cleanup errors — CleanExpired will sweep
137137+ // it later and the unique-domain constraint prevents reuse.
138138+ if err := a.store.DeletePendingEnrollment(ctx, token); err != nil {
139139+ log.Printf("admin.enroll: did=%s domain=%s pending_cleanup_error=%v", pending.DID, pending.Domain, err)
140140+ }
141141+ return pending, nil
142142+}
143143+144144+// --- Phase 3: authorize -----------------------------------------------------
145145+146146+// checkDomainAvailable confirms the domain is unclaimed and the DID hasn't
147147+// already maxed out its per-account domain quota. Both checks run on every
148148+// enroll completion because handleEnrollStart's check is racy: a second
149149+// enrollment could complete between start and verify if the DID raced to
150150+// acquire domains via another browser tab or API caller.
151151+func (a *API) checkDomainAvailable(ctx context.Context, did, domain string) *enrollHTTPError {
152152+ existing, err := a.store.GetMemberDomain(ctx, domain)
153153+ if err != nil {
154154+ log.Printf("admin.enroll: did=%s error=%v", did, err)
155155+ return enrollErrf(http.StatusInternalServerError, "internal error")
156156+ }
157157+ if existing != nil {
158158+ if existing.DID == did {
159159+ return enrollErrf(http.StatusConflict, "You've already enrolled this domain. Sign in at /account to manage it.")
160160+ }
161161+ return enrollErrf(http.StatusConflict, "This domain is registered to another account.")
162162+ }
163163+164164+ owned, err := a.store.ListMemberDomains(ctx, did)
165165+ if err != nil {
166166+ log.Printf("admin.enroll: did=%s list_domains_error=%v", did, err)
167167+ return enrollErrf(http.StatusInternalServerError, "internal error")
168168+ }
169169+ if len(owned) >= maxDomainsPerMember {
170170+ return enrollErrf(http.StatusConflict, "domain limit reached — your account currently supports up to %d sending domains", maxDomainsPerMember)
171171+ }
172172+ return nil
173173+}
174174+175175+// --- Phase 4: provision -----------------------------------------------------
176176+177177+// enrollProvisionResult bundles the records and key material the persist
178178+// + respond phases need. IsNewDID is true when GetMember returned nil,
179179+// signalling that the persist step should also insert a member row and
180180+// dispatch should fire the operator-ping.
181181+type enrollProvisionResult struct {
182182+ Member *relaystore.Member // nil when adding a domain to an existing DID
183183+ Domain *relaystore.MemberDomain
184184+ APIKey string
185185+ APIKeyHash []byte
186186+ DKIMKeys *relay.DKIMKeys
187187+ DKIMSelector string
188188+ IsNewDID bool
189189+}
190190+191191+// provisionMemberAndDomain generates the API key, DKIM keypair, and
192192+// builds the member + domain records for atomic insert. Pure aside from
193193+// the DID lookup against the store and the random-key generation; the
194194+// returned result is what the persist + dispatch + respond phases need.
195195+func (a *API) provisionMemberAndDomain(ctx context.Context, pending *relaystore.PendingEnrollment, forwardTo string) (*enrollProvisionResult, *enrollHTTPError) {
196196+ existing, err := a.store.GetMember(ctx, pending.DID)
197197+ if err != nil {
198198+ log.Printf("admin.enroll: did=%s error=%v", pending.DID, err)
199199+ return nil, enrollErrf(http.StatusInternalServerError, "internal error")
200200+ }
201201+202202+ apiKey, err := relay.GenerateAPIKey()
203203+ if err != nil {
204204+ log.Printf("admin.enroll: did=%s error=generate_api_key %v", pending.DID, err)
205205+ return nil, enrollErrf(http.StatusInternalServerError, "internal error")
206206+ }
207207+ apiKeyHash, err := relay.HashAPIKey(apiKey)
208208+ if err != nil {
209209+ log.Printf("admin.enroll: did=%s error=hash_api_key %v", pending.DID, err)
210210+ return nil, enrollErrf(http.StatusInternalServerError, "internal error")
211211+ }
212212+213213+ selector := fmt.Sprintf("atmos%s", time.Now().UTC().Format("20060102"))
214214+ dkimKeys, err := relay.GenerateDKIMKeys(selector)
215215+ if err != nil {
216216+ log.Printf("admin.enroll: did=%s error=generate_dkim %v", pending.DID, err)
217217+ return nil, enrollErrf(http.StatusInternalServerError, "internal error")
218218+ }
219219+220220+ rsaBytes, err := x509.MarshalPKCS8PrivateKey(dkimKeys.RSAPriv)
221221+ if err != nil {
222222+ log.Printf("admin.enroll: did=%s error=marshal_rsa %v", pending.DID, err)
223223+ return nil, enrollErrf(http.StatusInternalServerError, "internal error")
224224+ }
225225+ edBytes, err := x509.MarshalPKCS8PrivateKey(dkimKeys.EdPriv)
226226+ if err != nil {
227227+ log.Printf("admin.enroll: did=%s error=marshal_ed %v", pending.DID, err)
228228+ return nil, enrollErrf(http.StatusInternalServerError, "internal error")
229229+ }
230230+231231+ now := time.Now().UTC()
232232+ var member *relaystore.Member
233233+ isNewDID := existing == nil
234234+ if isNewDID {
235235+ if !pending.TermsAccepted {
236236+ return nil, enrollErrf(http.StatusBadRequest, "terms acceptance required")
237237+ }
238238+ member = &relaystore.Member{
239239+ DID: pending.DID,
240240+ Status: relaystore.StatusPending,
241241+ DIDVerified: false,
242242+ TermsAcceptedAt: now,
243243+ TermsVersion: relaystore.CurrentTermsVersion,
244244+ HourlyLimit: 100,
245245+ DailyLimit: 1000,
246246+ CreatedAt: now,
247247+ UpdatedAt: now,
248248+ }
249249+ }
250250+251251+ domainRecord := &relaystore.MemberDomain{
252252+ Domain: pending.Domain,
253253+ DID: pending.DID,
254254+ APIKeyHash: apiKeyHash,
255255+ DKIMRSAPriv: rsaBytes,
256256+ DKIMEdPriv: edBytes,
257257+ DKIMSelector: selector,
258258+ ForwardTo: forwardTo,
259259+ ContactEmail: pending.ContactEmail,
260260+ CreatedAt: now,
261261+ }
262262+263263+ return &enrollProvisionResult{
264264+ Member: member,
265265+ Domain: domainRecord,
266266+ APIKey: apiKey,
267267+ APIKeyHash: apiKeyHash,
268268+ DKIMKeys: dkimKeys,
269269+ DKIMSelector: selector,
270270+ IsNewDID: isNewDID,
271271+ }, nil
272272+}
273273+274274+// --- Phase 6: dispatch ------------------------------------------------------
275275+276276+// dispatchEnrollNotifications fires the post-persist side effects:
277277+// operator-ping email (only for new DIDs to avoid notification fatigue),
278278+// webhook event, and contact-email verification. Errors are best-effort
279279+// because the enrollment itself has already succeeded.
280280+func (a *API) dispatchEnrollNotifications(pending *relaystore.PendingEnrollment, isNewDID bool) {
281281+ if isNewDID {
282282+ go a.FireOperatorPing(context.Background(), pending.DID, pending.Domain, pending.ContactEmail)
283283+ a.notifyEvent(notify.KindMemberPending, pending.DID, pending.Domain, "", pending.ContactEmail)
284284+ } else {
285285+ a.notifyEvent(notify.KindMemberDomainAdded, pending.DID, pending.Domain, "", pending.ContactEmail)
286286+ }
287287+288288+ if pending.ContactEmail != "" {
289289+ go a.TriggerEmailVerification(context.Background(), pending.Domain, pending.ContactEmail)
290290+ }
291291+}
292292+293293+// --- Phase 7: respond -------------------------------------------------------
294294+295295+// buildEnrollResponse assembles the JSON response body, including SPF
296296+// alignment when configured. Pure given inputs; ctx is only used for the
297297+// SPF lookup.
298298+func (a *API) buildEnrollResponse(ctx context.Context, p *enrollProvisionResult, domain string) EnrollResponse {
299299+ var spfResult *SPFAlignmentResponse
300300+ if a.spfChecker != nil {
301301+ result := a.spfChecker.CheckAlignment(ctx, domain)
302302+ spfResult = &SPFAlignmentResponse{
303303+ Aligned: result.Aligned,
304304+ Failures: result.Failures,
305305+ }
306306+ if !result.Aligned {
307307+ log.Printf("admin.enroll.spf_warning: did=%s domain=%s failures=%v", p.Domain.DID, domain, result.Failures)
308308+ }
309309+ }
310310+311311+ return EnrollResponse{
312312+ DID: p.Domain.DID,
313313+ APIKey: p.APIKey,
314314+ DKIM: DKIMResponse{
315315+ Selector: p.DKIMSelector,
316316+ RSASelector: p.DKIMKeys.RSASelectorName(),
317317+ EdSelector: p.DKIMKeys.EdSelectorName(),
318318+ RSARecord: p.DKIMKeys.RSADNSRecord(),
319319+ EdRecord: p.DKIMKeys.EdDNSRecord(),
320320+ RSADNSName: fmt.Sprintf("%s._domainkey.%s", p.DKIMKeys.RSASelectorName(), domain),
321321+ EdDNSName: fmt.Sprintf("%s._domainkey.%s", p.DKIMKeys.EdSelectorName(), domain),
322322+ },
323323+ SMTP: SMTPResponse{
324324+ Host: "smtp." + a.domain,
325325+ Port: 587,
326326+ },
327327+ SPFAlignment: spfResult,
328328+ apiKeyHash: p.APIKeyHash,
329329+ }
330330+}
+113
internal/admin/enroll_phases_test.go
···11+// SPDX-License-Identifier: AGPL-3.0-or-later
22+33+package admin
44+55+import (
66+ "net/http"
77+ "net/http/httptest"
88+ "strings"
99+ "testing"
1010+)
1111+1212+func TestValidateEnrollRequest_Valid(t *testing.T) {
1313+ r := httptest.NewRequest(http.MethodPost, "/admin/enroll", strings.NewReader(`{"token":"abc123"}`))
1414+ token, forwardTo, herr := validateEnrollRequest(r)
1515+ if herr != nil {
1616+ t.Fatalf("unexpected error: %v", herr)
1717+ }
1818+ if token != "abc123" {
1919+ t.Errorf("token=%q, want abc123", token)
2020+ }
2121+ if forwardTo != "" {
2222+ t.Errorf("forwardTo=%q, want empty", forwardTo)
2323+ }
2424+}
2525+2626+func TestValidateEnrollRequest_ForwardToValid(t *testing.T) {
2727+ r := httptest.NewRequest(http.MethodPost, "/admin/enroll?forward_to=ops@example.com",
2828+ strings.NewReader(`{"token":"t"}`))
2929+ token, forwardTo, herr := validateEnrollRequest(r)
3030+ if herr != nil {
3131+ t.Fatalf("err=%v", herr)
3232+ }
3333+ if token != "t" {
3434+ t.Errorf("token=%q", token)
3535+ }
3636+ if forwardTo != "ops@example.com" {
3737+ t.Errorf("forwardTo=%q, want ops@example.com", forwardTo)
3838+ }
3939+}
4040+4141+func TestValidateEnrollRequest_ForwardToInvalid(t *testing.T) {
4242+ r := httptest.NewRequest(http.MethodPost, "/admin/enroll?forward_to=not-an-email",
4343+ strings.NewReader(`{"token":"t"}`))
4444+ _, _, herr := validateEnrollRequest(r)
4545+ if herr == nil {
4646+ t.Fatal("expected forward_to validation error")
4747+ }
4848+ if herr.Status != http.StatusBadRequest {
4949+ t.Errorf("status=%d, want 400", herr.Status)
5050+ }
5151+}
5252+5353+func TestValidateEnrollRequest_EmptyBody(t *testing.T) {
5454+ r := httptest.NewRequest(http.MethodPost, "/admin/enroll", strings.NewReader(""))
5555+ _, _, herr := validateEnrollRequest(r)
5656+ if herr == nil {
5757+ t.Fatal("expected empty-body error")
5858+ }
5959+ if herr.Status != http.StatusBadRequest {
6060+ t.Errorf("status=%d, want 400", herr.Status)
6161+ }
6262+ if !strings.Contains(herr.Message, "token") {
6363+ t.Errorf("message %q should mention token", herr.Message)
6464+ }
6565+}
6666+6767+func TestValidateEnrollRequest_InvalidJSON(t *testing.T) {
6868+ r := httptest.NewRequest(http.MethodPost, "/admin/enroll", strings.NewReader(`{not json`))
6969+ _, _, herr := validateEnrollRequest(r)
7070+ if herr == nil {
7171+ t.Fatal("expected JSON error")
7272+ }
7373+ if herr.Status != http.StatusBadRequest {
7474+ t.Errorf("status=%d, want 400", herr.Status)
7575+ }
7676+}
7777+7878+func TestValidateEnrollRequest_TokenMissing(t *testing.T) {
7979+ r := httptest.NewRequest(http.MethodPost, "/admin/enroll", strings.NewReader(`{"token":""}`))
8080+ _, _, herr := validateEnrollRequest(r)
8181+ if herr == nil {
8282+ t.Fatal("expected token-required error")
8383+ }
8484+ if herr.Status != http.StatusBadRequest {
8585+ t.Errorf("status=%d", herr.Status)
8686+ }
8787+}
8888+8989+func TestValidateEnrollRequest_BodyOver4KiB(t *testing.T) {
9090+ // 5 KiB of payload — io.LimitReader truncates so the JSON unmarshals
9191+ // to whatever fits, which here will be invalid JSON. We just want to
9292+ // confirm we don't OOM or read unbounded input.
9393+ huge := `{"token":"` + strings.Repeat("a", 5000) + `"}`
9494+ r := httptest.NewRequest(http.MethodPost, "/admin/enroll", strings.NewReader(huge))
9595+ _, _, herr := validateEnrollRequest(r)
9696+ if herr == nil {
9797+ t.Fatal("expected bounded body to surface a JSON error after truncation")
9898+ }
9999+}
100100+101101+// TestEnrollHTTPError_ImplementsError pins that *enrollHTTPError satisfies
102102+// the error interface so it can be wrapped/unwrapped if a future caller
103103+// needs to do so.
104104+func TestEnrollHTTPError_ImplementsError(t *testing.T) {
105105+ var _ error = (*enrollHTTPError)(nil)
106106+ herr := enrollErrf(http.StatusBadRequest, "bad %s", "input")
107107+ if herr.Error() != "bad input" {
108108+ t.Errorf("Error()=%q, want %q", herr.Error(), "bad input")
109109+ }
110110+ if herr.Status != http.StatusBadRequest {
111111+ t.Errorf("Status=%d", herr.Status)
112112+ }
113113+}
+9-2
internal/admin/ui/attest.go
···4343// when the shared OAuth callback sees a session with no Attestation
4444// payload — signalling the flow was initiated for credential recovery
4545// rather than enrollment.
4646+//
4747+// UA-binding is required: the only entry point on this interface
4848+// takes the User-Agent of the browser that completed OAuth, so a
4949+// leaked cookie cannot be replayed from a different browser. The
5050+// legacy no-UA helper (IssueRecoveryTicket on *RecoverHandler) is
5151+// retained for tests but deliberately NOT exposed here so production
5252+// callers can't accidentally bypass the binding (#212).
4653type RecoveryIssuer interface {
4747- IssueRecoveryTicket(did, domain string) string
5454+ IssueRecoveryTicketWithUA(did, domain, ua string) string
4855}
49565057// DIDHandleResolver resolves a DID to its atproto handle. Used by the
···233240 if h.funnel != nil {
234241 h.funnel.RecordOAuthCallback("recovery", h.resolveHandle(ctx, sess.AccountDID()))
235242 }
236236- target := h.recoveryIssuer.IssueRecoveryTicket(sess.AccountDID(), sess.Domain())
243243+ target := h.recoveryIssuer.IssueRecoveryTicketWithUA(sess.AccountDID(), sess.Domain(), r.UserAgent())
237244 log.Printf("attest.callback: did=%s domain=%s handoff=recovery target=%s",
238245 sess.AccountDID(), sess.Domain(), target)
239246 http.Redirect(w, r, target, http.StatusFound)
+80-13
internal/admin/ui/enroll.go
···1313 "io"
1414 "log"
1515 "net/http"
1616- "net/http/httptest"
1716 "strings"
1817 "sync"
1918 "time"
···122121 h.mux.HandleFunc("/privacy", h.handlePrivacy)
123122 h.mux.HandleFunc("/aup", h.handleAUP)
124123 h.mux.HandleFunc("/about", h.handleAbout)
124124+ h.mux.HandleFunc("/faq", h.handleFAQ)
125125 return h
126126}
127127···261261 })
262262}
263263264264+func (h *EnrollHandler) handleFAQ(w http.ResponseWriter, r *http.Request) {
265265+ h.staticPage(w, r, func(w http.ResponseWriter, r *http.Request) {
266266+ _ = templates.FAQPage().Render(r.Context(), w)
267267+ })
268268+}
269269+264270// handleResolve takes a handle and returns the resolved DID as JSON.
265271// Used by the landing-page JS to turn `scottlanoue.com` into
266272// `did:plc:…` before the user submits the form.
···445451 "contactEmail": contactEmail,
446452 "termsAccepted": termsAccepted,
447453 })
448448- resp := h.proxyAdminInner(http.MethodPost, "/admin/enroll-start", bytes.NewReader(body))
454454+ resp := h.proxyAdminInner(http.MethodPost, "/admin/enroll-start", bytes.NewReader(body), r)
449455 if resp.Code != http.StatusOK {
450456 msg := strings.TrimSpace(resp.Body.String())
451457 if msg == "" {
···500506 }
501507502508 body, _ := json.Marshal(map[string]string{"token": token})
503503- resp := h.proxyAdminInner(http.MethodPost, "/admin/enroll", bytes.NewReader(body))
509509+ resp := h.proxyAdminInner(http.MethodPost, "/admin/enroll", bytes.NewReader(body), r)
504510 if resp.Code != http.StatusOK {
505511 msg := strings.TrimSpace(resp.Body.String())
506512 if msg == "" {
···752758 _ = templates.EnrollError(message).Render(r.Context(), w)
753759}
754760755755-// proxyAdminInner invokes the admin API in-process via httptest. We never
756756-// forward the caller's Authorization header — the admin API's enrollment
757757-// endpoints do their own verification (DNS TXT ownership), so forwarding
758758-// caller auth is unnecessary and would risk leaking admin credentials
759759-// from other contexts into the public path.
760760-func (h *EnrollHandler) proxyAdminInner(method, target string, body io.Reader) *httptest.ResponseRecorder {
761761- req := httptest.NewRequest(method, target, body)
761761+// proxyAdminInner invokes the admin API in-process. We never forward the
762762+// caller's Authorization header — the admin API's enrollment endpoints do
763763+// their own verification (DNS TXT ownership), so forwarding caller auth
764764+// is unnecessary and would risk leaking admin credentials from other
765765+// contexts into the public path.
766766+//
767767+// Cookie + User-Agent are forwarded so the inner admin API can look up
768768+// the enroll-auth ticket the public UI set after a successful AT Proto
769769+// OAuth round-trip — the central defense for #207.
770770+//
771771+// RemoteAddr is also forwarded so the admin API's per-IP enroll-start
772772+// rate limiter sees the real public client IP. Without this, every
773773+// public enrollment request would share a single rate-limit bucket and
774774+// a single attacker could exhaust it for all legitimate users from any
775775+// IP — closes #211.
776776+//
777777+// This used to construct an httptest.NewRequest + httptest.ResponseRecorder
778778+// in the production call chain (#222). The dependency on net/http/httptest
779779+// from non-test code masked the rate-limiter bypass that became #211 and
780780+// made the call site inscrutable to readers expecting test-only types not
781781+// to leak. We now use http.NewRequestWithContext + an in-package response
782782+// writer (inMemoryResponseWriter) so the type signatures match the rest
783783+// of the production stack.
784784+func (h *EnrollHandler) proxyAdminInner(method, target string, body io.Reader, src *http.Request) *adminProxyResponse {
785785+ ctx := context.Background()
786786+ if src != nil {
787787+ ctx = src.Context()
788788+ }
789789+ req, err := http.NewRequestWithContext(ctx, method, target, body)
790790+ if err != nil {
791791+ // method/target are package-internal constants; a build error here
792792+ // indicates a programming bug, not a runtime condition. Surface it
793793+ // as a 500 so the wrapping handler renders an inline error rather
794794+ // than panicking.
795795+ return &adminProxyResponse{
796796+ Code: http.StatusInternalServerError,
797797+ Body: bytes.NewBufferString("internal error: build admin request"),
798798+ header: http.Header{},
799799+ }
800800+ }
762801 req.Header.Set("Content-Type", "application/json")
763763- rr := httptest.NewRecorder()
764764- h.adminAPI.ServeHTTP(rr, req)
765765- return rr
802802+ if src != nil {
803803+ if cookie := src.Header.Get("Cookie"); cookie != "" {
804804+ req.Header.Set("Cookie", cookie)
805805+ }
806806+ if ua := src.UserAgent(); ua != "" {
807807+ req.Header.Set("User-Agent", ua)
808808+ }
809809+ if src.RemoteAddr != "" {
810810+ req.RemoteAddr = src.RemoteAddr
811811+ }
812812+ }
813813+ rw := newInMemoryResponseWriter()
814814+ h.adminAPI.ServeHTTP(rw, req)
815815+ return rw.snapshot()
816816+}
817817+818818+// VerifyAuthCookie implements admin.EnrollAuthVerifier. Returns the DID
819819+// proven by the most recent successful AT Proto OAuth round-trip if the
820820+// caller presents a valid enroll-auth ticket cookie, or "" / false
821821+// otherwise. The cookie's UA-binding is also enforced so a stolen
822822+// cookie can't be replayed from a different browser.
823823+func (h *EnrollHandler) VerifyAuthCookie(r *http.Request) (string, bool) {
824824+ id, ok := enrollAuthTicketFromCookie(r)
825825+ if !ok {
826826+ return "", false
827827+ }
828828+ ticket, ok := h.lookupEnrollAuthTicket(id, r.UserAgent())
829829+ if !ok {
830830+ return "", false
831831+ }
832832+ return ticket.did, true
766833}
+96-2
internal/admin/ui/enroll_test.go
···2626 lastAuth string
2727 lastPath string
2828 lastBody string
2929+ lastRemoteAddr string
3030+ lastCookie string
2931 gotEnrollStart bool
3032 gotEnroll bool
3133}
···3335func (f *fakeAdminAPI) ServeHTTP(w http.ResponseWriter, r *http.Request) {
3436 f.lastAuth = r.Header.Get("Authorization")
3537 f.lastPath = r.URL.Path + "?" + r.URL.RawQuery
3838+ f.lastRemoteAddr = r.RemoteAddr
3939+ f.lastCookie = r.Header.Get("Cookie")
3640 if r.Body != nil {
3741 b, _ := io.ReadAll(r.Body)
3842 f.lastBody = string(b)
···561565562566func TestStaticPage_HEADReturns200(t *testing.T) {
563567 h := NewEnrollHandler(&fakeAdminAPI{}, nil)
564564- for _, p := range []string{"/", "/terms", "/privacy", "/aup", "/about"} {
568568+ for _, p := range []string{"/", "/terms", "/privacy", "/aup", "/about", "/faq"} {
565569 req := httptest.NewRequest(http.MethodHead, p, nil)
566570 w := httptest.NewRecorder()
567571 h.ServeHTTP(w, req)
···675679 }
676680}
677681682682+func TestFAQPage_ServesHTML(t *testing.T) {
683683+ h := NewEnrollHandler(&fakeAdminAPI{}, nil)
684684+ req := httptest.NewRequest(http.MethodGet, "/faq", nil)
685685+ w := httptest.NewRecorder()
686686+ h.ServeHTTP(w, req)
687687+688688+ if w.Code != http.StatusOK {
689689+ t.Fatalf("status = %d, want 200", w.Code)
690690+ }
691691+ body := w.Body.String()
692692+ if !strings.Contains(body, "FAQ") {
693693+ t.Error("faq page should contain 'FAQ'")
694694+ }
695695+ if !strings.Contains(body, "Atmosphere Mail LLC") {
696696+ t.Error("faq page must identify the legal entity")
697697+ }
698698+ // The FAQ must answer the three questions prospective members ask most.
699699+ for _, required := range []string{"free", "trust", "commercial relay"} {
700700+ if !strings.Contains(strings.ToLower(body), required) {
701701+ t.Errorf("faq page must mention %q", required)
702702+ }
703703+ }
704704+}
705705+678706// TestDropCapOnlyOnLanding pins the Round 2 design decision that the
679707// drop-cap brand mark is a landing-page-only element. Putting it on every
680708// page dilutes the signature; this test guards against regressing.
···691719692720 // Legal + about pages must not carry a drop-cap — they are reference
693721 // documents, not the brand moment.
694694- for _, p := range []string{"/terms", "/privacy", "/aup", "/about"} {
722722+ for _, p := range []string{"/terms", "/privacy", "/aup", "/about", "/faq"} {
695723 req := httptest.NewRequest(http.MethodGet, p, nil)
696724 w := httptest.NewRecorder()
697725 h.ServeHTTP(w, req)
···954982 t.Errorf("body should report false flags, got %q", w.Body.String())
955983 }
956984}
985985+986986+// TestEnrollStart_ForwardsRemoteAddr proves the public /enroll/start
987987+// path delivers the caller's real RemoteAddr to the inner admin API
988988+// so the per-IP enroll-start rate limiter sees distinct buckets per
989989+// source. Without this, every public enrollment request shares a
990990+// single bucket (the httptest synthetic default), letting one
991991+// attacker exhaust the limit for everyone — closes #211.
992992+func TestEnrollStart_ForwardsRemoteAddr(t *testing.T) {
993993+ fake := &fakeAdminAPI{
994994+ enrollStartStatus: http.StatusOK,
995995+ enrollStartBody: `{"token":"tok","dnsName":"_atmos-enroll.x.example","dnsValue":"atmos-verify=tok","expiresAt":"2026-04-17T12:00:00Z"}`,
996996+ }
997997+ h := NewEnrollHandler(fake, nil)
998998+999999+ form := strings.NewReader("did=did:plc:testtesttesttesttest&domain=x.example&terms_accepted=on")
10001000+ req := httptest.NewRequest(http.MethodPost, "/enroll/start", form)
10011001+ req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
10021002+ // Realistic public-side RemoteAddr — what would arrive after
10031003+ // Tailscale Serve termination on the public listener.
10041004+ req.RemoteAddr = "203.0.113.42:54321"
10051005+ w := httptest.NewRecorder()
10061006+ h.ServeHTTP(w, req)
10071007+10081008+ if w.Code != http.StatusOK {
10091009+ t.Fatalf("status = %d body = %q", w.Code, w.Body.String())
10101010+ }
10111011+ if fake.lastRemoteAddr != "203.0.113.42:54321" {
10121012+ t.Errorf("admin API saw RemoteAddr=%q, want %q (proxyAdminInner dropped real client IP — rate limiter would treat all callers as one)",
10131013+ fake.lastRemoteAddr, "203.0.113.42:54321")
10141014+ }
10151015+}
10161016+10171017+// TestEnrollStart_DistinctIPsRouteToDistinctBuckets demonstrates the
10181018+// per-IP isolation property end-to-end: two requests from different
10191019+// public IPs both reach the admin API with their original RemoteAddr
10201020+// preserved, so an in-process IP-keyed limiter can distinguish them.
10211021+func TestEnrollStart_DistinctIPsRouteToDistinctBuckets(t *testing.T) {
10221022+ fake := &fakeAdminAPI{
10231023+ enrollStartStatus: http.StatusOK,
10241024+ enrollStartBody: `{"token":"tok","dnsName":"_atmos-enroll.x.example","dnsValue":"atmos-verify=tok","expiresAt":"2026-04-17T12:00:00Z"}`,
10251025+ }
10261026+ h := NewEnrollHandler(fake, nil)
10271027+10281028+ send := func(remote string) string {
10291029+ t.Helper()
10301030+ form := strings.NewReader("did=did:plc:testtesttesttesttest&domain=x.example&terms_accepted=on")
10311031+ req := httptest.NewRequest(http.MethodPost, "/enroll/start", form)
10321032+ req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
10331033+ req.RemoteAddr = remote
10341034+ w := httptest.NewRecorder()
10351035+ h.ServeHTTP(w, req)
10361036+ if w.Code != http.StatusOK {
10371037+ t.Fatalf("status = %d body = %q", w.Code, w.Body.String())
10381038+ }
10391039+ return fake.lastRemoteAddr
10401040+ }
10411041+10421042+ a := send("198.51.100.1:1111")
10431043+ b := send("198.51.100.2:2222")
10441044+ if a == b {
10451045+ t.Errorf("both requests delivered the same RemoteAddr=%q — limiter cannot distinguish them", a)
10461046+ }
10471047+ if a != "198.51.100.1:1111" || b != "198.51.100.2:2222" {
10481048+ t.Errorf("RemoteAddr forwarding garbled: a=%q b=%q", a, b)
10491049+ }
10501050+}
···11+// SPDX-License-Identifier: AGPL-3.0-or-later
22+33+package ui
44+55+import (
66+ "bytes"
77+ "net/http"
88+)
99+1010+// adminProxyResponse captures the response from invoking the admin API
1111+// in-process. Replaces *httptest.ResponseRecorder so test-only types stay
1212+// out of the production call chain (#222).
1313+//
1414+// Field names mirror the legacy ResponseRecorder API (`Code`, `Body`) so
1515+// callers that read `resp.Code` and `resp.Body.String()` keep working
1616+// without churn.
1717+type adminProxyResponse struct {
1818+ Code int
1919+ Body *bytes.Buffer
2020+ header http.Header
2121+}
2222+2323+// Header returns the response headers the inner admin handler set. Not
2424+// every caller needs them; exposed for parity with ResponseRecorder.
2525+func (r *adminProxyResponse) Header() http.Header { return r.header }
2626+2727+// inMemoryResponseWriter is a minimal real http.ResponseWriter the inner
2828+// admin handler writes to. Unlike httptest.ResponseRecorder this lives in
2929+// the regular package, so production code paths no longer depend on
3030+// net/http/httptest just to invoke an in-process handler.
3131+//
3232+// Behavior matches what stdlib serves: WriteHeader is sticky (first call
3333+// wins), Write to a 0-status writer implies 200 OK, and Header() returns
3434+// a mutable map up until WriteHeader fires.
3535+type inMemoryResponseWriter struct {
3636+ code int
3737+ body *bytes.Buffer
3838+ headers http.Header
3939+ written bool
4040+}
4141+4242+func newInMemoryResponseWriter() *inMemoryResponseWriter {
4343+ return &inMemoryResponseWriter{
4444+ body: &bytes.Buffer{},
4545+ headers: http.Header{},
4646+ }
4747+}
4848+4949+func (w *inMemoryResponseWriter) Header() http.Header { return w.headers }
5050+5151+func (w *inMemoryResponseWriter) WriteHeader(code int) {
5252+ if w.written {
5353+ return
5454+ }
5555+ w.code = code
5656+ w.written = true
5757+}
5858+5959+func (w *inMemoryResponseWriter) Write(p []byte) (int, error) {
6060+ if !w.written {
6161+ w.code = http.StatusOK
6262+ w.written = true
6363+ }
6464+ return w.body.Write(p)
6565+}
6666+6767+// snapshot freezes the writer state into an adminProxyResponse the
6868+// caller can read without further mutation.
6969+func (w *inMemoryResponseWriter) snapshot() *adminProxyResponse {
7070+ code := w.code
7171+ if code == 0 {
7272+ // Handler returned without writing anything — same convention as
7373+ // net/http: empty body, 200.
7474+ code = http.StatusOK
7575+ }
7676+ return &adminProxyResponse{Code: code, Body: w.body, header: w.headers}
7777+}
+75
internal/admin/ui/inproc_test.go
···11+// SPDX-License-Identifier: AGPL-3.0-or-later
22+33+package ui
44+55+import (
66+ "net/http"
77+ "reflect"
88+ "strings"
99+ "testing"
1010+)
1111+1212+func TestInMemoryResponseWriter_DefaultsTo200(t *testing.T) {
1313+ w := newInMemoryResponseWriter()
1414+ _, _ = w.Write([]byte("hello"))
1515+ snap := w.snapshot()
1616+ if snap.Code != http.StatusOK {
1717+ t.Errorf("Code=%d, want 200 (Write without WriteHeader implies 200)", snap.Code)
1818+ }
1919+ if got := snap.Body.String(); got != "hello" {
2020+ t.Errorf("Body=%q, want %q", got, "hello")
2121+ }
2222+}
2323+2424+func TestInMemoryResponseWriter_WriteHeaderIsSticky(t *testing.T) {
2525+ w := newInMemoryResponseWriter()
2626+ w.WriteHeader(http.StatusBadRequest)
2727+ w.WriteHeader(http.StatusInternalServerError) // ignored
2828+ snap := w.snapshot()
2929+ if snap.Code != http.StatusBadRequest {
3030+ t.Errorf("Code=%d, want 400 (first WriteHeader wins per net/http)", snap.Code)
3131+ }
3232+}
3333+3434+func TestInMemoryResponseWriter_EmptyHandlerReturns200(t *testing.T) {
3535+ w := newInMemoryResponseWriter()
3636+ snap := w.snapshot()
3737+ if snap.Code != http.StatusOK {
3838+ t.Errorf("empty handler Code=%d, want 200", snap.Code)
3939+ }
4040+ if snap.Body.Len() != 0 {
4141+ t.Errorf("empty handler body has %d bytes, want 0", snap.Body.Len())
4242+ }
4343+}
4444+4545+func TestInMemoryResponseWriter_HeaderRoundTrip(t *testing.T) {
4646+ w := newInMemoryResponseWriter()
4747+ w.Header().Set("X-Test", "value")
4848+ w.Header().Add("X-Multi", "a")
4949+ w.Header().Add("X-Multi", "b")
5050+ snap := w.snapshot()
5151+ if snap.Header().Get("X-Test") != "value" {
5252+ t.Errorf("X-Test header lost: %v", snap.Header())
5353+ }
5454+ if got := snap.Header().Values("X-Multi"); !reflect.DeepEqual(got, []string{"a", "b"}) {
5555+ t.Errorf("X-Multi=%v, want [a b]", got)
5656+ }
5757+}
5858+5959+// TestProductionCallChain_FreeOfHttptest pins the #222 fix at compile time:
6060+// adminProxyResponse must NOT be a *httptest.ResponseRecorder. If a future
6161+// change reintroduces test-only types in the production path this test
6262+// breaks before review.
6363+func TestProductionCallChain_FreeOfHttptest(t *testing.T) {
6464+ resp := &adminProxyResponse{}
6565+ typeName := reflect.TypeOf(resp).String()
6666+ if strings.Contains(typeName, "httptest") {
6767+ t.Errorf("adminProxyResponse type %q contains 'httptest' — production call chain regressed (#222)", typeName)
6868+ }
6969+ rw := newInMemoryResponseWriter()
7070+ var _ http.ResponseWriter = rw // compile-time interface conformance
7171+ rwType := reflect.TypeOf(rw).String()
7272+ if strings.Contains(rwType, "httptest") {
7373+ t.Errorf("inMemoryResponseWriter type %q contains 'httptest' — production call chain regressed (#222)", rwType)
7474+ }
7575+}
···11+// SPDX-License-Identifier: AGPL-3.0-or-later
22+33+package ui
44+55+import (
66+ "reflect"
77+ "testing"
88+)
99+1010+// TestRecoveryIssuer_RequiresUABinding pins the interface shape so a
1111+// future refactor can't accidentally re-introduce a no-UA entry point
1212+// on the contract that production callers use. The whole point of
1313+// #212 is that the OAuth callback MUST forward the User-Agent into
1414+// the ticket binding — an interface that exposes a no-UA method
1515+// would let that drift back over time.
1616+func TestRecoveryIssuer_RequiresUABinding(t *testing.T) {
1717+ typ := reflect.TypeOf((*RecoveryIssuer)(nil)).Elem()
1818+ if typ.NumMethod() != 1 {
1919+ t.Fatalf("RecoveryIssuer should have exactly 1 method, got %d", typ.NumMethod())
2020+ }
2121+ m := typ.Method(0)
2222+ if m.Name != "IssueRecoveryTicketWithUA" {
2323+ t.Errorf("RecoveryIssuer.method[0] = %q, want IssueRecoveryTicketWithUA — re-introducing a no-UA entry point regresses #212",
2424+ m.Name)
2525+ }
2626+ // The signature is (did, domain, ua string) string — 3 string args
2727+ // (plus the receiver), one string return.
2828+ if m.Type.NumIn() != 3 {
2929+ t.Errorf("IssueRecoveryTicketWithUA should accept 3 args (did, domain, ua), got %d", m.Type.NumIn())
3030+ }
3131+ if m.Type.NumOut() != 1 {
3232+ t.Errorf("IssueRecoveryTicketWithUA should return 1 value (URL), got %d", m.Type.NumOut())
3333+ }
3434+}
3535+3636+// TestRecoverHandler_SatisfiesRecoveryIssuer is a compile-time guard:
3737+// if *RecoverHandler ever drops IssueRecoveryTicketWithUA, this test
3838+// won't link. The whole production path (cmd/relay calls
3939+// AttestHandler.SetRecoveryIssuer with a *RecoverHandler) depends on
4040+// this assignment compiling.
4141+func TestRecoverHandler_SatisfiesRecoveryIssuer(t *testing.T) {
4242+ var _ RecoveryIssuer = (*RecoverHandler)(nil)
4343+ t.Log("*RecoverHandler satisfies RecoveryIssuer; compile-time check OK")
4444+}
+127
internal/admin/ui/templates/deliverability.go
···11+// SPDX-License-Identifier: AGPL-3.0-or-later
22+33+package templates
44+55+// DeliverabilityPage is the member-facing view of their own sending
66+// reputation: bounces, complaints, warming tier, and daily volume trend.
77+88+import (
99+ "context"
1010+ "fmt"
1111+ "html"
1212+ "io"
1313+ "strings"
1414+1515+ "github.com/a-h/templ"
1616+)
1717+1818+// DeliverabilityData carries all metrics for the /account/deliverability page.
1919+type DeliverabilityData struct {
2020+ DID string
2121+ Domain string
2222+ Status string
2323+ SuspendReason string
2424+2525+ Sent14d int64
2626+ Bounced14d int64
2727+ Complaints14d int64
2828+ BounceRate float64 // 0.0–1.0
2929+3030+ DailySends []int64 // 14 buckets, oldest-to-newest
3131+3232+ HourlyLimit int
3333+ DailyLimit int
3434+3535+ WarmingTier string // "warming" | "ramping" | "warmed" | ""
3636+ WarmingLabel string // human-readable, e.g. "warming (3/7 days)"
3737+3838+ Labels []string // Osprey + labeler labels
3939+}
4040+4141+func DeliverabilityPage(d DeliverabilityData) templ.Component {
4242+ return templ.ComponentFunc(func(ctx context.Context, w io.Writer) error {
4343+ inner := templ.ComponentFunc(func(_ context.Context, w io.Writer) error {
4444+ var b strings.Builder
4545+4646+ b.WriteString(`<nav class="topnav" aria-label="breadcrumb"><a href="/account" class="topnav-home">← Account</a></nav>`)
4747+ b.WriteString(`<h1 class="masthead masthead-sub">Deliverability</h1>`)
4848+ fmt.Fprintf(&b, `<p class="lede">Sending reputation for <code>%s</code>.</p>`, html.EscapeString(d.Domain))
4949+5050+ // Status banner
5151+ if d.Status == "suspended" {
5252+ b.WriteString(`<div class="error-note" role="alert"><p style="margin: 0;"><strong>Account suspended.</strong>`)
5353+ if d.SuspendReason != "" {
5454+ fmt.Fprintf(&b, ` Reason: %s`, html.EscapeString(d.SuspendReason))
5555+ }
5656+ b.WriteString(` SMTP submission is currently rejected. Contact the operator to appeal.</p></div>`)
5757+ }
5858+5959+ b.WriteString(`<div class="stat-grid" style="display: grid; grid-template-columns: repeat(auto-fit, minmax(140px, 1fr)); gap: 1rem; margin: 1.5rem 0;">`)
6060+ b.WriteString(statCard("Sent (14d)", fmt.Sprintf("%d", d.Sent14d)))
6161+ b.WriteString(statCard("Bounced", fmt.Sprintf("%d", d.Bounced14d)))
6262+ b.WriteString(statCard("Complaints", fmt.Sprintf("%d", d.Complaints14d)))
6363+ b.WriteString(statCard("Bounce rate", fmt.Sprintf("%.1f%%", d.BounceRate*100)))
6464+ b.WriteString(`</div>`)
6565+6666+ // Sparkline
6767+ if len(d.DailySends) > 0 {
6868+ b.WriteString(`<section class="section">`)
6969+ b.WriteString(`<h2>Sends per day</h2>`)
7070+ b.WriteString(sparklineSVG(d.DailySends))
7171+ b.WriteString(`</section>`)
7272+ }
7373+7474+ // Warming tier
7575+ if d.WarmingLabel != "" {
7676+ b.WriteString(`<section class="section">`)
7777+ b.WriteString(`<h2>Warming progress</h2>`)
7878+ fmt.Fprintf(&b, `<p class="section-lede">%s</p>`, html.EscapeString(d.WarmingLabel))
7979+ b.WriteString(warningNote(d.WarmingTier))
8080+ b.WriteString(`</section>`)
8181+ }
8282+8383+ // Limits
8484+ b.WriteString(`<section class="section">`)
8585+ b.WriteString(`<h2>Current limits</h2>`)
8686+ fmt.Fprintf(&b, `<dl class="bullets"><dt>Hourly limit</dt><dd>%d</dd><dt>Daily limit</dt><dd>%d</dd></dl>`, d.HourlyLimit, d.DailyLimit)
8787+ b.WriteString(`</section>`)
8888+8989+ // Labels
9090+ if len(d.Labels) > 0 {
9191+ b.WriteString(`<section class="section">`)
9292+ b.WriteString(`<h2>Reputation labels</h2>`)
9393+ b.WriteString(`<p class="section-lede">Labels published by the atproto labeler. Other services can query these to decide whether to trust mail from your domain.</p>`)
9494+ for _, l := range d.Labels {
9595+ fmt.Fprintf(&b, `<span class="badge badge-label">%s</span> `, html.EscapeString(l))
9696+ }
9797+ b.WriteString(`</section>`)
9898+ }
9999+100100+ b.WriteString(`<section class="section">`)
101101+ b.WriteString(`<p class="section-lede">These numbers update in real time. Bounce rate above 5%% or complaint rate above 0.1%% can trigger automatic throttling or suspension. The fix is always the same: send only to engaged recipients who asked for your mail.</p>`)
102102+ b.WriteString(`</section>`)
103103+104104+ _, err := io.WriteString(w, b.String())
105105+ return err
106106+ })
107107+ return publicLayout("Deliverability — "+d.Domain, false).Render(templ.WithChildren(ctx, inner), w)
108108+ })
109109+}
110110+111111+func statCard(title, value string) string {
112112+ return fmt.Sprintf(`<article style="background: var(--surface); border: 1px solid var(--line); padding: 1rem; border-radius: 2px;">
113113+ <div style="font-size: var(--t-xs); text-transform: uppercase; letter-spacing: 0.1em; color: var(--muted); margin-bottom: 0.5rem;">%s</div>
114114+ <div style="font-size: var(--t-2xl); font-family: var(--font-display); color: var(--ink);">%s</div>
115115+ </article>`, html.EscapeString(title), html.EscapeString(value))
116116+}
117117+118118+func warningNote(tier string) string {
119119+ switch tier {
120120+ case "warming":
121121+ return `<p class="section-lede" style="color: var(--accent-ink);">Your domain is in the warming tier: 5 emails per hour, 20 per day. This protects the shared IP while Gmail learns your sending pattern. The cap lifts automatically after 7 days of clean sending.</p>`
122122+ case "ramping":
123123+ return `<p class="section-lede" style="color: var(--accent-ink);">Your domain is ramping: 20 emails per hour, 100 per day. Keep engagement high and complaints low. Full limits unlock after 14 days total.</p>`
124124+ default:
125125+ return ""
126126+ }
127127+}
···11+// SPDX-License-Identifier: AGPL-3.0-or-later
22+33+package templates
44+55+// FAQPage answers the questions prospective members ask before they enroll.
66+// Honest, concise, and written to defuse the obvious objections.
77+88+import (
99+ "context"
1010+ "io"
1111+ "strings"
1212+1313+ "github.com/a-h/templ"
1414+)
1515+1616+func FAQPage() templ.Component {
1717+ return templ.ComponentFunc(func(ctx context.Context, w io.Writer) error {
1818+ inner := templ.ComponentFunc(func(_ context.Context, w io.Writer) error {
1919+ var b strings.Builder
2020+2121+ b.WriteString(`<h1 class="masthead masthead-sub">FAQ</h1>`)
2222+ b.WriteString(`<p class="lede">Questions we expect — answered honestly.</p>`)
2323+2424+ b.WriteString(`<section class="section">`)
2525+ b.WriteString(`<span class="step-marker">Pricing</span>`)
2626+ b.WriteString(`<h2>Why is this free? Will it stay free?</h2>`)
2727+ b.WriteString(`<p class="section-lede">It's free now because the relay needs a diverse, honest sender base to build IP reputation before we can responsibly charge anyone. Once the pool is warm and the first billing system is wired, paid tiers will start at around $10–15 per month per PDS operator. There will always be a generous free tier for low-volume senders.</p>`)
2828+ b.WriteString(`<p class="section-lede">If you enroll today, you are not signing up for a future invoice. We will announce pricing changes with at least 30 days' notice, and you can export your reputation or leave at any time.</p>`)
2929+ b.WriteString(`</section>`)
3030+3131+ b.WriteString(`<section class="section">`)
3232+ b.WriteString(`<span class="step-marker">Trust</span>`)
3333+ b.WriteString(`<h2>How can I trust this?</h2>`)
3434+ b.WriteString(`<p class="section-lede">You don't have to trust us blindly. The relay source code is open source (AGPL-3.0-or-later), the Osprey reputation rules are published, and the atproto labeler feed is public. You — or your favorite LLM — can audit exactly how deliverability decisions are made.</p>`)
3535+ b.WriteString(`<p class="section-lede">On privacy: the relay sees message metadata (sender, recipient, timestamp, size) but never the raw message body. That is the same trust model as Postmark, Mailgun, or Amazon SES, except here the code is open and the operator is a small LLC instead of a public company.</p>`)
3636+ b.WriteString(`</section>`)
3737+3838+ b.WriteString(`<section class="section">`)
3939+ b.WriteString(`<span class="step-marker">Alternatives</span>`)
4040+ b.WriteString(`<h2>Why not use a trusted commercial relay?</h2>`)
4141+ b.WriteString(`<p class="section-lede">Commercial relays work well, but your domain reputation lives inside their business. If you switch providers, you start from zero. Atmosphere Mail is designed so your reputation stays with you: your DID, your domain, your attestation record. If you ever want to run your own relay, the code and the reputation layer come with you.</p>`)
4242+ b.WriteString(`<p class="section-lede">The long-term goal is a federation of cooperative relays that share a reputation blocklist indexed through atproto. One relay is live today; the architecture is built for many.</p>`)
4343+ b.WriteString(`</section>`)
4444+4545+ b.WriteString(`<section class="section">`)
4646+ b.WriteString(`<span class="step-marker">Deliverability</span>`)
4747+ b.WriteString(`<h2>Will my mail reach the inbox?</h2>`)
4848+ b.WriteString(`<p class="section-lede">Maybe not on day one. Gmail treats mail from a new IP as suspicious regardless of authentication cleanliness. The relay protects the shared pool with warming tier caps: 5 emails per hour for the first week, graduating as your domain builds reputation. Expect some messages to land in spam initially. The fix is slow, engaged sending — not better DNS records.</p>`)
4949+ b.WriteString(`<p class="section-lede">We run pool-level feedback loops with Gmail, Microsoft, and Yahoo so complaints route back to the offending member, not the whole cooperative. That is how shared reputation stays shared instead of collective punishment.</p>`)
5050+ b.WriteString(`</section>`)
5151+5252+ b.WriteString(`<section class="section">`)
5353+ b.WriteString(`<span class="step-marker">Portability</span>`)
5454+ b.WriteString(`<h2>What if I want to leave?</h2>`)
5555+ b.WriteString(`<p class="section-lede">Your domain reputation is yours. The DKIM keys are published in your DNS, the attestation record lives on your PDS, and the <code>verified-mail-operator</code> label is signed against your DID. If you graduate to self-hosted delivery, those signals travel with you. If you want your member record deleted, email <a href="mailto:postmaster@atmos.email">postmaster@atmos.email</a> and we will remove it within 14 days.</p>`)
5656+ b.WriteString(`</section>`)
5757+5858+ b.WriteString(`<section class="section">`)
5959+ b.WriteString(`<span class="step-marker">Scope</span>`)
6060+ b.WriteString(`<h2>What can I send through this relay?</h2>`)
6161+ b.WriteString(`<p class="section-lede">Transactional and operational mail from your own domain: verification codes, password resets, notifications, personal correspondence. Unsolicited bulk mail, scraped lists, and relaying for third parties will get you suspended quickly. See the <a href="/aup">Acceptable Use Policy</a> for the full list.</p>`)
6262+ b.WriteString(`</section>`)
6363+6464+ b.WriteString(`<section class="section">`)
6565+ b.WriteString(`<p class="section-lede">Still have questions? Reach the operator at <a href="https://bsky.app/profile/scottlanoue.com">@scottlanoue.com</a> or <a href="mailto:postmaster@atmos.email">postmaster@atmos.email</a>.</p>`)
6666+ b.WriteString(`</section>`)
6767+6868+ _, err := io.WriteString(w, b.String())
6969+ return err
7070+ })
7171+ return publicLayout("FAQ — Atmosphere Mail", false).Render(templ.WithChildren(ctx, inner), w)
7272+ })
7373+}
+12
internal/admin/ui/templates/member_detail.templ
···223223 </button>
224224 }
225225 </div>
226226+ <details>
227227+ <summary role="button" class="outline contrast" style="margin-top: 1rem; font-size: 0.85rem;">Permanently Delete Member</summary>
228228+ <p style="margin: 0.5rem 0;"><small>This permanently removes the member, all domains, DKIM keys, message history, and rate counters. Suppressions are preserved for compliance. This cannot be undone.</small></p>
229229+ <button
230230+ class="secondary"
231231+ style="background: var(--pico-del-color, #c62828); border-color: var(--pico-del-color, #c62828);"
232232+ hx-delete={ "/ui/member/" + m.DID + "/delete" }
233233+ hx-confirm={ "PERMANENTLY DELETE " + m.Domain + " (" + m.DID + ")? This cannot be undone." }
234234+ >
235235+ Delete { m.Domain } Forever
236236+ </button>
237237+ </details>
226238}
···6565 // visibility but no action is required when the member is already
6666 // approved.
6767 KindMemberDomainAdded EventKind = "member_domain_added"
6868+6969+ // KindBypassAdded fires when an admin adds a label-bypass entry for
7070+ // a DID. High signal: bypass disables T&S enforcement, so operators
7171+ // must see every add land in their notification stream (#213).
7272+ KindBypassAdded EventKind = "bypass_added"
7373+7474+ // KindBypassRemoved fires when an admin or the expiry janitor
7575+ // removes a bypass entry. Distinguish manual vs janitor via the
7676+ // Reason field on the event.
7777+ KindBypassRemoved EventKind = "bypass_removed"
6878)
69797080// Event is the payload shape every webhook call carries. Fields are
+105-5
internal/osprey/emitter.go
···2626type EmitterMetrics interface {
2727 IncEmitted(eventType string)
2828 IncFailed(eventType string)
2929+ // IncSpooled fires when an event lands on disk because the
3030+ // broker rejected/silently dropped it (#214 DLQ).
3131+ IncSpooled(eventType string)
3232+ // IncReplayed fires when a previously-spooled event finally
3333+ // makes it to the broker on a subsequent retry.
3434+ IncReplayed(eventType string)
3535+ // IncDropped fires when the spool overflows and an event is
3636+ // permanently lost. reason names the trigger ("overflow",
3737+ // "corrupt").
3838+ IncDropped(reason string)
3939+ // SetSpoolDepth republishes the current spool size as a gauge.
4040+ SetSpoolDepth(n int)
2941}
30423143// Emitter sends relay events to Osprey via Kafka.
···3547 counter atomic.Int64
3648 enabled bool
3749 metrics EmitterMetrics
5050+ // spool persists events that didn't reach the broker on first
5151+ // attempt. The replayer drains it back when Kafka is healthy
5252+ // again. nil = no spool wired (legacy fire-and-forget).
5353+ spool *EventSpool
3854}
39554056// NewEmitter creates an emitter that writes to the given Kafka broker.
···6581//
6682// Per-event-type attribution relies on the "event_type" header that Emit
6783// attaches to every produced message.
8484+//
8585+// On batch failure with a spool wired, every message in the batch is
8686+// re-spooled so a downstream replayer can retry. Without the spool, the
8787+// failure is logged + counted only — legacy behavior.
6888func (e *Emitter) handleCompletion(messages []kafka.Message, err error) {
6989 if err != nil {
7090 log.Printf("osprey.kafka_batch_error: messages=%d error=%v", len(messages), err)
7191 }
7272- if e.metrics == nil {
7373- return
7474- }
7592 for _, m := range messages {
7693 et := eventTypeFromHeaders(m.Headers)
7794 if err != nil {
7878- e.metrics.IncFailed(et)
7979- } else {
9595+ if e.metrics != nil {
9696+ e.metrics.IncFailed(et)
9797+ }
9898+ e.spoolEvent(et, string(m.Key), m.Value)
9999+ } else if e.metrics != nil {
80100 e.metrics.IncEmitted(et)
81101 }
82102 }
···109129 e.metrics = m
110130}
111131132132+// SetSpool wires an on-disk dead-letter queue. When set, events that
133133+// fail to write or that the broker rejects asynchronously are landed
134134+// to the spool instead of being silently dropped. Call ReplaySpool
135135+// periodically (cmd/relay drives this from a GoSafe goroutine) to
136136+// drain the queue back to the broker after recovery. Closes #214.
137137+func (e *Emitter) SetSpool(s *EventSpool) {
138138+ e.spool = s
139139+ if s != nil && e.metrics != nil {
140140+ s.SetDropper(spoolDropperBridge{e.metrics})
141141+ }
142142+}
143143+144144+// spoolDropperBridge adapts EmitterMetrics.IncDropped to the
145145+// SpoolDropper interface so the spool itself doesn't need to import
146146+// the emitter's metrics shape.
147147+type spoolDropperBridge struct{ m EmitterMetrics }
148148+149149+func (s spoolDropperBridge) IncDropped(reason string) { s.m.IncDropped(reason) }
150150+151151+// spoolEvent persists a (key, payload) pair for later replay,
152152+// recording the spool/drop metrics as appropriate. Caller MUST have
153153+// already failed a real send attempt; spoolEvent is a recovery path,
154154+// not a primary write.
155155+func (e *Emitter) spoolEvent(eventType, key string, payload []byte) {
156156+ if e.spool == nil {
157157+ return
158158+ }
159159+ if err := e.spool.Write(eventType, key, payload); err != nil {
160160+ log.Printf("osprey.spool.write_error: event_type=%s error=%v", eventType, err)
161161+ if e.metrics != nil {
162162+ e.metrics.IncDropped("spool_write_error")
163163+ }
164164+ return
165165+ }
166166+ if e.metrics != nil {
167167+ e.metrics.IncSpooled(eventType)
168168+ }
169169+}
170170+171171+// ReplaySpool drains spooled events back to the broker. Call from a
172172+// periodic loop in cmd/relay. Returns (replayed, failed) counts so
173173+// the caller can log a single summary line per pass. Errors from the
174174+// underlying directory listing surface as the third return.
175175+//
176176+// Per-event replay uses the same writer.WriteMessages path as live
177177+// Emit; failures land back on the spool (the entry was never deleted)
178178+// for the next pass. So a sustained Kafka outage manifests as a
179179+// growing spool depth (visible via the gauge) without permanent
180180+// loss until the cap is hit.
181181+func (e *Emitter) ReplaySpool(ctx context.Context) (int, int, error) {
182182+ if e.spool == nil || !e.enabled {
183183+ return 0, 0, nil
184184+ }
185185+ replayed, failed, err := e.spool.Walk(func(se SpooledEvent) error {
186186+ writeErr := e.writer.WriteMessages(ctx, kafka.Message{
187187+ Key: []byte(se.Key),
188188+ Value: se.Payload,
189189+ Headers: []kafka.Header{
190190+ {Key: "event_type", Value: []byte(se.EventType)},
191191+ },
192192+ })
193193+ if writeErr != nil {
194194+ return writeErr
195195+ }
196196+ if e.metrics != nil {
197197+ e.metrics.IncReplayed(se.EventType)
198198+ }
199199+ return nil
200200+ })
201201+ if e.metrics != nil {
202202+ e.metrics.SetSpoolDepth(e.spool.Depth())
203203+ }
204204+ return replayed, failed, err
205205+}
206206+112207// Emit sends an event to Osprey. It is non-blocking (async writes)
113208// and never returns an error to avoid impacting relay operations.
114209func (e *Emitter) Emit(ctx context.Context, data EventData) {
···150245 if e.metrics != nil {
151246 e.metrics.IncFailed(data.EventType)
152247 }
248248+ // Sync-error spool: same failure mode as the async batch
249249+ // case in handleCompletion. Without this branch the buffer-
250250+ // full / shutdown class of failures is silently lost even
251251+ // when the spool is wired (#214).
252252+ e.spoolEvent(data.EventType, data.SenderDID, payload)
153253 }
154254 // Happy-path IncEmitted is intentionally NOT here — it fires in
155255 // handleCompletion once the broker actually confirms the batch. Doing
+41-4
internal/osprey/emitter_integration_test.go
···276276// queue-accepted write, so the actual broker success/failure can only be
277277// observed from the Completion callback. These tests pin that contract.
278278279279-// fakeMetrics records IncEmitted/IncFailed calls for assertion.
279279+// fakeMetrics records EmitterMetrics calls for assertion.
280280type fakeMetrics struct {
281281- mu sync.Mutex
282282- emitted map[string]int
283283- failed map[string]int
281281+ mu sync.Mutex
282282+ emitted map[string]int
283283+ failed map[string]int
284284+ spooled map[string]int
285285+ replayed map[string]int
286286+ dropped map[string]int
287287+ spoolDepth int
284288}
285289286290func (f *fakeMetrics) IncEmitted(t string) {
···299303 f.failed = map[string]int{}
300304 }
301305 f.failed[t]++
306306+}
307307+308308+func (f *fakeMetrics) IncSpooled(t string) {
309309+ f.mu.Lock()
310310+ defer f.mu.Unlock()
311311+ if f.spooled == nil {
312312+ f.spooled = map[string]int{}
313313+ }
314314+ f.spooled[t]++
315315+}
316316+317317+func (f *fakeMetrics) IncReplayed(t string) {
318318+ f.mu.Lock()
319319+ defer f.mu.Unlock()
320320+ if f.replayed == nil {
321321+ f.replayed = map[string]int{}
322322+ }
323323+ f.replayed[t]++
324324+}
325325+326326+func (f *fakeMetrics) IncDropped(reason string) {
327327+ f.mu.Lock()
328328+ defer f.mu.Unlock()
329329+ if f.dropped == nil {
330330+ f.dropped = map[string]int{}
331331+ }
332332+ f.dropped[reason]++
333333+}
334334+335335+func (f *fakeMetrics) SetSpoolDepth(n int) {
336336+ f.mu.Lock()
337337+ f.spoolDepth = n
338338+ f.mu.Unlock()
302339}
303340304341// TestEmitAttachesEventTypeHeader: the Completion callback runs long after
+240
internal/osprey/spool.go
···11+// SPDX-License-Identifier: AGPL-3.0-or-later
22+33+package osprey
44+55+import (
66+ "crypto/rand"
77+ "encoding/hex"
88+ "encoding/json"
99+ "fmt"
1010+ "log"
1111+ "os"
1212+ "path/filepath"
1313+ "sort"
1414+ "strings"
1515+ "sync"
1616+ "time"
1717+)
1818+1919+// EventSpool persists Osprey events to disk when the Kafka emitter
2020+// can't deliver them. A background replayer drains the spool back to
2121+// the broker after reconnect.
2222+//
2323+// Without this, an atmos-ops outage silently loses every relay event
2424+// fired during the window — labels stop propagating, trust scoring
2525+// freezes on stale data, and there is no signal an operator can see
2626+// after-the-fact that says "we lost N events between 03:14 and 04:02."
2727+// Closes #214.
2828+//
2929+// On-disk format: each event is one JSON object per file, named
3030+// {unix-nanos}-{8-hex-rand}.json, stored under dir. Filenames sort
3131+// chronologically so Walk can replay in arrival order.
3232+//
3333+// Bounded: when len(spool) >= maxEntries, the oldest file is dropped
3434+// to make room. Operators MUST graph the drop counter — a non-zero
3535+// drop rate means events were permanently lost.
3636+type EventSpool struct {
3737+ dir string
3838+ maxEntries int
3939+ mu sync.Mutex
4040+ dropMetrics interface{ IncDropped(reason string) }
4141+}
4242+4343+// SpooledEvent is the persisted on-disk representation. Includes the
4444+// event_type so the Kafka header can be reconstructed at replay
4545+// without re-parsing the body.
4646+type SpooledEvent struct {
4747+ EventType string `json:"event_type"`
4848+ Key string `json:"key"`
4949+ Payload json.RawMessage `json:"payload"`
5050+ SpooledAt time.Time `json:"spooled_at"`
5151+}
5252+5353+// SpoolDropper is the narrow interface the spool uses to count
5454+// permanent drops (overflow). Optional; nil-safe.
5555+type SpoolDropper interface {
5656+ IncDropped(reason string)
5757+}
5858+5959+// NewEventSpool creates a spool rooted at dir. Creates the directory
6060+// if it doesn't exist. maxEntries caps the spool depth — when full,
6161+// the oldest entries are dropped to make room. Pass 0 for the
6262+// default cap (10k entries).
6363+func NewEventSpool(dir string, maxEntries int) (*EventSpool, error) {
6464+ if maxEntries <= 0 {
6565+ maxEntries = 10_000
6666+ }
6767+ if err := os.MkdirAll(dir, 0o755); err != nil {
6868+ return nil, fmt.Errorf("mkdir spool: %w", err)
6969+ }
7070+ return &EventSpool{dir: dir, maxEntries: maxEntries}, nil
7171+}
7272+7373+// SetDropper wires a counter that fires when the spool overflows.
7474+// Without it, drops are logged only — fine for tests, insufficient
7575+// for production observability.
7676+func (s *EventSpool) SetDropper(d SpoolDropper) {
7777+ s.dropMetrics = d
7878+}
7979+8080+// Write persists one event. Atomic via tmp+rename + fsync of the
8181+// parent directory so a crash mid-write can't leave a partial file
8282+// that Walk would later choke on.
8383+func (s *EventSpool) Write(eventType, key string, payload []byte) error {
8484+ s.mu.Lock()
8585+ defer s.mu.Unlock()
8686+8787+ // Enforce cap before write so a flood doesn't blow past it on
8888+ // race conditions between checks.
8989+ if err := s.enforceCapLocked(); err != nil {
9090+ log.Printf("osprey.spool.cap_error: %v", err)
9191+ // Continue — losing a write because the cleanup failed is
9292+ // strictly worse than tolerating temporary over-cap.
9393+ }
9494+9595+ se := SpooledEvent{
9696+ EventType: eventType,
9797+ Key: key,
9898+ Payload: payload,
9999+ SpooledAt: time.Now().UTC(),
100100+ }
101101+ data, err := json.Marshal(se)
102102+ if err != nil {
103103+ return fmt.Errorf("marshal: %w", err)
104104+ }
105105+106106+ var raw [4]byte
107107+ if _, err := rand.Read(raw[:]); err != nil {
108108+ return fmt.Errorf("rand: %w", err)
109109+ }
110110+ name := fmt.Sprintf("%d-%s.json", time.Now().UTC().UnixNano(), hex.EncodeToString(raw[:]))
111111+ path := filepath.Join(s.dir, name)
112112+ tmp := path + ".tmp"
113113+114114+ f, err := os.OpenFile(tmp, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o600)
115115+ if err != nil {
116116+ return fmt.Errorf("open tmp: %w", err)
117117+ }
118118+ if _, err := f.Write(data); err != nil {
119119+ f.Close()
120120+ os.Remove(tmp)
121121+ return fmt.Errorf("write tmp: %w", err)
122122+ }
123123+ if err := f.Sync(); err != nil {
124124+ f.Close()
125125+ os.Remove(tmp)
126126+ return fmt.Errorf("fsync tmp: %w", err)
127127+ }
128128+ if err := f.Close(); err != nil {
129129+ os.Remove(tmp)
130130+ return fmt.Errorf("close tmp: %w", err)
131131+ }
132132+ if err := os.Rename(tmp, path); err != nil {
133133+ os.Remove(tmp)
134134+ return fmt.Errorf("rename: %w", err)
135135+ }
136136+ if d, err := os.Open(s.dir); err == nil {
137137+ _ = d.Sync()
138138+ _ = d.Close()
139139+ }
140140+ return nil
141141+}
142142+143143+// Walk invokes fn for each spooled event, in arrival (filename) order.
144144+// On fn returning nil, the entry is removed. On fn returning an error,
145145+// the entry is left in place and the walk continues to the next entry
146146+// — failures are per-entry, not fatal to the loop.
147147+//
148148+// Returns (replayed, failed, err) where err is non-nil only on a
149149+// directory-listing failure (permission, missing dir, etc.).
150150+func (s *EventSpool) Walk(fn func(SpooledEvent) error) (int, int, error) {
151151+ s.mu.Lock()
152152+ defer s.mu.Unlock()
153153+154154+ entries, err := s.listLocked()
155155+ if err != nil {
156156+ return 0, 0, err
157157+ }
158158+ var replayed, failed int
159159+ for _, name := range entries {
160160+ path := filepath.Join(s.dir, name)
161161+ raw, err := os.ReadFile(path)
162162+ if err != nil {
163163+ log.Printf("osprey.spool.read_error: file=%s error=%v", name, err)
164164+ failed++
165165+ continue
166166+ }
167167+ var se SpooledEvent
168168+ if err := json.Unmarshal(raw, &se); err != nil {
169169+ // Corrupt entry — drop so it doesn't block the queue forever.
170170+ log.Printf("osprey.spool.corrupt: file=%s error=%v — dropping", name, err)
171171+ os.Remove(path)
172172+ failed++
173173+ continue
174174+ }
175175+ if err := fn(se); err != nil {
176176+ failed++
177177+ continue // leave in place; replayer will retry next pass
178178+ }
179179+ if err := os.Remove(path); err != nil {
180180+ log.Printf("osprey.spool.remove_error: file=%s error=%v", name, err)
181181+ }
182182+ replayed++
183183+ }
184184+ return replayed, failed, nil
185185+}
186186+187187+// Depth returns the current number of spooled entries. Cheap; called
188188+// every replay tick to update the depth gauge.
189189+func (s *EventSpool) Depth() int {
190190+ s.mu.Lock()
191191+ defer s.mu.Unlock()
192192+ entries, err := s.listLocked()
193193+ if err != nil {
194194+ return 0
195195+ }
196196+ return len(entries)
197197+}
198198+199199+// listLocked returns the sorted list of spool filenames. Caller MUST
200200+// hold s.mu.
201201+func (s *EventSpool) listLocked() ([]string, error) {
202202+ dirEntries, err := os.ReadDir(s.dir)
203203+ if err != nil {
204204+ return nil, err
205205+ }
206206+ var names []string
207207+ for _, de := range dirEntries {
208208+ if de.IsDir() {
209209+ continue
210210+ }
211211+ n := de.Name()
212212+ if !strings.HasSuffix(n, ".json") {
213213+ continue
214214+ }
215215+ names = append(names, n)
216216+ }
217217+ sort.Strings(names) // chronological by Unix-nanos prefix
218218+ return names, nil
219219+}
220220+221221+// enforceCapLocked drops oldest entries until the spool is below cap.
222222+// Caller MUST hold s.mu.
223223+func (s *EventSpool) enforceCapLocked() error {
224224+ entries, err := s.listLocked()
225225+ if err != nil {
226226+ return err
227227+ }
228228+ for len(entries) >= s.maxEntries {
229229+ oldest := entries[0]
230230+ if err := os.Remove(filepath.Join(s.dir, oldest)); err != nil {
231231+ return err
232232+ }
233233+ entries = entries[1:]
234234+ if s.dropMetrics != nil {
235235+ s.dropMetrics.IncDropped("overflow")
236236+ }
237237+ log.Printf("osprey.spool.dropped: file=%s reason=overflow cap=%d", oldest, s.maxEntries)
238238+ }
239239+ return nil
240240+}
+244
internal/osprey/spool_test.go
···11+// SPDX-License-Identifier: AGPL-3.0-or-later
22+33+package osprey
44+55+import (
66+ "context"
77+ "errors"
88+ "os"
99+ "path/filepath"
1010+ "strings"
1111+ "sync"
1212+ "testing"
1313+1414+ "github.com/segmentio/kafka-go"
1515+)
1616+1717+func TestEventSpool_WriteAndWalkRoundTrip(t *testing.T) {
1818+ dir := t.TempDir()
1919+ s, err := NewEventSpool(dir, 0)
2020+ if err != nil {
2121+ t.Fatalf("NewEventSpool: %v", err)
2222+ }
2323+2424+ if err := s.Write("relay_attempt", "did:plc:a", []byte(`{"k":"v"}`)); err != nil {
2525+ t.Fatalf("Write: %v", err)
2626+ }
2727+ if err := s.Write("delivery_result", "did:plc:b", []byte(`{"k":2}`)); err != nil {
2828+ t.Fatalf("Write: %v", err)
2929+ }
3030+3131+ if d := s.Depth(); d != 2 {
3232+ t.Errorf("Depth = %d, want 2", d)
3333+ }
3434+3535+ var seen []string
3636+ r, fail, err := s.Walk(func(se SpooledEvent) error {
3737+ seen = append(seen, se.EventType)
3838+ return nil
3939+ })
4040+ if err != nil {
4141+ t.Fatalf("Walk: %v", err)
4242+ }
4343+ if r != 2 || fail != 0 {
4444+ t.Errorf("replay count: replayed=%d failed=%d, want 2/0", r, fail)
4545+ }
4646+ if len(seen) != 2 || seen[0] != "relay_attempt" || seen[1] != "delivery_result" {
4747+ t.Errorf("Walk order = %v, want [relay_attempt delivery_result]", seen)
4848+ }
4949+ if d := s.Depth(); d != 0 {
5050+ t.Errorf("post-Walk depth = %d, want 0", d)
5151+ }
5252+}
5353+5454+func TestEventSpool_WalkLeavesFailedInPlace(t *testing.T) {
5555+ dir := t.TempDir()
5656+ s, _ := NewEventSpool(dir, 0)
5757+ _ = s.Write("a", "k", []byte(`{}`))
5858+ _ = s.Write("b", "k", []byte(`{}`))
5959+6060+ // First pass: fail the first entry, succeed the rest.
6161+ failOnce := true
6262+ r, fail, _ := s.Walk(func(se SpooledEvent) error {
6363+ if failOnce {
6464+ failOnce = false
6565+ return errors.New("simulated broker outage")
6666+ }
6767+ return nil
6868+ })
6969+ if r != 1 || fail != 1 {
7070+ t.Errorf("first walk replayed=%d failed=%d, want 1/1", r, fail)
7171+ }
7272+ if d := s.Depth(); d != 1 {
7373+ t.Errorf("depth after partial walk = %d, want 1 (failed entry retained)", d)
7474+ }
7575+7676+ // Second pass: everything succeeds.
7777+ r, fail, _ = s.Walk(func(se SpooledEvent) error { return nil })
7878+ if r != 1 || fail != 0 {
7979+ t.Errorf("second walk replayed=%d failed=%d, want 1/0", r, fail)
8080+ }
8181+ if d := s.Depth(); d != 0 {
8282+ t.Errorf("post-recovery depth = %d, want 0", d)
8383+ }
8484+}
8585+8686+type countingDropper struct {
8787+ mu sync.Mutex
8888+ calls map[string]int
8989+}
9090+9191+func (c *countingDropper) IncDropped(reason string) {
9292+ c.mu.Lock()
9393+ defer c.mu.Unlock()
9494+ if c.calls == nil {
9595+ c.calls = map[string]int{}
9696+ }
9797+ c.calls[reason]++
9898+}
9999+func (c *countingDropper) count(r string) int {
100100+ c.mu.Lock()
101101+ defer c.mu.Unlock()
102102+ return c.calls[r]
103103+}
104104+105105+func TestEventSpool_OverflowDropsOldest(t *testing.T) {
106106+ dir := t.TempDir()
107107+ s, _ := NewEventSpool(dir, 3) // tiny cap to make overflow easy
108108+ d := &countingDropper{}
109109+ s.SetDropper(d)
110110+111111+ for i := 0; i < 5; i++ {
112112+ if err := s.Write("t", "k", []byte(`{}`)); err != nil {
113113+ t.Fatalf("Write %d: %v", i, err)
114114+ }
115115+ }
116116+ if got := s.Depth(); got != 3 {
117117+ t.Errorf("Depth = %d, want 3 (cap)", got)
118118+ }
119119+ if d.count("overflow") != 2 {
120120+ t.Errorf("overflow count = %d, want 2", d.count("overflow"))
121121+ }
122122+}
123123+124124+func TestEventSpool_CorruptEntryDroppedNotBlocking(t *testing.T) {
125125+ dir := t.TempDir()
126126+ s, _ := NewEventSpool(dir, 0)
127127+ if err := s.Write("good", "k", []byte(`{}`)); err != nil {
128128+ t.Fatal(err)
129129+ }
130130+ // Inject a corrupt file that bypasses Write — simulate a partial
131131+ // pre-fsync write that survived a crash.
132132+ corrupt := filepath.Join(dir, "0-deadbeef.json")
133133+ if err := os.WriteFile(corrupt, []byte("not-json{"), 0o600); err != nil {
134134+ t.Fatal(err)
135135+ }
136136+137137+ r, fail, err := s.Walk(func(se SpooledEvent) error { return nil })
138138+ if err != nil {
139139+ t.Fatalf("Walk: %v", err)
140140+ }
141141+ // Corrupt entry counts as failed AND is removed; good entry replays.
142142+ if r != 1 {
143143+ t.Errorf("replayed = %d, want 1", r)
144144+ }
145145+ if fail != 1 {
146146+ t.Errorf("failed = %d, want 1 (corrupt)", fail)
147147+ }
148148+ // Corrupt file must be gone — otherwise it blocks the queue forever.
149149+ if _, err := os.Stat(corrupt); !os.IsNotExist(err) {
150150+ t.Errorf("corrupt file survived: stat err=%v", err)
151151+ }
152152+}
153153+154154+// stubWriter is a messageWriter that captures messages and lets tests
155155+// flip between success and failure modes.
156156+type stubWriter struct {
157157+ mu sync.Mutex
158158+ written []kafka.Message
159159+ failNow bool
160160+}
161161+162162+func (w *stubWriter) WriteMessages(_ context.Context, msgs ...kafka.Message) error {
163163+ w.mu.Lock()
164164+ defer w.mu.Unlock()
165165+ if w.failNow {
166166+ return errors.New("broker unreachable")
167167+ }
168168+ w.written = append(w.written, msgs...)
169169+ return nil
170170+}
171171+func (w *stubWriter) Close() error { return nil }
172172+func (w *stubWriter) count() int {
173173+ w.mu.Lock()
174174+ defer w.mu.Unlock()
175175+ return len(w.written)
176176+}
177177+178178+// TestEmitter_FailedSyncWriteSpoolsAndReplays exercises the
179179+// integration between Emit's sync-error branch and ReplaySpool.
180180+func TestEmitter_FailedSyncWriteSpoolsAndReplays(t *testing.T) {
181181+ dir := t.TempDir()
182182+ spool, _ := NewEventSpool(dir, 0)
183183+184184+ w := &stubWriter{failNow: true}
185185+ e := newEmitterWithWriter(w)
186186+ m := &fakeMetrics{}
187187+ e.SetMetrics(m)
188188+ e.SetSpool(spool)
189189+190190+ // First emit: writer fails, event lands on disk.
191191+ e.Emit(context.Background(), EventData{EventType: "relay_attempt", SenderDID: "did:plc:test"})
192192+ if got := spool.Depth(); got != 1 {
193193+ t.Fatalf("post-failure spool depth = %d, want 1", got)
194194+ }
195195+ if got := m.spooled["relay_attempt"]; got != 1 {
196196+ t.Errorf("spooled[relay_attempt] = %d, want 1", got)
197197+ }
198198+199199+ // Recover the broker, replay.
200200+ w.mu.Lock()
201201+ w.failNow = false
202202+ w.mu.Unlock()
203203+ r, fail, err := e.ReplaySpool(context.Background())
204204+ if err != nil {
205205+ t.Fatalf("ReplaySpool: %v", err)
206206+ }
207207+ if r != 1 || fail != 0 {
208208+ t.Errorf("replay r=%d fail=%d, want 1/0", r, fail)
209209+ }
210210+ if got := spool.Depth(); got != 0 {
211211+ t.Errorf("post-replay spool depth = %d, want 0", got)
212212+ }
213213+ if got := m.replayed["relay_attempt"]; got != 1 {
214214+ t.Errorf("replayed[relay_attempt] = %d, want 1", got)
215215+ }
216216+ if got := w.count(); got != 1 {
217217+ t.Errorf("writer received %d messages, want 1", got)
218218+ }
219219+}
220220+221221+// TestEmitter_NoSpoolFallsBackToLegacy confirms backward-compat: an
222222+// emitter with no spool wired drops failed events the same way as
223223+// before #214 (logged + IncFailed only).
224224+func TestEmitter_NoSpoolFallsBackToLegacy(t *testing.T) {
225225+ w := &stubWriter{failNow: true}
226226+ e := newEmitterWithWriter(w)
227227+ m := &fakeMetrics{}
228228+ e.SetMetrics(m)
229229+ // SetSpool intentionally NOT called.
230230+231231+ e.Emit(context.Background(), EventData{EventType: "x", SenderDID: "k"})
232232+ if m.failed["x"] != 1 {
233233+ t.Errorf("failed[x] = %d, want 1", m.failed["x"])
234234+ }
235235+ if m.spooled["x"] != 0 {
236236+ t.Errorf("spooled[x] = %d, want 0 (no spool wired)", m.spooled["x"])
237237+ }
238238+}
239239+240240+// hasSubstr is a tiny helper to keep the corrupt-file test resilient
241241+// to error-string drift.
242242+func hasSubstr(s, sub string) bool { return strings.Contains(s, sub) }
243243+244244+var _ = hasSubstr // keep import used if a future test wants it
···11+// SPDX-License-Identifier: AGPL-3.0-or-later
22+33+package relay
44+55+import (
66+ "crypto/tls"
77+ "fmt"
88+ "log"
99+ "os"
1010+ "sync"
1111+ "time"
1212+)
1313+1414+// CertReloader serves TLS certificates from disk with mtime-based
1515+// reload. Plug it into a *tls.Config via the GetCertificate callback;
1616+// the underlying cert is automatically refreshed when ACME (or any
1717+// other process) replaces the files on disk.
1818+//
1919+// Without this, every cert renewal forced a full relay restart via
2020+// systemd's reloadServices hook — dropping in-flight SMTP/HTTP
2121+// sessions and triggering the spool-reload race in #208. The
2222+// GetCertificate callback is invoked per TLS handshake, which is
2323+// many orders of magnitude cheaper than a process restart.
2424+//
2525+// Concurrency: safe for concurrent handshakes. Cert reads are
2626+// serialized via a mutex; the cached *tls.Certificate is shared
2727+// across all callers.
2828+//
2929+// Closes #216.
3030+type CertReloader struct {
3131+ certPath string
3232+ keyPath string
3333+3434+ mu sync.RWMutex
3535+ cert *tls.Certificate
3636+ loadedAt time.Time
3737+ certMtime time.Time
3838+ keyMtime time.Time
3939+}
4040+4141+// NewCertReloader builds a reloader for the given cert/key pair.
4242+// Loads the cert immediately so callers can fail fast on bad paths.
4343+// On a missing file (first deploy before ACME has minted a cert),
4444+// returns a non-nil reloader with no cached cert; GetCertificate
4545+// returns an error in that state. Callers can keep calling and the
4646+// reloader picks up the cert as soon as it lands.
4747+func NewCertReloader(certPath, keyPath string) (*CertReloader, error) {
4848+ r := &CertReloader{certPath: certPath, keyPath: keyPath}
4949+ if err := r.reload(); err != nil {
5050+ // Don't fail construction — first-boot ACME timing means the
5151+ // file may not exist yet. Log and keep going; the next
5252+ // GetCertificate invocation will retry.
5353+ log.Printf("cert.reload.initial_load_failed: cert=%s key=%s error=%v",
5454+ certPath, keyPath, err)
5555+ }
5656+ return r, nil
5757+}
5858+5959+// GetCertificate is the tls.Config.GetCertificate callback. Returns
6060+// the cached cert, reloading from disk first if the underlying file
6161+// mtime has changed since the last read.
6262+func (r *CertReloader) GetCertificate(_ *tls.ClientHelloInfo) (*tls.Certificate, error) {
6363+ if r.changed() {
6464+ if err := r.reload(); err != nil {
6565+ // Reload failed — fall back to the existing cached cert
6666+ // rather than fail the handshake. ACME mid-renewal can
6767+ // briefly leave the file inconsistent; surfacing that as
6868+ // a TLS handshake failure would be worse than serving the
6969+ // previous cert for one more poll cycle.
7070+ log.Printf("cert.reload.error: %v (continuing with cached cert)", err)
7171+ }
7272+ }
7373+ r.mu.RLock()
7474+ defer r.mu.RUnlock()
7575+ if r.cert == nil {
7676+ return nil, fmt.Errorf("cert.reload: no certificate available (cert=%s key=%s)", r.certPath, r.keyPath)
7777+ }
7878+ return r.cert, nil
7979+}
8080+8181+// changed reports whether either file's mtime has moved since the
8282+// last successful reload. Read-locked so concurrent handshakes
8383+// coalesce on a single mtime stat without serializing.
8484+func (r *CertReloader) changed() bool {
8585+ cs, cerr := os.Stat(r.certPath)
8686+ ks, kerr := os.Stat(r.keyPath)
8787+ if cerr != nil || kerr != nil {
8888+ // File disappeared — pretend nothing changed; the existing
8989+ // cached cert is still usable. A real disappearance gets
9090+ // surfaced on the next reload attempt.
9191+ return false
9292+ }
9393+ r.mu.RLock()
9494+ defer r.mu.RUnlock()
9595+ return !cs.ModTime().Equal(r.certMtime) || !ks.ModTime().Equal(r.keyMtime)
9696+}
9797+9898+// reload reads cert+key from disk and atomically swaps the cached
9999+// *tls.Certificate. Holds the write lock briefly; returning an
100100+// error leaves the previous cache untouched.
101101+func (r *CertReloader) reload() error {
102102+ cs, err := os.Stat(r.certPath)
103103+ if err != nil {
104104+ return fmt.Errorf("stat cert: %w", err)
105105+ }
106106+ ks, err := os.Stat(r.keyPath)
107107+ if err != nil {
108108+ return fmt.Errorf("stat key: %w", err)
109109+ }
110110+ cert, err := tls.LoadX509KeyPair(r.certPath, r.keyPath)
111111+ if err != nil {
112112+ return fmt.Errorf("load keypair: %w", err)
113113+ }
114114+ r.mu.Lock()
115115+ r.cert = &cert
116116+ r.loadedAt = time.Now()
117117+ r.certMtime = cs.ModTime()
118118+ r.keyMtime = ks.ModTime()
119119+ r.mu.Unlock()
120120+ log.Printf("cert.reload: loaded cert=%s key=%s mtime=%s",
121121+ r.certPath, r.keyPath, cs.ModTime().Format(time.RFC3339))
122122+ return nil
123123+}
124124+125125+// LoadedAt returns the wall-clock time of the most recent successful
126126+// reload. Used by metrics + the cert-age dashboard panel.
127127+func (r *CertReloader) LoadedAt() time.Time {
128128+ r.mu.RLock()
129129+ defer r.mu.RUnlock()
130130+ return r.loadedAt
131131+}
+204
internal/relay/cert_reload_test.go
···11+// SPDX-License-Identifier: AGPL-3.0-or-later
22+33+package relay
44+55+import (
66+ "crypto/ecdsa"
77+ "crypto/elliptic"
88+ "crypto/rand"
99+ "crypto/tls"
1010+ "crypto/x509"
1111+ "crypto/x509/pkix"
1212+ "encoding/pem"
1313+ "math/big"
1414+ "os"
1515+ "path/filepath"
1616+ "testing"
1717+ "time"
1818+)
1919+2020+// writeTestCertPair generates a self-signed cert pair into the given
2121+// dir, returning the cert and key paths. CommonName is set to the
2222+// supplied identifier so two successive calls produce visibly
2323+// different certificates and the reload-on-mtime-change test can
2424+// distinguish them after a refresh.
2525+func writeTestCertPair(t *testing.T, dir, identifier string) (string, string) {
2626+ t.Helper()
2727+ key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
2828+ if err != nil {
2929+ t.Fatalf("genkey: %v", err)
3030+ }
3131+ template := &x509.Certificate{
3232+ SerialNumber: big.NewInt(1),
3333+ Subject: pkix.Name{Organization: []string{identifier}, CommonName: identifier},
3434+ NotBefore: time.Now(),
3535+ NotAfter: time.Now().Add(time.Hour),
3636+ KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment,
3737+ ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
3838+ DNSNames: []string{"localhost"},
3939+ }
4040+ certDER, err := x509.CreateCertificate(rand.Reader, template, template, &key.PublicKey, key)
4141+ if err != nil {
4242+ t.Fatalf("create cert: %v", err)
4343+ }
4444+ keyDER, err := x509.MarshalECPrivateKey(key)
4545+ if err != nil {
4646+ t.Fatalf("marshal key: %v", err)
4747+ }
4848+ certPath := filepath.Join(dir, identifier+".crt")
4949+ keyPath := filepath.Join(dir, identifier+".key")
5050+ certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER})
5151+ keyPEM := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER})
5252+ if err := os.WriteFile(certPath, certPEM, 0o600); err != nil {
5353+ t.Fatalf("write cert: %v", err)
5454+ }
5555+ if err := os.WriteFile(keyPath, keyPEM, 0o600); err != nil {
5656+ t.Fatalf("write key: %v", err)
5757+ }
5858+ return certPath, keyPath
5959+}
6060+6161+// commonNameOf parses the leaf cert from a tls.Certificate and
6262+// returns its Subject Common Name. Used by tests to confirm the
6363+// reloader served the new cert (different CN) rather than the
6464+// cached old one.
6565+func commonNameOf(t *testing.T, c *tls.Certificate) string {
6666+ t.Helper()
6767+ if c == nil || len(c.Certificate) == 0 {
6868+ t.Fatal("nil or empty certificate")
6969+ }
7070+ leaf, err := x509.ParseCertificate(c.Certificate[0])
7171+ if err != nil {
7272+ t.Fatalf("parse leaf: %v", err)
7373+ }
7474+ return leaf.Subject.CommonName
7575+}
7676+7777+func TestCertReloader_LoadsOnConstruction(t *testing.T) {
7878+ dir := t.TempDir()
7979+ certPath, keyPath := writeTestCertPair(t, dir, "v1")
8080+8181+ r, err := NewCertReloader(certPath, keyPath)
8282+ if err != nil {
8383+ t.Fatalf("NewCertReloader: %v", err)
8484+ }
8585+ got, err := r.GetCertificate(nil)
8686+ if err != nil {
8787+ t.Fatalf("GetCertificate: %v", err)
8888+ }
8989+ if cn := commonNameOf(t, got); cn != "v1" {
9090+ t.Errorf("CommonName = %q, want v1", cn)
9191+ }
9292+}
9393+9494+// TestCertReloader_ReloadsOnMtimeChange is the core invariant of
9595+// #216: rewriting the cert on disk causes the next handshake to
9696+// serve the new cert, with no process restart.
9797+func TestCertReloader_ReloadsOnMtimeChange(t *testing.T) {
9898+ dir := t.TempDir()
9999+ certPath, keyPath := writeTestCertPair(t, dir, "v1")
100100+ r, err := NewCertReloader(certPath, keyPath)
101101+ if err != nil {
102102+ t.Fatalf("NewCertReloader: %v", err)
103103+ }
104104+105105+ // First handshake: v1.
106106+ got, _ := r.GetCertificate(nil)
107107+ if cn := commonNameOf(t, got); cn != "v1" {
108108+ t.Fatalf("first CN = %q, want v1", cn)
109109+ }
110110+111111+ // Sleep briefly so mtime is observably different across
112112+ // filesystems with second-granularity stat (older macOS, ext3).
113113+ time.Sleep(1100 * time.Millisecond)
114114+115115+ // Rewrite cert+key with a new identifier.
116116+ v2Cert, v2Key := writeTestCertPair(t, dir, "v2")
117117+ // Move the v2 files into the v1 paths so the reloader sees the
118118+ // SAME paths but DIFFERENT mtimes/contents — which is exactly
119119+ // what ACME does when it renews in-place.
120120+ if err := os.Rename(v2Cert, certPath); err != nil {
121121+ t.Fatal(err)
122122+ }
123123+ if err := os.Rename(v2Key, keyPath); err != nil {
124124+ t.Fatal(err)
125125+ }
126126+127127+ got, err = r.GetCertificate(nil)
128128+ if err != nil {
129129+ t.Fatalf("post-rotation GetCertificate: %v", err)
130130+ }
131131+ if cn := commonNameOf(t, got); cn != "v2" {
132132+ t.Errorf("post-rotation CN = %q, want v2 (reloader didn't pick up new file)", cn)
133133+ }
134134+}
135135+136136+// TestCertReloader_FallsBackOnReadError confirms that a transient
137137+// rename-in-flight (where the cert file briefly disappears or is
138138+// corrupted) doesn't fail the next TLS handshake — we keep serving
139139+// the previously-cached cert.
140140+func TestCertReloader_FallsBackOnReadError(t *testing.T) {
141141+ dir := t.TempDir()
142142+ certPath, keyPath := writeTestCertPair(t, dir, "v1")
143143+ r, err := NewCertReloader(certPath, keyPath)
144144+ if err != nil {
145145+ t.Fatalf("NewCertReloader: %v", err)
146146+ }
147147+148148+ // Prime the cache.
149149+ if _, err := r.GetCertificate(nil); err != nil {
150150+ t.Fatal(err)
151151+ }
152152+153153+ // Corrupt the cert file mid-flight.
154154+ time.Sleep(1100 * time.Millisecond)
155155+ if err := os.WriteFile(certPath, []byte("not a pem"), 0o600); err != nil {
156156+ t.Fatal(err)
157157+ }
158158+159159+ // GetCertificate should fall back to the cached v1 rather than
160160+ // fail the handshake.
161161+ got, err := r.GetCertificate(nil)
162162+ if err != nil {
163163+ t.Fatalf("expected fallback to cached cert, got err=%v", err)
164164+ }
165165+ if cn := commonNameOf(t, got); cn != "v1" {
166166+ t.Errorf("fallback CN = %q, want v1", cn)
167167+ }
168168+}
169169+170170+// TestCertReloader_FirstBootMissingFile pins the first-deploy
171171+// behavior: the cert may not exist yet (ACME hasn't minted it).
172172+// NewCertReloader must not fail; GetCertificate returns an error
173173+// the TLS layer can surface, and a later GetCertificate after the
174174+// file lands picks it up.
175175+func TestCertReloader_FirstBootMissingFile(t *testing.T) {
176176+ dir := t.TempDir()
177177+ certPath := filepath.Join(dir, "missing.crt")
178178+ keyPath := filepath.Join(dir, "missing.key")
179179+180180+ r, err := NewCertReloader(certPath, keyPath)
181181+ if err != nil {
182182+ t.Fatalf("NewCertReloader should not fail on missing files, got %v", err)
183183+ }
184184+ if _, err := r.GetCertificate(nil); err == nil {
185185+ t.Fatal("expected error from GetCertificate when no cert exists")
186186+ }
187187+188188+ // ACME mints the cert post-startup.
189189+ cp, kp := writeTestCertPair(t, dir, "fresh")
190190+ if err := os.Rename(cp, certPath); err != nil {
191191+ t.Fatal(err)
192192+ }
193193+ if err := os.Rename(kp, keyPath); err != nil {
194194+ t.Fatal(err)
195195+ }
196196+197197+ got, err := r.GetCertificate(nil)
198198+ if err != nil {
199199+ t.Fatalf("post-mint GetCertificate: %v", err)
200200+ }
201201+ if cn := commonNameOf(t, got); cn != "fresh" {
202202+ t.Errorf("CN = %q, want fresh", cn)
203203+ }
204204+}
···11+// SPDX-License-Identifier: AGPL-3.0-or-later
22+33+package relay
44+55+import (
66+ "log"
77+ "runtime/debug"
88+)
99+1010+// PanicRecorder is the narrow interface GoSafe needs to count
1111+// recovered panics, so callers can wire metrics.GoroutineCrashes
1212+// without the relay package taking a hard dependency on Prometheus
1313+// types here. Implementations must be safe for concurrent use.
1414+type PanicRecorder interface {
1515+ IncGoroutineCrash(name string)
1616+}
1717+1818+// goSafePanicRecorder is the package-level recorder. nil means
1919+// panics are still recovered + logged but not counted. Set via
2020+// SetPanicRecorder during cmd/relay wiring.
2121+var goSafePanicRecorder PanicRecorder
2222+2323+// SetPanicRecorder installs a metrics recorder used by GoSafe to
2424+// count recovered panics. Calling more than once replaces the
2525+// previous recorder. Safe to call before any GoSafe invocation.
2626+func SetPanicRecorder(r PanicRecorder) {
2727+ goSafePanicRecorder = r
2828+}
2929+3030+// GoSafe runs fn in a new goroutine with a deferred recover that
3131+// converts a panic into a log line + stack trace + metric increment
3232+// instead of process termination.
3333+//
3434+// Without this wrapper, every long-lived background goroutine in the
3535+// relay (queue worker, inbound server, public listener, events
3636+// consumer, health probe, hourly cleanups, notify worker, warmup
3737+// scheduler, ...) crashes the entire relay process on any panic.
3838+// A malformed inbound ARF report or a poison Kafka record is enough
3939+// to take the SMTP service down indefinitely. The deferred recover
4040+// here turns those into observable, contained failures the operator
4141+// can investigate without an outage. Closes #209.
4242+//
4343+// name is a stable label suitable for Prometheus and grep — keep it
4444+// short and stable across deploys ("queue.run", "inbound.serve",
4545+// etc.). Anonymous goroutines are intentionally rejected; if a
4646+// caller passes "", GoSafe still runs but panics report as
4747+// name="unnamed" so the metric remains non-empty.
4848+func GoSafe(name string, fn func()) {
4949+ if name == "" {
5050+ name = "unnamed"
5151+ }
5252+ go func() {
5353+ defer func() {
5454+ if r := recover(); r != nil {
5555+ log.Printf("goroutine.panic: name=%s recovered=%v\n%s",
5656+ name, r, string(debug.Stack()))
5757+ if goSafePanicRecorder != nil {
5858+ goSafePanicRecorder.IncGoroutineCrash(name)
5959+ }
6060+ }
6161+ }()
6262+ fn()
6363+ }()
6464+}
+142
internal/relay/gosafe_test.go
···11+// SPDX-License-Identifier: AGPL-3.0-or-later
22+33+package relay
44+55+import (
66+ "sync"
77+ "sync/atomic"
88+ "testing"
99+ "time"
1010+)
1111+1212+// fakePanicRecorder captures IncGoroutineCrash calls so tests can
1313+// assert the panic was both recovered and counted.
1414+type fakePanicRecorder struct {
1515+ mu sync.Mutex
1616+ calls map[string]int
1717+}
1818+1919+func newFakePanicRecorder() *fakePanicRecorder {
2020+ return &fakePanicRecorder{calls: map[string]int{}}
2121+}
2222+2323+func (f *fakePanicRecorder) IncGoroutineCrash(name string) {
2424+ f.mu.Lock()
2525+ f.calls[name]++
2626+ f.mu.Unlock()
2727+}
2828+2929+func (f *fakePanicRecorder) count(name string) int {
3030+ f.mu.Lock()
3131+ defer f.mu.Unlock()
3232+ return f.calls[name]
3333+}
3434+3535+// withRecorder swaps the package-global recorder for the duration of
3636+// a test and restores the previous value on cleanup. Tests run in a
3737+// single process so we serialize via a mutex to avoid races between
3838+// parallel tests that might otherwise see each other's recorders.
3939+var goSafeTestMu sync.Mutex
4040+4141+func withRecorder(t *testing.T, r PanicRecorder) {
4242+ t.Helper()
4343+ goSafeTestMu.Lock()
4444+ prev := goSafePanicRecorder
4545+ SetPanicRecorder(r)
4646+ t.Cleanup(func() {
4747+ SetPanicRecorder(prev)
4848+ goSafeTestMu.Unlock()
4949+ })
5050+}
5151+5252+// awaitCount polls until f.count(name) reaches want or timeout.
5353+// Failing fast on a missing increment is more useful than a
5454+// timeout after the deferred recover ate the panic invisibly.
5555+func awaitCount(t *testing.T, f *fakePanicRecorder, name string, want int) {
5656+ t.Helper()
5757+ deadline := time.Now().Add(2 * time.Second)
5858+ for time.Now().Before(deadline) {
5959+ if f.count(name) >= want {
6060+ return
6161+ }
6262+ time.Sleep(5 * time.Millisecond)
6363+ }
6464+ t.Fatalf("expected %s count >= %d, got %d", name, want, f.count(name))
6565+}
6666+6767+func TestGoSafe_RunsFnNoPanic(t *testing.T) {
6868+ rec := newFakePanicRecorder()
6969+ withRecorder(t, rec)
7070+ var ran atomic.Bool
7171+ done := make(chan struct{})
7272+ GoSafe("happy", func() {
7373+ ran.Store(true)
7474+ close(done)
7575+ })
7676+ select {
7777+ case <-done:
7878+ case <-time.After(time.Second):
7979+ t.Fatal("fn never ran")
8080+ }
8181+ if !ran.Load() {
8282+ t.Fatal("ran flag not set")
8383+ }
8484+ if got := rec.count("happy"); got != 0 {
8585+ t.Errorf("crash count = %d, want 0 on no-panic path", got)
8686+ }
8787+}
8888+8989+func TestGoSafe_RecoversPanicAndCounts(t *testing.T) {
9090+ rec := newFakePanicRecorder()
9191+ withRecorder(t, rec)
9292+ GoSafe("crashy", func() {
9393+ panic("intentional test panic")
9494+ })
9595+ awaitCount(t, rec, "crashy", 1)
9696+}
9797+9898+// TestGoSafe_ProcessSurvivesPanic — the load-bearing assertion of
9999+// #209: a panicking goroutine must NOT terminate the process. Test
100100+// runs the panicking GoSafe and then verifies a subsequent line of
101101+// test code executes (which it cannot if the runtime crashed).
102102+func TestGoSafe_ProcessSurvivesPanic(t *testing.T) {
103103+ rec := newFakePanicRecorder()
104104+ withRecorder(t, rec)
105105+ GoSafe("survivor", func() { panic("boom") })
106106+ awaitCount(t, rec, "survivor", 1)
107107+ // Reaching this line proves the runtime is still alive. If GoSafe
108108+ // regresses to a no-recover, the panic propagates and the test
109109+ // process dies before this assertion can run.
110110+ if t.Failed() {
111111+ t.Fatal("unreachable failure")
112112+ }
113113+}
114114+115115+func TestGoSafe_NilRecorderStillRecovers(t *testing.T) {
116116+ withRecorder(t, nil) // explicit nil
117117+ // Should not panic the test process.
118118+ GoSafe("orphan", func() { panic("no recorder, no problem") })
119119+ // Allow the goroutine time to schedule and recover.
120120+ time.Sleep(50 * time.Millisecond)
121121+}
122122+123123+func TestGoSafe_EmptyNameFallsBackToUnnamed(t *testing.T) {
124124+ rec := newFakePanicRecorder()
125125+ withRecorder(t, rec)
126126+ GoSafe("", func() { panic("anon") })
127127+ awaitCount(t, rec, "unnamed", 1)
128128+}
129129+130130+// TestGoSafe_MultiplePanicsCounted ensures the metric is per-name
131131+// and accumulates correctly under load — covers the "poison record
132132+// in a tight loop" scenario where the same goroutine panics
133133+// repeatedly while a supervisor restarts it.
134134+func TestGoSafe_MultiplePanicsCounted(t *testing.T) {
135135+ rec := newFakePanicRecorder()
136136+ withRecorder(t, rec)
137137+ const n = 5
138138+ for i := 0; i < n; i++ {
139139+ GoSafe("burst", func() { panic("repeat") })
140140+ }
141141+ awaitCount(t, rec, "burst", n)
142142+}
+41-3
internal/relay/inbound.go
···2424 ListenAddr string // default ":25"
2525 Domain string // relay domain (e.g. "atmos.email")
2626 MaxMsgSize int64 // default 10MB (replies can include larger bodies than DSNs)
2727+2828+ // RateLimitMsgsPerMinute caps the per-source-IP message rate at MAIL
2929+ // FROM. Zero or negative disables rate limiting (legacy behavior). A
3030+ // reasonable production default is 30 (50% headroom over the highest
3131+ // rate any legitimate single-source provider has been observed at).
3232+ RateLimitMsgsPerMinute float64
3333+ // RateLimitBurst is the token-bucket capacity. Zero defaults to 10.
3434+ // Bursts above this size from a single IP get a 421 retry-later;
3535+ // over a sustained window the IP is held to RateLimitMsgsPerMinute.
3636+ RateLimitBurst int
2737}
28382939// BounceHandler is called when a valid bounce DSN is received and matched.
···5666type InboundMetrics interface {
5767 RecordInbound(classification string)
5868 RecordForward(status string)
6969+ // RecordRejected fires when an inbound session is rejected before
7070+ // classification (e.g. rate-limited). reason is a short stable
7171+ // identifier ("rate_limit", ...) suitable for a Prometheus label.
7272+ RecordRejected(reason string)
5973}
60746175// InboundLogEntry is a single structured record of an accepted inbound
···112126 // Without this, provider authorization emails (Microsoft SNDS,
113127 // Yahoo CFL) and ops-team mail never reach a human.
114128 operatorForwardTo string
129129+130130+ // rateLimiter, when non-nil, enforces a per-source-IP rate limit at
131131+ // MAIL FROM. nil means rate limiting is disabled.
132132+ rateLimiter *inboundRateLimiter
115133}
116134117135// NewInboundServer creates an inbound SMTP server. domainLookup, forwarder,
···130148 domain: cfg.Domain,
131149 onBounce: onBounce,
132150 memberLookup: memberLookup,
151151+ rateLimiter: newInboundRateLimiter(cfg.RateLimitMsgsPerMinute, cfg.RateLimitBurst, 0),
133152 }
134153135154 smtpSrv := smtp.NewServer(s)
···191210192211// Close shuts down the inbound SMTP server.
193212func (s *InboundServer) Close() error {
213213+ if s.rateLimiter != nil {
214214+ s.rateLimiter.Close()
215215+ }
194216 return s.server.Close()
195217}
196218197219// NewSession implements smtp.Backend for the inbound server.
198220func (s *InboundServer) NewSession(c *smtp.Conn) (smtp.Session, error) {
199199- return &inboundSession{server: s}, nil
221221+ var ip string
222222+ if conn := c.Conn(); conn != nil {
223223+ ip = remoteIP(conn.RemoteAddr().String())
224224+ }
225225+ return &inboundSession{server: s, remoteIP: ip}, nil
200226}
201227202228// inboundSession handles a single inbound SMTP connection.
203229type inboundSession struct {
204204- server *InboundServer
205205- from string
230230+ server *InboundServer
231231+ remoteIP string // captured at NewSession; "" if unavailable
232232+ from string
206233 // rcpts holds per-recipient classification so Data() can route each
207234 // recipient to the right handler.
208235 rcpts []inboundRcpt
···232259}
233260234261func (s *inboundSession) Mail(from string, opts *smtp.MailOptions) error {
262262+ if !s.server.rateLimiter.Allow(s.remoteIP) {
263263+ log.Printf("inbound.rate_limited: ip=%s from=%s", s.remoteIP, from)
264264+ if s.server.metrics != nil {
265265+ s.server.metrics.RecordRejected("rate_limit")
266266+ }
267267+ return &smtp.SMTPError{
268268+ Code: 421,
269269+ EnhancedCode: smtp.EnhancedCode{4, 7, 0},
270270+ Message: "rate limit exceeded; please retry later",
271271+ }
272272+ }
235273 s.from = from
236274 return nil
237275}
+156
internal/relay/inbound_ratelimit.go
···11+// SPDX-License-Identifier: AGPL-3.0-or-later
22+33+package relay
44+55+import (
66+ "net"
77+ "sync"
88+ "time"
99+1010+ "golang.org/x/time/rate"
1111+)
1212+1313+// inboundRateLimiter enforces a per-source-IP token-bucket rate limit on
1414+// inbound SMTP MAIL FROM commands. It exists to prevent the inbound
1515+// listener (port 25) from being used as an open-relay-shaped amplifier:
1616+// a single attacker IP can otherwise burn through the relay's outbound
1717+// reputation by flooding member forward_to mailboxes or dumping noise
1818+// at the VERP/FBL/postmaster handlers.
1919+//
2020+// The limiter is keyed on the remote IP only (no port). It does not
2121+// distinguish between recipient types — bounces, FBL reports, replies,
2222+// and postmaster mail share the same budget per source. Legitimate
2323+// providers send from many IPs and never approach the per-IP rate; an
2424+// abuse source coming from a single IP burns its budget quickly and
2525+// gets a 421 deferral.
2626+//
2727+// A background goroutine evicts entries that haven't been touched in
2828+// idleTimeout, keeping the map bounded under sustained scanning traffic.
2929+type inboundRateLimiter struct {
3030+ mu sync.Mutex
3131+ buckets map[string]*ipBucket
3232+3333+ rate rate.Limit // tokens per second
3434+ burst int // bucket capacity
3535+ idleTimeout time.Duration // evict entries idle this long
3636+3737+ stop chan struct{}
3838+ stopWG sync.WaitGroup
3939+}
4040+4141+type ipBucket struct {
4242+ limiter *rate.Limiter
4343+ lastSeen time.Time
4444+}
4545+4646+// newInboundRateLimiter constructs a limiter and starts the cleanup
4747+// goroutine. msgsPerMinute <= 0 returns nil — callers must handle the
4848+// nil case as "rate limiting disabled". Negative or zero burst is
4949+// clamped to a sane default.
5050+func newInboundRateLimiter(msgsPerMinute float64, burst int, idleTimeout time.Duration) *inboundRateLimiter {
5151+ if msgsPerMinute <= 0 {
5252+ return nil
5353+ }
5454+ if burst <= 0 {
5555+ burst = 10
5656+ }
5757+ if idleTimeout <= 0 {
5858+ idleTimeout = 10 * time.Minute
5959+ }
6060+ rl := &inboundRateLimiter{
6161+ buckets: make(map[string]*ipBucket),
6262+ rate: rate.Limit(msgsPerMinute / 60.0),
6363+ burst: burst,
6464+ idleTimeout: idleTimeout,
6565+ stop: make(chan struct{}),
6666+ }
6767+ rl.stopWG.Add(1)
6868+ go rl.cleanupLoop()
6969+ return rl
7070+}
7171+7272+// Allow returns true if a message from ip is permitted. The empty string
7373+// is allowed (no rate limit applied) so unit tests and tools that bypass
7474+// network plumbing aren't blocked. Production callers always pass a real
7575+// remote IP because the smtp.Conn carries one.
7676+func (rl *inboundRateLimiter) Allow(ip string) bool {
7777+ if rl == nil || ip == "" {
7878+ return true
7979+ }
8080+ rl.mu.Lock()
8181+ b, ok := rl.buckets[ip]
8282+ if !ok {
8383+ b = &ipBucket{limiter: rate.NewLimiter(rl.rate, rl.burst)}
8484+ rl.buckets[ip] = b
8585+ }
8686+ b.lastSeen = time.Now()
8787+ rl.mu.Unlock()
8888+ return b.limiter.Allow()
8989+}
9090+9191+// Close stops the cleanup goroutine. Safe to call multiple times.
9292+func (rl *inboundRateLimiter) Close() {
9393+ if rl == nil {
9494+ return
9595+ }
9696+ select {
9797+ case <-rl.stop:
9898+ return // already closed
9999+ default:
100100+ close(rl.stop)
101101+ }
102102+ rl.stopWG.Wait()
103103+}
104104+105105+// cleanupLoop evicts buckets that haven't been seen within idleTimeout.
106106+// Runs at idleTimeout/2 cadence so an entry never lingers more than
107107+// 1.5×idleTimeout after its last use.
108108+func (rl *inboundRateLimiter) cleanupLoop() {
109109+ defer rl.stopWG.Done()
110110+ tick := time.NewTicker(rl.idleTimeout / 2)
111111+ defer tick.Stop()
112112+ for {
113113+ select {
114114+ case <-rl.stop:
115115+ return
116116+ case now := <-tick.C:
117117+ rl.evictIdle(now)
118118+ }
119119+ }
120120+}
121121+122122+func (rl *inboundRateLimiter) evictIdle(now time.Time) {
123123+ cutoff := now.Add(-rl.idleTimeout)
124124+ rl.mu.Lock()
125125+ for ip, b := range rl.buckets {
126126+ if b.lastSeen.Before(cutoff) {
127127+ delete(rl.buckets, ip)
128128+ }
129129+ }
130130+ rl.mu.Unlock()
131131+}
132132+133133+// size returns the number of tracked IPs. Used by tests to verify
134134+// cleanup; not exported.
135135+func (rl *inboundRateLimiter) size() int {
136136+ if rl == nil {
137137+ return 0
138138+ }
139139+ rl.mu.Lock()
140140+ defer rl.mu.Unlock()
141141+ return len(rl.buckets)
142142+}
143143+144144+// remoteIP extracts the IP portion of a "host:port" string. Returns the
145145+// input unchanged if it doesn't parse as a host:port (e.g. unix sockets
146146+// in tests). Always returns lowercase to normalize IPv6.
147147+func remoteIP(addr string) string {
148148+ if addr == "" {
149149+ return ""
150150+ }
151151+ host, _, err := net.SplitHostPort(addr)
152152+ if err != nil {
153153+ return addr
154154+ }
155155+ return host
156156+}
+286
internal/relay/inbound_ratelimit_test.go
···11+// SPDX-License-Identifier: AGPL-3.0-or-later
22+33+package relay
44+55+import (
66+ "context"
77+ "fmt"
88+ "net"
99+ gosmtp "net/smtp"
1010+ "strings"
1111+ "sync"
1212+ "sync/atomic"
1313+ "testing"
1414+ "time"
1515+)
1616+1717+func TestInboundRateLimiter_AllowsWithinBurst(t *testing.T) {
1818+ rl := newInboundRateLimiter(60, 5, time.Minute)
1919+ defer rl.Close()
2020+ for i := 0; i < 5; i++ {
2121+ if !rl.Allow("1.2.3.4") {
2222+ t.Fatalf("call %d denied within burst", i+1)
2323+ }
2424+ }
2525+}
2626+2727+func TestInboundRateLimiter_BlocksOnceBurstExhausted(t *testing.T) {
2828+ // Very low refill so the burst is the only budget within the test window.
2929+ rl := newInboundRateLimiter(1, 3, time.Minute)
3030+ defer rl.Close()
3131+ for i := 0; i < 3; i++ {
3232+ if !rl.Allow("1.2.3.4") {
3333+ t.Fatalf("call %d denied within burst", i+1)
3434+ }
3535+ }
3636+ if rl.Allow("1.2.3.4") {
3737+ t.Fatal("4th call allowed; expected rate-limit denial")
3838+ }
3939+}
4040+4141+func TestInboundRateLimiter_PerIPIsolation(t *testing.T) {
4242+ rl := newInboundRateLimiter(1, 2, time.Minute)
4343+ defer rl.Close()
4444+ // Exhaust IP A's burst.
4545+ rl.Allow("1.1.1.1")
4646+ rl.Allow("1.1.1.1")
4747+ if rl.Allow("1.1.1.1") {
4848+ t.Fatal("IP A still allowed after burst")
4949+ }
5050+ // IP B has its own bucket.
5151+ if !rl.Allow("2.2.2.2") {
5252+ t.Fatal("IP B denied — buckets leaked across IPs")
5353+ }
5454+}
5555+5656+func TestInboundRateLimiter_NilIsAlwaysAllow(t *testing.T) {
5757+ var rl *inboundRateLimiter
5858+ if !rl.Allow("anything") {
5959+ t.Fatal("nil limiter should allow")
6060+ }
6161+ // Close on nil must not panic.
6262+ rl.Close()
6363+}
6464+6565+func TestInboundRateLimiter_ZeroDisables(t *testing.T) {
6666+ if rl := newInboundRateLimiter(0, 10, time.Minute); rl != nil {
6767+ t.Fatal("zero rate should return nil limiter")
6868+ }
6969+ if rl := newInboundRateLimiter(-5, 10, time.Minute); rl != nil {
7070+ t.Fatal("negative rate should return nil limiter")
7171+ }
7272+}
7373+7474+func TestInboundRateLimiter_EmptyIPAlwaysAllowed(t *testing.T) {
7575+ rl := newInboundRateLimiter(60, 1, time.Minute)
7676+ defer rl.Close()
7777+ // Burst is 1; if "" mapped to a bucket it'd run out fast. It shouldn't.
7878+ for i := 0; i < 100; i++ {
7979+ if !rl.Allow("") {
8080+ t.Fatalf("empty IP denied at call %d", i)
8181+ }
8282+ }
8383+}
8484+8585+func TestInboundRateLimiter_EvictsIdleEntries(t *testing.T) {
8686+ rl := newInboundRateLimiter(60, 5, 50*time.Millisecond)
8787+ defer rl.Close()
8888+ rl.Allow("9.9.9.9")
8989+ if rl.size() != 1 {
9090+ t.Fatalf("size = %d, want 1", rl.size())
9191+ }
9292+ // Force an eviction with a synthetic future time.
9393+ rl.evictIdle(time.Now().Add(time.Hour))
9494+ if rl.size() != 0 {
9595+ t.Fatalf("size after eviction = %d, want 0", rl.size())
9696+ }
9797+}
9898+9999+func TestRemoteIP_StripsPort(t *testing.T) {
100100+ cases := []struct {
101101+ in, want string
102102+ }{
103103+ {"1.2.3.4:567", "1.2.3.4"},
104104+ {"[::1]:567", "::1"},
105105+ {"[2001:db8::1]:25", "2001:db8::1"},
106106+ {"127.0.0.1:0", "127.0.0.1"},
107107+ {"", ""},
108108+ {"not-a-host-port", "not-a-host-port"},
109109+ }
110110+ for _, c := range cases {
111111+ if got := remoteIP(c.in); got != c.want {
112112+ t.Errorf("remoteIP(%q) = %q, want %q", c.in, got, c.want)
113113+ }
114114+ }
115115+}
116116+117117+// recordingMetrics captures InboundMetrics calls for tests that need to
118118+// observe rate-limit rejections without exposing internal state.
119119+type recordingMetrics struct {
120120+ mu sync.Mutex
121121+ rejected map[string]int
122122+ inbound map[string]int
123123+ forward map[string]int
124124+}
125125+126126+func newRecordingMetrics() *recordingMetrics {
127127+ return &recordingMetrics{
128128+ rejected: make(map[string]int),
129129+ inbound: make(map[string]int),
130130+ forward: make(map[string]int),
131131+ }
132132+}
133133+134134+func (m *recordingMetrics) RecordInbound(c string) {
135135+ m.mu.Lock()
136136+ m.inbound[c]++
137137+ m.mu.Unlock()
138138+}
139139+func (m *recordingMetrics) RecordForward(s string) {
140140+ m.mu.Lock()
141141+ m.forward[s]++
142142+ m.mu.Unlock()
143143+}
144144+func (m *recordingMetrics) RecordRejected(r string) {
145145+ m.mu.Lock()
146146+ m.rejected[r]++
147147+ m.mu.Unlock()
148148+}
149149+func (m *recordingMetrics) rejectedCount(reason string) int {
150150+ m.mu.Lock()
151151+ defer m.mu.Unlock()
152152+ return m.rejected[reason]
153153+}
154154+155155+// TestInbound_RateLimitRejects421OverWire wires the limiter into a real
156156+// SMTP server bound to 127.0.0.1, sends N+1 messages over the burst, and
157157+// asserts the (N+1)th gets a 421 with the expected enhanced code.
158158+func TestInbound_RateLimitRejects421OverWire(t *testing.T) {
159159+ ln, err := net.Listen("tcp", "127.0.0.1:0")
160160+ if err != nil {
161161+ t.Fatalf("listen: %v", err)
162162+ }
163163+ addr := ln.Addr().String()
164164+ ln.Close()
165165+166166+ memberHash := MemberHashFromDID("did:plc:testmember")
167167+ memberLookup := func(ctx context.Context, hash string) (string, bool) {
168168+ if hash == memberHash {
169169+ return "did:plc:testmember", true
170170+ }
171171+ return "", false
172172+ }
173173+174174+ srv := NewInboundServer(InboundConfig{
175175+ ListenAddr: addr,
176176+ Domain: "atmos.email",
177177+ // Tight rate so the test runs fast: 60/min = 1/sec, burst=2.
178178+ // Three messages in immediate succession exhausts the bucket
179179+ // before the next token is minted.
180180+ RateLimitMsgsPerMinute: 60,
181181+ RateLimitBurst: 2,
182182+ }, func(ctx context.Context, did, rcpt, btype, details string) {}, memberLookup)
183183+ metrics := newRecordingMetrics()
184184+ srv.SetMetrics(metrics)
185185+186186+ go srv.ListenAndServe()
187187+ defer srv.Close()
188188+189189+ for i := 0; i < 50; i++ {
190190+ c, err := net.DialTimeout("tcp", addr, 100*time.Millisecond)
191191+ if err == nil {
192192+ c.Close()
193193+ break
194194+ }
195195+ time.Sleep(10 * time.Millisecond)
196196+ }
197197+198198+ rcptHash := RecipientHashFromAddr("user@example.com")
199199+ verp := fmt.Sprintf("bounces+%s+%s@atmos.email", memberHash, rcptHash)
200200+ body := []byte("From: a@b\r\nTo: " + verp + "\r\n\r\nbody\r\n")
201201+202202+ // Each call opens a fresh TCP session; same client IP (127.0.0.1).
203203+ // The first two should succeed, the third should get 421.
204204+ send := func() error {
205205+ return gosmtp.SendMail(addr, nil, "a@b.example", []string{verp}, body)
206206+ }
207207+208208+ if err := send(); err != nil {
209209+ t.Fatalf("call 1: %v", err)
210210+ }
211211+ if err := send(); err != nil {
212212+ t.Fatalf("call 2: %v", err)
213213+ }
214214+ err = send()
215215+ if err == nil {
216216+ t.Fatal("call 3: expected rate-limit error, got success")
217217+ }
218218+ // net/smtp surfaces the server reply verbatim in err.Error();
219219+ // substring-check the 421 code rather than asserting on a private type.
220220+ if !strings.Contains(err.Error(), "421") {
221221+ t.Errorf("err = %v, want to contain 421", err)
222222+ }
223223+ // Metric must have been incremented for the rejection.
224224+ if got := metrics.rejectedCount("rate_limit"); got != 1 {
225225+ t.Errorf("rejected[rate_limit] = %d, want 1", got)
226226+ }
227227+}
228228+229229+// TestInbound_RateLimitDisabledByDefault verifies legacy behavior: when
230230+// RateLimitMsgsPerMinute is 0 (default), no rejections occur regardless
231231+// of burst.
232232+func TestInbound_RateLimitDisabledByDefault(t *testing.T) {
233233+ ln, err := net.Listen("tcp", "127.0.0.1:0")
234234+ if err != nil {
235235+ t.Fatalf("listen: %v", err)
236236+ }
237237+ addr := ln.Addr().String()
238238+ ln.Close()
239239+240240+ memberHash := MemberHashFromDID("did:plc:testmember")
241241+ memberLookup := func(ctx context.Context, hash string) (string, bool) {
242242+ if hash == memberHash {
243243+ return "did:plc:testmember", true
244244+ }
245245+ return "", false
246246+ }
247247+ var bounces atomic.Int32
248248+ onBounce := func(ctx context.Context, did, rcpt, btype, details string) {
249249+ bounces.Add(1)
250250+ }
251251+252252+ srv := NewInboundServer(InboundConfig{
253253+ ListenAddr: addr,
254254+ Domain: "atmos.email",
255255+ // Rate-limit explicitly NOT set → disabled.
256256+ }, onBounce, memberLookup)
257257+ go srv.ListenAndServe()
258258+ defer srv.Close()
259259+260260+ for i := 0; i < 50; i++ {
261261+ c, err := net.DialTimeout("tcp", addr, 100*time.Millisecond)
262262+ if err == nil {
263263+ c.Close()
264264+ break
265265+ }
266266+ time.Sleep(10 * time.Millisecond)
267267+ }
268268+269269+ rcptHash := RecipientHashFromAddr("user@example.com")
270270+ verp := fmt.Sprintf("bounces+%s+%s@atmos.email", memberHash, rcptHash)
271271+ dsn := "Content-Type: multipart/report; report-type=delivery-status; boundary=\"b1\"\r\n" +
272272+ "From: mailer-daemon@example.com\r\n" +
273273+ "To: " + verp + "\r\n" +
274274+ "\r\n" +
275275+ "--b1\r\n" +
276276+ "Content-Type: text/plain\r\n\r\n" +
277277+ "failed\r\n--b1\r\n" +
278278+ "Content-Type: message/delivery-status\r\n\r\n" +
279279+ "Final-Recipient: rfc822; user@example.com\r\nAction: failed\r\nStatus: 5.1.1\r\n\r\n--b1--\r\n"
280280+281281+ for i := 0; i < 10; i++ {
282282+ if err := gosmtp.SendMail(addr, nil, "mailer-daemon@example.com", []string{verp}, []byte(dsn)); err != nil {
283283+ t.Fatalf("send %d: %v", i+1, err)
284284+ }
285285+ }
286286+}
+3-3
internal/relay/labelcheck.go
···111111 verified, err := lc.queryLabeler(ctx, did)
112112 if err != nil {
113113 // Fail-closed: if labeler is unreachable and cache is expired, return error
114114- return false, fmt.Errorf("label check failed (fail-closed): %v", err)
114114+ return false, fmt.Errorf("label check failed (fail-closed): %w", err)
115115 }
116116117117 // Update cache
···214214215215 resp, err := lc.client.Do(req)
216216 if err != nil {
217217- return false, fmt.Errorf("labeler request: %v", err)
217217+ return false, fmt.Errorf("labeler request: %w", err)
218218 }
219219 defer resp.Body.Close()
220220···224224225225 var result queryLabelsResponse
226226 if err := json.NewDecoder(io.LimitReader(resp.Body, 1<<20)).Decode(&result); err != nil {
227227- return false, fmt.Errorf("decode labeler response: %v", err)
227227+ return false, fmt.Errorf("decode labeler response: %w", err)
228228 }
229229230230 // Check that all required labels are present and not negated
+239
internal/relay/memberhash_cache.go
···11+// SPDX-License-Identifier: AGPL-3.0-or-later
22+33+package relay
44+55+import (
66+ "context"
77+ "sync"
88+ "time"
99+)
1010+1111+// MemberHashCache answers VERP "is this hash a member?" queries from a
1212+// process-local cache. The previous implementation rebuilt the cache from the
1313+// full members table on every miss, so a sender pumping random VERP local
1414+// parts at port 25 could trigger an O(N) full-table scan per inbound message
1515+// and DoS the relay. See #218.
1616+//
1717+// This cache adds two defenses:
1818+//
1919+// 1. Negative cache. A hash that resolved to "no such member" stays
2020+// non-existent for negTTL (default 5 min) — repeat misses for the same
2121+// fake hash become O(1).
2222+// 2. Rebuild rate limit. The positive cache only rebuilds every
2323+// rebuildInterval (default 30 s). Repeated misses for *different* fake
2424+// hashes can no longer trigger a stampede of full-table scans. New
2525+// enrollments are also picked up by a periodic background rebuild
2626+// (PeriodicRebuild) so the rebuild-on-miss path is no longer the only
2727+// freshness mechanism.
2828+type MemberHashCache struct {
2929+ mu sync.RWMutex
3030+ positive map[string]string // hash → DID
3131+ negative map[string]time.Time // hash → expiry
3232+ lastRebuild time.Time
3333+3434+ rebuildInterval time.Duration
3535+ negTTL time.Duration
3636+ maxNeg int
3737+3838+ rebuild func() (map[string]string, error)
3939+ now func() time.Time
4040+ metrics MemberHashMetrics
4141+}
4242+4343+// MemberHashMetrics is the narrow metrics surface used by MemberHashCache.
4444+// Implementations record counts to Prometheus; nil-safe in tests.
4545+type MemberHashMetrics interface {
4646+ IncMemberHashHit() // positive cache hit
4747+ IncMemberHashNegHit() // negative cache hit (DoS short-circuit)
4848+ IncMemberHashMiss() // confirmed miss after rebuild
4949+ IncMemberHashRebuild() // a rebuild ran
5050+ IncMemberHashRebuildSkip() // rebuild rate-limited
5151+ SetMemberHashSize(positive, negative int)
5252+}
5353+5454+// MemberHashCacheConfig configures a MemberHashCache.
5555+type MemberHashCacheConfig struct {
5656+ // Rebuild loads the current positive cache from the source of truth.
5757+ // Required.
5858+ Rebuild func() (map[string]string, error)
5959+ // RebuildInterval is the minimum gap between successive rebuilds. Default 30s.
6060+ RebuildInterval time.Duration
6161+ // NegTTL is how long a hash stays in the negative cache. Default 5min.
6262+ NegTTL time.Duration
6363+ // MaxNegative caps the negative-cache size. Default 10000.
6464+ MaxNegative int
6565+ // Now overrides time.Now (for tests). Default time.Now.
6666+ Now func() time.Time
6767+ // Metrics receives counters/gauges. nil → no-op.
6868+ Metrics MemberHashMetrics
6969+}
7070+7171+// NewMemberHashCache builds a lookup from cfg. It performs the initial
7272+// rebuild synchronously so the cache is warm before the first request.
7373+func NewMemberHashCache(cfg MemberHashCacheConfig) *MemberHashCache {
7474+ if cfg.Rebuild == nil {
7575+ panic("MemberHashCache: Rebuild is required")
7676+ }
7777+ if cfg.RebuildInterval == 0 {
7878+ cfg.RebuildInterval = 30 * time.Second
7979+ }
8080+ if cfg.NegTTL == 0 {
8181+ cfg.NegTTL = 5 * time.Minute
8282+ }
8383+ if cfg.MaxNegative == 0 {
8484+ cfg.MaxNegative = 10000
8585+ }
8686+ if cfg.Now == nil {
8787+ cfg.Now = time.Now
8888+ }
8989+ if cfg.Metrics == nil {
9090+ cfg.Metrics = noopMemberHashMetrics{}
9191+ }
9292+ h := &MemberHashCache{
9393+ positive: map[string]string{},
9494+ negative: map[string]time.Time{},
9595+ rebuildInterval: cfg.RebuildInterval,
9696+ negTTL: cfg.NegTTL,
9797+ maxNeg: cfg.MaxNegative,
9898+ rebuild: cfg.Rebuild,
9999+ now: cfg.Now,
100100+ metrics: cfg.Metrics,
101101+ }
102102+ // Initial warm-up — block until the first load completes so we don't
103103+ // serve traffic with an empty positive cache.
104104+ h.runRebuild(true)
105105+ return h
106106+}
107107+108108+// Lookup returns (DID, true) for known members, ("", false) otherwise.
109109+// Negative-cached misses short-circuit without touching the store.
110110+func (h *MemberHashCache) Lookup(hash string) (string, bool) {
111111+ now := h.now()
112112+113113+ h.mu.RLock()
114114+ if did, ok := h.positive[hash]; ok {
115115+ h.mu.RUnlock()
116116+ h.metrics.IncMemberHashHit()
117117+ return did, true
118118+ }
119119+ if exp, ok := h.negative[hash]; ok && now.Before(exp) {
120120+ h.mu.RUnlock()
121121+ h.metrics.IncMemberHashNegHit()
122122+ return "", false
123123+ }
124124+ mayRebuild := now.Sub(h.lastRebuild) >= h.rebuildInterval
125125+ h.mu.RUnlock()
126126+127127+ if mayRebuild {
128128+ h.runRebuild(false)
129129+ h.mu.RLock()
130130+ if did, ok := h.positive[hash]; ok {
131131+ h.mu.RUnlock()
132132+ h.metrics.IncMemberHashHit()
133133+ return did, true
134134+ }
135135+ h.mu.RUnlock()
136136+ }
137137+138138+ h.recordMiss(hash, now)
139139+ h.metrics.IncMemberHashMiss()
140140+ return "", false
141141+}
142142+143143+// runRebuild reloads the positive cache. force=true bypasses the interval
144144+// gate (used at construction). When the gate fires, the rebuild is skipped.
145145+func (h *MemberHashCache) runRebuild(force bool) {
146146+ h.mu.Lock()
147147+ if !force && h.now().Sub(h.lastRebuild) < h.rebuildInterval {
148148+ h.mu.Unlock()
149149+ h.metrics.IncMemberHashRebuildSkip()
150150+ return
151151+ }
152152+ h.lastRebuild = h.now()
153153+ h.mu.Unlock()
154154+155155+ newMap, err := h.rebuild()
156156+ if err != nil {
157157+ // Keep the old positive map; a transient store error shouldn't
158158+ // blow away cached members. The next interval will retry.
159159+ return
160160+ }
161161+162162+ h.mu.Lock()
163163+ h.positive = newMap
164164+ // Drop negative entries that are now positive — happens when a member
165165+ // enrolls between our last rebuild and now.
166166+ for hash := range newMap {
167167+ delete(h.negative, hash)
168168+ }
169169+ posLen, negLen := len(h.positive), len(h.negative)
170170+ h.mu.Unlock()
171171+172172+ h.metrics.IncMemberHashRebuild()
173173+ h.metrics.SetMemberHashSize(posLen, negLen)
174174+}
175175+176176+// recordMiss inserts a negative-cache entry, evicting if at capacity.
177177+func (h *MemberHashCache) recordMiss(hash string, now time.Time) {
178178+ h.mu.Lock()
179179+ defer h.mu.Unlock()
180180+181181+ if len(h.negative) >= h.maxNeg {
182182+ // First sweep: drop expired entries.
183183+ for k, exp := range h.negative {
184184+ if !exp.After(now) {
185185+ delete(h.negative, k)
186186+ }
187187+ }
188188+ // Still full? Drop ~10% via Go's randomized map iteration. Not a
189189+ // perfect LRU but the negative cache is purely an optimization —
190190+ // any eviction simply means the next miss for that hash takes the
191191+ // rebuild-rate-limited slow path.
192192+ if len(h.negative) >= h.maxNeg {
193193+ toDrop := h.maxNeg / 10
194194+ if toDrop < 1 {
195195+ toDrop = 1
196196+ }
197197+ for k := range h.negative {
198198+ delete(h.negative, k)
199199+ toDrop--
200200+ if toDrop <= 0 {
201201+ break
202202+ }
203203+ }
204204+ }
205205+ }
206206+ h.negative[hash] = now.Add(h.negTTL)
207207+ h.metrics.SetMemberHashSize(len(h.positive), len(h.negative))
208208+}
209209+210210+// PeriodicRebuild runs in a goroutine and rebuilds the positive cache on
211211+// each tick, picking up newly enrolled members without waiting for a miss.
212212+func (h *MemberHashCache) PeriodicRebuild(ctx context.Context, interval time.Duration) {
213213+ ticker := time.NewTicker(interval)
214214+ defer ticker.Stop()
215215+ for {
216216+ select {
217217+ case <-ctx.Done():
218218+ return
219219+ case <-ticker.C:
220220+ h.runRebuild(true)
221221+ }
222222+ }
223223+}
224224+225225+// Sizes returns (positive, negative) counts. Test/debug helper.
226226+func (h *MemberHashCache) Sizes() (positive, negative int) {
227227+ h.mu.RLock()
228228+ defer h.mu.RUnlock()
229229+ return len(h.positive), len(h.negative)
230230+}
231231+232232+type noopMemberHashMetrics struct{}
233233+234234+func (noopMemberHashMetrics) IncMemberHashHit() {}
235235+func (noopMemberHashMetrics) IncMemberHashNegHit() {}
236236+func (noopMemberHashMetrics) IncMemberHashMiss() {}
237237+func (noopMemberHashMetrics) IncMemberHashRebuild() {}
238238+func (noopMemberHashMetrics) IncMemberHashRebuildSkip() {}
239239+func (noopMemberHashMetrics) SetMemberHashSize(_ int, _ int) {}
+255
internal/relay/memberhash_cache_test.go
···11+// SPDX-License-Identifier: AGPL-3.0-or-later
22+33+package relay
44+55+import (
66+ "errors"
77+ "sync"
88+ "sync/atomic"
99+ "testing"
1010+ "time"
1111+)
1212+1313+type fakeClock struct {
1414+ mu sync.Mutex
1515+ t time.Time
1616+}
1717+1818+func (c *fakeClock) Now() time.Time { c.mu.Lock(); defer c.mu.Unlock(); return c.t }
1919+func (c *fakeClock) Advance(d time.Duration) {
2020+ c.mu.Lock()
2121+ defer c.mu.Unlock()
2222+ c.t = c.t.Add(d)
2323+}
2424+2525+type memberHashCacheMetrics struct {
2626+ hit, neg, miss, rebuild, rebuildSkip atomic.Int64
2727+}
2828+2929+func (m *memberHashCacheMetrics) IncMemberHashHit() { m.hit.Add(1) }
3030+func (m *memberHashCacheMetrics) IncMemberHashNegHit() { m.neg.Add(1) }
3131+func (m *memberHashCacheMetrics) IncMemberHashMiss() { m.miss.Add(1) }
3232+func (m *memberHashCacheMetrics) IncMemberHashRebuild() { m.rebuild.Add(1) }
3333+func (m *memberHashCacheMetrics) IncMemberHashRebuildSkip(){ m.rebuildSkip.Add(1) }
3434+func (m *memberHashCacheMetrics) SetMemberHashSize(_ int, _ int) {}
3535+3636+func newCacheForTest(t *testing.T, members map[string]string, clock *fakeClock) (*MemberHashCache, *atomic.Int64, *memberHashCacheMetrics) {
3737+ t.Helper()
3838+ var rebuildCalls atomic.Int64
3939+ mu := sync.Mutex{}
4040+ current := members
4141+ rebuild := func() (map[string]string, error) {
4242+ rebuildCalls.Add(1)
4343+ mu.Lock()
4444+ defer mu.Unlock()
4545+ out := make(map[string]string, len(current))
4646+ for k, v := range current {
4747+ out[k] = v
4848+ }
4949+ return out, nil
5050+ }
5151+ mx := &memberHashCacheMetrics{}
5252+ h := NewMemberHashCache(MemberHashCacheConfig{
5353+ Rebuild: rebuild,
5454+ RebuildInterval: 30 * time.Second,
5555+ NegTTL: 5 * time.Minute,
5656+ MaxNegative: 100,
5757+ Now: clock.Now,
5858+ Metrics: mx,
5959+ })
6060+ return h, &rebuildCalls, mx
6161+}
6262+6363+func TestMemberHashCache_Hit(t *testing.T) {
6464+ clock := &fakeClock{t: time.Unix(1_700_000_000, 0)}
6565+ h, _, mx := newCacheForTest(t, map[string]string{"hash-a": "did:plc:aaa"}, clock)
6666+6767+ did, ok := h.Lookup("hash-a")
6868+ if !ok || did != "did:plc:aaa" {
6969+ t.Fatalf("got (%q,%v), want (did:plc:aaa,true)", did, ok)
7070+ }
7171+ if mx.hit.Load() != 1 {
7272+ t.Errorf("hit counter=%d, want 1", mx.hit.Load())
7373+ }
7474+}
7575+7676+func TestMemberHashCache_NegativeCacheShortCircuits(t *testing.T) {
7777+ // The DoS regression fix: a fake hash should NOT trigger a full-table
7878+ // rebuild on every subsequent miss within the negative TTL.
7979+ clock := &fakeClock{t: time.Unix(1_700_000_000, 0)}
8080+ h, rebuildCalls, mx := newCacheForTest(t, map[string]string{}, clock)
8181+8282+ initialRebuilds := rebuildCalls.Load() // 1 from constructor
8383+8484+ // First miss: triggers a rebuild (the bucket allowed it because
8585+ // lastRebuild is "now" but constructor used force=true so let's wait
8686+ // past the interval first).
8787+ clock.Advance(31 * time.Second)
8888+ if _, ok := h.Lookup("attacker-fake-1"); ok {
8989+ t.Fatal("unexpected hit for fake hash")
9090+ }
9191+ if mx.miss.Load() != 1 {
9292+ t.Errorf("miss counter=%d, want 1", mx.miss.Load())
9393+ }
9494+ if rebuildCalls.Load() != initialRebuilds+1 {
9595+ t.Errorf("rebuild calls=%d, want %d (one rebuild after interval)", rebuildCalls.Load(), initialRebuilds+1)
9696+ }
9797+9898+ // 1000 follow-up lookups for the SAME fake hash within the negTTL must
9999+ // short-circuit on the negative cache.
100100+ for i := 0; i < 1000; i++ {
101101+ if _, ok := h.Lookup("attacker-fake-1"); ok {
102102+ t.Fatal("unexpected hit for fake hash on repeat lookup")
103103+ }
104104+ }
105105+ if mx.neg.Load() != 1000 {
106106+ t.Errorf("negative-cache hits=%d, want 1000", mx.neg.Load())
107107+ }
108108+ if got := rebuildCalls.Load(); got != initialRebuilds+1 {
109109+ t.Errorf("rebuild fired during negative-cached lookups: got=%d, want %d", got, initialRebuilds+1)
110110+ }
111111+}
112112+113113+func TestMemberHashCache_RebuildIsRateLimited(t *testing.T) {
114114+ // 100 distinct fake hashes within a single 30s window should produce at
115115+ // most 1 rebuild — the rate limit prevents stampede.
116116+ clock := &fakeClock{t: time.Unix(1_700_000_000, 0)}
117117+ h, rebuildCalls, _ := newCacheForTest(t, map[string]string{}, clock)
118118+ initial := rebuildCalls.Load()
119119+120120+ // First lookup post-construction has lastRebuild = now, so within
121121+ // interval — rebuild should be skipped on the first call too.
122122+ for i := 0; i < 100; i++ {
123123+ hash := "fake-" + string(rune('A'+i%26)) + string(rune('0'+i/26))
124124+ h.Lookup(hash)
125125+ }
126126+127127+ rebuilds := rebuildCalls.Load() - initial
128128+ if rebuilds > 1 {
129129+ t.Errorf("rebuild fired %d times within interval, want ≤1", rebuilds)
130130+ }
131131+}
132132+133133+func TestMemberHashCache_NegativeTTLExpiresAndRebuildAdmitsNewMember(t *testing.T) {
134134+ clock := &fakeClock{t: time.Unix(1_700_000_000, 0)}
135135+ members := map[string]string{}
136136+ mu := sync.Mutex{}
137137+ rebuild := func() (map[string]string, error) {
138138+ mu.Lock()
139139+ defer mu.Unlock()
140140+ out := map[string]string{}
141141+ for k, v := range members {
142142+ out[k] = v
143143+ }
144144+ return out, nil
145145+ }
146146+ h := NewMemberHashCache(MemberHashCacheConfig{
147147+ Rebuild: rebuild,
148148+ RebuildInterval: 30 * time.Second,
149149+ NegTTL: 5 * time.Minute,
150150+ MaxNegative: 100,
151151+ Now: clock.Now,
152152+ })
153153+154154+ // 1) Member doesn't exist yet — first lookup misses + caches negatively.
155155+ clock.Advance(31 * time.Second)
156156+ if _, ok := h.Lookup("hash-late"); ok {
157157+ t.Fatal("unexpected hit before enrollment")
158158+ }
159159+160160+ // 2) Member enrolls (in the source) and the negative cache is still hot.
161161+ mu.Lock()
162162+ members["hash-late"] = "did:plc:late"
163163+ mu.Unlock()
164164+165165+ // Within negTTL but past rebuildInterval: lookup should hit the
166166+ // negative cache and NOT see the new member yet.
167167+ clock.Advance(31 * time.Second)
168168+ if _, ok := h.Lookup("hash-late"); ok {
169169+ t.Error("negative cache failed to short-circuit during TTL")
170170+ }
171171+172172+ // 3) After negTTL expires AND rebuild interval has passed, the lookup
173173+ // should rebuild and admit the new member.
174174+ clock.Advance(6 * time.Minute) // past 5min negTTL
175175+ did, ok := h.Lookup("hash-late")
176176+ if !ok || did != "did:plc:late" {
177177+ t.Errorf("late enrollment not picked up: did=%q ok=%v", did, ok)
178178+ }
179179+}
180180+181181+func TestMemberHashCache_RebuildErrorPreservesPositiveCache(t *testing.T) {
182182+ clock := &fakeClock{t: time.Unix(1_700_000_000, 0)}
183183+ rebuild := func() (map[string]string, error) {
184184+ return nil, errors.New("transient db error")
185185+ }
186186+ // Construct with a working rebuild first, so we have something cached.
187187+ cfg := MemberHashCacheConfig{
188188+ Rebuild: func() (map[string]string, error) {
189189+ return map[string]string{"hash-a": "did:plc:aaa"}, nil
190190+ },
191191+ RebuildInterval: 30 * time.Second,
192192+ NegTTL: 5 * time.Minute,
193193+ MaxNegative: 100,
194194+ Now: clock.Now,
195195+ }
196196+ h := NewMemberHashCache(cfg)
197197+198198+ // Swap in a failing rebuild and force it to run.
199199+ h.rebuild = rebuild
200200+ clock.Advance(31 * time.Second)
201201+ h.runRebuild(true)
202202+203203+ // Positive cache must still work despite the rebuild error.
204204+ did, ok := h.Lookup("hash-a")
205205+ if !ok || did != "did:plc:aaa" {
206206+ t.Errorf("positive cache lost on rebuild error: did=%q ok=%v", did, ok)
207207+ }
208208+}
209209+210210+func TestMemberHashCache_NegativeCapEvictsExpiredFirst(t *testing.T) {
211211+ clock := &fakeClock{t: time.Unix(1_700_000_000, 0)}
212212+ rebuild := func() (map[string]string, error) { return map[string]string{}, nil }
213213+ h := NewMemberHashCache(MemberHashCacheConfig{
214214+ Rebuild: rebuild,
215215+ RebuildInterval: 30 * time.Second,
216216+ NegTTL: 1 * time.Minute,
217217+ MaxNegative: 10,
218218+ Now: clock.Now,
219219+ })
220220+221221+ // Fill negative cache with 10 entries.
222222+ clock.Advance(31 * time.Second)
223223+ for i := 0; i < 10; i++ {
224224+ h.Lookup("fake-" + string(rune('A'+i)))
225225+ }
226226+ _, neg := h.Sizes()
227227+ if neg != 10 {
228228+ t.Fatalf("negative size=%d after fill, want 10", neg)
229229+ }
230230+231231+ // All entries expire.
232232+ clock.Advance(2 * time.Minute)
233233+234234+ // Inserting one more should sweep expired and end up well below cap.
235235+ h.Lookup("fake-NEW")
236236+ _, neg = h.Sizes()
237237+ if neg > 1 {
238238+ t.Errorf("expired entries not swept: negative size=%d, want 1", neg)
239239+ }
240240+}
241241+242242+func TestMemberHashCache_HitDoesNotTriggerRebuild(t *testing.T) {
243243+ clock := &fakeClock{t: time.Unix(1_700_000_000, 0)}
244244+ h, rebuildCalls, _ := newCacheForTest(t, map[string]string{"hash-a": "did:plc:aaa"}, clock)
245245+ initial := rebuildCalls.Load()
246246+247247+ // 1000 hits should never re-rebuild.
248248+ clock.Advance(10 * time.Minute)
249249+ for i := 0; i < 1000; i++ {
250250+ h.Lookup("hash-a")
251251+ }
252252+ if rebuildCalls.Load() != initial {
253253+ t.Errorf("rebuild fired during pure hit traffic: %d→%d", initial, rebuildCalls.Load())
254254+ }
255255+}
+230-4
internal/relay/metrics.go
···2121 BouncesTotal *prometheus.CounterVec // type: hard, soft
2222 AuthAttempts *prometheus.CounterVec // result: success, failure
2323 RateLimitHits *prometheus.CounterVec // limit_type: hourly, daily, global
2424+ OrphanDeliveries *prometheus.CounterVec // status: sent, bounced — delivery callbacks for missing DB rows (#208)
2525+ OrphanReconciled prometheus.Counter // status=queued rows the janitor closed because no spool file exists (#208)
2626+ GoroutineCrashes *prometheus.CounterVec // name — recovered panics in background goroutines (#209)
2727+2828+ // Multi-recipient SMTP DATA outcomes (#226). When a single DATA fans out
2929+ // to N recipients and a subset fail to enqueue, we accept the DATA (250)
3030+ // to avoid duplicating the successful recipients on client retry, and
3131+ // instead surface the failures here.
3232+ PartialDeliveries prometheus.Counter // DATA accepted with at least one recipient failed
3333+ PartialDeliveryRecipients *prometheus.CounterVec // outcome: succeeded, failed — per-recipient counts inside a partial-delivery DATA
3434+3535+ // Member-hash cache (#218). Negative cache + rebuild rate-limit defend
3636+ // against random-VERP DoS at port 25.
3737+ MemberHashLookups *prometheus.CounterVec // outcome: hit, neg_hit, miss
3838+ MemberHashRebuilds *prometheus.CounterVec // outcome: ran, skipped
3939+ MemberHashCacheSize *prometheus.GaugeVec // kind: positive, negative
24402541 // HTTP request tracking
2642 HTTPRequestsTotal *prometheus.CounterVec // host, method, path, status
···3551 LabelerReachable prometheus.Gauge
3652 OspreyReachable prometheus.Gauge
37535454+ // SQLite connection-pool observability (#210). Gauges sampled
5555+ // from sql.DB.Stats() periodically; counters incremented when a
5656+ // returned error matches the SQLITE_BUSY/locked signature.
5757+ SQLiteOpenConnections prometheus.Gauge
5858+ SQLiteInUse prometheus.Gauge
5959+ SQLiteIdle prometheus.Gauge
6060+ SQLiteWaitCount prometheus.Gauge // cumulative since process start
6161+ SQLiteWaitDurationSec prometheus.Gauge // cumulative seconds since process start
6262+ SQLiteBusyErrors *prometheus.CounterVec // op: insert, update, query, exec — best-effort classification at hot writers
6363+3864 // Osprey enforcement counters
3965 OspreyChecksTotal *prometheus.CounterVec // result: allowed, blocked
40664167 // Osprey event emission counters
4242- OspreyEventsEmitted *prometheus.CounterVec // event_type
4343- OspreyEventsFailed *prometheus.CounterVec // event_type
6868+ OspreyEventsEmitted *prometheus.CounterVec // event_type
6969+ OspreyEventsFailed *prometheus.CounterVec // event_type
7070+ OspreyEventsSpooled *prometheus.CounterVec // event_type — events landed in the on-disk DLQ (#214)
7171+ OspreyEventsReplayed *prometheus.CounterVec // event_type — DLQ entries that finally reached the broker (#214)
7272+ OspreyEventsDropped *prometheus.CounterVec // reason — permanent loss (overflow, corrupt) (#214)
7373+ OspreyDisabled prometheus.Gauge // 1 when the emitter is Noop (Kafka misconfigured), 0 when active (#214)
7474+ OspreySpoolDepth prometheus.Gauge // current DLQ size (#214)
7575+ OspreyColdCacheDecisions *prometheus.CounterVec // decision: allowed, denied — fires when Osprey is unreachable AND no cache entry (#215)
44764577 // FBL/ARF complaint tracking
4678 ComplaintsTotal *prometheus.CounterVec // feedback_type, provider
···5587 // Inbound mail classification + forwarding (Phase 1b)
5688 InboundMessages *prometheus.CounterVec // classification: verp_bounce, srs_bounce, reply, postmaster
5789 RepliesForwarded *prometheus.CounterVec // status: sent, failed
9090+ InboundRejected *prometheus.CounterVec // reason: rate_limit
58915992 // Osprey events consumer health
6093 EventsConsumerLastIngestTimestamp prometheus.Gauge // Unix timestamp of last successful consume
···80113 Name: "atmosphere_relay_delivery_attempts_total",
81114 Help: "Total delivery attempts, by outcome.",
82115 }, []string{"status"}),
116116+ OrphanDeliveries: prometheus.NewCounterVec(prometheus.CounterOpts{
117117+ Name: "atmosphere_relay_orphan_deliveries_total",
118118+ Help: "Delivery callbacks for spool entries with no backing messages row (#208).",
119119+ }, []string{"status"}),
120120+ OrphanReconciled: prometheus.NewCounter(prometheus.CounterOpts{
121121+ Name: "atmosphere_relay_orphan_reconciled_total",
122122+ Help: "Queued message rows closed by the orphan-reconciliation janitor because no spool file exists (#208).",
123123+ }),
124124+ GoroutineCrashes: prometheus.NewCounterVec(prometheus.CounterOpts{
125125+ Name: "atmosphere_relay_goroutine_crashes_total",
126126+ Help: "Background goroutine panics recovered by GoSafe, by goroutine name (#209).",
127127+ }, []string{"name"}),
128128+ PartialDeliveries: prometheus.NewCounter(prometheus.CounterOpts{
129129+ Name: "atmosphere_relay_partial_deliveries_total",
130130+ Help: "Multi-RCPT DATA messages accepted with at least one recipient failing to enqueue (#226).",
131131+ }),
132132+ PartialDeliveryRecipients: prometheus.NewCounterVec(prometheus.CounterOpts{
133133+ Name: "atmosphere_relay_partial_delivery_recipients_total",
134134+ Help: "Per-recipient outcomes inside multi-RCPT DATA messages, by outcome (#226).",
135135+ }, []string{"outcome"}),
136136+ MemberHashLookups: prometheus.NewCounterVec(prometheus.CounterOpts{
137137+ Name: "atmosphere_relay_member_hash_lookups_total",
138138+ Help: "Inbound VERP member-hash lookups, by outcome (#218).",
139139+ }, []string{"outcome"}),
140140+ MemberHashRebuilds: prometheus.NewCounterVec(prometheus.CounterOpts{
141141+ Name: "atmosphere_relay_member_hash_rebuilds_total",
142142+ Help: "Member-hash cache rebuilds, by outcome (#218).",
143143+ }, []string{"outcome"}),
144144+ MemberHashCacheSize: prometheus.NewGaugeVec(prometheus.GaugeOpts{
145145+ Name: "atmosphere_relay_member_hash_cache_size",
146146+ Help: "Member-hash cache size, by kind (positive=enrolled members, negative=cached misses) (#218).",
147147+ }, []string{"kind"}),
148148+ SQLiteOpenConnections: prometheus.NewGauge(prometheus.GaugeOpts{
149149+ Name: "atmosphere_relay_sqlite_open_connections",
150150+ Help: "sql.DB.Stats().OpenConnections — total connections open to SQLite (#210).",
151151+ }),
152152+ SQLiteInUse: prometheus.NewGauge(prometheus.GaugeOpts{
153153+ Name: "atmosphere_relay_sqlite_in_use",
154154+ Help: "sql.DB.Stats().InUse — connections currently checked out and busy executing a query (#210).",
155155+ }),
156156+ SQLiteIdle: prometheus.NewGauge(prometheus.GaugeOpts{
157157+ Name: "atmosphere_relay_sqlite_idle",
158158+ Help: "sql.DB.Stats().Idle — connections currently idle in the pool (#210).",
159159+ }),
160160+ SQLiteWaitCount: prometheus.NewGauge(prometheus.GaugeOpts{
161161+ Name: "atmosphere_relay_sqlite_wait_count",
162162+ Help: "sql.DB.Stats().WaitCount — cumulative number of connections that had to wait for a free slot (#210).",
163163+ }),
164164+ SQLiteWaitDurationSec: prometheus.NewGauge(prometheus.GaugeOpts{
165165+ Name: "atmosphere_relay_sqlite_wait_duration_seconds",
166166+ Help: "sql.DB.Stats().WaitDuration — cumulative seconds waited for a free connection (#210).",
167167+ }),
168168+ SQLiteBusyErrors: prometheus.NewCounterVec(prometheus.CounterOpts{
169169+ Name: "atmosphere_relay_sqlite_busy_errors_total",
170170+ Help: "SQLite errors classified as SQLITE_BUSY/locked at hot-path writers (#210).",
171171+ }, []string{"op"}),
83172 BouncesTotal: prometheus.NewCounterVec(prometheus.CounterOpts{
84173 Name: "atmosphere_relay_bounces_total",
85174 Help: "Total bounces received, by type.",
···125214 Name: "atmosphere_relay_osprey_checks_total",
126215 Help: "Osprey enforcement checks, by result.",
127216 }, []string{"result"}),
217217+ OspreyEventsSpooled: prometheus.NewCounterVec(prometheus.CounterOpts{
218218+ Name: "atmosphere_relay_osprey_events_spooled_total",
219219+ Help: "Osprey events that failed to reach Kafka and were spooled to disk for replay (#214).",
220220+ }, []string{"event_type"}),
221221+ OspreyEventsReplayed: prometheus.NewCounterVec(prometheus.CounterOpts{
222222+ Name: "atmosphere_relay_osprey_events_replayed_total",
223223+ Help: "Osprey events drained from the on-disk DLQ back to Kafka (#214).",
224224+ }, []string{"event_type"}),
225225+ OspreyEventsDropped: prometheus.NewCounterVec(prometheus.CounterOpts{
226226+ Name: "atmosphere_relay_osprey_events_dropped_total",
227227+ Help: "Osprey events permanently lost (DLQ overflow, corrupt entries) (#214).",
228228+ }, []string{"reason"}),
229229+ OspreyDisabled: prometheus.NewGauge(prometheus.GaugeOpts{
230230+ Name: "atmosphere_relay_osprey_disabled",
231231+ Help: "1 if the Osprey emitter is configured as Noop (Kafka broker missing); 0 if active (#214).",
232232+ }),
233233+ OspreySpoolDepth: prometheus.NewGauge(prometheus.GaugeOpts{
234234+ Name: "atmosphere_relay_osprey_spool_depth",
235235+ Help: "Number of events currently sitting in the Osprey on-disk DLQ awaiting replay (#214).",
236236+ }),
237237+ OspreyColdCacheDecisions: prometheus.NewCounterVec(prometheus.CounterOpts{
238238+ Name: "atmosphere_relay_osprey_cold_cache_decisions_total",
239239+ Help: "Cold-cache+unreachable enforcer decisions, by outcome (denied=fail-closed, allowed=fail-open) (#215).",
240240+ }, []string{"decision"}),
128241 OspreyEventsEmitted: prometheus.NewCounterVec(prometheus.CounterOpts{
129242 Name: "atmosphere_relay_osprey_events_emitted_total",
130243 Help: "Osprey events confirmed by Kafka broker (post-Completion), by event type.",
···153266 Name: "atmosphere_relay_replies_forwarded_total",
154267 Help: "Outcome of reply-forwarding attempts, by status.",
155268 }, []string{"status"}),
269269+ InboundRejected: prometheus.NewCounterVec(prometheus.CounterOpts{
270270+ Name: "atmosphere_relay_inbound_rejected_total",
271271+ Help: "Inbound SMTP sessions rejected before classification, by reason.",
272272+ }, []string{"reason"}),
156273 EventsConsumerLastIngestTimestamp: prometheus.NewGauge(prometheus.GaugeOpts{
157274 Name: "atmosphere_relay_events_consumer_last_ingest_timestamp_seconds",
158275 Help: "Unix timestamp of the last successfully consumed Osprey event.",
···173290 m.MessagesSent,
174291 m.DeliveryAttempts,
175292 m.BouncesTotal,
293293+ m.OrphanDeliveries,
294294+ m.OrphanReconciled,
295295+ m.GoroutineCrashes,
296296+ m.PartialDeliveries,
297297+ m.PartialDeliveryRecipients,
298298+ m.MemberHashLookups,
299299+ m.MemberHashRebuilds,
300300+ m.MemberHashCacheSize,
301301+ m.SQLiteOpenConnections,
302302+ m.SQLiteInUse,
303303+ m.SQLiteIdle,
304304+ m.SQLiteWaitCount,
305305+ m.SQLiteWaitDurationSec,
306306+ m.SQLiteBusyErrors,
176307 m.AuthAttempts,
177308 m.RateLimitHits,
178309 m.DeliveryQueueDepth,
···182313 m.OspreyChecksTotal,
183314 m.OspreyEventsEmitted,
184315 m.OspreyEventsFailed,
316316+ m.OspreyEventsSpooled,
317317+ m.OspreyEventsReplayed,
318318+ m.OspreyEventsDropped,
319319+ m.OspreyDisabled,
320320+ m.OspreySpoolDepth,
321321+ m.OspreyColdCacheDecisions,
185322 m.ComplaintsTotal,
186323 m.InboundMessages,
187324 m.RepliesForwarded,
325325+ m.InboundRejected,
188326 m.HTTPRequestsTotal,
189327 m.HTTPRequestDuration,
190328 m.EnrollFunnel,
···204342 m.MessagesRejected.WithLabelValues("osprey_suspended")
205343 m.MessagesRejected.WithLabelValues("suppressed")
206344 m.MessagesRejected.WithLabelValues("smuggling_guard")
345345+ m.MessagesRejected.WithLabelValues("delivery_failed")
346346+ m.PartialDeliveryRecipients.WithLabelValues("succeeded")
347347+ m.PartialDeliveryRecipients.WithLabelValues("failed")
348348+ m.MemberHashLookups.WithLabelValues("hit")
349349+ m.MemberHashLookups.WithLabelValues("neg_hit")
350350+ m.MemberHashLookups.WithLabelValues("miss")
351351+ m.MemberHashRebuilds.WithLabelValues("ran")
352352+ m.MemberHashRebuilds.WithLabelValues("skipped")
353353+ m.MemberHashCacheSize.WithLabelValues("positive")
354354+ m.MemberHashCacheSize.WithLabelValues("negative")
207355 m.DeliveryAttempts.WithLabelValues("sent")
208356 m.DeliveryAttempts.WithLabelValues("bounced")
209357 m.DeliveryAttempts.WithLabelValues("deferred")
···290438 m.RepliesForwarded.WithLabelValues(status).Inc()
291439}
292440441441+// IncGoroutineCrash implements relay.PanicRecorder. Used by GoSafe
442442+// to count recovered panics by named goroutine (#209).
443443+func (m *Metrics) IncGoroutineCrash(name string) {
444444+ m.GoroutineCrashes.WithLabelValues(name).Inc()
445445+}
446446+447447+// IncBusyError implements relaystore.BusyRecorder. Counts SQLITE_BUSY
448448+// errors that escape the busy_timeout PRAGMA at hot-path writers (#210).
449449+func (m *Metrics) IncBusyError(op string) {
450450+ m.SQLiteBusyErrors.WithLabelValues(op).Inc()
451451+}
452452+453453+// IncColdCacheDecision implements relay.ColdCacheRecorder. Counts
454454+// fail-open vs fail-closed enforcer decisions when Osprey is
455455+// unreachable AND the labelcheck cache has no entry for the DID (#215).
456456+func (m *Metrics) IncColdCacheDecision(decision string) {
457457+ m.OspreyColdCacheDecisions.WithLabelValues(decision).Inc()
458458+}
459459+460460+// IncMemberHashHit / IncMemberHashNegHit / IncMemberHashMiss /
461461+// IncMemberHashRebuild / IncMemberHashRebuildSkip / SetMemberHashSize
462462+// implement relay.MemberHashMetrics on *Metrics so the inbound member-hash
463463+// cache (#218) can record without needing a separate adapter type.
464464+func (m *Metrics) IncMemberHashHit() { m.MemberHashLookups.WithLabelValues("hit").Inc() }
465465+func (m *Metrics) IncMemberHashNegHit() { m.MemberHashLookups.WithLabelValues("neg_hit").Inc() }
466466+func (m *Metrics) IncMemberHashMiss() { m.MemberHashLookups.WithLabelValues("miss").Inc() }
467467+func (m *Metrics) IncMemberHashRebuild() { m.MemberHashRebuilds.WithLabelValues("ran").Inc() }
468468+func (m *Metrics) IncMemberHashRebuildSkip(){ m.MemberHashRebuilds.WithLabelValues("skipped").Inc() }
469469+func (m *Metrics) SetMemberHashSize(positive, negative int) {
470470+ m.MemberHashCacheSize.WithLabelValues("positive").Set(float64(positive))
471471+ m.MemberHashCacheSize.WithLabelValues("negative").Set(float64(negative))
472472+}
473473+474474+// SetSQLiteStats updates the SQLite pool gauges from a snapshot
475475+// taken via relaystore.Store.SampleStats(). Decoupled from
476476+// *sql.DB so the metrics package doesn't take a database/sql
477477+// dependency.
478478+func (m *Metrics) SetSQLiteStats(open, inUse, idle int, waitCount int64, waitDurationSec float64) {
479479+ m.SQLiteOpenConnections.Set(float64(open))
480480+ m.SQLiteInUse.Set(float64(inUse))
481481+ m.SQLiteIdle.Set(float64(idle))
482482+ m.SQLiteWaitCount.Set(float64(waitCount))
483483+ m.SQLiteWaitDurationSec.Set(waitDurationSec)
484484+}
485485+486486+// RecordRejected implements relay.InboundMetrics.
487487+func (m *Metrics) RecordRejected(reason string) {
488488+ m.InboundRejected.WithLabelValues(reason).Inc()
489489+}
490490+293491// EmitterMetricsAdapter bridges relay.Metrics to the osprey.EmitterMetrics interface.
294492type EmitterMetricsAdapter struct {
295295- Emitted *prometheus.CounterVec
296296- Failed *prometheus.CounterVec
493493+ Emitted *prometheus.CounterVec // event_type
494494+ Failed *prometheus.CounterVec // event_type
495495+ Spooled *prometheus.CounterVec // event_type — fired when an event lands in the on-disk DLQ (#214)
496496+ Replayed *prometheus.CounterVec // event_type — fired when a spooled event finally reaches the broker (#214)
497497+ Dropped *prometheus.CounterVec // reason — fired on permanent loss (overflow, corrupt) (#214)
498498+ SpoolDepth prometheus.Gauge // current spool size (#214)
297499}
298500299501func (a *EmitterMetricsAdapter) IncEmitted(eventType string) {
···302504303505func (a *EmitterMetricsAdapter) IncFailed(eventType string) {
304506 a.Failed.WithLabelValues(eventType).Inc()
507507+}
508508+509509+func (a *EmitterMetricsAdapter) IncSpooled(eventType string) {
510510+ if a.Spooled != nil {
511511+ a.Spooled.WithLabelValues(eventType).Inc()
512512+ }
513513+}
514514+515515+func (a *EmitterMetricsAdapter) IncReplayed(eventType string) {
516516+ if a.Replayed != nil {
517517+ a.Replayed.WithLabelValues(eventType).Inc()
518518+ }
519519+}
520520+521521+func (a *EmitterMetricsAdapter) IncDropped(reason string) {
522522+ if a.Dropped != nil {
523523+ a.Dropped.WithLabelValues(reason).Inc()
524524+ }
525525+}
526526+527527+func (a *EmitterMetricsAdapter) SetSpoolDepth(n int) {
528528+ if a.SpoolDepth != nil {
529529+ a.SpoolDepth.Set(float64(n))
530530+ }
305531}
306532307533// HTTPMiddleware wraps an http.Handler to record request count and duration.
+176-9
internal/relay/ospreyenforce.go
···55import (
66 "context"
77 "encoding/json"
88+ "errors"
89 "fmt"
910 "io"
1011 "log"
1112 "net/http"
1213 "net/url"
1414+ "os"
1315 "strings"
1416 "sync"
1517 "time"
···4446 cache map[string]*ospreyEntry
45474648 flight singleflight.Group
4949+5050+ // failClosedOnColdCache, when true (default), rejects sends with
5151+ // an error when Osprey is unreachable AND no cached entry exists.
5252+ // Without this, a relay restart followed by an Osprey outage
5353+ // allows every new DID to send unsuspended for the duration of
5454+ // the outage — even DIDs Osprey would have flagged on a healthy
5555+ // query. Closes #215.
5656+ failClosedOnColdCache bool
5757+5858+ // coldCacheRecorder counts fail-open vs fail-closed decisions on
5959+ // cold cache + Osprey unreachable so operators can graph how
6060+ // often the dangerous branch fires. Optional.
6161+ coldCacheRecorder ColdCacheRecorder
6262+6363+ // snapshotPath, when non-empty, names a JSON file used to
6464+ // persist the cache across restarts. The most-common cause of
6565+ // a cold cache (relay restart with Osprey still healthy) is
6666+ // addressed by reading this file on startup; the fail-closed
6767+ // path above is the safety net for the rarer case.
6868+ snapshotPath string
4769}
48707171+// ColdCacheRecorder is the narrow interface used to count fail-open
7272+// vs fail-closed decisions. nil-safe.
7373+type ColdCacheRecorder interface {
7474+ IncColdCacheDecision(decision string)
7575+}
7676+7777+// ErrOspreyColdCache is returned by GetPolicy when the cache is empty
7878+// for a DID, Osprey is unreachable, and failClosedOnColdCache is true.
7979+// Callers translate this into a 451 SMTP deferral.
8080+var ErrOspreyColdCache = errors.New("osprey: cold cache and broker unreachable")
8181+4982type ospreyEntry struct {
5083 // activeLabels captures which labels Osprey currently has in status=1
5184 // (active) for the DID. Lookup is O(1) per label name. We store the
···8311684117// NewOspreyEnforcer creates an enforcer that queries the Osprey UI API.
85118// apiURL is the base URL, e.g. "https://osprey-api.example.com".
119119+//
120120+// Defaults to fail-CLOSED on cold cache (no entry + Osprey unreachable)
121121+// — a regression from the legacy fail-open behavior, deliberately
122122+// chosen because the cold-cache+outage window is exactly when an
123123+// attacker can register a new DID and burn reputation before Osprey
124124+// labels arrive (#215). Operators can opt back into fail-open via
125125+// SetFailClosedOnColdCache(false) if the security tradeoff doesn't
126126+// match their environment.
86127func NewOspreyEnforcer(apiURL string, client *http.Client) *OspreyEnforcer {
87128 if client == nil {
88129 client = &http.Client{Timeout: 5 * time.Second}
89130 }
90131 return &OspreyEnforcer{
9191- apiURL: apiURL,
9292- client: client,
9393- ttl: defaultOspreyEnforcerTTL,
9494- cache: make(map[string]*ospreyEntry),
132132+ apiURL: apiURL,
133133+ client: client,
134134+ ttl: defaultOspreyEnforcerTTL,
135135+ cache: make(map[string]*ospreyEntry),
136136+ failClosedOnColdCache: true,
137137+ }
138138+}
139139+140140+// SetFailClosedOnColdCache controls the cold-cache fallback. true =
141141+// reject sends with ErrOspreyColdCache when no entry exists and the
142142+// broker is unreachable; false = legacy fail-open behavior.
143143+func (e *OspreyEnforcer) SetFailClosedOnColdCache(v bool) {
144144+ e.failClosedOnColdCache = v
145145+}
146146+147147+// SetColdCacheRecorder wires a metric recorder for cold-cache decisions.
148148+func (e *OspreyEnforcer) SetColdCacheRecorder(r ColdCacheRecorder) {
149149+ e.coldCacheRecorder = r
150150+}
151151+152152+// SetSnapshotPath enables on-disk cache persistence. Snapshots are
153153+// written periodically by Snapshot() and read by LoadSnapshot() on
154154+// startup so a relay restart doesn't reset the cache to empty —
155155+// which is the load-bearing concern for #215. Pass an empty string
156156+// to disable.
157157+func (e *OspreyEnforcer) SetSnapshotPath(path string) {
158158+ e.snapshotPath = path
159159+}
160160+161161+// snapshotEntry is the on-disk representation. Keeps fetchedAt as
162162+// RFC3339 so a manual operator can read the file without code.
163163+type snapshotEntry struct {
164164+ Labels []string `json:"labels"`
165165+ FetchedAt string `json:"fetched_at"`
166166+}
167167+168168+// Snapshot writes the in-memory cache to snapshotPath atomically.
169169+// Safe to call concurrently with reads; takes a brief write lock.
170170+// No-op when snapshotPath is empty.
171171+func (e *OspreyEnforcer) Snapshot() error {
172172+ if e.snapshotPath == "" {
173173+ return nil
174174+ }
175175+ e.mu.RLock()
176176+ out := make(map[string]snapshotEntry, len(e.cache))
177177+ for did, entry := range e.cache {
178178+ labels := make([]string, 0, len(entry.activeLabels))
179179+ for l := range entry.activeLabels {
180180+ labels = append(labels, l)
181181+ }
182182+ out[did] = snapshotEntry{Labels: labels, FetchedAt: entry.fetchedAt.UTC().Format(time.RFC3339Nano)}
183183+ }
184184+ e.mu.RUnlock()
185185+186186+ data, err := json.MarshalIndent(out, "", " ")
187187+ if err != nil {
188188+ return fmt.Errorf("marshal: %w", err)
189189+ }
190190+ tmp := e.snapshotPath + ".tmp"
191191+ if err := os.WriteFile(tmp, data, 0o600); err != nil {
192192+ return fmt.Errorf("write tmp: %w", err)
193193+ }
194194+ if err := os.Rename(tmp, e.snapshotPath); err != nil {
195195+ os.Remove(tmp)
196196+ return fmt.Errorf("rename: %w", err)
95197 }
198198+ return nil
199199+}
200200+201201+// LoadSnapshot populates the cache from snapshotPath. Entries whose
202202+// fetchedAt is older than 2*ttl are discarded — they would be served
203203+// stale and we'd rather force a fresh query than serve a week-old
204204+// label set. Missing file is not an error (first start). Returns
205205+// the number of entries loaded.
206206+func (e *OspreyEnforcer) LoadSnapshot() (int, error) {
207207+ if e.snapshotPath == "" {
208208+ return 0, nil
209209+ }
210210+ data, err := os.ReadFile(e.snapshotPath)
211211+ if err != nil {
212212+ if os.IsNotExist(err) {
213213+ return 0, nil
214214+ }
215215+ return 0, fmt.Errorf("read snapshot: %w", err)
216216+ }
217217+ var raw map[string]snapshotEntry
218218+ if err := json.Unmarshal(data, &raw); err != nil {
219219+ return 0, fmt.Errorf("unmarshal: %w", err)
220220+ }
221221+ cutoff := time.Now().Add(-2 * e.ttl)
222222+ loaded := 0
223223+ e.mu.Lock()
224224+ for did, se := range raw {
225225+ fetched, err := time.Parse(time.RFC3339Nano, se.FetchedAt)
226226+ if err != nil || fetched.Before(cutoff) {
227227+ continue
228228+ }
229229+ set := make(map[string]struct{}, len(se.Labels))
230230+ for _, l := range se.Labels {
231231+ set[l] = struct{}{}
232232+ }
233233+ e.cache[did] = &ospreyEntry{activeLabels: set, fetchedAt: fetched}
234234+ loaded++
235235+ }
236236+ e.mu.Unlock()
237237+ return loaded, nil
96238}
9723998240// ospreyLabelsResponse is the shape returned by
···106248}
107249108250// GetPolicy returns the effective sending policy for a DID derived from its
109109-// current Osprey labels. Fail-stale: if Osprey is unreachable and a previous
110110-// result is cached, that cached label set is used. If there is no cache at
111111-// all, returns defaultPolicy (fail-open so new DIDs are not blocked by
112112-// observability issues).
251251+// current Osprey labels.
252252+//
253253+// Fail-stale: if Osprey is unreachable and a previous result is
254254+// cached, that cached label set is used.
255255+//
256256+// Cold cache + Osprey unreachable: returns ErrOspreyColdCache when
257257+// failClosedOnColdCache is true (default — closes #215). Operators
258258+// who prefer the legacy fail-open behavior can call
259259+// SetFailClosedOnColdCache(false), which restores the pre-#215 path
260260+// of returning defaultPolicy with no error.
113261func (e *OspreyEnforcer) GetPolicy(ctx context.Context, did string) (*LabelPolicy, error) {
114262 labels, _, err := e.activeLabelsFor(ctx, did)
115263 if err != nil {
264264+ // activeLabelsFor only returns errors for the cold-cache
265265+ // fail-closed path; transient lookup failures already fall
266266+ // back to stale cache silently. Surface the typed error so
267267+ // the SMTP layer can return 451 to the client.
116268 return defaultPolicy(), err
117269 }
118270 return policyFromLabels(labels), nil
···157309 log.Printf("osprey.enforce: did=%s serving stale cache (labels=%v)", did, labelNames(entry.activeLabels))
158310 return entry.activeLabels, true, nil
159311 }
312312+ // Cold cache + Osprey unreachable. Default behavior is now
313313+ // fail-closed (#215): without this branch, a relay restart
314314+ // during an Osprey outage would let attackers send unsuspended
315315+ // for the duration of the outage. Operators who need the
316316+ // legacy fail-open semantics opt in via SetFailClosedOnColdCache.
317317+ if e.failClosedOnColdCache {
318318+ log.Printf("osprey.enforce: did=%s action=fail_closed reason=no_cache_and_unreachable", did)
319319+ if e.coldCacheRecorder != nil {
320320+ e.coldCacheRecorder.IncColdCacheDecision("denied")
321321+ }
322322+ return nil, false, ErrOspreyColdCache
323323+ }
160324 log.Printf("osprey.enforce: did=%s action=fail_open reason=no_cache_and_unreachable", did)
325325+ if e.coldCacheRecorder != nil {
326326+ e.coldCacheRecorder.IncColdCacheDecision("allowed")
327327+ }
161328 return nil, false, nil
162329 }
163330···294461 // Malformed response — return error so GetPolicy falls through to
295462 // the fail-stale path (preserving any cached labels) instead of
296463 // overwriting a known-label-bearing entry with an empty set.
297297- return nil, fmt.Errorf("malformed osprey response: %v", err)
464464+ return nil, fmt.Errorf("malformed osprey response: %w", err)
298465 }
299466300467 out := make(map[string]struct{}, len(result.Labels))
+194
internal/relay/ospreyenforce_coldcache_test.go
···11+// SPDX-License-Identifier: AGPL-3.0-or-later
22+33+package relay
44+55+import (
66+ "context"
77+ "errors"
88+ "net/http"
99+ "net/http/httptest"
1010+ "path/filepath"
1111+ "sync"
1212+ "testing"
1313+ "time"
1414+)
1515+1616+// brokenServer always returns 500 so the enforcer treats Osprey as
1717+// unreachable. Used to drive the fail-closed branch.
1818+func brokenServer(t *testing.T) *httptest.Server {
1919+ t.Helper()
2020+ s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
2121+ w.WriteHeader(http.StatusInternalServerError)
2222+ }))
2323+ t.Cleanup(s.Close)
2424+ return s
2525+}
2626+2727+type stubColdRecorder struct {
2828+ mu sync.Mutex
2929+ calls map[string]int
3030+}
3131+3232+func newStubColdRecorder() *stubColdRecorder {
3333+ return &stubColdRecorder{calls: map[string]int{}}
3434+}
3535+func (s *stubColdRecorder) IncColdCacheDecision(d string) {
3636+ s.mu.Lock()
3737+ s.calls[d]++
3838+ s.mu.Unlock()
3939+}
4040+func (s *stubColdRecorder) count(d string) int {
4141+ s.mu.Lock()
4242+ defer s.mu.Unlock()
4343+ return s.calls[d]
4444+}
4545+4646+// TestEnforcer_ColdCacheFailClosedByDefault pins the central #215
4747+// invariant: a cold cache + Osprey unreachable rejects with
4848+// ErrOspreyColdCache by default, NOT silently allows.
4949+func TestEnforcer_ColdCacheFailClosedByDefault(t *testing.T) {
5050+ srv := brokenServer(t)
5151+ e := NewOspreyEnforcer(srv.URL, &http.Client{Timeout: 200 * time.Millisecond})
5252+ rec := newStubColdRecorder()
5353+ e.SetColdCacheRecorder(rec)
5454+5555+ policy, err := e.GetPolicy(context.Background(), "did:plc:cold")
5656+ if !errors.Is(err, ErrOspreyColdCache) {
5757+ t.Fatalf("expected ErrOspreyColdCache, got err=%v", err)
5858+ }
5959+ if policy == nil {
6060+ t.Errorf("policy should be non-nil even on error (defaultPolicy)")
6161+ }
6262+ if rec.count("denied") != 1 {
6363+ t.Errorf("denied count = %d, want 1", rec.count("denied"))
6464+ }
6565+ if rec.count("allowed") != 0 {
6666+ t.Errorf("allowed count = %d, want 0 (default is fail-closed)", rec.count("allowed"))
6767+ }
6868+}
6969+7070+// TestEnforcer_ColdCacheFailOpenOptIn confirms the legacy fail-open
7171+// path can be restored via SetFailClosedOnColdCache(false).
7272+func TestEnforcer_ColdCacheFailOpenOptIn(t *testing.T) {
7373+ srv := brokenServer(t)
7474+ e := NewOspreyEnforcer(srv.URL, &http.Client{Timeout: 200 * time.Millisecond})
7575+ e.SetFailClosedOnColdCache(false)
7676+ rec := newStubColdRecorder()
7777+ e.SetColdCacheRecorder(rec)
7878+7979+ _, err := e.GetPolicy(context.Background(), "did:plc:cold")
8080+ if err != nil {
8181+ t.Fatalf("opt-in fail-open should not return err, got %v", err)
8282+ }
8383+ if rec.count("allowed") != 1 {
8484+ t.Errorf("allowed count = %d, want 1", rec.count("allowed"))
8585+ }
8686+}
8787+8888+// TestEnforcer_SnapshotRoundTrip confirms persistence: write entries,
8989+// snapshot, build a fresh enforcer pointed at the same path, load,
9090+// and verify the entries replay AND the cold-cache branch does NOT
9191+// fire (because the cache is no longer cold).
9292+func TestEnforcer_SnapshotRoundTrip(t *testing.T) {
9393+ dir := t.TempDir()
9494+ snap := filepath.Join(dir, "cache.json")
9595+9696+ // Original enforcer: stuff a cache entry in.
9797+ e1 := NewOspreyEnforcer("http://127.0.0.1:1", nil)
9898+ e1.SetSnapshotPath(snap)
9999+ e1.cache["did:plc:warm"] = &ospreyEntry{
100100+ activeLabels: map[string]struct{}{"highly_trusted": {}},
101101+ fetchedAt: time.Now(),
102102+ }
103103+ if err := e1.Snapshot(); err != nil {
104104+ t.Fatalf("Snapshot: %v", err)
105105+ }
106106+107107+ // Fresh enforcer + broken Osprey: would normally fail-closed
108108+ // on cold cache. Loading the snapshot first means the cache is
109109+ // warm for did:plc:warm, so the fail-closed branch never fires
110110+ // for that DID.
111111+ srv := brokenServer(t)
112112+ e2 := NewOspreyEnforcer(srv.URL, &http.Client{Timeout: 200 * time.Millisecond})
113113+ e2.SetSnapshotPath(snap)
114114+ rec := newStubColdRecorder()
115115+ e2.SetColdCacheRecorder(rec)
116116+ n, err := e2.LoadSnapshot()
117117+ if err != nil {
118118+ t.Fatalf("LoadSnapshot: %v", err)
119119+ }
120120+ if n != 1 {
121121+ t.Errorf("loaded entries = %d, want 1", n)
122122+ }
123123+124124+ // Now policy lookup uses the cached entry — no broker call,
125125+ // no cold-cache decision recorded.
126126+ policy, err := e2.GetPolicy(context.Background(), "did:plc:warm")
127127+ if err != nil {
128128+ t.Errorf("warm-cache lookup returned err: %v", err)
129129+ }
130130+ if !policy.SkipWarming {
131131+ t.Errorf("policy should reflect highly_trusted (SkipWarming=true): %+v", policy)
132132+ }
133133+ if rec.count("denied") != 0 || rec.count("allowed") != 0 {
134134+ t.Errorf("cold-cache recorder fired for warm entry: denied=%d allowed=%d",
135135+ rec.count("denied"), rec.count("allowed"))
136136+ }
137137+138138+ // A DIFFERENT DID still cold-cache fails closed.
139139+ if _, err := e2.GetPolicy(context.Background(), "did:plc:cold"); !errors.Is(err, ErrOspreyColdCache) {
140140+ t.Errorf("unknown DID should fail-closed; err=%v", err)
141141+ }
142142+}
143143+144144+// TestEnforcer_LoadSnapshot_DropsExpired ensures stale entries don't
145145+// outlive the 2*ttl freshness window. A snapshot from 1 month ago
146146+// shouldn't keep serving labels indefinitely.
147147+func TestEnforcer_LoadSnapshot_DropsExpired(t *testing.T) {
148148+ dir := t.TempDir()
149149+ snap := filepath.Join(dir, "cache.json")
150150+151151+ e1 := NewOspreyEnforcer("http://127.0.0.1:1", nil)
152152+ e1.SetSnapshotPath(snap)
153153+ e1.cache["did:plc:fresh"] = &ospreyEntry{
154154+ activeLabels: map[string]struct{}{},
155155+ fetchedAt: time.Now(),
156156+ }
157157+ e1.cache["did:plc:stale"] = &ospreyEntry{
158158+ activeLabels: map[string]struct{}{},
159159+ fetchedAt: time.Now().Add(-30 * 24 * time.Hour),
160160+ }
161161+ if err := e1.Snapshot(); err != nil {
162162+ t.Fatal(err)
163163+ }
164164+165165+ e2 := NewOspreyEnforcer("http://127.0.0.1:1", nil)
166166+ e2.SetSnapshotPath(snap)
167167+ n, err := e2.LoadSnapshot()
168168+ if err != nil {
169169+ t.Fatal(err)
170170+ }
171171+ if n != 1 {
172172+ t.Errorf("loaded entries = %d, want 1 (stale dropped)", n)
173173+ }
174174+ if _, ok := e2.cache["did:plc:fresh"]; !ok {
175175+ t.Error("fresh entry missing from loaded cache")
176176+ }
177177+ if _, ok := e2.cache["did:plc:stale"]; ok {
178178+ t.Error("stale entry survived load")
179179+ }
180180+}
181181+182182+// TestEnforcer_LoadSnapshot_MissingFileNoError covers first-boot:
183183+// no snapshot exists yet, Load should be a clean no-op.
184184+func TestEnforcer_LoadSnapshot_MissingFileNoError(t *testing.T) {
185185+ e := NewOspreyEnforcer("http://127.0.0.1:1", nil)
186186+ e.SetSnapshotPath(filepath.Join(t.TempDir(), "does-not-exist.json"))
187187+ n, err := e.LoadSnapshot()
188188+ if err != nil {
189189+ t.Errorf("missing file should not error, got %v", err)
190190+ }
191191+ if n != 0 {
192192+ t.Errorf("loaded = %d, want 0", n)
193193+ }
194194+}
+12-2
internal/relay/ospreyenforce_test.go
···120120}
121121122122func TestOspreyEnforcerUnreachableNoCache(t *testing.T) {
123123- // No server — enforcer should fail-open when no cache entry exists.
123123+ // Opt into the legacy fail-open behavior. Default is fail-closed
124124+ // (#215); see TestEnforcer_ColdCacheFailClosedByDefault for that
125125+ // path. This test pins the opt-in escape hatch.
124126 e := NewOspreyEnforcer("http://127.0.0.1:1", &http.Client{Timeout: 50 * time.Millisecond})
127127+ e.SetFailClosedOnColdCache(false)
125128 suspended, err := e.CheckSuspended(context.Background(), "did:plc:new")
126129 if err != nil {
127130 t.Fatalf("unexpected error: %v", err)
···187190}
188191189192func TestOspreyEnforcerMalformedResponseNoCacheFailsOpen(t *testing.T) {
190190- // No prior cache + malformed response = fail-open (allow send).
193193+ // Opt-in fail-open path. Default is fail-closed (#215); this
194194+ // test pins the legacy behavior available via opt-in only.
191195 srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
192196 w.Write([]byte("not json {{{"))
193197 }))
194198 defer srv.Close()
195199196200 e := NewOspreyEnforcer(srv.URL, srv.Client())
201201+ e.SetFailClosedOnColdCache(false)
197202 suspended, err := e.CheckSuspended(context.Background(), "did:plc:test")
198203 if err != nil {
199204 t.Fatalf("unexpected error: %v", err)
···290295 }
291296}
292297298298+// TestOspreyEnforcerServerErrorNoCacheFailsOpen is the opt-in fail-
299299+// open variant — default is fail-closed (#215). The opt-in pin is
300300+// here so a future contributor can find it next to the security
301301+// behavior it legacy-overrides.
293302func TestOspreyEnforcerServerErrorNoCacheFailsOpen(t *testing.T) {
294303 // 500 with no prior cache should fail-open (allow).
295304 srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
···298307 defer srv.Close()
299308300309 e := NewOspreyEnforcer(srv.URL, srv.Client())
310310+ e.SetFailClosedOnColdCache(false)
301311 suspended, err := e.CheckSuspended(context.Background(), "did:plc:new")
302312 if err != nil {
303313 t.Fatalf("unexpected error: %v", err)
···11+// SPDX-License-Identifier: AGPL-3.0-or-later
22+33+package relay
44+55+// RecipientOutcome records the result of delivering a single recipient inside
66+// a multi-RCPT SMTP DATA. Emitted by the per-recipient loop in cmd/relay so
77+// the caller can decide whether the whole DATA should be accepted, rejected,
88+// or partially-failed.
99+type RecipientOutcome struct {
1010+ // Recipient is the RCPT TO address.
1111+ Recipient string
1212+ // MsgID is the relaystore.messages row inserted for this recipient.
1313+ // Zero when the failure happened before InsertMessage (e.g. DKIM sign).
1414+ MsgID int64
1515+ // Err is non-nil when the recipient could not be enqueued for any reason.
1616+ Err error
1717+}
1818+1919+// AggregateRecipientOutcomes summarizes a per-recipient delivery loop.
2020+//
2121+// Returns:
2222+// - succeeded, failed: per-recipient counts
2323+// - retryAll: true only when zero recipients succeeded; the caller should
2424+// return a transient SMTP error so the client retries the whole DATA.
2525+// - lastErr: a representative error from the failures (for logging)
2626+//
2727+// When at least one recipient succeeded, the caller MUST accept the DATA
2828+// (return nil to the SMTP server). Returning a transient error in that case
2929+// would cause the client to retry the entire DATA, duplicating the
3030+// successfully-enqueued recipients — the bug fixed by this aggregator.
3131+func AggregateRecipientOutcomes(outcomes []RecipientOutcome) (succeeded, failed int, retryAll bool, lastErr error) {
3232+ for _, o := range outcomes {
3333+ if o.Err == nil {
3434+ succeeded++
3535+ continue
3636+ }
3737+ failed++
3838+ lastErr = o.Err
3939+ }
4040+ retryAll = succeeded == 0 && failed > 0
4141+ return
4242+}
+134
internal/relay/recipient_outcome_test.go
···11+// SPDX-License-Identifier: AGPL-3.0-or-later
22+33+package relay
44+55+import (
66+ "errors"
77+ "testing"
88+)
99+1010+func TestAggregateRecipientOutcomes_AllSucceeded(t *testing.T) {
1111+ outcomes := []RecipientOutcome{
1212+ {Recipient: "a@x.com", MsgID: 1},
1313+ {Recipient: "b@x.com", MsgID: 2},
1414+ {Recipient: "c@x.com", MsgID: 3},
1515+ }
1616+1717+ succeeded, failed, retryAll, lastErr := AggregateRecipientOutcomes(outcomes)
1818+1919+ if succeeded != 3 {
2020+ t.Errorf("succeeded = %d, want 3", succeeded)
2121+ }
2222+ if failed != 0 {
2323+ t.Errorf("failed = %d, want 0", failed)
2424+ }
2525+ if retryAll {
2626+ t.Error("retryAll = true, want false (any success means accept)")
2727+ }
2828+ if lastErr != nil {
2929+ t.Errorf("lastErr = %v, want nil", lastErr)
3030+ }
3131+}
3232+3333+func TestAggregateRecipientOutcomes_AllFailed(t *testing.T) {
3434+ bad := errors.New("queue full")
3535+ outcomes := []RecipientOutcome{
3636+ {Recipient: "a@x.com", Err: bad},
3737+ {Recipient: "b@x.com", Err: bad},
3838+ }
3939+4040+ succeeded, failed, retryAll, lastErr := AggregateRecipientOutcomes(outcomes)
4141+4242+ if succeeded != 0 {
4343+ t.Errorf("succeeded = %d, want 0", succeeded)
4444+ }
4545+ if failed != 2 {
4646+ t.Errorf("failed = %d, want 2", failed)
4747+ }
4848+ if !retryAll {
4949+ t.Error("retryAll = false, want true (zero successes → caller must reject DATA)")
5050+ }
5151+ if lastErr != bad {
5252+ t.Errorf("lastErr = %v, want %v", lastErr, bad)
5353+ }
5454+}
5555+5656+func TestAggregateRecipientOutcomes_PartialFailure_AcceptsAnyway(t *testing.T) {
5757+ // This is the regression fix for #226: when 1..N-1 enqueued and N failed,
5858+ // we MUST NOT signal retry — that would duplicate 1..N-1.
5959+ bad := errors.New("spool I/O error")
6060+ outcomes := []RecipientOutcome{
6161+ {Recipient: "a@x.com", MsgID: 1},
6262+ {Recipient: "b@x.com", MsgID: 2},
6363+ {Recipient: "c@x.com", MsgID: 3, Err: bad},
6464+ }
6565+6666+ succeeded, failed, retryAll, lastErr := AggregateRecipientOutcomes(outcomes)
6767+6868+ if succeeded != 2 {
6969+ t.Errorf("succeeded = %d, want 2", succeeded)
7070+ }
7171+ if failed != 1 {
7272+ t.Errorf("failed = %d, want 1", failed)
7373+ }
7474+ if retryAll {
7575+ t.Fatal("retryAll = true would cause client to retry DATA, duplicating recipients a@x.com and b@x.com — this is the bug #226 fixes")
7676+ }
7777+ if lastErr != bad {
7878+ t.Errorf("lastErr = %v, want %v", lastErr, bad)
7979+ }
8080+}
8181+8282+func TestAggregateRecipientOutcomes_FailureFirst_StillAcceptsIfAnySucceed(t *testing.T) {
8383+ // Order shouldn't matter — even if the first recipient failed, as long
8484+ // as at least one later recipient succeeded we still accept the DATA
8585+ // because retrying would duplicate the later success.
8686+ bad := errors.New("DKIM sign error")
8787+ outcomes := []RecipientOutcome{
8888+ {Recipient: "a@x.com", Err: bad},
8989+ {Recipient: "b@x.com", MsgID: 2},
9090+ }
9191+9292+ succeeded, failed, retryAll, _ := AggregateRecipientOutcomes(outcomes)
9393+9494+ if succeeded != 1 || failed != 1 {
9595+ t.Errorf("succeeded=%d failed=%d, want 1/1", succeeded, failed)
9696+ }
9797+ if retryAll {
9898+ t.Error("retryAll = true on partial success — would duplicate b@x.com on retry")
9999+ }
100100+}
101101+102102+func TestAggregateRecipientOutcomes_Empty(t *testing.T) {
103103+ succeeded, failed, retryAll, lastErr := AggregateRecipientOutcomes(nil)
104104+105105+ if succeeded != 0 || failed != 0 {
106106+ t.Errorf("counts non-zero on empty input: succeeded=%d failed=%d", succeeded, failed)
107107+ }
108108+ if retryAll {
109109+ // Empty input means "no recipients to deliver" — the caller should
110110+ // not have invoked the loop at all. We choose retryAll=false here
111111+ // because there's nothing to retry.
112112+ t.Error("retryAll = true on empty outcomes — should be false (nothing to retry)")
113113+ }
114114+ if lastErr != nil {
115115+ t.Errorf("lastErr = %v on empty, want nil", lastErr)
116116+ }
117117+}
118118+119119+func TestAggregateRecipientOutcomes_LastErrIsTheLastFailure(t *testing.T) {
120120+ // When multiple failures occur, lastErr should be deterministic — the
121121+ // last one in iteration order, so logs are reproducible.
122122+ first := errors.New("first failure")
123123+ second := errors.New("second failure")
124124+ outcomes := []RecipientOutcome{
125125+ {Recipient: "a@x.com", Err: first},
126126+ {Recipient: "b@x.com", Err: second},
127127+ }
128128+129129+ _, _, _, lastErr := AggregateRecipientOutcomes(outcomes)
130130+131131+ if lastErr != second {
132132+ t.Errorf("lastErr = %v, want %v (iteration order should pick the last)", lastErr, second)
133133+ }
134134+}
+60-17
internal/relay/smtp.go
···476476 return nil
477477}
478478479479-// validateFromHeader parses the From header from message data and verifies
480480-// the domain matches the member's registered domain. This prevents a member
481481-// registered for example.com from sending with From: ceo@bigbank.com, which
482482-// would be DKIM-signed and could enable phishing.
479479+// validateFromHeader parses the From header (and the related Sender,
480480+// Resent-From, Resent-Sender headers) from message data and verifies all
481481+// of them carry the member's registered domain.
482482+//
483483+// Why all four: per RFC 5322,
484484+// - From identifies the author. DMARC alignment is on the From domain,
485485+// so spoofing it enables phishing under the member's DKIM signature.
486486+// - Sender identifies the agent that actually injected the message
487487+// (used when From contains multiple authors). Receivers — Gmail in
488488+// particular — fall back to the Sender domain for DMARC alignment in
489489+// that case, so a member registered for example.com sending with
490490+// "Sender: agent@bigbank.com" would still spoof bigbank.com from
491491+// Gmail's perspective even though our From check passed.
492492+// - Resent-From / Resent-Sender carry the same risks for forwarded
493493+// messages. The relay isn't a re-mailer; messages should originate
494494+// from the member's domain regardless of which header conveys that.
495495+//
496496+// All four single-address headers are validated identically. Resent-*
497497+// headers may appear multiple times (RFC 5322 §3.6.6 forwarding trace);
498498+// every occurrence must pass.
483499func validateFromHeader(data []byte, memberDomain string) error {
484500 r := textproto.NewReader(bufio.NewReader(strings.NewReader(string(data))))
485501 header, err := r.ReadMIMEHeader()
···487503 return fmt.Errorf("From header domain must match %s", memberDomain)
488504 }
489505490490- fromHeader := header.Get("From")
491491- if fromHeader == "" {
506506+ if header.Get("From") == "" {
492507 return fmt.Errorf("missing From header")
493508 }
494509495495- // Parse From address using stdlib — rejects multi-address headers and
496496- // malformed addresses that hand-rolled parsers might accept.
497497- addr, err := mail.ParseAddress(fromHeader)
498498- if err != nil {
499499- return fmt.Errorf("could not parse From header: %v", err)
510510+ // Single-occurrence headers: From and Sender. Each must, when present,
511511+ // carry exactly one address aligned with the member's domain.
512512+ for _, name := range []string{"From", "Sender"} {
513513+ v := header.Get(name)
514514+ if v == "" {
515515+ continue
516516+ }
517517+ if err := requireAlignedSingleAddress(name, v, memberDomain); err != nil {
518518+ return err
519519+ }
520520+ }
521521+522522+ // Multi-occurrence Resent-* headers. RFC 5322 allows each forward hop
523523+ // to add its own Resent-From/Resent-Sender block; net/textproto returns
524524+ // every value via header.Values(). We require *every* hop to align.
525525+ for _, name := range []string{"Resent-From", "Resent-Sender"} {
526526+ for _, v := range header.Values(name) {
527527+ if v == "" {
528528+ continue
529529+ }
530530+ if err := requireAlignedSingleAddress(name, v, memberDomain); err != nil {
531531+ return err
532532+ }
533533+ }
500534 }
501535536536+ return nil
537537+}
538538+539539+// requireAlignedSingleAddress parses a single-address header value and
540540+// returns an error unless it contains exactly one address whose domain
541541+// matches memberDomain (case-insensitive, exact match — no subdomain
542542+// alignment, mirroring the rest of the relay's policy).
543543+func requireAlignedSingleAddress(headerName, headerValue, memberDomain string) error {
544544+ addr, err := mail.ParseAddress(headerValue)
545545+ if err != nil {
546546+ return fmt.Errorf("could not parse %s header: %w", headerName, err)
547547+ }
502548 parts := strings.SplitN(addr.Address, "@", 2)
503549 if len(parts) != 2 {
504504- return fmt.Errorf("could not parse domain from From header")
550550+ return fmt.Errorf("could not parse domain from %s header", headerName)
505551 }
506506- fromDomain := parts[1]
507507-508508- if strings.ToLower(fromDomain) != strings.ToLower(memberDomain) {
509509- return fmt.Errorf("From header domain %q does not match registered domain %q", fromDomain, memberDomain)
552552+ if !strings.EqualFold(parts[1], memberDomain) {
553553+ return fmt.Errorf("%s header domain %q does not match registered domain %q", headerName, parts[1], memberDomain)
510554 }
511511-512555 return nil
513556}
+95
internal/relay/smtp_test.go
···507507 }
508508}
509509510510+// --- Sender / Resent-* header validation (#225) ---
511511+//
512512+// Gmail and other large receivers fall back to Sender for DMARC alignment
513513+// when From contains multiple authors. A member that passes the From check
514514+// can still spoof a third party via a forged Sender — so we extend the
515515+// alignment check to every author/agent header RFC 5322 defines.
516516+517517+func TestValidateFromHeader_SenderAligned(t *testing.T) {
518518+ msg := "From: noreply@example.com\r\nSender: ops@example.com\r\nTo: user@gmail.com\r\nSubject: Test\r\n\r\nBody\r\n"
519519+ if err := validateFromHeader([]byte(msg), "example.com"); err != nil {
520520+ t.Fatalf("aligned Sender rejected: %v", err)
521521+ }
522522+}
523523+524524+func TestValidateFromHeader_SenderSpoofed(t *testing.T) {
525525+ // Author looks legit but Sender impersonates a bank — Gmail uses Sender
526526+ // for DMARC alignment when present, so this must be blocked.
527527+ msg := "From: noreply@example.com\r\nSender: ceo@bigbank.com\r\nTo: user@gmail.com\r\nSubject: Test\r\n\r\nBody\r\n"
528528+ err := validateFromHeader([]byte(msg), "example.com")
529529+ if err == nil {
530530+ t.Fatal("spoofed Sender header should be rejected")
531531+ }
532532+ if !strings.Contains(err.Error(), "Sender") {
533533+ t.Errorf("error should mention Sender header: %v", err)
534534+ }
535535+ if !strings.Contains(err.Error(), "bigbank.com") {
536536+ t.Errorf("error should name the spoofed domain: %v", err)
537537+ }
538538+}
539539+540540+func TestValidateFromHeader_SenderEmptyOk(t *testing.T) {
541541+ // Empty/absent Sender is fine — it's only meaningful when present.
542542+ msg := "From: noreply@example.com\r\nTo: user@gmail.com\r\nSubject: Test\r\n\r\nBody\r\n"
543543+ if err := validateFromHeader([]byte(msg), "example.com"); err != nil {
544544+ t.Fatalf("absent Sender should not be required: %v", err)
545545+ }
546546+}
547547+548548+func TestValidateFromHeader_ResentFromSpoofed(t *testing.T) {
549549+ msg := "From: noreply@example.com\r\nResent-From: ceo@bigbank.com\r\nTo: user@gmail.com\r\nSubject: Test\r\n\r\nBody\r\n"
550550+ err := validateFromHeader([]byte(msg), "example.com")
551551+ if err == nil {
552552+ t.Fatal("spoofed Resent-From header should be rejected")
553553+ }
554554+ if !strings.Contains(err.Error(), "Resent-From") {
555555+ t.Errorf("error should mention Resent-From: %v", err)
556556+ }
557557+}
558558+559559+func TestValidateFromHeader_ResentSenderSpoofed(t *testing.T) {
560560+ msg := "From: noreply@example.com\r\nResent-Sender: ceo@bigbank.com\r\nTo: user@gmail.com\r\nSubject: Test\r\n\r\nBody\r\n"
561561+ err := validateFromHeader([]byte(msg), "example.com")
562562+ if err == nil {
563563+ t.Fatal("spoofed Resent-Sender header should be rejected")
564564+ }
565565+}
566566+567567+func TestValidateFromHeader_ResentFromMultipleHopsAllAligned(t *testing.T) {
568568+ // RFC 5322 §3.6.6: each forward hop prepends its own Resent-* block.
569569+ // When the relay is the (only) re-mailer, all hops are us, and all
570570+ // should be aligned with the member's domain.
571571+ msg := "From: noreply@example.com\r\n" +
572572+ "Resent-From: ops@example.com\r\n" +
573573+ "Resent-From: ops2@example.com\r\n" +
574574+ "To: user@gmail.com\r\nSubject: Test\r\n\r\nBody\r\n"
575575+ if err := validateFromHeader([]byte(msg), "example.com"); err != nil {
576576+ t.Fatalf("multiple aligned Resent-From hops rejected: %v", err)
577577+ }
578578+}
579579+580580+func TestValidateFromHeader_ResentFromMultipleHopsOneSpoofed(t *testing.T) {
581581+ // One hop is forged → reject. Otherwise an attacker could chain a
582582+ // legitimate hop after a spoofed one to slip past a "first-only" check.
583583+ msg := "From: noreply@example.com\r\n" +
584584+ "Resent-From: ops@example.com\r\n" +
585585+ "Resent-From: ceo@bigbank.com\r\n" +
586586+ "To: user@gmail.com\r\nSubject: Test\r\n\r\nBody\r\n"
587587+ err := validateFromHeader([]byte(msg), "example.com")
588588+ if err == nil {
589589+ t.Fatal("forged hop in Resent-From chain should be rejected")
590590+ }
591591+ if !strings.Contains(err.Error(), "bigbank.com") {
592592+ t.Errorf("error should identify spoofed hop: %v", err)
593593+ }
594594+}
595595+596596+func TestValidateFromHeader_SenderMultiAddressRejected(t *testing.T) {
597597+ // Same multi-address attack as From, but on Sender.
598598+ msg := "From: noreply@example.com\r\nSender: attacker@evil.com, \"Friendly\" <legit@example.com>\r\nTo: user@gmail.com\r\nSubject: Test\r\n\r\nBody\r\n"
599599+ err := validateFromHeader([]byte(msg), "example.com")
600600+ if err == nil {
601601+ t.Fatal("multi-address Sender header should be rejected")
602602+ }
603603+}
604604+510605func TestSMTPFromHeaderPhishingBlocked(t *testing.T) {
511606 // End-to-end test: member for example.com tries to send with From: ceo@bigbank.com
512607 apiKey := "atmos_testkey123"
+85-10
internal/relay/spool.go
···3434}
35353636// Write persists a queue entry to the spool directory.
3737+//
3838+// Durability contract: when Write returns nil, the message body has
3939+// been fsynced to the underlying device AND the rename has been
4040+// fsynced to the directory entry. A subsequent power loss cannot lose
4141+// a message that Write claimed to persist. Without these fsyncs the
4242+// rename can appear to succeed but be reordered behind a crash,
4343+// leaving either a zero-length file or no file at all when the kernel
4444+// replays the journal — exactly the orphan case (#208) that produces
4545+// duplicate-delivery on SMTP retry.
3746func (s *Spool) Write(entry *QueueEntry) error {
3847 se := spoolEntry{
3948 ID: entry.ID,
···46554756 data, err := json.Marshal(se)
4857 if err != nil {
4949- return fmt.Errorf("marshal spool entry: %v", err)
5858+ return fmt.Errorf("marshal spool entry: %w", err)
5059 }
51605261 path := filepath.Join(s.dir, fmt.Sprintf("%d.msg", entry.ID))
5353-5454- // Write atomically: temp file + rename to avoid partial writes on crash
5562 tmp := path + ".tmp"
5656- if err := os.WriteFile(tmp, data, 0600); err != nil {
5757- return fmt.Errorf("write spool file: %v", err)
6363+6464+ // Step 1: write + fsync the temp file. fsync MUST happen before
6565+ // rename or the rename can land in the journal ahead of the data
6666+ // blocks, leaving a zero-byte file after a crash.
6767+ f, err := os.OpenFile(tmp, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
6868+ if err != nil {
6969+ return fmt.Errorf("open spool tmp: %w", err)
5870 }
7171+ if _, err := f.Write(data); err != nil {
7272+ f.Close()
7373+ os.Remove(tmp)
7474+ return fmt.Errorf("write spool tmp: %w", err)
7575+ }
7676+ if err := f.Sync(); err != nil {
7777+ f.Close()
7878+ os.Remove(tmp)
7979+ return fmt.Errorf("fsync spool tmp: %w", err)
8080+ }
8181+ if err := f.Close(); err != nil {
8282+ os.Remove(tmp)
8383+ return fmt.Errorf("close spool tmp: %w", err)
8484+ }
8585+8686+ // Step 2: rename. With the data fsynced above, the rename's
8787+ // directory-entry change is the only remaining hop, and step 3
8888+ // fsyncs the directory so even that hop can't be lost.
5989 if err := os.Rename(tmp, path); err != nil {
6090 os.Remove(tmp)
6161- return fmt.Errorf("rename spool file: %v", err)
9191+ return fmt.Errorf("rename spool file: %w", err)
9292+ }
9393+9494+ // Step 3: fsync the directory so the rename is durable. Some
9595+ // filesystems (ext4 default, btrfs, zfs) make this implicit when
9696+ // data was fsynced first, but Linux does not guarantee it across
9797+ // all configurations and macOS APFS makes no guarantees either.
9898+ // Failing here means we have a freshly-renamed file that may not
9999+ // survive a crash — log but do not roll back, since the file IS
100100+ // in place from this process's view and rolling back the rename
101101+ // would itself need another sync to be durable.
102102+ dir, err := os.Open(s.dir)
103103+ if err != nil {
104104+ return fmt.Errorf("open spool dir for fsync: %w", err)
105105+ }
106106+ if err := dir.Sync(); err != nil {
107107+ dir.Close()
108108+ return fmt.Errorf("fsync spool dir: %w", err)
109109+ }
110110+ if err := dir.Close(); err != nil {
111111+ return fmt.Errorf("close spool dir: %w", err)
62112 }
6311364114 return nil
···6611667117// Remove deletes a spool file for the given message ID.
68118// Returns nil if the file doesn't exist.
119119+//
120120+// fsync of the directory after the unlink is intentional: without
121121+// it, a crash between the unlink and a subsequent operation can
122122+// leave the file ghost-present after journal replay, and LoadAll
123123+// would then re-deliver an already-delivered message. Cost is one
124124+// directory fsync per terminal-state message, which is small
125125+// compared to the cost of an unintended duplicate send.
69126func (s *Spool) Remove(id int64) error {
70127 path := filepath.Join(s.dir, fmt.Sprintf("%d.msg", id))
71128 err := os.Remove(path)
7272- if os.IsNotExist(err) {
7373- return nil
129129+ if err != nil && !os.IsNotExist(err) {
130130+ return err
131131+ }
132132+ dir, derr := os.Open(s.dir)
133133+ if derr != nil {
134134+ return fmt.Errorf("open spool dir for fsync: %w", derr)
74135 }
7575- return err
136136+ defer dir.Close()
137137+ if err := dir.Sync(); err != nil {
138138+ return fmt.Errorf("fsync spool dir after remove: %w", err)
139139+ }
140140+ return nil
141141+}
142142+143143+// Exists reports whether a spool file for the given message ID is
144144+// currently present. Used by the orphan-reconciliation janitor in
145145+// cmd/relay (a status=queued DB row with no spool file is the
146146+// signature of a dropped Enqueue).
147147+func (s *Spool) Exists(id int64) bool {
148148+ path := filepath.Join(s.dir, fmt.Sprintf("%d.msg", id))
149149+ _, err := os.Stat(path)
150150+ return err == nil
76151}
7715278153// LoadAll reads all spool files and returns queue entries.
···80155func (s *Spool) LoadAll() ([]*QueueEntry, error) {
81156 entries, err := os.ReadDir(s.dir)
82157 if err != nil {
8383- return nil, fmt.Errorf("read spool dir: %v", err)
158158+ return nil, fmt.Errorf("read spool dir: %w", err)
84159 }
8516086161 var result []*QueueEntry
+112
internal/relay/spool_durability_test.go
···11+// SPDX-License-Identifier: AGPL-3.0-or-later
22+33+package relay
44+55+import (
66+ "os"
77+ "path/filepath"
88+ "testing"
99+)
1010+1111+// TestSpoolWrite_FsyncFile pins that Write leaves no .tmp file behind
1212+// and that the resulting message file is durable. We can't actually
1313+// observe fsync from userspace; the durable path is exercised by
1414+// dispatching a write and confirming a clean filesystem state — if
1515+// the implementation regresses to plain WriteFile + Rename, the test
1616+// still passes (fsync is invisible to readers), but the supporting
1717+// invariants below — no leftover .tmp, file is readable — guard the
1818+// observable contract.
1919+func TestSpoolWrite_NoTmpResidue(t *testing.T) {
2020+ dir := t.TempDir()
2121+ s := NewSpool(dir)
2222+ if err := s.Write(&QueueEntry{ID: 42, From: "a@b", To: "c@d", Data: []byte("hi")}); err != nil {
2323+ t.Fatalf("Write: %v", err)
2424+ }
2525+2626+ // .msg file present
2727+ if _, err := os.Stat(filepath.Join(dir, "42.msg")); err != nil {
2828+ t.Errorf(".msg file missing: %v", err)
2929+ }
3030+ // no .tmp residue
3131+ matches, _ := filepath.Glob(filepath.Join(dir, "*.tmp"))
3232+ if len(matches) != 0 {
3333+ t.Errorf("leftover .tmp files: %v", matches)
3434+ }
3535+}
3636+3737+// TestSpoolWrite_TmpRemovedOnRenameFailure verifies that if the
3838+// rename fails (because the target path is a directory), the .tmp
3939+// file is cleaned up rather than left behind to confuse future
4040+// LoadAll scans.
4141+func TestSpoolWrite_TmpRemovedOnRenameFailure(t *testing.T) {
4242+ dir := t.TempDir()
4343+ // Pre-create a *directory* at the target path so rename will fail.
4444+ if err := os.Mkdir(filepath.Join(dir, "1.msg"), 0755); err != nil {
4545+ t.Fatalf("mkdir: %v", err)
4646+ }
4747+ s := NewSpool(dir)
4848+ err := s.Write(&QueueEntry{ID: 1, From: "a@b", To: "c@d", Data: []byte("hi")})
4949+ if err == nil {
5050+ t.Fatal("expected error from rename onto directory")
5151+ }
5252+ // .tmp must not survive the failure.
5353+ if _, err := os.Stat(filepath.Join(dir, "1.msg.tmp")); !os.IsNotExist(err) {
5454+ t.Errorf("leftover .tmp after failed rename: stat err=%v", err)
5555+ }
5656+}
5757+5858+// TestSpoolExists reflects the contract used by the orphan
5959+// reconciliation janitor: only present files report true.
6060+func TestSpoolExists(t *testing.T) {
6161+ dir := t.TempDir()
6262+ s := NewSpool(dir)
6363+ if s.Exists(99) {
6464+ t.Error("Exists(99) returned true on empty spool")
6565+ }
6666+ if err := s.Write(&QueueEntry{ID: 99, From: "a@b", To: "c@d", Data: []byte("x")}); err != nil {
6767+ t.Fatalf("Write: %v", err)
6868+ }
6969+ if !s.Exists(99) {
7070+ t.Error("Exists(99) returned false after Write")
7171+ }
7272+ if err := s.Remove(99); err != nil {
7373+ t.Fatalf("Remove: %v", err)
7474+ }
7575+ if s.Exists(99) {
7676+ t.Error("Exists(99) returned true after Remove")
7777+ }
7878+}
7979+8080+// TestSpoolRemove_IdempotentOnMissing — the janitor calls Remove on
8181+// completion paths; missing-file is not an error.
8282+func TestSpoolRemove_IdempotentOnMissing(t *testing.T) {
8383+ dir := t.TempDir()
8484+ s := NewSpool(dir)
8585+ if err := s.Remove(123); err != nil {
8686+ t.Errorf("Remove on missing returned error: %v", err)
8787+ }
8888+}
8989+9090+// TestSpoolWriteRoundTrip confirms the data we wrote is what LoadAll
9191+// returns. Existing tests do this for the legacy code path; the
9292+// fsync rewrite must preserve byte-for-byte fidelity.
9393+func TestSpoolWriteRoundTrip(t *testing.T) {
9494+ dir := t.TempDir()
9595+ s := NewSpool(dir)
9696+ want := &QueueEntry{ID: 7, From: "from@x", To: "to@y", Data: []byte("hello world"), MemberDID: "did:plc:test", Attempts: 2}
9797+ if err := s.Write(want); err != nil {
9898+ t.Fatalf("Write: %v", err)
9999+ }
100100+ loaded, err := s.LoadAll()
101101+ if err != nil {
102102+ t.Fatalf("LoadAll: %v", err)
103103+ }
104104+ if len(loaded) != 1 {
105105+ t.Fatalf("LoadAll returned %d entries, want 1", len(loaded))
106106+ }
107107+ got := loaded[0]
108108+ if got.ID != want.ID || got.From != want.From || got.To != want.To ||
109109+ string(got.Data) != string(want.Data) || got.MemberDID != want.MemberDID || got.Attempts != want.Attempts {
110110+ t.Errorf("round-trip mismatch:\n got= %+v\n want= %+v", got, want)
111111+ }
112112+}
···1414// Bypasses rate limiting and suppression since these are operator-initiated
1515// sends to known seed addresses.
1616type WarmupSender struct {
1717- seedAddresses []string
1818- memberLookup func(ctx context.Context, did string) (*MemberWithDomains, error)
1919- queue *Queue
2020- operatorKeys *DKIMKeys
1717+ seedAddresses []string
1818+ fromLocalParts []string
1919+ memberLookup func(ctx context.Context, did string) (*MemberWithDomains, error)
2020+ queue *Queue
2121+ operatorKeys *DKIMKeys
2122 operatorDKIMDomain string
2222- relayDomain string
2323+ relayDomain string
23242425 insertMessage func(ctx context.Context, did, from, to, msgID string) (int64, error)
2526 incrSendCount func(ctx context.Context, did string)
···2829// WarmupConfig configures the warmup sender.
2930type WarmupConfig struct {
3031 SeedAddresses []string
3232+ FromLocalParts []string // local parts to rotate (default ["scott"])
3133 MemberLookup func(ctx context.Context, did string) (*MemberWithDomains, error)
3234 Queue *Queue
3335 OperatorKeys *DKIMKeys
···3840}
39414042func NewWarmupSender(cfg WarmupConfig) *WarmupSender {
4343+ fromParts := cfg.FromLocalParts
4444+ if len(fromParts) == 0 {
4545+ fromParts = []string{"scott"}
4646+ }
4147 return &WarmupSender{
4248 seedAddresses: cfg.SeedAddresses,
4949+ fromLocalParts: fromParts,
4350 memberLookup: cfg.MemberLookup,
4451 queue: cfg.Queue,
4552 operatorKeys: cfg.OperatorKeys,
···5966 Errors []string `json:"errors,omitempty"`
6067}
61686969+// SendOne sends a single warmup email to the given seed address on behalf of
7070+// the member DID. Template and From address are selected by recipientIdx to
7171+// ensure variety across recipients within a batch.
7272+func (w *WarmupSender) SendOne(ctx context.Context, did string, recipientIdx int) (*WarmupResult, error) {
7373+ if recipientIdx < 0 || recipientIdx >= len(w.seedAddresses) {
7474+ return nil, fmt.Errorf("recipient index %d out of range [0, %d)", recipientIdx, len(w.seedAddresses))
7575+ }
7676+7777+ member, err := w.memberLookup(ctx, did)
7878+ if err != nil {
7979+ return nil, fmt.Errorf("member lookup: %w", err)
8080+ }
8181+ if member == nil || len(member.Domains) == 0 {
8282+ return nil, fmt.Errorf("member %s not found or has no domains", did)
8383+ }
8484+8585+ domain := member.Domains[0]
8686+ to := w.seedAddresses[recipientIdx]
8787+ fromLocal := w.fromLocalParts[recipientIdx%len(w.fromLocalParts)]
8888+ from := fromLocal + "@" + domain.Domain
8989+9090+ templates := warmupTemplates()
9191+ tmpl := templates[recipientIdx%len(templates)]
9292+9393+ msgID := fmt.Sprintf("<%d.warmup@%s>", time.Now().UnixNano(), w.relayDomain)
9494+ msg := buildWarmupMessage(from, to, msgID, tmpl)
9595+9696+ result := &WarmupResult{}
9797+ if err := w.sendMessage(ctx, did, from, to, msgID, msg, domain); err != nil {
9898+ result.Failed = 1
9999+ result.Errors = append(result.Errors, fmt.Sprintf("%s: %v", to, err))
100100+ } else {
101101+ result.Sent = 1
102102+ }
103103+ return result, nil
104104+}
105105+62106// SendBatch sends one warmup email to each seed address on behalf of the
6363-// given member DID. Returns the number sent and any per-recipient errors.
107107+// given member DID. Template and From address vary per recipient.
108108+// Returns the number sent and any per-recipient errors.
64109func (w *WarmupSender) SendBatch(ctx context.Context, did string) (*WarmupResult, error) {
65110 if len(w.seedAddresses) == 0 {
66111 return nil, fmt.Errorf("no warmup seed addresses configured")
···75120 }
7612177122 domain := member.Domains[0]
7878- from := "postmaster@" + domain.Domain
7979-123123+ templates := warmupTemplates()
80124 result := &WarmupResult{}
8181- for _, to := range w.seedAddresses {
8282- msgID := fmt.Sprintf("<%d.warmup@%s>", time.Now().UnixNano(), w.relayDomain)
8383- msg := buildWarmupMessage(from, to, msgID, domain.Domain)
841258585- verpFrom := VERPReturnPath(did, to, w.relayDomain)
126126+ for i, to := range w.seedAddresses {
127127+ fromLocal := w.fromLocalParts[i%len(w.fromLocalParts)]
128128+ from := fromLocal + "@" + domain.Domain
129129+ tmpl := templates[i%len(templates)]
861308787- raw := []byte(msg)
8888- stamped := append([]byte("X-Atmos-Member-Did: "+did+"\r\n"), raw...)
8989- stamped = PrependFeedbackID(stamped, "transactional", did, domain.Domain)
131131+ msgID := fmt.Sprintf("<%d.warmup@%s>", time.Now().UnixNano(), w.relayDomain)
132132+ msg := buildWarmupMessage(from, to, msgID, tmpl)
901339191- signer := NewDualDomainSigner(domain.DKIMKeys, w.operatorKeys, domain.Domain, w.operatorDKIMDomain)
9292- signed, err := signer.Sign(strings.NewReader(string(stamped)))
9393- if err != nil {
134134+ if err := w.sendMessage(ctx, did, from, to, msgID, msg, domain); err != nil {
94135 result.Failed++
9595- result.Errors = append(result.Errors, fmt.Sprintf("%s: DKIM sign: %v", to, err))
9696- continue
136136+ result.Errors = append(result.Errors, fmt.Sprintf("%s: %v", to, err))
137137+ } else {
138138+ result.Sent++
97139 }
140140+ }
981419999- entryID := int64(0)
100100- if w.insertMessage != nil {
101101- id, err := w.insertMessage(ctx, did, from, to, msgID)
102102- if err != nil {
103103- log.Printf("warmup.insert_message: did=%s to=%s error=%v", did, to, err)
104104- } else {
105105- entryID = id
106106- }
107107- }
108108- if w.incrSendCount != nil {
109109- w.incrSendCount(ctx, did)
110110- }
142142+ return result, nil
143143+}
144144+145145+func (w *WarmupSender) sendMessage(ctx context.Context, did, from, to, msgID, msg string, domain DomainInfo) error {
146146+ verpFrom := VERPReturnPath(did, to, w.relayDomain)
111147112112- if err := w.queue.Enqueue(&QueueEntry{
113113- ID: entryID,
114114- From: verpFrom,
115115- To: to,
116116- Data: signed,
117117- MemberDID: did,
118118- }); err != nil {
119119- result.Failed++
120120- result.Errors = append(result.Errors, fmt.Sprintf("%s: enqueue: %v", to, err))
121121- continue
148148+ raw := []byte(msg)
149149+ stamped := append([]byte("X-Atmos-Member-Did: "+did+"\r\n"), raw...)
150150+ stamped = PrependFeedbackID(stamped, "transactional", did, domain.Domain)
151151+152152+ signer := NewDualDomainSigner(domain.DKIMKeys, w.operatorKeys, domain.Domain, w.operatorDKIMDomain)
153153+ signed, err := signer.Sign(strings.NewReader(string(stamped)))
154154+ if err != nil {
155155+ return fmt.Errorf("DKIM sign: %w", err)
156156+ }
157157+158158+ entryID := int64(0)
159159+ if w.insertMessage != nil {
160160+ id, err := w.insertMessage(ctx, did, from, to, msgID)
161161+ if err != nil {
162162+ log.Printf("warmup.insert_message: did=%s to=%s error=%v", did, to, err)
163163+ } else {
164164+ entryID = id
122165 }
166166+ }
167167+ if w.incrSendCount != nil {
168168+ w.incrSendCount(ctx, did)
169169+ }
123170124124- result.Sent++
125125- log.Printf("warmup.queued: did=%s to=%s msg_id=%s", did, to, msgID)
171171+ if err := w.queue.Enqueue(&QueueEntry{
172172+ ID: entryID,
173173+ From: verpFrom,
174174+ To: to,
175175+ Data: signed,
176176+ MemberDID: did,
177177+ }); err != nil {
178178+ return fmt.Errorf("enqueue: %w", err)
126179 }
127180128128- return result, nil
181181+ log.Printf("warmup.queued: did=%s from=%s to=%s msg_id=%s", did, from, to, msgID)
182182+ return nil
129183}
130184131185type warmupTemplate struct {
···133187 body string
134188}
135189136136-func warmupTemplates(domain string) []warmupTemplate {
190190+func warmupTemplates() []warmupTemplate {
137191 return []warmupTemplate{
138192 {
139139- subject: "Re: setting up email for " + domain,
140140- body: "Hi,\r\n\r\n" +
141141- "Just following up — the email configuration for " + domain + " is all set. DKIM signatures are being applied correctly and everything looks good on our end.\r\n\r\n" +
142142- "Let me know if you run into any issues or have questions about the setup.\r\n\r\n" +
143143- "Best,\r\n" +
193193+ subject: "Thursday lunch spot",
194194+ body: "Hey,\r\n\r\n" +
195195+ "Are you free Thursday? I was thinking we could try that new place on 4th. I heard they do a good cubano.\r\n\r\n" +
196196+ "Let me know — I can reserve a table if we go around noon.\r\n\r\n" +
197197+ "Scott",
198198+ },
199199+ {
200200+ subject: "Re: that article you sent",
201201+ body: "Just read through it — really interesting point about how federated systems handle trust differently than centralized ones. " +
202202+ "The section on cooperative infrastructure reminded me of some things we've been thinking about.\r\n\r\n" +
203203+ "Have you seen the follow-up post the author did? I'll dig up the link.\r\n\r\n" +
144204 "Scott",
145205 },
146206 {
147147- subject: "Quick note about " + domain,
207207+ subject: "Weekend plans?",
148208 body: "Hey,\r\n\r\n" +
149149- "Wanted to let you know that " + domain + " is fully configured and sending through the relay. The DKIM and SPF records are aligned, so messages should be landing in inboxes without any trouble.\r\n\r\n" +
150150- "The cooperative relay model means your domain benefits from shared reputation across all members, which is especially helpful for newer domains that haven't built up their own sending history yet.\r\n\r\n" +
151151- "Thanks,\r\n" +
209209+ "Any plans this weekend? I was going to do a hike if the weather holds up. The forecast looks decent but you never know around here.\r\n\r\n" +
210210+ "Also — I finally finished that book you recommended. The ending was not what I expected. We should talk about it.\r\n\r\n" +
211211+ "Scott",
212212+ },
213213+ {
214214+ subject: "quick favor",
215215+ body: "Hey, can you send me that recipe you mentioned last time? " +
216216+ "The one with the roasted peppers. I want to try making it this week.\r\n\r\n" +
217217+ "Thanks!\r\n" +
218218+ "Scott",
219219+ },
220220+ {
221221+ subject: "Re: meeting notes",
222222+ body: "Thanks for sending these over. I think the timeline in section 3 is a bit aggressive but everything else looks right to me.\r\n\r\n" +
223223+ "One thought — should we loop in the design team before we commit to the API contract? " +
224224+ "Might save us a round of changes later.\r\n\r\n" +
225225+ "Let me know what you think.\r\n\r\n" +
226226+ "Scott",
227227+ },
228228+ {
229229+ subject: "coffee machine recs",
230230+ body: "I'm finally replacing my old drip machine. Do you still like your Breville? " +
231231+ "I've been going back and forth between that and just getting a simple pour-over setup.\r\n\r\n" +
232232+ "Budget is flexible but I don't want something that takes 20 minutes to clean.\r\n\r\n" +
152233 "Scott",
153234 },
154235 {
155155- subject: domain + " is looking good",
156156- body: "Hi,\r\n\r\n" +
157157- "Everything is running well for " + domain + ". Wanted to drop a quick note to confirm that outbound messages are being signed and delivered as expected.\r\n\r\n" +
158158- "One thing worth mentioning — each message gets two DKIM signatures: one for your domain and one for the relay pool. This gives receiving mail servers two independent ways to verify authenticity, which generally helps with inbox placement.\r\n\r\n" +
159159- "Cheers,\r\n" +
236236+ subject: "Saw this and thought of you",
237237+ body: "There's a talk at the library next Tuesday about local history — the speaker is that author who wrote the book about the old rail lines. " +
238238+ "Starts at 7pm. Free admission.\r\n\r\n" +
239239+ "Want to go? I can drive.\r\n\r\n" +
240240+ "Scott",
241241+ },
242242+ {
243243+ subject: "Re: printer issue",
244244+ body: "Try power cycling it — unplug for 30 seconds, then plug back in. " +
245245+ "If that doesn't work, check if there's a firmware update. Mine had the same problem and updating fixed it.\r\n\r\n" +
246246+ "If it's still stuck after that let me know and I'll come take a look.\r\n\r\n" +
160247 "Scott",
161248 },
162249 {
163163- subject: "Checking in — " + domain,
164164- body: "Hey,\r\n\r\n" +
165165- "Just checking in on " + domain + ". The mail pipeline is healthy and I don't see any issues on our side.\r\n\r\n" +
166166- "If you've been seeing good deliverability, that's great — the shared IP reputation pool is working as intended. If anything looks off, just let me know and I can take a closer look at the logs.\r\n\r\n" +
167167- "Best,\r\n" +
250250+ subject: "Happy birthday!",
251251+ body: "Hope you have a great one today! Any big plans?\r\n\r\n" +
252252+ "We should get dinner sometime this week to celebrate. My treat.\r\n\r\n" +
168253 "Scott",
169254 },
170255 {
171171- subject: "All good with " + domain,
172172- body: "Hi,\r\n\r\n" +
173173- "Touching base to confirm " + domain + " is in good shape. The relay is processing your outbound mail normally, and authentication records are passing validation.\r\n\r\n" +
174174- "For context, Atmosphere Mail is a cooperative relay built for the AT Protocol ecosystem. The idea is that smaller self-hosted services can share IP reputation instead of each one starting from scratch with a cold IP address. Happy to answer any questions about how it works.\r\n\r\n" +
175175- "Thanks,\r\n" +
256256+ subject: "parking situation tomorrow",
257257+ body: "Heads up — they're doing construction on the south lot tomorrow so we'll need to use the garage on 2nd. " +
258258+ "I'd get there a bit early, it fills up fast.\r\n\r\n" +
259259+ "See you there.\r\n\r\n" +
176260 "Scott",
177261 },
178262 }
179263}
180264181181-func buildWarmupMessage(from, to, msgID, domain string) string {
182182- templates := warmupTemplates(domain)
183183- idx := int(time.Now().Unix()/60) % len(templates)
184184- t := templates[idx]
185185-265265+func buildWarmupMessage(from, to, msgID string, t warmupTemplate) string {
186266 return strings.Join([]string{
187267 "From: " + from,
188268 "To: " + to,
+240
internal/relay/warmup_scheduler.go
···11+// SPDX-License-Identifier: AGPL-3.0-or-later
22+33+package relay
44+55+import (
66+ "context"
77+ "log"
88+ "math/rand/v2"
99+ "sync"
1010+ "time"
1111+)
1212+1313+// MemberWarmupCandidate carries the per-member info the scheduler needs
1414+// to make a fair selection (#219). DID is required; CreatedAt is used to
1515+// boost newly-enrolled members so they reach mailbox-provider visibility
1616+// faster than long-tenured ones who already have a sending history.
1717+type MemberWarmupCandidate struct {
1818+ DID string
1919+ CreatedAt time.Time
2020+}
2121+2222+// WarmupScheduler drips warmup sends across the day instead of firing
2323+// them all at once. Each tick sends one email to one seed address for
2424+// one member, then waits before the next. This produces the organic
2525+// send pattern that mailbox providers expect from real human senders.
2626+//
2727+// Selection is rotation-fair (every eligible member gets warmed up
2828+// before any one repeats) with a tiebreaker that prefers newly-enrolled
2929+// members so a long-tenured member can't crowd out new enrollees on
3030+// the first iteration through the pool. See #219.
3131+type WarmupScheduler struct {
3232+ sender *WarmupSender
3333+ listCandidates func(ctx context.Context) ([]MemberWarmupCandidate, error)
3434+ interval time.Duration
3535+ jitter time.Duration
3636+3737+ mu sync.Mutex
3838+ running bool
3939+ cancelFunc context.CancelFunc
4040+4141+ // lastWarmedUp tracks per-DID lastWarmupAt for fairness. Process-local;
4242+ // a restart resets the rotation, which at worst over-warms one member
4343+ // before the wheel re-balances. Persistence is a future enhancement
4444+ // covered by the spec note "Track 'last_warmup_at' per member" but is
4545+ // not required for the fairness invariant within a process lifetime.
4646+ lastMu sync.Mutex
4747+ lastWarmedUp map[string]time.Time
4848+ now func() time.Time
4949+}
5050+5151+// WarmupSchedulerConfig configures the background warmup scheduler.
5252+type WarmupSchedulerConfig struct {
5353+ Sender *WarmupSender
5454+5555+ // ListCandidates returns active warmup-eligible members with their
5656+ // enrollment timestamps. Preferred over ListDIDs because it lets the
5757+ // fairness algorithm boost newly-enrolled members.
5858+ ListCandidates func(ctx context.Context) ([]MemberWarmupCandidate, error)
5959+6060+ // ListDIDs is the legacy callback returning DIDs without timestamps.
6161+ // When ListCandidates is nil the scheduler falls back to ListDIDs;
6262+ // the resulting candidates have CreatedAt = zero, which makes them
6363+ // equivalent for the boost tiebreaker. Members still rotate via
6464+ // lastWarmupAt tracking.
6565+ ListDIDs func(ctx context.Context) ([]string, error)
6666+6767+ Interval time.Duration // base time between sends (default 20min)
6868+ Jitter time.Duration // max random jitter (default 10min)
6969+7070+ // Now overrides time.Now for tests; defaults to time.Now.
7171+ Now func() time.Time
7272+}
7373+7474+func NewWarmupScheduler(cfg WarmupSchedulerConfig) *WarmupScheduler {
7575+ interval := cfg.Interval
7676+ if interval == 0 {
7777+ interval = 20 * time.Minute
7878+ }
7979+ jitter := cfg.Jitter
8080+ if jitter == 0 {
8181+ jitter = 10 * time.Minute
8282+ }
8383+ now := cfg.Now
8484+ if now == nil {
8585+ now = time.Now
8686+ }
8787+8888+ listCandidates := cfg.ListCandidates
8989+ if listCandidates == nil && cfg.ListDIDs != nil {
9090+ legacy := cfg.ListDIDs
9191+ listCandidates = func(ctx context.Context) ([]MemberWarmupCandidate, error) {
9292+ dids, err := legacy(ctx)
9393+ if err != nil {
9494+ return nil, err
9595+ }
9696+ out := make([]MemberWarmupCandidate, len(dids))
9797+ for i, d := range dids {
9898+ out[i] = MemberWarmupCandidate{DID: d}
9999+ }
100100+ return out, nil
101101+ }
102102+ }
103103+104104+ return &WarmupScheduler{
105105+ sender: cfg.Sender,
106106+ listCandidates: listCandidates,
107107+ interval: interval,
108108+ jitter: jitter,
109109+ lastWarmedUp: map[string]time.Time{},
110110+ now: now,
111111+ }
112112+}
113113+114114+// Start begins the background warmup loop. Safe to call multiple times;
115115+// subsequent calls are no-ops if already running.
116116+func (s *WarmupScheduler) Start(ctx context.Context) {
117117+ s.mu.Lock()
118118+ defer s.mu.Unlock()
119119+ if s.running {
120120+ return
121121+ }
122122+ s.running = true
123123+ ctx, s.cancelFunc = context.WithCancel(ctx)
124124+ go s.loop(ctx)
125125+ log.Printf("warmup.scheduler: started interval=%s jitter=%s seeds=%d",
126126+ s.interval, s.jitter, s.sender.SeedCount())
127127+}
128128+129129+// Stop halts the background warmup loop.
130130+func (s *WarmupScheduler) Stop() {
131131+ s.mu.Lock()
132132+ defer s.mu.Unlock()
133133+ if !s.running {
134134+ return
135135+ }
136136+ s.cancelFunc()
137137+ s.running = false
138138+ log.Printf("warmup.scheduler: stopped")
139139+}
140140+141141+func (s *WarmupScheduler) loop(ctx context.Context) {
142142+ defer func() {
143143+ s.mu.Lock()
144144+ s.running = false
145145+ s.mu.Unlock()
146146+ }()
147147+148148+ for {
149149+ wait := s.interval + time.Duration(rand.Int64N(int64(s.jitter)))
150150+ select {
151151+ case <-ctx.Done():
152152+ return
153153+ case <-time.After(wait):
154154+ s.tick(ctx)
155155+ }
156156+ }
157157+}
158158+159159+func (s *WarmupScheduler) tick(ctx context.Context) {
160160+ candidates, err := s.listCandidates(ctx)
161161+ if err != nil {
162162+ log.Printf("warmup.scheduler: list members: %v", err)
163163+ return
164164+ }
165165+ if len(candidates) == 0 {
166166+ return
167167+ }
168168+169169+ seedCount := s.sender.SeedCount()
170170+ if seedCount == 0 {
171171+ return
172172+ }
173173+174174+ picked := s.SelectMember(candidates)
175175+ recipientIdx := rand.IntN(seedCount)
176176+177177+ result, err := s.sender.SendOne(ctx, picked.DID, recipientIdx)
178178+ if err != nil {
179179+ log.Printf("warmup.scheduler: did=%s error=%v", picked.DID, err)
180180+ return
181181+ }
182182+183183+ if result.Sent > 0 {
184184+ s.recordWarmup(picked.DID)
185185+ log.Printf("warmup.scheduler: did=%s seed=%d sent=1", picked.DID, recipientIdx)
186186+ }
187187+ if result.Failed > 0 {
188188+ log.Printf("warmup.scheduler: did=%s seed=%d failed=1 errors=%v", picked.DID, recipientIdx, result.Errors)
189189+ }
190190+}
191191+192192+// SelectMember picks the candidate most due for a warmup send. Exported
193193+// so tests can pin the fairness invariant directly.
194194+//
195195+// Algorithm: oldest lastWarmupAt wins (rotation fairness — every member
196196+// gets warmed up before any single one repeats). Tiebreaker: newest
197197+// CreatedAt wins (boost recent enrollees so a flood of pre-existing
198198+// members can't starve a new one). Members never warmed up have
199199+// lastWarmupAt = zero, which always sorts before any non-zero time, so
200200+// they're always picked before re-warming an already-warmed member.
201201+func (s *WarmupScheduler) SelectMember(candidates []MemberWarmupCandidate) MemberWarmupCandidate {
202202+ if len(candidates) == 0 {
203203+ return MemberWarmupCandidate{}
204204+ }
205205+206206+ s.lastMu.Lock()
207207+ defer s.lastMu.Unlock()
208208+209209+ best := candidates[0]
210210+ bestLast := s.lastWarmedUp[best.DID]
211211+212212+ for _, c := range candidates[1:] {
213213+ last := s.lastWarmedUp[c.DID]
214214+ switch {
215215+ case last.Before(bestLast):
216216+ best, bestLast = c, last
217217+ case last.Equal(bestLast):
218218+ if c.CreatedAt.After(best.CreatedAt) {
219219+ best, bestLast = c, last
220220+ }
221221+ }
222222+ }
223223+ return best
224224+}
225225+226226+// recordWarmup stamps the lastWarmupAt for a DID after a successful send.
227227+func (s *WarmupScheduler) recordWarmup(did string) {
228228+ s.lastMu.Lock()
229229+ defer s.lastMu.Unlock()
230230+ s.lastWarmedUp[did] = s.now()
231231+}
232232+233233+// LastWarmedUp returns the last-warmup timestamp for a DID. Returns the
234234+// zero time if the DID has never been warmed up by this scheduler.
235235+// Test/diagnostic helper.
236236+func (s *WarmupScheduler) LastWarmedUp(did string) time.Time {
237237+ s.lastMu.Lock()
238238+ defer s.lastMu.Unlock()
239239+ return s.lastWarmedUp[did]
240240+}
···11+// SPDX-License-Identifier: AGPL-3.0-or-later
22+33+package relaystore
44+55+import (
66+ "context"
77+ "strings"
88+ "time"
99+)
1010+1111+// ListBypassAuditForTest exposes the bypass_audit table to tests
1212+// outside the relaystore package (the admin package's bypass tests
1313+// need to assert audit rows). Production code should not call this;
1414+// keep the API uncommitted by suffixing ForTest.
1515+func (s *Store) ListBypassAuditForTest(ctx context.Context, did string) ([]BypassAuditEntry, error) {
1616+ rows, err := s.db.QueryContext(ctx,
1717+ `SELECT id, did, action, reason, expires_at, created_at
1818+ FROM bypass_audit WHERE did = ? ORDER BY id ASC`, did,
1919+ )
2020+ if err != nil {
2121+ return nil, err
2222+ }
2323+ defer rows.Close()
2424+ var out []BypassAuditEntry
2525+ for rows.Next() {
2626+ var e BypassAuditEntry
2727+ var expStr, createdStr string
2828+ if err := rows.Scan(&e.ID, &e.DID, &e.Action, &e.Reason, &expStr, &createdStr); err != nil {
2929+ return nil, err
3030+ }
3131+ if expStr != "" {
3232+ if t, err := time.Parse(time.RFC3339Nano, expStr); err == nil {
3333+ e.ExpiresAt = t
3434+ }
3535+ }
3636+ if t, err := time.Parse(time.RFC3339Nano, createdStr); err == nil {
3737+ e.CreatedAt = t
3838+ }
3939+ out = append(out, e)
4040+ }
4141+ return out, rows.Err()
4242+}
4343+4444+// BusyRecorder is the narrow interface the Store needs to count
4545+// SQLITE_BUSY errors at hot-path writers, without taking a hard
4646+// dependency on Prometheus types in the relaystore package.
4747+// *relay.Metrics implements this; cmd/relay wires it via
4848+// SetBusyRecorder during construction.
4949+type BusyRecorder interface {
5050+ IncBusyError(op string)
5151+}
5252+5353+// SetBusyRecorder installs the busy-error recorder. Calling more
5454+// than once replaces the previous recorder; safe to call once
5555+// during wiring before the Store sees any traffic.
5656+func (s *Store) SetBusyRecorder(r BusyRecorder) {
5757+ s.busyRecorder = r
5858+}
5959+6060+// recordIfBusy is a small helper that callers use to classify a
6161+// freshly-returned error from a sql.DB call. Returns the error
6262+// unchanged so call sites can chain it inline.
6363+func (s *Store) recordIfBusy(op string, err error) error {
6464+ if err != nil && s.busyRecorder != nil && IsSQLiteBusy(err) {
6565+ s.busyRecorder.IncBusyError(op)
6666+ }
6767+ return err
6868+}
6969+7070+// PoolStats is a Store-level snapshot of *sql.DB pool counters
7171+// suitable for emitting as Prometheus gauges. Returned each time
7272+// SampleStats is called so a caller in cmd/relay can drive a
7373+// periodic update loop without exposing *sql.DB outside the
7474+// package.
7575+type PoolStats struct {
7676+ OpenConnections int
7777+ InUse int
7878+ Idle int
7979+ WaitCount int64
8080+ WaitDurationSecond float64
8181+}
8282+8383+// SampleStats reads sql.DB.Stats() and converts it into a
8484+// transport-friendly snapshot. Cheap to call (atomic loads under
8585+// the hood); cmd/relay should poll this every ~10s and forward
8686+// the values into the Prometheus gauges defined in
8787+// internal/relay/metrics.go.
8888+//
8989+// Why we expose sql.DB.Stats() rather than wrapping every Exec /
9090+// Query call site: the relay has 90+ DB call sites across the
9191+// store package, and SQLITE_BUSY errors that escape the 5s
9292+// busy_timeout PRAGMA are rare. The pool stats are a near-perfect
9393+// proxy: WaitCount climbing means contention is brewing, even
9494+// before any error escapes; InUse near MaxOpenConns means the
9595+// next caller will wait. Combined with BusyErrorClassify on the
9696+// hot writers, this gives operators a complete picture without
9797+// touching every callsite. Closes #210.
9898+func (s *Store) SampleStats() PoolStats {
9999+ st := s.db.Stats()
100100+ return PoolStats{
101101+ OpenConnections: st.OpenConnections,
102102+ InUse: st.InUse,
103103+ Idle: st.Idle,
104104+ WaitCount: st.WaitCount,
105105+ WaitDurationSecond: st.WaitDuration.Seconds(),
106106+ }
107107+}
108108+109109+// IsSQLiteBusy reports whether an error returned from modernc/sqlite
110110+// is a SQLITE_BUSY or locked condition. modernc does NOT export a
111111+// typed sentinel for these (the official driver wraps them as
112112+// *sqlite.Error but the value is unexported), so we fall back to
113113+// a robust substring match on the well-known reason strings.
114114+//
115115+// Used by store-level helpers to increment metrics.SQLiteBusyErrors
116116+// at the hot-path writers (InsertMessage, UpdateMessageStatus,
117117+// IncrementSendCount, RecordRateCount). False positives are
118118+// effectively impossible: these phrases are reserved by SQLite for
119119+// busy/locked conditions and don't appear in unrelated driver errors.
120120+func IsSQLiteBusy(err error) bool {
121121+ if err == nil {
122122+ return false
123123+ }
124124+ s := strings.ToLower(err.Error())
125125+ switch {
126126+ case strings.Contains(s, "database is locked"):
127127+ return true
128128+ case strings.Contains(s, "database table is locked"):
129129+ return true
130130+ case strings.Contains(s, "sqlite_busy"):
131131+ return true
132132+ case strings.Contains(s, "(5)"):
133133+ // modernc surfaces SQLITE_BUSY = 5 with a "(5)" suffix.
134134+ // Bare "(5)" matches too eagerly on its own; require the
135135+ // "locked"/"busy" keyword adjacent to avoid false positives
136136+ // against e.g. constraint codes.
137137+ return strings.Contains(s, "busy") || strings.Contains(s, "locked")
138138+ }
139139+ return false
140140+}
+151
internal/relaystore/observability_test.go
···11+// SPDX-License-Identifier: AGPL-3.0-or-later
22+33+package relaystore
44+55+import (
66+ "context"
77+ "errors"
88+ "sync"
99+ "sync/atomic"
1010+ "testing"
1111+ "time"
1212+)
1313+1414+func TestIsSQLiteBusy(t *testing.T) {
1515+ cases := []struct {
1616+ name string
1717+ err error
1818+ want bool
1919+ }{
2020+ {"nil", nil, false},
2121+ {"unrelated", errors.New("foreign key constraint failed"), false},
2222+ {"locked plain", errors.New("database is locked"), true},
2323+ {"locked uppercase", errors.New("Database Is Locked"), true},
2424+ {"table locked", errors.New("database table is locked"), true},
2525+ {"sqlite_busy", errors.New("SQLITE_BUSY: cannot start a transaction"), true},
2626+ {"modernc paren-5 + busy", errors.New("database busy (5)"), true},
2727+ {"modernc paren-5 + locked", errors.New("locked (5)"), true},
2828+ {"paren-5 alone is NOT busy", errors.New("constraint code (5)"), false},
2929+ {"wrapped", errors.New("insert message: database is locked"), true},
3030+ }
3131+ for _, tc := range cases {
3232+ t.Run(tc.name, func(t *testing.T) {
3333+ if got := IsSQLiteBusy(tc.err); got != tc.want {
3434+ t.Errorf("IsSQLiteBusy(%q) = %v, want %v", tc.err, got, tc.want)
3535+ }
3636+ })
3737+ }
3838+}
3939+4040+// stubBusyRecorder counts IncBusyError calls so tests can verify
4141+// the store wiring without dragging Prometheus types into the test.
4242+type stubBusyRecorder struct {
4343+ mu sync.Mutex
4444+ calls map[string]int
4545+}
4646+4747+func newStubBusyRecorder() *stubBusyRecorder {
4848+ return &stubBusyRecorder{calls: map[string]int{}}
4949+}
5050+5151+func (s *stubBusyRecorder) IncBusyError(op string) {
5252+ s.mu.Lock()
5353+ s.calls[op]++
5454+ s.mu.Unlock()
5555+}
5656+5757+func (s *stubBusyRecorder) count(op string) int {
5858+ s.mu.Lock()
5959+ defer s.mu.Unlock()
6060+ return s.calls[op]
6161+}
6262+6363+// TestSampleStats_ReturnsZeroOnFreshStore confirms the cheap-path
6464+// invariant: SampleStats is safe to call before any traffic and
6565+// returns sane zero-ish values rather than panicking on an
6666+// uninitialized pool.
6767+func TestSampleStats_ReturnsZeroOnFreshStore(t *testing.T) {
6868+ s := testStore(t)
6969+ ps := s.SampleStats()
7070+ // OpenConnections may be 0 or 1 depending on whether testStore
7171+ // pre-pinged. Just assert the shape didn't panic + returns
7272+ // non-negative values.
7373+ if ps.OpenConnections < 0 || ps.InUse < 0 || ps.Idle < 0 || ps.WaitCount < 0 {
7474+ t.Fatalf("negative stats: %+v", ps)
7575+ }
7676+}
7777+7878+// TestSampleStats_TracksInUse drives a transaction that holds a
7979+// connection and verifies SampleStats observes InUse > 0 while it
8080+// runs. Verifies the gauge has any signal at all (not just hard-
8181+// zero) when contention is occurring.
8282+func TestSampleStats_TracksInUse(t *testing.T) {
8383+ s := testStore(t)
8484+ // Start a long query in a goroutine; hold it until the test
8585+ // has had a chance to sample.
8686+ release := make(chan struct{})
8787+ started := make(chan struct{})
8888+ var counted atomic.Int32
8989+ go func() {
9090+ // modernc/sqlite supports BEGIN IMMEDIATE; hold it.
9191+ conn, err := s.db.Conn(context.Background())
9292+ if err != nil {
9393+ t.Errorf("get conn: %v", err)
9494+ return
9595+ }
9696+ defer conn.Close()
9797+ close(started)
9898+ <-release
9999+ counted.Store(1)
100100+ }()
101101+ <-started
102102+ defer close(release)
103103+104104+ // Allow the sql.DB pool to register the open connection.
105105+ deadline := time.Now().Add(time.Second)
106106+ for time.Now().Before(deadline) {
107107+ if s.SampleStats().InUse >= 1 {
108108+ return
109109+ }
110110+ time.Sleep(5 * time.Millisecond)
111111+ }
112112+ t.Errorf("SampleStats().InUse never reached >= 1; final: %+v", s.SampleStats())
113113+}
114114+115115+// TestStore_BusyRecorder_OptionalNilSafe confirms callers can use
116116+// the store without ever wiring a recorder. The recordIfBusy path
117117+// must short-circuit on nil rather than panicking.
118118+func TestStore_BusyRecorder_OptionalNilSafe(t *testing.T) {
119119+ s := testStore(t)
120120+ // No SetBusyRecorder call.
121121+ out := s.recordIfBusy("any", errors.New("database is locked"))
122122+ if out == nil || out.Error() != "database is locked" {
123123+ t.Errorf("recordIfBusy nil-recorder returned %v, want pass-through", out)
124124+ }
125125+}
126126+127127+// TestStore_RecordIfBusy_WiresClassifier confirms recordIfBusy
128128+// forwards busy errors and ignores non-busy errors.
129129+func TestStore_RecordIfBusy_WiresClassifier(t *testing.T) {
130130+ s := testStore(t)
131131+ rec := newStubBusyRecorder()
132132+ s.SetBusyRecorder(rec)
133133+134134+ // Busy: should count.
135135+ s.recordIfBusy("op1", errors.New("database is locked"))
136136+ if got := rec.count("op1"); got != 1 {
137137+ t.Errorf("op1 count = %d, want 1", got)
138138+ }
139139+140140+ // Non-busy: should NOT count.
141141+ s.recordIfBusy("op2", errors.New("constraint failure"))
142142+ if got := rec.count("op2"); got != 0 {
143143+ t.Errorf("op2 count = %d, want 0 (non-busy err shouldn't increment)", got)
144144+ }
145145+146146+ // Nil error: should NOT count.
147147+ s.recordIfBusy("op3", nil)
148148+ if got := rec.count("op3"); got != 0 {
149149+ t.Errorf("op3 count = %d, want 0 (nil err shouldn't increment)", got)
150150+ }
151151+}
+126
internal/relaystore/orphan_test.go
···11+// SPDX-License-Identifier: AGPL-3.0-or-later
22+33+package relaystore
44+55+import (
66+ "context"
77+ "errors"
88+ "testing"
99+ "time"
1010+)
1111+1212+// TestUpdateMessageStatus_ReturnsErrMessageNotFound pins the orphan
1313+// detection contract: updating a non-existent ID surfaces a typed
1414+// error so the delivery callback can increment the orphan metric
1515+// instead of silently dropping the update (#208).
1616+func TestUpdateMessageStatus_ReturnsErrMessageNotFound(t *testing.T) {
1717+ s := testStore(t)
1818+ err := s.UpdateMessageStatus(context.Background(), 999_999, MsgSent, 250)
1919+ if !errors.Is(err, ErrMessageNotFound) {
2020+ t.Fatalf("expected ErrMessageNotFound, got %v", err)
2121+ }
2222+}
2323+2424+// TestUpdateMessageStatus_ExistingRowReturnsNil — happy path, the
2525+// orphan-detecting code must not break the normal delivery callback.
2626+func TestUpdateMessageStatus_ExistingRowReturnsNil(t *testing.T) {
2727+ s := testStore(t)
2828+ ctx := context.Background()
2929+ if err := s.InsertMember(ctx, &Member{
3030+ DID: "did:plc:orphan",
3131+ Status: StatusActive,
3232+ HourlyLimit: 100, DailyLimit: 1000,
3333+ CreatedAt: time.Now().UTC(), UpdatedAt: time.Now().UTC(),
3434+ }); err != nil {
3535+ t.Fatal(err)
3636+ }
3737+ id, err := s.InsertMessage(ctx, &Message{
3838+ MemberDID: "did:plc:orphan", FromAddr: "a@b", ToAddr: "c@d",
3939+ MessageID: "<x@y>", Status: MsgQueued, CreatedAt: time.Now().UTC(),
4040+ })
4141+ if err != nil {
4242+ t.Fatalf("InsertMessage: %v", err)
4343+ }
4444+ if err := s.UpdateMessageStatus(ctx, id, MsgSent, 250); err != nil {
4545+ t.Errorf("update existing row: %v", err)
4646+ }
4747+}
4848+4949+// TestListQueuedMessageIDsOlderThan_FiltersByAgeAndStatus is the
5050+// janitor's contract: returns only rows that are status=queued AND
5151+// older than the cutoff. Excludes recent rows (would race with
5252+// just-Enqueued messages whose spool file is still landing) and
5353+// excludes terminal-state rows.
5454+func TestListQueuedMessageIDsOlderThan_FiltersByAgeAndStatus(t *testing.T) {
5555+ s := testStore(t)
5656+ ctx := context.Background()
5757+ if err := s.InsertMember(ctx, &Member{
5858+ DID: "did:plc:janitortest", Status: StatusActive, HourlyLimit: 100, DailyLimit: 1000,
5959+ CreatedAt: time.Now().UTC(), UpdatedAt: time.Now().UTC(),
6060+ }); err != nil {
6161+ t.Fatal(err)
6262+ }
6363+6464+ now := time.Now().UTC()
6565+ mk := func(status string, age time.Duration) int64 {
6666+ t.Helper()
6767+ id, err := s.InsertMessage(ctx, &Message{
6868+ MemberDID: "did:plc:janitortest", FromAddr: "a@b", ToAddr: "c@d",
6969+ MessageID: "<id@x>", Status: status, CreatedAt: now.Add(-age),
7070+ })
7171+ if err != nil {
7272+ t.Fatal(err)
7373+ }
7474+ return id
7575+ }
7676+7777+ oldQueued := mk(MsgQueued, 10*time.Minute) // should appear
7878+ mk(MsgQueued, 30*time.Second) // too recent — should NOT appear
7979+ mk(MsgSent, 1*time.Hour) // wrong status — should NOT appear
8080+ mk(MsgBounced, 1*time.Hour) // wrong status — should NOT appear
8181+ oldQueued2 := mk(MsgQueued, 1*time.Hour) // should appear
8282+8383+ ids, err := s.ListQueuedMessageIDsOlderThan(ctx, 5*time.Minute, 100)
8484+ if err != nil {
8585+ t.Fatalf("ListQueuedMessageIDsOlderThan: %v", err)
8686+ }
8787+ got := map[int64]bool{}
8888+ for _, id := range ids {
8989+ got[id] = true
9090+ }
9191+ if !got[oldQueued] || !got[oldQueued2] {
9292+ t.Errorf("missing expected ids; got=%v want both %d and %d", ids, oldQueued, oldQueued2)
9393+ }
9494+ if len(ids) != 2 {
9595+ t.Errorf("returned %d ids, want exactly 2 — recent / non-queued rows leaked through", len(ids))
9696+ }
9797+}
9898+9999+// TestListQueuedMessageIDsOlderThan_RespectsLimit confirms the limit
100100+// is honored so the janitor can bound its work per pass.
101101+func TestListQueuedMessageIDsOlderThan_RespectsLimit(t *testing.T) {
102102+ s := testStore(t)
103103+ ctx := context.Background()
104104+ if err := s.InsertMember(ctx, &Member{
105105+ DID: "did:plc:janitorlimit", Status: StatusActive, HourlyLimit: 100, DailyLimit: 1000,
106106+ CreatedAt: time.Now().UTC(), UpdatedAt: time.Now().UTC(),
107107+ }); err != nil {
108108+ t.Fatal(err)
109109+ }
110110+ old := time.Now().UTC().Add(-1 * time.Hour)
111111+ for i := 0; i < 7; i++ {
112112+ if _, err := s.InsertMessage(ctx, &Message{
113113+ MemberDID: "did:plc:janitorlimit", FromAddr: "a@b", ToAddr: "c@d",
114114+ MessageID: "<id@x>", Status: MsgQueued, CreatedAt: old,
115115+ }); err != nil {
116116+ t.Fatal(err)
117117+ }
118118+ }
119119+ ids, err := s.ListQueuedMessageIDsOlderThan(ctx, 5*time.Minute, 3)
120120+ if err != nil {
121121+ t.Fatalf("ListQueuedMessageIDsOlderThan: %v", err)
122122+ }
123123+ if len(ids) != 3 {
124124+ t.Errorf("returned %d ids, want 3 (limit)", len(ids))
125125+ }
126126+}
···9090 CREATE INDEX IF NOT EXISTS idx_relay_events_event_timestamp ON relay_events(event_timestamp DESC);
9191 `)
9292 if err != nil {
9393- return fmt.Errorf("create relay_events: %v", err)
9393+ return fmt.Errorf("create relay_events: %w", err)
9494 }
95959696 // content_fingerprint added after the table shipped — ADD COLUMN is
···101101 _ = s.db.QueryRow(`SELECT COUNT(*) FROM pragma_table_info('relay_events') WHERE name = 'content_fingerprint'`).Scan(&hasFingerprint)
102102 if hasFingerprint == 0 {
103103 if _, err := s.db.Exec(`ALTER TABLE relay_events ADD COLUMN content_fingerprint TEXT NOT NULL DEFAULT ''`); err != nil {
104104- return fmt.Errorf("add content_fingerprint column: %v", err)
104104+ return fmt.Errorf("add content_fingerprint column: %w", err)
105105 }
106106 }
107107 // Index the fingerprint for the primary read pattern: "show me every
···110110 // empty fingerprints (non-relay_attempt events) and we don't want them
111111 // bloating the index.
112112 if _, err := s.db.Exec(`CREATE INDEX IF NOT EXISTS idx_relay_events_fingerprint ON relay_events(content_fingerprint, event_timestamp DESC) WHERE content_fingerprint != ''`); err != nil {
113113- return fmt.Errorf("create fingerprint index: %v", err)
113113+ return fmt.Errorf("create fingerprint index: %w", err)
114114 }
115115 return nil
116116}
···121121func (s *Store) InsertRelayEvent(ctx context.Context, e *RelayEvent) error {
122122 verdictsJSON, err := json.Marshal(defaultStrings(e.Verdicts))
123123 if err != nil {
124124- return fmt.Errorf("marshal verdicts: %v", err)
124124+ return fmt.Errorf("marshal verdicts: %w", err)
125125 }
126126 labelsJSON, err := json.Marshal(defaultStrings(e.LabelsApplied))
127127 if err != nil {
128128- return fmt.Errorf("marshal labels_applied: %v", err)
128128+ return fmt.Errorf("marshal labels_applied: %w", err)
129129 }
130130131131 var smtpCode any
···146146 string(verdictsJSON), string(labelsJSON), e.Raw,
147147 )
148148 if err != nil {
149149- return fmt.Errorf("insert relay event: %v", err)
149149+ return fmt.Errorf("insert relay event: %w", err)
150150 }
151151 return nil
152152}
···203203204204 rows, err := s.db.QueryContext(ctx, q, args...)
205205 if err != nil {
206206- return nil, fmt.Errorf("list relay events: %v", err)
206206+ return nil, fmt.Errorf("list relay events: %w", err)
207207 }
208208 defer rows.Close()
209209···226226 `SELECT MAX(kafka_offset) FROM relay_events`,
227227 ).Scan(&offset)
228228 if err != nil {
229229- return -1, fmt.Errorf("last kafka offset: %v", err)
229229+ return -1, fmt.Errorf("last kafka offset: %w", err)
230230 }
231231 if !offset.Valid {
232232 return -1, nil
···244244 formatTime(since),
245245 )
246246 if err != nil {
247247- return nil, fmt.Errorf("count events by action: %v", err)
247247+ return nil, fmt.Errorf("count events by action: %w", err)
248248 }
249249 defer rows.Close()
250250···253253 var name string
254254 var count int64
255255 if err := rows.Scan(&name, &count); err != nil {
256256- return nil, fmt.Errorf("scan action count: %v", err)
256256+ return nil, fmt.Errorf("scan action count: %w", err)
257257 }
258258 out[name] = count
259259 }
···279279 formatTime(since),
280280 )
281281 if err != nil {
282282- return nil, fmt.Errorf("count labels applied: %v", err)
282282+ return nil, fmt.Errorf("count labels applied: %w", err)
283283 }
284284 defer rows.Close()
285285···287287 for rows.Next() {
288288 var raw string
289289 if err := rows.Scan(&raw); err != nil {
290290- return nil, fmt.Errorf("scan labels_applied: %v", err)
290290+ return nil, fmt.Errorf("scan labels_applied: %w", err)
291291 }
292292 var labels []string
293293 if err := json.Unmarshal([]byte(raw), &labels); err != nil {
···331331 return nil, nil
332332 }
333333 if err != nil {
334334- return nil, fmt.Errorf("scan relay event: %v", err)
334334+ return nil, fmt.Errorf("scan relay event: %w", err)
335335 }
336336 e.IngestedAt = parseTime(ingestedAt)
337337 e.EventTimestamp = parseTime(eventTimestamp)
+98
internal/relaystore/schema_version.go
···11+// SPDX-License-Identifier: AGPL-3.0-or-later
22+33+package relaystore
44+55+import (
66+ "database/sql"
77+ "errors"
88+ "fmt"
99+ "time"
1010+)
1111+1212+// CurrentSchemaVersion is the schema version this binary expects.
1313+//
1414+// Bump this whenever migrate() learns to apply a new structural change
1515+// (CREATE TABLE / ALTER TABLE / new index / migration of existing data).
1616+// The store records the bumped version at the end of a successful
1717+// migrate() run, and refuses to start the next time around if a *newer*
1818+// version has since been recorded by some later binary.
1919+//
2020+// Version history (kept here so the diff that bumps the constant carries
2121+// the rationale; the schema_version table also records `description`):
2222+//
2323+// 1 — baseline: everything that migrate() builds today (multi-domain,
2424+// messages, rate counters, suppressions, member_domains, attestation
2525+// flags, content fingerprints, relay events, bypass audit, etc.).
2626+// Every existing deployment lands here on first start.
2727+const CurrentSchemaVersion = 1
2828+2929+// ErrSchemaTooNew is returned by EnsureSchemaVersion when the database
3030+// has been written to by a newer binary than the one starting up. The
3131+// safe behavior is to refuse to start: an old binary missing knowledge
3232+// of newer columns or tables would otherwise silently INSERT defaults
3333+// and corrupt the data the newer binary persisted.
3434+var ErrSchemaTooNew = errors.New("database schema version is newer than this binary supports — refusing to start")
3535+3636+// EnsureSchemaVersion creates the schema_version tracking table if
3737+// missing and verifies the database isn't ahead of the binary.
3838+//
3939+// Returns ErrSchemaTooNew (wrapped) when MAX(version) > current. The
4040+// returned dbVersion is the highest version recorded in the table
4141+// (zero on a fresh database). It's exposed so the caller can decide
4242+// whether to skip work that's already been applied.
4343+func (s *Store) EnsureSchemaVersion() (dbVersion int, err error) {
4444+ if _, err := s.db.Exec(`
4545+ CREATE TABLE IF NOT EXISTS schema_version (
4646+ version INTEGER PRIMARY KEY,
4747+ description TEXT NOT NULL DEFAULT '',
4848+ applied_at TEXT NOT NULL,
4949+ binary_marker TEXT NOT NULL DEFAULT ''
5050+ )
5151+ `); err != nil {
5252+ return 0, fmt.Errorf("create schema_version: %w", err)
5353+ }
5454+5555+ var maxVersion sql.NullInt64
5656+ if err := s.db.QueryRow(`SELECT MAX(version) FROM schema_version`).Scan(&maxVersion); err != nil {
5757+ return 0, fmt.Errorf("read schema_version: %w", err)
5858+ }
5959+ if !maxVersion.Valid {
6060+ return 0, nil
6161+ }
6262+ dbVersion = int(maxVersion.Int64)
6363+ if dbVersion > CurrentSchemaVersion {
6464+ return dbVersion, fmt.Errorf("%w (db=%d, binary=%d)", ErrSchemaTooNew, dbVersion, CurrentSchemaVersion)
6565+ }
6666+ return dbVersion, nil
6767+}
6868+6969+// RecordSchemaVersion writes a row marking the current binary's schema
7070+// version as applied. Idempotent — INSERT OR IGNORE keeps the original
7171+// applied_at for unchanged versions so deployment history is retained.
7272+func (s *Store) RecordSchemaVersion(description, binaryMarker string) error {
7373+ _, err := s.db.Exec(
7474+ `INSERT OR IGNORE INTO schema_version (version, description, applied_at, binary_marker)
7575+ VALUES (?, ?, ?, ?)`,
7676+ CurrentSchemaVersion,
7777+ description,
7878+ time.Now().UTC().Format(time.RFC3339),
7979+ binaryMarker,
8080+ )
8181+ if err != nil {
8282+ return fmt.Errorf("record schema_version: %w", err)
8383+ }
8484+ return nil
8585+}
8686+8787+// SchemaVersion returns the highest version recorded in the database.
8888+// Zero on a fresh database.
8989+func (s *Store) SchemaVersion() (int, error) {
9090+ var v sql.NullInt64
9191+ if err := s.db.QueryRow(`SELECT MAX(version) FROM schema_version`).Scan(&v); err != nil {
9292+ return 0, fmt.Errorf("query schema_version: %w", err)
9393+ }
9494+ if !v.Valid {
9595+ return 0, nil
9696+ }
9797+ return int(v.Int64), nil
9898+}
+116
internal/relaystore/schema_version_test.go
···11+// SPDX-License-Identifier: AGPL-3.0-or-later
22+33+package relaystore
44+55+import (
66+ "errors"
77+ "path/filepath"
88+ "testing"
99+)
1010+1111+func newTempStore(t *testing.T) *Store {
1212+ t.Helper()
1313+ dsn := "file:" + filepath.Join(t.TempDir(), "test.db") + "?_journal=WAL"
1414+ s, err := New(dsn)
1515+ if err != nil {
1616+ t.Fatalf("New: %v", err)
1717+ }
1818+ t.Cleanup(func() { s.Close() })
1919+ return s
2020+}
2121+2222+func TestSchemaVersion_FreshDBRecordsCurrent(t *testing.T) {
2323+ s := newTempStore(t)
2424+ v, err := s.SchemaVersion()
2525+ if err != nil {
2626+ t.Fatalf("SchemaVersion: %v", err)
2727+ }
2828+ if v != CurrentSchemaVersion {
2929+ t.Errorf("fresh DB schema_version = %d, want %d", v, CurrentSchemaVersion)
3030+ }
3131+}
3232+3333+func TestSchemaVersion_ReopenIsIdempotent(t *testing.T) {
3434+ // Reopening a DB that's already at the current version must NOT
3535+ // produce a duplicate row or error.
3636+ dir := t.TempDir()
3737+ dsn := "file:" + filepath.Join(dir, "test.db") + "?_journal=WAL"
3838+3939+ s1, err := New(dsn)
4040+ if err != nil {
4141+ t.Fatalf("New (1): %v", err)
4242+ }
4343+ s1.Close()
4444+4545+ s2, err := New(dsn)
4646+ if err != nil {
4747+ t.Fatalf("New (2): %v", err)
4848+ }
4949+ defer s2.Close()
5050+5151+ var rows int
5252+ if err := s2.db.QueryRow(`SELECT COUNT(*) FROM schema_version WHERE version = ?`, CurrentSchemaVersion).Scan(&rows); err != nil {
5353+ t.Fatalf("count: %v", err)
5454+ }
5555+ if rows != 1 {
5656+ t.Errorf("schema_version row count = %d, want 1 (INSERT OR IGNORE should dedupe)", rows)
5757+ }
5858+}
5959+6060+func TestSchemaVersion_RefusesNewerDB(t *testing.T) {
6161+ // Simulate "rollback to old binary": pre-populate schema_version
6262+ // with a version higher than CurrentSchemaVersion, then reopen.
6363+ dir := t.TempDir()
6464+ dsn := "file:" + filepath.Join(dir, "test.db") + "?_journal=WAL"
6565+6666+ s1, err := New(dsn)
6767+ if err != nil {
6868+ t.Fatalf("New (1): %v", err)
6969+ }
7070+ if _, err := s1.db.Exec(
7171+ `INSERT INTO schema_version (version, description, applied_at, binary_marker)
7272+ VALUES (?, 'future', '2099-01-01T00:00:00Z', 'future-binary')`,
7373+ CurrentSchemaVersion+5,
7474+ ); err != nil {
7575+ t.Fatalf("inject future version: %v", err)
7676+ }
7777+ s1.Close()
7878+7979+ _, err = New(dsn)
8080+ if err == nil {
8181+ t.Fatal("expected New to refuse a newer DB, got nil error")
8282+ }
8383+ if !errors.Is(err, ErrSchemaTooNew) {
8484+ t.Errorf("error chain missing ErrSchemaTooNew: %v", err)
8585+ }
8686+}
8787+8888+func TestSchemaVersion_TableExistsAfterMigrate(t *testing.T) {
8989+ s := newTempStore(t)
9090+ var count int
9191+ if err := s.db.QueryRow(
9292+ `SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='schema_version'`,
9393+ ).Scan(&count); err != nil {
9494+ t.Fatalf("query sqlite_master: %v", err)
9595+ }
9696+ if count != 1 {
9797+ t.Errorf("schema_version table not created: count=%d", count)
9898+ }
9999+}
100100+101101+func TestSchemaVersion_RecordIsIdempotentExplicit(t *testing.T) {
102102+ // Direct unit test on RecordSchemaVersion to pin INSERT OR IGNORE.
103103+ s := newTempStore(t)
104104+ for i := 0; i < 5; i++ {
105105+ if err := s.RecordSchemaVersion("retry", "marker-v2"); err != nil {
106106+ t.Fatalf("RecordSchemaVersion (%d): %v", i, err)
107107+ }
108108+ }
109109+ var rows int
110110+ if err := s.db.QueryRow(`SELECT COUNT(*) FROM schema_version`).Scan(&rows); err != nil {
111111+ t.Fatalf("count: %v", err)
112112+ }
113113+ if rows != 1 {
114114+ t.Errorf("expected exactly 1 row after 6 calls, got %d", rows)
115115+ }
116116+}
+433-117
internal/relaystore/store.go
···55import (
66 "context"
77 "database/sql"
88+ "errors"
89 "fmt"
910 "log"
1011 "strings"
···13141415 _ "modernc.org/sqlite"
1516)
1717+1818+// ErrMessageNotFound is returned by UpdateMessageStatus when the
1919+// targeted row does not exist. Caused by a delivery callback firing
2020+// for a spool-only entry whose DB row was never inserted (or was
2121+// purged early). Callers should log + increment a metric so the
2222+// orphan rate is visible — silently dropping these updates is the
2323+// safety hole closed by #208.
2424+var ErrMessageNotFound = errors.New("relaystore: message row not found")
16251726// Member status constants.
1827const (
···3645 MsgQueued = "queued"
3746 MsgSent = "sent"
3847 MsgBounced = "bounced"
4848+ // MsgFailed is the terminal state for messages we lost internally
4949+ // (orphan reconciliation, spool corruption). Distinct from
5050+ // MsgBounced so operators can distinguish receiver-side rejection
5151+ // from our own pipeline failure when reading the dashboard.
5252+ MsgFailed = "failed"
3953 MsgDeferred = "deferred"
4054)
4155···147161}
148162149163type Store struct {
150150- db *sql.DB
151151- rateMu sync.Mutex // serializes CheckAndIncrementRate to prevent TOCTOU
164164+ db *sql.DB
165165+ rateMu sync.Mutex // serializes CheckAndIncrementRate to prevent TOCTOU
166166+ busyRecorder BusyRecorder // optional; counts SQLITE_BUSY errors at hot writers (#210)
152167}
153168154169func New(dsn string) (*Store, error) {
155170 db, err := sql.Open("sqlite", dsn)
156171 if err != nil {
157157- return nil, fmt.Errorf("open sqlite: %v", err)
172172+ return nil, fmt.Errorf("open sqlite: %w", err)
158173 }
159174 if _, err := db.Exec("PRAGMA journal_mode=WAL"); err != nil {
160175 db.Close()
161161- return nil, fmt.Errorf("set WAL mode: %v", err)
176176+ return nil, fmt.Errorf("set WAL mode: %w", err)
162177 }
163178 if _, err := db.Exec("PRAGMA busy_timeout = 5000"); err != nil {
164179 db.Close()
165165- return nil, fmt.Errorf("set busy timeout: %v", err)
180180+ return nil, fmt.Errorf("set busy timeout: %w", err)
166181 }
167182 if _, err := db.Exec("PRAGMA foreign_keys=ON"); err != nil {
168183 db.Close()
169169- return nil, fmt.Errorf("enable foreign keys: %v", err)
184184+ return nil, fmt.Errorf("enable foreign keys: %w", err)
170185 }
171186 s := &Store{db: db}
172187 if err := s.migrate(); err != nil {
173188 db.Close()
174174- return nil, fmt.Errorf("migrate: %v", err)
189189+ return nil, fmt.Errorf("migrate: %w", err)
175190 }
176191 return s, nil
177192}
···185200}
186201187202func (s *Store) migrate() error {
203203+ // Schema-version guard: refuse to start if the DB has been written to
204204+ // by a newer binary than this one. Without this an old rollback would
205205+ // silently use ALTER TABLE / INSERT DEFAULTS on a schema it doesn't
206206+ // understand and corrupt the data the newer binary persisted (#224).
207207+ if _, err := s.EnsureSchemaVersion(); err != nil {
208208+ return err
209209+ }
210210+188211 // Check if old schema exists (members table has 'domain' column)
189212 var hasDomainCol int
190213 _ = s.db.QueryRow(`SELECT COUNT(*) FROM pragma_table_info('members') WHERE name = 'domain'`).Scan(&hasDomainCol)
191214192215 if hasDomainCol > 0 {
193216 if err := s.migrateToMultiDomain(); err != nil {
194194- return fmt.Errorf("multi-domain migration: %v", err)
217217+ return fmt.Errorf("multi-domain migration: %w", err)
195218 }
196219 }
197220···260283 CREATE INDEX IF NOT EXISTS idx_feedback_events_member ON feedback_events(member_did);
261284262285 CREATE TABLE IF NOT EXISTS bypass_dids (
263263- did TEXT PRIMARY KEY
286286+ did TEXT PRIMARY KEY,
287287+ expires_at TEXT NOT NULL DEFAULT '',
288288+ reason TEXT NOT NULL DEFAULT '',
289289+ created_at TEXT NOT NULL DEFAULT ''
264290 );
265291292292+ -- bypass_audit retains an immutable log of every add/remove so
293293+ -- compromise or accidental mass-bypass can be reconstructed
294294+ -- after the fact. The active bypass set lives in bypass_dids;
295295+ -- this table is append-only.
296296+ CREATE TABLE IF NOT EXISTS bypass_audit (
297297+ id INTEGER PRIMARY KEY AUTOINCREMENT,
298298+ did TEXT NOT NULL,
299299+ action TEXT NOT NULL, -- 'add' or 'remove'
300300+ reason TEXT NOT NULL DEFAULT '',
301301+ expires_at TEXT NOT NULL DEFAULT '', -- only meaningful for 'add'
302302+ created_at TEXT NOT NULL
303303+ );
304304+ CREATE INDEX IF NOT EXISTS bypass_audit_created_at_idx
305305+ ON bypass_audit(created_at);
306306+266307 CREATE TABLE IF NOT EXISTS suppressions (
267308 member_did TEXT NOT NULL,
268309 recipient_addr TEXT NOT NULL,
···357398 _ = s.db.QueryRow(`SELECT COUNT(*) FROM pragma_table_info('member_domains') WHERE name = 'forward_to'`).Scan(&hasForwardTo)
358399 if hasForwardTo == 0 {
359400 if _, err := s.db.Exec(`ALTER TABLE member_domains ADD COLUMN forward_to TEXT NOT NULL DEFAULT ''`); err != nil {
360360- return fmt.Errorf("add forward_to column: %v", err)
401401+ return fmt.Errorf("add forward_to column: %w", err)
361402 }
362403 }
363404···369410 _ = s.db.QueryRow(`SELECT COUNT(*) FROM pragma_table_info('members') WHERE name = 'did_verified'`).Scan(&hasDIDVerified)
370411 if hasDIDVerified == 0 {
371412 if _, err := s.db.Exec(`ALTER TABLE members ADD COLUMN did_verified INTEGER NOT NULL DEFAULT 0`); err != nil {
372372- return fmt.Errorf("add did_verified column: %v", err)
413413+ return fmt.Errorf("add did_verified column: %w", err)
373414 }
374415 }
375416···380421 _ = s.db.QueryRow(`SELECT COUNT(*) FROM pragma_table_info('pending_enrollments') WHERE name = 'contact_email'`).Scan(&hasPendingContactEmail)
381422 if hasPendingContactEmail == 0 {
382423 if _, err := s.db.Exec(`ALTER TABLE pending_enrollments ADD COLUMN contact_email TEXT NOT NULL DEFAULT ''`); err != nil {
383383- return fmt.Errorf("add contact_email to pending_enrollments: %v", err)
424424+ return fmt.Errorf("add contact_email to pending_enrollments: %w", err)
384425 }
385426 }
386427···393434 _ = s.db.QueryRow(`SELECT COUNT(*) FROM pragma_table_info('member_domains') WHERE name = 'contact_email'`).Scan(&hasContactEmail)
394435 if hasContactEmail == 0 {
395436 if _, err := s.db.Exec(`ALTER TABLE member_domains ADD COLUMN contact_email TEXT NOT NULL DEFAULT ''`); err != nil {
396396- return fmt.Errorf("add contact_email column: %v", err)
437437+ return fmt.Errorf("add contact_email column: %w", err)
397438 }
398439 }
399440···406447 _ = s.db.QueryRow(`SELECT COUNT(*) FROM pragma_table_info('member_domains') WHERE name = 'attestation_rkey'`).Scan(&hasAttRkey)
407448 if hasAttRkey == 0 {
408449 if _, err := s.db.Exec(`ALTER TABLE member_domains ADD COLUMN attestation_rkey TEXT NOT NULL DEFAULT ''`); err != nil {
409409- return fmt.Errorf("add attestation_rkey column: %v", err)
450450+ return fmt.Errorf("add attestation_rkey column: %w", err)
410451 }
411452 }
412453 var hasAttAt int
413454 _ = s.db.QueryRow(`SELECT COUNT(*) FROM pragma_table_info('member_domains') WHERE name = 'attestation_published_at'`).Scan(&hasAttAt)
414455 if hasAttAt == 0 {
415456 if _, err := s.db.Exec(`ALTER TABLE member_domains ADD COLUMN attestation_published_at TEXT NOT NULL DEFAULT ''`); err != nil {
416416- return fmt.Errorf("add attestation_published_at column: %v", err)
457457+ return fmt.Errorf("add attestation_published_at column: %w", err)
417458 }
418459 }
419460···423464 _ = s.db.QueryRow(`SELECT COUNT(*) FROM pragma_table_info('members') WHERE name = 'terms_accepted_at'`).Scan(&hasTermsAcceptedAt)
424465 if hasTermsAcceptedAt == 0 {
425466 if _, err := s.db.Exec(`ALTER TABLE members ADD COLUMN terms_accepted_at TEXT NOT NULL DEFAULT ''`); err != nil {
426426- return fmt.Errorf("add terms_accepted_at column: %v", err)
467467+ return fmt.Errorf("add terms_accepted_at column: %w", err)
427468 }
428469 }
429470 var hasTermsVersion int
430471 _ = s.db.QueryRow(`SELECT COUNT(*) FROM pragma_table_info('members') WHERE name = 'terms_version'`).Scan(&hasTermsVersion)
431472 if hasTermsVersion == 0 {
432473 if _, err := s.db.Exec(`ALTER TABLE members ADD COLUMN terms_version TEXT NOT NULL DEFAULT ''`); err != nil {
433433- return fmt.Errorf("add terms_version column: %v", err)
474474+ return fmt.Errorf("add terms_version column: %w", err)
434475 }
435476 }
436477···440481 _ = s.db.QueryRow(`SELECT COUNT(*) FROM pragma_table_info('pending_enrollments') WHERE name = 'terms_accepted'`).Scan(&hasPendingTerms)
441482 if hasPendingTerms == 0 {
442483 if _, err := s.db.Exec(`ALTER TABLE pending_enrollments ADD COLUMN terms_accepted INTEGER NOT NULL DEFAULT 0`); err != nil {
443443- return fmt.Errorf("add terms_accepted to pending_enrollments: %v", err)
484484+ return fmt.Errorf("add terms_accepted to pending_enrollments: %w", err)
444485 }
445486 }
446487···453494 _ = s.db.QueryRow(`SELECT COUNT(*) FROM pragma_table_info('member_domains') WHERE name = 'email_verified'`).Scan(&hasEmailVerified)
454495 if hasEmailVerified == 0 {
455496 if _, err := s.db.Exec(`ALTER TABLE member_domains ADD COLUMN email_verified INTEGER DEFAULT 0`); err != nil {
456456- return fmt.Errorf("add email_verified column: %v", err)
497497+ return fmt.Errorf("add email_verified column: %w", err)
457498 }
458499 }
459500 var hasEmailVerifyToken int
460501 _ = s.db.QueryRow(`SELECT COUNT(*) FROM pragma_table_info('member_domains') WHERE name = 'email_verify_token'`).Scan(&hasEmailVerifyToken)
461502 if hasEmailVerifyToken == 0 {
462503 if _, err := s.db.Exec(`ALTER TABLE member_domains ADD COLUMN email_verify_token TEXT DEFAULT ''`); err != nil {
463463- return fmt.Errorf("add email_verify_token column: %v", err)
504504+ return fmt.Errorf("add email_verify_token column: %w", err)
464505 }
465506 }
466507 var hasEmailVerifyExpires int
467508 _ = s.db.QueryRow(`SELECT COUNT(*) FROM pragma_table_info('member_domains') WHERE name = 'email_verify_expires'`).Scan(&hasEmailVerifyExpires)
468509 if hasEmailVerifyExpires == 0 {
469510 if _, err := s.db.Exec(`ALTER TABLE member_domains ADD COLUMN email_verify_expires TEXT DEFAULT ''`); err != nil {
470470- return fmt.Errorf("add email_verify_expires column: %v", err)
511511+ return fmt.Errorf("add email_verify_expires column: %w", err)
471512 }
472513 }
473514···479520 _ = s.db.QueryRow(`SELECT COUNT(*) FROM pragma_table_info('messages') WHERE name = 'content_fingerprint'`).Scan(&hasFP)
480521 if hasFP == 0 {
481522 if _, err := s.db.Exec(`ALTER TABLE messages ADD COLUMN content_fingerprint TEXT NOT NULL DEFAULT ''`); err != nil {
482482- return fmt.Errorf("add content_fingerprint column: %v", err)
523523+ return fmt.Errorf("add content_fingerprint column: %w", err)
483524 }
484525 if _, err := s.db.Exec(`CREATE INDEX IF NOT EXISTS idx_messages_fingerprint ON messages(member_did, content_fingerprint, created_at)`); err != nil {
485485- return fmt.Errorf("create fingerprint index: %v", err)
526526+ return fmt.Errorf("create fingerprint index: %w", err)
486527 }
487528 }
488529···503544 if err := s.migratePendingNotifications(); err != nil {
504545 return err
505546 }
547547+ // Bypass-DID expiry + audit columns (#213). Existing deployments
548548+ // have a bypass_dids table without expires_at/reason/created_at —
549549+ // add them as defaults so we don't lose any active bypass on the
550550+ // migration. Old rows get expires_at='' which the new ListBypassDIDs
551551+ // treats as "permanent" (matching legacy behavior); operators are
552552+ // expected to re-add with expiry as part of the rollout runbook.
553553+ if err := s.migrateBypassExpiry(); err != nil {
554554+ return err
555555+ }
556556+557557+ // All structural changes applied — record the version so a future
558558+ // downgrade can detect that this binary already touched the DB.
559559+ if err := s.RecordSchemaVersion("baseline (multi-domain, full schema)", ""); err != nil {
560560+ return err
561561+ }
562562+ return nil
563563+}
564564+565565+// migrateBypassExpiry adds expires_at/reason/created_at to bypass_dids
566566+// on existing deployments and creates the bypass_audit table if it
567567+// doesn't already exist (the CREATE TABLE IF NOT EXISTS at top of
568568+// migrate() handles the audit table; this function is for the columns
569569+// on bypass_dids itself).
570570+func (s *Store) migrateBypassExpiry() error {
571571+ type col struct{ name, sql string }
572572+ wanted := []col{
573573+ {"expires_at", `ALTER TABLE bypass_dids ADD COLUMN expires_at TEXT NOT NULL DEFAULT ''`},
574574+ {"reason", `ALTER TABLE bypass_dids ADD COLUMN reason TEXT NOT NULL DEFAULT ''`},
575575+ {"created_at", `ALTER TABLE bypass_dids ADD COLUMN created_at TEXT NOT NULL DEFAULT ''`},
576576+ }
577577+ for _, c := range wanted {
578578+ var n int
579579+ if err := s.db.QueryRow(`SELECT COUNT(*) FROM pragma_table_info('bypass_dids') WHERE name = ?`, c.name).Scan(&n); err != nil {
580580+ return fmt.Errorf("check bypass_dids.%s: %w", c.name, err)
581581+ }
582582+ if n == 0 {
583583+ if _, err := s.db.Exec(c.sql); err != nil {
584584+ return fmt.Errorf("add bypass_dids.%s: %w", c.name, err)
585585+ }
586586+ }
587587+ }
506588 return nil
507589}
508590···511593func (s *Store) migrateToMultiDomain() error {
512594 // Must disable FK checks outside a transaction for SQLite
513595 if _, err := s.db.Exec("PRAGMA foreign_keys=OFF"); err != nil {
514514- return fmt.Errorf("disable FK: %v", err)
596596+ return fmt.Errorf("disable FK: %w", err)
515597 }
516598 defer s.db.Exec("PRAGMA foreign_keys=ON")
517599518600 tx, err := s.db.Begin()
519601 if err != nil {
520520- return fmt.Errorf("begin tx: %v", err)
602602+ return fmt.Errorf("begin tx: %w", err)
521603 }
522604 defer tx.Rollback()
523605···536618 SELECT domain, did, api_key_hash, dkim_rsa_privkey, dkim_ed_privkey, dkim_selector, created_at FROM members;
537619 `)
538620 if err != nil {
539539- return fmt.Errorf("create member_domains: %v", err)
621621+ return fmt.Errorf("create member_domains: %w", err)
540622 }
541623542624 // Recreate members without domain-level columns
···557639 ALTER TABLE members_v2 RENAME TO members;
558640 `)
559641 if err != nil {
560560- return fmt.Errorf("recreate members: %v", err)
642642+ return fmt.Errorf("recreate members: %w", err)
561643 }
562644563645 _, err = tx.Exec(`CREATE INDEX idx_member_domains_did ON member_domains(did)`)
564646 if err != nil {
565565- return fmt.Errorf("create domain index: %v", err)
647647+ return fmt.Errorf("create domain index: %w", err)
566648 }
567649568650 log.Printf("relaystore: migrated to multi-domain schema")
···581663 formatTime(m.CreatedAt), formatTime(m.UpdatedAt),
582664 )
583665 if err != nil {
584584- return fmt.Errorf("insert member: %v", err)
666666+ return fmt.Errorf("insert member: %w", err)
585667 }
586668 return nil
587669}
···592674func (s *Store) EnrollMember(ctx context.Context, member *Member, domain *MemberDomain) error {
593675 tx, err := s.db.BeginTx(ctx, nil)
594676 if err != nil {
595595- return fmt.Errorf("enroll begin tx: %v", err)
677677+ return fmt.Errorf("enroll begin tx: %w", err)
596678 }
597679 defer tx.Rollback()
598680···607689 formatTime(member.CreatedAt), formatTime(member.UpdatedAt),
608690 )
609691 if err != nil {
610610- return fmt.Errorf("enroll insert member: %v", err)
692692+ return fmt.Errorf("enroll insert member: %w", err)
611693 }
612694 }
613695···618700 formatTime(domain.CreatedAt),
619701 )
620702 if err != nil {
621621- return fmt.Errorf("enroll insert domain: %v", err)
703703+ return fmt.Errorf("enroll insert domain: %w", err)
622704 }
623705624706 return tx.Commit()
···641723 ORDER BY m.created_at ASC`,
642724 )
643725 if err != nil {
644644- return nil, fmt.Errorf("list members with domains: %v", err)
726726+ return nil, fmt.Errorf("list members with domains: %w", err)
645727 }
646728 defer rows.Close()
647729···656738 &termsAcceptedAt, &mwd.TermsVersion,
657739 &createdAt, &updatedAt, &domainCSV,
658740 ); err != nil {
659659- return nil, fmt.Errorf("scan member with domains: %v", err)
741741+ return nil, fmt.Errorf("scan member with domains: %w", err)
660742 }
661743 mwd.DIDVerified = didVerified != 0
662744 mwd.TermsAcceptedAt = parseTime(termsAcceptedAt)
···684766 FROM members ORDER BY created_at ASC`,
685767 )
686768 if err != nil {
687687- return nil, fmt.Errorf("list members: %v", err)
769769+ return nil, fmt.Errorf("list members: %w", err)
688770 }
689771 defer rows.Close()
690772···712794func (s *Store) DeleteMember(ctx context.Context, did string) error {
713795 tx, err := s.db.BeginTx(ctx, nil)
714796 if err != nil {
715715- return fmt.Errorf("delete begin tx: %v", err)
797797+ return fmt.Errorf("delete begin tx: %w", err)
716798 }
717799 defer tx.Rollback()
718800···724806 }
725807 for _, d := range deletes {
726808 if _, err := tx.ExecContext(ctx, "DELETE FROM "+d.table+" WHERE "+d.col+" = ?", did); err != nil {
727727- return fmt.Errorf("delete from %s: %v", d.table, err)
809809+ return fmt.Errorf("delete from %s: %w", d.table, err)
728810 }
729811 }
730812 if _, err := tx.ExecContext(ctx, "DELETE FROM member_domains WHERE did = ?", did); err != nil {
731731- return fmt.Errorf("delete member_domains: %v", err)
813813+ return fmt.Errorf("delete member_domains: %w", err)
732814 }
733815 if _, err := tx.ExecContext(ctx, "DELETE FROM members WHERE did = ?", did); err != nil {
734734- return fmt.Errorf("delete member: %v", err)
816816+ return fmt.Errorf("delete member: %w", err)
735817 }
736818737819 return tx.Commit()
···742824 `UPDATE members SET send_count = send_count + 1, updated_at = ? WHERE did = ?`,
743825 formatTime(time.Now().UTC()), did,
744826 )
745745- return err
827827+ return s.recordIfBusy("increment_send_count", err)
746828}
747829748830// scanner is satisfied by both *sql.Row and *sql.Rows.
···765847 return nil, nil
766848 }
767849 if err != nil {
768768- return nil, fmt.Errorf("scan member: %v", err)
850850+ return nil, fmt.Errorf("scan member: %w", err)
769851 }
770852771853 m.DIDVerified = didVerified != 0
···785867 formatTime(d.CreatedAt),
786868 )
787869 if err != nil {
788788- return fmt.Errorf("insert member domain: %v", err)
870870+ return fmt.Errorf("insert member domain: %w", err)
789871 }
790872 return nil
791873}
···802884 return nil, nil
803885 }
804886 if err != nil {
805805- return nil, fmt.Errorf("get member domain: %v", err)
887887+ return nil, fmt.Errorf("get member domain: %w", err)
806888 }
807889 d.EmailVerified = emailVerified != 0
808890 d.CreatedAt = parseTime(createdAt)
···815897 FROM member_domains WHERE did = ? ORDER BY created_at ASC`, did,
816898 )
817899 if err != nil {
818818- return nil, fmt.Errorf("list member domains: %v", err)
900900+ return nil, fmt.Errorf("list member domains: %w", err)
819901 }
820902 defer rows.Close()
821903···825907 var createdAt string
826908 var emailVerified int
827909 if err := rows.Scan(&d.Domain, &d.DID, &d.APIKeyHash, &d.DKIMRSAPriv, &d.DKIMEdPriv, &d.DKIMSelector, &d.ForwardTo, &d.ContactEmail, &emailVerified, &createdAt); err != nil {
828828- return nil, fmt.Errorf("scan member domain: %v", err)
910910+ return nil, fmt.Errorf("scan member domain: %w", err)
829911 }
830912 d.EmailVerified = emailVerified != 0
831913 d.CreatedAt = parseTime(createdAt)
···844926 hash, domain,
845927 )
846928 if err != nil {
847847- return fmt.Errorf("update api_key_hash: %v", err)
929929+ return fmt.Errorf("update api_key_hash: %w", err)
848930 }
849931 n, err := res.RowsAffected()
850932 if err != nil {
851851- return fmt.Errorf("update api_key_hash rows: %v", err)
933933+ return fmt.Errorf("update api_key_hash rows: %w", err)
852934 }
853935 if n == 0 {
854936 return fmt.Errorf("domain %q not registered", domain)
···867949 contactEmail, domain,
868950 )
869951 if err != nil {
870870- return fmt.Errorf("update contact_email: %v", err)
952952+ return fmt.Errorf("update contact_email: %w", err)
871953 }
872954 n, err := res.RowsAffected()
873955 if err != nil {
874874- return fmt.Errorf("update contact_email rows: %v", err)
956956+ return fmt.Errorf("update contact_email rows: %w", err)
875957 }
876958 if n == 0 {
877959 return fmt.Errorf("domain %q not registered", domain)
···890972 token, formatTime(expiresAt), domain,
891973 )
892974 if err != nil {
893893- return fmt.Errorf("set email verify token: %v", err)
975975+ return fmt.Errorf("set email verify token: %w", err)
894976 }
895977 n, err := res.RowsAffected()
896978 if err != nil {
897897- return fmt.Errorf("set email verify token rows: %v", err)
979979+ return fmt.Errorf("set email verify token rows: %w", err)
898980 }
899981 if n == 0 {
900982 return fmt.Errorf("domain %q not registered", domain)
···9201002 return "", fmt.Errorf("verification token not found")
9211003 }
9221004 if err != nil {
923923- return "", fmt.Errorf("verify email lookup: %v", err)
10051005+ return "", fmt.Errorf("verify email lookup: %w", err)
9241006 }
9251007 expiresAt := parseTime(expiresAtStr)
9261008 if !expiresAt.IsZero() && time.Now().UTC().After(expiresAt) {
···9361018 domain,
9371019 )
9381020 if err != nil {
939939- return "", fmt.Errorf("mark email verified: %v", err)
10211021+ return "", fmt.Errorf("mark email verified: %w", err)
9401022 }
9411023 return domain, nil
9421024}
···9531035 return false, nil
9541036 }
9551037 if err != nil {
956956- return false, fmt.Errorf("is email verified: %v", err)
10381038+ return false, fmt.Errorf("is email verified: %w", err)
9571039 }
9581040 return verified != 0, nil
9591041}
···9671049 domain,
9681050 )
9691051 if err != nil {
970970- return fmt.Errorf("reset email verification: %v", err)
10521052+ return fmt.Errorf("reset email verification: %w", err)
9711053 }
9721054 return nil
9731055}
···9951077 return nil, nil, nil
9961078 }
9971079 if err != nil {
998998- return nil, nil, fmt.Errorf("get member by domain: %v", err)
10801080+ return nil, nil, fmt.Errorf("get member by domain: %w", err)
9991081 }
1000108210011083 m.DIDVerified = didVerified != 0
···10161098 forwardTo, domain,
10171099 )
10181100 if err != nil {
10191019- return fmt.Errorf("set forward_to: %v", err)
11011101+ return fmt.Errorf("set forward_to: %w", err)
10201102 }
10211103 n, err := res.RowsAffected()
10221104 if err != nil {
10231023- return fmt.Errorf("set forward_to rows: %v", err)
11051105+ return fmt.Errorf("set forward_to rows: %w", err)
10241106 }
10251107 if n == 0 {
10261108 return fmt.Errorf("domain %q not registered", domain)
···10541136 formatTime(m.CreatedAt), formatTime(m.DeliveredAt), m.ContentFingerprint,
10551137 )
10561138 if err != nil {
10571057- return 0, fmt.Errorf("insert message: %v", err)
11391139+ return 0, fmt.Errorf("insert message: %w", s.recordIfBusy("insert_message", err))
10581140 }
10591141 return res.LastInsertId()
10601142}
···10801162 if status == MsgSent {
10811163 deliveredAt = formatTime(time.Now().UTC())
10821164 }
10831083- _, err := s.db.ExecContext(ctx,
11651165+ res, err := s.db.ExecContext(ctx,
10841166 `UPDATE messages SET status = ?, smtp_code = ?, delivered_at = ? WHERE id = ?`,
10851167 status, smtpCode, deliveredAt, id,
10861168 )
10871087- return err
11691169+ if err != nil {
11701170+ return s.recordIfBusy("update_message_status", err)
11711171+ }
11721172+ rows, err := res.RowsAffected()
11731173+ if err != nil {
11741174+ return fmt.Errorf("rows affected: %w", err)
11751175+ }
11761176+ if rows == 0 {
11771177+ return ErrMessageNotFound
11781178+ }
11791179+ return nil
11801180+}
11811181+11821182+// ListQueuedMessageIDsOlderThan returns message row IDs whose status
11831183+// is still "queued" and whose created_at is at least minAge old.
11841184+// Used by the orphan-reconciliation janitor to find rows whose spool
11851185+// file vanished (Enqueue failure mid-batch, manual spool wipe, FS
11861186+// corruption). minAge prevents racing with rows that were just
11871187+// inserted but haven't had their spool file landed yet.
11881188+func (s *Store) ListQueuedMessageIDsOlderThan(ctx context.Context, minAge time.Duration, limit int) ([]int64, error) {
11891189+ if limit <= 0 {
11901190+ limit = 100
11911191+ }
11921192+ cutoff := formatTime(time.Now().UTC().Add(-minAge))
11931193+ rows, err := s.db.QueryContext(ctx,
11941194+ `SELECT id FROM messages WHERE status = ? AND created_at < ? ORDER BY id ASC LIMIT ?`,
11951195+ MsgQueued, cutoff, limit,
11961196+ )
11971197+ if err != nil {
11981198+ return nil, fmt.Errorf("query queued: %w", err)
11991199+ }
12001200+ defer rows.Close()
12011201+ var ids []int64
12021202+ for rows.Next() {
12031203+ var id int64
12041204+ if err := rows.Scan(&id); err != nil {
12051205+ return nil, fmt.Errorf("scan queued id: %w", err)
12061206+ }
12071207+ ids = append(ids, id)
12081208+ }
12091209+ if err := rows.Err(); err != nil {
12101210+ return nil, fmt.Errorf("iter queued ids: %w", err)
12111211+ }
12121212+ return ids, nil
10881213}
1089121410901215func scanMessage(sc scanner) (*Message, error) {
···10991224 return nil, nil
11001225 }
11011226 if err != nil {
11021102- return nil, fmt.Errorf("scan message: %v", err)
12271227+ return nil, fmt.Errorf("scan message: %w", err)
11031228 }
1104122911051230 m.CreatedAt = parseTime(createdAt)
···11501275 if err == sql.ErrNoRows {
11511276 current = 0
11521277 } else if err != nil {
11531153- return 0, fmt.Errorf("read counter: %v", err)
12781278+ return 0, fmt.Errorf("read counter: %w", err)
11541279 }
1155128011561281 if current+count > limit {
···11651290 did, windowType, formatTime(windowStart), count, count,
11661291 )
11671292 if err != nil {
11681168- return current, fmt.Errorf("increment counter: %v", err)
12931293+ return current, fmt.Errorf("increment counter: %w", err)
11691294 }
1170129511711296 return current, nil
···12021327 e.MemberDID, e.EventType, e.MessageID, e.Recipient, e.Details, formatTime(e.CreatedAt),
12031328 )
12041329 if err != nil {
12051205- return 0, fmt.Errorf("insert feedback event: %v", err)
13301330+ return 0, fmt.Errorf("insert feedback event: %w", err)
12061331 }
12071332 return res.LastInsertId()
12081333}
···12161341 memberDID, MsgSent, MsgBounced, formatTime(since),
12171342 ).Scan(&total)
12181343 if err != nil {
12191219- return 0, 0, fmt.Errorf("count terminal: %v", err)
13441344+ return 0, 0, fmt.Errorf("count terminal: %w", err)
12201345 }
1221134612221347 err = s.db.QueryRowContext(ctx,
···12241349 memberDID, MsgBounced, formatTime(since),
12251350 ).Scan(&bounced)
12261351 if err != nil {
12271227- return 0, 0, fmt.Errorf("count bounced: %v", err)
13521352+ return 0, 0, fmt.Errorf("count bounced: %w", err)
12281353 }
1229135412301355 return total, bounced, nil
12311356}
1232135713581358+// GetDailySendCounts returns per-day terminal (sent+bounced) message counts
13591359+// for the last n days, oldest-to-newest. Days with zero sends are included
13601360+// so callers get a fixed-length slice suitable for sparklines.
13611361+func (s *Store) GetDailySendCounts(ctx context.Context, memberDID string, days int) ([]int64, error) {
13621362+ if days <= 0 {
13631363+ days = 14
13641364+ }
13651365+ // Compute the inclusive cutoff in Go so SQLite parameter binding
13661366+ // works cleanly. date('now', 'localtime', 'start of day', '-13 days')
13671367+ // gives the first instant of the oldest day we care about.
13681368+ cutoff := time.Now().UTC().AddDate(0, 0, -(days - 1)).Format("2006-01-02")
13691369+13701370+ rows, err := s.db.QueryContext(ctx,
13711371+ `SELECT date(created_at) as day, COUNT(*)
13721372+ FROM messages
13731373+ WHERE member_did = ? AND status IN (?, ?) AND date(created_at) >= ?
13741374+ GROUP BY day
13751375+ ORDER BY day ASC`,
13761376+ memberDID, MsgSent, MsgBounced, cutoff,
13771377+ )
13781378+ if err != nil {
13791379+ return nil, fmt.Errorf("daily send counts: %w", err)
13801380+ }
13811381+ defer rows.Close()
13821382+13831383+ counts := make(map[string]int64)
13841384+ for rows.Next() {
13851385+ var day string
13861386+ var c int64
13871387+ if err := rows.Scan(&day, &c); err != nil {
13881388+ return nil, fmt.Errorf("scan daily count: %w", err)
13891389+ }
13901390+ counts[day] = c
13911391+ }
13921392+ if err := rows.Err(); err != nil {
13931393+ return nil, fmt.Errorf("daily send counts rows: %w", err)
13941394+ }
13951395+13961396+ // Fill in zero days so the slice is exactly `days` long.
13971397+ out := make([]int64, days)
13981398+ now := time.Now().UTC()
13991399+ for i := 0; i < days; i++ {
14001400+ day := now.AddDate(0, 0, -(days-1-i)).Format("2006-01-02")
14011401+ out[i] = counts[day]
14021402+ }
14031403+ return out, nil
14041404+}
14051405+14061406+// GetComplaintCount returns the number of feedback_events with event_type
14071407+// 'complaint' for the member since the given time.
14081408+func (s *Store) GetComplaintCount(ctx context.Context, memberDID string, since time.Time) (int64, error) {
14091409+ var n int64
14101410+ err := s.db.QueryRowContext(ctx,
14111411+ `SELECT COUNT(*) FROM feedback_events
14121412+ WHERE member_did = ? AND event_type = ? AND created_at >= ?`,
14131413+ memberDID, "complaint", formatTime(since),
14141414+ ).Scan(&n)
14151415+ if err != nil {
14161416+ return 0, fmt.Errorf("count complaints: %w", err)
14171417+ }
14181418+ return n, nil
14191419+}
14201420+12331421// GetUniqueRecipientDomainsSince counts DISTINCT recipient domains a member
12341422// has sent to since the given time. Used by the DomainSpray detection rule —
12351423// legitimate transactional mail usually goes to a small handful of domains;
···12461434 memberDID, formatTime(since),
12471435 ).Scan(&n)
12481436 if err != nil {
12491249- return 0, fmt.Errorf("count unique recipient domains: %v", err)
14371437+ return 0, fmt.Errorf("count unique recipient domains: %w", err)
12501438 }
12511439 return n, nil
12521440}
···12661454 memberDID, formatTime(since),
12671455 ).Scan(&n)
12681456 if err != nil {
12691269- return 0, fmt.Errorf("count sends since: %v", err)
14571457+ return 0, fmt.Errorf("count sends since: %w", err)
12701458 }
12711459 return n, nil
12721460}
···12881476 memberDID, fingerprint, formatTime(since),
12891477 ).Scan(&n)
12901478 if err != nil {
12911291- return 0, fmt.Errorf("count same-content recipients: %v", err)
14791479+ return 0, fmt.Errorf("count same-content recipients: %w", err)
12921480 }
12931481 return n, nil
12941482}
···13041492 MsgSent, MsgBounced, formatTime(before),
13051493 )
13061494 if err != nil {
13071307- return 0, fmt.Errorf("purge old messages: %v", err)
14951495+ return 0, fmt.Errorf("purge old messages: %w", err)
13081496 }
13091497 return res.RowsAffected()
13101498}
1311149913121500// --- Bypass DIDs ---
1313150113141314-// InsertBypassDID adds a DID to the label bypass list. Idempotent.
13151315-func (s *Store) InsertBypassDID(ctx context.Context, did string) error {
13161316- _, err := s.db.ExecContext(ctx,
13171317- `INSERT OR IGNORE INTO bypass_dids (did) VALUES (?)`, did,
13181318- )
13191319- return err
15021502+// BypassEntry pairs a bypassed DID with its lifecycle metadata. Empty
15031503+// expiresAt means "permanent" — supported only for legacy entries
15041504+// migrated from the pre-#213 schema; new entries always carry an
15051505+// explicit expiry capped at 30 days.
15061506+type BypassEntry struct {
15071507+ DID string
15081508+ ExpiresAt time.Time // zero value = legacy permanent
15091509+ Reason string
15101510+ CreatedAt time.Time
15111511+}
15121512+15131513+// BypassAuditEntry is one append-only row in the bypass_audit table.
15141514+// Action is "add" or "remove". Used for incident reconstruction; not
15151515+// served on the dashboard.
15161516+type BypassAuditEntry struct {
15171517+ ID int64
15181518+ DID string
15191519+ Action string
15201520+ Reason string
15211521+ ExpiresAt time.Time
15221522+ CreatedAt time.Time
15231523+}
15241524+15251525+// InsertBypassDID adds a DID to the label bypass list and writes a
15261526+// matching audit row in the same transaction. expiresAt may be zero
15271527+// only for the legacy permanent path used by migration restoration;
15281528+// new admin-driven calls always pass a non-zero expiry. Idempotent
15291529+// in the bypass_dids set (INSERT OR REPLACE) but every call appends
15301530+// to bypass_audit so a re-issue is observable.
15311531+func (s *Store) InsertBypassDID(ctx context.Context, did string, expiresAt time.Time, reason string) error {
15321532+ now := time.Now().UTC()
15331533+ tx, err := s.db.BeginTx(ctx, nil)
15341534+ if err != nil {
15351535+ return fmt.Errorf("begin: %w", err)
15361536+ }
15371537+ defer tx.Rollback()
15381538+ if _, err := tx.ExecContext(ctx,
15391539+ `INSERT INTO bypass_dids (did, expires_at, reason, created_at)
15401540+ VALUES (?, ?, ?, ?)
15411541+ ON CONFLICT(did) DO UPDATE SET
15421542+ expires_at = excluded.expires_at,
15431543+ reason = excluded.reason,
15441544+ created_at = excluded.created_at`,
15451545+ did, formatTime(expiresAt), reason, formatTime(now),
15461546+ ); err != nil {
15471547+ return fmt.Errorf("insert bypass: %w", err)
15481548+ }
15491549+ if _, err := tx.ExecContext(ctx,
15501550+ `INSERT INTO bypass_audit (did, action, reason, expires_at, created_at)
15511551+ VALUES (?, 'add', ?, ?, ?)`,
15521552+ did, reason, formatTime(expiresAt), formatTime(now),
15531553+ ); err != nil {
15541554+ return fmt.Errorf("insert audit: %w", err)
15551555+ }
15561556+ return tx.Commit()
15571557+}
15581558+15591559+// DeleteBypassDID removes a DID from the label bypass list and writes
15601560+// an audit row noting the removal. reason names the trigger ("manual",
15611561+// "expired", etc.) so post-hoc analysis can distinguish operator
15621562+// action from janitor cleanup.
15631563+func (s *Store) DeleteBypassDID(ctx context.Context, did, reason string) error {
15641564+ now := time.Now().UTC()
15651565+ tx, err := s.db.BeginTx(ctx, nil)
15661566+ if err != nil {
15671567+ return fmt.Errorf("begin: %w", err)
15681568+ }
15691569+ defer tx.Rollback()
15701570+ if _, err := tx.ExecContext(ctx, `DELETE FROM bypass_dids WHERE did = ?`, did); err != nil {
15711571+ return fmt.Errorf("delete bypass: %w", err)
15721572+ }
15731573+ if _, err := tx.ExecContext(ctx,
15741574+ `INSERT INTO bypass_audit (did, action, reason, expires_at, created_at)
15751575+ VALUES (?, 'remove', ?, '', ?)`,
15761576+ did, reason, formatTime(now),
15771577+ ); err != nil {
15781578+ return fmt.Errorf("insert audit: %w", err)
15791579+ }
15801580+ return tx.Commit()
13201581}
1321158213221322-// DeleteBypassDID removes a DID from the label bypass list.
13231323-func (s *Store) DeleteBypassDID(ctx context.Context, did string) error {
13241324- _, err := s.db.ExecContext(ctx,
13251325- `DELETE FROM bypass_dids WHERE did = ?`, did,
15831583+// PurgeExpiredBypassDIDs deletes bypass entries whose expires_at is
15841584+// non-empty and in the past, writing 'remove' audit rows with reason
15851585+// "expired" so the dashboard can distinguish janitor evictions from
15861586+// operator removals. Returns the number of evicted DIDs.
15871587+//
15881588+// Legacy entries with expires_at='' are NOT touched — they were
15891589+// migrated from a permanent-bypass schema and removing them would be
15901590+// a behavior change the operator hasn't authorized. Convert legacy
15911591+// entries by re-adding with explicit expiry.
15921592+func (s *Store) PurgeExpiredBypassDIDs(ctx context.Context) (int, error) {
15931593+ now := time.Now().UTC()
15941594+ cutoff := formatTime(now)
15951595+ tx, err := s.db.BeginTx(ctx, nil)
15961596+ if err != nil {
15971597+ return 0, fmt.Errorf("begin: %w", err)
15981598+ }
15991599+ defer tx.Rollback()
16001600+ rows, err := tx.QueryContext(ctx,
16011601+ `SELECT did FROM bypass_dids WHERE expires_at != '' AND expires_at < ?`,
16021602+ cutoff,
13261603 )
13271327- return err
16041604+ if err != nil {
16051605+ return 0, fmt.Errorf("scan expired: %w", err)
16061606+ }
16071607+ var dids []string
16081608+ for rows.Next() {
16091609+ var d string
16101610+ if err := rows.Scan(&d); err != nil {
16111611+ rows.Close()
16121612+ return 0, fmt.Errorf("scan did: %w", err)
16131613+ }
16141614+ dids = append(dids, d)
16151615+ }
16161616+ rows.Close()
16171617+ if err := rows.Err(); err != nil {
16181618+ return 0, fmt.Errorf("iter expired: %w", err)
16191619+ }
16201620+ for _, d := range dids {
16211621+ if _, err := tx.ExecContext(ctx, `DELETE FROM bypass_dids WHERE did = ?`, d); err != nil {
16221622+ return 0, fmt.Errorf("delete %s: %w", d, err)
16231623+ }
16241624+ if _, err := tx.ExecContext(ctx,
16251625+ `INSERT INTO bypass_audit (did, action, reason, expires_at, created_at)
16261626+ VALUES (?, 'remove', 'expired', '', ?)`,
16271627+ d, cutoff,
16281628+ ); err != nil {
16291629+ return 0, fmt.Errorf("audit %s: %w", d, err)
16301630+ }
16311631+ }
16321632+ if err := tx.Commit(); err != nil {
16331633+ return 0, fmt.Errorf("commit: %w", err)
16341634+ }
16351635+ return len(dids), nil
13281636}
1329163713301330-// ListBypassDIDs returns all DIDs in the label bypass list.
16381638+// ListBypassDIDs returns all DIDs in the label bypass list, excluding
16391639+// entries whose expiry has already passed. Legacy entries with
16401640+// expires_at='' are always returned (permanent grandfather).
13311641func (s *Store) ListBypassDIDs(ctx context.Context) ([]string, error) {
13321332- rows, err := s.db.QueryContext(ctx, `SELECT did FROM bypass_dids ORDER BY did`)
16421642+ now := formatTime(time.Now().UTC())
16431643+ rows, err := s.db.QueryContext(ctx,
16441644+ `SELECT did FROM bypass_dids
16451645+ WHERE expires_at = '' OR expires_at >= ?
16461646+ ORDER BY did`,
16471647+ now,
16481648+ )
13331649 if err != nil {
13341650 return nil, err
13351651 }
···13691685 memberDID, strings.ToLower(recipient), source, formatTime(time.Now().UTC()),
13701686 )
13711687 if err != nil {
13721372- return fmt.Errorf("insert suppression: %v", err)
16881688+ return fmt.Errorf("insert suppression: %w", err)
13731689 }
13741690 return nil
13751691}
···13851701 return false, nil
13861702 }
13871703 if err != nil {
13881388- return false, fmt.Errorf("is suppressed: %v", err)
17041704+ return false, fmt.Errorf("is suppressed: %w", err)
13891705 }
13901706 return true, nil
13911707}
···14001716 memberDID,
14011717 )
14021718 if err != nil {
14031403- return nil, fmt.Errorf("list suppressions: %v", err)
17191719+ return nil, fmt.Errorf("list suppressions: %w", err)
14041720 }
14051721 defer rows.Close()
14061722···14251741 memberDID, strings.ToLower(recipient),
14261742 )
14271743 if err != nil {
14281428- return fmt.Errorf("delete suppression: %v", err)
17441744+ return fmt.Errorf("delete suppression: %w", err)
14291745 }
14301746 return nil
14311747}
···14361752 var st Stats
14371753 err := s.db.QueryRowContext(ctx, `SELECT COUNT(*) FROM members`).Scan(&st.Members)
14381754 if err != nil {
14391439- return st, fmt.Errorf("count members: %v", err)
17551755+ return st, fmt.Errorf("count members: %w", err)
14401756 }
14411757 err = s.db.QueryRowContext(ctx, `SELECT COUNT(*) FROM member_domains`).Scan(&st.Domains)
14421758 if err != nil {
14431443- return st, fmt.Errorf("count domains: %v", err)
17591759+ return st, fmt.Errorf("count domains: %w", err)
14441760 }
14451761 err = s.db.QueryRowContext(ctx, `SELECT COUNT(*) FROM messages`).Scan(&st.Messages)
14461762 if err != nil {
14471447- return st, fmt.Errorf("count messages: %v", err)
17631763+ return st, fmt.Errorf("count messages: %w", err)
14481764 }
14491765 err = s.db.QueryRowContext(ctx, `SELECT COUNT(*) FROM messages WHERE status = ?`, MsgBounced).Scan(&st.Bounces)
14501766 if err != nil {
14511451- return st, fmt.Errorf("count bounces: %v", err)
17671767+ return st, fmt.Errorf("count bounces: %w", err)
14521768 }
14531769 return st, nil
14541770}
···14611777 var active, suspended, pending int64
14621778 err := s.db.QueryRowContext(ctx, `SELECT COUNT(*) FROM members WHERE status = ?`, StatusActive).Scan(&active)
14631779 if err != nil {
14641464- return 0, 0, 0, fmt.Errorf("count active members: %v", err)
17801780+ return 0, 0, 0, fmt.Errorf("count active members: %w", err)
14651781 }
14661782 err = s.db.QueryRowContext(ctx, `SELECT COUNT(*) FROM members WHERE status = ?`, StatusSuspended).Scan(&suspended)
14671783 if err != nil {
14681468- return 0, 0, 0, fmt.Errorf("count suspended members: %v", err)
17841784+ return 0, 0, 0, fmt.Errorf("count suspended members: %w", err)
14691785 }
14701786 err = s.db.QueryRowContext(ctx, `SELECT COUNT(*) FROM members WHERE status = ?`, StatusPending).Scan(&pending)
14711787 if err != nil {
14721472- return 0, 0, 0, fmt.Errorf("count pending members: %v", err)
17881788+ return 0, 0, 0, fmt.Errorf("count pending members: %w", err)
14731789 }
14741790 return active, suspended, pending, nil
14751791}
···14861802 rkey, formatTime(at), domain,
14871803 )
14881804 if err != nil {
14891489- return fmt.Errorf("set attestation published: %v", err)
18051805+ return fmt.Errorf("set attestation published: %w", err)
14901806 }
14911807 n, err := res.RowsAffected()
14921808 if err != nil {
14931493- return fmt.Errorf("set attestation rows: %v", err)
18091809+ return fmt.Errorf("set attestation rows: %w", err)
14941810 }
14951811 if n == 0 {
14961812 return fmt.Errorf("domain %q not registered", domain)
···15191835 return nil, nil
15201836 }
15211837 if err != nil {
15221522- return nil, fmt.Errorf("get attestation state: %v", err)
18381838+ return nil, fmt.Errorf("get attestation state: %w", err)
15231839 }
15241840 return &AttestationState{RKey: rkey, PublishedAt: parseTime(publishedAt)}, nil
15251841}
···15671883 formatTime(r.ExpiresAt), formatTime(r.CreatedAt),
15681884 )
15691885 if err != nil {
15701570- return fmt.Errorf("save oauth auth request: %v", err)
18861886+ return fmt.Errorf("save oauth auth request: %w", err)
15711887 }
15721888 return nil
15731889}
···15951911 return nil, nil
15961912 }
15971913 if err != nil {
15981598- return nil, fmt.Errorf("get oauth auth request: %v", err)
19141914+ return nil, fmt.Errorf("get oauth auth request: %w", err)
15991915 }
16001916 r.ExpiresAt = parseTime(expiresAt)
16011917 r.CreatedAt = parseTime(createdAt)
···16221938 return "", fmt.Errorf("no pending request for request_uri")
16231939 }
16241940 if err != nil {
16251625- return "", fmt.Errorf("find state by request_uri: %v", err)
19411941+ return "", fmt.Errorf("find state by request_uri: %w", err)
16261942 }
16271943 return state, nil
16281944}
···16391955 accountDID, domain, string(attestation), formatTime(expiresAt), state,
16401956 )
16411957 if err != nil {
16421642- return fmt.Errorf("augment oauth auth request: %v", err)
19581958+ return fmt.Errorf("augment oauth auth request: %w", err)
16431959 }
16441960 n, err := res.RowsAffected()
16451961 if err != nil {
16461646- return fmt.Errorf("augment oauth rows: %v", err)
19621962+ return fmt.Errorf("augment oauth rows: %w", err)
16471963 }
16481964 if n == 0 {
16491965 return fmt.Errorf("no pending row for state")
···16581974 `DELETE FROM oauth_auth_requests WHERE state = ?`, state,
16591975 )
16601976 if err != nil {
16611661- return fmt.Errorf("delete oauth auth request: %v", err)
19771977+ return fmt.Errorf("delete oauth auth request: %w", err)
16621978 }
16631979 return nil
16641980}
···17112027 formatTime(sess.CreatedAt), formatTime(sess.UpdatedAt),
17122028 )
17132029 if err != nil {
17141714- return fmt.Errorf("save oauth session: %v", err)
20302030+ return fmt.Errorf("save oauth session: %w", err)
17152031 }
17162032 return nil
17172033}
···17392055 return nil, nil
17402056 }
17412057 if err != nil {
17421742- return nil, fmt.Errorf("get oauth session: %v", err)
20582058+ return nil, fmt.Errorf("get oauth session: %w", err)
17432059 }
17442060 if scopes != "" {
17452061 sess.Scopes = strings.Split(scopes, " ")
···17562072 did, sessionID,
17572073 )
17582074 if err != nil {
17591759- return fmt.Errorf("delete oauth session: %v", err)
20752075+ return fmt.Errorf("delete oauth session: %w", err)
17602076 }
17612077 return nil
17622078}
···17692085 formatTime(now),
17702086 )
17712087 if err != nil {
17721772- return 0, fmt.Errorf("cleanup expired oauth: %v", err)
20882088+ return 0, fmt.Errorf("cleanup expired oauth: %w", err)
17732089 }
17742090 return res.RowsAffected()
17752091}
···18092125 n.MemberDID, n.Action, n.Actor, n.Note, formatTime(reviewedAt),
18102126 )
18112127 if err != nil {
18121812- return 0, fmt.Errorf("insert member review note: %v", err)
21282128+ return 0, fmt.Errorf("insert member review note: %w", err)
18132129 }
18142130 return res.LastInsertId()
18152131}
···18232139 did,
18242140 )
18252141 if err != nil {
18261826- return nil, fmt.Errorf("list member review notes: %v", err)
21422142+ return nil, fmt.Errorf("list member review notes: %w", err)
18272143 }
18282144 defer rows.Close()
18292145···18322148 var n MemberReviewNote
18332149 var reviewedAt string
18342150 if err := rows.Scan(&n.ID, &n.MemberDID, &n.Action, &n.Actor, &n.Note, &reviewedAt); err != nil {
18351835- return nil, fmt.Errorf("scan review note: %v", err)
21512151+ return nil, fmt.Errorf("scan review note: %w", err)
18362152 }
18372153 n.ReviewedAt = parseTime(reviewedAt)
18382154 out = append(out, n)
···18522168 ReviewActionReactivated, formatTime(since),
18532169 )
18542170 if err != nil {
18551855- return nil, fmt.Errorf("list reactivated dids: %v", err)
21712171+ return nil, fmt.Errorf("list reactivated dids: %w", err)
18562172 }
18572173 defer rows.Close()
18582174···18832199 senderDID, formatTime(since),
18842200 ).Scan(&n)
18852201 if err != nil {
18861886- return 0, fmt.Errorf("count relay_rejected since: %v", err)
22022202+ return 0, fmt.Errorf("count relay_rejected since: %w", err)
18872203 }
18882204 return n, nil
18892205}
···19612277 formatTime(p.CreatedAt), formatTime(p.ExpiresAt),
19622278 )
19632279 if err != nil {
19641964- return fmt.Errorf("create pending enrollment: %v", err)
22802280+ return fmt.Errorf("create pending enrollment: %w", err)
19652281 }
19662282 return nil
19672283}
···19832299 return nil, nil
19842300 }
19852301 if err != nil {
19861986- return nil, fmt.Errorf("get pending enrollment: %v", err)
23022302+ return nil, fmt.Errorf("get pending enrollment: %w", err)
19872303 }
19882304 p.TermsAccepted = termsAccepted != 0
19892305 p.CreatedAt = parseTime(createdAt)
···19992315 `DELETE FROM pending_enrollments WHERE token = ?`, token,
20002316 )
20012317 if err != nil {
20022002- return fmt.Errorf("delete pending enrollment: %v", err)
23182318+ return fmt.Errorf("delete pending enrollment: %w", err)
20032319 }
20042320 return nil
20052321}
···20132329 formatTime(cutoff),
20142330 )
20152331 if err != nil {
20162016- return 0, fmt.Errorf("clean expired pending enrollments: %v", err)
23322332+ return 0, fmt.Errorf("clean expired pending enrollments: %w", err)
20172333 }
20182334 return res.RowsAffected()
20192335}