A container registry that uses the AT Protocol for manifest storage and S3 for blob storage.
0
fork

Configure Feed

Select the types of activity you want to include in your feed.

try and add cursor=0 to subscribe

+843 -1351
+1
cmd/hold/main.go
··· 147 147 // Update status post to "online" after server starts 148 148 if holdPDS != nil { 149 149 ctx := context.Background() 150 + 150 151 if err := holdPDS.SetStatus(ctx, "online"); err != nil { 151 152 log.Printf("Warning: Failed to set status post to online: %v", err) 152 153 } else {
+2 -2
go.mod
··· 4 4 5 5 require ( 6 6 github.com/aws/aws-sdk-go v1.55.5 7 - github.com/bluesky-social/indigo v0.0.0-20251014222321-1e8718ae9f33 7 + github.com/bluesky-social/indigo v0.0.0-20251021193747-543ab1124beb 8 8 github.com/distribution/distribution/v3 v3.0.0 9 9 github.com/distribution/reference v0.6.0 10 10 github.com/go-chi/chi/v5 v5.2.3 ··· 34 34 ) 35 35 36 36 require ( 37 + github.com/RussellLuo/slidingwindow v0.0.0-20200528002341-535bb99d338b // indirect 37 38 github.com/aymerick/douceur v0.2.0 // indirect 38 39 github.com/beorn7/perks v1.0.1 // indirect 39 40 github.com/bshuster-repo/logrus-logstash-hook v1.0.0 // indirect ··· 46 47 github.com/docker/go-metrics v0.0.1 // indirect 47 48 github.com/earthboundkid/versioninfo/v2 v2.24.1 // indirect 48 49 github.com/felixge/httpsnoop v1.0.4 // indirect 49 - github.com/go-chi/cors v1.2.2 // indirect 50 50 github.com/go-jose/go-jose/v4 v4.1.2 // indirect 51 51 github.com/go-logr/logr v1.4.2 // indirect 52 52 github.com/go-logr/stdr v1.2.2 // indirect
+6 -4
go.sum
··· 1 1 github.com/AdaLogics/go-fuzz-headers v0.0.0-20221103172237-443f56ff4ba8 h1:d+pBUmsteW5tM87xmVXHZ4+LibHRFn40SPAoZJOg2ak= 2 2 github.com/AdaLogics/go-fuzz-headers v0.0.0-20221103172237-443f56ff4ba8/go.mod h1:i9fr2JpcEcY/IHEvzCM3qXUZYOQHgR89dt4es1CgMhc= 3 3 github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= 4 + github.com/RussellLuo/slidingwindow v0.0.0-20200528002341-535bb99d338b h1:5/++qT1/z812ZqBvqQt6ToRswSuPZ/B33m6xVHRzADU= 5 + github.com/RussellLuo/slidingwindow v0.0.0-20200528002341-535bb99d338b/go.mod h1:4+EPqMRApwwE/6yo6CxiHoSnBzjRr3jsqer7frxP8y4= 4 6 github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= 5 7 github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= 6 8 github.com/alexbrainman/goissue34681 v0.0.0-20191006012335-3fc7a47baff5 h1:iW0a5ljuFxkLGPNem5Ui+KBjFJzKg4Fv2fnxe4dvzpM= ··· 18 20 github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= 19 21 github.com/bitly/go-hostpool v0.0.0-20171023180738-a3a6125de932 h1:mXoPYz/Ul5HYEDvkta6I8/rnYM5gSdSV2tJ6XbZuEtY= 20 22 github.com/bitly/go-hostpool v0.0.0-20171023180738-a3a6125de932/go.mod h1:NOuUCSz6Q9T7+igc/hlvDOUdtWKryOrtFyIVABv/p7k= 21 - github.com/bluesky-social/indigo v0.0.0-20251014222321-1e8718ae9f33 h1:x06Y6VyYUCvqWl2AS4/3NBBbRf8wWNMd3YrI44NTHS8= 22 - github.com/bluesky-social/indigo v0.0.0-20251014222321-1e8718ae9f33/go.mod h1:GuGAU33qKulpZCZNPcUeIQ4RW6KzNvOy7s8MSUXbAng= 23 + github.com/bluesky-social/indigo v0.0.0-20251021193747-543ab1124beb h1:zzyqB1W/itfdIA5cnOZ7IFCJ6QtqwOsXltmLunL4sHw= 24 + github.com/bluesky-social/indigo v0.0.0-20251021193747-543ab1124beb/go.mod h1:GuGAU33qKulpZCZNPcUeIQ4RW6KzNvOy7s8MSUXbAng= 23 25 github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869 h1:DDGfHa7BWjL4YnC6+E63dPcxHo2sUxDIu8g3QgEJdRY= 24 26 github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869/go.mod h1:Ekp36dRnpXw/yCqJaO+ZrUyxD+3VXMFFr56k5XYrpB4= 25 27 github.com/bshuster-repo/logrus-logstash-hook v1.0.0 h1:e+C0SB5R1pu//O4MQ3f9cFuPGoOVeF2fE4Og9otCc70= ··· 68 70 github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= 69 71 github.com/go-chi/chi/v5 v5.2.3 h1:WQIt9uxdsAbgIYgid+BpYc+liqQZGMHRaUwp0JUcvdE= 70 72 github.com/go-chi/chi/v5 v5.2.3/go.mod h1:L2yAIGWB3H+phAw1NxKwWM+7eUH/lU8pOMm5hHcoops= 71 - github.com/go-chi/cors v1.2.2 h1:Jmey33TE+b+rB7fT8MUy1u0I4L+NARQlK6LhzKPSyQE= 72 - github.com/go-chi/cors v1.2.2/go.mod h1:sSbTewc+6wYHBBCW7ytsFSn836hqM7JxpglAy2Vzc58= 73 73 github.com/go-jose/go-jose/v4 v4.1.2 h1:TK/7NqRQZfgAh+Td8AlsrvtPoUyiHh0LqVvokh+1vHI= 74 74 github.com/go-jose/go-jose/v4 v4.1.2/go.mod h1:22cg9HWM1pOlnRiY+9cQYJ9XHmya1bYW8OeDM6Ku6Oo= 75 75 github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= ··· 80 80 github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= 81 81 github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= 82 82 github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= 83 + github.com/go-redis/redis v6.15.9+incompatible h1:K0pv1D7EQUjfyoMql+r/jZqCLizCGKFlFgcHWWmHQjg= 84 + github.com/go-redis/redis v6.15.9+incompatible/go.mod h1:NAIEuMOZ/fxfXJIrKDQDz8wamY7mA7PouImQ2Jvg6kA= 83 85 github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= 84 86 github.com/go-yaml/yaml v2.1.0+incompatible/go.mod h1:w2MrLa16VYP0jy6N7M5kHaCkaLENm+P+Tv+MfurjSw0= 85 87 github.com/gocql/gocql v1.7.0 h1:O+7U7/1gSN7QTEAaMEsJc1Oq2QHXvCWoF3DFK9HDHus=
+54 -3
pkg/hold/oci/xrpc_test.go
··· 4 4 "bytes" 5 5 "context" 6 6 "encoding/json" 7 + "fmt" 7 8 "io" 8 9 "net/http" 9 10 "net/http/httptest" 11 + "os" 10 12 "path/filepath" 11 13 "strconv" 12 14 "testing" ··· 18 20 _ "github.com/distribution/distribution/v3/registry/storage/driver/filesystem" 19 21 ) 20 22 23 + // Shared test resources for OCI package 24 + var ( 25 + sharedTestKeyPath string 26 + sharedTestKey []byte 27 + ) 28 + 29 + // TestMain sets up shared resources for all OCI tests 30 + func TestMain(m *testing.M) { 31 + // Create a temporary directory for shared test key 32 + tmpDir, err := os.MkdirTemp("", "oci-test-shared-*") 33 + if err != nil { 34 + panic(fmt.Sprintf("Failed to create temp dir: %v", err)) 35 + } 36 + defer os.RemoveAll(tmpDir) 37 + 38 + // Generate one signing key to be reused across all tests 39 + sharedTestKeyPath = filepath.Join(tmpDir, "shared-signing-key") 40 + privateKey, err := pds.GenerateOrLoadKey(sharedTestKeyPath) 41 + if err != nil { 42 + panic(fmt.Sprintf("Failed to generate shared signing key: %v", err)) 43 + } 44 + 45 + // Store the key bytes so tests can copy them 46 + sharedTestKey = privateKey.Bytes() 47 + 48 + // Run tests 49 + code := m.Run() 50 + os.Exit(code) 51 + } 52 + 21 53 // Test setup helpers 22 54 23 55 // mockPDSClient implements pds.HTTPClient for testing ··· 54 86 } 55 87 56 88 // Create minimal PDS for DID/auth 57 - dbPath := filepath.Join(tmpDir, "pds.db") 89 + // Use in-memory database for speed 90 + dbPath := ":memory:" 58 91 keyPath := filepath.Join(tmpDir, "signing-key") 59 92 holdDID := "did:web:hold.example.com" 60 93 publicURL := "https://hold.example.com" 61 94 95 + // Copy shared signing key instead of generating a new one 96 + if err := os.WriteFile(keyPath, sharedTestKey, 0600); err != nil { 97 + t.Fatalf("Failed to copy shared signing key: %v", err) 98 + } 99 + 62 100 holdPDS, err := pds.NewHoldPDS(ctx, holdDID, publicURL, dbPath, keyPath) 63 101 if err != nil { 64 102 t.Fatalf("Failed to create PDS: %v", err) 65 103 } 66 104 67 - // Bootstrap PDS 105 + // Bootstrap PDS, suppressing stdout to avoid log spam 68 106 ownerDID := "did:plc:owner123" 69 - if err := holdPDS.Bootstrap(ctx, nil, ownerDID, true, false, ""); err != nil { 107 + 108 + // Redirect stdout to suppress bootstrap logging 109 + oldStdout := os.Stdout 110 + r, w, _ := os.Pipe() 111 + os.Stdout = w 112 + 113 + err = holdPDS.Bootstrap(ctx, nil, ownerDID, true, false, "") 114 + 115 + // Restore stdout 116 + w.Close() 117 + os.Stdout = oldStdout 118 + io.ReadAll(r) // Drain the pipe 119 + 120 + if err != nil { 70 121 t.Fatalf("Failed to bootstrap PDS: %v", err) 71 122 } 72 123
-138
pkg/hold/pds/apppassword.go
··· 1 - package pds 2 - 3 - import ( 4 - "crypto/rand" 5 - "encoding/base32" 6 - "fmt" 7 - "strings" 8 - 9 - "golang.org/x/crypto/bcrypt" 10 - ) 11 - 12 - // GenerateAppPassword creates a random app password in the format: abcd-efgh-ijkl-mnop 13 - // Uses base32 encoding for readable characters (no ambiguous chars like 0/O, 1/l) 14 - func GenerateAppPassword() (string, error) { 15 - // Generate 20 random bytes (160 bits of entropy) 16 - // Base32 encoding gives us 32 characters, we'll format as 4 groups of 4 17 - randomBytes := make([]byte, 20) 18 - if _, err := rand.Read(randomBytes); err != nil { 19 - return "", fmt.Errorf("failed to generate random bytes: %w", err) 20 - } 21 - 22 - // Encode as base32 and lowercase (base32 alphabet: a-z, 2-7) 23 - encoded := base32.StdEncoding.EncodeToString(randomBytes) 24 - encoded = strings.ToLower(encoded) 25 - 26 - // Remove padding and take first 16 characters 27 - encoded = strings.TrimRight(encoded, "=") 28 - if len(encoded) > 16 { 29 - encoded = encoded[:16] 30 - } 31 - 32 - // Format as: xxxx-xxxx-xxxx-xxxx 33 - parts := []string{ 34 - encoded[0:4], 35 - encoded[4:8], 36 - encoded[8:12], 37 - encoded[12:16], 38 - } 39 - 40 - return strings.Join(parts, "-"), nil 41 - } 42 - 43 - // HashAppPassword hashes an app password using bcrypt 44 - // Cost is set to 12 for good security without excessive CPU usage 45 - func HashAppPassword(password string) (string, error) { 46 - hash, err := bcrypt.GenerateFromPassword([]byte(password), 12) 47 - if err != nil { 48 - return "", fmt.Errorf("failed to hash password: %w", err) 49 - } 50 - return string(hash), nil 51 - } 52 - 53 - // ValidateAppPassword compares a plaintext password with a bcrypt hash 54 - func ValidateAppPassword(password, hash string) bool { 55 - err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(password)) 56 - return err == nil 57 - } 58 - 59 - // CreateAppPassword generates and stores a new app password 60 - func (p *HoldPDS) CreateAppPassword(name string) (string, error) { 61 - // Generate random password 62 - password, err := GenerateAppPassword() 63 - if err != nil { 64 - return "", fmt.Errorf("failed to generate password: %w", err) 65 - } 66 - 67 - // Hash password 68 - hash, err := HashAppPassword(password) 69 - if err != nil { 70 - return "", fmt.Errorf("failed to hash password: %w", err) 71 - } 72 - 73 - // Store in database 74 - if err := p.authDB.CreateAppPassword(name, hash); err != nil { 75 - return "", fmt.Errorf("failed to store password: %w", err) 76 - } 77 - 78 - return password, nil 79 - } 80 - 81 - // ValidateAppPasswordByName checks if a password matches the stored hash for a given name 82 - func (p *HoldPDS) ValidateAppPasswordByName(name, password string) error { 83 - // Get app password from database 84 - ap, err := p.authDB.GetAppPassword(name) 85 - if err != nil { 86 - return fmt.Errorf("app password not found: %w", err) 87 - } 88 - 89 - // Validate password 90 - if !ValidateAppPassword(password, ap.PasswordHash) { 91 - return fmt.Errorf("invalid password") 92 - } 93 - 94 - // Update last used timestamp 95 - if err := p.authDB.UpdateLastUsed(name); err != nil { 96 - // Log but don't fail - this is not critical 97 - fmt.Printf("Warning: failed to update last used timestamp: %v\n", err) 98 - } 99 - 100 - return nil 101 - } 102 - 103 - // ValidateAnyAppPassword checks if a password matches any stored app password 104 - // Returns the name of the matching app password, or error if none match 105 - func (p *HoldPDS) ValidateAnyAppPassword(password string) (string, error) { 106 - // List all app passwords 107 - passwords, err := p.authDB.ListAppPasswords() 108 - if err != nil { 109 - return "", fmt.Errorf("failed to list app passwords: %w", err) 110 - } 111 - 112 - // Try each one 113 - for _, ap := range passwords { 114 - // Get full record with hash 115 - fullAP, err := p.authDB.GetAppPassword(ap.Name) 116 - if err != nil { 117 - continue 118 - } 119 - 120 - if ValidateAppPassword(password, fullAP.PasswordHash) { 121 - // Update last used 122 - p.authDB.UpdateLastUsed(ap.Name) 123 - return ap.Name, nil 124 - } 125 - } 126 - 127 - return "", fmt.Errorf("invalid app password") 128 - } 129 - 130 - // ListAppPasswords returns a list of app password names (without hashes) 131 - func (p *HoldPDS) ListAppPasswords() ([]AppPassword, error) { 132 - return p.authDB.ListAppPasswords() 133 - } 134 - 135 - // RevokeAppPassword deletes an app password 136 - func (p *HoldPDS) RevokeAppPassword(name string) error { 137 - return p.authDB.DeleteAppPassword(name) 138 - }
-30
pkg/hold/pds/auth.go
··· 528 528 529 529 return publicKey, nil 530 530 } 531 - 532 - // ValidateJWTAuth validates a request with a JWT access token from createSession 533 - // This is used for authenticated repo operations (createRecord, etc.) 534 - // Returns the validated user DID 535 - func ValidateJWTAuth(r *http.Request, pds *HoldPDS) (*ValidatedUser, error) { 536 - // Extract Authorization header 537 - authHeader := r.Header.Get("Authorization") 538 - if authHeader == "" { 539 - return nil, fmt.Errorf("missing Authorization header") 540 - } 541 - 542 - // Remove "Bearer " prefix 543 - accessToken := strings.TrimPrefix(authHeader, "Bearer ") 544 - if accessToken == authHeader { 545 - return nil, fmt.Errorf("invalid authorization header format (expected Bearer)") 546 - } 547 - 548 - // Validate access token 549 - claims, err := pds.ValidateAccessToken(accessToken) 550 - if err != nil { 551 - return nil, fmt.Errorf("invalid access token: %w", err) 552 - } 553 - 554 - return &ValidatedUser{ 555 - DID: claims.DID, 556 - Handle: claims.Handle, 557 - PDS: "", 558 - Authorized: true, 559 - }, nil 560 - }
+15 -75
pkg/hold/pds/auth_test.go
··· 506 506 507 507 // TestValidateBlobWriteAccess_ServiceToken_Owner tests owner write access via service token 508 508 func TestValidateBlobWriteAccess_ServiceToken_Owner(t *testing.T) { 509 - pds, ctx := setupTestPDS(t) 510 - 511 509 ownerDID := "did:plc:owner123" 512 510 holdDID := "did:web:hold01.atcr.io" 513 511 514 - // Bootstrap with owner 515 - err := pds.Bootstrap(ctx, nil, ownerDID, true, false, "") 516 - if err != nil { 517 - t.Fatalf("Failed to bootstrap PDS: %v", err) 518 - } 512 + _, _ = setupTestPDSWithBootstrap(t, ownerDID, true, false) 519 513 520 514 // Create service token for owner 521 515 helper, err := NewServiceTokenTestHelper(ownerDID, holdDID) ··· 539 533 540 534 // TestValidateBlobWriteAccess_ServiceToken_CrewWithPermission tests crew write access via service token 541 535 func TestValidateBlobWriteAccess_ServiceToken_CrewWithPermission(t *testing.T) { 542 - pds, ctx := setupTestPDS(t) 543 - 544 536 ownerDID := "did:plc:owner123" 545 537 writerDID := "did:plc:writer123" 546 538 holdDID := "did:web:hold01.atcr.io" 547 539 548 - // Bootstrap 549 - err := pds.Bootstrap(ctx, nil, ownerDID, true, false, "") 550 - if err != nil { 551 - t.Fatalf("Failed to bootstrap PDS: %v", err) 552 - } 540 + pds, ctx := setupTestPDSWithBootstrap(t, ownerDID, true, false) 553 541 554 542 // Add crew member with blob:write permission 555 - _, err = pds.AddCrewMember(ctx, writerDID, "writer", []string{"blob:write"}) 543 + _, err := pds.AddCrewMember(ctx, writerDID, "writer", []string{"blob:write"}) 556 544 if err != nil { 557 545 t.Fatalf("Failed to add crew member: %v", err) 558 546 } ··· 598 586 599 587 // TestValidateBlobWriteAccess_ServiceToken_CrewWithoutPermission tests that crew without permission is rejected 600 588 func TestValidateBlobWriteAccess_ServiceToken_CrewWithoutPermission(t *testing.T) { 601 - pds, ctx := setupTestPDS(t) 602 - 603 589 ownerDID := "did:plc:owner123" 604 590 readerDID := "did:plc:reader123" 605 591 holdDID := "did:web:hold01.atcr.io" 606 592 607 - // Bootstrap 608 - err := pds.Bootstrap(ctx, nil, ownerDID, true, false, "") 609 - if err != nil { 610 - t.Fatalf("Failed to bootstrap PDS: %v", err) 611 - } 593 + pds, ctx := setupTestPDSWithBootstrap(t, ownerDID, true, false) 612 594 613 595 // Add crew member with blob:read permission only (no blob:write) 614 - _, err = pds.AddCrewMember(ctx, readerDID, "reader", []string{"blob:read"}) 596 + _, err := pds.AddCrewMember(ctx, readerDID, "reader", []string{"blob:read"}) 615 597 if err != nil { 616 598 t.Fatalf("Failed to add crew member: %v", err) 617 599 } ··· 645 627 646 628 // TestValidateBlobWriteAccess_Owner tests that the hold owner has write access 647 629 func TestValidateBlobWriteAccess_Owner(t *testing.T) { 648 - pds, ctx := setupTestPDS(t) 649 - 650 630 ownerDID := "did:plc:owner123" 651 631 652 - // Bootstrap with owner 653 - err := pds.Bootstrap(ctx, nil, ownerDID, true, false, "") 654 - if err != nil { 655 - t.Fatalf("Failed to bootstrap PDS: %v", err) 656 - } 632 + pds, _ := setupTestPDSWithBootstrap(t, ownerDID, true, false) 657 633 658 634 // Create DPoP helper for owner 659 635 dpopHelper, err := NewDPoPTestHelper(ownerDID, "https://test-pds.example.com") ··· 691 667 692 668 // TestValidateBlobWriteAccess_CrewPermissions tests crew permission checking 693 669 func TestValidateBlobWriteAccess_CrewPermissions(t *testing.T) { 694 - pds, ctx := setupTestPDS(t) 695 - 696 670 ownerDID := "did:plc:owner123" 697 671 698 - // Bootstrap 699 - err := pds.Bootstrap(ctx, nil, ownerDID, true, false, "") 700 - if err != nil { 701 - t.Fatalf("Failed to bootstrap PDS: %v", err) 702 - } 672 + pds, ctx := setupTestPDSWithBootstrap(t, ownerDID, true, false) 703 673 704 674 // Add crew member with blob:write permission 705 675 writerDID := "did:plc:writer123" 706 - _, err = pds.AddCrewMember(ctx, writerDID, "writer", []string{"blob:write"}) 676 + _, err := pds.AddCrewMember(ctx, writerDID, "writer", []string{"blob:write"}) 707 677 if err != nil { 708 678 t.Fatalf("Failed to add crew member: %v", err) 709 679 } ··· 764 734 765 735 // TestValidateBlobReadAccess_PublicHold tests public hold access 766 736 func TestValidateBlobReadAccess_PublicHold(t *testing.T) { 767 - pds, ctx := setupTestPDS(t) 768 - 769 737 ownerDID := "did:plc:owner123" 770 738 771 - // Bootstrap with public=true 772 - err := pds.Bootstrap(ctx, nil, ownerDID, true, false, "") 773 - if err != nil { 774 - t.Fatalf("Failed to bootstrap PDS: %v", err) 775 - } 739 + pds, ctx := setupTestPDSWithBootstrap(t, ownerDID, true, false) 776 740 777 741 // Verify captain record has public=true 778 742 _, captain, err := pds.GetCaptainRecord(ctx) ··· 801 765 802 766 // TestValidateBlobReadAccess_PrivateHold tests private hold access 803 767 func TestValidateBlobReadAccess_PrivateHold(t *testing.T) { 804 - pds, ctx := setupTestPDS(t) 805 - 806 768 ownerDID := "did:plc:owner123" 807 769 808 - // Bootstrap with public=false 809 - err := pds.Bootstrap(ctx, nil, ownerDID, false, false, "") 810 - if err != nil { 811 - t.Fatalf("Failed to bootstrap PDS: %v", err) 812 - } 770 + pds, ctx := setupTestPDSWithBootstrap(t, ownerDID, false, false) 813 771 814 772 // Update captain to be private 815 - _, err = pds.UpdateCaptainRecord(ctx, false, false) 773 + _, err := pds.UpdateCaptainRecord(ctx, false, false) 816 774 if err != nil { 817 775 t.Fatalf("Failed to update captain record: %v", err) 818 776 } ··· 843 801 844 802 // TestValidateOwnerOrCrewAdmin tests admin permission checking 845 803 func TestValidateOwnerOrCrewAdmin(t *testing.T) { 846 - pds, ctx := setupTestPDS(t) 847 - 848 804 ownerDID := "did:plc:owner123" 849 805 850 - // Bootstrap 851 - err := pds.Bootstrap(ctx, nil, ownerDID, true, false, "") 852 - if err != nil { 853 - t.Fatalf("Failed to bootstrap PDS: %v", err) 854 - } 806 + pds, ctx := setupTestPDSWithBootstrap(t, ownerDID, true, false) 855 807 856 808 // Add crew member with crew:admin permission 857 809 adminDID := "did:plc:admin123" 858 - _, err = pds.AddCrewMember(ctx, adminDID, "admin", []string{"crew:admin", "blob:write", "blob:read"}) 810 + _, err := pds.AddCrewMember(ctx, adminDID, "admin", []string{"crew:admin", "blob:write", "blob:read"}) 859 811 if err != nil { 860 812 t.Fatalf("Failed to add crew admin: %v", err) 861 813 } ··· 911 863 912 864 // TestCrewPermissions tests various permission combinations 913 865 func TestCrewPermissions(t *testing.T) { 914 - pds, ctx := setupTestPDS(t) 915 - 916 866 ownerDID := "did:plc:owner123" 917 867 918 - // Bootstrap 919 - err := pds.Bootstrap(ctx, nil, ownerDID, true, false, "") 920 - if err != nil { 921 - t.Fatalf("Failed to bootstrap PDS: %v", err) 922 - } 868 + pds, ctx := setupTestPDSWithBootstrap(t, ownerDID, true, false) 923 869 924 870 tests := []struct { 925 871 name string ··· 1035 981 1036 982 for _, tt := range tests { 1037 983 t.Run(tt.name, func(t *testing.T) { 1038 - pds, ctx := setupTestPDS(t) 1039 - 1040 984 ownerDID := "did:plc:owner123" 1041 985 1042 - // Bootstrap with specified settings 1043 - err := pds.Bootstrap(ctx, nil, ownerDID, tt.public, tt.allowAllCrew, "") 1044 - if err != nil { 1045 - t.Fatalf("Failed to bootstrap PDS: %v", err) 1046 - } 986 + pds, ctx := setupTestPDSWithBootstrap(t, ownerDID, tt.public, tt.allowAllCrew) 1047 987 1048 988 // Verify captain record has expected settings 1049 989 _, captain, err := pds.GetCaptainRecord(ctx)
+34 -1
pkg/hold/pds/captain_test.go
··· 3 3 import ( 4 4 "bytes" 5 5 "context" 6 + "io" 7 + "os" 6 8 "path/filepath" 7 9 "strings" 8 10 "testing" ··· 17 19 ctx := context.Background() 18 20 tmpDir := t.TempDir() 19 21 20 - dbPath := filepath.Join(tmpDir, "pds.db") 22 + // Use in-memory database for speed 23 + dbPath := ":memory:" 21 24 keyPath := filepath.Join(tmpDir, "signing-key") 22 25 26 + // Copy shared signing key instead of generating a new one 27 + if err := os.WriteFile(keyPath, sharedTestKey, 0600); err != nil { 28 + t.Fatalf("Failed to copy shared signing key: %v", err) 29 + } 30 + 23 31 pds, err := NewHoldPDS(ctx, "did:web:hold.example.com", "https://hold.example.com", dbPath, keyPath) 24 32 if err != nil { 25 33 t.Fatalf("Failed to create test PDS: %v", err) ··· 30 38 err = pds.repomgr.InitNewActor(ctx, pds.uid, "", pds.did, "", "", "") 31 39 if err != nil { 32 40 t.Fatalf("Failed to initialize test repo: %v", err) 41 + } 42 + 43 + return pds, ctx 44 + } 45 + 46 + // setupTestPDSWithBootstrap creates a test PDS and bootstraps it with suppressed output 47 + // This is a convenience function for tests that need a fully initialized PDS 48 + func setupTestPDSWithBootstrap(t *testing.T, ownerDID string, public, allowAllCrew bool) (*HoldPDS, context.Context) { 49 + t.Helper() 50 + 51 + pds, ctx := setupTestPDS(t) 52 + 53 + // Bootstrap with suppressed output 54 + oldStdout := os.Stdout 55 + r, w, _ := os.Pipe() 56 + os.Stdout = w 57 + 58 + err := pds.Bootstrap(ctx, nil, ownerDID, public, allowAllCrew, "") 59 + 60 + w.Close() 61 + os.Stdout = oldStdout 62 + io.ReadAll(r) // Drain the pipe 63 + 64 + if err != nil { 65 + t.Fatalf("Failed to bootstrap PDS: %v", err) 33 66 } 34 67 35 68 return pds, ctx
-232
pkg/hold/pds/database.go
··· 1 - package pds 2 - 3 - import ( 4 - "database/sql" 5 - "fmt" 6 - "os" 7 - "path/filepath" 8 - "time" 9 - 10 - _ "github.com/mattn/go-sqlite3" 11 - ) 12 - 13 - // Database manages app passwords and sessions for the hold PDS 14 - type Database struct { 15 - db *sql.DB 16 - } 17 - 18 - // NewDatabase creates or opens a database for app passwords and sessions 19 - // dbPath should be the directory path (same as carstore) 20 - // It creates a separate "auth.db" file for authentication data 21 - func NewDatabase(dbPath string) (*Database, error) { 22 - // Ensure directory exists 23 - if err := os.MkdirAll(dbPath, 0755); err != nil { 24 - return nil, fmt.Errorf("failed to create database directory: %w", err) 25 - } 26 - 27 - // Create auth database file alongside carstore database 28 - authDBFile := filepath.Join(dbPath, "auth.db") 29 - 30 - db, err := sql.Open("sqlite3", authDBFile) 31 - if err != nil { 32 - return nil, fmt.Errorf("failed to open database: %w", err) 33 - } 34 - 35 - // Create tables 36 - if err := createTables(db); err != nil { 37 - db.Close() 38 - return nil, fmt.Errorf("failed to create tables: %w", err) 39 - } 40 - 41 - return &Database{db: db}, nil 42 - } 43 - 44 - // createTables creates the database schema 45 - func createTables(db *sql.DB) error { 46 - schema := ` 47 - CREATE TABLE IF NOT EXISTS app_passwords ( 48 - id INTEGER PRIMARY KEY AUTOINCREMENT, 49 - name TEXT NOT NULL UNIQUE, 50 - password_hash TEXT NOT NULL, 51 - created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, 52 - last_used_at TIMESTAMP 53 - ); 54 - 55 - CREATE INDEX IF NOT EXISTS idx_app_passwords_name ON app_passwords(name); 56 - 57 - CREATE TABLE IF NOT EXISTS refresh_tokens ( 58 - id INTEGER PRIMARY KEY AUTOINCREMENT, 59 - token_hash TEXT NOT NULL UNIQUE, 60 - created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, 61 - expires_at TIMESTAMP NOT NULL, 62 - last_used_at TIMESTAMP 63 - ); 64 - 65 - CREATE INDEX IF NOT EXISTS idx_refresh_tokens_hash ON refresh_tokens(token_hash); 66 - CREATE INDEX IF NOT EXISTS idx_refresh_tokens_expires ON refresh_tokens(expires_at); 67 - ` 68 - 69 - _, err := db.Exec(schema) 70 - return err 71 - } 72 - 73 - // Close closes the database connection 74 - func (d *Database) Close() error { 75 - return d.db.Close() 76 - } 77 - 78 - // AppPassword represents an app password record 79 - type AppPassword struct { 80 - ID int64 81 - Name string 82 - PasswordHash string 83 - CreatedAt time.Time 84 - LastUsedAt *time.Time 85 - } 86 - 87 - // CreateAppPassword stores a new app password 88 - func (d *Database) CreateAppPassword(name, passwordHash string) error { 89 - query := `INSERT INTO app_passwords (name, password_hash) VALUES (?, ?)` 90 - _, err := d.db.Exec(query, name, passwordHash) 91 - if err != nil { 92 - return fmt.Errorf("failed to create app password: %w", err) 93 - } 94 - return nil 95 - } 96 - 97 - // GetAppPassword retrieves an app password by name 98 - func (d *Database) GetAppPassword(name string) (*AppPassword, error) { 99 - query := `SELECT id, name, password_hash, created_at, last_used_at FROM app_passwords WHERE name = ?` 100 - 101 - var ap AppPassword 102 - var lastUsedAt sql.NullTime 103 - 104 - err := d.db.QueryRow(query, name).Scan( 105 - &ap.ID, 106 - &ap.Name, 107 - &ap.PasswordHash, 108 - &ap.CreatedAt, 109 - &lastUsedAt, 110 - ) 111 - 112 - if err == sql.ErrNoRows { 113 - return nil, fmt.Errorf("app password not found") 114 - } 115 - if err != nil { 116 - return nil, fmt.Errorf("failed to get app password: %w", err) 117 - } 118 - 119 - if lastUsedAt.Valid { 120 - ap.LastUsedAt = &lastUsedAt.Time 121 - } 122 - 123 - return &ap, nil 124 - } 125 - 126 - // ListAppPasswords returns all app passwords (without hashes) 127 - func (d *Database) ListAppPasswords() ([]AppPassword, error) { 128 - query := `SELECT id, name, created_at, last_used_at FROM app_passwords ORDER BY created_at DESC` 129 - 130 - rows, err := d.db.Query(query) 131 - if err != nil { 132 - return nil, fmt.Errorf("failed to list app passwords: %w", err) 133 - } 134 - defer rows.Close() 135 - 136 - var passwords []AppPassword 137 - for rows.Next() { 138 - var ap AppPassword 139 - var lastUsedAt sql.NullTime 140 - 141 - if err := rows.Scan(&ap.ID, &ap.Name, &ap.CreatedAt, &lastUsedAt); err != nil { 142 - return nil, fmt.Errorf("failed to scan row: %w", err) 143 - } 144 - 145 - if lastUsedAt.Valid { 146 - ap.LastUsedAt = &lastUsedAt.Time 147 - } 148 - 149 - passwords = append(passwords, ap) 150 - } 151 - 152 - return passwords, rows.Err() 153 - } 154 - 155 - // UpdateLastUsed updates the last used timestamp for an app password 156 - func (d *Database) UpdateLastUsed(name string) error { 157 - query := `UPDATE app_passwords SET last_used_at = CURRENT_TIMESTAMP WHERE name = ?` 158 - _, err := d.db.Exec(query, name) 159 - return err 160 - } 161 - 162 - // DeleteAppPassword removes an app password 163 - func (d *Database) DeleteAppPassword(name string) error { 164 - query := `DELETE FROM app_passwords WHERE name = ?` 165 - result, err := d.db.Exec(query, name) 166 - if err != nil { 167 - return fmt.Errorf("failed to delete app password: %w", err) 168 - } 169 - 170 - rows, err := result.RowsAffected() 171 - if err != nil { 172 - return fmt.Errorf("failed to check rows affected: %w", err) 173 - } 174 - 175 - if rows == 0 { 176 - return fmt.Errorf("app password not found") 177 - } 178 - 179 - return nil 180 - } 181 - 182 - // CreateRefreshToken stores a refresh token 183 - func (d *Database) CreateRefreshToken(tokenHash string, expiresAt time.Time) error { 184 - query := `INSERT INTO refresh_tokens (token_hash, expires_at) VALUES (?, ?)` 185 - _, err := d.db.Exec(query, tokenHash, expiresAt) 186 - if err != nil { 187 - return fmt.Errorf("failed to create refresh token: %w", err) 188 - } 189 - return nil 190 - } 191 - 192 - // ValidateRefreshToken checks if a refresh token exists and is not expired 193 - func (d *Database) ValidateRefreshToken(tokenHash string) (bool, error) { 194 - query := `SELECT expires_at FROM refresh_tokens WHERE token_hash = ?` 195 - 196 - var expiresAt time.Time 197 - err := d.db.QueryRow(query, tokenHash).Scan(&expiresAt) 198 - 199 - if err == sql.ErrNoRows { 200 - return false, nil 201 - } 202 - if err != nil { 203 - return false, fmt.Errorf("failed to validate refresh token: %w", err) 204 - } 205 - 206 - // Check if expired 207 - if time.Now().After(expiresAt) { 208 - // Delete expired token 209 - d.DeleteRefreshToken(tokenHash) 210 - return false, nil 211 - } 212 - 213 - // Update last used 214 - updateQuery := `UPDATE refresh_tokens SET last_used_at = CURRENT_TIMESTAMP WHERE token_hash = ?` 215 - d.db.Exec(updateQuery, tokenHash) 216 - 217 - return true, nil 218 - } 219 - 220 - // DeleteRefreshToken removes a refresh token 221 - func (d *Database) DeleteRefreshToken(tokenHash string) error { 222 - query := `DELETE FROM refresh_tokens WHERE token_hash = ?` 223 - _, err := d.db.Exec(query, tokenHash) 224 - return err 225 - } 226 - 227 - // CleanupExpiredTokens removes all expired refresh tokens 228 - func (d *Database) CleanupExpiredTokens() error { 229 - query := `DELETE FROM refresh_tokens WHERE expires_at < CURRENT_TIMESTAMP` 230 - _, err := d.db.Exec(query) 231 - return err 232 - }
+71 -16
pkg/hold/pds/events.go
··· 8 8 "time" 9 9 10 10 atproto "github.com/bluesky-social/indigo/api/atproto" 11 + "github.com/bluesky-social/indigo/events" 11 12 lexutil "github.com/bluesky-social/indigo/lex/util" 12 13 "github.com/gorilla/websocket" 14 + "github.com/ipfs/go-cid" 13 15 ) 14 16 15 17 // EventBroadcaster manages WebSocket connections and broadcasts repo events ··· 77 79 b.mu.Unlock() 78 80 79 81 // Send historical events if cursor is provided and < current seq 80 - if cursor > 0 && cursor < currentSeq { 82 + // cursor=0 means "replay all events from the beginning" 83 + // cursor >= 0 triggers backfill, negative cursor means "no backfill" 84 + if cursor >= 0 && cursor < currentSeq { 81 85 go b.backfillSubscriber(sub, cursor) 82 86 } 83 87 ··· 205 209 }() 206 210 207 211 for event := range sub.send { 208 - // Encode as CBOR 209 - cborBytes, err := encodeCBOR(event) 212 + // Create event header (ATProto firehose format) 213 + header := events.EventHeader{ 214 + Op: events.EvtKindMessage, 215 + MsgType: "#commit", 216 + } 217 + 218 + // Get a writer for this message 219 + wc, err := sub.conn.NextWriter(websocket.BinaryMessage) 210 220 if err != nil { 211 - log.Printf("Failed to encode event as CBOR: %v", err) 212 - continue 221 + log.Printf("Failed to get websocket writer: %v", err) 222 + return 213 223 } 214 224 215 - // Write CBOR message to WebSocket 216 - err = sub.conn.WriteMessage(websocket.BinaryMessage, cborBytes) 217 - if err != nil { 218 - log.Printf("Failed to write to websocket: %v", err) 225 + // Write header as CBOR 226 + if err := header.MarshalCBOR(wc); err != nil { 227 + log.Printf("Failed to write event header: %v", err) 228 + wc.Close() 229 + return 230 + } 231 + 232 + // Convert our RepoCommitEvent to indigo's SyncSubscribeRepos_Commit 233 + indigoEvent := convertToIndigoCommit(event) 234 + 235 + // Write the event as CBOR 236 + var obj lexutil.CBOR = indigoEvent 237 + if err := obj.MarshalCBOR(wc); err != nil { 238 + log.Printf("Failed to write event body: %v", err) 239 + wc.Close() 240 + return 241 + } 242 + 243 + // Close the writer to flush the message 244 + if err := wc.Close(); err != nil { 245 + log.Printf("Failed to close websocket writer: %v", err) 219 246 return 220 247 } 221 248 ··· 224 251 } 225 252 } 226 253 227 - // encodeCBOR encodes an event as CBOR 254 + // convertToIndigoCommit converts our RepoCommitEvent to indigo's SyncSubscribeRepos_Commit 255 + // which has proper CBOR marshaling methods generated 256 + func convertToIndigoCommit(event *RepoCommitEvent) *atproto.SyncSubscribeRepos_Commit { 257 + // Parse commit CID string to cid.Cid, then convert to LexLink 258 + commitCID, err := cid.Decode(event.Commit) 259 + if err != nil { 260 + log.Printf("Warning: failed to parse commit CID %s: %v", event.Commit, err) 261 + // Create an empty CID as fallback 262 + commitCID = cid.Undef 263 + } 264 + 265 + // Convert cid.Cid to LexLink 266 + commitLink := lexutil.LexLink(commitCID) 267 + 268 + // Convert blocks to LexBytes 269 + blocks := lexutil.LexBytes(event.Blocks) 270 + 271 + return &atproto.SyncSubscribeRepos_Commit{ 272 + Seq: event.Seq, 273 + Repo: event.Repo, 274 + Commit: commitLink, 275 + Rev: event.Rev, 276 + Since: event.Since, 277 + Blocks: blocks, 278 + Ops: event.Ops, 279 + Time: event.Time, 280 + Blobs: []lexutil.LexLink{}, // Empty for now, we don't track blob refs in our simplified model 281 + Rebase: false, // DEPRECATED field 282 + TooBig: false, // Not implementing tooBig for now 283 + } 284 + } 285 + 286 + // encodeCBOR encodes an event as CBOR (DEPRECATED - kept for tests) 228 287 func encodeCBOR(event *RepoCommitEvent) ([]byte, error) { 229 - // For now, use JSON encoding wrapped in CBOR envelope 230 - // In production, you'd use proper CBOR encoding 231 - // The atproto spec requires DAG-CBOR with specific header 232 - 233 - // Simple approach: encode as JSON for MVP 234 - // Real implementation needs proper CBOR-gen types 288 + // For backward compatibility with tests, encode as JSON 289 + // Production code uses convertToIndigoCommit + CBOR marshaling in handleSubscriber 235 290 return json.Marshal(event) 236 291 } 237 292
+164
pkg/hold/pds/events_test.go
··· 382 382 t.Errorf("Expected decoded seq=1, got %d", decoded.Seq) 383 383 } 384 384 } 385 + 386 + // TestSubscribe_CursorZeroBackfill tests that cursor=0 replays all events 387 + func TestSubscribe_CursorZeroBackfill(t *testing.T) { 388 + broadcaster := NewEventBroadcaster("did:web:hold.example.com", 100) 389 + ctx := context.Background() 390 + 391 + testCID, _ := cid.Decode("bafyreib2rxk3rkhh5ylyxj3x3gathxt3s32qvwj2lf3qg4kmzr6b7teqke") 392 + 393 + // Broadcast 5 events before subscribing 394 + for i := 1; i <= 5; i++ { 395 + event := &RepoEvent{ 396 + NewRoot: testCID, 397 + Rev: "test-rev", 398 + RepoSlice: []byte("test CAR data"), 399 + Ops: []RepoOp{}, 400 + } 401 + broadcaster.Broadcast(ctx, event) 402 + } 403 + 404 + // Verify we have 5 events in history 405 + if broadcaster.eventSeq != 5 { 406 + t.Fatalf("Expected eventSeq=5, got %d", broadcaster.eventSeq) 407 + } 408 + 409 + // Create mock websocket connection (we won't actually use it) 410 + // We just need to verify backfillSubscriber is called 411 + // For this test, we'll check the history directly 412 + if len(broadcaster.eventHistory) != 5 { 413 + t.Errorf("Expected 5 events in history, got %d", len(broadcaster.eventHistory)) 414 + } 415 + 416 + // Verify all events have sequential sequence numbers 417 + for i, he := range broadcaster.eventHistory { 418 + expectedSeq := int64(i + 1) 419 + if he.Seq != expectedSeq { 420 + t.Errorf("Expected history[%d].Seq=%d, got %d", i, expectedSeq, he.Seq) 421 + } 422 + } 423 + 424 + // Test backfillSubscriber directly with cursor=0 425 + // Create a subscriber manually (conn not needed for backfill test) 426 + sub := &Subscriber{ 427 + conn: nil, // Not used in backfillSubscriber 428 + send: make(chan *RepoCommitEvent, 100), // Large buffer for testing 429 + cursor: 0, 430 + } 431 + 432 + // Run backfill in a goroutine 433 + go broadcaster.backfillSubscriber(sub, 0) 434 + 435 + // Wait for events to be sent 436 + time.Sleep(100 * time.Millisecond) 437 + 438 + // Should receive all 5 events 439 + receivedCount := len(sub.send) 440 + if receivedCount != 5 { 441 + t.Errorf("Expected to receive 5 events with cursor=0, got %d", receivedCount) 442 + } 443 + 444 + // Verify events are in order 445 + for i := 1; i <= 5; i++ { 446 + select { 447 + case event := <-sub.send: 448 + if event.Seq != int64(i) { 449 + t.Errorf("Expected event seq=%d, got %d", i, event.Seq) 450 + } 451 + default: 452 + t.Errorf("Expected event %d but channel was empty", i) 453 + } 454 + } 455 + } 456 + 457 + // TestSubscribe_MidCursorBackfill tests that cursor=N only gets events after N 458 + func TestSubscribe_MidCursorBackfill(t *testing.T) { 459 + broadcaster := NewEventBroadcaster("did:web:hold.example.com", 100) 460 + ctx := context.Background() 461 + 462 + testCID, _ := cid.Decode("bafyreib2rxk3rkhh5ylyxj3x3gathxt3s32qvwj2lf3qg4kmzr6b7teqke") 463 + 464 + // Broadcast 10 events before subscribing 465 + for i := 1; i <= 10; i++ { 466 + event := &RepoEvent{ 467 + NewRoot: testCID, 468 + Rev: "test-rev", 469 + RepoSlice: []byte("test CAR data"), 470 + Ops: []RepoOp{}, 471 + } 472 + broadcaster.Broadcast(ctx, event) 473 + } 474 + 475 + // Test backfillSubscriber with cursor=5 (conn not needed for backfill test) 476 + sub := &Subscriber{ 477 + conn: nil, // Not used in backfillSubscriber 478 + send: make(chan *RepoCommitEvent, 100), // Large buffer for testing 479 + cursor: 5, 480 + } 481 + 482 + // Run backfill 483 + go broadcaster.backfillSubscriber(sub, 5) 484 + 485 + // Wait for events to be sent 486 + time.Sleep(100 * time.Millisecond) 487 + 488 + // Should receive events 6-10 (5 events after cursor=5) 489 + receivedCount := len(sub.send) 490 + if receivedCount != 5 { 491 + t.Errorf("Expected to receive 5 events with cursor=5, got %d", receivedCount) 492 + } 493 + 494 + // Verify events start at seq=6 495 + for i := 6; i <= 10; i++ { 496 + select { 497 + case event := <-sub.send: 498 + if event.Seq != int64(i) { 499 + t.Errorf("Expected event seq=%d, got %d", i, event.Seq) 500 + } 501 + default: 502 + t.Errorf("Expected event %d but channel was empty", i) 503 + } 504 + } 505 + } 506 + 507 + // TestSubscribe_NegativeCursorNoBackfill tests that negative cursor means no backfill 508 + func TestSubscribe_NegativeCursorNoBackfill(t *testing.T) { 509 + broadcaster := NewEventBroadcaster("did:web:hold.example.com", 100) 510 + ctx := context.Background() 511 + 512 + testCID, _ := cid.Decode("bafyreib2rxk3rkhh5ylyxj3x3gathxt3s32qvwj2lf3qg4kmzr6b7teqke") 513 + 514 + // Broadcast 5 events before subscribing 515 + for i := 1; i <= 5; i++ { 516 + event := &RepoEvent{ 517 + NewRoot: testCID, 518 + Rev: "test-rev", 519 + RepoSlice: []byte("test CAR data"), 520 + Ops: []RepoOp{}, 521 + } 522 + broadcaster.Broadcast(ctx, event) 523 + } 524 + 525 + // Create subscriber with cursor=-1 (no backfill, conn not needed) 526 + sub := &Subscriber{ 527 + conn: nil, // Not used in this test 528 + send: make(chan *RepoCommitEvent, 100), 529 + cursor: -1, 530 + } 531 + 532 + // Subscribe should not trigger backfill 533 + broadcaster.mu.Lock() 534 + currentSeq := broadcaster.eventSeq 535 + broadcaster.mu.Unlock() 536 + 537 + // Check the condition: cursor >= 0 && cursor < currentSeq 538 + // For cursor=-1, this should be false 539 + shouldBackfill := -1 >= 0 && -1 < currentSeq 540 + if shouldBackfill { 541 + t.Error("Expected shouldBackfill=false for cursor=-1, but condition evaluated to true") 542 + } 543 + 544 + // Verify no events in send channel (no backfill happened) 545 + if len(sub.send) != 0 { 546 + t.Errorf("Expected 0 events with cursor=-1 (no backfill), got %d", len(sub.send)) 547 + } 548 + }
-249
pkg/hold/pds/jwt.go
··· 1 - package pds 2 - 3 - import ( 4 - "crypto/sha256" 5 - "encoding/base64" 6 - "encoding/hex" 7 - "encoding/json" 8 - "fmt" 9 - "time" 10 - ) 11 - 12 - // Session token types 13 - const ( 14 - TokenTypeAccess = "access" 15 - TokenTypeRefresh = "refresh" 16 - ) 17 - 18 - // Token expiration durations 19 - const ( 20 - AccessTokenDuration = 2 * time.Hour // Short-lived access token 21 - RefreshTokenDuration = 90 * 24 * time.Hour // Long-lived refresh token (90 days) 22 - ) 23 - 24 - // SessionClaims represents JWT claims for ATProto sessions 25 - type SessionClaims struct { 26 - DID string `json:"sub"` // Subject (DID) 27 - Issuer string `json:"iss"` // Issuer (PDS DID) 28 - Handle string `json:"handle,omitempty"` 29 - Scope string `json:"scope"` 30 - TokenType string `json:"token_type"` 31 - IssuedAt int64 `json:"iat"` // Unix timestamp 32 - ExpiresAt int64 `json:"exp"` // Unix timestamp 33 - } 34 - 35 - // IssueAccessToken creates a new access JWT for a session 36 - func (p *HoldPDS) IssueAccessToken(did, handle string) (string, error) { 37 - now := time.Now() 38 - claims := &SessionClaims{ 39 - DID: did, 40 - Issuer: p.did, 41 - Handle: handle, 42 - Scope: "com.atproto.access", 43 - TokenType: TokenTypeAccess, 44 - IssuedAt: now.Unix(), 45 - ExpiresAt: now.Add(AccessTokenDuration).Unix(), 46 - } 47 - 48 - return p.signJWT(claims) 49 - } 50 - 51 - // IssueRefreshToken creates a new refresh JWT for a session 52 - func (p *HoldPDS) IssueRefreshToken(did, handle string) (string, error) { 53 - now := time.Now() 54 - claims := &SessionClaims{ 55 - DID: did, 56 - Issuer: p.did, 57 - Handle: handle, 58 - Scope: "com.atproto.refresh", 59 - TokenType: TokenTypeRefresh, 60 - IssuedAt: now.Unix(), 61 - ExpiresAt: now.Add(RefreshTokenDuration).Unix(), 62 - } 63 - 64 - signedToken, err := p.signJWT(claims) 65 - if err != nil { 66 - return "", err 67 - } 68 - 69 - // Store refresh token hash in database for validation/revocation 70 - tokenHash := hashToken(signedToken) 71 - expiresAt := now.Add(RefreshTokenDuration) 72 - if err := p.authDB.CreateRefreshToken(tokenHash, expiresAt); err != nil { 73 - return "", fmt.Errorf("failed to store refresh token: %w", err) 74 - } 75 - 76 - return signedToken, nil 77 - } 78 - 79 - // ValidateAccessToken validates an access JWT and returns the claims 80 - func (p *HoldPDS) ValidateAccessToken(tokenString string) (*SessionClaims, error) { 81 - return p.validateToken(tokenString, TokenTypeAccess) 82 - } 83 - 84 - // ValidateRefreshToken validates a refresh JWT and returns the claims 85 - // Also checks the database to ensure the token hasn't been revoked 86 - func (p *HoldPDS) ValidateRefreshToken(tokenString string) (*SessionClaims, error) { 87 - // First validate signature and claims 88 - claims, err := p.validateToken(tokenString, TokenTypeRefresh) 89 - if err != nil { 90 - return nil, err 91 - } 92 - 93 - // Check if token is in database (not revoked) 94 - tokenHash := hashToken(tokenString) 95 - valid, err := p.authDB.ValidateRefreshToken(tokenHash) 96 - if err != nil { 97 - return nil, fmt.Errorf("failed to validate refresh token in database: %w", err) 98 - } 99 - if !valid { 100 - return nil, fmt.Errorf("refresh token has been revoked or expired") 101 - } 102 - 103 - return claims, nil 104 - } 105 - 106 - // validateToken validates a JWT token and returns the claims 107 - func (p *HoldPDS) validateToken(tokenString, expectedType string) (*SessionClaims, error) { 108 - // Split token into parts 109 - parts := splitJWT(tokenString) 110 - if len(parts) != 3 { 111 - return nil, fmt.Errorf("invalid JWT format") 112 - } 113 - 114 - // Decode header 115 - headerBytes, err := base64.RawURLEncoding.DecodeString(parts[0]) 116 - if err != nil { 117 - return nil, fmt.Errorf("failed to decode header: %w", err) 118 - } 119 - 120 - var header map[string]interface{} 121 - if err := json.Unmarshal(headerBytes, &header); err != nil { 122 - return nil, fmt.Errorf("failed to parse header: %w", err) 123 - } 124 - 125 - // Verify algorithm 126 - alg, ok := header["alg"].(string) 127 - if !ok || alg != "ES256K" { 128 - return nil, fmt.Errorf("unsupported algorithm: %v", alg) 129 - } 130 - 131 - // Decode claims 132 - claimsBytes, err := base64.RawURLEncoding.DecodeString(parts[1]) 133 - if err != nil { 134 - return nil, fmt.Errorf("failed to decode claims: %w", err) 135 - } 136 - 137 - var claims SessionClaims 138 - if err := json.Unmarshal(claimsBytes, &claims); err != nil { 139 - return nil, fmt.Errorf("failed to parse claims: %w", err) 140 - } 141 - 142 - // Verify token type 143 - if claims.TokenType != expectedType { 144 - return nil, fmt.Errorf("invalid token type: expected %s, got %s", expectedType, claims.TokenType) 145 - } 146 - 147 - // Verify issuer 148 - if claims.Issuer != p.did { 149 - return nil, fmt.Errorf("invalid issuer: expected %s, got %s", p.did, claims.Issuer) 150 - } 151 - 152 - // Verify subject matches this hold 153 - if claims.DID != p.did { 154 - return nil, fmt.Errorf("invalid subject: expected %s, got %s", p.did, claims.DID) 155 - } 156 - 157 - // Verify expiration 158 - if time.Now().Unix() > claims.ExpiresAt { 159 - return nil, fmt.Errorf("token has expired") 160 - } 161 - 162 - // Verify signature 163 - signedData := []byte(parts[0] + "." + parts[1]) 164 - signature, err := base64.RawURLEncoding.DecodeString(parts[2]) 165 - if err != nil { 166 - return nil, fmt.Errorf("failed to decode signature: %w", err) 167 - } 168 - 169 - publicKey, err := p.signingKey.PublicKey() 170 - if err != nil { 171 - return nil, fmt.Errorf("failed to get public key: %w", err) 172 - } 173 - 174 - if err := publicKey.HashAndVerify(signedData, signature); err != nil { 175 - return nil, fmt.Errorf("signature verification failed: %w", err) 176 - } 177 - 178 - return &claims, nil 179 - } 180 - 181 - // RevokeRefreshToken revokes a refresh token by removing it from the database 182 - func (p *HoldPDS) RevokeRefreshToken(tokenString string) error { 183 - tokenHash := hashToken(tokenString) 184 - return p.authDB.DeleteRefreshToken(tokenHash) 185 - } 186 - 187 - // hashToken creates a SHA-256 hash of a token for storage 188 - func hashToken(token string) string { 189 - hash := sha256.Sum256([]byte(token)) 190 - return hex.EncodeToString(hash[:]) 191 - } 192 - 193 - // signJWT creates and signs a JWT using the hold's private key 194 - func (p *HoldPDS) signJWT(claims *SessionClaims) (string, error) { 195 - // Create header 196 - header := map[string]interface{}{ 197 - "typ": "JWT", 198 - "alg": "ES256K", 199 - } 200 - 201 - headerJSON, err := json.Marshal(header) 202 - if err != nil { 203 - return "", fmt.Errorf("failed to marshal header: %w", err) 204 - } 205 - 206 - // Create payload 207 - payloadJSON, err := json.Marshal(claims) 208 - if err != nil { 209 - return "", fmt.Errorf("failed to marshal claims: %w", err) 210 - } 211 - 212 - // Base64url encode header and payload 213 - headerEncoded := base64.RawURLEncoding.EncodeToString(headerJSON) 214 - payloadEncoded := base64.RawURLEncoding.EncodeToString(payloadJSON) 215 - 216 - // Create signing input 217 - signingInput := headerEncoded + "." + payloadEncoded 218 - 219 - // Sign with private key 220 - signature, err := p.signingKey.HashAndSign([]byte(signingInput)) 221 - if err != nil { 222 - return "", fmt.Errorf("failed to sign JWT: %w", err) 223 - } 224 - 225 - // Base64url encode signature 226 - signatureEncoded := base64.RawURLEncoding.EncodeToString(signature) 227 - 228 - // Combine into final JWT 229 - jwt := signingInput + "." + signatureEncoded 230 - 231 - return jwt, nil 232 - } 233 - 234 - // splitJWT splits a JWT string into its three parts 235 - func splitJWT(token string) []string { 236 - // JWT format: header.payload.signature 237 - parts := make([]string, 0, 3) 238 - start := 0 239 - for i := 0; i < len(token); i++ { 240 - if token[i] == '.' { 241 - parts = append(parts, token[start:i]) 242 - start = i + 1 243 - } 244 - } 245 - if start < len(token) { 246 - parts = append(parts, token[start:]) 247 - } 248 - return parts 249 - }
+21 -67
pkg/hold/pds/server.go
··· 36 36 dbPath string 37 37 uid models.Uid 38 38 signingKey *atcrypto.PrivateKeyK256 39 - authDB *Database // Authentication database for app passwords and sessions 40 39 } 41 40 42 41 // NewHoldPDS creates or opens a hold PDS with SQLite carstore 43 42 func NewHoldPDS(ctx context.Context, did, publicURL, dbPath, keyPath string) (*HoldPDS, error) { 44 - // Ensure directory exists 45 - dir := filepath.Dir(dbPath) 46 - if err := os.MkdirAll(dir, 0755); err != nil { 47 - return nil, fmt.Errorf("failed to create database directory: %w", err) 48 - } 49 - 50 43 // Generate or load signing key 51 44 signingKey, err := GenerateOrLoadKey(keyPath) 52 45 if err != nil { 53 46 return nil, fmt.Errorf("failed to initialize signing key: %w", err) 54 47 } 55 48 56 - // Create and open SQLite-backed carstore 57 - // dbPath is the directory, carstore creates and opens db.sqlite3 inside it 58 - sqlStore, err := carstore.NewSqliteStore(dbPath) 59 - if err != nil { 60 - return nil, fmt.Errorf("failed to create sqlite store: %w", err) 49 + // Create SQLite-backed carstore 50 + var sqlStore *carstore.SQLiteStore 51 + 52 + if dbPath == ":memory:" { 53 + // In-memory mode for tests: create carstore manually and open with :memory: 54 + sqlStore = new(carstore.SQLiteStore) 55 + if err := sqlStore.Open(":memory:"); err != nil { 56 + return nil, fmt.Errorf("failed to open in-memory sqlite store: %w", err) 57 + } 58 + } else { 59 + // File mode for production: create directory and use NewSqliteStore 60 + dir := filepath.Dir(dbPath) 61 + if err := os.MkdirAll(dir, 0755); err != nil { 62 + return nil, fmt.Errorf("failed to create database directory: %w", err) 63 + } 64 + 65 + // dbPath is the directory, carstore creates and opens db.sqlite3 inside it 66 + sqlStore, err = carstore.NewSqliteStore(dbPath) 67 + if err != nil { 68 + return nil, fmt.Errorf("failed to create sqlite store: %w", err) 69 + } 61 70 } 62 71 63 72 // Use SQLiteStore directly, not the CarStore() wrapper ··· 84 93 fmt.Printf("New hold repo - will be initialized in Bootstrap\n") 85 94 } 86 95 87 - // Create or open authentication database 88 - authDB, err := NewDatabase(dbPath) 89 - if err != nil { 90 - return nil, fmt.Errorf("failed to create auth database: %w", err) 91 - } 92 - 93 96 return &HoldPDS{ 94 97 did: did, 95 98 PublicURL: publicURL, ··· 98 101 dbPath: dbPath, 99 102 uid: uid, 100 103 signingKey: signingKey, 101 - authDB: authDB, 102 104 }, nil 103 105 } 104 106 ··· 184 186 } else { 185 187 fmt.Printf("✅ Bluesky profile record already exists, skipping\n") 186 188 } 187 - 188 - // Create Tangled profile record (idempotent - check if exists first) 189 - _, _, err = p.GetTangledProfileRecord(ctx) 190 - if err != nil { 191 - // Tangled profile doesn't exist, create it 192 - description := "ahoy from the cargo hold" 193 - links := []string{"https://atcr.io"} 194 - 195 - _, err = p.CreateTangledProfileRecord(ctx, links, description) 196 - if err != nil { 197 - return fmt.Errorf("failed to create tangled profile record: %w", err) 198 - } 199 - fmt.Printf("✅ Created Tangled profile record\n") 200 - } else { 201 - fmt.Printf("✅ Tangled profile record already exists, skipping\n") 202 - } 203 - } 204 - 205 - // Create bootstrap app password if none exist (one-time setup) 206 - passwords, err := p.authDB.ListAppPasswords() 207 - if err != nil { 208 - return fmt.Errorf("failed to list app passwords: %w", err) 209 - } 210 - 211 - if len(passwords) == 0 { 212 - // No app passwords exist, create one 213 - password, err := p.CreateAppPassword("bootstrap") 214 - if err != nil { 215 - return fmt.Errorf("failed to create bootstrap app password: %w", err) 216 - } 217 - 218 - fmt.Printf("\n") 219 - fmt.Printf("╔════════════════════════════════════════════════════════════════╗\n") 220 - fmt.Printf("║ 🔑 APP PASSWORD CREATED ║\n") 221 - fmt.Printf("╠════════════════════════════════════════════════════════════════╣\n") 222 - fmt.Printf("║ ║\n") 223 - fmt.Printf("║ Password: %-51s ║\n", password) 224 - fmt.Printf("║ ║\n") 225 - fmt.Printf("║ ⚠️ SAVE THIS PASSWORD - it will not be shown again ║\n") 226 - fmt.Printf("║ ║\n") 227 - fmt.Printf("║ Use this password to log into Bluesky app or CLI tools ║\n") 228 - fmt.Printf("║ PDS URL: %-51s ║\n", p.PublicURL) 229 - fmt.Printf("║ Username: %-50s ║\n", p.did) 230 - fmt.Printf("║ ║\n") 231 - fmt.Printf("╚════════════════════════════════════════════════════════════════╝\n") 232 - fmt.Printf("\n") 233 - } else { 234 - fmt.Printf("✅ App passwords already exist (count: %d), skipping auto-generation\n", len(passwords)) 235 189 } 236 190 237 191 return nil
-381
pkg/hold/pds/session.go
··· 1 - package pds 2 - 3 - import ( 4 - "encoding/json" 5 - "fmt" 6 - "io" 7 - "net/http" 8 - "strings" 9 - 10 - cbg "github.com/whyrusleeping/cbor-gen" 11 - "github.com/ipfs/go-cid" 12 - ) 13 - 14 - // CreateSessionRequest represents a session creation request 15 - type CreateSessionRequest struct { 16 - Identifier string `json:"identifier"` // DID or handle 17 - Password string `json:"password"` // App password 18 - } 19 - 20 - // CreateSessionResponse represents a successful session creation 21 - type CreateSessionResponse struct { 22 - AccessJwt string `json:"accessJwt"` 23 - RefreshJwt string `json:"refreshJwt"` 24 - Handle string `json:"handle"` 25 - DID string `json:"did"` 26 - DIDDoc map[string]interface{} `json:"didDoc,omitempty"` // Optional DID document 27 - Email string `json:"email,omitempty"` // Optional, not used for holds 28 - Active *bool `json:"active,omitempty"` // Optional account status 29 - Status string `json:"status,omitempty"` // Optional account status 30 - } 31 - 32 - // SessionInfo represents session information 33 - type SessionInfo struct { 34 - Handle string `json:"handle"` 35 - DID string `json:"did"` 36 - Email string `json:"email,omitempty"` 37 - } 38 - 39 - // HandleCreateSession handles com.atproto.server.createSession 40 - func (h *XRPCHandler) HandleCreateSession(w http.ResponseWriter, r *http.Request) { 41 - // Parse request 42 - var req CreateSessionRequest 43 - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { 44 - http.Error(w, fmt.Sprintf("invalid request body: %v", err), http.StatusBadRequest) 45 - return 46 - } 47 - 48 - // Validate required fields 49 - if req.Identifier == "" || req.Password == "" { 50 - http.Error(w, "identifier and password are required", http.StatusBadRequest) 51 - return 52 - } 53 - 54 - // Validate identifier matches this hold's DID or handle 55 - holdDID := h.pds.DID() 56 - // For did:web, handle is the domain part without "did:web:" prefix 57 - // e.g., "did:web:hold01.atcr.io" -> "hold01.atcr.io" 58 - holdHandle := strings.TrimPrefix(holdDID, "did:web:") 59 - 60 - // Normalize the identifier (strip "at://" prefix if present) 61 - identifier := strings.TrimPrefix(req.Identifier, "at://") 62 - 63 - // Accept any of: 64 - // 1. Full DID: "did:web:hold01.atcr.io" 65 - // 2. Handle (domain): "hold01.atcr.io" 66 - // 3. Either with "at://" prefix 67 - isValidIdentifier := identifier == holdDID || identifier == holdHandle 68 - 69 - if !isValidIdentifier { 70 - fmt.Printf("Invalid identifier: got %q, expected DID %q or handle %q\n", req.Identifier, holdDID, holdHandle) 71 - http.Error(w, "invalid identifier", http.StatusUnauthorized) 72 - return 73 - } 74 - 75 - // Validate app password 76 - _, err := h.pds.ValidateAnyAppPassword(req.Password) 77 - if err != nil { 78 - http.Error(w, "invalid password", http.StatusUnauthorized) 79 - return 80 - } 81 - 82 - // Issue access and refresh tokens 83 - accessToken, err := h.pds.IssueAccessToken(holdDID, holdHandle) 84 - if err != nil { 85 - http.Error(w, fmt.Sprintf("failed to issue access token: %v", err), http.StatusInternalServerError) 86 - return 87 - } 88 - 89 - refreshToken, err := h.pds.IssueRefreshToken(holdDID, holdHandle) 90 - if err != nil { 91 - http.Error(w, fmt.Sprintf("failed to issue refresh token: %v", err), http.StatusInternalServerError) 92 - return 93 - } 94 - 95 - // Return session response 96 - active := true 97 - response := CreateSessionResponse{ 98 - AccessJwt: accessToken, 99 - RefreshJwt: refreshToken, 100 - Handle: holdHandle, 101 - DID: holdDID, 102 - Active: &active, // Account is active 103 - } 104 - 105 - w.Header().Set("Content-Type", "application/json") 106 - json.NewEncoder(w).Encode(response) 107 - } 108 - 109 - // HandleRefreshSession handles com.atproto.server.refreshSession 110 - func (h *XRPCHandler) HandleRefreshSession(w http.ResponseWriter, r *http.Request) { 111 - // Extract refresh token from Authorization header 112 - authHeader := r.Header.Get("Authorization") 113 - if authHeader == "" { 114 - http.Error(w, "authorization header required", http.StatusUnauthorized) 115 - return 116 - } 117 - 118 - // Remove "Bearer " prefix 119 - refreshToken := strings.TrimPrefix(authHeader, "Bearer ") 120 - if refreshToken == authHeader { 121 - http.Error(w, "invalid authorization header format", http.StatusUnauthorized) 122 - return 123 - } 124 - 125 - // Validate refresh token 126 - claims, err := h.pds.ValidateRefreshToken(refreshToken) 127 - if err != nil { 128 - http.Error(w, fmt.Sprintf("invalid refresh token: %v", err), http.StatusUnauthorized) 129 - return 130 - } 131 - 132 - // Issue new access token (and optionally new refresh token) 133 - accessToken, err := h.pds.IssueAccessToken(claims.DID, claims.Handle) 134 - if err != nil { 135 - http.Error(w, fmt.Sprintf("failed to issue access token: %v", err), http.StatusInternalServerError) 136 - return 137 - } 138 - 139 - // Issue new refresh token (rotate refresh tokens for security) 140 - newRefreshToken, err := h.pds.IssueRefreshToken(claims.DID, claims.Handle) 141 - if err != nil { 142 - http.Error(w, fmt.Sprintf("failed to issue refresh token: %v", err), http.StatusInternalServerError) 143 - return 144 - } 145 - 146 - // Revoke old refresh token 147 - if err := h.pds.RevokeRefreshToken(refreshToken); err != nil { 148 - // Log but don't fail - new tokens are already issued 149 - fmt.Printf("Warning: failed to revoke old refresh token: %v\n", err) 150 - } 151 - 152 - // Return new tokens 153 - active := true 154 - response := CreateSessionResponse{ 155 - AccessJwt: accessToken, 156 - RefreshJwt: newRefreshToken, 157 - Handle: claims.Handle, 158 - DID: claims.DID, 159 - Active: &active, // Account is active 160 - } 161 - 162 - w.Header().Set("Content-Type", "application/json") 163 - json.NewEncoder(w).Encode(response) 164 - } 165 - 166 - // HandleGetSession handles com.atproto.server.getSession 167 - func (h *XRPCHandler) HandleGetSession(w http.ResponseWriter, r *http.Request) { 168 - // Extract access token from Authorization header 169 - authHeader := r.Header.Get("Authorization") 170 - if authHeader == "" { 171 - http.Error(w, "authorization header required", http.StatusUnauthorized) 172 - return 173 - } 174 - 175 - // Remove "Bearer " prefix 176 - accessToken := strings.TrimPrefix(authHeader, "Bearer ") 177 - if accessToken == authHeader { 178 - http.Error(w, "invalid authorization header format", http.StatusUnauthorized) 179 - return 180 - } 181 - 182 - // Validate access token 183 - claims, err := h.pds.ValidateAccessToken(accessToken) 184 - if err != nil { 185 - http.Error(w, fmt.Sprintf("invalid access token: %v", err), http.StatusUnauthorized) 186 - return 187 - } 188 - 189 - // Return session info 190 - response := SessionInfo{ 191 - Handle: claims.Handle, 192 - DID: claims.DID, 193 - } 194 - 195 - w.Header().Set("Content-Type", "application/json") 196 - json.NewEncoder(w).Encode(response) 197 - } 198 - 199 - // CreateRecordRequest represents a record creation request 200 - type CreateRecordRequest struct { 201 - Repo string `json:"repo"` // DID of the repository 202 - Collection string `json:"collection"` // Collection name (e.g., "app.bsky.feed.post") 203 - Rkey string `json:"rkey,omitempty"` // Optional record key (TID generated if not provided) 204 - Validate *bool `json:"validate,omitempty"` // Optional validation flag 205 - Record interface{} `json:"record"` // The record value (JSON object) 206 - } 207 - 208 - // CreateRecordResponse represents a successful record creation 209 - type CreateRecordResponse struct { 210 - URI string `json:"uri"` // at://did/collection/rkey 211 - CID string `json:"cid"` // Record CID 212 - } 213 - 214 - // RawRecord wraps a record value and implements CBORMarshaler 215 - // This allows us to accept any JSON record and marshal it to CBOR 216 - type RawRecord struct { 217 - Value map[string]interface{} 218 - } 219 - 220 - // MarshalCBOR implements CBORMarshaler for RawRecord 221 - func (r *RawRecord) MarshalCBOR(w io.Writer) error { 222 - // Write CBOR map header 223 - if err := cbg.WriteMajorTypeHeader(w, cbg.MajMap, uint64(len(r.Value))); err != nil { 224 - return err 225 - } 226 - 227 - // Write each key-value pair 228 - for key, val := range r.Value { 229 - // Write key as text string 230 - if err := cbg.WriteMajorTypeHeader(w, cbg.MajTextString, uint64(len(key))); err != nil { 231 - return err 232 - } 233 - if _, err := w.Write([]byte(key)); err != nil { 234 - return err 235 - } 236 - 237 - // Write value (simplified - handles common types) 238 - if err := writeValue(w, val); err != nil { 239 - return err 240 - } 241 - } 242 - 243 - return nil 244 - } 245 - 246 - // writeValue writes a value to CBOR (helper for RawRecord) 247 - func writeValue(w io.Writer, val interface{}) error { 248 - switch v := val.(type) { 249 - case string: 250 - if err := cbg.WriteMajorTypeHeader(w, cbg.MajTextString, uint64(len(v))); err != nil { 251 - return err 252 - } 253 - _, err := w.Write([]byte(v)) 254 - return err 255 - case int64: 256 - return cbg.CborWriteHeader(w, cbg.MajUnsignedInt, uint64(v)) 257 - case float64: 258 - // Write as unsigned int for now (simplified) 259 - return cbg.CborWriteHeader(w, cbg.MajUnsignedInt, uint64(v)) 260 - case bool: 261 - if v { 262 - return cbg.WriteBool(w, true) 263 - } 264 - return cbg.WriteBool(w, false) 265 - case map[string]interface{}: 266 - rec := &RawRecord{Value: v} 267 - return rec.MarshalCBOR(w) 268 - case []interface{}: 269 - if err := cbg.WriteMajorTypeHeader(w, cbg.MajArray, uint64(len(v))); err != nil { 270 - return err 271 - } 272 - for _, item := range v { 273 - if err := writeValue(w, item); err != nil { 274 - return err 275 - } 276 - } 277 - return nil 278 - default: 279 - // For unknown types, convert to JSON then write as string 280 - jsonBytes, err := json.Marshal(v) 281 - if err != nil { 282 - return err 283 - } 284 - if err := cbg.WriteMajorTypeHeader(w, cbg.MajTextString, uint64(len(jsonBytes))); err != nil { 285 - return err 286 - } 287 - _, err = w.Write(jsonBytes) 288 - return err 289 - } 290 - } 291 - 292 - // HandleCreateRecord handles com.atproto.repo.createRecord 293 - func (h *XRPCHandler) HandleCreateRecord(w http.ResponseWriter, r *http.Request) { 294 - // Validate JWT authentication 295 - user, err := ValidateJWTAuth(r, h.pds) 296 - if err != nil { 297 - http.Error(w, fmt.Sprintf("authentication required: %v", err), http.StatusUnauthorized) 298 - return 299 - } 300 - 301 - // Parse request 302 - var req CreateRecordRequest 303 - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { 304 - http.Error(w, fmt.Sprintf("invalid request body: %v", err), http.StatusBadRequest) 305 - return 306 - } 307 - 308 - // Validate required fields 309 - if req.Repo == "" || req.Collection == "" || req.Record == nil { 310 - http.Error(w, "repo, collection, and record are required", http.StatusBadRequest) 311 - return 312 - } 313 - 314 - // Verify repo matches authenticated user 315 - if req.Repo != user.DID { 316 - http.Error(w, "repo must match authenticated user DID", http.StatusForbidden) 317 - return 318 - } 319 - 320 - // Verify repo matches this hold's DID 321 - if req.Repo != h.pds.DID() { 322 - http.Error(w, "invalid repo (must be this hold's DID)", http.StatusBadRequest) 323 - return 324 - } 325 - 326 - // Convert record from JSON to CBOR-marshalable format 327 - recordMap, ok := req.Record.(map[string]interface{}) 328 - if !ok { 329 - http.Error(w, "record must be a JSON object", http.StatusBadRequest) 330 - return 331 - } 332 - 333 - // Wrap in RawRecord which implements CBORMarshaler 334 - recordValue := &RawRecord{Value: recordMap} 335 - 336 - // Create record using repomgr 337 - var recordPath string 338 - var recordCID cid.Cid 339 - 340 - if req.Rkey != "" { 341 - // Use PutRecord if rkey is specified 342 - recordPath, recordCID, err = h.pds.repomgr.PutRecord( 343 - r.Context(), 344 - h.pds.uid, 345 - req.Collection, 346 - req.Rkey, 347 - recordValue, 348 - ) 349 - } else { 350 - // Use CreateRecord if no rkey (auto-generates TID) 351 - recordPath, recordCID, err = h.pds.repomgr.CreateRecord( 352 - r.Context(), 353 - h.pds.uid, 354 - req.Collection, 355 - recordValue, 356 - ) 357 - } 358 - 359 - if err != nil { 360 - http.Error(w, fmt.Sprintf("failed to create record: %v", err), http.StatusInternalServerError) 361 - return 362 - } 363 - 364 - // Extract rkey from path (format: "collection/rkey") 365 - parts := strings.Split(recordPath, "/") 366 - if len(parts) < 2 { 367 - http.Error(w, "invalid record path returned", http.StatusInternalServerError) 368 - return 369 - } 370 - actualRkey := parts[len(parts)-1] 371 - 372 - // Return success response 373 - response := CreateRecordResponse{ 374 - URI: fmt.Sprintf("at://%s/%s/%s", h.pds.DID(), req.Collection, actualRkey), 375 - CID: recordCID.String(), 376 - } 377 - 378 - w.Header().Set("Content-Type", "application/json") 379 - w.WriteHeader(http.StatusCreated) 380 - json.NewEncoder(w).Encode(response) 381 - }
+11 -57
pkg/hold/pds/status.go
··· 6 6 "time" 7 7 8 8 bsky "github.com/bluesky-social/indigo/api/bsky" 9 - "github.com/ipfs/go-cid" 10 9 ) 11 10 12 11 const ( 13 - // StatusPostRkey is the fixed rkey for the status post (singleton) 14 - StatusPostRkey = "status" 15 - 16 12 // StatusPostCollection is the collection name for Bluesky posts 17 13 StatusPostCollection = "app.bsky.feed.post" 18 14 ) 19 15 20 - // SetStatus creates or updates the hold's status post on Bluesky 16 + // SetStatus creates a new status post on Bluesky 21 17 // status should be "online" or "offline" 18 + // Each call creates a unique post with a TID-based rkey 22 19 func (p *HoldPDS) SetStatus(ctx context.Context, status string) error { 23 20 // Format the post text with emoji indicator 24 21 emoji := "🟢" ··· 27 24 } 28 25 text := fmt.Sprintf("%s Current status: %s", emoji, status) 29 26 30 - // Check if status post already exists 31 - _, existingPost, err := p.GetStatusPost(ctx) 32 - if err != nil { 33 - // Post doesn't exist, create it 34 - return p.createStatusPost(ctx, text) 35 - } 36 - 37 - // Post exists, update it 38 - // We need to preserve the original CreatedAt timestamp 39 - return p.updateStatusPost(ctx, text, existingPost.CreatedAt) 40 - } 41 - 42 - // GetStatusPost retrieves the status post if it exists 43 - func (p *HoldPDS) GetStatusPost(ctx context.Context) (cid.Cid, *bsky.FeedPost, error) { 44 - // Use repomgr.GetRecord 45 - recordCID, val, err := p.repomgr.GetRecord(ctx, p.uid, StatusPostCollection, StatusPostRkey, cid.Undef) 46 - if err != nil { 47 - return cid.Undef, nil, fmt.Errorf("failed to get status post: %w", err) 48 - } 49 - 50 - // Type assert to bsky.FeedPost 51 - post, ok := val.(*bsky.FeedPost) 52 - if !ok { 53 - return cid.Undef, nil, fmt.Errorf("unexpected type for status post: %T", val) 54 - } 55 - 56 - return recordCID, post, nil 27 + // Create the post with a unique TID 28 + return p.createStatusPost(ctx, text) 57 29 } 58 30 59 - // createStatusPost creates a new status post (first time) 31 + // createStatusPost creates a new status post with a TID-based rkey 60 32 func (p *HoldPDS) createStatusPost(ctx context.Context, text string) error { 61 33 // Create post struct 62 - now := time.Now().Format(time.RFC3339) 34 + now := time.Now() 63 35 post := &bsky.FeedPost{ 64 36 LexiconTypeID: "app.bsky.feed.post", 65 37 Text: text, 66 - CreatedAt: now, 38 + CreatedAt: now.Format(time.RFC3339), 67 39 } 68 40 69 - // Use repomgr.PutRecord - creates with explicit rkey, fails if already exists 70 - recordPath, recordCID, err := p.repomgr.PutRecord(ctx, p.uid, StatusPostCollection, StatusPostRkey, post) 41 + // Use repomgr.CreateRecord to create the post with auto-generated TID 42 + // CreateRecord automatically generates a unique TID using the repo's clock 43 + rkey, recordCID, err := p.repomgr.CreateRecord(ctx, p.uid, StatusPostCollection, post) 71 44 if err != nil { 72 45 return fmt.Errorf("failed to create status post: %w", err) 73 46 } 74 47 75 - fmt.Printf("Created status post at %s, cid: %s, text: %s\n", recordPath, recordCID, text) 76 - return nil 77 - } 78 - 79 - // updateStatusPost updates an existing status post 80 - func (p *HoldPDS) updateStatusPost(ctx context.Context, text string, createdAt string) error { 81 - // Create updated post struct with original CreatedAt 82 - post := &bsky.FeedPost{ 83 - LexiconTypeID: "app.bsky.feed.post", 84 - Text: text, 85 - CreatedAt: createdAt, // Preserve original creation time 86 - } 87 - 88 - // Use repomgr.UpdateRecord 89 - recordCID, err := p.repomgr.UpdateRecord(ctx, p.uid, StatusPostCollection, StatusPostRkey, post) 90 - if err != nil { 91 - return fmt.Errorf("failed to update status post: %w", err) 92 - } 93 - 94 - fmt.Printf("Updated status post, cid: %s, text: %s\n", recordCID, text) 48 + fmt.Printf("Created status post at %s/%s (rkey: %s), cid: %s, text: %s\n", StatusPostCollection, rkey, rkey, recordCID, text) 95 49 return nil 96 50 }
+201 -64
pkg/hold/pds/status_test.go
··· 2 2 3 3 import ( 4 4 "context" 5 + "encoding/json" 6 + "fmt" 7 + "net/http" 8 + "net/http/httptest" 5 9 "os" 6 10 "path/filepath" 7 11 "testing" 12 + "time" 8 13 14 + "atcr.io/pkg/atproto" 15 + "atcr.io/pkg/s3" 9 16 bsky "github.com/bluesky-social/indigo/api/bsky" 10 17 ) 11 18 19 + // Shared test resources (used across all test files in package) 20 + var ( 21 + sharedTestKeyPath string 22 + sharedTestKey []byte 23 + sharedPDS *HoldPDS // Shared bootstrapped PDS for read-only tests 24 + sharedHandler *XRPCHandler // Shared handler for read-only tests 25 + sharedCtx context.Context // Shared context 26 + ) 27 + 12 28 func TestStatusPost(t *testing.T) { 13 - // Create temporary directory for test database 29 + // Create temporary directory for test 14 30 tmpDir := t.TempDir() 15 - dbPath := filepath.Join(tmpDir, "test.db") 31 + // Use in-memory database for speed 32 + dbPath := ":memory:" 16 33 keyPath := filepath.Join(tmpDir, "test.key") 34 + 35 + // Copy shared signing key 36 + if err := os.WriteFile(keyPath, sharedTestKey, 0600); err != nil { 37 + t.Fatalf("Failed to copy shared signing key: %v", err) 38 + } 17 39 18 40 // Create test PDS 19 41 ctx := context.Background() ··· 31 53 t.Fatalf("Failed to initialize repo: %v", err) 32 54 } 33 55 56 + // Create handler for XRPC endpoints 57 + handler := NewXRPCHandler(holdPDS, s3.S3Service{}, nil, nil, &mockPDSClient{}) 58 + 59 + // Helper function to list posts via XRPC 60 + listPosts := func() ([]map[string]any, error) { 61 + req := makeXRPCGetRequest(atproto.RepoListRecords, map[string]string{ 62 + "repo": did, 63 + "collection": StatusPostCollection, 64 + "limit": "100", 65 + "reverse": "true", // Most recent first 66 + }) 67 + w := httptest.NewRecorder() 68 + handler.HandleListRecords(w, req) 69 + 70 + if w.Code != http.StatusOK { 71 + return nil, fmt.Errorf("unexpected status code: %d, body: %s", w.Code, w.Body.String()) 72 + } 73 + 74 + var result map[string]any 75 + if err := json.NewDecoder(w.Body).Decode(&result); err != nil { 76 + return nil, fmt.Errorf("failed to decode response: %w", err) 77 + } 78 + 79 + records, ok := result["records"].([]any) 80 + if !ok { 81 + return nil, fmt.Errorf("expected records array, got %T", result["records"]) 82 + } 83 + 84 + posts := make([]map[string]any, len(records)) 85 + for i, rec := range records { 86 + post, ok := rec.(map[string]any) 87 + if !ok { 88 + return nil, fmt.Errorf("expected record map, got %T", rec) 89 + } 90 + posts[i] = post 91 + } 92 + return posts, nil 93 + } 94 + 34 95 t.Run("CreateStatusPost", func(t *testing.T) { 35 96 // Set status to online (creates new post) 36 97 err := holdPDS.SetStatus(ctx, "online") ··· 38 99 t.Fatalf("Failed to set status to online: %v", err) 39 100 } 40 101 41 - // Verify post was created 42 - _, post, err := holdPDS.GetStatusPost(ctx) 102 + // List posts 103 + posts, err := listPosts() 43 104 if err != nil { 44 - t.Fatalf("Failed to get status post: %v", err) 105 + t.Fatalf("Failed to list posts: %v", err) 106 + } 107 + 108 + if len(posts) == 0 { 109 + t.Fatal("Expected at least one status post, got 0") 110 + } 111 + 112 + // Get the latest post 113 + post := posts[0] 114 + 115 + value, ok := post["value"].(map[string]any) 116 + if !ok { 117 + t.Fatalf("Expected value map, got %T", post["value"]) 45 118 } 46 119 47 - if post.Text != "🟢 Current status: online" { 48 - t.Errorf("Expected text '🟢 Current status: online', got '%s'", post.Text) 120 + text, ok := value["text"].(string) 121 + if !ok { 122 + t.Fatalf("Expected text string, got %T", value["text"]) 49 123 } 50 124 51 - if post.LexiconTypeID != "app.bsky.feed.post" { 52 - t.Errorf("Expected LexiconTypeID 'app.bsky.feed.post', got '%s'", post.LexiconTypeID) 125 + if text != "🟢 Current status: online" { 126 + t.Errorf("Expected text '🟢 Current status: online', got '%s'", text) 53 127 } 54 128 55 - if post.CreatedAt == "" { 56 - t.Error("CreatedAt should not be empty") 129 + // Verify TID-based rkey (extract from URI) 130 + uri, ok := post["uri"].(string) 131 + if !ok { 132 + t.Fatalf("Expected uri string, got %T", post["uri"]) 133 + } 134 + // URI format: at://did:web:test.example.com/app.bsky.feed.post/3m3c4... 135 + // We just check that it contains the collection 136 + if !contains(uri, StatusPostCollection) { 137 + t.Errorf("Expected URI to contain collection %s, got %s", StatusPostCollection, uri) 57 138 } 58 139 }) 59 140 60 - t.Run("UpdateStatusPost", func(t *testing.T) { 61 - // Get the original post to check CreatedAt preservation 62 - _, originalPost, err := holdPDS.GetStatusPost(ctx) 141 + t.Run("CreateMultiplePosts", func(t *testing.T) { 142 + // Create multiple status posts 143 + err := holdPDS.SetStatus(ctx, "offline") 63 144 if err != nil { 64 - t.Fatalf("Failed to get original status post: %v", err) 145 + t.Fatalf("Failed to set status to offline: %v", err) 65 146 } 66 147 67 - // Set status to offline (updates existing post) 68 - err = holdPDS.SetStatus(ctx, "offline") 148 + // Wait a moment to ensure different timestamp 149 + time.Sleep(10 * time.Millisecond) 150 + 151 + err = holdPDS.SetStatus(ctx, "online") 69 152 if err != nil { 70 - t.Fatalf("Failed to set status to offline: %v", err) 153 + t.Fatalf("Failed to set status to online again: %v", err) 71 154 } 72 155 73 - // Verify post was updated 74 - _, post, err := holdPDS.GetStatusPost(ctx) 156 + // List all posts - should have at least 3 now (1 from previous test + 2 from this test) 157 + posts, err := listPosts() 75 158 if err != nil { 76 - t.Fatalf("Failed to get updated status post: %v", err) 159 + t.Fatalf("Failed to list posts: %v", err) 77 160 } 78 161 79 - if post.Text != "🔴 Current status: offline" { 80 - t.Errorf("Expected text '🔴 Current status: offline', got '%s'", post.Text) 162 + if len(posts) < 3 { 163 + t.Errorf("Expected at least 3 status posts, got %d", len(posts)) 164 + } 165 + 166 + // Verify each post has a unique URI 167 + uris := make(map[string]bool) 168 + for _, post := range posts { 169 + uri, ok := post["uri"].(string) 170 + if !ok { 171 + t.Errorf("Expected uri string, got %T", post["uri"]) 172 + continue 173 + } 174 + if uris[uri] { 175 + t.Errorf("Duplicate URI found: %s", uri) 176 + } 177 + uris[uri] = true 81 178 } 82 179 83 - // Verify CreatedAt was preserved 84 - if post.CreatedAt != originalPost.CreatedAt { 85 - t.Errorf("CreatedAt should be preserved. Expected '%s', got '%s'", originalPost.CreatedAt, post.CreatedAt) 180 + // Verify the latest post is online 181 + latestPost := posts[0] 182 + value, ok := latestPost["value"].(map[string]any) 183 + if !ok { 184 + t.Fatalf("Expected value map, got %T", latestPost["value"]) 185 + } 186 + text, ok := value["text"].(string) 187 + if !ok { 188 + t.Fatalf("Expected text string, got %T", value["text"]) 189 + } 190 + if text != "🟢 Current status: online" { 191 + t.Errorf("Expected latest post text '🟢 Current status: online', got '%s'", text) 86 192 } 87 193 }) 88 194 89 - t.Run("ToggleStatus", func(t *testing.T) { 90 - // Toggle back to online 91 - err := holdPDS.SetStatus(ctx, "online") 195 + t.Run("OfflineStatus", func(t *testing.T) { 196 + // Create offline status post 197 + err := holdPDS.SetStatus(ctx, "offline") 92 198 if err != nil { 93 - t.Fatalf("Failed to set status to online: %v", err) 199 + t.Fatalf("Failed to set status to offline: %v", err) 94 200 } 95 201 96 - _, post, err := holdPDS.GetStatusPost(ctx) 202 + // Get the latest post 203 + posts, err := listPosts() 97 204 if err != nil { 98 - t.Fatalf("Failed to get status post: %v", err) 205 + t.Fatalf("Failed to list posts: %v", err) 206 + } 207 + 208 + if len(posts) == 0 { 209 + t.Fatal("Expected at least one status post, got 0") 210 + } 211 + 212 + latestPost := posts[0] 213 + value, ok := latestPost["value"].(map[string]any) 214 + if !ok { 215 + t.Fatalf("Expected value map, got %T", latestPost["value"]) 216 + } 217 + text, ok := value["text"].(string) 218 + if !ok { 219 + t.Fatalf("Expected text string, got %T", value["text"]) 99 220 } 100 221 101 - if post.Text != "🟢 Current status: online" { 102 - t.Errorf("Expected text '🟢 Current status: online', got '%s'", post.Text) 222 + if text != "🔴 Current status: offline" { 223 + t.Errorf("Expected text '🔴 Current status: offline', got '%s'", text) 103 224 } 104 225 }) 105 226 } 106 227 107 228 func TestStatusPostCollection(t *testing.T) { 108 - // Verify constants 229 + // Verify constant 109 230 if StatusPostCollection != "app.bsky.feed.post" { 110 231 t.Errorf("Expected StatusPostCollection 'app.bsky.feed.post', got '%s'", StatusPostCollection) 111 232 } 112 - 113 - if StatusPostRkey != "status" { 114 - t.Errorf("Expected StatusPostRkey 'status', got '%s'", StatusPostRkey) 115 - } 116 233 } 117 234 118 - func TestGetStatusPostNotExists(t *testing.T) { 119 - // Create temporary directory for test database 120 - tmpDir := t.TempDir() 121 - dbPath := filepath.Join(tmpDir, "test.db") 122 - keyPath := filepath.Join(tmpDir, "test.key") 235 + // Helper function to check if a string contains a substring 236 + func contains(s, substr string) bool { 237 + return len(s) >= len(substr) && (s == substr || len(s) > len(substr) && findSubstring(s, substr)) 238 + } 123 239 124 - // Create test PDS 125 - ctx := context.Background() 126 - did := "did:web:test2.example.com" 127 - publicURL := "https://test2.example.com" 128 - 129 - holdPDS, err := NewHoldPDS(ctx, did, publicURL, dbPath, keyPath) 130 - if err != nil { 131 - t.Fatalf("Failed to create test PDS: %v", err) 240 + func findSubstring(s, substr string) bool { 241 + for i := 0; i <= len(s)-len(substr); i++ { 242 + if s[i:i+len(substr)] == substr { 243 + return true 244 + } 132 245 } 133 - 134 - // Initialize empty repo 135 - err = holdPDS.repomgr.InitNewActor(ctx, holdPDS.uid, "", did, "", "", "") 136 - if err != nil { 137 - t.Fatalf("Failed to initialize repo: %v", err) 138 - } 139 - 140 - // Try to get status post that doesn't exist 141 - _, _, err = holdPDS.GetStatusPost(ctx) 142 - if err == nil { 143 - t.Error("Expected error when getting non-existent status post, got nil") 144 - } 246 + return false 145 247 } 146 248 147 249 func init() { ··· 154 256 155 257 // Cleanup function to remove test files 156 258 func TestMain(m *testing.M) { 259 + // Create a temporary directory for shared test key 260 + tmpDir, err := os.MkdirTemp("", "pds-test-shared-*") 261 + if err != nil { 262 + panic(fmt.Sprintf("Failed to create temp dir: %v", err)) 263 + } 264 + defer os.RemoveAll(tmpDir) 265 + 266 + // Generate one signing key to be reused across all tests in the package 267 + sharedTestKeyPath = filepath.Join(tmpDir, "shared-signing-key") 268 + privateKey, err := GenerateOrLoadKey(sharedTestKeyPath) 269 + if err != nil { 270 + panic(fmt.Sprintf("Failed to generate shared signing key: %v", err)) 271 + } 272 + 273 + // Store the key bytes so tests can copy them 274 + sharedTestKey = privateKey.Bytes() 275 + 276 + // Create one shared, bootstrapped PDS for read-only tests 277 + // Use in-memory database for speed 278 + sharedCtx = context.Background() 279 + sharedPDS, err = NewHoldPDS(sharedCtx, "did:web:hold.example.com", "https://hold.example.com", ":memory:", sharedTestKeyPath) 280 + if err != nil { 281 + panic(fmt.Sprintf("Failed to create shared PDS: %v", err)) 282 + } 283 + 284 + // Bootstrap once 285 + ownerDID := "did:plc:testowner123" 286 + err = sharedPDS.Bootstrap(sharedCtx, nil, ownerDID, true, false, "") 287 + if err != nil { 288 + panic(fmt.Sprintf("Failed to bootstrap shared PDS: %v", err)) 289 + } 290 + 291 + // Create shared handler 292 + sharedHandler = NewXRPCHandler(sharedPDS, s3.S3Service{}, nil, nil, &mockPDSClient{}) 293 + 157 294 // Run tests 158 295 code := m.Run() 159 296
+6 -29
pkg/hold/pds/xrpc.go
··· 92 92 93 93 // Handle OPTIONS preflight 94 94 if r.Method == "OPTIONS" { 95 - w.WriteHeader(http.StatusNoContent) 95 + w.WriteHeader(http.StatusOK) 96 96 return 97 97 } 98 98 ··· 150 150 r.Get("/xrpc/_health", h.HandleHealth) 151 151 r.Get(atproto.ServerDescribeServer, h.HandleDescribeServer) 152 152 153 - // Session management (public - creates sessions) 154 - r.Post(atproto.ServerCreateSession, h.HandleCreateSession) 155 - r.Post(atproto.ServerRefreshSession, h.HandleRefreshSession) 156 - r.Get(atproto.ServerGetSession, h.HandleGetSession) 157 - 158 153 // Repository metadata 159 154 r.Get(atproto.RepoDescribeRepo, h.HandleDescribeRepo) 160 155 r.Get(atproto.RepoGetRecord, h.HandleGetRecord) ··· 186 181 // Write endpoints (owner/crew admin auth) 187 182 r.Group(func(r chi.Router) { 188 183 r.Use(h.requireOwnerOrCrewAdmin) 189 - 190 184 r.Post(atproto.RepoDeleteRecord, h.HandleDeleteRecord) 191 185 r.Post(atproto.RepoUploadBlob, h.HandleUploadBlob) 192 186 }) ··· 194 188 // Auth-only endpoints (DPoP auth) 195 189 r.Group(func(r chi.Router) { 196 190 r.Use(h.requireAuth) 197 - 198 191 r.Post(atproto.HoldRequestCrew, h.HandleRequestCrew) 199 - }) 200 - 201 - // JWT-authenticated endpoints (JWT auth from createSession) 202 - // Note: JWT auth is validated inside each handler 203 - r.Group(func(r chi.Router) { 204 - r.Post("/xrpc/com.atproto.repo.createRecord", h.HandleCreateRecord) 205 192 }) 206 193 } 207 194 ··· 861 848 } 862 849 863 850 // Get optional cursor parameter for backfill 864 - var cursor int64 = 0 851 + // Default to -1 (no backfill, only stream new events) 852 + // cursor=0 means "replay all events from the beginning" 853 + var cursor int64 = -1 865 854 if cursorStr := r.URL.Query().Get("cursor"); cursorStr != "" { 866 855 var err error 867 856 cursor, err = strconv.ParseInt(cursorStr, 10, 64) ··· 879 868 } 880 869 881 870 // Subscribe to events 882 - sub := h.broadcaster.Subscribe(conn, cursor) 883 - 884 871 // The broadcaster's handleSubscriber goroutine will manage this connection 885 - // We just need to keep reading to detect client disconnects 886 - go func() { 887 - defer h.broadcaster.Unsubscribe(sub) 888 - for { 889 - // Read messages from client (mostly just to detect disconnect) 890 - _, _, err := conn.ReadMessage() 891 - if err != nil { 892 - // Client disconnected 893 - break 894 - } 895 - } 896 - }() 872 + // and handle cleanup when the client disconnects 873 + h.broadcaster.Subscribe(conn, cursor) 897 874 } 898 875 899 876 // HandleUploadBlob handles blob uploads with support for multipart operations
+257 -3
pkg/hold/pds/xrpc_test.go
··· 12 12 "path/filepath" 13 13 "strings" 14 14 "testing" 15 + "time" 15 16 16 17 "atcr.io/pkg/atproto" 17 18 "atcr.io/pkg/s3" 19 + indigoAtproto "github.com/bluesky-social/indigo/api/atproto" 20 + "github.com/bluesky-social/indigo/events" 18 21 "github.com/distribution/distribution/v3/registry/storage/driver/factory" 19 22 _ "github.com/distribution/distribution/v3/registry/storage/driver/filesystem" 20 23 "github.com/go-chi/chi/v5" 24 + "github.com/gorilla/websocket" 25 + "github.com/ipfs/go-cid" 21 26 ) 22 27 23 28 // Test helpers ··· 30 35 ctx := context.Background() 31 36 tmpDir := t.TempDir() 32 37 33 - dbPath := filepath.Join(tmpDir, "pds.db") 38 + // Use in-memory database for speed 39 + dbPath := ":memory:" 34 40 keyPath := filepath.Join(tmpDir, "signing-key") 41 + 42 + // Copy shared signing key instead of generating a new one 43 + if err := os.WriteFile(keyPath, sharedTestKey, 0600); err != nil { 44 + t.Fatalf("Failed to copy shared signing key: %v", err) 45 + } 35 46 36 47 pds, err := NewHoldPDS(ctx, "did:web:hold.example.com", "https://hold.example.com", dbPath, keyPath) 37 48 if err != nil { ··· 115 126 } 116 127 117 128 return result 129 + } 130 + 131 + // decodeFirehoseMessage decodes an ATProto firehose message (header + CBOR body) 132 + func decodeFirehoseMessage(t *testing.T, message []byte) (*events.EventHeader, *indigoAtproto.SyncSubscribeRepos_Commit) { 133 + t.Helper() 134 + 135 + reader := bytes.NewReader(message) 136 + 137 + // Decode header 138 + var header events.EventHeader 139 + if err := header.UnmarshalCBOR(reader); err != nil { 140 + t.Fatalf("Failed to decode event header: %v", err) 141 + } 142 + 143 + // Verify it's a commit event 144 + if header.MsgType != "#commit" { 145 + t.Fatalf("Expected #commit event, got %s", header.MsgType) 146 + } 147 + 148 + // Decode commit event 149 + var commit indigoAtproto.SyncSubscribeRepos_Commit 150 + if err := commit.UnmarshalCBOR(reader); err != nil { 151 + t.Fatalf("Failed to decode commit event: %v", err) 152 + } 153 + 154 + return &header, &commit 118 155 } 119 156 120 157 // assertCARResponse validates CAR file response ··· 1331 1368 ctx := context.Background() 1332 1369 tmpDir := t.TempDir() 1333 1370 1334 - dbPath := filepath.Join(tmpDir, "pds.db") 1371 + // Use in-memory database for speed 1372 + dbPath := ":memory:" 1335 1373 keyPath := filepath.Join(tmpDir, "signing-key") 1374 + 1375 + // Copy shared signing key instead of generating a new one 1376 + if err := os.WriteFile(keyPath, sharedTestKey, 0600); err != nil { 1377 + t.Fatalf("Failed to copy shared signing key: %v", err) 1378 + } 1336 1379 1337 1380 pds, err := NewHoldPDS(ctx, "did:web:hold.example.com", "https://hold.example.com", dbPath, keyPath) 1338 1381 if err != nil { ··· 1686 1729 w := httptest.NewRecorder() 1687 1730 1688 1731 // Wrap with CORS middleware (chi-style) 1689 - corsHandler := handler.corsMiddleware(http.HandlerFunc(handler.HandleGetBlob)) 1732 + corsHandler := handler.CORSMiddleware()(http.HandlerFunc(handler.HandleGetBlob)) 1690 1733 corsHandler.ServeHTTP(w, req) 1691 1734 1692 1735 // Verify CORS headers are present ··· 1719 1762 1720 1763 // Create chi router and register handlers 1721 1764 r := chi.NewRouter() 1765 + r.Use(handler.CORSMiddleware()) // Apply CORS middleware 1722 1766 handler.RegisterHandlers(r) 1723 1767 1724 1768 tests := []struct { ··· 2009 2053 }) 2010 2054 } 2011 2055 } 2056 + 2057 + // TestHandleSubscribeRepos tests the WebSocket firehose endpoint 2058 + func TestHandleSubscribeRepos(t *testing.T) { 2059 + handler, ctx := setupTestXRPCHandler(t) 2060 + 2061 + // Create EventBroadcaster 2062 + broadcaster := NewEventBroadcaster(handler.pds.DID(), 100) 2063 + handler.broadcaster = broadcaster 2064 + 2065 + // Set up test HTTP server 2066 + r := chi.NewRouter() 2067 + handler.RegisterHandlers(r) 2068 + server := httptest.NewServer(r) 2069 + defer server.Close() 2070 + 2071 + // Broadcast some events before connecting 2072 + testCID, _ := cid.Decode("bafyreib2rxk3rkhh5ylyxj3x3gathxt3s32qvwj2lf3qg4kmzr6b7teqke") 2073 + for i := 1; i <= 3; i++ { 2074 + event := &RepoEvent{ 2075 + NewRoot: testCID, 2076 + Rev: fmt.Sprintf("rev-%d", i), 2077 + RepoSlice: []byte(fmt.Sprintf("CAR data %d", i)), 2078 + Ops: []RepoOp{}, 2079 + } 2080 + broadcaster.Broadcast(ctx, event) 2081 + } 2082 + 2083 + // Verify events were stored 2084 + if broadcaster.eventSeq != 3 { 2085 + t.Fatalf("Expected eventSeq=3, got %d", broadcaster.eventSeq) 2086 + } 2087 + 2088 + t.Run("cursor=0 replays all events", func(t *testing.T) { 2089 + // Connect to WebSocket with cursor=0 2090 + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/xrpc/com.atproto.sync.subscribeRepos?cursor=0" 2091 + conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) 2092 + if err != nil { 2093 + t.Fatalf("Failed to connect to WebSocket: %v", err) 2094 + } 2095 + defer conn.Close() 2096 + 2097 + // Should receive the 3 historical events 2098 + for i := 0; i < 3; i++ { 2099 + messageType, message, err := conn.ReadMessage() 2100 + if err != nil { 2101 + t.Fatalf("Failed to read message: %v", err) 2102 + } 2103 + 2104 + if messageType != websocket.BinaryMessage { 2105 + t.Errorf("Expected binary message, got type %d", messageType) 2106 + } 2107 + 2108 + // Decode CBOR message (header + commit) 2109 + header, commit := decodeFirehoseMessage(t, message) 2110 + 2111 + // Verify header 2112 + if header.MsgType != "#commit" { 2113 + t.Errorf("Expected MsgType=#commit, got %s", header.MsgType) 2114 + } 2115 + 2116 + // Verify commit fields 2117 + expectedSeq := int64(i + 1) 2118 + if commit.Seq != expectedSeq { 2119 + t.Errorf("Expected seq=%d, got %d", expectedSeq, commit.Seq) 2120 + } 2121 + if commit.Repo != handler.pds.DID() { 2122 + t.Errorf("Expected repo=%s, got %s", handler.pds.DID(), commit.Repo) 2123 + } 2124 + } 2125 + }) 2126 + 2127 + t.Run("cursor=2 only replays events after 2", func(t *testing.T) { 2128 + // Connect with cursor=2 2129 + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/xrpc/com.atproto.sync.subscribeRepos?cursor=2" 2130 + conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) 2131 + if err != nil { 2132 + t.Fatalf("Failed to connect to WebSocket: %v", err) 2133 + } 2134 + defer conn.Close() 2135 + 2136 + // Should only receive event 3 (after cursor=2) 2137 + messageType, message, err := conn.ReadMessage() 2138 + if err != nil { 2139 + t.Fatalf("Failed to read message: %v", err) 2140 + } 2141 + 2142 + if messageType != websocket.BinaryMessage { 2143 + t.Errorf("Expected binary message, got type %d", messageType) 2144 + } 2145 + 2146 + header, commit := decodeFirehoseMessage(t, message) 2147 + if header.MsgType != "#commit" { 2148 + t.Errorf("Expected MsgType=#commit, got %s", header.MsgType) 2149 + } 2150 + 2151 + if commit.Seq != 3 { 2152 + t.Errorf("Expected seq=3, got %d", commit.Seq) 2153 + } 2154 + 2155 + // Verify no more events (use timeout) 2156 + conn.SetReadDeadline(time.Now().Add(100 * time.Millisecond)) 2157 + _, _, err = conn.ReadMessage() 2158 + if err == nil { 2159 + t.Error("Expected no more events, but received another message") 2160 + } 2161 + }) 2162 + 2163 + t.Run("no cursor streams only new events", func(t *testing.T) { 2164 + // Connect without cursor (should not get backfill) 2165 + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/xrpc/com.atproto.sync.subscribeRepos" 2166 + conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) 2167 + if err != nil { 2168 + t.Fatalf("Failed to connect to WebSocket: %v", err) 2169 + } 2170 + defer conn.Close() 2171 + 2172 + // Verify no historical events by broadcasting immediately and checking 2173 + // that we only receive the new event (not historical ones) 2174 + // Give subscriber time to register first 2175 + time.Sleep(100 * time.Millisecond) 2176 + 2177 + // Broadcast a new event (seq 4) 2178 + newEvent := &RepoEvent{ 2179 + NewRoot: testCID, 2180 + Rev: "rev-4", 2181 + RepoSlice: []byte("CAR data 4"), 2182 + Ops: []RepoOp{}, 2183 + } 2184 + broadcaster.Broadcast(ctx, newEvent) 2185 + 2186 + // Should receive ONLY the new event (seq 4), not historical events 1-3 2187 + conn.SetReadDeadline(time.Now().Add(1 * time.Second)) 2188 + messageType, message, err := conn.ReadMessage() 2189 + if err != nil { 2190 + t.Fatalf("Failed to read new event: %v", err) 2191 + } 2192 + 2193 + if messageType != websocket.BinaryMessage { 2194 + t.Errorf("Expected binary message, got type %d", messageType) 2195 + } 2196 + 2197 + header, commit := decodeFirehoseMessage(t, message) 2198 + if header.MsgType != "#commit" { 2199 + t.Errorf("Expected MsgType=#commit, got %s", header.MsgType) 2200 + } 2201 + 2202 + // Key assertion: should be seq 4 (new event), not seq 1 (historical backfill) 2203 + if commit.Seq != 4 { 2204 + t.Errorf("Expected seq=4 for new event (no backfill), got %d", commit.Seq) 2205 + } 2206 + 2207 + // Verify no more messages (no historical backfill) 2208 + conn.SetReadDeadline(time.Now().Add(100 * time.Millisecond)) 2209 + _, _, err = conn.ReadMessage() 2210 + if err == nil { 2211 + t.Error("Expected no more events, but received another message (possible backfill leak)") 2212 + } 2213 + }) 2214 + 2215 + t.Run("real-time event delivery", func(t *testing.T) { 2216 + // Connect with cursor=0 to get all events first 2217 + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/xrpc/com.atproto.sync.subscribeRepos?cursor=0" 2218 + conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) 2219 + if err != nil { 2220 + t.Fatalf("Failed to connect to WebSocket: %v", err) 2221 + } 2222 + defer conn.Close() 2223 + 2224 + // Read and discard the 4 historical events (seq 1-4) 2225 + for i := 0; i < 4; i++ { 2226 + _, _, err := conn.ReadMessage() 2227 + if err != nil { 2228 + t.Fatalf("Failed to read historical event %d: %v", i+1, err) 2229 + } 2230 + } 2231 + 2232 + // Broadcast 2 new events 2233 + for i := 5; i <= 6; i++ { 2234 + newEvent := &RepoEvent{ 2235 + NewRoot: testCID, 2236 + Rev: fmt.Sprintf("rev-%d", i), 2237 + RepoSlice: []byte(fmt.Sprintf("CAR data %d", i)), 2238 + Ops: []RepoOp{}, 2239 + } 2240 + broadcaster.Broadcast(ctx, newEvent) 2241 + } 2242 + 2243 + // Should receive both new events 2244 + for expectedSeq := 5; expectedSeq <= 6; expectedSeq++ { 2245 + conn.SetReadDeadline(time.Now().Add(1 * time.Second)) 2246 + messageType, message, err := conn.ReadMessage() 2247 + if err != nil { 2248 + t.Fatalf("Failed to read event seq=%d: %v", expectedSeq, err) 2249 + } 2250 + 2251 + if messageType != websocket.BinaryMessage { 2252 + t.Errorf("Expected binary message, got type %d", messageType) 2253 + } 2254 + 2255 + header, commit := decodeFirehoseMessage(t, message) 2256 + if header.MsgType != "#commit" { 2257 + t.Errorf("Expected MsgType=#commit, got %s", header.MsgType) 2258 + } 2259 + 2260 + if commit.Seq != int64(expectedSeq) { 2261 + t.Errorf("Expected seq=%d, got %d", expectedSeq, commit.Seq) 2262 + } 2263 + } 2264 + }) 2265 + }