···147147 // Update status post to "online" after server starts
148148 if holdPDS != nil {
149149 ctx := context.Background()
150150+150151 if err := holdPDS.SetStatus(ctx, "online"); err != nil {
151152 log.Printf("Warning: Failed to set status post to online: %v", err)
152153 } else {
···33import (
44 "bytes"
55 "context"
66+ "io"
77+ "os"
68 "path/filepath"
79 "strings"
810 "testing"
···1719 ctx := context.Background()
1820 tmpDir := t.TempDir()
19212020- dbPath := filepath.Join(tmpDir, "pds.db")
2222+ // Use in-memory database for speed
2323+ dbPath := ":memory:"
2124 keyPath := filepath.Join(tmpDir, "signing-key")
22252626+ // Copy shared signing key instead of generating a new one
2727+ if err := os.WriteFile(keyPath, sharedTestKey, 0600); err != nil {
2828+ t.Fatalf("Failed to copy shared signing key: %v", err)
2929+ }
3030+2331 pds, err := NewHoldPDS(ctx, "did:web:hold.example.com", "https://hold.example.com", dbPath, keyPath)
2432 if err != nil {
2533 t.Fatalf("Failed to create test PDS: %v", err)
···3038 err = pds.repomgr.InitNewActor(ctx, pds.uid, "", pds.did, "", "", "")
3139 if err != nil {
3240 t.Fatalf("Failed to initialize test repo: %v", err)
4141+ }
4242+4343+ return pds, ctx
4444+}
4545+4646+// setupTestPDSWithBootstrap creates a test PDS and bootstraps it with suppressed output
4747+// This is a convenience function for tests that need a fully initialized PDS
4848+func setupTestPDSWithBootstrap(t *testing.T, ownerDID string, public, allowAllCrew bool) (*HoldPDS, context.Context) {
4949+ t.Helper()
5050+5151+ pds, ctx := setupTestPDS(t)
5252+5353+ // Bootstrap with suppressed output
5454+ oldStdout := os.Stdout
5555+ r, w, _ := os.Pipe()
5656+ os.Stdout = w
5757+5858+ err := pds.Bootstrap(ctx, nil, ownerDID, public, allowAllCrew, "")
5959+6060+ w.Close()
6161+ os.Stdout = oldStdout
6262+ io.ReadAll(r) // Drain the pipe
6363+6464+ if err != nil {
6565+ t.Fatalf("Failed to bootstrap PDS: %v", err)
3366 }
34673568 return pds, ctx
-232
pkg/hold/pds/database.go
···11-package pds
22-33-import (
44- "database/sql"
55- "fmt"
66- "os"
77- "path/filepath"
88- "time"
99-1010- _ "github.com/mattn/go-sqlite3"
1111-)
1212-1313-// Database manages app passwords and sessions for the hold PDS
1414-type Database struct {
1515- db *sql.DB
1616-}
1717-1818-// NewDatabase creates or opens a database for app passwords and sessions
1919-// dbPath should be the directory path (same as carstore)
2020-// It creates a separate "auth.db" file for authentication data
2121-func NewDatabase(dbPath string) (*Database, error) {
2222- // Ensure directory exists
2323- if err := os.MkdirAll(dbPath, 0755); err != nil {
2424- return nil, fmt.Errorf("failed to create database directory: %w", err)
2525- }
2626-2727- // Create auth database file alongside carstore database
2828- authDBFile := filepath.Join(dbPath, "auth.db")
2929-3030- db, err := sql.Open("sqlite3", authDBFile)
3131- if err != nil {
3232- return nil, fmt.Errorf("failed to open database: %w", err)
3333- }
3434-3535- // Create tables
3636- if err := createTables(db); err != nil {
3737- db.Close()
3838- return nil, fmt.Errorf("failed to create tables: %w", err)
3939- }
4040-4141- return &Database{db: db}, nil
4242-}
4343-4444-// createTables creates the database schema
4545-func createTables(db *sql.DB) error {
4646- schema := `
4747- CREATE TABLE IF NOT EXISTS app_passwords (
4848- id INTEGER PRIMARY KEY AUTOINCREMENT,
4949- name TEXT NOT NULL UNIQUE,
5050- password_hash TEXT NOT NULL,
5151- created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
5252- last_used_at TIMESTAMP
5353- );
5454-5555- CREATE INDEX IF NOT EXISTS idx_app_passwords_name ON app_passwords(name);
5656-5757- CREATE TABLE IF NOT EXISTS refresh_tokens (
5858- id INTEGER PRIMARY KEY AUTOINCREMENT,
5959- token_hash TEXT NOT NULL UNIQUE,
6060- created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
6161- expires_at TIMESTAMP NOT NULL,
6262- last_used_at TIMESTAMP
6363- );
6464-6565- CREATE INDEX IF NOT EXISTS idx_refresh_tokens_hash ON refresh_tokens(token_hash);
6666- CREATE INDEX IF NOT EXISTS idx_refresh_tokens_expires ON refresh_tokens(expires_at);
6767- `
6868-6969- _, err := db.Exec(schema)
7070- return err
7171-}
7272-7373-// Close closes the database connection
7474-func (d *Database) Close() error {
7575- return d.db.Close()
7676-}
7777-7878-// AppPassword represents an app password record
7979-type AppPassword struct {
8080- ID int64
8181- Name string
8282- PasswordHash string
8383- CreatedAt time.Time
8484- LastUsedAt *time.Time
8585-}
8686-8787-// CreateAppPassword stores a new app password
8888-func (d *Database) CreateAppPassword(name, passwordHash string) error {
8989- query := `INSERT INTO app_passwords (name, password_hash) VALUES (?, ?)`
9090- _, err := d.db.Exec(query, name, passwordHash)
9191- if err != nil {
9292- return fmt.Errorf("failed to create app password: %w", err)
9393- }
9494- return nil
9595-}
9696-9797-// GetAppPassword retrieves an app password by name
9898-func (d *Database) GetAppPassword(name string) (*AppPassword, error) {
9999- query := `SELECT id, name, password_hash, created_at, last_used_at FROM app_passwords WHERE name = ?`
100100-101101- var ap AppPassword
102102- var lastUsedAt sql.NullTime
103103-104104- err := d.db.QueryRow(query, name).Scan(
105105- &ap.ID,
106106- &ap.Name,
107107- &ap.PasswordHash,
108108- &ap.CreatedAt,
109109- &lastUsedAt,
110110- )
111111-112112- if err == sql.ErrNoRows {
113113- return nil, fmt.Errorf("app password not found")
114114- }
115115- if err != nil {
116116- return nil, fmt.Errorf("failed to get app password: %w", err)
117117- }
118118-119119- if lastUsedAt.Valid {
120120- ap.LastUsedAt = &lastUsedAt.Time
121121- }
122122-123123- return &ap, nil
124124-}
125125-126126-// ListAppPasswords returns all app passwords (without hashes)
127127-func (d *Database) ListAppPasswords() ([]AppPassword, error) {
128128- query := `SELECT id, name, created_at, last_used_at FROM app_passwords ORDER BY created_at DESC`
129129-130130- rows, err := d.db.Query(query)
131131- if err != nil {
132132- return nil, fmt.Errorf("failed to list app passwords: %w", err)
133133- }
134134- defer rows.Close()
135135-136136- var passwords []AppPassword
137137- for rows.Next() {
138138- var ap AppPassword
139139- var lastUsedAt sql.NullTime
140140-141141- if err := rows.Scan(&ap.ID, &ap.Name, &ap.CreatedAt, &lastUsedAt); err != nil {
142142- return nil, fmt.Errorf("failed to scan row: %w", err)
143143- }
144144-145145- if lastUsedAt.Valid {
146146- ap.LastUsedAt = &lastUsedAt.Time
147147- }
148148-149149- passwords = append(passwords, ap)
150150- }
151151-152152- return passwords, rows.Err()
153153-}
154154-155155-// UpdateLastUsed updates the last used timestamp for an app password
156156-func (d *Database) UpdateLastUsed(name string) error {
157157- query := `UPDATE app_passwords SET last_used_at = CURRENT_TIMESTAMP WHERE name = ?`
158158- _, err := d.db.Exec(query, name)
159159- return err
160160-}
161161-162162-// DeleteAppPassword removes an app password
163163-func (d *Database) DeleteAppPassword(name string) error {
164164- query := `DELETE FROM app_passwords WHERE name = ?`
165165- result, err := d.db.Exec(query, name)
166166- if err != nil {
167167- return fmt.Errorf("failed to delete app password: %w", err)
168168- }
169169-170170- rows, err := result.RowsAffected()
171171- if err != nil {
172172- return fmt.Errorf("failed to check rows affected: %w", err)
173173- }
174174-175175- if rows == 0 {
176176- return fmt.Errorf("app password not found")
177177- }
178178-179179- return nil
180180-}
181181-182182-// CreateRefreshToken stores a refresh token
183183-func (d *Database) CreateRefreshToken(tokenHash string, expiresAt time.Time) error {
184184- query := `INSERT INTO refresh_tokens (token_hash, expires_at) VALUES (?, ?)`
185185- _, err := d.db.Exec(query, tokenHash, expiresAt)
186186- if err != nil {
187187- return fmt.Errorf("failed to create refresh token: %w", err)
188188- }
189189- return nil
190190-}
191191-192192-// ValidateRefreshToken checks if a refresh token exists and is not expired
193193-func (d *Database) ValidateRefreshToken(tokenHash string) (bool, error) {
194194- query := `SELECT expires_at FROM refresh_tokens WHERE token_hash = ?`
195195-196196- var expiresAt time.Time
197197- err := d.db.QueryRow(query, tokenHash).Scan(&expiresAt)
198198-199199- if err == sql.ErrNoRows {
200200- return false, nil
201201- }
202202- if err != nil {
203203- return false, fmt.Errorf("failed to validate refresh token: %w", err)
204204- }
205205-206206- // Check if expired
207207- if time.Now().After(expiresAt) {
208208- // Delete expired token
209209- d.DeleteRefreshToken(tokenHash)
210210- return false, nil
211211- }
212212-213213- // Update last used
214214- updateQuery := `UPDATE refresh_tokens SET last_used_at = CURRENT_TIMESTAMP WHERE token_hash = ?`
215215- d.db.Exec(updateQuery, tokenHash)
216216-217217- return true, nil
218218-}
219219-220220-// DeleteRefreshToken removes a refresh token
221221-func (d *Database) DeleteRefreshToken(tokenHash string) error {
222222- query := `DELETE FROM refresh_tokens WHERE token_hash = ?`
223223- _, err := d.db.Exec(query, tokenHash)
224224- return err
225225-}
226226-227227-// CleanupExpiredTokens removes all expired refresh tokens
228228-func (d *Database) CleanupExpiredTokens() error {
229229- query := `DELETE FROM refresh_tokens WHERE expires_at < CURRENT_TIMESTAMP`
230230- _, err := d.db.Exec(query)
231231- return err
232232-}
+71-16
pkg/hold/pds/events.go
···88 "time"
991010 atproto "github.com/bluesky-social/indigo/api/atproto"
1111+ "github.com/bluesky-social/indigo/events"
1112 lexutil "github.com/bluesky-social/indigo/lex/util"
1213 "github.com/gorilla/websocket"
1414+ "github.com/ipfs/go-cid"
1315)
14161517// EventBroadcaster manages WebSocket connections and broadcasts repo events
···7779 b.mu.Unlock()
78807981 // Send historical events if cursor is provided and < current seq
8080- if cursor > 0 && cursor < currentSeq {
8282+ // cursor=0 means "replay all events from the beginning"
8383+ // cursor >= 0 triggers backfill, negative cursor means "no backfill"
8484+ if cursor >= 0 && cursor < currentSeq {
8185 go b.backfillSubscriber(sub, cursor)
8286 }
8387···205209 }()
206210207211 for event := range sub.send {
208208- // Encode as CBOR
209209- cborBytes, err := encodeCBOR(event)
212212+ // Create event header (ATProto firehose format)
213213+ header := events.EventHeader{
214214+ Op: events.EvtKindMessage,
215215+ MsgType: "#commit",
216216+ }
217217+218218+ // Get a writer for this message
219219+ wc, err := sub.conn.NextWriter(websocket.BinaryMessage)
210220 if err != nil {
211211- log.Printf("Failed to encode event as CBOR: %v", err)
212212- continue
221221+ log.Printf("Failed to get websocket writer: %v", err)
222222+ return
213223 }
214224215215- // Write CBOR message to WebSocket
216216- err = sub.conn.WriteMessage(websocket.BinaryMessage, cborBytes)
217217- if err != nil {
218218- log.Printf("Failed to write to websocket: %v", err)
225225+ // Write header as CBOR
226226+ if err := header.MarshalCBOR(wc); err != nil {
227227+ log.Printf("Failed to write event header: %v", err)
228228+ wc.Close()
229229+ return
230230+ }
231231+232232+ // Convert our RepoCommitEvent to indigo's SyncSubscribeRepos_Commit
233233+ indigoEvent := convertToIndigoCommit(event)
234234+235235+ // Write the event as CBOR
236236+ var obj lexutil.CBOR = indigoEvent
237237+ if err := obj.MarshalCBOR(wc); err != nil {
238238+ log.Printf("Failed to write event body: %v", err)
239239+ wc.Close()
240240+ return
241241+ }
242242+243243+ // Close the writer to flush the message
244244+ if err := wc.Close(); err != nil {
245245+ log.Printf("Failed to close websocket writer: %v", err)
219246 return
220247 }
221248···224251 }
225252}
226253227227-// encodeCBOR encodes an event as CBOR
254254+// convertToIndigoCommit converts our RepoCommitEvent to indigo's SyncSubscribeRepos_Commit
255255+// which has proper CBOR marshaling methods generated
256256+func convertToIndigoCommit(event *RepoCommitEvent) *atproto.SyncSubscribeRepos_Commit {
257257+ // Parse commit CID string to cid.Cid, then convert to LexLink
258258+ commitCID, err := cid.Decode(event.Commit)
259259+ if err != nil {
260260+ log.Printf("Warning: failed to parse commit CID %s: %v", event.Commit, err)
261261+ // Create an empty CID as fallback
262262+ commitCID = cid.Undef
263263+ }
264264+265265+ // Convert cid.Cid to LexLink
266266+ commitLink := lexutil.LexLink(commitCID)
267267+268268+ // Convert blocks to LexBytes
269269+ blocks := lexutil.LexBytes(event.Blocks)
270270+271271+ return &atproto.SyncSubscribeRepos_Commit{
272272+ Seq: event.Seq,
273273+ Repo: event.Repo,
274274+ Commit: commitLink,
275275+ Rev: event.Rev,
276276+ Since: event.Since,
277277+ Blocks: blocks,
278278+ Ops: event.Ops,
279279+ Time: event.Time,
280280+ Blobs: []lexutil.LexLink{}, // Empty for now, we don't track blob refs in our simplified model
281281+ Rebase: false, // DEPRECATED field
282282+ TooBig: false, // Not implementing tooBig for now
283283+ }
284284+}
285285+286286+// encodeCBOR encodes an event as CBOR (DEPRECATED - kept for tests)
228287func encodeCBOR(event *RepoCommitEvent) ([]byte, error) {
229229- // For now, use JSON encoding wrapped in CBOR envelope
230230- // In production, you'd use proper CBOR encoding
231231- // The atproto spec requires DAG-CBOR with specific header
232232-233233- // Simple approach: encode as JSON for MVP
234234- // Real implementation needs proper CBOR-gen types
288288+ // For backward compatibility with tests, encode as JSON
289289+ // Production code uses convertToIndigoCommit + CBOR marshaling in handleSubscriber
235290 return json.Marshal(event)
236291}
237292
+164
pkg/hold/pds/events_test.go
···382382 t.Errorf("Expected decoded seq=1, got %d", decoded.Seq)
383383 }
384384}
385385+386386+// TestSubscribe_CursorZeroBackfill tests that cursor=0 replays all events
387387+func TestSubscribe_CursorZeroBackfill(t *testing.T) {
388388+ broadcaster := NewEventBroadcaster("did:web:hold.example.com", 100)
389389+ ctx := context.Background()
390390+391391+ testCID, _ := cid.Decode("bafyreib2rxk3rkhh5ylyxj3x3gathxt3s32qvwj2lf3qg4kmzr6b7teqke")
392392+393393+ // Broadcast 5 events before subscribing
394394+ for i := 1; i <= 5; i++ {
395395+ event := &RepoEvent{
396396+ NewRoot: testCID,
397397+ Rev: "test-rev",
398398+ RepoSlice: []byte("test CAR data"),
399399+ Ops: []RepoOp{},
400400+ }
401401+ broadcaster.Broadcast(ctx, event)
402402+ }
403403+404404+ // Verify we have 5 events in history
405405+ if broadcaster.eventSeq != 5 {
406406+ t.Fatalf("Expected eventSeq=5, got %d", broadcaster.eventSeq)
407407+ }
408408+409409+ // Create mock websocket connection (we won't actually use it)
410410+ // We just need to verify backfillSubscriber is called
411411+ // For this test, we'll check the history directly
412412+ if len(broadcaster.eventHistory) != 5 {
413413+ t.Errorf("Expected 5 events in history, got %d", len(broadcaster.eventHistory))
414414+ }
415415+416416+ // Verify all events have sequential sequence numbers
417417+ for i, he := range broadcaster.eventHistory {
418418+ expectedSeq := int64(i + 1)
419419+ if he.Seq != expectedSeq {
420420+ t.Errorf("Expected history[%d].Seq=%d, got %d", i, expectedSeq, he.Seq)
421421+ }
422422+ }
423423+424424+ // Test backfillSubscriber directly with cursor=0
425425+ // Create a subscriber manually (conn not needed for backfill test)
426426+ sub := &Subscriber{
427427+ conn: nil, // Not used in backfillSubscriber
428428+ send: make(chan *RepoCommitEvent, 100), // Large buffer for testing
429429+ cursor: 0,
430430+ }
431431+432432+ // Run backfill in a goroutine
433433+ go broadcaster.backfillSubscriber(sub, 0)
434434+435435+ // Wait for events to be sent
436436+ time.Sleep(100 * time.Millisecond)
437437+438438+ // Should receive all 5 events
439439+ receivedCount := len(sub.send)
440440+ if receivedCount != 5 {
441441+ t.Errorf("Expected to receive 5 events with cursor=0, got %d", receivedCount)
442442+ }
443443+444444+ // Verify events are in order
445445+ for i := 1; i <= 5; i++ {
446446+ select {
447447+ case event := <-sub.send:
448448+ if event.Seq != int64(i) {
449449+ t.Errorf("Expected event seq=%d, got %d", i, event.Seq)
450450+ }
451451+ default:
452452+ t.Errorf("Expected event %d but channel was empty", i)
453453+ }
454454+ }
455455+}
456456+457457+// TestSubscribe_MidCursorBackfill tests that cursor=N only gets events after N
458458+func TestSubscribe_MidCursorBackfill(t *testing.T) {
459459+ broadcaster := NewEventBroadcaster("did:web:hold.example.com", 100)
460460+ ctx := context.Background()
461461+462462+ testCID, _ := cid.Decode("bafyreib2rxk3rkhh5ylyxj3x3gathxt3s32qvwj2lf3qg4kmzr6b7teqke")
463463+464464+ // Broadcast 10 events before subscribing
465465+ for i := 1; i <= 10; i++ {
466466+ event := &RepoEvent{
467467+ NewRoot: testCID,
468468+ Rev: "test-rev",
469469+ RepoSlice: []byte("test CAR data"),
470470+ Ops: []RepoOp{},
471471+ }
472472+ broadcaster.Broadcast(ctx, event)
473473+ }
474474+475475+ // Test backfillSubscriber with cursor=5 (conn not needed for backfill test)
476476+ sub := &Subscriber{
477477+ conn: nil, // Not used in backfillSubscriber
478478+ send: make(chan *RepoCommitEvent, 100), // Large buffer for testing
479479+ cursor: 5,
480480+ }
481481+482482+ // Run backfill
483483+ go broadcaster.backfillSubscriber(sub, 5)
484484+485485+ // Wait for events to be sent
486486+ time.Sleep(100 * time.Millisecond)
487487+488488+ // Should receive events 6-10 (5 events after cursor=5)
489489+ receivedCount := len(sub.send)
490490+ if receivedCount != 5 {
491491+ t.Errorf("Expected to receive 5 events with cursor=5, got %d", receivedCount)
492492+ }
493493+494494+ // Verify events start at seq=6
495495+ for i := 6; i <= 10; i++ {
496496+ select {
497497+ case event := <-sub.send:
498498+ if event.Seq != int64(i) {
499499+ t.Errorf("Expected event seq=%d, got %d", i, event.Seq)
500500+ }
501501+ default:
502502+ t.Errorf("Expected event %d but channel was empty", i)
503503+ }
504504+ }
505505+}
506506+507507+// TestSubscribe_NegativeCursorNoBackfill tests that negative cursor means no backfill
508508+func TestSubscribe_NegativeCursorNoBackfill(t *testing.T) {
509509+ broadcaster := NewEventBroadcaster("did:web:hold.example.com", 100)
510510+ ctx := context.Background()
511511+512512+ testCID, _ := cid.Decode("bafyreib2rxk3rkhh5ylyxj3x3gathxt3s32qvwj2lf3qg4kmzr6b7teqke")
513513+514514+ // Broadcast 5 events before subscribing
515515+ for i := 1; i <= 5; i++ {
516516+ event := &RepoEvent{
517517+ NewRoot: testCID,
518518+ Rev: "test-rev",
519519+ RepoSlice: []byte("test CAR data"),
520520+ Ops: []RepoOp{},
521521+ }
522522+ broadcaster.Broadcast(ctx, event)
523523+ }
524524+525525+ // Create subscriber with cursor=-1 (no backfill, conn not needed)
526526+ sub := &Subscriber{
527527+ conn: nil, // Not used in this test
528528+ send: make(chan *RepoCommitEvent, 100),
529529+ cursor: -1,
530530+ }
531531+532532+ // Subscribe should not trigger backfill
533533+ broadcaster.mu.Lock()
534534+ currentSeq := broadcaster.eventSeq
535535+ broadcaster.mu.Unlock()
536536+537537+ // Check the condition: cursor >= 0 && cursor < currentSeq
538538+ // For cursor=-1, this should be false
539539+ shouldBackfill := -1 >= 0 && -1 < currentSeq
540540+ if shouldBackfill {
541541+ t.Error("Expected shouldBackfill=false for cursor=-1, but condition evaluated to true")
542542+ }
543543+544544+ // Verify no events in send channel (no backfill happened)
545545+ if len(sub.send) != 0 {
546546+ t.Errorf("Expected 0 events with cursor=-1 (no backfill), got %d", len(sub.send))
547547+ }
548548+}
-249
pkg/hold/pds/jwt.go
···11-package pds
22-33-import (
44- "crypto/sha256"
55- "encoding/base64"
66- "encoding/hex"
77- "encoding/json"
88- "fmt"
99- "time"
1010-)
1111-1212-// Session token types
1313-const (
1414- TokenTypeAccess = "access"
1515- TokenTypeRefresh = "refresh"
1616-)
1717-1818-// Token expiration durations
1919-const (
2020- AccessTokenDuration = 2 * time.Hour // Short-lived access token
2121- RefreshTokenDuration = 90 * 24 * time.Hour // Long-lived refresh token (90 days)
2222-)
2323-2424-// SessionClaims represents JWT claims for ATProto sessions
2525-type SessionClaims struct {
2626- DID string `json:"sub"` // Subject (DID)
2727- Issuer string `json:"iss"` // Issuer (PDS DID)
2828- Handle string `json:"handle,omitempty"`
2929- Scope string `json:"scope"`
3030- TokenType string `json:"token_type"`
3131- IssuedAt int64 `json:"iat"` // Unix timestamp
3232- ExpiresAt int64 `json:"exp"` // Unix timestamp
3333-}
3434-3535-// IssueAccessToken creates a new access JWT for a session
3636-func (p *HoldPDS) IssueAccessToken(did, handle string) (string, error) {
3737- now := time.Now()
3838- claims := &SessionClaims{
3939- DID: did,
4040- Issuer: p.did,
4141- Handle: handle,
4242- Scope: "com.atproto.access",
4343- TokenType: TokenTypeAccess,
4444- IssuedAt: now.Unix(),
4545- ExpiresAt: now.Add(AccessTokenDuration).Unix(),
4646- }
4747-4848- return p.signJWT(claims)
4949-}
5050-5151-// IssueRefreshToken creates a new refresh JWT for a session
5252-func (p *HoldPDS) IssueRefreshToken(did, handle string) (string, error) {
5353- now := time.Now()
5454- claims := &SessionClaims{
5555- DID: did,
5656- Issuer: p.did,
5757- Handle: handle,
5858- Scope: "com.atproto.refresh",
5959- TokenType: TokenTypeRefresh,
6060- IssuedAt: now.Unix(),
6161- ExpiresAt: now.Add(RefreshTokenDuration).Unix(),
6262- }
6363-6464- signedToken, err := p.signJWT(claims)
6565- if err != nil {
6666- return "", err
6767- }
6868-6969- // Store refresh token hash in database for validation/revocation
7070- tokenHash := hashToken(signedToken)
7171- expiresAt := now.Add(RefreshTokenDuration)
7272- if err := p.authDB.CreateRefreshToken(tokenHash, expiresAt); err != nil {
7373- return "", fmt.Errorf("failed to store refresh token: %w", err)
7474- }
7575-7676- return signedToken, nil
7777-}
7878-7979-// ValidateAccessToken validates an access JWT and returns the claims
8080-func (p *HoldPDS) ValidateAccessToken(tokenString string) (*SessionClaims, error) {
8181- return p.validateToken(tokenString, TokenTypeAccess)
8282-}
8383-8484-// ValidateRefreshToken validates a refresh JWT and returns the claims
8585-// Also checks the database to ensure the token hasn't been revoked
8686-func (p *HoldPDS) ValidateRefreshToken(tokenString string) (*SessionClaims, error) {
8787- // First validate signature and claims
8888- claims, err := p.validateToken(tokenString, TokenTypeRefresh)
8989- if err != nil {
9090- return nil, err
9191- }
9292-9393- // Check if token is in database (not revoked)
9494- tokenHash := hashToken(tokenString)
9595- valid, err := p.authDB.ValidateRefreshToken(tokenHash)
9696- if err != nil {
9797- return nil, fmt.Errorf("failed to validate refresh token in database: %w", err)
9898- }
9999- if !valid {
100100- return nil, fmt.Errorf("refresh token has been revoked or expired")
101101- }
102102-103103- return claims, nil
104104-}
105105-106106-// validateToken validates a JWT token and returns the claims
107107-func (p *HoldPDS) validateToken(tokenString, expectedType string) (*SessionClaims, error) {
108108- // Split token into parts
109109- parts := splitJWT(tokenString)
110110- if len(parts) != 3 {
111111- return nil, fmt.Errorf("invalid JWT format")
112112- }
113113-114114- // Decode header
115115- headerBytes, err := base64.RawURLEncoding.DecodeString(parts[0])
116116- if err != nil {
117117- return nil, fmt.Errorf("failed to decode header: %w", err)
118118- }
119119-120120- var header map[string]interface{}
121121- if err := json.Unmarshal(headerBytes, &header); err != nil {
122122- return nil, fmt.Errorf("failed to parse header: %w", err)
123123- }
124124-125125- // Verify algorithm
126126- alg, ok := header["alg"].(string)
127127- if !ok || alg != "ES256K" {
128128- return nil, fmt.Errorf("unsupported algorithm: %v", alg)
129129- }
130130-131131- // Decode claims
132132- claimsBytes, err := base64.RawURLEncoding.DecodeString(parts[1])
133133- if err != nil {
134134- return nil, fmt.Errorf("failed to decode claims: %w", err)
135135- }
136136-137137- var claims SessionClaims
138138- if err := json.Unmarshal(claimsBytes, &claims); err != nil {
139139- return nil, fmt.Errorf("failed to parse claims: %w", err)
140140- }
141141-142142- // Verify token type
143143- if claims.TokenType != expectedType {
144144- return nil, fmt.Errorf("invalid token type: expected %s, got %s", expectedType, claims.TokenType)
145145- }
146146-147147- // Verify issuer
148148- if claims.Issuer != p.did {
149149- return nil, fmt.Errorf("invalid issuer: expected %s, got %s", p.did, claims.Issuer)
150150- }
151151-152152- // Verify subject matches this hold
153153- if claims.DID != p.did {
154154- return nil, fmt.Errorf("invalid subject: expected %s, got %s", p.did, claims.DID)
155155- }
156156-157157- // Verify expiration
158158- if time.Now().Unix() > claims.ExpiresAt {
159159- return nil, fmt.Errorf("token has expired")
160160- }
161161-162162- // Verify signature
163163- signedData := []byte(parts[0] + "." + parts[1])
164164- signature, err := base64.RawURLEncoding.DecodeString(parts[2])
165165- if err != nil {
166166- return nil, fmt.Errorf("failed to decode signature: %w", err)
167167- }
168168-169169- publicKey, err := p.signingKey.PublicKey()
170170- if err != nil {
171171- return nil, fmt.Errorf("failed to get public key: %w", err)
172172- }
173173-174174- if err := publicKey.HashAndVerify(signedData, signature); err != nil {
175175- return nil, fmt.Errorf("signature verification failed: %w", err)
176176- }
177177-178178- return &claims, nil
179179-}
180180-181181-// RevokeRefreshToken revokes a refresh token by removing it from the database
182182-func (p *HoldPDS) RevokeRefreshToken(tokenString string) error {
183183- tokenHash := hashToken(tokenString)
184184- return p.authDB.DeleteRefreshToken(tokenHash)
185185-}
186186-187187-// hashToken creates a SHA-256 hash of a token for storage
188188-func hashToken(token string) string {
189189- hash := sha256.Sum256([]byte(token))
190190- return hex.EncodeToString(hash[:])
191191-}
192192-193193-// signJWT creates and signs a JWT using the hold's private key
194194-func (p *HoldPDS) signJWT(claims *SessionClaims) (string, error) {
195195- // Create header
196196- header := map[string]interface{}{
197197- "typ": "JWT",
198198- "alg": "ES256K",
199199- }
200200-201201- headerJSON, err := json.Marshal(header)
202202- if err != nil {
203203- return "", fmt.Errorf("failed to marshal header: %w", err)
204204- }
205205-206206- // Create payload
207207- payloadJSON, err := json.Marshal(claims)
208208- if err != nil {
209209- return "", fmt.Errorf("failed to marshal claims: %w", err)
210210- }
211211-212212- // Base64url encode header and payload
213213- headerEncoded := base64.RawURLEncoding.EncodeToString(headerJSON)
214214- payloadEncoded := base64.RawURLEncoding.EncodeToString(payloadJSON)
215215-216216- // Create signing input
217217- signingInput := headerEncoded + "." + payloadEncoded
218218-219219- // Sign with private key
220220- signature, err := p.signingKey.HashAndSign([]byte(signingInput))
221221- if err != nil {
222222- return "", fmt.Errorf("failed to sign JWT: %w", err)
223223- }
224224-225225- // Base64url encode signature
226226- signatureEncoded := base64.RawURLEncoding.EncodeToString(signature)
227227-228228- // Combine into final JWT
229229- jwt := signingInput + "." + signatureEncoded
230230-231231- return jwt, nil
232232-}
233233-234234-// splitJWT splits a JWT string into its three parts
235235-func splitJWT(token string) []string {
236236- // JWT format: header.payload.signature
237237- parts := make([]string, 0, 3)
238238- start := 0
239239- for i := 0; i < len(token); i++ {
240240- if token[i] == '.' {
241241- parts = append(parts, token[start:i])
242242- start = i + 1
243243- }
244244- }
245245- if start < len(token) {
246246- parts = append(parts, token[start:])
247247- }
248248- return parts
249249-}
+21-67
pkg/hold/pds/server.go
···3636 dbPath string
3737 uid models.Uid
3838 signingKey *atcrypto.PrivateKeyK256
3939- authDB *Database // Authentication database for app passwords and sessions
4039}
41404241// NewHoldPDS creates or opens a hold PDS with SQLite carstore
4342func NewHoldPDS(ctx context.Context, did, publicURL, dbPath, keyPath string) (*HoldPDS, error) {
4444- // Ensure directory exists
4545- dir := filepath.Dir(dbPath)
4646- if err := os.MkdirAll(dir, 0755); err != nil {
4747- return nil, fmt.Errorf("failed to create database directory: %w", err)
4848- }
4949-5043 // Generate or load signing key
5144 signingKey, err := GenerateOrLoadKey(keyPath)
5245 if err != nil {
5346 return nil, fmt.Errorf("failed to initialize signing key: %w", err)
5447 }
55485656- // Create and open SQLite-backed carstore
5757- // dbPath is the directory, carstore creates and opens db.sqlite3 inside it
5858- sqlStore, err := carstore.NewSqliteStore(dbPath)
5959- if err != nil {
6060- return nil, fmt.Errorf("failed to create sqlite store: %w", err)
4949+ // Create SQLite-backed carstore
5050+ var sqlStore *carstore.SQLiteStore
5151+5252+ if dbPath == ":memory:" {
5353+ // In-memory mode for tests: create carstore manually and open with :memory:
5454+ sqlStore = new(carstore.SQLiteStore)
5555+ if err := sqlStore.Open(":memory:"); err != nil {
5656+ return nil, fmt.Errorf("failed to open in-memory sqlite store: %w", err)
5757+ }
5858+ } else {
5959+ // File mode for production: create directory and use NewSqliteStore
6060+ dir := filepath.Dir(dbPath)
6161+ if err := os.MkdirAll(dir, 0755); err != nil {
6262+ return nil, fmt.Errorf("failed to create database directory: %w", err)
6363+ }
6464+6565+ // dbPath is the directory, carstore creates and opens db.sqlite3 inside it
6666+ sqlStore, err = carstore.NewSqliteStore(dbPath)
6767+ if err != nil {
6868+ return nil, fmt.Errorf("failed to create sqlite store: %w", err)
6969+ }
6170 }
62716372 // Use SQLiteStore directly, not the CarStore() wrapper
···8493 fmt.Printf("New hold repo - will be initialized in Bootstrap\n")
8594 }
86958787- // Create or open authentication database
8888- authDB, err := NewDatabase(dbPath)
8989- if err != nil {
9090- return nil, fmt.Errorf("failed to create auth database: %w", err)
9191- }
9292-9396 return &HoldPDS{
9497 did: did,
9598 PublicURL: publicURL,
···98101 dbPath: dbPath,
99102 uid: uid,
100103 signingKey: signingKey,
101101- authDB: authDB,
102104 }, nil
103105}
104106···184186 } else {
185187 fmt.Printf("✅ Bluesky profile record already exists, skipping\n")
186188 }
187187-188188- // Create Tangled profile record (idempotent - check if exists first)
189189- _, _, err = p.GetTangledProfileRecord(ctx)
190190- if err != nil {
191191- // Tangled profile doesn't exist, create it
192192- description := "ahoy from the cargo hold"
193193- links := []string{"https://atcr.io"}
194194-195195- _, err = p.CreateTangledProfileRecord(ctx, links, description)
196196- if err != nil {
197197- return fmt.Errorf("failed to create tangled profile record: %w", err)
198198- }
199199- fmt.Printf("✅ Created Tangled profile record\n")
200200- } else {
201201- fmt.Printf("✅ Tangled profile record already exists, skipping\n")
202202- }
203203- }
204204-205205- // Create bootstrap app password if none exist (one-time setup)
206206- passwords, err := p.authDB.ListAppPasswords()
207207- if err != nil {
208208- return fmt.Errorf("failed to list app passwords: %w", err)
209209- }
210210-211211- if len(passwords) == 0 {
212212- // No app passwords exist, create one
213213- password, err := p.CreateAppPassword("bootstrap")
214214- if err != nil {
215215- return fmt.Errorf("failed to create bootstrap app password: %w", err)
216216- }
217217-218218- fmt.Printf("\n")
219219- fmt.Printf("╔════════════════════════════════════════════════════════════════╗\n")
220220- fmt.Printf("║ 🔑 APP PASSWORD CREATED ║\n")
221221- fmt.Printf("╠════════════════════════════════════════════════════════════════╣\n")
222222- fmt.Printf("║ ║\n")
223223- fmt.Printf("║ Password: %-51s ║\n", password)
224224- fmt.Printf("║ ║\n")
225225- fmt.Printf("║ ⚠️ SAVE THIS PASSWORD - it will not be shown again ║\n")
226226- fmt.Printf("║ ║\n")
227227- fmt.Printf("║ Use this password to log into Bluesky app or CLI tools ║\n")
228228- fmt.Printf("║ PDS URL: %-51s ║\n", p.PublicURL)
229229- fmt.Printf("║ Username: %-50s ║\n", p.did)
230230- fmt.Printf("║ ║\n")
231231- fmt.Printf("╚════════════════════════════════════════════════════════════════╝\n")
232232- fmt.Printf("\n")
233233- } else {
234234- fmt.Printf("✅ App passwords already exist (count: %d), skipping auto-generation\n", len(passwords))
235189 }
236190237191 return nil
-381
pkg/hold/pds/session.go
···11-package pds
22-33-import (
44- "encoding/json"
55- "fmt"
66- "io"
77- "net/http"
88- "strings"
99-1010- cbg "github.com/whyrusleeping/cbor-gen"
1111- "github.com/ipfs/go-cid"
1212-)
1313-1414-// CreateSessionRequest represents a session creation request
1515-type CreateSessionRequest struct {
1616- Identifier string `json:"identifier"` // DID or handle
1717- Password string `json:"password"` // App password
1818-}
1919-2020-// CreateSessionResponse represents a successful session creation
2121-type CreateSessionResponse struct {
2222- AccessJwt string `json:"accessJwt"`
2323- RefreshJwt string `json:"refreshJwt"`
2424- Handle string `json:"handle"`
2525- DID string `json:"did"`
2626- DIDDoc map[string]interface{} `json:"didDoc,omitempty"` // Optional DID document
2727- Email string `json:"email,omitempty"` // Optional, not used for holds
2828- Active *bool `json:"active,omitempty"` // Optional account status
2929- Status string `json:"status,omitempty"` // Optional account status
3030-}
3131-3232-// SessionInfo represents session information
3333-type SessionInfo struct {
3434- Handle string `json:"handle"`
3535- DID string `json:"did"`
3636- Email string `json:"email,omitempty"`
3737-}
3838-3939-// HandleCreateSession handles com.atproto.server.createSession
4040-func (h *XRPCHandler) HandleCreateSession(w http.ResponseWriter, r *http.Request) {
4141- // Parse request
4242- var req CreateSessionRequest
4343- if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
4444- http.Error(w, fmt.Sprintf("invalid request body: %v", err), http.StatusBadRequest)
4545- return
4646- }
4747-4848- // Validate required fields
4949- if req.Identifier == "" || req.Password == "" {
5050- http.Error(w, "identifier and password are required", http.StatusBadRequest)
5151- return
5252- }
5353-5454- // Validate identifier matches this hold's DID or handle
5555- holdDID := h.pds.DID()
5656- // For did:web, handle is the domain part without "did:web:" prefix
5757- // e.g., "did:web:hold01.atcr.io" -> "hold01.atcr.io"
5858- holdHandle := strings.TrimPrefix(holdDID, "did:web:")
5959-6060- // Normalize the identifier (strip "at://" prefix if present)
6161- identifier := strings.TrimPrefix(req.Identifier, "at://")
6262-6363- // Accept any of:
6464- // 1. Full DID: "did:web:hold01.atcr.io"
6565- // 2. Handle (domain): "hold01.atcr.io"
6666- // 3. Either with "at://" prefix
6767- isValidIdentifier := identifier == holdDID || identifier == holdHandle
6868-6969- if !isValidIdentifier {
7070- fmt.Printf("Invalid identifier: got %q, expected DID %q or handle %q\n", req.Identifier, holdDID, holdHandle)
7171- http.Error(w, "invalid identifier", http.StatusUnauthorized)
7272- return
7373- }
7474-7575- // Validate app password
7676- _, err := h.pds.ValidateAnyAppPassword(req.Password)
7777- if err != nil {
7878- http.Error(w, "invalid password", http.StatusUnauthorized)
7979- return
8080- }
8181-8282- // Issue access and refresh tokens
8383- accessToken, err := h.pds.IssueAccessToken(holdDID, holdHandle)
8484- if err != nil {
8585- http.Error(w, fmt.Sprintf("failed to issue access token: %v", err), http.StatusInternalServerError)
8686- return
8787- }
8888-8989- refreshToken, err := h.pds.IssueRefreshToken(holdDID, holdHandle)
9090- if err != nil {
9191- http.Error(w, fmt.Sprintf("failed to issue refresh token: %v", err), http.StatusInternalServerError)
9292- return
9393- }
9494-9595- // Return session response
9696- active := true
9797- response := CreateSessionResponse{
9898- AccessJwt: accessToken,
9999- RefreshJwt: refreshToken,
100100- Handle: holdHandle,
101101- DID: holdDID,
102102- Active: &active, // Account is active
103103- }
104104-105105- w.Header().Set("Content-Type", "application/json")
106106- json.NewEncoder(w).Encode(response)
107107-}
108108-109109-// HandleRefreshSession handles com.atproto.server.refreshSession
110110-func (h *XRPCHandler) HandleRefreshSession(w http.ResponseWriter, r *http.Request) {
111111- // Extract refresh token from Authorization header
112112- authHeader := r.Header.Get("Authorization")
113113- if authHeader == "" {
114114- http.Error(w, "authorization header required", http.StatusUnauthorized)
115115- return
116116- }
117117-118118- // Remove "Bearer " prefix
119119- refreshToken := strings.TrimPrefix(authHeader, "Bearer ")
120120- if refreshToken == authHeader {
121121- http.Error(w, "invalid authorization header format", http.StatusUnauthorized)
122122- return
123123- }
124124-125125- // Validate refresh token
126126- claims, err := h.pds.ValidateRefreshToken(refreshToken)
127127- if err != nil {
128128- http.Error(w, fmt.Sprintf("invalid refresh token: %v", err), http.StatusUnauthorized)
129129- return
130130- }
131131-132132- // Issue new access token (and optionally new refresh token)
133133- accessToken, err := h.pds.IssueAccessToken(claims.DID, claims.Handle)
134134- if err != nil {
135135- http.Error(w, fmt.Sprintf("failed to issue access token: %v", err), http.StatusInternalServerError)
136136- return
137137- }
138138-139139- // Issue new refresh token (rotate refresh tokens for security)
140140- newRefreshToken, err := h.pds.IssueRefreshToken(claims.DID, claims.Handle)
141141- if err != nil {
142142- http.Error(w, fmt.Sprintf("failed to issue refresh token: %v", err), http.StatusInternalServerError)
143143- return
144144- }
145145-146146- // Revoke old refresh token
147147- if err := h.pds.RevokeRefreshToken(refreshToken); err != nil {
148148- // Log but don't fail - new tokens are already issued
149149- fmt.Printf("Warning: failed to revoke old refresh token: %v\n", err)
150150- }
151151-152152- // Return new tokens
153153- active := true
154154- response := CreateSessionResponse{
155155- AccessJwt: accessToken,
156156- RefreshJwt: newRefreshToken,
157157- Handle: claims.Handle,
158158- DID: claims.DID,
159159- Active: &active, // Account is active
160160- }
161161-162162- w.Header().Set("Content-Type", "application/json")
163163- json.NewEncoder(w).Encode(response)
164164-}
165165-166166-// HandleGetSession handles com.atproto.server.getSession
167167-func (h *XRPCHandler) HandleGetSession(w http.ResponseWriter, r *http.Request) {
168168- // Extract access token from Authorization header
169169- authHeader := r.Header.Get("Authorization")
170170- if authHeader == "" {
171171- http.Error(w, "authorization header required", http.StatusUnauthorized)
172172- return
173173- }
174174-175175- // Remove "Bearer " prefix
176176- accessToken := strings.TrimPrefix(authHeader, "Bearer ")
177177- if accessToken == authHeader {
178178- http.Error(w, "invalid authorization header format", http.StatusUnauthorized)
179179- return
180180- }
181181-182182- // Validate access token
183183- claims, err := h.pds.ValidateAccessToken(accessToken)
184184- if err != nil {
185185- http.Error(w, fmt.Sprintf("invalid access token: %v", err), http.StatusUnauthorized)
186186- return
187187- }
188188-189189- // Return session info
190190- response := SessionInfo{
191191- Handle: claims.Handle,
192192- DID: claims.DID,
193193- }
194194-195195- w.Header().Set("Content-Type", "application/json")
196196- json.NewEncoder(w).Encode(response)
197197-}
198198-199199-// CreateRecordRequest represents a record creation request
200200-type CreateRecordRequest struct {
201201- Repo string `json:"repo"` // DID of the repository
202202- Collection string `json:"collection"` // Collection name (e.g., "app.bsky.feed.post")
203203- Rkey string `json:"rkey,omitempty"` // Optional record key (TID generated if not provided)
204204- Validate *bool `json:"validate,omitempty"` // Optional validation flag
205205- Record interface{} `json:"record"` // The record value (JSON object)
206206-}
207207-208208-// CreateRecordResponse represents a successful record creation
209209-type CreateRecordResponse struct {
210210- URI string `json:"uri"` // at://did/collection/rkey
211211- CID string `json:"cid"` // Record CID
212212-}
213213-214214-// RawRecord wraps a record value and implements CBORMarshaler
215215-// This allows us to accept any JSON record and marshal it to CBOR
216216-type RawRecord struct {
217217- Value map[string]interface{}
218218-}
219219-220220-// MarshalCBOR implements CBORMarshaler for RawRecord
221221-func (r *RawRecord) MarshalCBOR(w io.Writer) error {
222222- // Write CBOR map header
223223- if err := cbg.WriteMajorTypeHeader(w, cbg.MajMap, uint64(len(r.Value))); err != nil {
224224- return err
225225- }
226226-227227- // Write each key-value pair
228228- for key, val := range r.Value {
229229- // Write key as text string
230230- if err := cbg.WriteMajorTypeHeader(w, cbg.MajTextString, uint64(len(key))); err != nil {
231231- return err
232232- }
233233- if _, err := w.Write([]byte(key)); err != nil {
234234- return err
235235- }
236236-237237- // Write value (simplified - handles common types)
238238- if err := writeValue(w, val); err != nil {
239239- return err
240240- }
241241- }
242242-243243- return nil
244244-}
245245-246246-// writeValue writes a value to CBOR (helper for RawRecord)
247247-func writeValue(w io.Writer, val interface{}) error {
248248- switch v := val.(type) {
249249- case string:
250250- if err := cbg.WriteMajorTypeHeader(w, cbg.MajTextString, uint64(len(v))); err != nil {
251251- return err
252252- }
253253- _, err := w.Write([]byte(v))
254254- return err
255255- case int64:
256256- return cbg.CborWriteHeader(w, cbg.MajUnsignedInt, uint64(v))
257257- case float64:
258258- // Write as unsigned int for now (simplified)
259259- return cbg.CborWriteHeader(w, cbg.MajUnsignedInt, uint64(v))
260260- case bool:
261261- if v {
262262- return cbg.WriteBool(w, true)
263263- }
264264- return cbg.WriteBool(w, false)
265265- case map[string]interface{}:
266266- rec := &RawRecord{Value: v}
267267- return rec.MarshalCBOR(w)
268268- case []interface{}:
269269- if err := cbg.WriteMajorTypeHeader(w, cbg.MajArray, uint64(len(v))); err != nil {
270270- return err
271271- }
272272- for _, item := range v {
273273- if err := writeValue(w, item); err != nil {
274274- return err
275275- }
276276- }
277277- return nil
278278- default:
279279- // For unknown types, convert to JSON then write as string
280280- jsonBytes, err := json.Marshal(v)
281281- if err != nil {
282282- return err
283283- }
284284- if err := cbg.WriteMajorTypeHeader(w, cbg.MajTextString, uint64(len(jsonBytes))); err != nil {
285285- return err
286286- }
287287- _, err = w.Write(jsonBytes)
288288- return err
289289- }
290290-}
291291-292292-// HandleCreateRecord handles com.atproto.repo.createRecord
293293-func (h *XRPCHandler) HandleCreateRecord(w http.ResponseWriter, r *http.Request) {
294294- // Validate JWT authentication
295295- user, err := ValidateJWTAuth(r, h.pds)
296296- if err != nil {
297297- http.Error(w, fmt.Sprintf("authentication required: %v", err), http.StatusUnauthorized)
298298- return
299299- }
300300-301301- // Parse request
302302- var req CreateRecordRequest
303303- if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
304304- http.Error(w, fmt.Sprintf("invalid request body: %v", err), http.StatusBadRequest)
305305- return
306306- }
307307-308308- // Validate required fields
309309- if req.Repo == "" || req.Collection == "" || req.Record == nil {
310310- http.Error(w, "repo, collection, and record are required", http.StatusBadRequest)
311311- return
312312- }
313313-314314- // Verify repo matches authenticated user
315315- if req.Repo != user.DID {
316316- http.Error(w, "repo must match authenticated user DID", http.StatusForbidden)
317317- return
318318- }
319319-320320- // Verify repo matches this hold's DID
321321- if req.Repo != h.pds.DID() {
322322- http.Error(w, "invalid repo (must be this hold's DID)", http.StatusBadRequest)
323323- return
324324- }
325325-326326- // Convert record from JSON to CBOR-marshalable format
327327- recordMap, ok := req.Record.(map[string]interface{})
328328- if !ok {
329329- http.Error(w, "record must be a JSON object", http.StatusBadRequest)
330330- return
331331- }
332332-333333- // Wrap in RawRecord which implements CBORMarshaler
334334- recordValue := &RawRecord{Value: recordMap}
335335-336336- // Create record using repomgr
337337- var recordPath string
338338- var recordCID cid.Cid
339339-340340- if req.Rkey != "" {
341341- // Use PutRecord if rkey is specified
342342- recordPath, recordCID, err = h.pds.repomgr.PutRecord(
343343- r.Context(),
344344- h.pds.uid,
345345- req.Collection,
346346- req.Rkey,
347347- recordValue,
348348- )
349349- } else {
350350- // Use CreateRecord if no rkey (auto-generates TID)
351351- recordPath, recordCID, err = h.pds.repomgr.CreateRecord(
352352- r.Context(),
353353- h.pds.uid,
354354- req.Collection,
355355- recordValue,
356356- )
357357- }
358358-359359- if err != nil {
360360- http.Error(w, fmt.Sprintf("failed to create record: %v", err), http.StatusInternalServerError)
361361- return
362362- }
363363-364364- // Extract rkey from path (format: "collection/rkey")
365365- parts := strings.Split(recordPath, "/")
366366- if len(parts) < 2 {
367367- http.Error(w, "invalid record path returned", http.StatusInternalServerError)
368368- return
369369- }
370370- actualRkey := parts[len(parts)-1]
371371-372372- // Return success response
373373- response := CreateRecordResponse{
374374- URI: fmt.Sprintf("at://%s/%s/%s", h.pds.DID(), req.Collection, actualRkey),
375375- CID: recordCID.String(),
376376- }
377377-378378- w.Header().Set("Content-Type", "application/json")
379379- w.WriteHeader(http.StatusCreated)
380380- json.NewEncoder(w).Encode(response)
381381-}
+11-57
pkg/hold/pds/status.go
···66 "time"
7788 bsky "github.com/bluesky-social/indigo/api/bsky"
99- "github.com/ipfs/go-cid"
109)
11101211const (
1313- // StatusPostRkey is the fixed rkey for the status post (singleton)
1414- StatusPostRkey = "status"
1515-1612 // StatusPostCollection is the collection name for Bluesky posts
1713 StatusPostCollection = "app.bsky.feed.post"
1814)
19152020-// SetStatus creates or updates the hold's status post on Bluesky
1616+// SetStatus creates a new status post on Bluesky
2117// status should be "online" or "offline"
1818+// Each call creates a unique post with a TID-based rkey
2219func (p *HoldPDS) SetStatus(ctx context.Context, status string) error {
2320 // Format the post text with emoji indicator
2421 emoji := "🟢"
···2724 }
2825 text := fmt.Sprintf("%s Current status: %s", emoji, status)
29263030- // Check if status post already exists
3131- _, existingPost, err := p.GetStatusPost(ctx)
3232- if err != nil {
3333- // Post doesn't exist, create it
3434- return p.createStatusPost(ctx, text)
3535- }
3636-3737- // Post exists, update it
3838- // We need to preserve the original CreatedAt timestamp
3939- return p.updateStatusPost(ctx, text, existingPost.CreatedAt)
4040-}
4141-4242-// GetStatusPost retrieves the status post if it exists
4343-func (p *HoldPDS) GetStatusPost(ctx context.Context) (cid.Cid, *bsky.FeedPost, error) {
4444- // Use repomgr.GetRecord
4545- recordCID, val, err := p.repomgr.GetRecord(ctx, p.uid, StatusPostCollection, StatusPostRkey, cid.Undef)
4646- if err != nil {
4747- return cid.Undef, nil, fmt.Errorf("failed to get status post: %w", err)
4848- }
4949-5050- // Type assert to bsky.FeedPost
5151- post, ok := val.(*bsky.FeedPost)
5252- if !ok {
5353- return cid.Undef, nil, fmt.Errorf("unexpected type for status post: %T", val)
5454- }
5555-5656- return recordCID, post, nil
2727+ // Create the post with a unique TID
2828+ return p.createStatusPost(ctx, text)
5729}
58305959-// createStatusPost creates a new status post (first time)
3131+// createStatusPost creates a new status post with a TID-based rkey
6032func (p *HoldPDS) createStatusPost(ctx context.Context, text string) error {
6133 // Create post struct
6262- now := time.Now().Format(time.RFC3339)
3434+ now := time.Now()
6335 post := &bsky.FeedPost{
6436 LexiconTypeID: "app.bsky.feed.post",
6537 Text: text,
6666- CreatedAt: now,
3838+ CreatedAt: now.Format(time.RFC3339),
6739 }
68406969- // Use repomgr.PutRecord - creates with explicit rkey, fails if already exists
7070- recordPath, recordCID, err := p.repomgr.PutRecord(ctx, p.uid, StatusPostCollection, StatusPostRkey, post)
4141+ // Use repomgr.CreateRecord to create the post with auto-generated TID
4242+ // CreateRecord automatically generates a unique TID using the repo's clock
4343+ rkey, recordCID, err := p.repomgr.CreateRecord(ctx, p.uid, StatusPostCollection, post)
7144 if err != nil {
7245 return fmt.Errorf("failed to create status post: %w", err)
7346 }
74477575- fmt.Printf("Created status post at %s, cid: %s, text: %s\n", recordPath, recordCID, text)
7676- return nil
7777-}
7878-7979-// updateStatusPost updates an existing status post
8080-func (p *HoldPDS) updateStatusPost(ctx context.Context, text string, createdAt string) error {
8181- // Create updated post struct with original CreatedAt
8282- post := &bsky.FeedPost{
8383- LexiconTypeID: "app.bsky.feed.post",
8484- Text: text,
8585- CreatedAt: createdAt, // Preserve original creation time
8686- }
8787-8888- // Use repomgr.UpdateRecord
8989- recordCID, err := p.repomgr.UpdateRecord(ctx, p.uid, StatusPostCollection, StatusPostRkey, post)
9090- if err != nil {
9191- return fmt.Errorf("failed to update status post: %w", err)
9292- }
9393-9494- fmt.Printf("Updated status post, cid: %s, text: %s\n", recordCID, text)
4848+ fmt.Printf("Created status post at %s/%s (rkey: %s), cid: %s, text: %s\n", StatusPostCollection, rkey, rkey, recordCID, text)
9549 return nil
9650}
+201-64
pkg/hold/pds/status_test.go
···2233import (
44 "context"
55+ "encoding/json"
66+ "fmt"
77+ "net/http"
88+ "net/http/httptest"
59 "os"
610 "path/filepath"
711 "testing"
1212+ "time"
8131414+ "atcr.io/pkg/atproto"
1515+ "atcr.io/pkg/s3"
916 bsky "github.com/bluesky-social/indigo/api/bsky"
1017)
11181919+// Shared test resources (used across all test files in package)
2020+var (
2121+ sharedTestKeyPath string
2222+ sharedTestKey []byte
2323+ sharedPDS *HoldPDS // Shared bootstrapped PDS for read-only tests
2424+ sharedHandler *XRPCHandler // Shared handler for read-only tests
2525+ sharedCtx context.Context // Shared context
2626+)
2727+1228func TestStatusPost(t *testing.T) {
1313- // Create temporary directory for test database
2929+ // Create temporary directory for test
1430 tmpDir := t.TempDir()
1515- dbPath := filepath.Join(tmpDir, "test.db")
3131+ // Use in-memory database for speed
3232+ dbPath := ":memory:"
1633 keyPath := filepath.Join(tmpDir, "test.key")
3434+3535+ // Copy shared signing key
3636+ if err := os.WriteFile(keyPath, sharedTestKey, 0600); err != nil {
3737+ t.Fatalf("Failed to copy shared signing key: %v", err)
3838+ }
17391840 // Create test PDS
1941 ctx := context.Background()
···3153 t.Fatalf("Failed to initialize repo: %v", err)
3254 }
33555656+ // Create handler for XRPC endpoints
5757+ handler := NewXRPCHandler(holdPDS, s3.S3Service{}, nil, nil, &mockPDSClient{})
5858+5959+ // Helper function to list posts via XRPC
6060+ listPosts := func() ([]map[string]any, error) {
6161+ req := makeXRPCGetRequest(atproto.RepoListRecords, map[string]string{
6262+ "repo": did,
6363+ "collection": StatusPostCollection,
6464+ "limit": "100",
6565+ "reverse": "true", // Most recent first
6666+ })
6767+ w := httptest.NewRecorder()
6868+ handler.HandleListRecords(w, req)
6969+7070+ if w.Code != http.StatusOK {
7171+ return nil, fmt.Errorf("unexpected status code: %d, body: %s", w.Code, w.Body.String())
7272+ }
7373+7474+ var result map[string]any
7575+ if err := json.NewDecoder(w.Body).Decode(&result); err != nil {
7676+ return nil, fmt.Errorf("failed to decode response: %w", err)
7777+ }
7878+7979+ records, ok := result["records"].([]any)
8080+ if !ok {
8181+ return nil, fmt.Errorf("expected records array, got %T", result["records"])
8282+ }
8383+8484+ posts := make([]map[string]any, len(records))
8585+ for i, rec := range records {
8686+ post, ok := rec.(map[string]any)
8787+ if !ok {
8888+ return nil, fmt.Errorf("expected record map, got %T", rec)
8989+ }
9090+ posts[i] = post
9191+ }
9292+ return posts, nil
9393+ }
9494+3495 t.Run("CreateStatusPost", func(t *testing.T) {
3596 // Set status to online (creates new post)
3697 err := holdPDS.SetStatus(ctx, "online")
···3899 t.Fatalf("Failed to set status to online: %v", err)
39100 }
401014141- // Verify post was created
4242- _, post, err := holdPDS.GetStatusPost(ctx)
102102+ // List posts
103103+ posts, err := listPosts()
43104 if err != nil {
4444- t.Fatalf("Failed to get status post: %v", err)
105105+ t.Fatalf("Failed to list posts: %v", err)
106106+ }
107107+108108+ if len(posts) == 0 {
109109+ t.Fatal("Expected at least one status post, got 0")
110110+ }
111111+112112+ // Get the latest post
113113+ post := posts[0]
114114+115115+ value, ok := post["value"].(map[string]any)
116116+ if !ok {
117117+ t.Fatalf("Expected value map, got %T", post["value"])
45118 }
461194747- if post.Text != "🟢 Current status: online" {
4848- t.Errorf("Expected text '🟢 Current status: online', got '%s'", post.Text)
120120+ text, ok := value["text"].(string)
121121+ if !ok {
122122+ t.Fatalf("Expected text string, got %T", value["text"])
49123 }
501245151- if post.LexiconTypeID != "app.bsky.feed.post" {
5252- t.Errorf("Expected LexiconTypeID 'app.bsky.feed.post', got '%s'", post.LexiconTypeID)
125125+ if text != "🟢 Current status: online" {
126126+ t.Errorf("Expected text '🟢 Current status: online', got '%s'", text)
53127 }
541285555- if post.CreatedAt == "" {
5656- t.Error("CreatedAt should not be empty")
129129+ // Verify TID-based rkey (extract from URI)
130130+ uri, ok := post["uri"].(string)
131131+ if !ok {
132132+ t.Fatalf("Expected uri string, got %T", post["uri"])
133133+ }
134134+ // URI format: at://did:web:test.example.com/app.bsky.feed.post/3m3c4...
135135+ // We just check that it contains the collection
136136+ if !contains(uri, StatusPostCollection) {
137137+ t.Errorf("Expected URI to contain collection %s, got %s", StatusPostCollection, uri)
57138 }
58139 })
591406060- t.Run("UpdateStatusPost", func(t *testing.T) {
6161- // Get the original post to check CreatedAt preservation
6262- _, originalPost, err := holdPDS.GetStatusPost(ctx)
141141+ t.Run("CreateMultiplePosts", func(t *testing.T) {
142142+ // Create multiple status posts
143143+ err := holdPDS.SetStatus(ctx, "offline")
63144 if err != nil {
6464- t.Fatalf("Failed to get original status post: %v", err)
145145+ t.Fatalf("Failed to set status to offline: %v", err)
65146 }
661476767- // Set status to offline (updates existing post)
6868- err = holdPDS.SetStatus(ctx, "offline")
148148+ // Wait a moment to ensure different timestamp
149149+ time.Sleep(10 * time.Millisecond)
150150+151151+ err = holdPDS.SetStatus(ctx, "online")
69152 if err != nil {
7070- t.Fatalf("Failed to set status to offline: %v", err)
153153+ t.Fatalf("Failed to set status to online again: %v", err)
71154 }
721557373- // Verify post was updated
7474- _, post, err := holdPDS.GetStatusPost(ctx)
156156+ // List all posts - should have at least 3 now (1 from previous test + 2 from this test)
157157+ posts, err := listPosts()
75158 if err != nil {
7676- t.Fatalf("Failed to get updated status post: %v", err)
159159+ t.Fatalf("Failed to list posts: %v", err)
77160 }
781617979- if post.Text != "🔴 Current status: offline" {
8080- t.Errorf("Expected text '🔴 Current status: offline', got '%s'", post.Text)
162162+ if len(posts) < 3 {
163163+ t.Errorf("Expected at least 3 status posts, got %d", len(posts))
164164+ }
165165+166166+ // Verify each post has a unique URI
167167+ uris := make(map[string]bool)
168168+ for _, post := range posts {
169169+ uri, ok := post["uri"].(string)
170170+ if !ok {
171171+ t.Errorf("Expected uri string, got %T", post["uri"])
172172+ continue
173173+ }
174174+ if uris[uri] {
175175+ t.Errorf("Duplicate URI found: %s", uri)
176176+ }
177177+ uris[uri] = true
81178 }
821798383- // Verify CreatedAt was preserved
8484- if post.CreatedAt != originalPost.CreatedAt {
8585- t.Errorf("CreatedAt should be preserved. Expected '%s', got '%s'", originalPost.CreatedAt, post.CreatedAt)
180180+ // Verify the latest post is online
181181+ latestPost := posts[0]
182182+ value, ok := latestPost["value"].(map[string]any)
183183+ if !ok {
184184+ t.Fatalf("Expected value map, got %T", latestPost["value"])
185185+ }
186186+ text, ok := value["text"].(string)
187187+ if !ok {
188188+ t.Fatalf("Expected text string, got %T", value["text"])
189189+ }
190190+ if text != "🟢 Current status: online" {
191191+ t.Errorf("Expected latest post text '🟢 Current status: online', got '%s'", text)
86192 }
87193 })
881948989- t.Run("ToggleStatus", func(t *testing.T) {
9090- // Toggle back to online
9191- err := holdPDS.SetStatus(ctx, "online")
195195+ t.Run("OfflineStatus", func(t *testing.T) {
196196+ // Create offline status post
197197+ err := holdPDS.SetStatus(ctx, "offline")
92198 if err != nil {
9393- t.Fatalf("Failed to set status to online: %v", err)
199199+ t.Fatalf("Failed to set status to offline: %v", err)
94200 }
952019696- _, post, err := holdPDS.GetStatusPost(ctx)
202202+ // Get the latest post
203203+ posts, err := listPosts()
97204 if err != nil {
9898- t.Fatalf("Failed to get status post: %v", err)
205205+ t.Fatalf("Failed to list posts: %v", err)
206206+ }
207207+208208+ if len(posts) == 0 {
209209+ t.Fatal("Expected at least one status post, got 0")
210210+ }
211211+212212+ latestPost := posts[0]
213213+ value, ok := latestPost["value"].(map[string]any)
214214+ if !ok {
215215+ t.Fatalf("Expected value map, got %T", latestPost["value"])
216216+ }
217217+ text, ok := value["text"].(string)
218218+ if !ok {
219219+ t.Fatalf("Expected text string, got %T", value["text"])
99220 }
100221101101- if post.Text != "🟢 Current status: online" {
102102- t.Errorf("Expected text '🟢 Current status: online', got '%s'", post.Text)
222222+ if text != "🔴 Current status: offline" {
223223+ t.Errorf("Expected text '🔴 Current status: offline', got '%s'", text)
103224 }
104225 })
105226}
106227107228func TestStatusPostCollection(t *testing.T) {
108108- // Verify constants
229229+ // Verify constant
109230 if StatusPostCollection != "app.bsky.feed.post" {
110231 t.Errorf("Expected StatusPostCollection 'app.bsky.feed.post', got '%s'", StatusPostCollection)
111232 }
112112-113113- if StatusPostRkey != "status" {
114114- t.Errorf("Expected StatusPostRkey 'status', got '%s'", StatusPostRkey)
115115- }
116233}
117234118118-func TestGetStatusPostNotExists(t *testing.T) {
119119- // Create temporary directory for test database
120120- tmpDir := t.TempDir()
121121- dbPath := filepath.Join(tmpDir, "test.db")
122122- keyPath := filepath.Join(tmpDir, "test.key")
235235+// Helper function to check if a string contains a substring
236236+func contains(s, substr string) bool {
237237+ return len(s) >= len(substr) && (s == substr || len(s) > len(substr) && findSubstring(s, substr))
238238+}
123239124124- // Create test PDS
125125- ctx := context.Background()
126126- did := "did:web:test2.example.com"
127127- publicURL := "https://test2.example.com"
128128-129129- holdPDS, err := NewHoldPDS(ctx, did, publicURL, dbPath, keyPath)
130130- if err != nil {
131131- t.Fatalf("Failed to create test PDS: %v", err)
240240+func findSubstring(s, substr string) bool {
241241+ for i := 0; i <= len(s)-len(substr); i++ {
242242+ if s[i:i+len(substr)] == substr {
243243+ return true
244244+ }
132245 }
133133-134134- // Initialize empty repo
135135- err = holdPDS.repomgr.InitNewActor(ctx, holdPDS.uid, "", did, "", "", "")
136136- if err != nil {
137137- t.Fatalf("Failed to initialize repo: %v", err)
138138- }
139139-140140- // Try to get status post that doesn't exist
141141- _, _, err = holdPDS.GetStatusPost(ctx)
142142- if err == nil {
143143- t.Error("Expected error when getting non-existent status post, got nil")
144144- }
246246+ return false
145247}
146248147249func init() {
···154256155257// Cleanup function to remove test files
156258func TestMain(m *testing.M) {
259259+ // Create a temporary directory for shared test key
260260+ tmpDir, err := os.MkdirTemp("", "pds-test-shared-*")
261261+ if err != nil {
262262+ panic(fmt.Sprintf("Failed to create temp dir: %v", err))
263263+ }
264264+ defer os.RemoveAll(tmpDir)
265265+266266+ // Generate one signing key to be reused across all tests in the package
267267+ sharedTestKeyPath = filepath.Join(tmpDir, "shared-signing-key")
268268+ privateKey, err := GenerateOrLoadKey(sharedTestKeyPath)
269269+ if err != nil {
270270+ panic(fmt.Sprintf("Failed to generate shared signing key: %v", err))
271271+ }
272272+273273+ // Store the key bytes so tests can copy them
274274+ sharedTestKey = privateKey.Bytes()
275275+276276+ // Create one shared, bootstrapped PDS for read-only tests
277277+ // Use in-memory database for speed
278278+ sharedCtx = context.Background()
279279+ sharedPDS, err = NewHoldPDS(sharedCtx, "did:web:hold.example.com", "https://hold.example.com", ":memory:", sharedTestKeyPath)
280280+ if err != nil {
281281+ panic(fmt.Sprintf("Failed to create shared PDS: %v", err))
282282+ }
283283+284284+ // Bootstrap once
285285+ ownerDID := "did:plc:testowner123"
286286+ err = sharedPDS.Bootstrap(sharedCtx, nil, ownerDID, true, false, "")
287287+ if err != nil {
288288+ panic(fmt.Sprintf("Failed to bootstrap shared PDS: %v", err))
289289+ }
290290+291291+ // Create shared handler
292292+ sharedHandler = NewXRPCHandler(sharedPDS, s3.S3Service{}, nil, nil, &mockPDSClient{})
293293+157294 // Run tests
158295 code := m.Run()
159296
+6-29
pkg/hold/pds/xrpc.go
···92929393 // Handle OPTIONS preflight
9494 if r.Method == "OPTIONS" {
9595- w.WriteHeader(http.StatusNoContent)
9595+ w.WriteHeader(http.StatusOK)
9696 return
9797 }
9898···150150 r.Get("/xrpc/_health", h.HandleHealth)
151151 r.Get(atproto.ServerDescribeServer, h.HandleDescribeServer)
152152153153- // Session management (public - creates sessions)
154154- r.Post(atproto.ServerCreateSession, h.HandleCreateSession)
155155- r.Post(atproto.ServerRefreshSession, h.HandleRefreshSession)
156156- r.Get(atproto.ServerGetSession, h.HandleGetSession)
157157-158153 // Repository metadata
159154 r.Get(atproto.RepoDescribeRepo, h.HandleDescribeRepo)
160155 r.Get(atproto.RepoGetRecord, h.HandleGetRecord)
···186181 // Write endpoints (owner/crew admin auth)
187182 r.Group(func(r chi.Router) {
188183 r.Use(h.requireOwnerOrCrewAdmin)
189189-190184 r.Post(atproto.RepoDeleteRecord, h.HandleDeleteRecord)
191185 r.Post(atproto.RepoUploadBlob, h.HandleUploadBlob)
192186 })
···194188 // Auth-only endpoints (DPoP auth)
195189 r.Group(func(r chi.Router) {
196190 r.Use(h.requireAuth)
197197-198191 r.Post(atproto.HoldRequestCrew, h.HandleRequestCrew)
199199- })
200200-201201- // JWT-authenticated endpoints (JWT auth from createSession)
202202- // Note: JWT auth is validated inside each handler
203203- r.Group(func(r chi.Router) {
204204- r.Post("/xrpc/com.atproto.repo.createRecord", h.HandleCreateRecord)
205192 })
206193}
207194···861848 }
862849863850 // Get optional cursor parameter for backfill
864864- var cursor int64 = 0
851851+ // Default to -1 (no backfill, only stream new events)
852852+ // cursor=0 means "replay all events from the beginning"
853853+ var cursor int64 = -1
865854 if cursorStr := r.URL.Query().Get("cursor"); cursorStr != "" {
866855 var err error
867856 cursor, err = strconv.ParseInt(cursorStr, 10, 64)
···879868 }
880869881870 // Subscribe to events
882882- sub := h.broadcaster.Subscribe(conn, cursor)
883883-884871 // The broadcaster's handleSubscriber goroutine will manage this connection
885885- // We just need to keep reading to detect client disconnects
886886- go func() {
887887- defer h.broadcaster.Unsubscribe(sub)
888888- for {
889889- // Read messages from client (mostly just to detect disconnect)
890890- _, _, err := conn.ReadMessage()
891891- if err != nil {
892892- // Client disconnected
893893- break
894894- }
895895- }
896896- }()
872872+ // and handle cleanup when the client disconnects
873873+ h.broadcaster.Subscribe(conn, cursor)
897874}
898875899876// HandleUploadBlob handles blob uploads with support for multipart operations
+257-3
pkg/hold/pds/xrpc_test.go
···1212 "path/filepath"
1313 "strings"
1414 "testing"
1515+ "time"
15161617 "atcr.io/pkg/atproto"
1718 "atcr.io/pkg/s3"
1919+ indigoAtproto "github.com/bluesky-social/indigo/api/atproto"
2020+ "github.com/bluesky-social/indigo/events"
1821 "github.com/distribution/distribution/v3/registry/storage/driver/factory"
1922 _ "github.com/distribution/distribution/v3/registry/storage/driver/filesystem"
2023 "github.com/go-chi/chi/v5"
2424+ "github.com/gorilla/websocket"
2525+ "github.com/ipfs/go-cid"
2126)
22272328// Test helpers
···3035 ctx := context.Background()
3136 tmpDir := t.TempDir()
32373333- dbPath := filepath.Join(tmpDir, "pds.db")
3838+ // Use in-memory database for speed
3939+ dbPath := ":memory:"
3440 keyPath := filepath.Join(tmpDir, "signing-key")
4141+4242+ // Copy shared signing key instead of generating a new one
4343+ if err := os.WriteFile(keyPath, sharedTestKey, 0600); err != nil {
4444+ t.Fatalf("Failed to copy shared signing key: %v", err)
4545+ }
35463647 pds, err := NewHoldPDS(ctx, "did:web:hold.example.com", "https://hold.example.com", dbPath, keyPath)
3748 if err != nil {
···115126 }
116127117128 return result
129129+}
130130+131131+// decodeFirehoseMessage decodes an ATProto firehose message (header + CBOR body)
132132+func decodeFirehoseMessage(t *testing.T, message []byte) (*events.EventHeader, *indigoAtproto.SyncSubscribeRepos_Commit) {
133133+ t.Helper()
134134+135135+ reader := bytes.NewReader(message)
136136+137137+ // Decode header
138138+ var header events.EventHeader
139139+ if err := header.UnmarshalCBOR(reader); err != nil {
140140+ t.Fatalf("Failed to decode event header: %v", err)
141141+ }
142142+143143+ // Verify it's a commit event
144144+ if header.MsgType != "#commit" {
145145+ t.Fatalf("Expected #commit event, got %s", header.MsgType)
146146+ }
147147+148148+ // Decode commit event
149149+ var commit indigoAtproto.SyncSubscribeRepos_Commit
150150+ if err := commit.UnmarshalCBOR(reader); err != nil {
151151+ t.Fatalf("Failed to decode commit event: %v", err)
152152+ }
153153+154154+ return &header, &commit
118155}
119156120157// assertCARResponse validates CAR file response
···13311368 ctx := context.Background()
13321369 tmpDir := t.TempDir()
1333137013341334- dbPath := filepath.Join(tmpDir, "pds.db")
13711371+ // Use in-memory database for speed
13721372+ dbPath := ":memory:"
13351373 keyPath := filepath.Join(tmpDir, "signing-key")
13741374+13751375+ // Copy shared signing key instead of generating a new one
13761376+ if err := os.WriteFile(keyPath, sharedTestKey, 0600); err != nil {
13771377+ t.Fatalf("Failed to copy shared signing key: %v", err)
13781378+ }
1336137913371380 pds, err := NewHoldPDS(ctx, "did:web:hold.example.com", "https://hold.example.com", dbPath, keyPath)
13381381 if err != nil {
···16861729 w := httptest.NewRecorder()
1687173016881731 // Wrap with CORS middleware (chi-style)
16891689- corsHandler := handler.corsMiddleware(http.HandlerFunc(handler.HandleGetBlob))
17321732+ corsHandler := handler.CORSMiddleware()(http.HandlerFunc(handler.HandleGetBlob))
16901733 corsHandler.ServeHTTP(w, req)
1691173416921735 // Verify CORS headers are present
···1719176217201763 // Create chi router and register handlers
17211764 r := chi.NewRouter()
17651765+ r.Use(handler.CORSMiddleware()) // Apply CORS middleware
17221766 handler.RegisterHandlers(r)
1723176717241768 tests := []struct {
···20092053 })
20102054 }
20112055}
20562056+20572057+// TestHandleSubscribeRepos tests the WebSocket firehose endpoint
20582058+func TestHandleSubscribeRepos(t *testing.T) {
20592059+ handler, ctx := setupTestXRPCHandler(t)
20602060+20612061+ // Create EventBroadcaster
20622062+ broadcaster := NewEventBroadcaster(handler.pds.DID(), 100)
20632063+ handler.broadcaster = broadcaster
20642064+20652065+ // Set up test HTTP server
20662066+ r := chi.NewRouter()
20672067+ handler.RegisterHandlers(r)
20682068+ server := httptest.NewServer(r)
20692069+ defer server.Close()
20702070+20712071+ // Broadcast some events before connecting
20722072+ testCID, _ := cid.Decode("bafyreib2rxk3rkhh5ylyxj3x3gathxt3s32qvwj2lf3qg4kmzr6b7teqke")
20732073+ for i := 1; i <= 3; i++ {
20742074+ event := &RepoEvent{
20752075+ NewRoot: testCID,
20762076+ Rev: fmt.Sprintf("rev-%d", i),
20772077+ RepoSlice: []byte(fmt.Sprintf("CAR data %d", i)),
20782078+ Ops: []RepoOp{},
20792079+ }
20802080+ broadcaster.Broadcast(ctx, event)
20812081+ }
20822082+20832083+ // Verify events were stored
20842084+ if broadcaster.eventSeq != 3 {
20852085+ t.Fatalf("Expected eventSeq=3, got %d", broadcaster.eventSeq)
20862086+ }
20872087+20882088+ t.Run("cursor=0 replays all events", func(t *testing.T) {
20892089+ // Connect to WebSocket with cursor=0
20902090+ wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/xrpc/com.atproto.sync.subscribeRepos?cursor=0"
20912091+ conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
20922092+ if err != nil {
20932093+ t.Fatalf("Failed to connect to WebSocket: %v", err)
20942094+ }
20952095+ defer conn.Close()
20962096+20972097+ // Should receive the 3 historical events
20982098+ for i := 0; i < 3; i++ {
20992099+ messageType, message, err := conn.ReadMessage()
21002100+ if err != nil {
21012101+ t.Fatalf("Failed to read message: %v", err)
21022102+ }
21032103+21042104+ if messageType != websocket.BinaryMessage {
21052105+ t.Errorf("Expected binary message, got type %d", messageType)
21062106+ }
21072107+21082108+ // Decode CBOR message (header + commit)
21092109+ header, commit := decodeFirehoseMessage(t, message)
21102110+21112111+ // Verify header
21122112+ if header.MsgType != "#commit" {
21132113+ t.Errorf("Expected MsgType=#commit, got %s", header.MsgType)
21142114+ }
21152115+21162116+ // Verify commit fields
21172117+ expectedSeq := int64(i + 1)
21182118+ if commit.Seq != expectedSeq {
21192119+ t.Errorf("Expected seq=%d, got %d", expectedSeq, commit.Seq)
21202120+ }
21212121+ if commit.Repo != handler.pds.DID() {
21222122+ t.Errorf("Expected repo=%s, got %s", handler.pds.DID(), commit.Repo)
21232123+ }
21242124+ }
21252125+ })
21262126+21272127+ t.Run("cursor=2 only replays events after 2", func(t *testing.T) {
21282128+ // Connect with cursor=2
21292129+ wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/xrpc/com.atproto.sync.subscribeRepos?cursor=2"
21302130+ conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
21312131+ if err != nil {
21322132+ t.Fatalf("Failed to connect to WebSocket: %v", err)
21332133+ }
21342134+ defer conn.Close()
21352135+21362136+ // Should only receive event 3 (after cursor=2)
21372137+ messageType, message, err := conn.ReadMessage()
21382138+ if err != nil {
21392139+ t.Fatalf("Failed to read message: %v", err)
21402140+ }
21412141+21422142+ if messageType != websocket.BinaryMessage {
21432143+ t.Errorf("Expected binary message, got type %d", messageType)
21442144+ }
21452145+21462146+ header, commit := decodeFirehoseMessage(t, message)
21472147+ if header.MsgType != "#commit" {
21482148+ t.Errorf("Expected MsgType=#commit, got %s", header.MsgType)
21492149+ }
21502150+21512151+ if commit.Seq != 3 {
21522152+ t.Errorf("Expected seq=3, got %d", commit.Seq)
21532153+ }
21542154+21552155+ // Verify no more events (use timeout)
21562156+ conn.SetReadDeadline(time.Now().Add(100 * time.Millisecond))
21572157+ _, _, err = conn.ReadMessage()
21582158+ if err == nil {
21592159+ t.Error("Expected no more events, but received another message")
21602160+ }
21612161+ })
21622162+21632163+ t.Run("no cursor streams only new events", func(t *testing.T) {
21642164+ // Connect without cursor (should not get backfill)
21652165+ wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/xrpc/com.atproto.sync.subscribeRepos"
21662166+ conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
21672167+ if err != nil {
21682168+ t.Fatalf("Failed to connect to WebSocket: %v", err)
21692169+ }
21702170+ defer conn.Close()
21712171+21722172+ // Verify no historical events by broadcasting immediately and checking
21732173+ // that we only receive the new event (not historical ones)
21742174+ // Give subscriber time to register first
21752175+ time.Sleep(100 * time.Millisecond)
21762176+21772177+ // Broadcast a new event (seq 4)
21782178+ newEvent := &RepoEvent{
21792179+ NewRoot: testCID,
21802180+ Rev: "rev-4",
21812181+ RepoSlice: []byte("CAR data 4"),
21822182+ Ops: []RepoOp{},
21832183+ }
21842184+ broadcaster.Broadcast(ctx, newEvent)
21852185+21862186+ // Should receive ONLY the new event (seq 4), not historical events 1-3
21872187+ conn.SetReadDeadline(time.Now().Add(1 * time.Second))
21882188+ messageType, message, err := conn.ReadMessage()
21892189+ if err != nil {
21902190+ t.Fatalf("Failed to read new event: %v", err)
21912191+ }
21922192+21932193+ if messageType != websocket.BinaryMessage {
21942194+ t.Errorf("Expected binary message, got type %d", messageType)
21952195+ }
21962196+21972197+ header, commit := decodeFirehoseMessage(t, message)
21982198+ if header.MsgType != "#commit" {
21992199+ t.Errorf("Expected MsgType=#commit, got %s", header.MsgType)
22002200+ }
22012201+22022202+ // Key assertion: should be seq 4 (new event), not seq 1 (historical backfill)
22032203+ if commit.Seq != 4 {
22042204+ t.Errorf("Expected seq=4 for new event (no backfill), got %d", commit.Seq)
22052205+ }
22062206+22072207+ // Verify no more messages (no historical backfill)
22082208+ conn.SetReadDeadline(time.Now().Add(100 * time.Millisecond))
22092209+ _, _, err = conn.ReadMessage()
22102210+ if err == nil {
22112211+ t.Error("Expected no more events, but received another message (possible backfill leak)")
22122212+ }
22132213+ })
22142214+22152215+ t.Run("real-time event delivery", func(t *testing.T) {
22162216+ // Connect with cursor=0 to get all events first
22172217+ wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/xrpc/com.atproto.sync.subscribeRepos?cursor=0"
22182218+ conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
22192219+ if err != nil {
22202220+ t.Fatalf("Failed to connect to WebSocket: %v", err)
22212221+ }
22222222+ defer conn.Close()
22232223+22242224+ // Read and discard the 4 historical events (seq 1-4)
22252225+ for i := 0; i < 4; i++ {
22262226+ _, _, err := conn.ReadMessage()
22272227+ if err != nil {
22282228+ t.Fatalf("Failed to read historical event %d: %v", i+1, err)
22292229+ }
22302230+ }
22312231+22322232+ // Broadcast 2 new events
22332233+ for i := 5; i <= 6; i++ {
22342234+ newEvent := &RepoEvent{
22352235+ NewRoot: testCID,
22362236+ Rev: fmt.Sprintf("rev-%d", i),
22372237+ RepoSlice: []byte(fmt.Sprintf("CAR data %d", i)),
22382238+ Ops: []RepoOp{},
22392239+ }
22402240+ broadcaster.Broadcast(ctx, newEvent)
22412241+ }
22422242+22432243+ // Should receive both new events
22442244+ for expectedSeq := 5; expectedSeq <= 6; expectedSeq++ {
22452245+ conn.SetReadDeadline(time.Now().Add(1 * time.Second))
22462246+ messageType, message, err := conn.ReadMessage()
22472247+ if err != nil {
22482248+ t.Fatalf("Failed to read event seq=%d: %v", expectedSeq, err)
22492249+ }
22502250+22512251+ if messageType != websocket.BinaryMessage {
22522252+ t.Errorf("Expected binary message, got type %d", messageType)
22532253+ }
22542254+22552255+ header, commit := decodeFirehoseMessage(t, message)
22562256+ if header.MsgType != "#commit" {
22572257+ t.Errorf("Expected MsgType=#commit, got %s", header.MsgType)
22582258+ }
22592259+22602260+ if commit.Seq != int64(expectedSeq) {
22612261+ t.Errorf("Expected seq=%d, got %d", expectedSeq, commit.Seq)
22622262+ }
22632263+ }
22642264+ })
22652265+}