···1616 return fmt.Sprintf("https://imgs.blue/%s/%s", did, cid)
1717}
18181919+// activeTakedownClause returns a SQL fragment ready to drop into a `WHERE NOT
2020+// EXISTS (...)` filter for excluding rows whose `(did, repository)` pair is currently
2121+// taken down. The `alias` argument is the outer table alias (e.g. "m" for manifests,
2222+// "lm" for latest_manifests) and must already be in scope at the use site.
2323+//
2424+// Mirrors the semantics of `IsTakenDown` (defined in labels.go) so listings stay
2525+// consistent with the per-repo page check: a label only counts as active when it has
2626+// neg=0, no newer neg=1 row with the same (src, uri, val), and a non-expired `exp`.
2727+// Without these clauses listings hide a repo forever once you've ever taken it down,
2828+// even after a reversal.
2929+func activeTakedownClause(alias string) string {
3030+ return `NOT EXISTS (
3131+ SELECT 1 FROM labels l1
3232+ WHERE l1.subject_did = ` + alias + `.did
3333+ AND (l1.subject_repo = ` + alias + `.repository OR l1.subject_repo = '')
3434+ AND l1.val = '!takedown' AND l1.neg = 0
3535+ AND NOT EXISTS (
3636+ SELECT 1 FROM labels l2
3737+ WHERE l2.src = l1.src AND l2.uri = l1.uri AND l2.val = l1.val
3838+ AND l2.neg = 1 AND l2.id > l1.id
3939+ )
4040+ AND (l1.exp IS NULL OR l1.exp > CURRENT_TIMESTAMP)
4141+ )`
4242+}
4343+1944// accessibleHoldsSubquery returns SQL that evaluates to the set of hold DIDs
2045// the viewer is allowed to see in listings. Requires the viewerDID to be
2146// passed twice as query arguments (once for the owner_did check and once
···107132 WHERE ra.did = lm.did AND ra.repository = lm.repository
108133 AND ra.value LIKE ? ESCAPE '\'
109134 ))
110110- AND NOT EXISTS (
111111- SELECT 1 FROM labels
112112- WHERE (subject_did = lm.did AND (subject_repo = lm.repository OR subject_repo = ''))
113113- AND val = '!takedown' AND neg = 0
114114- )
135135+ AND ` + activeTakedownClause("lm") + `
115136 ),
116137 repo_stats AS (
117138 SELECT
···21222143 JOIN users u ON m.did = u.did
21232144 LEFT JOIN repository_stats rs ON m.did = rs.did AND m.repository = rs.repository
21242145 LEFT JOIN repo_pages rp ON m.did = rp.did AND m.repository = rp.repository
21252125- WHERE NOT EXISTS (
21262126- SELECT 1 FROM labels
21272127- WHERE (subject_did = m.did AND (subject_repo = m.repository OR subject_repo = ''))
21282128- AND val = '!takedown' AND neg = 0
21292129- )
21462146+ WHERE ` + activeTakedownClause("m") + `
21302147 ORDER BY ` + orderBy + `
21312148 LIMIT ?
21322149 `
···22052222 JOIN users u ON m.did = u.did
22062223 LEFT JOIN repository_stats rs ON m.did = rs.did AND m.repository = rs.repository
22072224 LEFT JOIN repo_pages rp ON m.did = rp.did AND m.repository = rp.repository
22082208- WHERE NOT EXISTS (
22092209- SELECT 1 FROM labels
22102210- WHERE (subject_did = m.did AND (subject_repo = m.repository OR subject_repo = ''))
22112211- AND val = '!takedown' AND neg = 0
22122212- )
22252225+ WHERE ` + activeTakedownClause("m") + `
22132226 ORDER BY MAX(rs.last_push, m.created_at) DESC
22142227 `
22152228
+8-9
pkg/appview/db/readonly.go
···3939 } else {
4040 roDSN += "?mode=ro"
4141 }
4242- readOnlyDB, err := sql.Open("libsql", roDSN)
4242+ // Wrap with busyTimeoutConnector so every pooled read-only connection
4343+ // gets PRAGMA busy_timeout. Without this, reads return SQLITE_BUSY
4444+ // immediately when a write is in progress on the read-write connection
4545+ // (busy_timeout is per-connection, so a one-shot PRAGMA only configures
4646+ // whichever conn served it).
4747+ roBase, err := openLibsqlLocalConnector(roDSN)
4348 if err != nil {
4444- slog.Warn("Failed to open read-only database connection", "error", err)
4949+ slog.Warn("Failed to open read-only database connector", "error", err)
4550 return nil, nil, nil
4651 }
4747-4848- // busy_timeout is per-connection — without this, reads return SQLITE_BUSY
4949- // immediately when a write is in progress on the read-write connection.
5050- var busyTimeout int
5151- if err := readOnlyDB.QueryRow("PRAGMA busy_timeout = 5000").Scan(&busyTimeout); err != nil {
5252- slog.Warn("Failed to set busy_timeout on read-only connection", "error", err)
5353- }
5252+ readOnlyDB := sql.OpenDB(&busyTimeoutConnector{base: roBase, timeoutMs: 5000})
54535554 slog.Info("UI database initialized", "mode", "readonly", "path", dbPath)
5655
+71-15
pkg/appview/db/schema.go
···55package db
6677import (
88+ "context"
89 "database/sql"
1010+ "database/sql/driver"
911 "embed"
1012 "fmt"
1113 "io/fs"
···5557 db = sql.OpenDB(connector)
5658 slog.Info("Database opened in embedded replica mode", "path", path, "sync_url", cfg.SyncURL)
5759 } else {
5858- // Local-only mode: plain file via libsql driver
5959- // Paths starting with "file:" or ":memory:" are already valid libsql URIs
6060+ // Local-only mode: plain file via libsql driver, wrapped so every new
6161+ // connection gets PRAGMA busy_timeout. SQLite's busy_timeout is
6262+ // per-connection, so a one-shot db.Exec only configures whichever
6363+ // pooled conn served the call — leaving the rest to fail SQLITE_BUSY
6464+ // instantly on any write contention with the jetstream/backfill workers.
6565+ // Paths starting with "file:" or ":memory:" are already valid libsql URIs.
6066 dsn := path
6167 if !strings.HasPrefix(path, "file:") && !strings.HasPrefix(path, ":memory:") {
6268 dsn = "file:" + path
6369 }
6464- var err error
6565- db, err = sql.Open("libsql", dsn)
7070+ baseConnector, err := openLibsqlLocalConnector(dsn)
6671 if err != nil {
6772 return nil, err
6873 }
7474+ db = sql.OpenDB(&busyTimeoutConnector{base: baseConnector, timeoutMs: 5000})
6975 slog.Info("Database opened in local-only mode", "path", path)
7076 }
71777272- // In local-only mode, configure WAL and busy_timeout locally.
7373- // In embedded replica mode, the remote server manages these settings
7474- // and PRAGMA assignments are rejected as "unsupported statement"
7575- // (observed with Bunny Database; Turso may behave similarly).
7878+ // In local-only mode, set WAL mode (database-wide setting, persists
7979+ // across connections — single call is sufficient unlike busy_timeout).
8080+ // In embedded replica mode, the remote server manages this and the
8181+ // PRAGMA is rejected as "unsupported statement" (observed with Bunny;
8282+ // Turso may behave similarly).
7683 if cfg.SyncURL == "" {
7777- // Enable WAL mode for concurrent read/write access
7884 var journalMode string
7985 if err := db.QueryRow("PRAGMA journal_mode = WAL").Scan(&journalMode); err != nil {
8080- return nil, err
8181- }
8282-8383- // Retry on lock instead of failing immediately (5s timeout)
8484- var busyTimeout int
8585- if err := db.QueryRow("PRAGMA busy_timeout = 5000").Scan(&busyTimeout); err != nil {
8686 return nil, err
8787 }
8888 }
···377377378378 return version, name, nil
379379}
380380+381381+// openLibsqlLocalConnector returns a driver.Connector for a local libsql DSN.
382382+// go-libsql exports NewEmbeddedReplicaConnector for replica mode but no public
383383+// constructor for local files, so we obtain the driver via a probe sql.Open
384384+// (which is lazy and opens no connection) and ask it for a Connector.
385385+func openLibsqlLocalConnector(dsn string) (driver.Connector, error) {
386386+ probe, err := sql.Open("libsql", dsn)
387387+ if err != nil {
388388+ return nil, fmt.Errorf("probe libsql driver: %w", err)
389389+ }
390390+ drv := probe.Driver()
391391+ _ = probe.Close()
392392+393393+ dctx, ok := drv.(driver.DriverContext)
394394+ if !ok {
395395+ return nil, fmt.Errorf("libsql driver does not implement driver.DriverContext")
396396+ }
397397+ return dctx.OpenConnector(dsn)
398398+}
399399+400400+// busyTimeoutConnector wraps a driver.Connector and runs PRAGMA busy_timeout
401401+// on every newly opened connection. SQLite's busy_timeout is per-connection,
402402+// so this is the only way to ensure every conn in the pool waits on lock
403403+// contention instead of returning SQLITE_BUSY immediately.
404404+type busyTimeoutConnector struct {
405405+ base driver.Connector
406406+ timeoutMs int
407407+}
408408+409409+func (c *busyTimeoutConnector) Connect(ctx context.Context) (driver.Conn, error) {
410410+ conn, err := c.base.Connect(ctx)
411411+ if err != nil {
412412+ return nil, err
413413+ }
414414+415415+ // libsql treats PRAGMA assignments as queries that return a row, so we
416416+ // must use QueryerContext rather than ExecerContext.
417417+ queryer, ok := conn.(driver.QueryerContext)
418418+ if !ok {
419419+ _ = conn.Close()
420420+ return nil, fmt.Errorf("libsql conn does not support QueryerContext")
421421+ }
422422+423423+ rows, err := queryer.QueryContext(ctx, fmt.Sprintf("PRAGMA busy_timeout = %d", c.timeoutMs), nil)
424424+ if err != nil {
425425+ _ = conn.Close()
426426+ return nil, fmt.Errorf("set busy_timeout on new conn: %w", err)
427427+ }
428428+ _ = rows.Close()
429429+430430+ return conn, nil
431431+}
432432+433433+func (c *busyTimeoutConnector) Driver() driver.Driver {
434434+ return c.base.Driver()
435435+}
···33package labeler
4455import (
66+ "bytes"
67 "database/sql"
77- "encoding/json"
88+ "errors"
89 "fmt"
910 "log/slog"
1011 "net/url"
···13141415 "atcr.io/pkg/appview/db"
15161717+ comatproto "github.com/bluesky-social/indigo/api/atproto"
1818+ "github.com/bluesky-social/indigo/events"
1619 "github.com/gorilla/websocket"
1720)
1818-1919-// LabelsMessage is the wire format for subscribeLabels events.
2020-type LabelsMessage struct {
2121- Seq int64 `json:"seq"`
2222- Labels []LabelEvent `json:"labels"`
2323-}
2424-2525-// LabelEvent is a single label from the labeler.
2626-type LabelEvent struct {
2727- Src string `json:"src"`
2828- URI string `json:"uri"`
2929- CID string `json:"cid,omitempty"`
3030- Val string `json:"val"`
3131- Neg bool `json:"neg"`
3232- Cts string `json:"cts"`
3333- Exp string `json:"exp,omitempty"`
3434-}
35213622// Subscriber connects to a labeler's subscribeLabels endpoint
3723// and mirrors labels into the appview database.
···121107 default:
122108 }
123109124124- var msg LabelsMessage
125125- if err := conn.ReadJSON(&msg); err != nil {
110110+ mt, payload, err := conn.ReadMessage()
111111+ if err != nil {
126112 return fmt.Errorf("read error: %w", err)
127113 }
114114+ // Per the ATProto event-stream spec each frame is a binary message; reject text.
115115+ if mt != websocket.BinaryMessage {
116116+ slog.Warn("Ignoring non-binary frame from labeler", "type", mt)
117117+ continue
118118+ }
128119129129- for _, le := range msg.Labels {
120120+ seq, labels, err := decodeFrame(payload)
121121+ if err != nil {
122122+ if errors.Is(err, errInfoFrame) {
123123+ continue // already logged inside decodeFrame
124124+ }
125125+ return fmt.Errorf("decode frame: %w", err)
126126+ }
127127+128128+ for _, le := range labels {
130129 cts, _ := time.Parse(time.RFC3339, le.Cts)
131131- did, repo := extractSubjectFromURI(le.URI)
130130+ did, repo := extractSubjectFromURI(le.Uri)
132131133132 label := &db.Label{
134133 Src: le.Src,
135135- URI: le.URI,
134134+ URI: le.Uri,
136135 Val: le.Val,
137137- Neg: le.Neg,
136136+ Neg: le.Neg != nil && *le.Neg,
138137 Cts: cts,
139138 SubjectDID: did,
140139 SubjectRepo: repo,
141141- Seq: msg.Seq,
140140+ Seq: seq,
142141 }
143142144143 if err := db.UpsertLabel(s.database, label); err != nil {
145145- slog.Warn("Failed to upsert label", "uri", le.URI, "error", err)
144144+ slog.Warn("Failed to upsert label", "uri", le.Uri, "error", err)
146145 continue
147146 }
148147149149- slog.Info("Mirrored label",
150150- "uri", le.URI,
148148+ // "Mirrored label X" reads as an apply; reversals are a different action
149149+ // from the operator's POV (and a different SQL effect — the NOT EXISTS
150150+ // negation clause kicks in), so log them distinctly.
151151+ msg := "Mirrored label"
152152+ if label.Neg {
153153+ msg = "Mirrored label reversal"
154154+ }
155155+ slog.Info(msg,
156156+ "uri", le.Uri,
151157 "val", le.Val,
152152- "neg", le.Neg,
158158+ "neg", label.Neg,
153159 "subject_did", did,
154160 "subject_repo", repo,
155161 )
156162 }
163163+ }
164164+}
165165+166166+// errInfoFrame is returned by decodeFrame when the frame is informational and the
167167+// caller should just continue to the next message.
168168+var errInfoFrame = errors.New("labeler: info frame")
169169+170170+// decodeFrame parses a single subscribeLabels binary frame. ATProto event-stream framing
171171+// is two concatenated CBOR objects: a {op,t} header and a body. We dispatch on the
172172+// header op/t pair and return the labels body for op=1, t="#labels". For #info frames
173173+// we log and signal errInfoFrame so the caller skips. Error frames (op=-1) become Go
174174+// errors so the run loop reconnects with backoff.
175175+func decodeFrame(payload []byte) (int64, []*comatproto.LabelDefs_Label, error) {
176176+ r := bytes.NewReader(payload)
177177+ var header events.EventHeader
178178+ if err := header.UnmarshalCBOR(r); err != nil {
179179+ return 0, nil, fmt.Errorf("unmarshal header: %w", err)
180180+ }
181181+182182+ switch {
183183+ case header.Op == events.EvtKindErrorFrame:
184184+ var ef events.ErrorFrame
185185+ if err := ef.UnmarshalCBOR(r); err != nil {
186186+ return 0, nil, fmt.Errorf("unmarshal error frame: %w", err)
187187+ }
188188+ return 0, nil, fmt.Errorf("labeler error frame: %s — %s", ef.Error, ef.Message)
189189+190190+ case header.Op == events.EvtKindMessage && header.MsgType == "#labels":
191191+ var body comatproto.LabelSubscribeLabels_Labels
192192+ if err := body.UnmarshalCBOR(r); err != nil {
193193+ return 0, nil, fmt.Errorf("unmarshal labels body: %w", err)
194194+ }
195195+ return body.Seq, body.Labels, nil
196196+197197+ case header.Op == events.EvtKindMessage && header.MsgType == "#info":
198198+ var info comatproto.LabelSubscribeLabels_Info
199199+ if err := info.UnmarshalCBOR(r); err != nil {
200200+ return 0, nil, fmt.Errorf("unmarshal info body: %w", err)
201201+ }
202202+ message := ""
203203+ if info.Message != nil {
204204+ message = *info.Message
205205+ }
206206+ slog.Info("Labeler info frame", "name", info.Name, "message", message)
207207+ return 0, nil, errInfoFrame
208208+209209+ default:
210210+ return 0, nil, fmt.Errorf("unexpected frame op=%d t=%q", header.Op, header.MsgType)
157211 }
158212}
159213···228282 labelerURL := ParseLabelerURL(labelerDIDOrURL)
229283 return NewSubscriber(labelerURL, database)
230284}
231231-232232-// DecodeLabelsFromJSON decodes a JSON-encoded labels message.
233233-func DecodeLabelsFromJSON(data []byte) (*LabelsMessage, error) {
234234- var msg LabelsMessage
235235- if err := json.Unmarshal(data, &msg); err != nil {
236236- return nil, err
237237- }
238238- return &msg, nil
239239-}
+208
pkg/atproto/did/cmd.go
···11+package did
22+33+import (
44+ "context"
55+ "fmt"
66+77+ "github.com/bluesky-social/indigo/atproto/atcrypto"
88+ didplc "github.com/did-method-plc/go-didplc"
99+)
1010+1111+// AddRotationKeyOptions configures a rotation-key insert operation.
1212+type AddRotationKeyOptions struct {
1313+ // DID is the resolved did:plc identifier of the service.
1414+ DID string
1515+1616+ // PLCDirectoryURL is the PLC directory endpoint (defaults to https://plc.directory if empty).
1717+ PLCDirectoryURL string
1818+1919+ // RotationKey is the currently-authorized rotation key used to sign the update op.
2020+ RotationKey atcrypto.PrivateKey
2121+2222+ // SigningKey is the local k256 verification key — its public part goes into the
2323+ // new op's VerificationMethods so we don't accidentally drop it during the update.
2424+ SigningKey *atcrypto.PrivateKeyK256
2525+2626+ // VerificationKeyName is the fragment under which SigningKey is registered
2727+ // (e.g. "atproto" for a PDS, "atproto_label" for a labeler).
2828+ VerificationKeyName string
2929+3030+ // NewKey is the rotation key to add. If nil, a fresh K-256 key is generated and
3131+ // returned in the result so the caller can print/persist it.
3232+ NewKey atcrypto.PrivateKeyExportable
3333+3434+ // Prepend places the new key at the highest priority position. When false the key
3535+ // is appended at the lowest priority — only set false when the operator explicitly
3636+ // asks for it.
3737+ Prepend bool
3838+}
3939+4040+// AddRotationKeyResult describes the outcome of an AddRotationKey call.
4141+type AddRotationKeyResult struct {
4242+ NewKey atcrypto.PrivateKeyExportable
4343+ NewKeyDIDKey string
4444+ Generated bool
4545+ AlreadyPresent bool
4646+ ExistingAt int
4747+ InsertedAt int
4848+ TotalKeys int
4949+}
5050+5151+// AddRotationKey fetches the current PLC op log, inserts NewKey (generating one if nil),
5252+// signs the update with RotationKey, and submits it. Caller is responsible for printing
5353+// the generated key material — this function returns it on the result so prints can
5454+// happen in the binary's own format.
5555+func AddRotationKey(ctx context.Context, opt AddRotationKeyOptions) (*AddRotationKeyResult, error) {
5656+ if opt.DID == "" {
5757+ return nil, fmt.Errorf("plc: DID is required")
5858+ }
5959+ if opt.RotationKey == nil {
6060+ return nil, fmt.Errorf("plc: rotation key is required to sign updates")
6161+ }
6262+ if opt.SigningKey == nil {
6363+ return nil, fmt.Errorf("plc: signing key is required (becomes verificationMethods.%s)", opt.VerificationKeyName)
6464+ }
6565+ if opt.VerificationKeyName == "" {
6666+ return nil, fmt.Errorf("plc: VerificationKeyName is required")
6767+ }
6868+6969+ directory := opt.PLCDirectoryURL
7070+ if directory == "" {
7171+ directory = "https://plc.directory"
7272+ }
7373+ client := &didplc.Client{DirectoryURL: directory}
7474+7575+ res := &AddRotationKeyResult{NewKey: opt.NewKey}
7676+ if res.NewKey == nil {
7777+ raw, err := atcrypto.GeneratePrivateKeyK256()
7878+ if err != nil {
7979+ return nil, fmt.Errorf("plc: failed to generate rotation key: %w", err)
8080+ }
8181+ res.NewKey = raw
8282+ res.Generated = true
8383+ }
8484+ newPub, err := res.NewKey.PublicKey()
8585+ if err != nil {
8686+ return nil, fmt.Errorf("plc: failed to derive new public key: %w", err)
8787+ }
8888+ res.NewKeyDIDKey = newPub.DIDKey()
8989+9090+ opLog, err := client.OpLog(ctx, opt.DID)
9191+ if err != nil {
9292+ return nil, fmt.Errorf("plc: failed to fetch op log for %s: %w", opt.DID, err)
9393+ }
9494+ if len(opLog) == 0 {
9595+ return nil, fmt.Errorf("plc: empty op log for %s", opt.DID)
9696+ }
9797+ lastEntry := opLog[len(opLog)-1]
9898+ lastOp := lastEntry.Regular
9999+ if lastOp == nil {
100100+ return nil, fmt.Errorf("plc: last operation is not a regular op")
101101+ }
102102+103103+ for i, k := range lastOp.RotationKeys {
104104+ if k == res.NewKeyDIDKey {
105105+ res.AlreadyPresent = true
106106+ res.ExistingAt = i
107107+ res.TotalKeys = len(lastOp.RotationKeys)
108108+ return res, nil
109109+ }
110110+ }
111111+112112+ rotationKeys := make([]string, 0, len(lastOp.RotationKeys)+1)
113113+ if opt.Prepend {
114114+ rotationKeys = append(rotationKeys, res.NewKeyDIDKey)
115115+ rotationKeys = append(rotationKeys, lastOp.RotationKeys...)
116116+ res.InsertedAt = 0
117117+ } else {
118118+ rotationKeys = append(rotationKeys, lastOp.RotationKeys...)
119119+ rotationKeys = append(rotationKeys, res.NewKeyDIDKey)
120120+ res.InsertedAt = len(rotationKeys) - 1
121121+ }
122122+ res.TotalKeys = len(rotationKeys)
123123+124124+ sigPub, err := opt.SigningKey.PublicKey()
125125+ if err != nil {
126126+ return nil, fmt.Errorf("plc: failed to derive signing public key: %w", err)
127127+ }
128128+ prevCID := lastEntry.AsOperation().CID().String()
129129+130130+ op := &didplc.RegularOp{
131131+ Type: "plc_operation",
132132+ RotationKeys: rotationKeys,
133133+ VerificationMethods: map[string]string{
134134+ opt.VerificationKeyName: sigPub.DIDKey(),
135135+ },
136136+ AlsoKnownAs: lastOp.AlsoKnownAs,
137137+ Services: lastOp.Services,
138138+ Prev: &prevCID,
139139+ }
140140+ if err := op.Sign(opt.RotationKey); err != nil {
141141+ return nil, fmt.Errorf("plc: failed to sign update: %w", err)
142142+ }
143143+ if err := client.Submit(ctx, opt.DID, op); err != nil {
144144+ return nil, fmt.Errorf("plc: failed to submit update: %w", err)
145145+ }
146146+ return res, nil
147147+}
148148+149149+// ListRotationKeysOptions configures a list-rotation-keys read.
150150+type ListRotationKeysOptions struct {
151151+ DID string
152152+ PLCDirectoryURL string
153153+ LocalRotationKey atcrypto.PrivateKey // optional — used to compute the LOCAL marker
154154+}
155155+156156+// ListRotationKeysResult holds the priority-ordered rotation keys plus the local
157157+// rotation key's did:key form (if provided), so callers can mark and warn appropriately.
158158+type ListRotationKeysResult struct {
159159+ DID string
160160+ Directory string
161161+ Keys []string
162162+ LocalDIDKey string
163163+ LocalPresent bool
164164+}
165165+166166+// ListRotationKeys fetches the current PLC op and returns its rotation keys in priority order.
167167+func ListRotationKeys(ctx context.Context, opt ListRotationKeysOptions) (*ListRotationKeysResult, error) {
168168+ if opt.DID == "" {
169169+ return nil, fmt.Errorf("plc: DID is required")
170170+ }
171171+ directory := opt.PLCDirectoryURL
172172+ if directory == "" {
173173+ directory = "https://plc.directory"
174174+ }
175175+ client := &didplc.Client{DirectoryURL: directory}
176176+177177+ opLog, err := client.OpLog(ctx, opt.DID)
178178+ if err != nil {
179179+ return nil, fmt.Errorf("plc: failed to fetch op log for %s: %w", opt.DID, err)
180180+ }
181181+ if len(opLog) == 0 {
182182+ return nil, fmt.Errorf("plc: empty op log for %s", opt.DID)
183183+ }
184184+ lastOp := opLog[len(opLog)-1].Regular
185185+ if lastOp == nil {
186186+ return nil, fmt.Errorf("plc: last operation is not a regular op")
187187+ }
188188+189189+ res := &ListRotationKeysResult{
190190+ DID: opt.DID,
191191+ Directory: directory,
192192+ Keys: append([]string(nil), lastOp.RotationKeys...),
193193+ }
194194+ if opt.LocalRotationKey != nil {
195195+ pub, err := opt.LocalRotationKey.PublicKey()
196196+ if err != nil {
197197+ return nil, fmt.Errorf("plc: failed to derive local rotation public key: %w", err)
198198+ }
199199+ res.LocalDIDKey = pub.DIDKey()
200200+ for _, k := range res.Keys {
201201+ if k == res.LocalDIDKey {
202202+ res.LocalPresent = true
203203+ break
204204+ }
205205+ }
206206+ }
207207+ return res, nil
208208+}
+304
pkg/atproto/did/cmd_test.go
···11+package did
22+33+import (
44+ "context"
55+ "strings"
66+ "testing"
77+88+ "github.com/bluesky-social/indigo/atproto/atcrypto"
99+)
1010+1111+// TestAddRotationKey_AppendNew confirms a fresh key is appended at the lowest priority
1212+// when Prepend is false.
1313+func TestAddRotationKey_AppendNew(t *testing.T) {
1414+ ctx := context.Background()
1515+1616+ serverRot := generateK256(t)
1717+ signing := generateK256(t)
1818+ fake := newFakePLC(t, []*atcrypto.PrivateKeyK256{serverRot}, serverRot, signing)
1919+ defer fake.Close()
2020+2121+ newKey := generateK256(t)
2222+ res, err := AddRotationKey(ctx, AddRotationKeyOptions{
2323+ DID: fake.did,
2424+ PLCDirectoryURL: fake.URL(),
2525+ RotationKey: serverRot,
2626+ SigningKey: signing,
2727+ VerificationKeyName: "atproto",
2828+ NewKey: newKey,
2929+ Prepend: false,
3030+ })
3131+ if err != nil {
3232+ t.Fatalf("AddRotationKey: %v", err)
3333+ }
3434+ if res.AlreadyPresent {
3535+ t.Fatal("AlreadyPresent should be false")
3636+ }
3737+ if res.Generated {
3838+ t.Error("Generated should be false when NewKey provided")
3939+ }
4040+ if res.TotalKeys != 2 {
4141+ t.Errorf("TotalKeys: got %d want 2", res.TotalKeys)
4242+ }
4343+ if res.InsertedAt != 1 {
4444+ t.Errorf("InsertedAt: got %d want 1 (appended)", res.InsertedAt)
4545+ }
4646+4747+ if len(fake.submitted) != 1 {
4848+ t.Fatalf("expected one update submission, got %d", len(fake.submitted))
4949+ }
5050+ got := fake.submitted[0]
5151+ newPub, _ := newKey.PublicKey()
5252+ if got.RotationKeys[len(got.RotationKeys)-1] != newPub.DIDKey() {
5353+ t.Errorf("appended key not at last position: %v", got.RotationKeys)
5454+ }
5555+}
5656+5757+// TestAddRotationKey_Prepend confirms Prepend=true puts the new key at index 0
5858+// (highest priority position).
5959+func TestAddRotationKey_Prepend(t *testing.T) {
6060+ ctx := context.Background()
6161+6262+ serverRot := generateK256(t)
6363+ signing := generateK256(t)
6464+ fake := newFakePLC(t, []*atcrypto.PrivateKeyK256{serverRot}, serverRot, signing)
6565+ defer fake.Close()
6666+6767+ newKey := generateK256(t)
6868+ res, err := AddRotationKey(ctx, AddRotationKeyOptions{
6969+ DID: fake.did,
7070+ PLCDirectoryURL: fake.URL(),
7171+ RotationKey: serverRot,
7272+ SigningKey: signing,
7373+ VerificationKeyName: "atproto",
7474+ NewKey: newKey,
7575+ Prepend: true,
7676+ })
7777+ if err != nil {
7878+ t.Fatalf("AddRotationKey: %v", err)
7979+ }
8080+ if res.InsertedAt != 0 {
8181+ t.Errorf("InsertedAt: got %d want 0", res.InsertedAt)
8282+ }
8383+ if len(fake.submitted) != 1 {
8484+ t.Fatalf("expected one update, got %d", len(fake.submitted))
8585+ }
8686+ newPub, _ := newKey.PublicKey()
8787+ if fake.submitted[0].RotationKeys[0] != newPub.DIDKey() {
8888+ t.Errorf("prepended key not at first position: %v", fake.submitted[0].RotationKeys)
8989+ }
9090+}
9191+9292+// TestAddRotationKey_GeneratesWhenNil confirms the helper generates a fresh key
9393+// and reports it via Result.
9494+func TestAddRotationKey_GeneratesWhenNil(t *testing.T) {
9595+ ctx := context.Background()
9696+9797+ serverRot := generateK256(t)
9898+ signing := generateK256(t)
9999+ fake := newFakePLC(t, []*atcrypto.PrivateKeyK256{serverRot}, serverRot, signing)
100100+ defer fake.Close()
101101+102102+ res, err := AddRotationKey(ctx, AddRotationKeyOptions{
103103+ DID: fake.did,
104104+ PLCDirectoryURL: fake.URL(),
105105+ RotationKey: serverRot,
106106+ SigningKey: signing,
107107+ VerificationKeyName: "atproto",
108108+ NewKey: nil,
109109+ })
110110+ if err != nil {
111111+ t.Fatalf("AddRotationKey: %v", err)
112112+ }
113113+ if !res.Generated {
114114+ t.Error("Generated should be true when NewKey is nil")
115115+ }
116116+ if res.NewKey == nil {
117117+ t.Fatal("NewKey on result should not be nil")
118118+ }
119119+ if res.NewKeyDIDKey == "" {
120120+ t.Error("NewKeyDIDKey should be populated")
121121+ }
122122+}
123123+124124+// TestAddRotationKey_AlreadyPresent confirms a no-op when the key is already in the list.
125125+func TestAddRotationKey_AlreadyPresent(t *testing.T) {
126126+ ctx := context.Background()
127127+128128+ serverRot := generateK256(t)
129129+ signing := generateK256(t)
130130+ fake := newFakePLC(t, []*atcrypto.PrivateKeyK256{serverRot}, serverRot, signing)
131131+ defer fake.Close()
132132+133133+ res, err := AddRotationKey(ctx, AddRotationKeyOptions{
134134+ DID: fake.did,
135135+ PLCDirectoryURL: fake.URL(),
136136+ RotationKey: serverRot,
137137+ SigningKey: signing,
138138+ VerificationKeyName: "atproto",
139139+ NewKey: serverRot,
140140+ })
141141+ if err != nil {
142142+ t.Fatalf("AddRotationKey: %v", err)
143143+ }
144144+ if !res.AlreadyPresent {
145145+ t.Error("AlreadyPresent should be true")
146146+ }
147147+ if res.ExistingAt != 0 {
148148+ t.Errorf("ExistingAt: got %d want 0", res.ExistingAt)
149149+ }
150150+ if len(fake.submitted) != 0 {
151151+ t.Errorf("no submission expected when key already present, got %d", len(fake.submitted))
152152+ }
153153+}
154154+155155+// TestAddRotationKey_ValidationErrors covers the early-return guard clauses in AddRotationKey.
156156+func TestAddRotationKey_ValidationErrors(t *testing.T) {
157157+ ctx := context.Background()
158158+ signing := generateK256(t)
159159+ rot := generateK256(t)
160160+161161+ cases := []struct {
162162+ name string
163163+ opt AddRotationKeyOptions
164164+ wantSub string
165165+ }{
166166+ {
167167+ name: "missing DID",
168168+ opt: AddRotationKeyOptions{RotationKey: rot, SigningKey: signing, VerificationKeyName: "atproto"},
169169+ wantSub: "DID is required",
170170+ },
171171+ {
172172+ name: "missing rotation key",
173173+ opt: AddRotationKeyOptions{DID: "did:plc:abc", SigningKey: signing, VerificationKeyName: "atproto"},
174174+ wantSub: "rotation key is required",
175175+ },
176176+ {
177177+ name: "missing signing key",
178178+ opt: AddRotationKeyOptions{DID: "did:plc:abc", RotationKey: rot, VerificationKeyName: "atproto"},
179179+ wantSub: "signing key is required",
180180+ },
181181+ {
182182+ name: "missing verification key name",
183183+ opt: AddRotationKeyOptions{DID: "did:plc:abc", RotationKey: rot, SigningKey: signing},
184184+ wantSub: "VerificationKeyName is required",
185185+ },
186186+ }
187187+ for _, tc := range cases {
188188+ t.Run(tc.name, func(t *testing.T) {
189189+ _, err := AddRotationKey(ctx, tc.opt)
190190+ if err == nil {
191191+ t.Fatal("expected error, got nil")
192192+ }
193193+ if !strings.Contains(err.Error(), tc.wantSub) {
194194+ t.Errorf("error: got %q want substring %q", err.Error(), tc.wantSub)
195195+ }
196196+ })
197197+ }
198198+}
199199+200200+// TestListRotationKeys returns the priority-ordered keys from the latest op.
201201+func TestListRotationKeys(t *testing.T) {
202202+ ctx := context.Background()
203203+204204+ rot1 := generateK256(t)
205205+ rot2 := generateK256(t)
206206+ signing := generateK256(t)
207207+ fake := newFakePLC(t, []*atcrypto.PrivateKeyK256{rot1, rot2}, rot1, signing)
208208+ defer fake.Close()
209209+210210+ res, err := ListRotationKeys(ctx, ListRotationKeysOptions{
211211+ DID: fake.did,
212212+ PLCDirectoryURL: fake.URL(),
213213+ })
214214+ if err != nil {
215215+ t.Fatalf("ListRotationKeys: %v", err)
216216+ }
217217+ if res.DID != fake.did {
218218+ t.Errorf("DID: got %s want %s", res.DID, fake.did)
219219+ }
220220+ if res.Directory != fake.URL() {
221221+ t.Errorf("Directory: got %s want %s", res.Directory, fake.URL())
222222+ }
223223+ if len(res.Keys) != 2 {
224224+ t.Fatalf("Keys length: got %d want 2", len(res.Keys))
225225+ }
226226+ pub1, _ := rot1.PublicKey()
227227+ pub2, _ := rot2.PublicKey()
228228+ if res.Keys[0] != pub1.DIDKey() {
229229+ t.Errorf("Keys[0]: got %s want %s", res.Keys[0], pub1.DIDKey())
230230+ }
231231+ if res.Keys[1] != pub2.DIDKey() {
232232+ t.Errorf("Keys[1]: got %s want %s", res.Keys[1], pub2.DIDKey())
233233+ }
234234+ if res.LocalDIDKey != "" {
235235+ t.Errorf("LocalDIDKey should be empty when no LocalRotationKey provided, got %s", res.LocalDIDKey)
236236+ }
237237+ if res.LocalPresent {
238238+ t.Error("LocalPresent should be false when no LocalRotationKey provided")
239239+ }
240240+}
241241+242242+// TestListRotationKeys_LocalPresent confirms LocalPresent flips to true when the local
243243+// key matches one in the published list.
244244+func TestListRotationKeys_LocalPresent(t *testing.T) {
245245+ ctx := context.Background()
246246+247247+ serverRot := generateK256(t)
248248+ signing := generateK256(t)
249249+ fake := newFakePLC(t, []*atcrypto.PrivateKeyK256{serverRot}, serverRot, signing)
250250+ defer fake.Close()
251251+252252+ res, err := ListRotationKeys(ctx, ListRotationKeysOptions{
253253+ DID: fake.did,
254254+ PLCDirectoryURL: fake.URL(),
255255+ LocalRotationKey: serverRot,
256256+ })
257257+ if err != nil {
258258+ t.Fatalf("ListRotationKeys: %v", err)
259259+ }
260260+ if !res.LocalPresent {
261261+ t.Error("LocalPresent should be true")
262262+ }
263263+ pub, _ := serverRot.PublicKey()
264264+ if res.LocalDIDKey != pub.DIDKey() {
265265+ t.Errorf("LocalDIDKey: got %s want %s", res.LocalDIDKey, pub.DIDKey())
266266+ }
267267+}
268268+269269+// TestListRotationKeys_LocalNotPresent flags a rotated-out local key.
270270+func TestListRotationKeys_LocalNotPresent(t *testing.T) {
271271+ ctx := context.Background()
272272+273273+ serverRot := generateK256(t)
274274+ signing := generateK256(t)
275275+ fake := newFakePLC(t, []*atcrypto.PrivateKeyK256{serverRot}, serverRot, signing)
276276+ defer fake.Close()
277277+278278+ stranger := generateK256(t)
279279+ res, err := ListRotationKeys(ctx, ListRotationKeysOptions{
280280+ DID: fake.did,
281281+ PLCDirectoryURL: fake.URL(),
282282+ LocalRotationKey: stranger,
283283+ })
284284+ if err != nil {
285285+ t.Fatalf("ListRotationKeys: %v", err)
286286+ }
287287+ if res.LocalPresent {
288288+ t.Error("LocalPresent should be false for a stranger key")
289289+ }
290290+ if res.LocalDIDKey == "" {
291291+ t.Error("LocalDIDKey should still be populated even when not present")
292292+ }
293293+}
294294+295295+// TestListRotationKeys_MissingDID surfaces the early-return validation.
296296+func TestListRotationKeys_MissingDID(t *testing.T) {
297297+ _, err := ListRotationKeys(context.Background(), ListRotationKeysOptions{})
298298+ if err == nil {
299299+ t.Fatal("expected error for missing DID")
300300+ }
301301+ if !strings.Contains(err.Error(), "DID is required") {
302302+ t.Errorf("error: got %q", err.Error())
303303+ }
304304+}
+243
pkg/atproto/did/did.go
···11+// Package did provides shared did:web and did:plc identity management for ATCR services.
22+//
33+// Both the hold and labeler services declare an ATProto identity with a signing key
44+// and one or more service endpoints. This package generalizes the genesis/update/load
55+// flow so callers only have to specify their verification key fragment name and the
66+// service entries they want to register.
77+package did
88+99+import (
1010+ "context"
1111+ "encoding/json"
1212+ "fmt"
1313+ "log/slog"
1414+ "net/url"
1515+ "os"
1616+ "path/filepath"
1717+ "strings"
1818+1919+ "atcr.io/pkg/auth/oauth"
2020+ "github.com/bluesky-social/indigo/atproto/atcrypto"
2121+)
2222+2323+// Service is a service entry in a DID document or PLC operation.
2424+type Service struct {
2525+ Type string
2626+ Endpoint string
2727+}
2828+2929+// Config configures DID identity loading or creation.
3030+type Config struct {
3131+ // Method is "web" or "plc".
3232+ Method string
3333+3434+ // PublicURL is the externally reachable URL of the service.
3535+ PublicURL string
3636+3737+ // DBPath is a directory used to persist did.txt for did:plc identities.
3838+ DBPath string
3939+4040+ // SigningKeyPath is the on-disk path for the K-256 signing key (will be generated if missing).
4141+ SigningKeyPath string
4242+4343+ // RotationKey is a multibase-encoded private key used to sign PLC operations (optional).
4444+ // If empty for did:plc, a new rotation key is generated and logged once for the operator.
4545+ RotationKey string
4646+4747+ // PLCDirectoryURL is the PLC directory endpoint.
4848+ PLCDirectoryURL string
4949+5050+ // DID overrides the persisted DID (used for adoption/recovery of an existing did:plc).
5151+ DID string
5252+5353+ // VerificationKeyName is the fragment used in the DID document and PLC operation
5454+ // for the signing key (e.g. "atproto" for a PDS, "atproto_label" for a labeler).
5555+ VerificationKeyName string
5656+5757+ // Services lists service entries keyed by service id (e.g. "atproto_pds", "atproto_labeler").
5858+ Services map[string]Service
5959+}
6060+6161+// LoadOrCreate returns the service's DID. did:web is derived deterministically from
6262+// PublicURL; did:plc is loaded from disk or created and registered with the PLC directory.
6363+func LoadOrCreate(ctx context.Context, cfg Config) (string, error) {
6464+ if cfg.Method != "plc" {
6565+ return GenerateDIDFromURL(cfg.PublicURL), nil
6666+ }
6767+6868+ if cfg.VerificationKeyName == "" {
6969+ return "", fmt.Errorf("did: VerificationKeyName is required for did:plc")
7070+ }
7171+ if len(cfg.Services) == 0 {
7272+ return "", fmt.Errorf("did: at least one service entry is required for did:plc")
7373+ }
7474+7575+ didPath := filepath.Join(cfg.DBPath, "did.txt")
7676+7777+ var d string
7878+ if cfg.DID != "" {
7979+ if !strings.HasPrefix(cfg.DID, "did:plc:") {
8080+ return "", fmt.Errorf("did: DID must be a did:plc identifier, got %q", cfg.DID)
8181+ }
8282+ d = cfg.DID
8383+ slog.Info("Using DID from config (adoption/recovery)", "did", d)
8484+ } else if data, err := os.ReadFile(didPath); err == nil {
8585+ val := strings.TrimSpace(string(data))
8686+ if strings.HasPrefix(val, "did:plc:") {
8787+ d = val
8888+ slog.Info("Loaded existing did:plc identity", "did", d)
8989+ }
9090+ }
9191+9292+ if d != "" {
9393+ if err := os.MkdirAll(filepath.Dir(didPath), 0755); err != nil {
9494+ return "", fmt.Errorf("did: failed to create did.txt directory: %w", err)
9595+ }
9696+ if err := os.WriteFile(didPath, []byte(d+"\n"), 0600); err != nil {
9797+ return "", fmt.Errorf("did: failed to write did.txt: %w", err)
9898+ }
9999+100100+ signingKey, err := oauth.GenerateOrLoadPDSKey(cfg.SigningKeyPath)
101101+ if err != nil {
102102+ return "", fmt.Errorf("did: failed to load signing key: %w", err)
103103+ }
104104+ rotationKey, _ := parseOptionalMultibaseKey(cfg.RotationKey)
105105+106106+ if err := EnsureCurrent(ctx, d, rotationKey, signingKey, cfg); err != nil {
107107+ slog.Warn("Failed to verify PLC identity is current (will retry on next restart)",
108108+ "did", d, "error", err)
109109+ }
110110+ return d, nil
111111+ }
112112+113113+ slog.Info("Creating new did:plc identity")
114114+115115+ signingKey, err := oauth.GenerateOrLoadPDSKey(cfg.SigningKeyPath)
116116+ if err != nil {
117117+ return "", fmt.Errorf("did: failed to load signing key: %w", err)
118118+ }
119119+120120+ var rotationKey atcrypto.PrivateKeyExportable
121121+ if cfg.RotationKey != "" {
122122+ rotationKey, err = parseOptionalMultibaseKey(cfg.RotationKey)
123123+ if err != nil {
124124+ return "", fmt.Errorf("did: failed to parse rotation_key: %w", err)
125125+ }
126126+ } else {
127127+ rawKey, genErr := atcrypto.GeneratePrivateKeyK256()
128128+ if genErr != nil {
129129+ return "", fmt.Errorf("did: failed to generate rotation key: %w", genErr)
130130+ }
131131+ rotationKey = rawKey
132132+ slog.Warn("Generated new rotation key — save this in your config as rotation_key",
133133+ "rotation_key", rawKey.Multibase())
134134+ }
135135+136136+ d, err = CreateIdentity(ctx, rotationKey, signingKey, cfg)
137137+ if err != nil {
138138+ return "", fmt.Errorf("did: failed to create PLC identity: %w", err)
139139+ }
140140+141141+ if err := os.MkdirAll(filepath.Dir(didPath), 0755); err != nil {
142142+ return "", fmt.Errorf("did: failed to create did.txt directory: %w", err)
143143+ }
144144+ if err := os.WriteFile(didPath, []byte(d+"\n"), 0600); err != nil {
145145+ return "", fmt.Errorf("did: failed to write did.txt: %w", err)
146146+ }
147147+148148+ slog.Info("Created did:plc identity", "did", d, "plc_directory", cfg.PLCDirectoryURL)
149149+ slog.Warn("Back up your rotation_key. It is only needed for DID updates (URL changes, key rotation).")
150150+ return d, nil
151151+}
152152+153153+// DIDDocument is the JSON shape we serve for did:web identities.
154154+type DIDDocument struct {
155155+ Context []string `json:"@context"`
156156+ ID string `json:"id"`
157157+ AlsoKnownAs []string `json:"alsoKnownAs,omitempty"`
158158+ VerificationMethod []VerificationMethod `json:"verificationMethod"`
159159+ Authentication []string `json:"authentication,omitempty"`
160160+ AssertionMethod []string `json:"assertionMethod,omitempty"`
161161+ Service []DIDService `json:"service,omitempty"`
162162+}
163163+164164+// VerificationMethod is a public key entry in a DID document.
165165+type VerificationMethod struct {
166166+ ID string `json:"id"`
167167+ Type string `json:"type"`
168168+ Controller string `json:"controller"`
169169+ PublicKeyMultibase string `json:"publicKeyMultibase"`
170170+}
171171+172172+// DIDService is a service entry in a DID document.
173173+type DIDService struct {
174174+ ID string `json:"id"`
175175+ Type string `json:"type"`
176176+ ServiceEndpoint string `json:"serviceEndpoint"`
177177+}
178178+179179+// BuildDIDDocument constructs a DID document for a did:web identity. The verification
180180+// method fragment matches verificationKeyName (e.g. "#atproto" or "#atproto_label");
181181+// pass "" to default to "atproto". Authentication is only added for the standard
182182+// "atproto" key per the bsky/PDS pattern.
183183+func BuildDIDDocument(did, publicURL string, signingKey *atcrypto.PrivateKeyK256, verificationKeyName string, services map[string]Service) (*DIDDocument, error) {
184184+ host, err := hostWithPort(publicURL)
185185+ if err != nil {
186186+ return nil, err
187187+ }
188188+ pub, err := signingKey.PublicKey()
189189+ if err != nil {
190190+ return nil, fmt.Errorf("did: failed to get public key: %w", err)
191191+ }
192192+193193+ keyName := verificationKeyName
194194+ if keyName == "" {
195195+ keyName = "atproto"
196196+ }
197197+198198+ doc := &DIDDocument{
199199+ Context: []string{
200200+ "https://www.w3.org/ns/did/v1",
201201+ "https://w3id.org/security/multikey/v1",
202202+ "https://w3id.org/security/suites/secp256k1-2019/v1",
203203+ },
204204+ ID: did,
205205+ AlsoKnownAs: []string{"at://" + host},
206206+ VerificationMethod: []VerificationMethod{
207207+ {
208208+ ID: fmt.Sprintf("%s#%s", did, keyName),
209209+ Type: "Multikey",
210210+ Controller: did,
211211+ PublicKeyMultibase: pub.Multibase(),
212212+ },
213213+ },
214214+ }
215215+ if keyName == "atproto" {
216216+ doc.Authentication = []string{fmt.Sprintf("%s#atproto", did)}
217217+ }
218218+ for id, svc := range services {
219219+ doc.Service = append(doc.Service, DIDService{
220220+ ID: "#" + id,
221221+ Type: svc.Type,
222222+ ServiceEndpoint: svc.Endpoint,
223223+ })
224224+ }
225225+ return doc, nil
226226+}
227227+228228+// MarshalDIDDocument is a convenience for serving a DID doc as indented JSON.
229229+func MarshalDIDDocument(doc *DIDDocument) ([]byte, error) {
230230+ return json.MarshalIndent(doc, "", " ")
231231+}
232232+233233+func hostWithPort(publicURL string) (string, error) {
234234+ u, err := url.Parse(publicURL)
235235+ if err != nil {
236236+ return "", fmt.Errorf("did: failed to parse public URL: %w", err)
237237+ }
238238+ host := u.Hostname()
239239+ if port := u.Port(); port != "" && port != "80" && port != "443" {
240240+ host = host + ":" + port
241241+ }
242242+ return host, nil
243243+}
···11+package did
22+33+import (
44+ "context"
55+ "fmt"
66+ "log/slog"
77+88+ "github.com/bluesky-social/indigo/atproto/atcrypto"
99+ didplc "github.com/did-method-plc/go-didplc"
1010+)
1111+1212+// CreateIdentity builds a genesis PLC operation with the configured verification key and
1313+// services, signs it with the rotation key, and submits it to the PLC directory.
1414+func CreateIdentity(ctx context.Context, rotationKey atcrypto.PrivateKey, signingKey *atcrypto.PrivateKeyK256, cfg Config) (string, error) {
1515+ rotPub, err := rotationKey.PublicKey()
1616+ if err != nil {
1717+ return "", fmt.Errorf("did: failed to get rotation public key: %w", err)
1818+ }
1919+ sigPub, err := signingKey.PublicKey()
2020+ if err != nil {
2121+ return "", fmt.Errorf("did: failed to get signing public key: %w", err)
2222+ }
2323+2424+ host, err := hostWithPort(cfg.PublicURL)
2525+ if err != nil {
2626+ return "", err
2727+ }
2828+2929+ op := &didplc.RegularOp{
3030+ Type: "plc_operation",
3131+ RotationKeys: []string{rotPub.DIDKey()},
3232+ VerificationMethods: map[string]string{
3333+ cfg.VerificationKeyName: sigPub.DIDKey(),
3434+ },
3535+ AlsoKnownAs: []string{"at://" + host},
3636+ Services: toOpServices(cfg.Services),
3737+ Prev: nil,
3838+ }
3939+ if err := op.Sign(rotationKey); err != nil {
4040+ return "", fmt.Errorf("did: failed to sign genesis operation: %w", err)
4141+ }
4242+4343+ d, err := op.DID()
4444+ if err != nil {
4545+ return "", fmt.Errorf("did: failed to compute DID from genesis: %w", err)
4646+ }
4747+4848+ client := &didplc.Client{DirectoryURL: cfg.PLCDirectoryURL}
4949+ if err := client.Submit(ctx, d, op); err != nil {
5050+ return "", fmt.Errorf("did: failed to submit genesis operation: %w", err)
5151+ }
5252+ return d, nil
5353+}
5454+5555+// EnsureCurrent reconciles the published DID document with the local config; if the local
5656+// signing key, public URL, or service set differs, an update operation is signed and submitted.
5757+// Without a rotation key, mismatches log a warning but are not fatal.
5858+func EnsureCurrent(ctx context.Context, did string, rotationKey atcrypto.PrivateKey, signingKey *atcrypto.PrivateKeyK256, cfg Config) error {
5959+ client := &didplc.Client{DirectoryURL: cfg.PLCDirectoryURL}
6060+6161+ opLog, err := client.OpLog(ctx, did)
6262+ if err != nil {
6363+ return fmt.Errorf("did: failed to fetch op log for %s: %w", did, err)
6464+ }
6565+ if len(opLog) == 0 {
6666+ return fmt.Errorf("did: empty op log for %s", did)
6767+ }
6868+ lastEntry := opLog[len(opLog)-1]
6969+ lastOp := lastEntry.Regular
7070+ if lastOp == nil {
7171+ slog.Warn("Last PLC operation is not a regular op, skipping auto-update", "did", did)
7272+ return nil
7373+ }
7474+7575+ sigPub, err := signingKey.PublicKey()
7676+ if err != nil {
7777+ return fmt.Errorf("did: failed to get signing public key: %w", err)
7878+ }
7979+ localKey := sigPub.DIDKey()
8080+ plcKey := lastOp.VerificationMethods[cfg.VerificationKeyName]
8181+ keyMatch := localKey == plcKey
8282+8383+ servicesMatch := true
8484+ for name, svc := range cfg.Services {
8585+ plcSvc, ok := lastOp.Services[name]
8686+ if !ok || plcSvc.Type != svc.Type || plcSvc.Endpoint != svc.Endpoint {
8787+ servicesMatch = false
8888+ break
8989+ }
9090+ }
9191+9292+ if keyMatch && servicesMatch {
9393+ slog.Info("PLC identity is current", "did", did)
9494+ return nil
9595+ }
9696+9797+ slog.Info("PLC identity needs update",
9898+ "did", did, "signing_key_changed", !keyMatch, "services_changed", !servicesMatch)
9999+100100+ if rotationKey == nil {
101101+ slog.Warn("PLC document doesn't match local state but no rotation key available. Provide rotation key to auto-update.",
102102+ "did", did, "signing_key_changed", !keyMatch, "services_changed", !servicesMatch)
103103+ return nil
104104+ }
105105+106106+ rotPub, err := rotationKey.PublicKey()
107107+ if err != nil {
108108+ return fmt.Errorf("did: failed to get rotation public key: %w", err)
109109+ }
110110+ localRotKey := rotPub.DIDKey()
111111+112112+ // Verify the local rotation key still has authority on the PLC document.
113113+ // If it's been rotated out (possibly maliciously), refuse to submit — PLC would
114114+ // reject anyway, and silent failure here would be confusing.
115115+ localRotKeyPresent := false
116116+ for _, k := range lastOp.RotationKeys {
117117+ if k == localRotKey {
118118+ localRotKeyPresent = true
119119+ break
120120+ }
121121+ }
122122+ if !localRotKeyPresent {
123123+ slog.Warn("Local rotation key not present in PLC document — refusing to update. Possible compromise or out-of-band rotation. Recover with offline key if available.",
124124+ "did", did, "local_rotation_key", localRotKey, "plc_rotation_keys", lastOp.RotationKeys)
125125+ return nil
126126+ }
127127+128128+ host, err := hostWithPort(cfg.PublicURL)
129129+ if err != nil {
130130+ return err
131131+ }
132132+ prevCID := lastEntry.AsOperation().CID().String()
133133+134134+ op := &didplc.RegularOp{
135135+ Type: "plc_operation",
136136+ RotationKeys: lastOp.RotationKeys,
137137+ VerificationMethods: map[string]string{
138138+ cfg.VerificationKeyName: localKey,
139139+ },
140140+ AlsoKnownAs: []string{"at://" + host},
141141+ Services: toOpServices(cfg.Services),
142142+ Prev: &prevCID,
143143+ }
144144+ if err := op.Sign(rotationKey); err != nil {
145145+ return fmt.Errorf("did: failed to sign update operation: %w", err)
146146+ }
147147+ if err := client.Submit(ctx, did, op); err != nil {
148148+ return fmt.Errorf("did: failed to submit update: %w", err)
149149+ }
150150+ slog.Info("Updated PLC identity",
151151+ "did", did, "signing_key_rotated", !keyMatch, "services_changed", !servicesMatch)
152152+ return nil
153153+}
154154+155155+func toOpServices(in map[string]Service) map[string]didplc.OpService {
156156+ out := make(map[string]didplc.OpService, len(in))
157157+ for name, svc := range in {
158158+ out[name] = didplc.OpService{Type: svc.Type, Endpoint: svc.Endpoint}
159159+ }
160160+ return out
161161+}
162162+163163+func parseOptionalMultibaseKey(encoded string) (atcrypto.PrivateKeyExportable, error) {
164164+ if encoded == "" {
165165+ return nil, nil
166166+ }
167167+ key, err := atcrypto.ParsePrivateMultibase(encoded)
168168+ if err != nil {
169169+ return nil, fmt.Errorf("did: failed to parse multibase key: %w", err)
170170+ }
171171+ return key, nil
172172+}
+243
pkg/atproto/did/plc_test.go
···11+package did
22+33+import (
44+ "context"
55+ "encoding/json"
66+ "io"
77+ "net/http"
88+ "net/http/httptest"
99+ "os"
1010+ "path/filepath"
1111+ "strings"
1212+ "testing"
1313+1414+ "github.com/bluesky-social/indigo/atproto/atcrypto"
1515+ didplc "github.com/did-method-plc/go-didplc"
1616+)
1717+1818+// fakePLC stands up an httptest server that serves a single op log entry and
1919+// captures any submitted update op for inspection.
2020+type fakePLC struct {
2121+ server *httptest.Server
2222+ did string
2323+ logEntries []didplc.OpEnum
2424+ submitted []didplc.RegularOp
2525+}
2626+2727+func (f *fakePLC) URL() string { return f.server.URL }
2828+2929+func (f *fakePLC) Close() { f.server.Close() }
3030+3131+// newFakePLC creates a fake PLC directory pre-loaded with a single signed
3232+// genesis op containing the given rotation keys (in priority order). Returns
3333+// the fake server and the resulting did:plc DID derived from the genesis op.
3434+func newFakePLC(t *testing.T, rotationKeys []*atcrypto.PrivateKeyK256, signer atcrypto.PrivateKey, signingKey *atcrypto.PrivateKeyK256) *fakePLC {
3535+ t.Helper()
3636+3737+ rotationDIDKeys := make([]string, 0, len(rotationKeys))
3838+ for _, k := range rotationKeys {
3939+ pub, err := k.PublicKey()
4040+ if err != nil {
4141+ t.Fatalf("rotation key public: %v", err)
4242+ }
4343+ rotationDIDKeys = append(rotationDIDKeys, pub.DIDKey())
4444+ }
4545+4646+ sigPub, err := signingKey.PublicKey()
4747+ if err != nil {
4848+ t.Fatalf("signing key public: %v", err)
4949+ }
5050+5151+ op := &didplc.RegularOp{
5252+ Type: "plc_operation",
5353+ RotationKeys: rotationDIDKeys,
5454+ VerificationMethods: map[string]string{
5555+ "atproto": sigPub.DIDKey(),
5656+ },
5757+ AlsoKnownAs: []string{"at://example.test"},
5858+ Services: map[string]didplc.OpService{
5959+ "atproto_pds": {Type: "AtprotoPersonalDataServer", Endpoint: "https://example.test"},
6060+ },
6161+ Prev: nil,
6262+ }
6363+ if err := op.Sign(signer); err != nil {
6464+ t.Fatalf("sign genesis: %v", err)
6565+ }
6666+ did, err := op.DID()
6767+ if err != nil {
6868+ t.Fatalf("compute DID: %v", err)
6969+ }
7070+7171+ f := &fakePLC{
7272+ did: did,
7373+ logEntries: []didplc.OpEnum{{Regular: op}},
7474+ }
7575+7676+ mux := http.NewServeMux()
7777+ mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
7878+ if r.Method == http.MethodGet && strings.HasSuffix(r.URL.Path, "/log") {
7979+ w.Header().Set("Content-Type", "application/json")
8080+ _ = json.NewEncoder(w).Encode(f.logEntries)
8181+ return
8282+ }
8383+ if r.Method == http.MethodPost && r.URL.Path == "/"+did {
8484+ body, err := io.ReadAll(r.Body)
8585+ if err != nil {
8686+ http.Error(w, err.Error(), http.StatusBadRequest)
8787+ return
8888+ }
8989+ var op didplc.RegularOp
9090+ if err := json.Unmarshal(body, &op); err != nil {
9191+ http.Error(w, err.Error(), http.StatusBadRequest)
9292+ return
9393+ }
9494+ f.submitted = append(f.submitted, op)
9595+ w.WriteHeader(http.StatusOK)
9696+ return
9797+ }
9898+ http.NotFound(w, r)
9999+ })
100100+ f.server = httptest.NewServer(mux)
101101+ return f
102102+}
103103+104104+// generateK256 returns a fresh K-256 keypair, failing the test on error.
105105+func generateK256(t *testing.T) *atcrypto.PrivateKeyK256 {
106106+ t.Helper()
107107+ k, err := atcrypto.GeneratePrivateKeyK256()
108108+ if err != nil {
109109+ t.Fatalf("generate K-256: %v", err)
110110+ }
111111+ return k
112112+}
113113+114114+// writeSigningKey persists a signing key to a temp file and returns its path.
115115+// Used so EnsureCurrent can load it via oauth.GenerateOrLoadPDSKey.
116116+func writeSigningKey(t *testing.T, dir string, key *atcrypto.PrivateKeyK256) string {
117117+ t.Helper()
118118+ path := filepath.Join(dir, "signing.key")
119119+ if err := os.WriteFile(path, key.Bytes(), 0600); err != nil {
120120+ t.Fatalf("write signing key: %v", err)
121121+ }
122122+ return path
123123+}
124124+125125+func TestEnsureCurrent_PreservesRotationKeys(t *testing.T) {
126126+ ctx := context.Background()
127127+ tmp := t.TempDir()
128128+129129+ // Server-side rotation key (the one stored in database.rotation_key) and an
130130+ // "offline" recovery key that lives only in PLC. The genesis op lists offline
131131+ // FIRST (highest priority).
132132+ serverRot := generateK256(t)
133133+ offlineRot := generateK256(t)
134134+135135+ // Original signing key used to build genesis; the local signing key on disk
136136+ // will be different to force EnsureCurrent into the update path.
137137+ originalSigning := generateK256(t)
138138+ localSigning := generateK256(t)
139139+ writeSigningKey(t, tmp, localSigning)
140140+141141+ fake := newFakePLC(t, []*atcrypto.PrivateKeyK256{offlineRot, serverRot}, serverRot, originalSigning)
142142+ defer fake.Close()
143143+144144+ cfg := Config{
145145+ PublicURL: "https://example.test",
146146+ PLCDirectoryURL: fake.URL(),
147147+ VerificationKeyName: "atproto",
148148+ Services: map[string]Service{
149149+ "atproto_pds": {Type: "AtprotoPersonalDataServer", Endpoint: "https://example.test"},
150150+ },
151151+ }
152152+153153+ if err := EnsureCurrent(ctx, fake.did, serverRot, localSigning, cfg); err != nil {
154154+ t.Fatalf("EnsureCurrent: %v", err)
155155+ }
156156+157157+ if len(fake.submitted) != 1 {
158158+ t.Fatalf("expected exactly one update op submitted, got %d", len(fake.submitted))
159159+ }
160160+ got := fake.submitted[0]
161161+162162+ offlinePub, _ := offlineRot.PublicKey()
163163+ serverPub, _ := serverRot.PublicKey()
164164+ want := []string{offlinePub.DIDKey(), serverPub.DIDKey()}
165165+166166+ if len(got.RotationKeys) != len(want) {
167167+ t.Fatalf("rotation keys length: got %d want %d", len(got.RotationKeys), len(want))
168168+ }
169169+ for i := range want {
170170+ if got.RotationKeys[i] != want[i] {
171171+ t.Errorf("rotation key [%d]: got %s want %s", i, got.RotationKeys[i], want[i])
172172+ }
173173+ }
174174+175175+ // Verify signing key was actually rotated (sanity check we hit the update path).
176176+ localPub, _ := localSigning.PublicKey()
177177+ if got.VerificationMethods["atproto"] != localPub.DIDKey() {
178178+ t.Errorf("expected signing key to update to local key %s, got %s",
179179+ localPub.DIDKey(), got.VerificationMethods["atproto"])
180180+ }
181181+}
182182+183183+func TestEnsureCurrent_RefusesUpdateWhenLocalKeyMissing(t *testing.T) {
184184+ ctx := context.Background()
185185+ tmp := t.TempDir()
186186+187187+ // Genesis lists only the offline key. The local server has been rotated out.
188188+ offlineRot := generateK256(t)
189189+ localRot := generateK256(t) // not in PLC
190190+191191+ originalSigning := generateK256(t)
192192+ localSigning := generateK256(t)
193193+ writeSigningKey(t, tmp, localSigning)
194194+195195+ fake := newFakePLC(t, []*atcrypto.PrivateKeyK256{offlineRot}, offlineRot, originalSigning)
196196+ defer fake.Close()
197197+198198+ cfg := Config{
199199+ PublicURL: "https://example.test",
200200+ PLCDirectoryURL: fake.URL(),
201201+ VerificationKeyName: "atproto",
202202+ Services: map[string]Service{
203203+ "atproto_pds": {Type: "AtprotoPersonalDataServer", Endpoint: "https://example.test"},
204204+ },
205205+ }
206206+207207+ // Local signing key has drifted, which would normally trigger an update.
208208+ if err := EnsureCurrent(ctx, fake.did, localRot, localSigning, cfg); err != nil {
209209+ t.Fatalf("EnsureCurrent returned error: %v", err)
210210+ }
211211+212212+ if len(fake.submitted) != 0 {
213213+ t.Fatalf("expected no update submission when local rotation key isn't in PLC list, got %d", len(fake.submitted))
214214+ }
215215+}
216216+217217+func TestEnsureCurrent_NoOpWhenCurrent(t *testing.T) {
218218+ ctx := context.Background()
219219+ tmp := t.TempDir()
220220+221221+ serverRot := generateK256(t)
222222+ signing := generateK256(t)
223223+ writeSigningKey(t, tmp, signing)
224224+225225+ fake := newFakePLC(t, []*atcrypto.PrivateKeyK256{serverRot}, serverRot, signing)
226226+ defer fake.Close()
227227+228228+ cfg := Config{
229229+ PublicURL: "https://example.test",
230230+ PLCDirectoryURL: fake.URL(),
231231+ VerificationKeyName: "atproto",
232232+ Services: map[string]Service{
233233+ "atproto_pds": {Type: "AtprotoPersonalDataServer", Endpoint: "https://example.test"},
234234+ },
235235+ }
236236+237237+ if err := EnsureCurrent(ctx, fake.did, serverRot, signing, cfg); err != nil {
238238+ t.Fatalf("EnsureCurrent: %v", err)
239239+ }
240240+ if len(fake.submitted) != 0 {
241241+ t.Fatalf("expected no update when state is current, got %d submitted", len(fake.submitted))
242242+ }
243243+}
+24
pkg/atproto/did/web.go
···11+package did
22+33+import (
44+ "fmt"
55+ "net/url"
66+)
77+88+// GenerateDIDFromURL computes a did:web identifier from a public URL.
99+// Per the did:web spec, ports are percent-encoded (`:` → `%3A`).
1010+func GenerateDIDFromURL(publicURL string) string {
1111+ u, err := url.Parse(publicURL)
1212+ if err != nil {
1313+ return fmt.Sprintf("did:web:%s", publicURL)
1414+ }
1515+ hostname := u.Hostname()
1616+ if hostname == "" {
1717+ hostname = "localhost"
1818+ }
1919+ port := u.Port()
2020+ if port != "" && port != "80" && port != "443" {
2121+ return fmt.Sprintf("did:web:%s%%3A%s", hostname, port)
2222+ }
2323+ return fmt.Sprintf("did:web:%s", hostname)
2424+}
+38
pkg/atproto/did/web_test.go
···11+package did
22+33+import "testing"
44+55+// TestGenerateDIDFromURL covers host extraction and port encoding for did:web.
66+func TestGenerateDIDFromURL(t *testing.T) {
77+ cases := []struct {
88+ name string
99+ publicURL string
1010+ want string
1111+ }{
1212+ {"https no port", "https://hold.example.com", "did:web:hold.example.com"},
1313+ {"http no port", "http://hold.example.com", "did:web:hold.example.com"},
1414+ {"https port 443 stripped", "https://hold.example.com:443", "did:web:hold.example.com"},
1515+ {"http port 80 stripped", "http://hold.example.com:80", "did:web:hold.example.com"},
1616+ {"non-standard port encoded", "https://hold.example.com:8443", "did:web:hold.example.com%3A8443"},
1717+ {"localhost with port", "http://localhost:3000", "did:web:localhost%3A3000"},
1818+ {"trailing path ignored", "https://hold.example.com/foo/bar", "did:web:hold.example.com"},
1919+ {"subdomain preserved", "https://api.hold.example.com", "did:web:api.hold.example.com"},
2020+ }
2121+ for _, tc := range cases {
2222+ t.Run(tc.name, func(t *testing.T) {
2323+ got := GenerateDIDFromURL(tc.publicURL)
2424+ if got != tc.want {
2525+ t.Errorf("GenerateDIDFromURL(%q): got %s want %s", tc.publicURL, got, tc.want)
2626+ }
2727+ })
2828+ }
2929+}
3030+3131+// TestGenerateDIDFromURL_EmptyHost confirms a URL without a hostname falls back to localhost.
3232+// (url.Parse on a bare path does not error, so the fallback is the only signal we have.)
3333+func TestGenerateDIDFromURL_EmptyHost(t *testing.T) {
3434+ got := GenerateDIDFromURL("")
3535+ if got != "did:web:localhost" {
3636+ t.Errorf("empty URL: got %s want did:web:localhost", got)
3737+ }
3838+}
+4-4
pkg/billing/billing.go
···644644 }
645645646646 var (
647647- minQuota int64 = -1
648648- maxQuota int64
649649- scanCount int
650650- totalHolds int
647647+ minQuota int64 = -1
648648+ maxQuota int64
649649+ scanCount int
650650+ totalHolds int
651651 )
652652653653 for _, cached := range m.holdTierCache {
+19
pkg/hold/config.go
···15151616 "github.com/spf13/viper"
17171818+ "atcr.io/pkg/atproto/did"
1819 "atcr.io/pkg/config"
1920 "atcr.io/pkg/hold/gc"
2121+ "atcr.io/pkg/hold/pds"
2022 "atcr.io/pkg/hold/quota"
2123)
2224···5658// ConfigPath returns the path to the YAML configuration file used to load this config.
5759// Subsystems (e.g. billing) use this to re-read the same file for extended fields.
5860func (c *Config) ConfigPath() string { return c.configPath }
6161+6262+// DIDConfig builds the did.Config used to load or create the hold's identity.
6363+// The verification key fragment and service set are hold-specific (atproto PDS +
6464+// AtcrHoldService), so they're filled in here rather than at every callsite.
6565+func (c *Config) DIDConfig() did.Config {
6666+ return did.Config{
6767+ DID: c.Database.DID,
6868+ Method: c.Database.DIDMethod,
6969+ PublicURL: c.Server.PublicURL,
7070+ DBPath: c.Database.Path,
7171+ SigningKeyPath: c.Database.KeyPath,
7272+ RotationKey: c.Database.RotationKey,
7373+ PLCDirectoryURL: c.Database.PLCDirectoryURL,
7474+ VerificationKeyName: "atproto",
7575+ Services: pds.HoldServices(c.Server.PublicURL),
7676+ }
7777+}
59786079// AdminConfig defines admin panel settings
6180type AdminConfig struct {
-435
pkg/hold/pds/did.go
···11-package pds
22-33-import (
44- "context"
55- "encoding/json"
66- "fmt"
77- "log/slog"
88- "net/url"
99- "os"
1010- "path/filepath"
1111- "strings"
1212-1313- "atcr.io/pkg/auth/oauth"
1414- "github.com/bluesky-social/indigo/atproto/atcrypto"
1515- didplc "github.com/did-method-plc/go-didplc"
1616-)
1717-1818-// DIDDocument represents a did:web document
1919-type DIDDocument struct {
2020- Context []string `json:"@context"`
2121- ID string `json:"id"`
2222- AlsoKnownAs []string `json:"alsoKnownAs,omitempty"`
2323- VerificationMethod []VerificationMethod `json:"verificationMethod"`
2424- Authentication []string `json:"authentication,omitempty"`
2525- AssertionMethod []string `json:"assertionMethod,omitempty"`
2626- Service []Service `json:"service,omitempty"`
2727-}
2828-2929-// VerificationMethod represents a public key in a DID document
3030-type VerificationMethod struct {
3131- ID string `json:"id"`
3232- Type string `json:"type"`
3333- Controller string `json:"controller"`
3434- PublicKeyMultibase string `json:"publicKeyMultibase"`
3535-}
3636-3737-// Service represents a service endpoint in a DID document
3838-type Service struct {
3939- ID string `json:"id"`
4040- Type string `json:"type"`
4141- ServiceEndpoint string `json:"serviceEndpoint"`
4242-}
4343-4444-// GenerateDIDDocument creates a DID document for the hold's identity.
4545-// It uses the hold's stored DID (which may be did:web or did:plc).
4646-func (p *HoldPDS) GenerateDIDDocument(publicURL string) (*DIDDocument, error) {
4747- did := p.did
4848-4949- // Parse URL for alsoKnownAs
5050- u, err := url.Parse(publicURL)
5151- if err != nil {
5252- return nil, fmt.Errorf("failed to parse public URL: %w", err)
5353- }
5454- host := u.Hostname()
5555- if port := u.Port(); port != "" && port != "80" && port != "443" {
5656- host = fmt.Sprintf("%s:%s", host, port)
5757- }
5858-5959- // Get public key in multibase format using indigo's crypto
6060- pubKey, err := p.signingKey.PublicKey()
6161- if err != nil {
6262- return nil, fmt.Errorf("failed to get public key: %w", err)
6363- }
6464- publicKeyMultibase := pubKey.Multibase()
6565-6666- doc := &DIDDocument{
6767- Context: []string{
6868- "https://www.w3.org/ns/did/v1",
6969- "https://w3id.org/security/multikey/v1",
7070- "https://w3id.org/security/suites/secp256k1-2019/v1",
7171- },
7272- ID: did,
7373- AlsoKnownAs: []string{
7474- fmt.Sprintf("at://%s", host),
7575- },
7676- VerificationMethod: []VerificationMethod{
7777- {
7878- ID: fmt.Sprintf("%s#atproto", did),
7979- Type: "Multikey",
8080- Controller: did,
8181- PublicKeyMultibase: publicKeyMultibase,
8282- },
8383- },
8484- Authentication: []string{
8585- fmt.Sprintf("%s#atproto", did),
8686- },
8787- Service: []Service{
8888- {
8989- ID: "#atproto_pds",
9090- Type: "AtprotoPersonalDataServer",
9191- ServiceEndpoint: publicURL,
9292- },
9393- {
9494- ID: "#atcr_hold",
9595- Type: "AtcrHoldService",
9696- ServiceEndpoint: publicURL,
9797- },
9898- },
9999- }
100100-101101- return doc, nil
102102-}
103103-104104-// MarshalDIDDocument converts a DID document to JSON using the stored public URL
105105-func (p *HoldPDS) MarshalDIDDocument() ([]byte, error) {
106106- doc, err := p.GenerateDIDDocument(p.PublicURL)
107107- if err != nil {
108108- return nil, err
109109- }
110110-111111- return json.MarshalIndent(doc, "", " ")
112112-}
113113-114114-// DIDConfig holds parameters for DID creation/loading.
115115-type DIDConfig struct {
116116- DID string // Explicit DID for adoption/recovery (optional)
117117- DIDMethod string // "web" or "plc"
118118- PublicURL string
119119- DBPath string
120120- SigningKeyPath string
121121- RotationKey string // Multibase-encoded private key, K-256 or P-256 (optional)
122122- PLCDirectoryURL string
123123-}
124124-125125-// LoadOrCreateDID returns the hold's DID, either by deriving it from the URL (did:web)
126126-// or by loading/creating a did:plc identity registered with the PLC directory.
127127-//
128128-// For did:plc, the priority is: config DID > did.txt > create new.
129129-// When an existing DID is found (config or did.txt), EnsurePLCCurrent is called
130130-// to auto-update the PLC directory if the signing key or URL has changed.
131131-func LoadOrCreateDID(ctx context.Context, cfg DIDConfig) (string, error) {
132132- if cfg.DIDMethod != "plc" {
133133- return GenerateDIDFromURL(cfg.PublicURL), nil
134134- }
135135-136136- didPath := filepath.Join(cfg.DBPath, "did.txt")
137137-138138- // Priority: config DID > did.txt > create new
139139- var did string
140140- if cfg.DID != "" {
141141- if !strings.HasPrefix(cfg.DID, "did:plc:") {
142142- return "", fmt.Errorf("database.did must be a did:plc identifier, got %q", cfg.DID)
143143- }
144144- did = cfg.DID
145145- slog.Info("Using DID from config (adoption/recovery)", "did", did)
146146- } else if data, err := os.ReadFile(didPath); err == nil {
147147- d := strings.TrimSpace(string(data))
148148- if strings.HasPrefix(d, "did:plc:") {
149149- did = d
150150- slog.Info("Loaded existing did:plc identity", "did", did)
151151- }
152152- }
153153-154154- if did != "" {
155155- // Persist to did.txt (may be from config on first adoption)
156156- if err := os.MkdirAll(filepath.Dir(didPath), 0755); err != nil {
157157- return "", fmt.Errorf("failed to create directory for did.txt: %w", err)
158158- }
159159- if err := os.WriteFile(didPath, []byte(did+"\n"), 0600); err != nil {
160160- return "", fmt.Errorf("failed to write did.txt: %w", err)
161161- }
162162-163163- // Load signing key (generate if missing — recovery case)
164164- signingKey, err := oauth.GenerateOrLoadPDSKey(cfg.SigningKeyPath)
165165- if err != nil {
166166- return "", fmt.Errorf("failed to load signing key: %w", err)
167167- }
168168-169169- // Try to parse rotation key (optional — may not be configured)
170170- rotationKey, _ := parseOptionalMultibaseKey(cfg.RotationKey)
171171-172172- if err := EnsurePLCCurrent(ctx, did, rotationKey, signingKey, cfg.PublicURL, cfg.PLCDirectoryURL); err != nil {
173173- slog.Warn("Failed to verify PLC identity is current (will retry on next restart)",
174174- "did", did,
175175- "error", err,
176176- )
177177- }
178178-179179- return did, nil
180180- }
181181-182182- // No existing DID — create new genesis operation
183183- slog.Info("Creating new did:plc identity")
184184-185185- // Load or generate signing key
186186- signingKey, err := oauth.GenerateOrLoadPDSKey(cfg.SigningKeyPath)
187187- if err != nil {
188188- return "", fmt.Errorf("failed to load signing key: %w", err)
189189- }
190190-191191- // Parse or generate rotation key
192192- var rotationKey atcrypto.PrivateKeyExportable
193193- if cfg.RotationKey != "" {
194194- rotationKey, err = parseOptionalMultibaseKey(cfg.RotationKey)
195195- if err != nil {
196196- return "", fmt.Errorf("failed to parse rotation_key: %w", err)
197197- }
198198- } else {
199199- // Generate a new rotation key — user must save the multibase output
200200- rawKey, genErr := atcrypto.GeneratePrivateKeyK256()
201201- if genErr != nil {
202202- return "", fmt.Errorf("failed to generate rotation key: %w", genErr)
203203- }
204204- rotationKey = rawKey
205205- slog.Warn("Generated new rotation key — save this in your config as database.rotation_key",
206206- "rotation_key", rawKey.Multibase(),
207207- )
208208- }
209209-210210- did, err = CreatePLCIdentity(ctx, rotationKey, signingKey, cfg.PublicURL, cfg.PLCDirectoryURL)
211211- if err != nil {
212212- return "", fmt.Errorf("failed to create PLC identity: %w", err)
213213- }
214214-215215- // Persist DID
216216- if err := os.MkdirAll(filepath.Dir(didPath), 0755); err != nil {
217217- return "", fmt.Errorf("failed to create directory for did.txt: %w", err)
218218- }
219219- if err := os.WriteFile(didPath, []byte(did+"\n"), 0600); err != nil {
220220- return "", fmt.Errorf("failed to write did.txt: %w", err)
221221- }
222222-223223- slog.Info("Created did:plc identity",
224224- "did", did,
225225- "plc_directory", cfg.PLCDirectoryURL,
226226- )
227227- slog.Warn("Back up your rotation_key. It is only needed for DID updates (URL changes, key rotation).")
228228-229229- return did, nil
230230-}
231231-232232-// parseOptionalMultibaseKey parses a multibase-encoded private key string (K-256 or P-256).
233233-// Returns nil, nil if the input is empty (key not configured).
234234-func parseOptionalMultibaseKey(encoded string) (atcrypto.PrivateKeyExportable, error) {
235235- if encoded == "" {
236236- return nil, nil
237237- }
238238- key, err := atcrypto.ParsePrivateMultibase(encoded)
239239- if err != nil {
240240- return nil, fmt.Errorf("failed to parse rotation key multibase string: %w", err)
241241- }
242242- return key, nil
243243-}
244244-245245-// EnsurePLCCurrent checks the PLC directory for the given DID and updates it
246246-// if the local signing key or public URL doesn't match what's registered.
247247-// If rotationKey is nil, mismatches are logged as warnings but not fatal.
248248-func EnsurePLCCurrent(ctx context.Context, did string, rotationKey atcrypto.PrivateKey, signingKey *atcrypto.PrivateKeyK256, publicURL, plcDirectoryURL string) error {
249249- client := &didplc.Client{DirectoryURL: plcDirectoryURL}
250250-251251- // Fetch current op log
252252- opLog, err := client.OpLog(ctx, did)
253253- if err != nil {
254254- return fmt.Errorf("failed to fetch PLC op log for %s: %w", did, err)
255255- }
256256- if len(opLog) == 0 {
257257- return fmt.Errorf("empty op log for %s", did)
258258- }
259259-260260- lastEntry := opLog[len(opLog)-1]
261261- lastOp := lastEntry.Regular
262262- if lastOp == nil {
263263- // Last op is not a regular op (could be legacy or tombstone) — skip update
264264- slog.Warn("Last PLC operation is not a regular op, skipping auto-update", "did", did)
265265- return nil
266266- }
267267-268268- // Compare local state vs PLC state
269269- sigPub, err := signingKey.PublicKey()
270270- if err != nil {
271271- return fmt.Errorf("failed to get signing public key: %w", err)
272272- }
273273- localVerificationKey := sigPub.DIDKey()
274274- plcVerificationKey := lastOp.VerificationMethods["atproto"]
275275-276276- localEndpoint := publicURL
277277- var plcEndpoint string
278278- if svc, ok := lastOp.Services["atproto_pds"]; ok {
279279- plcEndpoint = svc.Endpoint
280280- }
281281-282282- keyMatch := localVerificationKey == plcVerificationKey
283283- endpointMatch := localEndpoint == plcEndpoint
284284-285285- if keyMatch && endpointMatch {
286286- slog.Info("PLC identity is current", "did", did)
287287- return nil
288288- }
289289-290290- slog.Info("PLC identity needs update",
291291- "did", did,
292292- "signing_key_changed", !keyMatch,
293293- "endpoint_changed", !endpointMatch,
294294- )
295295-296296- if rotationKey == nil {
297297- slog.Warn("PLC document doesn't match local state but no rotation key available. Provide rotation key to auto-update PLC directory.",
298298- "did", did,
299299- "signing_key_changed", !keyMatch,
300300- "endpoint_changed", !endpointMatch,
301301- )
302302- return nil
303303- }
304304-305305- // Build update operation
306306- rotPub, err := rotationKey.PublicKey()
307307- if err != nil {
308308- return fmt.Errorf("failed to get rotation public key: %w", err)
309309- }
310310-311311- // Extract hostname for alsoKnownAs
312312- u, err := url.Parse(publicURL)
313313- if err != nil {
314314- return fmt.Errorf("failed to parse public URL: %w", err)
315315- }
316316- host := u.Hostname()
317317- if port := u.Port(); port != "" && port != "80" && port != "443" {
318318- host = host + ":" + port
319319- }
320320-321321- prevCID := lastEntry.AsOperation().CID().String()
322322-323323- op := &didplc.RegularOp{
324324- Type: "plc_operation",
325325- RotationKeys: []string{rotPub.DIDKey()},
326326- VerificationMethods: map[string]string{
327327- "atproto": localVerificationKey,
328328- },
329329- AlsoKnownAs: []string{"at://" + host},
330330- Services: map[string]didplc.OpService{
331331- "atproto_pds": {Type: "AtprotoPersonalDataServer", Endpoint: publicURL},
332332- "atcr_hold": {Type: "AtcrHoldService", Endpoint: publicURL},
333333- },
334334- Prev: &prevCID,
335335- }
336336-337337- if err := op.Sign(rotationKey); err != nil {
338338- return fmt.Errorf("failed to sign PLC update operation: %w", err)
339339- }
340340-341341- if err := client.Submit(ctx, did, op); err != nil {
342342- return fmt.Errorf("failed to submit PLC update: %w", err)
343343- }
344344-345345- slog.Info("Updated PLC identity",
346346- "did", did,
347347- "signing_key_rotated", !keyMatch,
348348- "endpoint_changed", !endpointMatch,
349349- )
350350-351351- return nil
352352-}
353353-354354-// CreatePLCIdentity creates a new did:plc identity by building a genesis operation,
355355-// signing it with the rotation key, and submitting it to the PLC directory.
356356-func CreatePLCIdentity(ctx context.Context, rotationKey atcrypto.PrivateKey, signingKey *atcrypto.PrivateKeyK256, publicURL, plcDirectoryURL string) (string, error) {
357357- rotPub, err := rotationKey.PublicKey()
358358- if err != nil {
359359- return "", fmt.Errorf("failed to get rotation public key: %w", err)
360360- }
361361-362362- sigPub, err := signingKey.PublicKey()
363363- if err != nil {
364364- return "", fmt.Errorf("failed to get signing public key: %w", err)
365365- }
366366-367367- // Extract hostname for alsoKnownAs
368368- u, err := url.Parse(publicURL)
369369- if err != nil {
370370- return "", fmt.Errorf("failed to parse public URL: %w", err)
371371- }
372372- host := u.Hostname()
373373- if port := u.Port(); port != "" && port != "80" && port != "443" {
374374- host = host + ":" + port
375375- }
376376-377377- op := &didplc.RegularOp{
378378- Type: "plc_operation",
379379- RotationKeys: []string{rotPub.DIDKey()},
380380- VerificationMethods: map[string]string{
381381- "atproto": sigPub.DIDKey(),
382382- },
383383- AlsoKnownAs: []string{"at://" + host},
384384- Services: map[string]didplc.OpService{
385385- "atproto_pds": {Type: "AtprotoPersonalDataServer", Endpoint: publicURL},
386386- "atcr_hold": {Type: "AtcrHoldService", Endpoint: publicURL},
387387- },
388388- Prev: nil,
389389- }
390390-391391- if err := op.Sign(rotationKey); err != nil {
392392- return "", fmt.Errorf("failed to sign PLC genesis operation: %w", err)
393393- }
394394-395395- did, err := op.DID()
396396- if err != nil {
397397- return "", fmt.Errorf("failed to compute DID from genesis operation: %w", err)
398398- }
399399-400400- client := &didplc.Client{DirectoryURL: plcDirectoryURL}
401401- if err := client.Submit(ctx, did, op); err != nil {
402402- return "", fmt.Errorf("failed to submit genesis operation to PLC directory: %w", err)
403403- }
404404-405405- return did, nil
406406-}
407407-408408-// GenerateDIDFromURL creates a did:web identifier from a public URL.
409409-// Per the did:web spec, ports are percent-encoded: the colon becomes %3A.
410410-// Example: "http://hold1.example.com:8080" -> "did:web:hold1.example.com%3A8080"
411411-func GenerateDIDFromURL(publicURL string) string {
412412- // Parse URL
413413- u, err := url.Parse(publicURL)
414414- if err != nil {
415415- // Fallback: assume it's just a hostname
416416- return fmt.Sprintf("did:web:%s", publicURL)
417417- }
418418-419419- // Get hostname
420420- hostname := u.Hostname()
421421- if hostname == "" {
422422- hostname = "localhost"
423423- }
424424-425425- // Get port
426426- port := u.Port()
427427-428428- // Include port in DID if it's non-standard (not 80 for http, not 443 for https)
429429- // Per did:web spec, the colon is percent-encoded as %3A
430430- if port != "" && port != "80" && port != "443" {
431431- return fmt.Sprintf("did:web:%s%%3A%s", hostname, port)
432432- }
433433-434434- return fmt.Sprintf("did:web:%s", hostname)
435435-}
-274
pkg/hold/pds/did_test.go
···11-package pds
22-33-import (
44- "context"
55- "encoding/json"
66- "path/filepath"
77- "testing"
88-)
99-1010-// TestGenerateDIDFromURL tests DID generation from various URL formats
1111-func TestGenerateDIDFromURL(t *testing.T) {
1212- tests := []struct {
1313- name string
1414- publicURL string
1515- expectedDID string
1616- }{
1717- {
1818- name: "standard HTTP with standard port",
1919- publicURL: "http://hold.example.com",
2020- expectedDID: "did:web:hold.example.com",
2121- },
2222- {
2323- name: "standard HTTPS with standard port",
2424- publicURL: "https://hold.example.com",
2525- expectedDID: "did:web:hold.example.com",
2626- },
2727- {
2828- name: "HTTP with non-standard port",
2929- publicURL: "http://hold.example.com:8080",
3030- expectedDID: "did:web:hold.example.com%3A8080",
3131- },
3232- {
3333- name: "HTTPS with non-standard port",
3434- publicURL: "https://hold.example.com:8443",
3535- expectedDID: "did:web:hold.example.com%3A8443",
3636- },
3737- {
3838- name: "localhost with port",
3939- publicURL: "http://localhost:8080",
4040- expectedDID: "did:web:localhost%3A8080",
4141- },
4242- {
4343- name: "HTTP with explicit port 80",
4444- publicURL: "http://hold.example.com:80",
4545- expectedDID: "did:web:hold.example.com",
4646- },
4747- {
4848- name: "HTTPS with explicit port 443",
4949- publicURL: "https://hold.example.com:443",
5050- expectedDID: "did:web:hold.example.com",
5151- },
5252- {
5353- name: "subdomain",
5454- publicURL: "https://hold1.atcr.io",
5555- expectedDID: "did:web:hold1.atcr.io",
5656- },
5757- }
5858-5959- for _, tt := range tests {
6060- t.Run(tt.name, func(t *testing.T) {
6161- did := GenerateDIDFromURL(tt.publicURL)
6262- if did != tt.expectedDID {
6363- t.Errorf("Expected DID %s, got %s", tt.expectedDID, did)
6464- }
6565- })
6666- }
6767-}
6868-6969-// TestGenerateDIDFromURL_InvalidURL tests handling of invalid URLs
7070-func TestGenerateDIDFromURL_InvalidURL(t *testing.T) {
7171- // Invalid URLs get parsed with empty hostname, which defaults to localhost
7272- did := GenerateDIDFromURL("not a url")
7373- if did != "did:web:localhost" {
7474- t.Errorf("Expected did:web:localhost for invalid URL, got %s", did)
7575- }
7676-}
7777-7878-// TestGenerateDIDDocument tests DID document generation
7979-func TestGenerateDIDDocument(t *testing.T) {
8080- ctx := context.Background()
8181- tmpDir := t.TempDir()
8282-8383- dbPath := filepath.Join(tmpDir, "pds.db")
8484- keyPath := filepath.Join(tmpDir, "signing-key")
8585- publicURL := "https://hold.example.com"
8686-8787- pds, err := NewHoldPDS(ctx, "did:web:hold.example.com", publicURL, "https://atcr.io", dbPath, keyPath, false)
8888- if err != nil {
8989- t.Fatalf("Failed to create PDS: %v", err)
9090- }
9191-9292- doc, err := pds.GenerateDIDDocument(publicURL)
9393- if err != nil {
9494- t.Fatalf("Failed to generate DID document: %v", err)
9595- }
9696-9797- // Verify required fields
9898- if doc.ID != "did:web:hold.example.com" {
9999- t.Errorf("Expected DID did:web:hold.example.com, got %s", doc.ID)
100100- }
101101-102102- // Verify context
103103- if len(doc.Context) != 3 {
104104- t.Errorf("Expected 3 context entries, got %d", len(doc.Context))
105105- }
106106-107107- expectedContexts := []string{
108108- "https://www.w3.org/ns/did/v1",
109109- "https://w3id.org/security/multikey/v1",
110110- "https://w3id.org/security/suites/secp256k1-2019/v1",
111111- }
112112- for i, expected := range expectedContexts {
113113- if doc.Context[i] != expected {
114114- t.Errorf("Expected context[%d] = %s, got %s", i, expected, doc.Context[i])
115115- }
116116- }
117117-118118- // Verify alsoKnownAs
119119- if len(doc.AlsoKnownAs) != 1 || doc.AlsoKnownAs[0] != "at://hold.example.com" {
120120- t.Errorf("Expected alsoKnownAs=['at://hold.example.com'], got %v", doc.AlsoKnownAs)
121121- }
122122-123123- // Verify verification method
124124- if len(doc.VerificationMethod) != 1 {
125125- t.Fatalf("Expected 1 verification method, got %d", len(doc.VerificationMethod))
126126- }
127127-128128- vm := doc.VerificationMethod[0]
129129- if vm.ID != "did:web:hold.example.com#atproto" {
130130- t.Errorf("Expected verification method ID did:web:hold.example.com#atproto, got %s", vm.ID)
131131- }
132132- if vm.Type != "Multikey" {
133133- t.Errorf("Expected type Multikey, got %s", vm.Type)
134134- }
135135- if vm.Controller != "did:web:hold.example.com" {
136136- t.Errorf("Expected controller did:web:hold.example.com, got %s", vm.Controller)
137137- }
138138- if vm.PublicKeyMultibase == "" {
139139- t.Error("Expected non-empty publicKeyMultibase")
140140- }
141141-142142- // Verify authentication
143143- if len(doc.Authentication) != 1 || doc.Authentication[0] != "did:web:hold.example.com#atproto" {
144144- t.Errorf("Expected authentication=['did:web:hold.example.com#atproto'], got %v", doc.Authentication)
145145- }
146146-147147- // Verify services
148148- if len(doc.Service) != 2 {
149149- t.Fatalf("Expected 2 services, got %d", len(doc.Service))
150150- }
151151-152152- // Check PDS service
153153- pdsService := doc.Service[0]
154154- if pdsService.ID != "#atproto_pds" {
155155- t.Errorf("Expected service ID #atproto_pds, got %s", pdsService.ID)
156156- }
157157- if pdsService.Type != "AtprotoPersonalDataServer" {
158158- t.Errorf("Expected service type AtprotoPersonalDataServer, got %s", pdsService.Type)
159159- }
160160- if pdsService.ServiceEndpoint != publicURL {
161161- t.Errorf("Expected service endpoint %s, got %s", publicURL, pdsService.ServiceEndpoint)
162162- }
163163-164164- // Check hold service
165165- holdService := doc.Service[1]
166166- if holdService.ID != "#atcr_hold" {
167167- t.Errorf("Expected service ID #atcr_hold, got %s", holdService.ID)
168168- }
169169- if holdService.Type != "AtcrHoldService" {
170170- t.Errorf("Expected service type AtcrHoldService, got %s", holdService.Type)
171171- }
172172- if holdService.ServiceEndpoint != publicURL {
173173- t.Errorf("Expected service endpoint %s, got %s", publicURL, holdService.ServiceEndpoint)
174174- }
175175-}
176176-177177-// TestGenerateDIDDocument_WithPort tests DID document with non-standard port
178178-func TestGenerateDIDDocument_WithPort(t *testing.T) {
179179- ctx := context.Background()
180180- tmpDir := t.TempDir()
181181-182182- dbPath := filepath.Join(tmpDir, "pds.db")
183183- keyPath := filepath.Join(tmpDir, "signing-key")
184184- publicURL := "https://hold.example.com:8443"
185185-186186- pds, err := NewHoldPDS(ctx, "did:web:hold.example.com%3A8443", publicURL, "https://atcr.io", dbPath, keyPath, false)
187187- if err != nil {
188188- t.Fatalf("Failed to create PDS: %v", err)
189189- }
190190-191191- doc, err := pds.GenerateDIDDocument(publicURL)
192192- if err != nil {
193193- t.Fatalf("Failed to generate DID document: %v", err)
194194- }
195195-196196- // Verify DID includes percent-encoded port
197197- if doc.ID != "did:web:hold.example.com%3A8443" {
198198- t.Errorf("Expected DID did:web:hold.example.com%%3A8443, got %s", doc.ID)
199199- }
200200-201201- // Verify alsoKnownAs includes port
202202- if doc.AlsoKnownAs[0] != "at://hold.example.com:8443" {
203203- t.Errorf("Expected alsoKnownAs with port, got %s", doc.AlsoKnownAs[0])
204204- }
205205-}
206206-207207-// TestMarshalDIDDocument tests DID document JSON marshaling
208208-func TestMarshalDIDDocument(t *testing.T) {
209209- ctx := context.Background()
210210- tmpDir := t.TempDir()
211211-212212- dbPath := filepath.Join(tmpDir, "pds.db")
213213- keyPath := filepath.Join(tmpDir, "signing-key")
214214- publicURL := "https://hold.example.com"
215215-216216- pds, err := NewHoldPDS(ctx, "did:web:hold.example.com", publicURL, "https://atcr.io", dbPath, keyPath, false)
217217- if err != nil {
218218- t.Fatalf("Failed to create PDS: %v", err)
219219- }
220220-221221- jsonBytes, err := pds.MarshalDIDDocument()
222222- if err != nil {
223223- t.Fatalf("Failed to marshal DID document: %v", err)
224224- }
225225-226226- // Verify it's valid JSON
227227- var doc map[string]any
228228- if err := json.Unmarshal(jsonBytes, &doc); err != nil {
229229- t.Fatalf("Failed to unmarshal DID document JSON: %v", err)
230230- }
231231-232232- // Verify required fields
233233- if id, ok := doc["id"].(string); !ok || id != "did:web:hold.example.com" {
234234- t.Errorf("Expected id='did:web:hold.example.com', got %v", doc["id"])
235235- }
236236-237237- if _, ok := doc["@context"]; !ok {
238238- t.Error("Expected @context field in JSON")
239239- }
240240-241241- if _, ok := doc["verificationMethod"]; !ok {
242242- t.Error("Expected verificationMethod field in JSON")
243243- }
244244-245245- if _, ok := doc["service"]; !ok {
246246- t.Error("Expected service field in JSON")
247247- }
248248-249249- // Verify pretty-printed (has indentation)
250250- if len(jsonBytes) < 100 {
251251- t.Error("Expected pretty-printed JSON to be reasonably sized")
252252- }
253253-}
254254-255255-// TestGenerateDIDDocument_InvalidURL tests error handling
256256-func TestGenerateDIDDocument_InvalidURL(t *testing.T) {
257257- ctx := context.Background()
258258- tmpDir := t.TempDir()
259259-260260- dbPath := filepath.Join(tmpDir, "pds.db")
261261- keyPath := filepath.Join(tmpDir, "signing-key")
262262- publicURL := "https://hold.example.com"
263263-264264- pds, err := NewHoldPDS(ctx, "did:web:hold.example.com", publicURL, "https://atcr.io", dbPath, keyPath, false)
265265- if err != nil {
266266- t.Fatalf("Failed to create PDS: %v", err)
267267- }
268268-269269- // Try to generate DID document with invalid URL
270270- _, err = pds.GenerateDIDDocument("ht!tp://invalid url")
271271- if err == nil {
272272- t.Error("Expected error for invalid URL, got nil")
273273- }
274274-}
-11
pkg/hold/pds/export.go
···11-package pds
22-33-import (
44- "context"
55- "io"
66-)
77-88-// ExportToCAR streams the hold's repo as a CAR file to the writer.
99-func (p *HoldPDS) ExportToCAR(ctx context.Context, w io.Writer) error {
1010- return p.repomgr.ReadRepo(ctx, p.uid, "", w)
1111-}
+12
pkg/hold/pds/server.go
pkg/hold/pds/hold_pds.go
···1111 "strings"
12121313 "atcr.io/pkg/atproto"
1414+ "atcr.io/pkg/atproto/did"
1415 "atcr.io/pkg/auth/oauth"
1516 holddb "atcr.io/pkg/hold/db"
1617 "atcr.io/pkg/s3"
···3435 lexutil.RegisterType(atproto.DailyStatsCollection, &atproto.DailyStatsRecord{})
3536 lexutil.RegisterType(atproto.ScanCollection, &atproto.ScanRecord{})
3637 lexutil.RegisterType(atproto.ImageConfigCollection, &atproto.ImageConfigRecord{})
3838+}
3939+4040+// HoldServices returns the service entries the hold publishes in its DID document
4141+// and PLC operations: an atproto PDS endpoint plus the ATCR hold service endpoint.
4242+// Single source of truth, used by both boot-time identity loading and DID-document
4343+// serving.
4444+func HoldServices(publicURL string) map[string]did.Service {
4545+ return map[string]did.Service{
4646+ "atproto_pds": {Type: "AtprotoPersonalDataServer", Endpoint: publicURL},
4747+ "atcr_hold": {Type: "AtcrHoldService", Endpoint: publicURL},
4848+ }
3749}
38503951// HoldPDS is a minimal ATProto PDS implementation for a hold service
···33import (
44 "database/sql"
55 "fmt"
66+ "io"
77+ "log/slog"
68 "os"
79 "path/filepath"
1010+ "strings"
811 "time"
9121010- _ "github.com/tursodatabase/go-libsql"
1313+ "github.com/bluesky-social/indigo/atproto/atcrypto"
1414+ "github.com/bluesky-social/indigo/atproto/labeling"
1515+ "github.com/tursodatabase/go-libsql"
1116)
1717+1818+// LabelVersion is the ATProto label format version (currently 1).
1919+const LabelVersion int64 = labeling.ATPROTO_LABEL_VERSION
12201321const schema = `
1422CREATE TABLE IF NOT EXISTS labels (
···2028 neg BOOLEAN NOT NULL DEFAULT 0,
2129 cts TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
2230 exp TIMESTAMP,
3131+ ver INTEGER NOT NULL DEFAULT 1,
3232+ sig BLOB NOT NULL,
2333 subject_did TEXT NOT NULL,
2424- subject_repo TEXT NOT NULL DEFAULT '',
2525- UNIQUE(src, uri, val, neg)
3434+ subject_repo TEXT NOT NULL DEFAULT ''
2635);
2736CREATE INDEX IF NOT EXISTS idx_labels_subject ON labels(subject_did, subject_repo);
2837CREATE INDEX IF NOT EXISTS idx_labels_cts ON labels(cts DESC);
3838+CREATE INDEX IF NOT EXISTS idx_labels_uri ON labels(uri);
2939`
30403131-// Label represents an ATProto label (com.atproto.label.defs#label).
4141+// Label represents an ATProto label record stored locally. Its on-the-wire representation
4242+// is produced by ToLabeling() which round-trips through indigo's labeling package so the
4343+// signature stays valid byte-for-byte.
3244type Label struct {
3345 ID int64
3446 Src string
···3850 Neg bool
3951 Cts time.Time
4052 Exp *time.Time
5353+ Ver int64
5454+ Sig []byte
4155 SubjectDID string
4256 SubjectRepo string
4357}
44584545-// OpenDB opens or creates the labeler database.
4646-func OpenDB(dbPath string) (*sql.DB, error) {
5959+// LibsqlSync configures optional embedded-replica sync to a remote libSQL database.
6060+// SyncURL empty means local-only mode.
6161+type LibsqlSync struct {
6262+ SyncURL string
6363+ AuthToken string
6464+ SyncInterval time.Duration
6565+}
6666+6767+// LabelerDB wraps the *sql.DB plus its libsql connector (when in embedded-replica mode)
6868+// so the caller can release file locks on shutdown.
6969+type LabelerDB struct {
7070+ DB *sql.DB
7171+ connector io.Closer
7272+}
7373+7474+// Close closes the database and the libsql connector (if any). The connector close is
7575+// what releases file locks; without it a subsequent local-only open errors with
7676+// "database is locked" — the same gotcha the hold ran into.
7777+func (l *LabelerDB) Close() error {
7878+ var dbErr, connErr error
7979+ if l.DB != nil {
8080+ dbErr = l.DB.Close()
8181+ }
8282+ if l.connector != nil {
8383+ connErr = l.connector.Close()
8484+ }
8585+ if dbErr != nil {
8686+ return dbErr
8787+ }
8888+ return connErr
8989+}
9090+9191+// OpenDB opens or creates the labeler database. When sync.SyncURL is set, the DB runs
9292+// in embedded-replica mode (writes go to the remote, frames replicate to the local
9393+// file); otherwise it's a plain local libSQL file. Schema is applied either way.
9494+func OpenDB(dbPath string, sync LibsqlSync) (*LabelerDB, error) {
4795 if err := os.MkdirAll(filepath.Dir(dbPath), 0755); err != nil {
4896 return nil, fmt.Errorf("failed to create db directory: %w", err)
4997 }
50985151- db, err := sql.Open("libsql", "file:"+dbPath)
5252- if err != nil {
5353- return nil, fmt.Errorf("failed to open database: %w", err)
9999+ var (
100100+ db *sql.DB
101101+ connector io.Closer
102102+ )
103103+104104+ if sync.SyncURL != "" {
105105+ opts := []libsql.Option{libsql.WithAuthToken(sync.AuthToken)}
106106+ if sync.SyncInterval > 0 {
107107+ opts = append(opts, libsql.WithSyncInterval(sync.SyncInterval))
108108+ }
109109+ conn, err := libsql.NewEmbeddedReplicaConnector(dbPath, sync.SyncURL, opts...)
110110+ if err != nil {
111111+ return nil, fmt.Errorf("failed to create libsql embedded replica connector: %w", err)
112112+ }
113113+ db = sql.OpenDB(conn)
114114+ connector = conn
115115+ slog.Info("Labeler database opened in embedded replica mode", "path", dbPath, "sync_url", sync.SyncURL)
116116+ } else {
117117+ dsn := dbPath
118118+ if !strings.HasPrefix(dsn, "file:") && !strings.HasPrefix(dsn, ":memory:") {
119119+ dsn = "file:" + dsn
120120+ }
121121+ var err error
122122+ db, err = sql.Open("libsql", dsn)
123123+ if err != nil {
124124+ return nil, fmt.Errorf("failed to open database: %w", err)
125125+ }
126126+ slog.Info("Labeler database opened in local-only mode", "path", dbPath)
127127+ }
128128+129129+ // Local PRAGMAs only — Bunny rejects PRAGMA assignments forwarded over the
130130+ // replication protocol (same caveat as pkg/hold/db).
131131+ if sync.SyncURL == "" {
132132+ var journalMode string
133133+ if err := db.QueryRow("PRAGMA journal_mode = WAL").Scan(&journalMode); err != nil {
134134+ _ = closeIfNonNil(db, connector)
135135+ return nil, fmt.Errorf("failed to set journal mode: %w", err)
136136+ }
137137+ var busyTimeout int
138138+ if err := db.QueryRow("PRAGMA busy_timeout = 5000").Scan(&busyTimeout); err != nil {
139139+ _ = closeIfNonNil(db, connector)
140140+ return nil, fmt.Errorf("failed to set busy_timeout: %w", err)
141141+ }
54142 }
551435656- // Apply schema
57144 for _, stmt := range splitStatements(schema) {
58145 if _, err := db.Exec(stmt); err != nil {
146146+ _ = closeIfNonNil(db, connector)
59147 return nil, fmt.Errorf("failed to apply schema: %w", err)
60148 }
61149 }
150150+ return &LabelerDB{DB: db, connector: connector}, nil
151151+}
621526363- return db, nil
153153+// closeIfNonNil is the defensive cleanup for the failure path on OpenDB so we don't
154154+// leave file locks dangling if schema application fails.
155155+func closeIfNonNil(db *sql.DB, connector io.Closer) error {
156156+ if db != nil {
157157+ _ = db.Close()
158158+ }
159159+ if connector != nil {
160160+ return connector.Close()
161161+ }
162162+ return nil
64163}
651646666-// splitStatements splits SQL by semicolons (go-libsql doesn't support multi-statement exec).
67165func splitStatements(sql string) []string {
6868- var stmts []string
6969- for _, s := range splitOnSemicolon(sql) {
7070- s = trimSpace(s)
166166+ parts := strings.Split(sql, ";")
167167+ out := make([]string, 0, len(parts))
168168+ for _, s := range parts {
169169+ s = strings.TrimSpace(s)
71170 if s != "" {
7272- stmts = append(stmts, s)
171171+ out = append(out, s)
73172 }
74173 }
7575- return stmts
174174+ return out
76175}
771767878-func splitOnSemicolon(s string) []string {
7979- var parts []string
8080- start := 0
8181- for i := 0; i < len(s); i++ {
8282- if s[i] == ';' {
8383- parts = append(parts, s[start:i])
8484- start = i + 1
8585- }
177177+// ToLabeling converts the row into indigo's label struct (deterministic CBOR shape).
178178+func (l *Label) ToLabeling() labeling.Label {
179179+ out := labeling.Label{
180180+ CreatedAt: l.Cts.UTC().Format(time.RFC3339),
181181+ SourceDID: l.Src,
182182+ URI: l.URI,
183183+ Val: l.Val,
184184+ Version: l.Ver,
86185 }
8787- if start < len(s) {
8888- parts = append(parts, s[start:])
186186+ if l.CID != "" {
187187+ s := l.CID
188188+ out.CID = &s
89189 }
9090- return parts
190190+ if l.Exp != nil {
191191+ s := l.Exp.UTC().Format(time.RFC3339)
192192+ out.ExpiresAt = &s
193193+ }
194194+ if l.Neg {
195195+ t := true
196196+ out.Negated = &t
197197+ }
198198+ if len(l.Sig) > 0 {
199199+ out.Sig = l.Sig
200200+ }
201201+ return out
91202}
922039393-func trimSpace(s string) string {
9494- // Simple trim that handles newlines and spaces
9595- i := 0
9696- for i < len(s) && (s[i] == ' ' || s[i] == '\t' || s[i] == '\n' || s[i] == '\r') {
9797- i++
204204+// Sign computes a k256 signature over the deterministic CBOR encoding of the label
205205+// (without the sig field) and stores it on the row.
206206+func (l *Label) Sign(key *atcrypto.PrivateKeyK256) error {
207207+ if l.Ver == 0 {
208208+ l.Ver = LabelVersion
209209+ }
210210+ if l.Cts.IsZero() {
211211+ l.Cts = time.Now().UTC()
98212 }
9999- j := len(s)
100100- for j > i && (s[j-1] == ' ' || s[j-1] == '\t' || s[j-1] == '\n' || s[j-1] == '\r') {
101101- j--
213213+ pre := l.ToLabeling()
214214+ pre.Sig = nil
215215+ if err := pre.Sign(key); err != nil {
216216+ return fmt.Errorf("failed to sign label: %w", err)
102217 }
103103- return s[i:j]
218218+ l.Sig = pre.Sig
219219+ return nil
104220}
105221106106-// CreateLabel inserts a new label into the database.
222222+// CreateLabel inserts a freshly signed label and returns its sequence id.
223223+// Caller must Sign() first — CreateLabel rejects rows missing a signature.
107224func CreateLabel(db *sql.DB, l *Label) (int64, error) {
225225+ if len(l.Sig) == 0 {
226226+ return 0, fmt.Errorf("refusing to insert unsigned label")
227227+ }
228228+ if l.Ver == 0 {
229229+ l.Ver = LabelVersion
230230+ }
231231+ var expStr *string
232232+ if l.Exp != nil {
233233+ s := l.Exp.UTC().Format(time.RFC3339)
234234+ expStr = &s
235235+ }
108236 result, err := db.Exec(
109109- `INSERT INTO labels (src, uri, cid, val, neg, cts, exp, subject_did, subject_repo)
110110- VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
111111- ON CONFLICT(src, uri, val, neg) DO UPDATE SET cts = excluded.cts`,
112112- l.Src, l.URI, l.CID, l.Val, l.Neg, l.Cts.UTC().Format(time.RFC3339), l.Exp,
237237+ `INSERT INTO labels (src, uri, cid, val, neg, cts, exp, ver, sig, subject_did, subject_repo)
238238+ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
239239+ l.Src, l.URI, nullableString(l.CID), l.Val, l.Neg,
240240+ l.Cts.UTC().Format(time.RFC3339), expStr, l.Ver, l.Sig,
113241 l.SubjectDID, l.SubjectRepo,
114242 )
115243 if err != nil {
116116- return 0, fmt.Errorf("failed to create label: %w", err)
244244+ return 0, fmt.Errorf("failed to insert label: %w", err)
245245+ }
246246+ id, err := result.LastInsertId()
247247+ if err != nil {
248248+ return 0, err
117249 }
118118- return result.LastInsertId()
250250+ l.ID = id
251251+ return id, nil
119252}
120253121121-// NegateLabel creates a negation label to reverse a previous label.
122122-func NegateLabel(db *sql.DB, src, uri, val string, subjectDID, subjectRepo string) error {
123123- _, err := db.Exec(
124124- `INSERT INTO labels (src, uri, val, neg, cts, subject_did, subject_repo)
125125- VALUES (?, ?, ?, 1, ?, ?, ?)`,
126126- src, uri, val, time.Now().UTC().Format(time.RFC3339), subjectDID, subjectRepo,
127127- )
128128- return err
254254+func nullableString(s string) any {
255255+ if s == "" {
256256+ return nil
257257+ }
258258+ return s
129259}
130260131261// GetLabelsSince returns labels with id > cursor, ordered by id ascending.
132262func GetLabelsSince(db *sql.DB, cursor int64, limit int) ([]Label, error) {
133263 rows, err := db.Query(
134134- `SELECT id, src, uri, COALESCE(cid, ''), val, neg, cts, exp, subject_did, subject_repo
264264+ `SELECT id, src, uri, COALESCE(cid, ''), val, neg, cts, exp, ver, sig, subject_did, subject_repo
135265 FROM labels WHERE id > ? ORDER BY id ASC LIMIT ?`,
136266 cursor, limit,
137267 )
···139269 return nil, err
140270 }
141271 defer rows.Close()
272272+ return scanLabels(rows)
273273+}
142274143143- return scanLabels(rows)
275275+// LatestSeq returns the highest sequence id in the database, or 0 if empty.
276276+func LatestSeq(db *sql.DB) (int64, error) {
277277+ var seq sql.NullInt64
278278+ if err := db.QueryRow(`SELECT MAX(id) FROM labels`).Scan(&seq); err != nil {
279279+ return 0, err
280280+ }
281281+ if !seq.Valid {
282282+ return 0, nil
283283+ }
284284+ return seq.Int64, nil
144285}
145286146287// ListActiveTakedowns returns active (non-negated) takedown labels.
···161302 }
162303163304 rows, err := db.Query(
164164- `SELECT l1.id, l1.src, l1.uri, COALESCE(l1.cid, ''), l1.val, l1.neg, l1.cts, l1.exp, l1.subject_did, l1.subject_repo
305305+ `SELECT l1.id, l1.src, l1.uri, COALESCE(l1.cid, ''), l1.val, l1.neg, l1.cts, l1.exp, l1.ver, l1.sig, l1.subject_did, l1.subject_repo
165306 FROM labels l1
166307 WHERE l1.val = '!takedown' AND l1.neg = 0
167308 AND NOT EXISTS (
···182323 return labels, total, err
183324}
184325185185-// GetLabelsForRepo returns all active labels for a specific DID + repository.
326326+// GetLabelsForRepo returns all labels for a specific DID + repository.
186327func GetLabelsForRepo(db *sql.DB, did, repo string) ([]Label, error) {
187328 rows, err := db.Query(
188188- `SELECT id, src, uri, COALESCE(cid, ''), val, neg, cts, exp, subject_did, subject_repo
329329+ `SELECT id, src, uri, COALESCE(cid, ''), val, neg, cts, exp, ver, sig, subject_did, subject_repo
189330 FROM labels
190331 WHERE subject_did = ? AND subject_repo = ?
191332 ORDER BY cts DESC`,
···198339 return scanLabels(rows)
199340}
200341201201-// NegateRepoLabels creates negation labels for all active takedown labels on a (DID, repo) pair.
202202-func NegateRepoLabels(db *sql.DB, src, did, repo string) error {
342342+// newNegationLabel constructs an unsigned negation label awaiting Sign().
343343+func newNegationLabel(src, uri, val, did, repo string) *Label {
344344+ return &Label{
345345+ Src: src,
346346+ URI: uri,
347347+ Val: val,
348348+ Neg: true,
349349+ Cts: time.Now().UTC(),
350350+ SubjectDID: did,
351351+ SubjectRepo: repo,
352352+ }
353353+}
354354+355355+// NegateRepoLabels signs+inserts negation labels for all active takedown labels on (DID, repo).
356356+func NegateRepoLabels(db *sql.DB, key *atcrypto.PrivateKeyK256, src, did, repo string) ([]Label, error) {
203357 rows, err := db.Query(
204358 `SELECT uri FROM labels
205359 WHERE subject_did = ? AND subject_repo = ? AND val = '!takedown' AND neg = 0`,
206360 did, repo,
207361 )
208362 if err != nil {
209209- return err
363363+ return nil, err
210364 }
211211-212365 var uris []string
213366 for rows.Next() {
214367 var uri string
215368 if err := rows.Scan(&uri); err != nil {
216369 rows.Close()
217217- return err
370370+ return nil, err
218371 }
219372 uris = append(uris, uri)
220373 }
221374 rows.Close()
222375 if err := rows.Err(); err != nil {
223223- return err
376376+ return nil, err
224377 }
225378226226- now := time.Now().UTC().Format(time.RFC3339)
379379+ out := make([]Label, 0, len(uris))
227380 for _, uri := range uris {
228228- if _, err := db.Exec(
229229- `INSERT INTO labels (src, uri, val, neg, cts, subject_did, subject_repo)
230230- VALUES (?, ?, '!takedown', 1, ?, ?, ?)`,
231231- src, uri, now, did, repo,
232232- ); err != nil {
233233- return err
381381+ neg := newNegationLabel(src, uri, "!takedown", did, repo)
382382+ if err := neg.Sign(key); err != nil {
383383+ return out, err
384384+ }
385385+ if _, err := CreateLabel(db, neg); err != nil {
386386+ return out, err
234387 }
388388+ out = append(out, *neg)
235389 }
236236- return nil
390390+ return out, nil
237391}
238392239239-// NegateUserLabels creates negation labels for all active takedown labels on a DID (user-level).
240240-func NegateUserLabels(db *sql.DB, src, did string) error {
393393+// NegateUserLabels signs+inserts negation labels for all active takedown labels on a DID.
394394+func NegateUserLabels(db *sql.DB, key *atcrypto.PrivateKeyK256, src, did string) ([]Label, error) {
241395 rows, err := db.Query(
242396 `SELECT uri, subject_repo FROM labels
243397 WHERE subject_did = ? AND val = '!takedown' AND neg = 0`,
244398 did,
245399 )
246400 if err != nil {
247247- return err
401401+ return nil, err
248402 }
249249-250403 type uriRepo struct {
251404 uri string
252405 repo string
···256409 var e uriRepo
257410 if err := rows.Scan(&e.uri, &e.repo); err != nil {
258411 rows.Close()
259259- return err
412412+ return nil, err
260413 }
261414 entries = append(entries, e)
262415 }
263416 rows.Close()
264417 if err := rows.Err(); err != nil {
265265- return err
418418+ return nil, err
266419 }
267420268268- now := time.Now().UTC().Format(time.RFC3339)
421421+ out := make([]Label, 0, len(entries))
269422 for _, e := range entries {
270270- if _, err := db.Exec(
271271- `INSERT INTO labels (src, uri, val, neg, cts, subject_did, subject_repo)
272272- VALUES (?, ?, '!takedown', 1, ?, ?, ?)`,
273273- src, e.uri, now, did, e.repo,
274274- ); err != nil {
275275- return err
423423+ neg := newNegationLabel(src, e.uri, "!takedown", did, e.repo)
424424+ if err := neg.Sign(key); err != nil {
425425+ return out, err
426426+ }
427427+ if _, err := CreateLabel(db, neg); err != nil {
428428+ return out, err
276429 }
430430+ out = append(out, *neg)
277431 }
278278- return nil
432432+ return out, nil
279433}
280434281435func scanLabels(rows *sql.Rows) ([]Label, error) {
···284438 var l Label
285439 var cts string
286440 var exp *string
287287- if err := rows.Scan(&l.ID, &l.Src, &l.URI, &l.CID, &l.Val, &l.Neg, &cts, &exp, &l.SubjectDID, &l.SubjectRepo); err != nil {
441441+ if err := rows.Scan(&l.ID, &l.Src, &l.URI, &l.CID, &l.Val, &l.Neg, &cts, &exp, &l.Ver, &l.Sig, &l.SubjectDID, &l.SubjectRepo); err != nil {
288442 return nil, err
289443 }
290444 if t, err := time.Parse(time.RFC3339, cts); err == nil {
+141-180
pkg/labeler/db_test.go
···11package labeler
2233import (
44+ "database/sql"
45 "os"
56 "path/filepath"
67 "testing"
78 "time"
99+1010+ "github.com/bluesky-social/indigo/atproto/atcrypto"
811)
9121313+func newTestKey(t *testing.T) *atcrypto.PrivateKeyK256 {
1414+ t.Helper()
1515+ k, err := atcrypto.GeneratePrivateKeyK256()
1616+ if err != nil {
1717+ t.Fatalf("generate key: %v", err)
1818+ }
1919+ return k
2020+}
2121+2222+// openTestDB opens a fresh local-only labeler DB and registers cleanup. Returns the
2323+// raw *sql.DB so existing tests can keep using it; the wrapper lifecycle is handled
2424+// here so tests don't have to know about the embedded-replica machinery.
2525+func openTestDB(t *testing.T, path string) *sql.DB {
2626+ t.Helper()
2727+ storage, err := OpenDB(path, LibsqlSync{})
2828+ if err != nil {
2929+ t.Fatalf("OpenDB: %v", err)
3030+ }
3131+ t.Cleanup(func() { _ = storage.Close() })
3232+ return storage.DB
3333+}
3434+3535+// signAndCreate is a helper that signs the label and inserts it; it returns the row id.
3636+func signAndCreate(t *testing.T, db *sql.DB, key *atcrypto.PrivateKeyK256, l *Label) int64 {
3737+ t.Helper()
3838+ if err := l.Sign(key); err != nil {
3939+ t.Fatalf("sign: %v", err)
4040+ }
4141+ id, err := CreateLabel(db, l)
4242+ if err != nil {
4343+ t.Fatalf("create: %v", err)
4444+ }
4545+ return id
4646+}
4747+1048func TestOpenDB(t *testing.T) {
1149 dir := t.TempDir()
1250 dbPath := filepath.Join(dir, "subdir", "test.db")
13511414- db, err := OpenDB(dbPath)
5252+ storage, err := OpenDB(dbPath, LibsqlSync{})
1553 if err != nil {
1654 t.Fatalf("OpenDB failed: %v", err)
1755 }
1818- defer db.Close()
5656+ defer storage.Close()
19572020- // Verify directory was created
2158 if _, err := os.Stat(filepath.Dir(dbPath)); os.IsNotExist(err) {
2259 t.Error("expected directory to be created")
2360 }
24612525- // Verify tables exist
2662 var count int
2727- err = db.QueryRow("SELECT COUNT(*) FROM labels").Scan(&count)
2828- if err != nil {
6363+ if err := storage.DB.QueryRow("SELECT COUNT(*) FROM labels").Scan(&count); err != nil {
2964 t.Fatalf("failed to query labels table: %v", err)
3065 }
3166 if count != 0 {
···35703671func TestCreateLabel(t *testing.T) {
3772 dir := t.TempDir()
3838- db, err := OpenDB(filepath.Join(dir, "test.db"))
3939- if err != nil {
4040- t.Fatal(err)
4141- }
4242- defer db.Close()
7373+ db := openTestDB(t, filepath.Join(dir, "test.db"))
7474+ key := newTestKey(t)
43754476 now := time.Now().UTC().Truncate(time.Second)
4577 label := &Label{
4646- Src: "did:web:labeler.atcr.io",
7878+ Src: "did:plc:labeler-1",
4779 URI: "at://did:plc:abc/io.atcr.manifest/sha256-123",
4880 Val: "!takedown",
4981 Cts: now,
5082 SubjectDID: "did:plc:abc",
5183 SubjectRepo: "myimage",
5284 }
5353-5454- id, err := CreateLabel(db, label)
5555- if err != nil {
5656- t.Fatalf("CreateLabel failed: %v", err)
5757- }
8585+ id := signAndCreate(t, db, key, label)
5886 if id <= 0 {
5987 t.Errorf("expected positive id, got %d", id)
6088 }
8989+ if len(label.Sig) == 0 {
9090+ t.Error("expected signature populated by Sign()")
9191+ }
61926262- // Verify it was stored
6393 labels, err := GetLabelsSince(db, 0, 10)
6494 if err != nil {
6595 t.Fatal(err)
···6797 if len(labels) != 1 {
6898 t.Fatalf("expected 1 label, got %d", len(labels))
6999 }
7070- if labels[0].Src != "did:web:labeler.atcr.io" {
7171- t.Errorf("expected src did:web:labeler.atcr.io, got %s", labels[0].Src)
7272- }
7373- if labels[0].Val != "!takedown" {
7474- t.Errorf("expected val !takedown, got %s", labels[0].Val)
100100+ if labels[0].Src != label.Src {
101101+ t.Errorf("Src = %s, want %s", labels[0].Src, label.Src)
75102 }
76103 if labels[0].SubjectDID != "did:plc:abc" {
7777- t.Errorf("expected subject_did did:plc:abc, got %s", labels[0].SubjectDID)
104104+ t.Errorf("SubjectDID = %s", labels[0].SubjectDID)
78105 }
79106 if labels[0].SubjectRepo != "myimage" {
8080- t.Errorf("expected subject_repo myimage, got %s", labels[0].SubjectRepo)
107107+ t.Errorf("SubjectRepo = %s", labels[0].SubjectRepo)
108108+ }
109109+ if labels[0].Ver != LabelVersion {
110110+ t.Errorf("Ver = %d, want %d", labels[0].Ver, LabelVersion)
111111+ }
112112+ if len(labels[0].Sig) == 0 {
113113+ t.Error("expected stored sig to be populated")
81114 }
82115}
831168484-func TestCreateLabel_Upsert(t *testing.T) {
117117+func TestCreateLabel_RejectsUnsigned(t *testing.T) {
85118 dir := t.TempDir()
8686- db, err := OpenDB(filepath.Join(dir, "test.db"))
8787- if err != nil {
8888- t.Fatal(err)
8989- }
9090- defer db.Close()
119119+ db := openTestDB(t, filepath.Join(dir, "test.db"))
911209292- now := time.Now().UTC()
93121 label := &Label{
9494- Src: "did:web:labeler.atcr.io",
9595- URI: "at://did:plc:abc/io.atcr.manifest/sha256-123",
9696- Val: "!takedown",
9797- Cts: now,
9898- SubjectDID: "did:plc:abc",
9999- SubjectRepo: "myimage",
100100- }
101101-102102- // First insert
103103- _, err = CreateLabel(db, label)
104104- if err != nil {
105105- t.Fatal(err)
106106- }
107107-108108- // Same (src, uri, val) - should upsert, not error
109109- label.Cts = now.Add(time.Hour)
110110- _, err = CreateLabel(db, label)
111111- if err != nil {
112112- t.Fatalf("upsert should not fail: %v", err)
113113- }
114114-115115- // Should still be 1 label
116116- labels, err := GetLabelsSince(db, 0, 10)
117117- if err != nil {
118118- t.Fatal(err)
122122+ Src: "did:plc:labeler-1", URI: "at://did:plc:abc",
123123+ Val: "!takedown", Cts: time.Now().UTC(),
124124+ SubjectDID: "did:plc:abc",
119125 }
120120- if len(labels) != 1 {
121121- t.Errorf("expected 1 label after upsert, got %d", len(labels))
126126+ if _, err := CreateLabel(db, label); err == nil {
127127+ t.Fatal("expected CreateLabel to reject an unsigned label")
122128 }
123129}
124130125125-func TestNegateLabel(t *testing.T) {
126126- dir := t.TempDir()
127127- db, err := OpenDB(filepath.Join(dir, "test.db"))
128128- if err != nil {
129129- t.Fatal(err)
131131+func TestSignAndVerify(t *testing.T) {
132132+ key := newTestKey(t)
133133+ label := &Label{
134134+ Src: "did:plc:labeler-1",
135135+ URI: "at://did:plc:abc",
136136+ Val: "!takedown",
137137+ Cts: time.Now().UTC(),
138138+ Ver: LabelVersion,
130139 }
131131- defer db.Close()
132132-133133- src := "did:web:labeler.atcr.io"
134134- now := time.Now().UTC()
135135-136136- // Create a label
137137- _, err = CreateLabel(db, &Label{
138138- Src: src, URI: "at://did:plc:abc/io.atcr.manifest/sha256-123",
139139- Val: "!takedown", Cts: now,
140140- SubjectDID: "did:plc:abc", SubjectRepo: "myimage",
141141- })
142142- if err != nil {
140140+ if err := label.Sign(key); err != nil {
143141 t.Fatal(err)
144142 }
145143146146- // Negate it
147147- err = NegateLabel(db, src, "at://did:plc:abc/io.atcr.manifest/sha256-123", "!takedown", "did:plc:abc", "myimage")
148148- if err != nil {
149149- t.Fatalf("NegateLabel failed: %v", err)
150150- }
151151-152152- // Should have 2 labels now (original + negation)
153153- labels, err := GetLabelsSince(db, 0, 10)
144144+ pub, err := key.PublicKey()
154145 if err != nil {
155146 t.Fatal(err)
156147 }
157157- if len(labels) != 2 {
158158- t.Fatalf("expected 2 labels, got %d", len(labels))
159159- }
160160-161161- // The negation label should have neg=true
162162- negLabel := labels[1]
163163- if !negLabel.Neg {
164164- t.Error("expected negation label to have neg=true")
148148+ wire := label.ToLabeling()
149149+ if err := wire.VerifySignature(pub); err != nil {
150150+ t.Fatalf("signature did not verify: %v", err)
165151 }
166152}
167153168154func TestListActiveTakedowns(t *testing.T) {
169155 dir := t.TempDir()
170170- db, err := OpenDB(filepath.Join(dir, "test.db"))
171171- if err != nil {
172172- t.Fatal(err)
173173- }
174174- defer db.Close()
156156+ db := openTestDB(t, filepath.Join(dir, "test.db"))
157157+ key := newTestKey(t)
175158176176- src := "did:web:labeler.atcr.io"
159159+ src := "did:plc:labeler-1"
177160 now := time.Now().UTC()
178161179179- // Create 3 labels
180162 for i, repo := range []string{"repo1", "repo2", "repo3"} {
181181- _, err = CreateLabel(db, &Label{
163163+ signAndCreate(t, db, key, &Label{
182164 Src: src, URI: "at://did:plc:abc/io.atcr.repo/" + repo,
183165 Val: "!takedown", Cts: now.Add(time.Duration(i) * time.Minute),
184166 SubjectDID: "did:plc:abc", SubjectRepo: repo,
185167 })
186186- if err != nil {
187187- t.Fatal(err)
188188- }
189168 }
190169191191- // All 3 should be active
192170 labels, total, err := ListActiveTakedowns(db, 10, 0)
193171 if err != nil {
194172 t.Fatal(err)
195173 }
196196- if total != 3 {
197197- t.Errorf("expected 3 active takedowns, got %d", total)
198198- }
199199- if len(labels) != 3 {
200200- t.Errorf("expected 3 labels returned, got %d", len(labels))
174174+ if total != 3 || len(labels) != 3 {
175175+ t.Errorf("expected 3 active takedowns, got total=%d returned=%d", total, len(labels))
201176 }
202177203203- // Negate one
204204- err = NegateLabel(db, src, "at://did:plc:abc/io.atcr.repo/repo2", "!takedown", "did:plc:abc", "repo2")
205205- if err != nil {
178178+ if _, err := NegateRepoLabels(db, key, src, "did:plc:abc", "repo2"); err != nil {
206179 t.Fatal(err)
207180 }
208181209209- // Should be 2 active
210182 _, total, err = ListActiveTakedowns(db, 10, 0)
211183 if err != nil {
212184 t.Fatal(err)
···218190219191func TestNegateRepoLabels(t *testing.T) {
220192 dir := t.TempDir()
221221- db, err := OpenDB(filepath.Join(dir, "test.db"))
222222- if err != nil {
223223- t.Fatal(err)
224224- }
225225- defer db.Close()
193193+ db := openTestDB(t, filepath.Join(dir, "test.db"))
194194+ key := newTestKey(t)
226195227227- src := "did:web:labeler.atcr.io"
196196+ src := "did:plc:labeler-1"
228197 now := time.Now().UTC()
229198 did := "did:plc:abc"
230199231231- // Create multiple labels for same repo
232200 uris := []string{
233201 "at://did:plc:abc/io.atcr.manifest/sha256-111",
234202 "at://did:plc:abc/io.atcr.manifest/sha256-222",
235203 "at://did:plc:abc/io.atcr.tag/myimage-latest",
236204 }
237205 for _, uri := range uris {
238238- _, err = CreateLabel(db, &Label{
206206+ signAndCreate(t, db, key, &Label{
239207 Src: src, URI: uri, Val: "!takedown", Cts: now,
240208 SubjectDID: did, SubjectRepo: "myimage",
241209 })
242242- if err != nil {
243243- t.Fatal(err)
244244- }
245210 }
246211247247- // Negate all labels for the repo
248248- err = NegateRepoLabels(db, src, did, "myimage")
212212+ negs, err := NegateRepoLabels(db, key, src, did, "myimage")
249213 if err != nil {
250214 t.Fatal(err)
251215 }
216216+ if len(negs) != len(uris) {
217217+ t.Errorf("expected %d negation labels, got %d", len(uris), len(negs))
218218+ }
252219253253- // Should have 0 active takedowns
254220 _, total, err := ListActiveTakedowns(db, 10, 0)
255221 if err != nil {
256222 t.Fatal(err)
···262228263229func TestNegateUserLabels(t *testing.T) {
264230 dir := t.TempDir()
265265- db, err := OpenDB(filepath.Join(dir, "test.db"))
266266- if err != nil {
267267- t.Fatal(err)
268268- }
269269- defer db.Close()
231231+ db := openTestDB(t, filepath.Join(dir, "test.db"))
232232+ key := newTestKey(t)
270233271271- src := "did:web:labeler.atcr.io"
234234+ src := "did:plc:labeler-1"
272235 now := time.Now().UTC()
273236 did := "did:plc:abc"
274237275275- // Create labels for different repos + a user-level label
276276- _, err = CreateLabel(db, &Label{
238238+ signAndCreate(t, db, key, &Label{
277239 Src: src, URI: "at://did:plc:abc", Val: "!takedown", Cts: now,
278278- SubjectDID: did, SubjectRepo: "",
240240+ SubjectDID: did,
279241 })
280280- if err != nil {
281281- t.Fatal(err)
282282- }
283283- _, err = CreateLabel(db, &Label{
242242+ signAndCreate(t, db, key, &Label{
284243 Src: src, URI: "at://did:plc:abc/io.atcr.repo/repo1", Val: "!takedown", Cts: now,
285244 SubjectDID: did, SubjectRepo: "repo1",
286245 })
287287- if err != nil {
288288- t.Fatal(err)
289289- }
290246291291- // Negate all labels for the user
292292- err = NegateUserLabels(db, src, did)
247247+ negs, err := NegateUserLabels(db, key, src, did)
293248 if err != nil {
294249 t.Fatal(err)
295250 }
251251+ if len(negs) != 2 {
252252+ t.Errorf("expected 2 negation labels, got %d", len(negs))
253253+ }
296254297297- // Should have 0 active
298255 _, total, err := ListActiveTakedowns(db, 10, 0)
299256 if err != nil {
300257 t.Fatal(err)
···306263307264func TestGetLabelsSince(t *testing.T) {
308265 dir := t.TempDir()
309309- db, err := OpenDB(filepath.Join(dir, "test.db"))
310310- if err != nil {
311311- t.Fatal(err)
312312- }
313313- defer db.Close()
266266+ db := openTestDB(t, filepath.Join(dir, "test.db"))
267267+ key := newTestKey(t)
314268315315- src := "did:web:labeler.atcr.io"
269269+ src := "did:plc:labeler-1"
316270 now := time.Now().UTC()
317271318318- // Create 5 labels
319272 for i := 0; i < 5; i++ {
320320- _, err = CreateLabel(db, &Label{
273273+ signAndCreate(t, db, key, &Label{
321274 Src: src, URI: "at://did:plc:abc/io.atcr.manifest/" + string(rune('a'+i)),
322275 Val: "!takedown", Cts: now.Add(time.Duration(i) * time.Minute),
323276 SubjectDID: "did:plc:abc", SubjectRepo: "repo",
324277 })
325325- if err != nil {
326326- t.Fatal(err)
327327- }
328278 }
329279330330- // Get all since 0
331280 labels, err := GetLabelsSince(db, 0, 10)
332281 if err != nil {
333282 t.Fatal(err)
···336285 t.Errorf("expected 5 labels, got %d", len(labels))
337286 }
338287339339- // Get since cursor (skip first 3)
340340- if len(labels) >= 3 {
341341- cursor := labels[2].ID
342342- after, err := GetLabelsSince(db, cursor, 10)
343343- if err != nil {
344344- t.Fatal(err)
345345- }
346346- if len(after) != 2 {
347347- t.Errorf("expected 2 labels after cursor %d, got %d", cursor, len(after))
348348- }
288288+ cursor := labels[2].ID
289289+ after, err := GetLabelsSince(db, cursor, 10)
290290+ if err != nil {
291291+ t.Fatal(err)
292292+ }
293293+ if len(after) != 2 {
294294+ t.Errorf("expected 2 labels after cursor %d, got %d", cursor, len(after))
349295 }
350296351351- // Get with limit
352297 limited, err := GetLabelsSince(db, 0, 2)
353298 if err != nil {
354299 t.Fatal(err)
···358303 }
359304}
360305361361-func TestGetLabelsForRepo(t *testing.T) {
306306+func TestLatestSeq(t *testing.T) {
362307 dir := t.TempDir()
363363- db, err := OpenDB(filepath.Join(dir, "test.db"))
308308+ db := openTestDB(t, filepath.Join(dir, "test.db"))
309309+ key := newTestKey(t)
310310+311311+ if seq, err := LatestSeq(db); err != nil || seq != 0 {
312312+ t.Fatalf("expected empty seq=0, got %d (err=%v)", seq, err)
313313+ }
314314+315315+ id := signAndCreate(t, db, key, &Label{
316316+ Src: "did:plc:labeler-1", URI: "at://did:plc:abc",
317317+ Val: "!takedown", Cts: time.Now().UTC(),
318318+ SubjectDID: "did:plc:abc",
319319+ })
320320+ seq, err := LatestSeq(db)
364321 if err != nil {
365322 t.Fatal(err)
366323 }
367367- defer db.Close()
324324+ if seq != id {
325325+ t.Errorf("LatestSeq = %d, want %d", seq, id)
326326+ }
327327+}
368328369369- src := "did:web:labeler.atcr.io"
329329+func TestGetLabelsForRepo(t *testing.T) {
330330+ dir := t.TempDir()
331331+ db := openTestDB(t, filepath.Join(dir, "test.db"))
332332+ key := newTestKey(t)
333333+334334+ src := "did:plc:labeler-1"
370335 now := time.Now().UTC()
371336372372- // Labels for different repos
373373- _, _ = CreateLabel(db, &Label{
337337+ signAndCreate(t, db, key, &Label{
374338 Src: src, URI: "at://did:plc:abc/io.atcr.repo/repo1",
375339 Val: "!takedown", Cts: now, SubjectDID: "did:plc:abc", SubjectRepo: "repo1",
376340 })
377377- _, _ = CreateLabel(db, &Label{
341341+ signAndCreate(t, db, key, &Label{
378342 Src: src, URI: "at://did:plc:abc/io.atcr.repo/repo2",
379343 Val: "!takedown", Cts: now, SubjectDID: "did:plc:abc", SubjectRepo: "repo2",
380344 })
381381- _, _ = CreateLabel(db, &Label{
345345+ signAndCreate(t, db, key, &Label{
382346 Src: src, URI: "at://did:plc:def/io.atcr.repo/repo1",
383347 Val: "!takedown", Cts: now, SubjectDID: "did:plc:def", SubjectRepo: "repo1",
384348 })
385349386386- // Get labels for specific did+repo
387350 labels, err := GetLabelsForRepo(db, "did:plc:abc", "repo1")
388351 if err != nil {
389352 t.Fatal(err)
···392355 t.Errorf("expected 1 label for did:plc:abc/repo1, got %d", len(labels))
393356 }
394357395395- // Different user same repo
396358 labels, err = GetLabelsForRepo(db, "did:plc:def", "repo1")
397359 if err != nil {
398360 t.Fatal(err)
···401363 t.Errorf("expected 1 label for did:plc:def/repo1, got %d", len(labels))
402364 }
403365404404- // No labels
405366 labels, err = GetLabelsForRepo(db, "did:plc:xyz", "repo1")
406367 if err != nil {
407368 t.Fatal(err)
+3-4
pkg/labeler/handlers.go
···1313// Auth handlers
14141515func (s *Server) handleLogin(w http.ResponseWriter, r *http.Request) {
1616- // If already logged in, redirect to dashboard
1716 if token, ok := getSessionCookie(r); ok {
1818- if session := s.auth.getSession(token); session != nil && session.DID == s.config.Labeler.OwnerDID {
1717+ if session := s.auth.GetSession(token); session != nil && session.DID == s.config.Labeler.OwnerDID {
1918 http.Redirect(w, r, "/", http.StatusFound)
2019 return
2120 }
···9998 return
10099 }
101100102102- token, err := s.auth.createSession(did, handle)
101101+ token, _, err := s.auth.CreateSession(did, handle, r.UserAgent(), clientIPPrefix(r))
103102 if err != nil {
104103 http.Error(w, "Failed to create session", http.StatusInternalServerError)
105104 return
···111110112111func (s *Server) handleLogout(w http.ResponseWriter, r *http.Request) {
113112 if token, ok := getSessionCookie(r); ok {
114114- s.auth.deleteSession(token)
113113+ s.auth.DeleteSession(token)
115114 }
116115 clearSessionCookie(w)
117116 http.Redirect(w, r, "/auth/login", http.StatusFound)
+75
pkg/labeler/hub.go
···11+package labeler
22+33+import (
44+ "sync"
55+)
66+77+// hubSubscriber is one connected subscribeLabels client. The hub fans out new labels
88+// to each subscriber's bounded channel; if a slow client fills the buffer, the hub
99+// drops them rather than blocking the writer.
1010+type hubSubscriber struct {
1111+ ch chan *Label
1212+ closed bool
1313+}
1414+1515+// Hub broadcasts newly-inserted labels to all live subscribeLabels clients.
1616+type Hub struct {
1717+ mu sync.Mutex
1818+ subs map[*hubSubscriber]struct{}
1919+}
2020+2121+// NewHub returns an empty hub ready to accept subscribers.
2222+func NewHub() *Hub {
2323+ return &Hub{subs: make(map[*hubSubscriber]struct{})}
2424+}
2525+2626+// subscribe registers a new subscriber and returns its event channel + a cancel func.
2727+// The buffer size bounds backpressure tolerance per client.
2828+func (h *Hub) subscribe(buffer int) (*hubSubscriber, func()) {
2929+ s := &hubSubscriber{ch: make(chan *Label, buffer)}
3030+ h.mu.Lock()
3131+ h.subs[s] = struct{}{}
3232+ h.mu.Unlock()
3333+ return s, func() { h.unsubscribe(s) }
3434+}
3535+3636+func (h *Hub) unsubscribe(s *hubSubscriber) {
3737+ h.mu.Lock()
3838+ defer h.mu.Unlock()
3939+ if _, ok := h.subs[s]; !ok {
4040+ return
4141+ }
4242+ delete(h.subs, s)
4343+ if !s.closed {
4444+ s.closed = true
4545+ close(s.ch)
4646+ }
4747+}
4848+4949+// Broadcast sends a copy of the label to every live subscriber. Subscribers whose
5050+// buffer is full are evicted on the spot rather than slowing down the writer.
5151+func (h *Hub) Broadcast(l *Label) {
5252+ if l == nil {
5353+ return
5454+ }
5555+ h.mu.Lock()
5656+ dead := make([]*hubSubscriber, 0)
5757+ for s := range h.subs {
5858+ select {
5959+ case s.ch <- l:
6060+ default:
6161+ dead = append(dead, s)
6262+ }
6363+ }
6464+ h.mu.Unlock()
6565+ for _, s := range dead {
6666+ h.unsubscribe(s)
6767+ }
6868+}
6969+7070+// Len returns the number of live subscribers (mostly for tests / metrics).
7171+func (h *Hub) Len() int {
7272+ h.mu.Lock()
7373+ defer h.mu.Unlock()
7474+ return len(h.subs)
7575+}
+38-21
pkg/labeler/identity.go
···11package labeler
2233import (
44+ "context"
45 "encoding/json"
56 "fmt"
67 "net/http"
88+99+ "atcr.io/pkg/atproto/did"
1010+ "atcr.io/pkg/auth/oauth"
1111+ "github.com/bluesky-social/indigo/atproto/atcrypto"
712)
81399-// DIDDocument represents a did:web DID document.
1010-type DIDDocument struct {
1111- Context []string `json:"@context"`
1212- ID string `json:"id"`
1313- Service []DIDService `json:"service,omitempty"`
1414+// labelerServices returns the service entries the labeler publishes in its DID document
1515+// and PLC operations: a single AtprotoLabeler endpoint at #atproto_labeler.
1616+func labelerServices(publicURL string) map[string]did.Service {
1717+ return map[string]did.Service{
1818+ "atproto_labeler": {Type: "AtprotoLabeler", Endpoint: publicURL},
1919+ }
1420}
15211616-// DIDService represents a service entry in a DID document.
1717-type DIDService struct {
1818- ID string `json:"id"`
1919- Type string `json:"type"`
2020- ServiceEndpoint string `json:"serviceEndpoint"`
2222+// LoadIdentity resolves the labeler's DID and loads its k256 signing key.
2323+// For did:plc this calls into the shared PLC package (loading or creating); for did:web
2424+// the DID is derived from PublicURL and the signing key is generated on disk if missing.
2525+func LoadIdentity(ctx context.Context, cfg *Config) (string, *atcrypto.PrivateKeyK256, error) {
2626+ labelerDID, err := did.LoadOrCreate(ctx, did.Config{
2727+ Method: cfg.Labeler.DIDMethod,
2828+ PublicURL: cfg.PublicURL(),
2929+ DBPath: cfg.Labeler.DataDir,
3030+ SigningKeyPath: cfg.SigningKeyPath(),
3131+ RotationKey: cfg.Labeler.RotationKey,
3232+ PLCDirectoryURL: cfg.PLCDirectoryURL(),
3333+ DID: cfg.Labeler.DID,
3434+ VerificationKeyName: "atproto_label",
3535+ Services: labelerServices(cfg.PublicURL()),
3636+ })
3737+ if err != nil {
3838+ return "", nil, fmt.Errorf("labeler: failed to resolve DID: %w", err)
3939+ }
4040+ signingKey, err := oauth.GenerateOrLoadPDSKey(cfg.SigningKeyPath())
4141+ if err != nil {
4242+ return "", nil, fmt.Errorf("labeler: failed to load signing key: %w", err)
4343+ }
4444+ return labelerDID, signingKey, nil
2145}
22462347func (s *Server) handleDIDDocument(w http.ResponseWriter, r *http.Request) {
2424- doc := DIDDocument{
2525- Context: []string{"https://www.w3.org/ns/did/v1"},
2626- ID: s.config.DID(),
2727- Service: []DIDService{
2828- {
2929- ID: "#atproto_labeler",
3030- Type: "AtprotoLabeler",
3131- ServiceEndpoint: s.config.PublicURL(),
3232- },
3333- },
4848+ doc, err := did.BuildDIDDocument(s.did, s.config.PublicURL(), s.signingKey, "atproto_label", labelerServices(s.config.PublicURL()))
4949+ if err != nil {
5050+ http.Error(w, "failed to build DID document", http.StatusInternalServerError)
5151+ return
3452 }
3535-3653 w.Header().Set("Content-Type", "application/json")
3754 _ = json.NewEncoder(w).Encode(doc)
3855}
+50-16
pkg/labeler/server.go
···55 "database/sql"
66 "fmt"
77 "log/slog"
88+ "net"
89 "net/http"
910 "net/url"
1011 "os"
1112 "os/signal"
1212- "strings"
1313 "syscall"
1414+ "time"
14151516 "atcr.io/pkg/atproto"
1717+ "github.com/bluesky-social/indigo/atproto/atcrypto"
1618 indigooauth "github.com/bluesky-social/indigo/atproto/auth/oauth"
1719 "github.com/go-chi/chi/v5"
1820)
19212022// Server is the labeler HTTP server.
2123type Server struct {
2222- config *Config
2323- db *sql.DB
2424- router chi.Router
2525- clientApp *indigooauth.ClientApp
2626- auth *Auth
2424+ config *Config
2525+ storage *LabelerDB
2626+ db *sql.DB
2727+ router chi.Router
2828+ clientApp *indigooauth.ClientApp
2929+ auth *Auth
3030+ did string
3131+ signingKey *atcrypto.PrivateKeyK256
3232+ hub *Hub
2733}
28342935// NewServer creates a new labeler server.
3036func NewServer(cfg *Config) (*Server, error) {
3131- db, err := OpenDB(cfg.Labeler.DBPath)
3737+ storage, err := OpenDB(cfg.DBPath(), LibsqlSync{
3838+ SyncURL: cfg.Labeler.LibsqlSyncURL,
3939+ AuthToken: cfg.Labeler.LibsqlAuthToken,
4040+ SyncInterval: cfg.Labeler.LibsqlSyncInterval,
4141+ })
3242 if err != nil {
3343 return nil, fmt.Errorf("failed to open database: %w", err)
4444+ }
4545+4646+ ctx := context.Background()
4747+ did, signingKey, err := LoadIdentity(ctx, cfg)
4848+ if err != nil {
4949+ _ = storage.Close()
5050+ return nil, err
3451 }
35523653 publicURL := cfg.PublicURL()
···6885 auth := NewAuth(cfg.Labeler.OwnerDID)
69867087 s := &Server{
7171- config: cfg,
7272- db: db,
7373- clientApp: clientApp,
7474- auth: auth,
8888+ config: cfg,
8989+ storage: storage,
9090+ db: storage.DB,
9191+ clientApp: clientApp,
9292+ auth: auth,
9393+ did: did,
9494+ signingKey: signingKey,
9595+ hub: NewHub(),
7596 }
76977798 s.setupRoutes()
···97118 r.Get("/xrpc/com.atproto.label.subscribeLabels", s.handleSubscribeLabels)
98119 r.Get("/xrpc/com.atproto.label.queryLabels", s.handleQueryLabels)
99120100100- // Protected routes (require owner)
121121+ // Protected routes (require owner). CSRF is enforced for state-mutating
122122+ // methods inside the same group, so it sees the session on the context.
101123 r.Group(func(r chi.Router) {
102124 r.Use(s.auth.RequireOwner)
125125+ r.Use(s.auth.RequireCSRF)
103126104127 r.Get("/", s.handleDashboard)
105128 r.Get("/takedown", s.handleTakedownForm)
···115138 slog.Info("Starting labeler service",
116139 "addr", s.config.Labeler.Addr,
117140 "public_url", s.config.PublicURL(),
118118- "did", s.config.DID(),
141141+ "did", s.did,
119142 "owner", s.config.Labeler.OwnerDID,
120143 )
121144···140163 }
141164 case <-ctx.Done():
142165 slog.Info("Shutting down labeler service")
143143- shutdownCtx, cancel := context.WithTimeout(context.Background(), 5000000000) // 5s
166166+ shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
144167 defer cancel()
145168 if err := srv.Shutdown(shutdownCtx); err != nil {
146169 return fmt.Errorf("shutdown error: %w", err)
147170 }
148171 }
149172150150- s.db.Close()
173173+ if err := s.storage.Close(); err != nil {
174174+ slog.Warn("Error closing labeler database", "error", err)
175175+ }
151176 return nil
152177}
153178179179+// isLocalhost returns true when the host is reachable only from the local machine /
180180+// docker host — anything that an external PDS can't reach. Matches the hold's policy:
181181+// any IP literal counts (covers 127.0.0.1, 192.168.*, 172.16-31.*, 10.*, ::1, etc.) plus
182182+// the literal "localhost". When this is true, OAuth uses indigo's `NewLocalhostConfig`
183183+// which sets a `http://localhost`-form client_id that PDSes accept under the loopback
184184+// exception — so the PDS never has to fetch the client metadata URL we publish.
154185func isLocalhost(host string) bool {
155155- return host == "localhost" || host == "127.0.0.1" || strings.HasPrefix(host, "192.168.")
186186+ if host == "localhost" {
187187+ return true
188188+ }
189189+ return net.ParseIP(host) != nil
156190}
+30
pkg/labeler/server_test.go
···11+package labeler
22+33+import "testing"
44+55+// TestIsLocalhost covers the OAuth-mode decision: any IP-literal host (including
66+// docker-compose private addresses like 172.28.0.x and the standard 127.0.0.1) plus
77+// the literal "localhost" routes through the loopback OAuth path so PDSes don't have
88+// to fetch our published client metadata. Domain names go through the public-client
99+// path and require the metadata endpoint to be reachable from the PDS.
1010+func TestIsLocalhost(t *testing.T) {
1111+ tests := []struct {
1212+ host string
1313+ want bool
1414+ }{
1515+ {"localhost", true},
1616+ {"127.0.0.1", true},
1717+ {"::1", true},
1818+ {"192.168.1.10", true},
1919+ {"172.28.0.4", true}, // docker-compose private network
2020+ {"10.0.0.5", true}, // RFC 1918
2121+ {"labeler.atcr.io", false},
2222+ {"labeler.example.com", false},
2323+ {"", false},
2424+ }
2525+ for _, tt := range tests {
2626+ if got := isLocalhost(tt.host); got != tt.want {
2727+ t.Errorf("isLocalhost(%q) = %v, want %v", tt.host, got, tt.want)
2828+ }
2929+ }
3030+}
+244-94
pkg/labeler/subscribe.go
···11package labeler
2233import (
44+ "bytes"
55+ "database/sql"
46 "encoding/json"
77+ "errors"
88+ "fmt"
59 "log/slog"
610 "net/http"
711 "strconv"
88- "time"
1212+ "strings"
9131414+ comatproto "github.com/bluesky-social/indigo/api/atproto"
1515+ "github.com/bluesky-social/indigo/events"
1016 "github.com/gorilla/websocket"
1717+ cbg "github.com/whyrusleeping/cbor-gen"
1818+)
1919+2020+const (
2121+ subscriberBuffer = 64
2222+ backfillPageLimit = 200
1123)
12241325var upgrader = websocket.Upgrader{
2626+ // CheckOrigin is permissive: the firehose is a public stream by design and ATProto
2727+ // consumers are not browsers, so the same-origin policy doesn't apply to them anyway.
1428 CheckOrigin: func(r *http.Request) bool { return true },
1529}
16301717-// LabelsMessage is the ATProto subscribeLabels wire format.
1818-type LabelsMessage struct {
1919- Seq int64 `json:"seq"`
2020- Labels []LabelOutput `json:"labels"`
3131+// frameLabels builds the binary frame for a labels event: CBOR-encoded
3232+// {op:1, t:"#labels"} header concatenated with CBOR-encoded {seq, labels:[...]} body.
3333+func frameLabels(seq int64, labels []*comatproto.LabelDefs_Label) ([]byte, error) {
3434+ var buf bytes.Buffer
3535+ w := cbg.NewCborWriter(&buf)
3636+3737+ header := events.EventHeader{Op: events.EvtKindMessage, MsgType: "#labels"}
3838+ if err := header.MarshalCBOR(w); err != nil {
3939+ return nil, fmt.Errorf("marshal header: %w", err)
4040+ }
4141+ body := comatproto.LabelSubscribeLabels_Labels{Seq: seq, Labels: labels}
4242+ if err := body.MarshalCBOR(w); err != nil {
4343+ return nil, fmt.Errorf("marshal body: %w", err)
4444+ }
4545+ return buf.Bytes(), nil
2146}
22472323-// LabelOutput is the ATProto label format for subscribeLabels/queryLabels output.
2424-type LabelOutput struct {
2525- Src string `json:"src"`
2626- URI string `json:"uri"`
2727- CID string `json:"cid,omitempty"`
2828- Val string `json:"val"`
2929- Neg bool `json:"neg"`
3030- Cts string `json:"cts"`
3131- Exp string `json:"exp,omitempty"`
4848+// frameInfo builds the binary frame for an info event: header {op:1, t:"#info"} plus body.
4949+func frameInfo(name, message string) ([]byte, error) {
5050+ var buf bytes.Buffer
5151+ w := cbg.NewCborWriter(&buf)
5252+5353+ header := events.EventHeader{Op: events.EvtKindMessage, MsgType: "#info"}
5454+ if err := header.MarshalCBOR(w); err != nil {
5555+ return nil, err
5656+ }
5757+ body := comatproto.LabelSubscribeLabels_Info{Name: name}
5858+ if message != "" {
5959+ body.Message = &message
6060+ }
6161+ if err := body.MarshalCBOR(w); err != nil {
6262+ return nil, err
6363+ }
6464+ return buf.Bytes(), nil
3265}
33663434-func labelToOutput(l Label) LabelOutput {
3535- out := LabelOutput{
3636- Src: l.Src,
3737- URI: l.URI,
3838- CID: l.CID,
3939- Val: l.Val,
4040- Neg: l.Neg,
4141- Cts: l.Cts.UTC().Format(time.RFC3339),
6767+// frameError builds an error frame: header {op:-1} plus {error, message}.
6868+func frameError(name, message string) ([]byte, error) {
6969+ var buf bytes.Buffer
7070+ w := cbg.NewCborWriter(&buf)
7171+7272+ header := events.EventHeader{Op: events.EvtKindErrorFrame}
7373+ if err := header.MarshalCBOR(w); err != nil {
7474+ return nil, err
4275 }
4343- if l.Exp != nil {
4444- out.Exp = l.Exp.UTC().Format(time.RFC3339)
7676+ body := events.ErrorFrame{Error: name, Message: message}
7777+ if err := body.MarshalCBOR(w); err != nil {
7878+ return nil, err
4579 }
4646- return out
8080+ return buf.Bytes(), nil
8181+}
8282+8383+// labelToLexicon converts a stored row into the indigo lexicon type used in the wire format.
8484+func labelToLexicon(l *Label) *comatproto.LabelDefs_Label {
8585+ tmp := l.ToLabeling()
8686+ lex := tmp.ToLexicon()
8787+ return &lex
4788}
48894949-// handleSubscribeLabels implements com.atproto.label.subscribeLabels (WebSocket).
9090+// handleSubscribeLabels implements com.atproto.label.subscribeLabels.
9191+//
9292+// Wire format: each WebSocket binary message is two concatenated CBOR objects (header
9393+// + body) matching the firehose convention. Backfill pages historical labels since the
9494+// cursor, then the connection joins the broadcast hub for live deliveries.
5095func (s *Server) handleSubscribeLabels(w http.ResponseWriter, r *http.Request) {
5196 cursorStr := r.URL.Query().Get("cursor")
5297 var cursor int64
5398 if cursorStr != "" {
5454- var err error
5555- cursor, err = strconv.ParseInt(cursorStr, 10, 64)
9999+ v, err := strconv.ParseInt(cursorStr, 10, 64)
56100 if err != nil {
57101 http.Error(w, "invalid cursor", http.StatusBadRequest)
58102 return
59103 }
104104+ cursor = v
60105 }
6110662107 conn, err := upgrader.Upgrade(w, r, nil)
···6811369114 slog.Info("subscribeLabels client connected", "cursor", cursor)
701157171- // Send historical labels since cursor
7272- labels, err := GetLabelsSince(s.db, cursor, 1000)
116116+ latest, err := LatestSeq(s.db)
73117 if err != nil {
7474- slog.Error("Failed to get labels", "error", err)
118118+ slog.Error("Failed to read latest seq", "error", err)
75119 return
76120 }
121121+ if cursor > latest {
122122+ if frame, ferr := frameError("FutureCursor", "cursor is in the future"); ferr == nil {
123123+ _ = conn.WriteMessage(websocket.BinaryMessage, frame)
124124+ }
125125+ return
126126+ }
127127+128128+ // Subscribe to the broadcast hub BEFORE backfilling so we don't lose events
129129+ // that arrive while we're streaming the historical tail.
130130+ sub, cancel := s.hub.subscribe(subscriberBuffer)
131131+ defer cancel()
771327878- for _, l := range labels {
7979- msg := LabelsMessage{
8080- Seq: l.ID,
8181- Labels: []LabelOutput{labelToOutput(l)},
133133+ if cursor > 0 {
134134+ if frame, ferr := frameInfo("OutdatedCursor", "starting backfill from cursor"); ferr == nil {
135135+ if err := conn.WriteMessage(websocket.BinaryMessage, frame); err != nil {
136136+ return
137137+ }
82138 }
8383- if err := conn.WriteJSON(msg); err != nil {
139139+ }
140140+141141+ // Backfill historical labels in pages until we catch up.
142142+ for {
143143+ labels, err := GetLabelsSince(s.db, cursor, backfillPageLimit)
144144+ if err != nil {
145145+ slog.Error("Failed to read labels for backfill", "error", err)
84146 return
85147 }
8686- cursor = l.ID
148148+ if len(labels) == 0 {
149149+ break
150150+ }
151151+ for i := range labels {
152152+ frame, ferr := frameLabels(labels[i].ID, []*comatproto.LabelDefs_Label{labelToLexicon(&labels[i])})
153153+ if ferr != nil {
154154+ slog.Error("Failed to encode label frame", "error", ferr)
155155+ return
156156+ }
157157+ if err := conn.WriteMessage(websocket.BinaryMessage, frame); err != nil {
158158+ return
159159+ }
160160+ cursor = labels[i].ID
161161+ }
162162+ if len(labels) < backfillPageLimit {
163163+ break
164164+ }
87165 }
881668989- // Poll for new labels
9090- ticker := time.NewTicker(5 * time.Second)
9191- defer ticker.Stop()
9292-9393- // Read pump (detect client disconnect)
167167+ // Live delivery: a goroutine monitors the read side so we notice client disconnects;
168168+ // the main loop pulls from the hub and writes frames until either side closes.
94169 done := make(chan struct{})
95170 go func() {
96171 defer close(done)
97172 for {
9898- if _, _, err := conn.ReadMessage(); err != nil {
173173+ if _, _, rerr := conn.ReadMessage(); rerr != nil {
99174 return
100175 }
101176 }
···105180 select {
106181 case <-done:
107182 return
108108- case <-ticker.C:
109109- labels, err := GetLabelsSince(s.db, cursor, 100)
110110- if err != nil {
111111- slog.Error("Failed to poll labels", "error", err)
112112- continue
183183+ case lbl, ok := <-sub.ch:
184184+ if !ok {
185185+ return
113186 }
114114- for _, l := range labels {
115115- msg := LabelsMessage{
116116- Seq: l.ID,
117117- Labels: []LabelOutput{labelToOutput(l)},
118118- }
119119- if err := conn.WriteJSON(msg); err != nil {
120120- return
121121- }
122122- cursor = l.ID
187187+ if lbl.ID <= cursor {
188188+ continue // already delivered during backfill
123189 }
190190+ frame, ferr := frameLabels(lbl.ID, []*comatproto.LabelDefs_Label{labelToLexicon(lbl)})
191191+ if ferr != nil {
192192+ slog.Error("Failed to encode label frame", "error", ferr)
193193+ return
194194+ }
195195+ if err := conn.WriteMessage(websocket.BinaryMessage, frame); err != nil {
196196+ return
197197+ }
198198+ cursor = lbl.ID
124199 }
125200 }
126201}
127202128128-// handleQueryLabels implements com.atproto.label.queryLabels (HTTP GET).
203203+// queryLabelsResponse mirrors the lexicon JSON shape for queryLabels.
204204+type queryLabelsResponse struct {
205205+ Cursor string `json:"cursor,omitempty"`
206206+ Labels []*comatproto.LabelDefs_Label `json:"labels"`
207207+}
208208+209209+// handleQueryLabels implements com.atproto.label.queryLabels.
210210+//
211211+// Filters (uriPatterns, sources) are applied in SQL so the LIMIT cap operates on the
212212+// filtered result, not the raw scan. URI patterns support a single trailing `*` glob
213213+// (LIKE), with `%` and `_` escaped to remain literal.
129214func (s *Server) handleQueryLabels(w http.ResponseWriter, r *http.Request) {
130130- uriPatterns := r.URL.Query()["uriPatterns"]
131131- cursorStr := r.URL.Query().Get("cursor")
132132- limitStr := r.URL.Query().Get("limit")
215215+ q := r.URL.Query()
216216+ patterns := q["uriPatterns"]
217217+ sources := q["sources"]
133218134219 var cursor int64
135135- if cursorStr != "" {
136136- cursor, _ = strconv.ParseInt(cursorStr, 10, 64)
220220+ if cs := q.Get("cursor"); cs != "" {
221221+ v, err := strconv.ParseInt(cs, 10, 64)
222222+ if err != nil {
223223+ http.Error(w, "invalid cursor", http.StatusBadRequest)
224224+ return
225225+ }
226226+ cursor = v
137227 }
228228+138229 limit := 50
139139- if limitStr != "" {
140140- if l, err := strconv.Atoi(limitStr); err == nil && l > 0 && l <= 250 {
230230+ if ls := q.Get("limit"); ls != "" {
231231+ if l, err := strconv.Atoi(ls); err == nil && l > 0 && l <= 250 {
141232 limit = l
142233 }
143234 }
144235145145- labels, err := GetLabelsSince(s.db, cursor, limit)
236236+ rows, err := queryLabelsSQL(s.db, patterns, sources, cursor, limit)
146237 if err != nil {
147147- http.Error(w, "failed to query labels", http.StatusInternalServerError)
238238+ if errors.Is(err, errInvalidPattern) {
239239+ http.Error(w, "invalid uriPattern: wildcard '*' must be at the end", http.StatusBadRequest)
240240+ return
241241+ }
242242+ slog.Error("queryLabels failed", "error", err)
243243+ http.Error(w, "internal error", http.StatusInternalServerError)
148244 return
149245 }
150246151151- // Filter by URI patterns if provided
152152- var filtered []LabelOutput
153153- for _, l := range labels {
154154- if len(uriPatterns) == 0 || matchesAnyPattern(l.URI, uriPatterns) {
155155- filtered = append(filtered, labelToOutput(l))
247247+ out := &queryLabelsResponse{Labels: make([]*comatproto.LabelDefs_Label, 0, len(rows))}
248248+ for i := range rows {
249249+ out.Labels = append(out.Labels, labelToLexicon(&rows[i]))
250250+ }
251251+ if len(rows) > 0 {
252252+ out.Cursor = strconv.FormatInt(rows[len(rows)-1].ID, 10)
253253+ }
254254+255255+ w.Header().Set("Content-Type", "application/json")
256256+ _ = json.NewEncoder(w).Encode(out)
257257+}
258258+259259+var errInvalidPattern = errors.New("invalid uriPattern")
260260+261261+// queryLabelsSQL builds the WHERE clause from filter args and runs the query. All
262262+// filtering happens in SQL so LIMIT operates on already-filtered rows.
263263+func queryLabelsSQL(db *sql.DB, patterns, sources []string, cursor int64, limit int) ([]Label, error) {
264264+ var (
265265+ where []string
266266+ args []any
267267+ )
268268+ where = append(where, "id > ?")
269269+ args = append(args, cursor)
270270+271271+ if len(patterns) > 0 {
272272+ var ors []string
273273+ var matchAll bool
274274+ for _, p := range patterns {
275275+ if p == "" {
276276+ continue
277277+ }
278278+ if p == "*" {
279279+ matchAll = true
280280+ break
281281+ }
282282+ like, err := patternToLike(p)
283283+ if err != nil {
284284+ return nil, err
285285+ }
286286+ if strings.ContainsAny(like, `%_\`) {
287287+ ors = append(ors, "uri LIKE ? ESCAPE '\\'")
288288+ } else {
289289+ ors = append(ors, "uri = ?")
290290+ }
291291+ args = append(args, like)
292292+ }
293293+ if !matchAll && len(ors) > 0 {
294294+ where = append(where, "("+strings.Join(ors, " OR ")+")")
156295 }
157296 }
158297159159- var nextCursor string
160160- if len(labels) > 0 {
161161- nextCursor = strconv.FormatInt(labels[len(labels)-1].ID, 10)
298298+ if len(sources) > 0 {
299299+ placeholders := strings.Repeat("?,", len(sources))
300300+ placeholders = placeholders[:len(placeholders)-1]
301301+ where = append(where, "src IN ("+placeholders+")")
302302+ for _, s := range sources {
303303+ args = append(args, s)
304304+ }
162305 }
163306164164- resp := struct {
165165- Cursor string `json:"cursor,omitempty"`
166166- Labels []LabelOutput `json:"labels"`
167167- }{
168168- Cursor: nextCursor,
169169- Labels: filtered,
307307+ args = append(args, limit)
308308+ q := `SELECT id, src, uri, COALESCE(cid, ''), val, neg, cts, exp, ver, sig, subject_did, subject_repo
309309+ FROM labels WHERE ` + strings.Join(where, " AND ") +
310310+ ` ORDER BY id ASC LIMIT ?`
311311+312312+ rows, err := db.Query(q, args...)
313313+ if err != nil {
314314+ return nil, err
170315 }
171171- if resp.Labels == nil {
172172- resp.Labels = []LabelOutput{}
173173- }
174174-175175- w.Header().Set("Content-Type", "application/json")
176176- _ = json.NewEncoder(w).Encode(resp)
316316+ defer rows.Close()
317317+ return scanLabels(rows)
177318}
178319179179-func matchesAnyPattern(uri string, patterns []string) bool {
180180- for _, p := range patterns {
181181- // Simple prefix matching (ATProto spec allows glob-like patterns)
182182- if p == uri || (len(p) > 0 && p[len(p)-1] == '*' && len(uri) >= len(p)-1 && uri[:len(p)-1] == p[:len(p)-1]) {
183183- return true
184184- }
320320+// patternToLike converts a uriPattern into a SQLite LIKE expression. The only
321321+// wildcard supported is a trailing `*`, which becomes `%`. Literal `%`, `_`, and `\`
322322+// in the input are escaped via the LIKE ESCAPE clause used at query time.
323323+func patternToLike(p string) (string, error) {
324324+ if idx := strings.Index(p, "*"); idx >= 0 && idx != len(p)-1 {
325325+ return "", errInvalidPattern
326326+ }
327327+ literal := p
328328+ suffix := ""
329329+ if strings.HasSuffix(p, "*") {
330330+ literal = p[:len(p)-1]
331331+ suffix = "%"
185332 }
186186- return false
333333+ literal = strings.ReplaceAll(literal, `\`, `\\`)
334334+ literal = strings.ReplaceAll(literal, `%`, `\%`)
335335+ literal = strings.ReplaceAll(literal, `_`, `\_`)
336336+ return literal + suffix, nil
187337}