A container registry that uses the AT Protocol for manifest storage and S3 for blob storage. atcr.io
docker container atproto go
81
fork

Configure Feed

Select the types of activity you want to include in your feed.

try and add cursor=0 to subscribe

+843 -1351
+1
cmd/hold/main.go
··· 147 147 // Update status post to "online" after server starts 148 148 if holdPDS != nil { 149 149 ctx := context.Background() 150 + 150 151 if err := holdPDS.SetStatus(ctx, "online"); err != nil { 151 152 log.Printf("Warning: Failed to set status post to online: %v", err) 152 153 } else {
+2 -2
go.mod
··· 4 4 5 5 require ( 6 6 github.com/aws/aws-sdk-go v1.55.5 7 - github.com/bluesky-social/indigo v0.0.0-20251014222321-1e8718ae9f33 7 + github.com/bluesky-social/indigo v0.0.0-20251021193747-543ab1124beb 8 8 github.com/distribution/distribution/v3 v3.0.0 9 9 github.com/distribution/reference v0.6.0 10 10 github.com/go-chi/chi/v5 v5.2.3 ··· 34 34 ) 35 35 36 36 require ( 37 + github.com/RussellLuo/slidingwindow v0.0.0-20200528002341-535bb99d338b // indirect 37 38 github.com/aymerick/douceur v0.2.0 // indirect 38 39 github.com/beorn7/perks v1.0.1 // indirect 39 40 github.com/bshuster-repo/logrus-logstash-hook v1.0.0 // indirect ··· 46 47 github.com/docker/go-metrics v0.0.1 // indirect 47 48 github.com/earthboundkid/versioninfo/v2 v2.24.1 // indirect 48 49 github.com/felixge/httpsnoop v1.0.4 // indirect 49 - github.com/go-chi/cors v1.2.2 // indirect 50 50 github.com/go-jose/go-jose/v4 v4.1.2 // indirect 51 51 github.com/go-logr/logr v1.4.2 // indirect 52 52 github.com/go-logr/stdr v1.2.2 // indirect
+6 -4
go.sum
··· 1 1 github.com/AdaLogics/go-fuzz-headers v0.0.0-20221103172237-443f56ff4ba8 h1:d+pBUmsteW5tM87xmVXHZ4+LibHRFn40SPAoZJOg2ak= 2 2 github.com/AdaLogics/go-fuzz-headers v0.0.0-20221103172237-443f56ff4ba8/go.mod h1:i9fr2JpcEcY/IHEvzCM3qXUZYOQHgR89dt4es1CgMhc= 3 3 github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= 4 + github.com/RussellLuo/slidingwindow v0.0.0-20200528002341-535bb99d338b h1:5/++qT1/z812ZqBvqQt6ToRswSuPZ/B33m6xVHRzADU= 5 + github.com/RussellLuo/slidingwindow v0.0.0-20200528002341-535bb99d338b/go.mod h1:4+EPqMRApwwE/6yo6CxiHoSnBzjRr3jsqer7frxP8y4= 4 6 github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= 5 7 github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= 6 8 github.com/alexbrainman/goissue34681 v0.0.0-20191006012335-3fc7a47baff5 h1:iW0a5ljuFxkLGPNem5Ui+KBjFJzKg4Fv2fnxe4dvzpM= ··· 18 20 github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= 19 21 github.com/bitly/go-hostpool v0.0.0-20171023180738-a3a6125de932 h1:mXoPYz/Ul5HYEDvkta6I8/rnYM5gSdSV2tJ6XbZuEtY= 20 22 github.com/bitly/go-hostpool v0.0.0-20171023180738-a3a6125de932/go.mod h1:NOuUCSz6Q9T7+igc/hlvDOUdtWKryOrtFyIVABv/p7k= 21 - github.com/bluesky-social/indigo v0.0.0-20251014222321-1e8718ae9f33 h1:x06Y6VyYUCvqWl2AS4/3NBBbRf8wWNMd3YrI44NTHS8= 22 - github.com/bluesky-social/indigo v0.0.0-20251014222321-1e8718ae9f33/go.mod h1:GuGAU33qKulpZCZNPcUeIQ4RW6KzNvOy7s8MSUXbAng= 23 + github.com/bluesky-social/indigo v0.0.0-20251021193747-543ab1124beb h1:zzyqB1W/itfdIA5cnOZ7IFCJ6QtqwOsXltmLunL4sHw= 24 + github.com/bluesky-social/indigo v0.0.0-20251021193747-543ab1124beb/go.mod h1:GuGAU33qKulpZCZNPcUeIQ4RW6KzNvOy7s8MSUXbAng= 23 25 github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869 h1:DDGfHa7BWjL4YnC6+E63dPcxHo2sUxDIu8g3QgEJdRY= 24 26 github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869/go.mod h1:Ekp36dRnpXw/yCqJaO+ZrUyxD+3VXMFFr56k5XYrpB4= 25 27 github.com/bshuster-repo/logrus-logstash-hook v1.0.0 h1:e+C0SB5R1pu//O4MQ3f9cFuPGoOVeF2fE4Og9otCc70= ··· 68 70 github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= 69 71 github.com/go-chi/chi/v5 v5.2.3 h1:WQIt9uxdsAbgIYgid+BpYc+liqQZGMHRaUwp0JUcvdE= 70 72 github.com/go-chi/chi/v5 v5.2.3/go.mod h1:L2yAIGWB3H+phAw1NxKwWM+7eUH/lU8pOMm5hHcoops= 71 - github.com/go-chi/cors v1.2.2 h1:Jmey33TE+b+rB7fT8MUy1u0I4L+NARQlK6LhzKPSyQE= 72 - github.com/go-chi/cors v1.2.2/go.mod h1:sSbTewc+6wYHBBCW7ytsFSn836hqM7JxpglAy2Vzc58= 73 73 github.com/go-jose/go-jose/v4 v4.1.2 h1:TK/7NqRQZfgAh+Td8AlsrvtPoUyiHh0LqVvokh+1vHI= 74 74 github.com/go-jose/go-jose/v4 v4.1.2/go.mod h1:22cg9HWM1pOlnRiY+9cQYJ9XHmya1bYW8OeDM6Ku6Oo= 75 75 github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= ··· 80 80 github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= 81 81 github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= 82 82 github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= 83 + github.com/go-redis/redis v6.15.9+incompatible h1:K0pv1D7EQUjfyoMql+r/jZqCLizCGKFlFgcHWWmHQjg= 84 + github.com/go-redis/redis v6.15.9+incompatible/go.mod h1:NAIEuMOZ/fxfXJIrKDQDz8wamY7mA7PouImQ2Jvg6kA= 83 85 github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= 84 86 github.com/go-yaml/yaml v2.1.0+incompatible/go.mod h1:w2MrLa16VYP0jy6N7M5kHaCkaLENm+P+Tv+MfurjSw0= 85 87 github.com/gocql/gocql v1.7.0 h1:O+7U7/1gSN7QTEAaMEsJc1Oq2QHXvCWoF3DFK9HDHus=
+54 -3
pkg/hold/oci/xrpc_test.go
··· 4 4 "bytes" 5 5 "context" 6 6 "encoding/json" 7 + "fmt" 7 8 "io" 8 9 "net/http" 9 10 "net/http/httptest" 11 + "os" 10 12 "path/filepath" 11 13 "strconv" 12 14 "testing" ··· 18 20 _ "github.com/distribution/distribution/v3/registry/storage/driver/filesystem" 19 21 ) 20 22 23 + // Shared test resources for OCI package 24 + var ( 25 + sharedTestKeyPath string 26 + sharedTestKey []byte 27 + ) 28 + 29 + // TestMain sets up shared resources for all OCI tests 30 + func TestMain(m *testing.M) { 31 + // Create a temporary directory for shared test key 32 + tmpDir, err := os.MkdirTemp("", "oci-test-shared-*") 33 + if err != nil { 34 + panic(fmt.Sprintf("Failed to create temp dir: %v", err)) 35 + } 36 + defer os.RemoveAll(tmpDir) 37 + 38 + // Generate one signing key to be reused across all tests 39 + sharedTestKeyPath = filepath.Join(tmpDir, "shared-signing-key") 40 + privateKey, err := pds.GenerateOrLoadKey(sharedTestKeyPath) 41 + if err != nil { 42 + panic(fmt.Sprintf("Failed to generate shared signing key: %v", err)) 43 + } 44 + 45 + // Store the key bytes so tests can copy them 46 + sharedTestKey = privateKey.Bytes() 47 + 48 + // Run tests 49 + code := m.Run() 50 + os.Exit(code) 51 + } 52 + 21 53 // Test setup helpers 22 54 23 55 // mockPDSClient implements pds.HTTPClient for testing ··· 54 86 } 55 87 56 88 // Create minimal PDS for DID/auth 57 - dbPath := filepath.Join(tmpDir, "pds.db") 89 + // Use in-memory database for speed 90 + dbPath := ":memory:" 58 91 keyPath := filepath.Join(tmpDir, "signing-key") 59 92 holdDID := "did:web:hold.example.com" 60 93 publicURL := "https://hold.example.com" 61 94 95 + // Copy shared signing key instead of generating a new one 96 + if err := os.WriteFile(keyPath, sharedTestKey, 0600); err != nil { 97 + t.Fatalf("Failed to copy shared signing key: %v", err) 98 + } 99 + 62 100 holdPDS, err := pds.NewHoldPDS(ctx, holdDID, publicURL, dbPath, keyPath) 63 101 if err != nil { 64 102 t.Fatalf("Failed to create PDS: %v", err) 65 103 } 66 104 67 - // Bootstrap PDS 105 + // Bootstrap PDS, suppressing stdout to avoid log spam 68 106 ownerDID := "did:plc:owner123" 69 - if err := holdPDS.Bootstrap(ctx, nil, ownerDID, true, false, ""); err != nil { 107 + 108 + // Redirect stdout to suppress bootstrap logging 109 + oldStdout := os.Stdout 110 + r, w, _ := os.Pipe() 111 + os.Stdout = w 112 + 113 + err = holdPDS.Bootstrap(ctx, nil, ownerDID, true, false, "") 114 + 115 + // Restore stdout 116 + w.Close() 117 + os.Stdout = oldStdout 118 + io.ReadAll(r) // Drain the pipe 119 + 120 + if err != nil { 70 121 t.Fatalf("Failed to bootstrap PDS: %v", err) 71 122 } 72 123
-138
pkg/hold/pds/apppassword.go
··· 1 - package pds 2 - 3 - import ( 4 - "crypto/rand" 5 - "encoding/base32" 6 - "fmt" 7 - "strings" 8 - 9 - "golang.org/x/crypto/bcrypt" 10 - ) 11 - 12 - // GenerateAppPassword creates a random app password in the format: abcd-efgh-ijkl-mnop 13 - // Uses base32 encoding for readable characters (no ambiguous chars like 0/O, 1/l) 14 - func GenerateAppPassword() (string, error) { 15 - // Generate 20 random bytes (160 bits of entropy) 16 - // Base32 encoding gives us 32 characters, we'll format as 4 groups of 4 17 - randomBytes := make([]byte, 20) 18 - if _, err := rand.Read(randomBytes); err != nil { 19 - return "", fmt.Errorf("failed to generate random bytes: %w", err) 20 - } 21 - 22 - // Encode as base32 and lowercase (base32 alphabet: a-z, 2-7) 23 - encoded := base32.StdEncoding.EncodeToString(randomBytes) 24 - encoded = strings.ToLower(encoded) 25 - 26 - // Remove padding and take first 16 characters 27 - encoded = strings.TrimRight(encoded, "=") 28 - if len(encoded) > 16 { 29 - encoded = encoded[:16] 30 - } 31 - 32 - // Format as: xxxx-xxxx-xxxx-xxxx 33 - parts := []string{ 34 - encoded[0:4], 35 - encoded[4:8], 36 - encoded[8:12], 37 - encoded[12:16], 38 - } 39 - 40 - return strings.Join(parts, "-"), nil 41 - } 42 - 43 - // HashAppPassword hashes an app password using bcrypt 44 - // Cost is set to 12 for good security without excessive CPU usage 45 - func HashAppPassword(password string) (string, error) { 46 - hash, err := bcrypt.GenerateFromPassword([]byte(password), 12) 47 - if err != nil { 48 - return "", fmt.Errorf("failed to hash password: %w", err) 49 - } 50 - return string(hash), nil 51 - } 52 - 53 - // ValidateAppPassword compares a plaintext password with a bcrypt hash 54 - func ValidateAppPassword(password, hash string) bool { 55 - err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(password)) 56 - return err == nil 57 - } 58 - 59 - // CreateAppPassword generates and stores a new app password 60 - func (p *HoldPDS) CreateAppPassword(name string) (string, error) { 61 - // Generate random password 62 - password, err := GenerateAppPassword() 63 - if err != nil { 64 - return "", fmt.Errorf("failed to generate password: %w", err) 65 - } 66 - 67 - // Hash password 68 - hash, err := HashAppPassword(password) 69 - if err != nil { 70 - return "", fmt.Errorf("failed to hash password: %w", err) 71 - } 72 - 73 - // Store in database 74 - if err := p.authDB.CreateAppPassword(name, hash); err != nil { 75 - return "", fmt.Errorf("failed to store password: %w", err) 76 - } 77 - 78 - return password, nil 79 - } 80 - 81 - // ValidateAppPasswordByName checks if a password matches the stored hash for a given name 82 - func (p *HoldPDS) ValidateAppPasswordByName(name, password string) error { 83 - // Get app password from database 84 - ap, err := p.authDB.GetAppPassword(name) 85 - if err != nil { 86 - return fmt.Errorf("app password not found: %w", err) 87 - } 88 - 89 - // Validate password 90 - if !ValidateAppPassword(password, ap.PasswordHash) { 91 - return fmt.Errorf("invalid password") 92 - } 93 - 94 - // Update last used timestamp 95 - if err := p.authDB.UpdateLastUsed(name); err != nil { 96 - // Log but don't fail - this is not critical 97 - fmt.Printf("Warning: failed to update last used timestamp: %v\n", err) 98 - } 99 - 100 - return nil 101 - } 102 - 103 - // ValidateAnyAppPassword checks if a password matches any stored app password 104 - // Returns the name of the matching app password, or error if none match 105 - func (p *HoldPDS) ValidateAnyAppPassword(password string) (string, error) { 106 - // List all app passwords 107 - passwords, err := p.authDB.ListAppPasswords() 108 - if err != nil { 109 - return "", fmt.Errorf("failed to list app passwords: %w", err) 110 - } 111 - 112 - // Try each one 113 - for _, ap := range passwords { 114 - // Get full record with hash 115 - fullAP, err := p.authDB.GetAppPassword(ap.Name) 116 - if err != nil { 117 - continue 118 - } 119 - 120 - if ValidateAppPassword(password, fullAP.PasswordHash) { 121 - // Update last used 122 - p.authDB.UpdateLastUsed(ap.Name) 123 - return ap.Name, nil 124 - } 125 - } 126 - 127 - return "", fmt.Errorf("invalid app password") 128 - } 129 - 130 - // ListAppPasswords returns a list of app password names (without hashes) 131 - func (p *HoldPDS) ListAppPasswords() ([]AppPassword, error) { 132 - return p.authDB.ListAppPasswords() 133 - } 134 - 135 - // RevokeAppPassword deletes an app password 136 - func (p *HoldPDS) RevokeAppPassword(name string) error { 137 - return p.authDB.DeleteAppPassword(name) 138 - }
-30
pkg/hold/pds/auth.go
··· 528 528 529 529 return publicKey, nil 530 530 } 531 - 532 - // ValidateJWTAuth validates a request with a JWT access token from createSession 533 - // This is used for authenticated repo operations (createRecord, etc.) 534 - // Returns the validated user DID 535 - func ValidateJWTAuth(r *http.Request, pds *HoldPDS) (*ValidatedUser, error) { 536 - // Extract Authorization header 537 - authHeader := r.Header.Get("Authorization") 538 - if authHeader == "" { 539 - return nil, fmt.Errorf("missing Authorization header") 540 - } 541 - 542 - // Remove "Bearer " prefix 543 - accessToken := strings.TrimPrefix(authHeader, "Bearer ") 544 - if accessToken == authHeader { 545 - return nil, fmt.Errorf("invalid authorization header format (expected Bearer)") 546 - } 547 - 548 - // Validate access token 549 - claims, err := pds.ValidateAccessToken(accessToken) 550 - if err != nil { 551 - return nil, fmt.Errorf("invalid access token: %w", err) 552 - } 553 - 554 - return &ValidatedUser{ 555 - DID: claims.DID, 556 - Handle: claims.Handle, 557 - PDS: "", 558 - Authorized: true, 559 - }, nil 560 - }
+15 -75
pkg/hold/pds/auth_test.go
··· 506 506 507 507 // TestValidateBlobWriteAccess_ServiceToken_Owner tests owner write access via service token 508 508 func TestValidateBlobWriteAccess_ServiceToken_Owner(t *testing.T) { 509 - pds, ctx := setupTestPDS(t) 510 - 511 509 ownerDID := "did:plc:owner123" 512 510 holdDID := "did:web:hold01.atcr.io" 513 511 514 - // Bootstrap with owner 515 - err := pds.Bootstrap(ctx, nil, ownerDID, true, false, "") 516 - if err != nil { 517 - t.Fatalf("Failed to bootstrap PDS: %v", err) 518 - } 512 + _, _ = setupTestPDSWithBootstrap(t, ownerDID, true, false) 519 513 520 514 // Create service token for owner 521 515 helper, err := NewServiceTokenTestHelper(ownerDID, holdDID) ··· 539 533 540 534 // TestValidateBlobWriteAccess_ServiceToken_CrewWithPermission tests crew write access via service token 541 535 func TestValidateBlobWriteAccess_ServiceToken_CrewWithPermission(t *testing.T) { 542 - pds, ctx := setupTestPDS(t) 543 - 544 536 ownerDID := "did:plc:owner123" 545 537 writerDID := "did:plc:writer123" 546 538 holdDID := "did:web:hold01.atcr.io" 547 539 548 - // Bootstrap 549 - err := pds.Bootstrap(ctx, nil, ownerDID, true, false, "") 550 - if err != nil { 551 - t.Fatalf("Failed to bootstrap PDS: %v", err) 552 - } 540 + pds, ctx := setupTestPDSWithBootstrap(t, ownerDID, true, false) 553 541 554 542 // Add crew member with blob:write permission 555 - _, err = pds.AddCrewMember(ctx, writerDID, "writer", []string{"blob:write"}) 543 + _, err := pds.AddCrewMember(ctx, writerDID, "writer", []string{"blob:write"}) 556 544 if err != nil { 557 545 t.Fatalf("Failed to add crew member: %v", err) 558 546 } ··· 598 586 599 587 // TestValidateBlobWriteAccess_ServiceToken_CrewWithoutPermission tests that crew without permission is rejected 600 588 func TestValidateBlobWriteAccess_ServiceToken_CrewWithoutPermission(t *testing.T) { 601 - pds, ctx := setupTestPDS(t) 602 - 603 589 ownerDID := "did:plc:owner123" 604 590 readerDID := "did:plc:reader123" 605 591 holdDID := "did:web:hold01.atcr.io" 606 592 607 - // Bootstrap 608 - err := pds.Bootstrap(ctx, nil, ownerDID, true, false, "") 609 - if err != nil { 610 - t.Fatalf("Failed to bootstrap PDS: %v", err) 611 - } 593 + pds, ctx := setupTestPDSWithBootstrap(t, ownerDID, true, false) 612 594 613 595 // Add crew member with blob:read permission only (no blob:write) 614 - _, err = pds.AddCrewMember(ctx, readerDID, "reader", []string{"blob:read"}) 596 + _, err := pds.AddCrewMember(ctx, readerDID, "reader", []string{"blob:read"}) 615 597 if err != nil { 616 598 t.Fatalf("Failed to add crew member: %v", err) 617 599 } ··· 645 627 646 628 // TestValidateBlobWriteAccess_Owner tests that the hold owner has write access 647 629 func TestValidateBlobWriteAccess_Owner(t *testing.T) { 648 - pds, ctx := setupTestPDS(t) 649 - 650 630 ownerDID := "did:plc:owner123" 651 631 652 - // Bootstrap with owner 653 - err := pds.Bootstrap(ctx, nil, ownerDID, true, false, "") 654 - if err != nil { 655 - t.Fatalf("Failed to bootstrap PDS: %v", err) 656 - } 632 + pds, _ := setupTestPDSWithBootstrap(t, ownerDID, true, false) 657 633 658 634 // Create DPoP helper for owner 659 635 dpopHelper, err := NewDPoPTestHelper(ownerDID, "https://test-pds.example.com") ··· 691 667 692 668 // TestValidateBlobWriteAccess_CrewPermissions tests crew permission checking 693 669 func TestValidateBlobWriteAccess_CrewPermissions(t *testing.T) { 694 - pds, ctx := setupTestPDS(t) 695 - 696 670 ownerDID := "did:plc:owner123" 697 671 698 - // Bootstrap 699 - err := pds.Bootstrap(ctx, nil, ownerDID, true, false, "") 700 - if err != nil { 701 - t.Fatalf("Failed to bootstrap PDS: %v", err) 702 - } 672 + pds, ctx := setupTestPDSWithBootstrap(t, ownerDID, true, false) 703 673 704 674 // Add crew member with blob:write permission 705 675 writerDID := "did:plc:writer123" 706 - _, err = pds.AddCrewMember(ctx, writerDID, "writer", []string{"blob:write"}) 676 + _, err := pds.AddCrewMember(ctx, writerDID, "writer", []string{"blob:write"}) 707 677 if err != nil { 708 678 t.Fatalf("Failed to add crew member: %v", err) 709 679 } ··· 764 734 765 735 // TestValidateBlobReadAccess_PublicHold tests public hold access 766 736 func TestValidateBlobReadAccess_PublicHold(t *testing.T) { 767 - pds, ctx := setupTestPDS(t) 768 - 769 737 ownerDID := "did:plc:owner123" 770 738 771 - // Bootstrap with public=true 772 - err := pds.Bootstrap(ctx, nil, ownerDID, true, false, "") 773 - if err != nil { 774 - t.Fatalf("Failed to bootstrap PDS: %v", err) 775 - } 739 + pds, ctx := setupTestPDSWithBootstrap(t, ownerDID, true, false) 776 740 777 741 // Verify captain record has public=true 778 742 _, captain, err := pds.GetCaptainRecord(ctx) ··· 801 765 802 766 // TestValidateBlobReadAccess_PrivateHold tests private hold access 803 767 func TestValidateBlobReadAccess_PrivateHold(t *testing.T) { 804 - pds, ctx := setupTestPDS(t) 805 - 806 768 ownerDID := "did:plc:owner123" 807 769 808 - // Bootstrap with public=false 809 - err := pds.Bootstrap(ctx, nil, ownerDID, false, false, "") 810 - if err != nil { 811 - t.Fatalf("Failed to bootstrap PDS: %v", err) 812 - } 770 + pds, ctx := setupTestPDSWithBootstrap(t, ownerDID, false, false) 813 771 814 772 // Update captain to be private 815 - _, err = pds.UpdateCaptainRecord(ctx, false, false) 773 + _, err := pds.UpdateCaptainRecord(ctx, false, false) 816 774 if err != nil { 817 775 t.Fatalf("Failed to update captain record: %v", err) 818 776 } ··· 843 801 844 802 // TestValidateOwnerOrCrewAdmin tests admin permission checking 845 803 func TestValidateOwnerOrCrewAdmin(t *testing.T) { 846 - pds, ctx := setupTestPDS(t) 847 - 848 804 ownerDID := "did:plc:owner123" 849 805 850 - // Bootstrap 851 - err := pds.Bootstrap(ctx, nil, ownerDID, true, false, "") 852 - if err != nil { 853 - t.Fatalf("Failed to bootstrap PDS: %v", err) 854 - } 806 + pds, ctx := setupTestPDSWithBootstrap(t, ownerDID, true, false) 855 807 856 808 // Add crew member with crew:admin permission 857 809 adminDID := "did:plc:admin123" 858 - _, err = pds.AddCrewMember(ctx, adminDID, "admin", []string{"crew:admin", "blob:write", "blob:read"}) 810 + _, err := pds.AddCrewMember(ctx, adminDID, "admin", []string{"crew:admin", "blob:write", "blob:read"}) 859 811 if err != nil { 860 812 t.Fatalf("Failed to add crew admin: %v", err) 861 813 } ··· 911 863 912 864 // TestCrewPermissions tests various permission combinations 913 865 func TestCrewPermissions(t *testing.T) { 914 - pds, ctx := setupTestPDS(t) 915 - 916 866 ownerDID := "did:plc:owner123" 917 867 918 - // Bootstrap 919 - err := pds.Bootstrap(ctx, nil, ownerDID, true, false, "") 920 - if err != nil { 921 - t.Fatalf("Failed to bootstrap PDS: %v", err) 922 - } 868 + pds, ctx := setupTestPDSWithBootstrap(t, ownerDID, true, false) 923 869 924 870 tests := []struct { 925 871 name string ··· 1035 981 1036 982 for _, tt := range tests { 1037 983 t.Run(tt.name, func(t *testing.T) { 1038 - pds, ctx := setupTestPDS(t) 1039 - 1040 984 ownerDID := "did:plc:owner123" 1041 985 1042 - // Bootstrap with specified settings 1043 - err := pds.Bootstrap(ctx, nil, ownerDID, tt.public, tt.allowAllCrew, "") 1044 - if err != nil { 1045 - t.Fatalf("Failed to bootstrap PDS: %v", err) 1046 - } 986 + pds, ctx := setupTestPDSWithBootstrap(t, ownerDID, tt.public, tt.allowAllCrew) 1047 987 1048 988 // Verify captain record has expected settings 1049 989 _, captain, err := pds.GetCaptainRecord(ctx)
+34 -1
pkg/hold/pds/captain_test.go
··· 3 3 import ( 4 4 "bytes" 5 5 "context" 6 + "io" 7 + "os" 6 8 "path/filepath" 7 9 "strings" 8 10 "testing" ··· 17 19 ctx := context.Background() 18 20 tmpDir := t.TempDir() 19 21 20 - dbPath := filepath.Join(tmpDir, "pds.db") 22 + // Use in-memory database for speed 23 + dbPath := ":memory:" 21 24 keyPath := filepath.Join(tmpDir, "signing-key") 22 25 26 + // Copy shared signing key instead of generating a new one 27 + if err := os.WriteFile(keyPath, sharedTestKey, 0600); err != nil { 28 + t.Fatalf("Failed to copy shared signing key: %v", err) 29 + } 30 + 23 31 pds, err := NewHoldPDS(ctx, "did:web:hold.example.com", "https://hold.example.com", dbPath, keyPath) 24 32 if err != nil { 25 33 t.Fatalf("Failed to create test PDS: %v", err) ··· 30 38 err = pds.repomgr.InitNewActor(ctx, pds.uid, "", pds.did, "", "", "") 31 39 if err != nil { 32 40 t.Fatalf("Failed to initialize test repo: %v", err) 41 + } 42 + 43 + return pds, ctx 44 + } 45 + 46 + // setupTestPDSWithBootstrap creates a test PDS and bootstraps it with suppressed output 47 + // This is a convenience function for tests that need a fully initialized PDS 48 + func setupTestPDSWithBootstrap(t *testing.T, ownerDID string, public, allowAllCrew bool) (*HoldPDS, context.Context) { 49 + t.Helper() 50 + 51 + pds, ctx := setupTestPDS(t) 52 + 53 + // Bootstrap with suppressed output 54 + oldStdout := os.Stdout 55 + r, w, _ := os.Pipe() 56 + os.Stdout = w 57 + 58 + err := pds.Bootstrap(ctx, nil, ownerDID, public, allowAllCrew, "") 59 + 60 + w.Close() 61 + os.Stdout = oldStdout 62 + io.ReadAll(r) // Drain the pipe 63 + 64 + if err != nil { 65 + t.Fatalf("Failed to bootstrap PDS: %v", err) 33 66 } 34 67 35 68 return pds, ctx
-232
pkg/hold/pds/database.go
··· 1 - package pds 2 - 3 - import ( 4 - "database/sql" 5 - "fmt" 6 - "os" 7 - "path/filepath" 8 - "time" 9 - 10 - _ "github.com/mattn/go-sqlite3" 11 - ) 12 - 13 - // Database manages app passwords and sessions for the hold PDS 14 - type Database struct { 15 - db *sql.DB 16 - } 17 - 18 - // NewDatabase creates or opens a database for app passwords and sessions 19 - // dbPath should be the directory path (same as carstore) 20 - // It creates a separate "auth.db" file for authentication data 21 - func NewDatabase(dbPath string) (*Database, error) { 22 - // Ensure directory exists 23 - if err := os.MkdirAll(dbPath, 0755); err != nil { 24 - return nil, fmt.Errorf("failed to create database directory: %w", err) 25 - } 26 - 27 - // Create auth database file alongside carstore database 28 - authDBFile := filepath.Join(dbPath, "auth.db") 29 - 30 - db, err := sql.Open("sqlite3", authDBFile) 31 - if err != nil { 32 - return nil, fmt.Errorf("failed to open database: %w", err) 33 - } 34 - 35 - // Create tables 36 - if err := createTables(db); err != nil { 37 - db.Close() 38 - return nil, fmt.Errorf("failed to create tables: %w", err) 39 - } 40 - 41 - return &Database{db: db}, nil 42 - } 43 - 44 - // createTables creates the database schema 45 - func createTables(db *sql.DB) error { 46 - schema := ` 47 - CREATE TABLE IF NOT EXISTS app_passwords ( 48 - id INTEGER PRIMARY KEY AUTOINCREMENT, 49 - name TEXT NOT NULL UNIQUE, 50 - password_hash TEXT NOT NULL, 51 - created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, 52 - last_used_at TIMESTAMP 53 - ); 54 - 55 - CREATE INDEX IF NOT EXISTS idx_app_passwords_name ON app_passwords(name); 56 - 57 - CREATE TABLE IF NOT EXISTS refresh_tokens ( 58 - id INTEGER PRIMARY KEY AUTOINCREMENT, 59 - token_hash TEXT NOT NULL UNIQUE, 60 - created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, 61 - expires_at TIMESTAMP NOT NULL, 62 - last_used_at TIMESTAMP 63 - ); 64 - 65 - CREATE INDEX IF NOT EXISTS idx_refresh_tokens_hash ON refresh_tokens(token_hash); 66 - CREATE INDEX IF NOT EXISTS idx_refresh_tokens_expires ON refresh_tokens(expires_at); 67 - ` 68 - 69 - _, err := db.Exec(schema) 70 - return err 71 - } 72 - 73 - // Close closes the database connection 74 - func (d *Database) Close() error { 75 - return d.db.Close() 76 - } 77 - 78 - // AppPassword represents an app password record 79 - type AppPassword struct { 80 - ID int64 81 - Name string 82 - PasswordHash string 83 - CreatedAt time.Time 84 - LastUsedAt *time.Time 85 - } 86 - 87 - // CreateAppPassword stores a new app password 88 - func (d *Database) CreateAppPassword(name, passwordHash string) error { 89 - query := `INSERT INTO app_passwords (name, password_hash) VALUES (?, ?)` 90 - _, err := d.db.Exec(query, name, passwordHash) 91 - if err != nil { 92 - return fmt.Errorf("failed to create app password: %w", err) 93 - } 94 - return nil 95 - } 96 - 97 - // GetAppPassword retrieves an app password by name 98 - func (d *Database) GetAppPassword(name string) (*AppPassword, error) { 99 - query := `SELECT id, name, password_hash, created_at, last_used_at FROM app_passwords WHERE name = ?` 100 - 101 - var ap AppPassword 102 - var lastUsedAt sql.NullTime 103 - 104 - err := d.db.QueryRow(query, name).Scan( 105 - &ap.ID, 106 - &ap.Name, 107 - &ap.PasswordHash, 108 - &ap.CreatedAt, 109 - &lastUsedAt, 110 - ) 111 - 112 - if err == sql.ErrNoRows { 113 - return nil, fmt.Errorf("app password not found") 114 - } 115 - if err != nil { 116 - return nil, fmt.Errorf("failed to get app password: %w", err) 117 - } 118 - 119 - if lastUsedAt.Valid { 120 - ap.LastUsedAt = &lastUsedAt.Time 121 - } 122 - 123 - return &ap, nil 124 - } 125 - 126 - // ListAppPasswords returns all app passwords (without hashes) 127 - func (d *Database) ListAppPasswords() ([]AppPassword, error) { 128 - query := `SELECT id, name, created_at, last_used_at FROM app_passwords ORDER BY created_at DESC` 129 - 130 - rows, err := d.db.Query(query) 131 - if err != nil { 132 - return nil, fmt.Errorf("failed to list app passwords: %w", err) 133 - } 134 - defer rows.Close() 135 - 136 - var passwords []AppPassword 137 - for rows.Next() { 138 - var ap AppPassword 139 - var lastUsedAt sql.NullTime 140 - 141 - if err := rows.Scan(&ap.ID, &ap.Name, &ap.CreatedAt, &lastUsedAt); err != nil { 142 - return nil, fmt.Errorf("failed to scan row: %w", err) 143 - } 144 - 145 - if lastUsedAt.Valid { 146 - ap.LastUsedAt = &lastUsedAt.Time 147 - } 148 - 149 - passwords = append(passwords, ap) 150 - } 151 - 152 - return passwords, rows.Err() 153 - } 154 - 155 - // UpdateLastUsed updates the last used timestamp for an app password 156 - func (d *Database) UpdateLastUsed(name string) error { 157 - query := `UPDATE app_passwords SET last_used_at = CURRENT_TIMESTAMP WHERE name = ?` 158 - _, err := d.db.Exec(query, name) 159 - return err 160 - } 161 - 162 - // DeleteAppPassword removes an app password 163 - func (d *Database) DeleteAppPassword(name string) error { 164 - query := `DELETE FROM app_passwords WHERE name = ?` 165 - result, err := d.db.Exec(query, name) 166 - if err != nil { 167 - return fmt.Errorf("failed to delete app password: %w", err) 168 - } 169 - 170 - rows, err := result.RowsAffected() 171 - if err != nil { 172 - return fmt.Errorf("failed to check rows affected: %w", err) 173 - } 174 - 175 - if rows == 0 { 176 - return fmt.Errorf("app password not found") 177 - } 178 - 179 - return nil 180 - } 181 - 182 - // CreateRefreshToken stores a refresh token 183 - func (d *Database) CreateRefreshToken(tokenHash string, expiresAt time.Time) error { 184 - query := `INSERT INTO refresh_tokens (token_hash, expires_at) VALUES (?, ?)` 185 - _, err := d.db.Exec(query, tokenHash, expiresAt) 186 - if err != nil { 187 - return fmt.Errorf("failed to create refresh token: %w", err) 188 - } 189 - return nil 190 - } 191 - 192 - // ValidateRefreshToken checks if a refresh token exists and is not expired 193 - func (d *Database) ValidateRefreshToken(tokenHash string) (bool, error) { 194 - query := `SELECT expires_at FROM refresh_tokens WHERE token_hash = ?` 195 - 196 - var expiresAt time.Time 197 - err := d.db.QueryRow(query, tokenHash).Scan(&expiresAt) 198 - 199 - if err == sql.ErrNoRows { 200 - return false, nil 201 - } 202 - if err != nil { 203 - return false, fmt.Errorf("failed to validate refresh token: %w", err) 204 - } 205 - 206 - // Check if expired 207 - if time.Now().After(expiresAt) { 208 - // Delete expired token 209 - d.DeleteRefreshToken(tokenHash) 210 - return false, nil 211 - } 212 - 213 - // Update last used 214 - updateQuery := `UPDATE refresh_tokens SET last_used_at = CURRENT_TIMESTAMP WHERE token_hash = ?` 215 - d.db.Exec(updateQuery, tokenHash) 216 - 217 - return true, nil 218 - } 219 - 220 - // DeleteRefreshToken removes a refresh token 221 - func (d *Database) DeleteRefreshToken(tokenHash string) error { 222 - query := `DELETE FROM refresh_tokens WHERE token_hash = ?` 223 - _, err := d.db.Exec(query, tokenHash) 224 - return err 225 - } 226 - 227 - // CleanupExpiredTokens removes all expired refresh tokens 228 - func (d *Database) CleanupExpiredTokens() error { 229 - query := `DELETE FROM refresh_tokens WHERE expires_at < CURRENT_TIMESTAMP` 230 - _, err := d.db.Exec(query) 231 - return err 232 - }
+71 -16
pkg/hold/pds/events.go
··· 8 8 "time" 9 9 10 10 atproto "github.com/bluesky-social/indigo/api/atproto" 11 + "github.com/bluesky-social/indigo/events" 11 12 lexutil "github.com/bluesky-social/indigo/lex/util" 12 13 "github.com/gorilla/websocket" 14 + "github.com/ipfs/go-cid" 13 15 ) 14 16 15 17 // EventBroadcaster manages WebSocket connections and broadcasts repo events ··· 77 79 b.mu.Unlock() 78 80 79 81 // Send historical events if cursor is provided and < current seq 80 - if cursor > 0 && cursor < currentSeq { 82 + // cursor=0 means "replay all events from the beginning" 83 + // cursor >= 0 triggers backfill, negative cursor means "no backfill" 84 + if cursor >= 0 && cursor < currentSeq { 81 85 go b.backfillSubscriber(sub, cursor) 82 86 } 83 87 ··· 205 209 }() 206 210 207 211 for event := range sub.send { 208 - // Encode as CBOR 209 - cborBytes, err := encodeCBOR(event) 212 + // Create event header (ATProto firehose format) 213 + header := events.EventHeader{ 214 + Op: events.EvtKindMessage, 215 + MsgType: "#commit", 216 + } 217 + 218 + // Get a writer for this message 219 + wc, err := sub.conn.NextWriter(websocket.BinaryMessage) 210 220 if err != nil { 211 - log.Printf("Failed to encode event as CBOR: %v", err) 212 - continue 221 + log.Printf("Failed to get websocket writer: %v", err) 222 + return 213 223 } 214 224 215 - // Write CBOR message to WebSocket 216 - err = sub.conn.WriteMessage(websocket.BinaryMessage, cborBytes) 217 - if err != nil { 218 - log.Printf("Failed to write to websocket: %v", err) 225 + // Write header as CBOR 226 + if err := header.MarshalCBOR(wc); err != nil { 227 + log.Printf("Failed to write event header: %v", err) 228 + wc.Close() 229 + return 230 + } 231 + 232 + // Convert our RepoCommitEvent to indigo's SyncSubscribeRepos_Commit 233 + indigoEvent := convertToIndigoCommit(event) 234 + 235 + // Write the event as CBOR 236 + var obj lexutil.CBOR = indigoEvent 237 + if err := obj.MarshalCBOR(wc); err != nil { 238 + log.Printf("Failed to write event body: %v", err) 239 + wc.Close() 240 + return 241 + } 242 + 243 + // Close the writer to flush the message 244 + if err := wc.Close(); err != nil { 245 + log.Printf("Failed to close websocket writer: %v", err) 219 246 return 220 247 } 221 248 ··· 224 251 } 225 252 } 226 253 227 - // encodeCBOR encodes an event as CBOR 254 + // convertToIndigoCommit converts our RepoCommitEvent to indigo's SyncSubscribeRepos_Commit 255 + // which has proper CBOR marshaling methods generated 256 + func convertToIndigoCommit(event *RepoCommitEvent) *atproto.SyncSubscribeRepos_Commit { 257 + // Parse commit CID string to cid.Cid, then convert to LexLink 258 + commitCID, err := cid.Decode(event.Commit) 259 + if err != nil { 260 + log.Printf("Warning: failed to parse commit CID %s: %v", event.Commit, err) 261 + // Create an empty CID as fallback 262 + commitCID = cid.Undef 263 + } 264 + 265 + // Convert cid.Cid to LexLink 266 + commitLink := lexutil.LexLink(commitCID) 267 + 268 + // Convert blocks to LexBytes 269 + blocks := lexutil.LexBytes(event.Blocks) 270 + 271 + return &atproto.SyncSubscribeRepos_Commit{ 272 + Seq: event.Seq, 273 + Repo: event.Repo, 274 + Commit: commitLink, 275 + Rev: event.Rev, 276 + Since: event.Since, 277 + Blocks: blocks, 278 + Ops: event.Ops, 279 + Time: event.Time, 280 + Blobs: []lexutil.LexLink{}, // Empty for now, we don't track blob refs in our simplified model 281 + Rebase: false, // DEPRECATED field 282 + TooBig: false, // Not implementing tooBig for now 283 + } 284 + } 285 + 286 + // encodeCBOR encodes an event as CBOR (DEPRECATED - kept for tests) 228 287 func encodeCBOR(event *RepoCommitEvent) ([]byte, error) { 229 - // For now, use JSON encoding wrapped in CBOR envelope 230 - // In production, you'd use proper CBOR encoding 231 - // The atproto spec requires DAG-CBOR with specific header 232 - 233 - // Simple approach: encode as JSON for MVP 234 - // Real implementation needs proper CBOR-gen types 288 + // For backward compatibility with tests, encode as JSON 289 + // Production code uses convertToIndigoCommit + CBOR marshaling in handleSubscriber 235 290 return json.Marshal(event) 236 291 } 237 292
+164
pkg/hold/pds/events_test.go
··· 382 382 t.Errorf("Expected decoded seq=1, got %d", decoded.Seq) 383 383 } 384 384 } 385 + 386 + // TestSubscribe_CursorZeroBackfill tests that cursor=0 replays all events 387 + func TestSubscribe_CursorZeroBackfill(t *testing.T) { 388 + broadcaster := NewEventBroadcaster("did:web:hold.example.com", 100) 389 + ctx := context.Background() 390 + 391 + testCID, _ := cid.Decode("bafyreib2rxk3rkhh5ylyxj3x3gathxt3s32qvwj2lf3qg4kmzr6b7teqke") 392 + 393 + // Broadcast 5 events before subscribing 394 + for i := 1; i <= 5; i++ { 395 + event := &RepoEvent{ 396 + NewRoot: testCID, 397 + Rev: "test-rev", 398 + RepoSlice: []byte("test CAR data"), 399 + Ops: []RepoOp{}, 400 + } 401 + broadcaster.Broadcast(ctx, event) 402 + } 403 + 404 + // Verify we have 5 events in history 405 + if broadcaster.eventSeq != 5 { 406 + t.Fatalf("Expected eventSeq=5, got %d", broadcaster.eventSeq) 407 + } 408 + 409 + // Create mock websocket connection (we won't actually use it) 410 + // We just need to verify backfillSubscriber is called 411 + // For this test, we'll check the history directly 412 + if len(broadcaster.eventHistory) != 5 { 413 + t.Errorf("Expected 5 events in history, got %d", len(broadcaster.eventHistory)) 414 + } 415 + 416 + // Verify all events have sequential sequence numbers 417 + for i, he := range broadcaster.eventHistory { 418 + expectedSeq := int64(i + 1) 419 + if he.Seq != expectedSeq { 420 + t.Errorf("Expected history[%d].Seq=%d, got %d", i, expectedSeq, he.Seq) 421 + } 422 + } 423 + 424 + // Test backfillSubscriber directly with cursor=0 425 + // Create a subscriber manually (conn not needed for backfill test) 426 + sub := &Subscriber{ 427 + conn: nil, // Not used in backfillSubscriber 428 + send: make(chan *RepoCommitEvent, 100), // Large buffer for testing 429 + cursor: 0, 430 + } 431 + 432 + // Run backfill in a goroutine 433 + go broadcaster.backfillSubscriber(sub, 0) 434 + 435 + // Wait for events to be sent 436 + time.Sleep(100 * time.Millisecond) 437 + 438 + // Should receive all 5 events 439 + receivedCount := len(sub.send) 440 + if receivedCount != 5 { 441 + t.Errorf("Expected to receive 5 events with cursor=0, got %d", receivedCount) 442 + } 443 + 444 + // Verify events are in order 445 + for i := 1; i <= 5; i++ { 446 + select { 447 + case event := <-sub.send: 448 + if event.Seq != int64(i) { 449 + t.Errorf("Expected event seq=%d, got %d", i, event.Seq) 450 + } 451 + default: 452 + t.Errorf("Expected event %d but channel was empty", i) 453 + } 454 + } 455 + } 456 + 457 + // TestSubscribe_MidCursorBackfill tests that cursor=N only gets events after N 458 + func TestSubscribe_MidCursorBackfill(t *testing.T) { 459 + broadcaster := NewEventBroadcaster("did:web:hold.example.com", 100) 460 + ctx := context.Background() 461 + 462 + testCID, _ := cid.Decode("bafyreib2rxk3rkhh5ylyxj3x3gathxt3s32qvwj2lf3qg4kmzr6b7teqke") 463 + 464 + // Broadcast 10 events before subscribing 465 + for i := 1; i <= 10; i++ { 466 + event := &RepoEvent{ 467 + NewRoot: testCID, 468 + Rev: "test-rev", 469 + RepoSlice: []byte("test CAR data"), 470 + Ops: []RepoOp{}, 471 + } 472 + broadcaster.Broadcast(ctx, event) 473 + } 474 + 475 + // Test backfillSubscriber with cursor=5 (conn not needed for backfill test) 476 + sub := &Subscriber{ 477 + conn: nil, // Not used in backfillSubscriber 478 + send: make(chan *RepoCommitEvent, 100), // Large buffer for testing 479 + cursor: 5, 480 + } 481 + 482 + // Run backfill 483 + go broadcaster.backfillSubscriber(sub, 5) 484 + 485 + // Wait for events to be sent 486 + time.Sleep(100 * time.Millisecond) 487 + 488 + // Should receive events 6-10 (5 events after cursor=5) 489 + receivedCount := len(sub.send) 490 + if receivedCount != 5 { 491 + t.Errorf("Expected to receive 5 events with cursor=5, got %d", receivedCount) 492 + } 493 + 494 + // Verify events start at seq=6 495 + for i := 6; i <= 10; i++ { 496 + select { 497 + case event := <-sub.send: 498 + if event.Seq != int64(i) { 499 + t.Errorf("Expected event seq=%d, got %d", i, event.Seq) 500 + } 501 + default: 502 + t.Errorf("Expected event %d but channel was empty", i) 503 + } 504 + } 505 + } 506 + 507 + // TestSubscribe_NegativeCursorNoBackfill tests that negative cursor means no backfill 508 + func TestSubscribe_NegativeCursorNoBackfill(t *testing.T) { 509 + broadcaster := NewEventBroadcaster("did:web:hold.example.com", 100) 510 + ctx := context.Background() 511 + 512 + testCID, _ := cid.Decode("bafyreib2rxk3rkhh5ylyxj3x3gathxt3s32qvwj2lf3qg4kmzr6b7teqke") 513 + 514 + // Broadcast 5 events before subscribing 515 + for i := 1; i <= 5; i++ { 516 + event := &RepoEvent{ 517 + NewRoot: testCID, 518 + Rev: "test-rev", 519 + RepoSlice: []byte("test CAR data"), 520 + Ops: []RepoOp{}, 521 + } 522 + broadcaster.Broadcast(ctx, event) 523 + } 524 + 525 + // Create subscriber with cursor=-1 (no backfill, conn not needed) 526 + sub := &Subscriber{ 527 + conn: nil, // Not used in this test 528 + send: make(chan *RepoCommitEvent, 100), 529 + cursor: -1, 530 + } 531 + 532 + // Subscribe should not trigger backfill 533 + broadcaster.mu.Lock() 534 + currentSeq := broadcaster.eventSeq 535 + broadcaster.mu.Unlock() 536 + 537 + // Check the condition: cursor >= 0 && cursor < currentSeq 538 + // For cursor=-1, this should be false 539 + shouldBackfill := -1 >= 0 && -1 < currentSeq 540 + if shouldBackfill { 541 + t.Error("Expected shouldBackfill=false for cursor=-1, but condition evaluated to true") 542 + } 543 + 544 + // Verify no events in send channel (no backfill happened) 545 + if len(sub.send) != 0 { 546 + t.Errorf("Expected 0 events with cursor=-1 (no backfill), got %d", len(sub.send)) 547 + } 548 + }
-249
pkg/hold/pds/jwt.go
··· 1 - package pds 2 - 3 - import ( 4 - "crypto/sha256" 5 - "encoding/base64" 6 - "encoding/hex" 7 - "encoding/json" 8 - "fmt" 9 - "time" 10 - ) 11 - 12 - // Session token types 13 - const ( 14 - TokenTypeAccess = "access" 15 - TokenTypeRefresh = "refresh" 16 - ) 17 - 18 - // Token expiration durations 19 - const ( 20 - AccessTokenDuration = 2 * time.Hour // Short-lived access token 21 - RefreshTokenDuration = 90 * 24 * time.Hour // Long-lived refresh token (90 days) 22 - ) 23 - 24 - // SessionClaims represents JWT claims for ATProto sessions 25 - type SessionClaims struct { 26 - DID string `json:"sub"` // Subject (DID) 27 - Issuer string `json:"iss"` // Issuer (PDS DID) 28 - Handle string `json:"handle,omitempty"` 29 - Scope string `json:"scope"` 30 - TokenType string `json:"token_type"` 31 - IssuedAt int64 `json:"iat"` // Unix timestamp 32 - ExpiresAt int64 `json:"exp"` // Unix timestamp 33 - } 34 - 35 - // IssueAccessToken creates a new access JWT for a session 36 - func (p *HoldPDS) IssueAccessToken(did, handle string) (string, error) { 37 - now := time.Now() 38 - claims := &SessionClaims{ 39 - DID: did, 40 - Issuer: p.did, 41 - Handle: handle, 42 - Scope: "com.atproto.access", 43 - TokenType: TokenTypeAccess, 44 - IssuedAt: now.Unix(), 45 - ExpiresAt: now.Add(AccessTokenDuration).Unix(), 46 - } 47 - 48 - return p.signJWT(claims) 49 - } 50 - 51 - // IssueRefreshToken creates a new refresh JWT for a session 52 - func (p *HoldPDS) IssueRefreshToken(did, handle string) (string, error) { 53 - now := time.Now() 54 - claims := &SessionClaims{ 55 - DID: did, 56 - Issuer: p.did, 57 - Handle: handle, 58 - Scope: "com.atproto.refresh", 59 - TokenType: TokenTypeRefresh, 60 - IssuedAt: now.Unix(), 61 - ExpiresAt: now.Add(RefreshTokenDuration).Unix(), 62 - } 63 - 64 - signedToken, err := p.signJWT(claims) 65 - if err != nil { 66 - return "", err 67 - } 68 - 69 - // Store refresh token hash in database for validation/revocation 70 - tokenHash := hashToken(signedToken) 71 - expiresAt := now.Add(RefreshTokenDuration) 72 - if err := p.authDB.CreateRefreshToken(tokenHash, expiresAt); err != nil { 73 - return "", fmt.Errorf("failed to store refresh token: %w", err) 74 - } 75 - 76 - return signedToken, nil 77 - } 78 - 79 - // ValidateAccessToken validates an access JWT and returns the claims 80 - func (p *HoldPDS) ValidateAccessToken(tokenString string) (*SessionClaims, error) { 81 - return p.validateToken(tokenString, TokenTypeAccess) 82 - } 83 - 84 - // ValidateRefreshToken validates a refresh JWT and returns the claims 85 - // Also checks the database to ensure the token hasn't been revoked 86 - func (p *HoldPDS) ValidateRefreshToken(tokenString string) (*SessionClaims, error) { 87 - // First validate signature and claims 88 - claims, err := p.validateToken(tokenString, TokenTypeRefresh) 89 - if err != nil { 90 - return nil, err 91 - } 92 - 93 - // Check if token is in database (not revoked) 94 - tokenHash := hashToken(tokenString) 95 - valid, err := p.authDB.ValidateRefreshToken(tokenHash) 96 - if err != nil { 97 - return nil, fmt.Errorf("failed to validate refresh token in database: %w", err) 98 - } 99 - if !valid { 100 - return nil, fmt.Errorf("refresh token has been revoked or expired") 101 - } 102 - 103 - return claims, nil 104 - } 105 - 106 - // validateToken validates a JWT token and returns the claims 107 - func (p *HoldPDS) validateToken(tokenString, expectedType string) (*SessionClaims, error) { 108 - // Split token into parts 109 - parts := splitJWT(tokenString) 110 - if len(parts) != 3 { 111 - return nil, fmt.Errorf("invalid JWT format") 112 - } 113 - 114 - // Decode header 115 - headerBytes, err := base64.RawURLEncoding.DecodeString(parts[0]) 116 - if err != nil { 117 - return nil, fmt.Errorf("failed to decode header: %w", err) 118 - } 119 - 120 - var header map[string]interface{} 121 - if err := json.Unmarshal(headerBytes, &header); err != nil { 122 - return nil, fmt.Errorf("failed to parse header: %w", err) 123 - } 124 - 125 - // Verify algorithm 126 - alg, ok := header["alg"].(string) 127 - if !ok || alg != "ES256K" { 128 - return nil, fmt.Errorf("unsupported algorithm: %v", alg) 129 - } 130 - 131 - // Decode claims 132 - claimsBytes, err := base64.RawURLEncoding.DecodeString(parts[1]) 133 - if err != nil { 134 - return nil, fmt.Errorf("failed to decode claims: %w", err) 135 - } 136 - 137 - var claims SessionClaims 138 - if err := json.Unmarshal(claimsBytes, &claims); err != nil { 139 - return nil, fmt.Errorf("failed to parse claims: %w", err) 140 - } 141 - 142 - // Verify token type 143 - if claims.TokenType != expectedType { 144 - return nil, fmt.Errorf("invalid token type: expected %s, got %s", expectedType, claims.TokenType) 145 - } 146 - 147 - // Verify issuer 148 - if claims.Issuer != p.did { 149 - return nil, fmt.Errorf("invalid issuer: expected %s, got %s", p.did, claims.Issuer) 150 - } 151 - 152 - // Verify subject matches this hold 153 - if claims.DID != p.did { 154 - return nil, fmt.Errorf("invalid subject: expected %s, got %s", p.did, claims.DID) 155 - } 156 - 157 - // Verify expiration 158 - if time.Now().Unix() > claims.ExpiresAt { 159 - return nil, fmt.Errorf("token has expired") 160 - } 161 - 162 - // Verify signature 163 - signedData := []byte(parts[0] + "." + parts[1]) 164 - signature, err := base64.RawURLEncoding.DecodeString(parts[2]) 165 - if err != nil { 166 - return nil, fmt.Errorf("failed to decode signature: %w", err) 167 - } 168 - 169 - publicKey, err := p.signingKey.PublicKey() 170 - if err != nil { 171 - return nil, fmt.Errorf("failed to get public key: %w", err) 172 - } 173 - 174 - if err := publicKey.HashAndVerify(signedData, signature); err != nil { 175 - return nil, fmt.Errorf("signature verification failed: %w", err) 176 - } 177 - 178 - return &claims, nil 179 - } 180 - 181 - // RevokeRefreshToken revokes a refresh token by removing it from the database 182 - func (p *HoldPDS) RevokeRefreshToken(tokenString string) error { 183 - tokenHash := hashToken(tokenString) 184 - return p.authDB.DeleteRefreshToken(tokenHash) 185 - } 186 - 187 - // hashToken creates a SHA-256 hash of a token for storage 188 - func hashToken(token string) string { 189 - hash := sha256.Sum256([]byte(token)) 190 - return hex.EncodeToString(hash[:]) 191 - } 192 - 193 - // signJWT creates and signs a JWT using the hold's private key 194 - func (p *HoldPDS) signJWT(claims *SessionClaims) (string, error) { 195 - // Create header 196 - header := map[string]interface{}{ 197 - "typ": "JWT", 198 - "alg": "ES256K", 199 - } 200 - 201 - headerJSON, err := json.Marshal(header) 202 - if err != nil { 203 - return "", fmt.Errorf("failed to marshal header: %w", err) 204 - } 205 - 206 - // Create payload 207 - payloadJSON, err := json.Marshal(claims) 208 - if err != nil { 209 - return "", fmt.Errorf("failed to marshal claims: %w", err) 210 - } 211 - 212 - // Base64url encode header and payload 213 - headerEncoded := base64.RawURLEncoding.EncodeToString(headerJSON) 214 - payloadEncoded := base64.RawURLEncoding.EncodeToString(payloadJSON) 215 - 216 - // Create signing input 217 - signingInput := headerEncoded + "." + payloadEncoded 218 - 219 - // Sign with private key 220 - signature, err := p.signingKey.HashAndSign([]byte(signingInput)) 221 - if err != nil { 222 - return "", fmt.Errorf("failed to sign JWT: %w", err) 223 - } 224 - 225 - // Base64url encode signature 226 - signatureEncoded := base64.RawURLEncoding.EncodeToString(signature) 227 - 228 - // Combine into final JWT 229 - jwt := signingInput + "." + signatureEncoded 230 - 231 - return jwt, nil 232 - } 233 - 234 - // splitJWT splits a JWT string into its three parts 235 - func splitJWT(token string) []string { 236 - // JWT format: header.payload.signature 237 - parts := make([]string, 0, 3) 238 - start := 0 239 - for i := 0; i < len(token); i++ { 240 - if token[i] == '.' { 241 - parts = append(parts, token[start:i]) 242 - start = i + 1 243 - } 244 - } 245 - if start < len(token) { 246 - parts = append(parts, token[start:]) 247 - } 248 - return parts 249 - }
+21 -67
pkg/hold/pds/server.go
··· 36 36 dbPath string 37 37 uid models.Uid 38 38 signingKey *atcrypto.PrivateKeyK256 39 - authDB *Database // Authentication database for app passwords and sessions 40 39 } 41 40 42 41 // NewHoldPDS creates or opens a hold PDS with SQLite carstore 43 42 func NewHoldPDS(ctx context.Context, did, publicURL, dbPath, keyPath string) (*HoldPDS, error) { 44 - // Ensure directory exists 45 - dir := filepath.Dir(dbPath) 46 - if err := os.MkdirAll(dir, 0755); err != nil { 47 - return nil, fmt.Errorf("failed to create database directory: %w", err) 48 - } 49 - 50 43 // Generate or load signing key 51 44 signingKey, err := GenerateOrLoadKey(keyPath) 52 45 if err != nil { 53 46 return nil, fmt.Errorf("failed to initialize signing key: %w", err) 54 47 } 55 48 56 - // Create and open SQLite-backed carstore 57 - // dbPath is the directory, carstore creates and opens db.sqlite3 inside it 58 - sqlStore, err := carstore.NewSqliteStore(dbPath) 59 - if err != nil { 60 - return nil, fmt.Errorf("failed to create sqlite store: %w", err) 49 + // Create SQLite-backed carstore 50 + var sqlStore *carstore.SQLiteStore 51 + 52 + if dbPath == ":memory:" { 53 + // In-memory mode for tests: create carstore manually and open with :memory: 54 + sqlStore = new(carstore.SQLiteStore) 55 + if err := sqlStore.Open(":memory:"); err != nil { 56 + return nil, fmt.Errorf("failed to open in-memory sqlite store: %w", err) 57 + } 58 + } else { 59 + // File mode for production: create directory and use NewSqliteStore 60 + dir := filepath.Dir(dbPath) 61 + if err := os.MkdirAll(dir, 0755); err != nil { 62 + return nil, fmt.Errorf("failed to create database directory: %w", err) 63 + } 64 + 65 + // dbPath is the directory, carstore creates and opens db.sqlite3 inside it 66 + sqlStore, err = carstore.NewSqliteStore(dbPath) 67 + if err != nil { 68 + return nil, fmt.Errorf("failed to create sqlite store: %w", err) 69 + } 61 70 } 62 71 63 72 // Use SQLiteStore directly, not the CarStore() wrapper ··· 84 93 fmt.Printf("New hold repo - will be initialized in Bootstrap\n") 85 94 } 86 95 87 - // Create or open authentication database 88 - authDB, err := NewDatabase(dbPath) 89 - if err != nil { 90 - return nil, fmt.Errorf("failed to create auth database: %w", err) 91 - } 92 - 93 96 return &HoldPDS{ 94 97 did: did, 95 98 PublicURL: publicURL, ··· 98 101 dbPath: dbPath, 99 102 uid: uid, 100 103 signingKey: signingKey, 101 - authDB: authDB, 102 104 }, nil 103 105 } 104 106 ··· 184 186 } else { 185 187 fmt.Printf("✅ Bluesky profile record already exists, skipping\n") 186 188 } 187 - 188 - // Create Tangled profile record (idempotent - check if exists first) 189 - _, _, err = p.GetTangledProfileRecord(ctx) 190 - if err != nil { 191 - // Tangled profile doesn't exist, create it 192 - description := "ahoy from the cargo hold" 193 - links := []string{"https://atcr.io"} 194 - 195 - _, err = p.CreateTangledProfileRecord(ctx, links, description) 196 - if err != nil { 197 - return fmt.Errorf("failed to create tangled profile record: %w", err) 198 - } 199 - fmt.Printf("✅ Created Tangled profile record\n") 200 - } else { 201 - fmt.Printf("✅ Tangled profile record already exists, skipping\n") 202 - } 203 - } 204 - 205 - // Create bootstrap app password if none exist (one-time setup) 206 - passwords, err := p.authDB.ListAppPasswords() 207 - if err != nil { 208 - return fmt.Errorf("failed to list app passwords: %w", err) 209 - } 210 - 211 - if len(passwords) == 0 { 212 - // No app passwords exist, create one 213 - password, err := p.CreateAppPassword("bootstrap") 214 - if err != nil { 215 - return fmt.Errorf("failed to create bootstrap app password: %w", err) 216 - } 217 - 218 - fmt.Printf("\n") 219 - fmt.Printf("╔════════════════════════════════════════════════════════════════╗\n") 220 - fmt.Printf("║ 🔑 APP PASSWORD CREATED ║\n") 221 - fmt.Printf("╠════════════════════════════════════════════════════════════════╣\n") 222 - fmt.Printf("║ ║\n") 223 - fmt.Printf("║ Password: %-51s ║\n", password) 224 - fmt.Printf("║ ║\n") 225 - fmt.Printf("║ ⚠️ SAVE THIS PASSWORD - it will not be shown again ║\n") 226 - fmt.Printf("║ ║\n") 227 - fmt.Printf("║ Use this password to log into Bluesky app or CLI tools ║\n") 228 - fmt.Printf("║ PDS URL: %-51s ║\n", p.PublicURL) 229 - fmt.Printf("║ Username: %-50s ║\n", p.did) 230 - fmt.Printf("║ ║\n") 231 - fmt.Printf("╚════════════════════════════════════════════════════════════════╝\n") 232 - fmt.Printf("\n") 233 - } else { 234 - fmt.Printf("✅ App passwords already exist (count: %d), skipping auto-generation\n", len(passwords)) 235 189 } 236 190 237 191 return nil
-381
pkg/hold/pds/session.go
··· 1 - package pds 2 - 3 - import ( 4 - "encoding/json" 5 - "fmt" 6 - "io" 7 - "net/http" 8 - "strings" 9 - 10 - cbg "github.com/whyrusleeping/cbor-gen" 11 - "github.com/ipfs/go-cid" 12 - ) 13 - 14 - // CreateSessionRequest represents a session creation request 15 - type CreateSessionRequest struct { 16 - Identifier string `json:"identifier"` // DID or handle 17 - Password string `json:"password"` // App password 18 - } 19 - 20 - // CreateSessionResponse represents a successful session creation 21 - type CreateSessionResponse struct { 22 - AccessJwt string `json:"accessJwt"` 23 - RefreshJwt string `json:"refreshJwt"` 24 - Handle string `json:"handle"` 25 - DID string `json:"did"` 26 - DIDDoc map[string]interface{} `json:"didDoc,omitempty"` // Optional DID document 27 - Email string `json:"email,omitempty"` // Optional, not used for holds 28 - Active *bool `json:"active,omitempty"` // Optional account status 29 - Status string `json:"status,omitempty"` // Optional account status 30 - } 31 - 32 - // SessionInfo represents session information 33 - type SessionInfo struct { 34 - Handle string `json:"handle"` 35 - DID string `json:"did"` 36 - Email string `json:"email,omitempty"` 37 - } 38 - 39 - // HandleCreateSession handles com.atproto.server.createSession 40 - func (h *XRPCHandler) HandleCreateSession(w http.ResponseWriter, r *http.Request) { 41 - // Parse request 42 - var req CreateSessionRequest 43 - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { 44 - http.Error(w, fmt.Sprintf("invalid request body: %v", err), http.StatusBadRequest) 45 - return 46 - } 47 - 48 - // Validate required fields 49 - if req.Identifier == "" || req.Password == "" { 50 - http.Error(w, "identifier and password are required", http.StatusBadRequest) 51 - return 52 - } 53 - 54 - // Validate identifier matches this hold's DID or handle 55 - holdDID := h.pds.DID() 56 - // For did:web, handle is the domain part without "did:web:" prefix 57 - // e.g., "did:web:hold01.atcr.io" -> "hold01.atcr.io" 58 - holdHandle := strings.TrimPrefix(holdDID, "did:web:") 59 - 60 - // Normalize the identifier (strip "at://" prefix if present) 61 - identifier := strings.TrimPrefix(req.Identifier, "at://") 62 - 63 - // Accept any of: 64 - // 1. Full DID: "did:web:hold01.atcr.io" 65 - // 2. Handle (domain): "hold01.atcr.io" 66 - // 3. Either with "at://" prefix 67 - isValidIdentifier := identifier == holdDID || identifier == holdHandle 68 - 69 - if !isValidIdentifier { 70 - fmt.Printf("Invalid identifier: got %q, expected DID %q or handle %q\n", req.Identifier, holdDID, holdHandle) 71 - http.Error(w, "invalid identifier", http.StatusUnauthorized) 72 - return 73 - } 74 - 75 - // Validate app password 76 - _, err := h.pds.ValidateAnyAppPassword(req.Password) 77 - if err != nil { 78 - http.Error(w, "invalid password", http.StatusUnauthorized) 79 - return 80 - } 81 - 82 - // Issue access and refresh tokens 83 - accessToken, err := h.pds.IssueAccessToken(holdDID, holdHandle) 84 - if err != nil { 85 - http.Error(w, fmt.Sprintf("failed to issue access token: %v", err), http.StatusInternalServerError) 86 - return 87 - } 88 - 89 - refreshToken, err := h.pds.IssueRefreshToken(holdDID, holdHandle) 90 - if err != nil { 91 - http.Error(w, fmt.Sprintf("failed to issue refresh token: %v", err), http.StatusInternalServerError) 92 - return 93 - } 94 - 95 - // Return session response 96 - active := true 97 - response := CreateSessionResponse{ 98 - AccessJwt: accessToken, 99 - RefreshJwt: refreshToken, 100 - Handle: holdHandle, 101 - DID: holdDID, 102 - Active: &active, // Account is active 103 - } 104 - 105 - w.Header().Set("Content-Type", "application/json") 106 - json.NewEncoder(w).Encode(response) 107 - } 108 - 109 - // HandleRefreshSession handles com.atproto.server.refreshSession 110 - func (h *XRPCHandler) HandleRefreshSession(w http.ResponseWriter, r *http.Request) { 111 - // Extract refresh token from Authorization header 112 - authHeader := r.Header.Get("Authorization") 113 - if authHeader == "" { 114 - http.Error(w, "authorization header required", http.StatusUnauthorized) 115 - return 116 - } 117 - 118 - // Remove "Bearer " prefix 119 - refreshToken := strings.TrimPrefix(authHeader, "Bearer ") 120 - if refreshToken == authHeader { 121 - http.Error(w, "invalid authorization header format", http.StatusUnauthorized) 122 - return 123 - } 124 - 125 - // Validate refresh token 126 - claims, err := h.pds.ValidateRefreshToken(refreshToken) 127 - if err != nil { 128 - http.Error(w, fmt.Sprintf("invalid refresh token: %v", err), http.StatusUnauthorized) 129 - return 130 - } 131 - 132 - // Issue new access token (and optionally new refresh token) 133 - accessToken, err := h.pds.IssueAccessToken(claims.DID, claims.Handle) 134 - if err != nil { 135 - http.Error(w, fmt.Sprintf("failed to issue access token: %v", err), http.StatusInternalServerError) 136 - return 137 - } 138 - 139 - // Issue new refresh token (rotate refresh tokens for security) 140 - newRefreshToken, err := h.pds.IssueRefreshToken(claims.DID, claims.Handle) 141 - if err != nil { 142 - http.Error(w, fmt.Sprintf("failed to issue refresh token: %v", err), http.StatusInternalServerError) 143 - return 144 - } 145 - 146 - // Revoke old refresh token 147 - if err := h.pds.RevokeRefreshToken(refreshToken); err != nil { 148 - // Log but don't fail - new tokens are already issued 149 - fmt.Printf("Warning: failed to revoke old refresh token: %v\n", err) 150 - } 151 - 152 - // Return new tokens 153 - active := true 154 - response := CreateSessionResponse{ 155 - AccessJwt: accessToken, 156 - RefreshJwt: newRefreshToken, 157 - Handle: claims.Handle, 158 - DID: claims.DID, 159 - Active: &active, // Account is active 160 - } 161 - 162 - w.Header().Set("Content-Type", "application/json") 163 - json.NewEncoder(w).Encode(response) 164 - } 165 - 166 - // HandleGetSession handles com.atproto.server.getSession 167 - func (h *XRPCHandler) HandleGetSession(w http.ResponseWriter, r *http.Request) { 168 - // Extract access token from Authorization header 169 - authHeader := r.Header.Get("Authorization") 170 - if authHeader == "" { 171 - http.Error(w, "authorization header required", http.StatusUnauthorized) 172 - return 173 - } 174 - 175 - // Remove "Bearer " prefix 176 - accessToken := strings.TrimPrefix(authHeader, "Bearer ") 177 - if accessToken == authHeader { 178 - http.Error(w, "invalid authorization header format", http.StatusUnauthorized) 179 - return 180 - } 181 - 182 - // Validate access token 183 - claims, err := h.pds.ValidateAccessToken(accessToken) 184 - if err != nil { 185 - http.Error(w, fmt.Sprintf("invalid access token: %v", err), http.StatusUnauthorized) 186 - return 187 - } 188 - 189 - // Return session info 190 - response := SessionInfo{ 191 - Handle: claims.Handle, 192 - DID: claims.DID, 193 - } 194 - 195 - w.Header().Set("Content-Type", "application/json") 196 - json.NewEncoder(w).Encode(response) 197 - } 198 - 199 - // CreateRecordRequest represents a record creation request 200 - type CreateRecordRequest struct { 201 - Repo string `json:"repo"` // DID of the repository 202 - Collection string `json:"collection"` // Collection name (e.g., "app.bsky.feed.post") 203 - Rkey string `json:"rkey,omitempty"` // Optional record key (TID generated if not provided) 204 - Validate *bool `json:"validate,omitempty"` // Optional validation flag 205 - Record interface{} `json:"record"` // The record value (JSON object) 206 - } 207 - 208 - // CreateRecordResponse represents a successful record creation 209 - type CreateRecordResponse struct { 210 - URI string `json:"uri"` // at://did/collection/rkey 211 - CID string `json:"cid"` // Record CID 212 - } 213 - 214 - // RawRecord wraps a record value and implements CBORMarshaler 215 - // This allows us to accept any JSON record and marshal it to CBOR 216 - type RawRecord struct { 217 - Value map[string]interface{} 218 - } 219 - 220 - // MarshalCBOR implements CBORMarshaler for RawRecord 221 - func (r *RawRecord) MarshalCBOR(w io.Writer) error { 222 - // Write CBOR map header 223 - if err := cbg.WriteMajorTypeHeader(w, cbg.MajMap, uint64(len(r.Value))); err != nil { 224 - return err 225 - } 226 - 227 - // Write each key-value pair 228 - for key, val := range r.Value { 229 - // Write key as text string 230 - if err := cbg.WriteMajorTypeHeader(w, cbg.MajTextString, uint64(len(key))); err != nil { 231 - return err 232 - } 233 - if _, err := w.Write([]byte(key)); err != nil { 234 - return err 235 - } 236 - 237 - // Write value (simplified - handles common types) 238 - if err := writeValue(w, val); err != nil { 239 - return err 240 - } 241 - } 242 - 243 - return nil 244 - } 245 - 246 - // writeValue writes a value to CBOR (helper for RawRecord) 247 - func writeValue(w io.Writer, val interface{}) error { 248 - switch v := val.(type) { 249 - case string: 250 - if err := cbg.WriteMajorTypeHeader(w, cbg.MajTextString, uint64(len(v))); err != nil { 251 - return err 252 - } 253 - _, err := w.Write([]byte(v)) 254 - return err 255 - case int64: 256 - return cbg.CborWriteHeader(w, cbg.MajUnsignedInt, uint64(v)) 257 - case float64: 258 - // Write as unsigned int for now (simplified) 259 - return cbg.CborWriteHeader(w, cbg.MajUnsignedInt, uint64(v)) 260 - case bool: 261 - if v { 262 - return cbg.WriteBool(w, true) 263 - } 264 - return cbg.WriteBool(w, false) 265 - case map[string]interface{}: 266 - rec := &RawRecord{Value: v} 267 - return rec.MarshalCBOR(w) 268 - case []interface{}: 269 - if err := cbg.WriteMajorTypeHeader(w, cbg.MajArray, uint64(len(v))); err != nil { 270 - return err 271 - } 272 - for _, item := range v { 273 - if err := writeValue(w, item); err != nil { 274 - return err 275 - } 276 - } 277 - return nil 278 - default: 279 - // For unknown types, convert to JSON then write as string 280 - jsonBytes, err := json.Marshal(v) 281 - if err != nil { 282 - return err 283 - } 284 - if err := cbg.WriteMajorTypeHeader(w, cbg.MajTextString, uint64(len(jsonBytes))); err != nil { 285 - return err 286 - } 287 - _, err = w.Write(jsonBytes) 288 - return err 289 - } 290 - } 291 - 292 - // HandleCreateRecord handles com.atproto.repo.createRecord 293 - func (h *XRPCHandler) HandleCreateRecord(w http.ResponseWriter, r *http.Request) { 294 - // Validate JWT authentication 295 - user, err := ValidateJWTAuth(r, h.pds) 296 - if err != nil { 297 - http.Error(w, fmt.Sprintf("authentication required: %v", err), http.StatusUnauthorized) 298 - return 299 - } 300 - 301 - // Parse request 302 - var req CreateRecordRequest 303 - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { 304 - http.Error(w, fmt.Sprintf("invalid request body: %v", err), http.StatusBadRequest) 305 - return 306 - } 307 - 308 - // Validate required fields 309 - if req.Repo == "" || req.Collection == "" || req.Record == nil { 310 - http.Error(w, "repo, collection, and record are required", http.StatusBadRequest) 311 - return 312 - } 313 - 314 - // Verify repo matches authenticated user 315 - if req.Repo != user.DID { 316 - http.Error(w, "repo must match authenticated user DID", http.StatusForbidden) 317 - return 318 - } 319 - 320 - // Verify repo matches this hold's DID 321 - if req.Repo != h.pds.DID() { 322 - http.Error(w, "invalid repo (must be this hold's DID)", http.StatusBadRequest) 323 - return 324 - } 325 - 326 - // Convert record from JSON to CBOR-marshalable format 327 - recordMap, ok := req.Record.(map[string]interface{}) 328 - if !ok { 329 - http.Error(w, "record must be a JSON object", http.StatusBadRequest) 330 - return 331 - } 332 - 333 - // Wrap in RawRecord which implements CBORMarshaler 334 - recordValue := &RawRecord{Value: recordMap} 335 - 336 - // Create record using repomgr 337 - var recordPath string 338 - var recordCID cid.Cid 339 - 340 - if req.Rkey != "" { 341 - // Use PutRecord if rkey is specified 342 - recordPath, recordCID, err = h.pds.repomgr.PutRecord( 343 - r.Context(), 344 - h.pds.uid, 345 - req.Collection, 346 - req.Rkey, 347 - recordValue, 348 - ) 349 - } else { 350 - // Use CreateRecord if no rkey (auto-generates TID) 351 - recordPath, recordCID, err = h.pds.repomgr.CreateRecord( 352 - r.Context(), 353 - h.pds.uid, 354 - req.Collection, 355 - recordValue, 356 - ) 357 - } 358 - 359 - if err != nil { 360 - http.Error(w, fmt.Sprintf("failed to create record: %v", err), http.StatusInternalServerError) 361 - return 362 - } 363 - 364 - // Extract rkey from path (format: "collection/rkey") 365 - parts := strings.Split(recordPath, "/") 366 - if len(parts) < 2 { 367 - http.Error(w, "invalid record path returned", http.StatusInternalServerError) 368 - return 369 - } 370 - actualRkey := parts[len(parts)-1] 371 - 372 - // Return success response 373 - response := CreateRecordResponse{ 374 - URI: fmt.Sprintf("at://%s/%s/%s", h.pds.DID(), req.Collection, actualRkey), 375 - CID: recordCID.String(), 376 - } 377 - 378 - w.Header().Set("Content-Type", "application/json") 379 - w.WriteHeader(http.StatusCreated) 380 - json.NewEncoder(w).Encode(response) 381 - }
+11 -57
pkg/hold/pds/status.go
··· 6 6 "time" 7 7 8 8 bsky "github.com/bluesky-social/indigo/api/bsky" 9 - "github.com/ipfs/go-cid" 10 9 ) 11 10 12 11 const ( 13 - // StatusPostRkey is the fixed rkey for the status post (singleton) 14 - StatusPostRkey = "status" 15 - 16 12 // StatusPostCollection is the collection name for Bluesky posts 17 13 StatusPostCollection = "app.bsky.feed.post" 18 14 ) 19 15 20 - // SetStatus creates or updates the hold's status post on Bluesky 16 + // SetStatus creates a new status post on Bluesky 21 17 // status should be "online" or "offline" 18 + // Each call creates a unique post with a TID-based rkey 22 19 func (p *HoldPDS) SetStatus(ctx context.Context, status string) error { 23 20 // Format the post text with emoji indicator 24 21 emoji := "🟢" ··· 27 24 } 28 25 text := fmt.Sprintf("%s Current status: %s", emoji, status) 29 26 30 - // Check if status post already exists 31 - _, existingPost, err := p.GetStatusPost(ctx) 32 - if err != nil { 33 - // Post doesn't exist, create it 34 - return p.createStatusPost(ctx, text) 35 - } 36 - 37 - // Post exists, update it 38 - // We need to preserve the original CreatedAt timestamp 39 - return p.updateStatusPost(ctx, text, existingPost.CreatedAt) 40 - } 41 - 42 - // GetStatusPost retrieves the status post if it exists 43 - func (p *HoldPDS) GetStatusPost(ctx context.Context) (cid.Cid, *bsky.FeedPost, error) { 44 - // Use repomgr.GetRecord 45 - recordCID, val, err := p.repomgr.GetRecord(ctx, p.uid, StatusPostCollection, StatusPostRkey, cid.Undef) 46 - if err != nil { 47 - return cid.Undef, nil, fmt.Errorf("failed to get status post: %w", err) 48 - } 49 - 50 - // Type assert to bsky.FeedPost 51 - post, ok := val.(*bsky.FeedPost) 52 - if !ok { 53 - return cid.Undef, nil, fmt.Errorf("unexpected type for status post: %T", val) 54 - } 55 - 56 - return recordCID, post, nil 27 + // Create the post with a unique TID 28 + return p.createStatusPost(ctx, text) 57 29 } 58 30 59 - // createStatusPost creates a new status post (first time) 31 + // createStatusPost creates a new status post with a TID-based rkey 60 32 func (p *HoldPDS) createStatusPost(ctx context.Context, text string) error { 61 33 // Create post struct 62 - now := time.Now().Format(time.RFC3339) 34 + now := time.Now() 63 35 post := &bsky.FeedPost{ 64 36 LexiconTypeID: "app.bsky.feed.post", 65 37 Text: text, 66 - CreatedAt: now, 38 + CreatedAt: now.Format(time.RFC3339), 67 39 } 68 40 69 - // Use repomgr.PutRecord - creates with explicit rkey, fails if already exists 70 - recordPath, recordCID, err := p.repomgr.PutRecord(ctx, p.uid, StatusPostCollection, StatusPostRkey, post) 41 + // Use repomgr.CreateRecord to create the post with auto-generated TID 42 + // CreateRecord automatically generates a unique TID using the repo's clock 43 + rkey, recordCID, err := p.repomgr.CreateRecord(ctx, p.uid, StatusPostCollection, post) 71 44 if err != nil { 72 45 return fmt.Errorf("failed to create status post: %w", err) 73 46 } 74 47 75 - fmt.Printf("Created status post at %s, cid: %s, text: %s\n", recordPath, recordCID, text) 76 - return nil 77 - } 78 - 79 - // updateStatusPost updates an existing status post 80 - func (p *HoldPDS) updateStatusPost(ctx context.Context, text string, createdAt string) error { 81 - // Create updated post struct with original CreatedAt 82 - post := &bsky.FeedPost{ 83 - LexiconTypeID: "app.bsky.feed.post", 84 - Text: text, 85 - CreatedAt: createdAt, // Preserve original creation time 86 - } 87 - 88 - // Use repomgr.UpdateRecord 89 - recordCID, err := p.repomgr.UpdateRecord(ctx, p.uid, StatusPostCollection, StatusPostRkey, post) 90 - if err != nil { 91 - return fmt.Errorf("failed to update status post: %w", err) 92 - } 93 - 94 - fmt.Printf("Updated status post, cid: %s, text: %s\n", recordCID, text) 48 + fmt.Printf("Created status post at %s/%s (rkey: %s), cid: %s, text: %s\n", StatusPostCollection, rkey, rkey, recordCID, text) 95 49 return nil 96 50 }
+201 -64
pkg/hold/pds/status_test.go
··· 2 2 3 3 import ( 4 4 "context" 5 + "encoding/json" 6 + "fmt" 7 + "net/http" 8 + "net/http/httptest" 5 9 "os" 6 10 "path/filepath" 7 11 "testing" 12 + "time" 8 13 14 + "atcr.io/pkg/atproto" 15 + "atcr.io/pkg/s3" 9 16 bsky "github.com/bluesky-social/indigo/api/bsky" 10 17 ) 11 18 19 + // Shared test resources (used across all test files in package) 20 + var ( 21 + sharedTestKeyPath string 22 + sharedTestKey []byte 23 + sharedPDS *HoldPDS // Shared bootstrapped PDS for read-only tests 24 + sharedHandler *XRPCHandler // Shared handler for read-only tests 25 + sharedCtx context.Context // Shared context 26 + ) 27 + 12 28 func TestStatusPost(t *testing.T) { 13 - // Create temporary directory for test database 29 + // Create temporary directory for test 14 30 tmpDir := t.TempDir() 15 - dbPath := filepath.Join(tmpDir, "test.db") 31 + // Use in-memory database for speed 32 + dbPath := ":memory:" 16 33 keyPath := filepath.Join(tmpDir, "test.key") 34 + 35 + // Copy shared signing key 36 + if err := os.WriteFile(keyPath, sharedTestKey, 0600); err != nil { 37 + t.Fatalf("Failed to copy shared signing key: %v", err) 38 + } 17 39 18 40 // Create test PDS 19 41 ctx := context.Background() ··· 31 53 t.Fatalf("Failed to initialize repo: %v", err) 32 54 } 33 55 56 + // Create handler for XRPC endpoints 57 + handler := NewXRPCHandler(holdPDS, s3.S3Service{}, nil, nil, &mockPDSClient{}) 58 + 59 + // Helper function to list posts via XRPC 60 + listPosts := func() ([]map[string]any, error) { 61 + req := makeXRPCGetRequest(atproto.RepoListRecords, map[string]string{ 62 + "repo": did, 63 + "collection": StatusPostCollection, 64 + "limit": "100", 65 + "reverse": "true", // Most recent first 66 + }) 67 + w := httptest.NewRecorder() 68 + handler.HandleListRecords(w, req) 69 + 70 + if w.Code != http.StatusOK { 71 + return nil, fmt.Errorf("unexpected status code: %d, body: %s", w.Code, w.Body.String()) 72 + } 73 + 74 + var result map[string]any 75 + if err := json.NewDecoder(w.Body).Decode(&result); err != nil { 76 + return nil, fmt.Errorf("failed to decode response: %w", err) 77 + } 78 + 79 + records, ok := result["records"].([]any) 80 + if !ok { 81 + return nil, fmt.Errorf("expected records array, got %T", result["records"]) 82 + } 83 + 84 + posts := make([]map[string]any, len(records)) 85 + for i, rec := range records { 86 + post, ok := rec.(map[string]any) 87 + if !ok { 88 + return nil, fmt.Errorf("expected record map, got %T", rec) 89 + } 90 + posts[i] = post 91 + } 92 + return posts, nil 93 + } 94 + 34 95 t.Run("CreateStatusPost", func(t *testing.T) { 35 96 // Set status to online (creates new post) 36 97 err := holdPDS.SetStatus(ctx, "online") ··· 38 99 t.Fatalf("Failed to set status to online: %v", err) 39 100 } 40 101 41 - // Verify post was created 42 - _, post, err := holdPDS.GetStatusPost(ctx) 102 + // List posts 103 + posts, err := listPosts() 43 104 if err != nil { 44 - t.Fatalf("Failed to get status post: %v", err) 105 + t.Fatalf("Failed to list posts: %v", err) 106 + } 107 + 108 + if len(posts) == 0 { 109 + t.Fatal("Expected at least one status post, got 0") 110 + } 111 + 112 + // Get the latest post 113 + post := posts[0] 114 + 115 + value, ok := post["value"].(map[string]any) 116 + if !ok { 117 + t.Fatalf("Expected value map, got %T", post["value"]) 45 118 } 46 119 47 - if post.Text != "🟢 Current status: online" { 48 - t.Errorf("Expected text '🟢 Current status: online', got '%s'", post.Text) 120 + text, ok := value["text"].(string) 121 + if !ok { 122 + t.Fatalf("Expected text string, got %T", value["text"]) 49 123 } 50 124 51 - if post.LexiconTypeID != "app.bsky.feed.post" { 52 - t.Errorf("Expected LexiconTypeID 'app.bsky.feed.post', got '%s'", post.LexiconTypeID) 125 + if text != "🟢 Current status: online" { 126 + t.Errorf("Expected text '🟢 Current status: online', got '%s'", text) 53 127 } 54 128 55 - if post.CreatedAt == "" { 56 - t.Error("CreatedAt should not be empty") 129 + // Verify TID-based rkey (extract from URI) 130 + uri, ok := post["uri"].(string) 131 + if !ok { 132 + t.Fatalf("Expected uri string, got %T", post["uri"]) 133 + } 134 + // URI format: at://did:web:test.example.com/app.bsky.feed.post/3m3c4... 135 + // We just check that it contains the collection 136 + if !contains(uri, StatusPostCollection) { 137 + t.Errorf("Expected URI to contain collection %s, got %s", StatusPostCollection, uri) 57 138 } 58 139 }) 59 140 60 - t.Run("UpdateStatusPost", func(t *testing.T) { 61 - // Get the original post to check CreatedAt preservation 62 - _, originalPost, err := holdPDS.GetStatusPost(ctx) 141 + t.Run("CreateMultiplePosts", func(t *testing.T) { 142 + // Create multiple status posts 143 + err := holdPDS.SetStatus(ctx, "offline") 63 144 if err != nil { 64 - t.Fatalf("Failed to get original status post: %v", err) 145 + t.Fatalf("Failed to set status to offline: %v", err) 65 146 } 66 147 67 - // Set status to offline (updates existing post) 68 - err = holdPDS.SetStatus(ctx, "offline") 148 + // Wait a moment to ensure different timestamp 149 + time.Sleep(10 * time.Millisecond) 150 + 151 + err = holdPDS.SetStatus(ctx, "online") 69 152 if err != nil { 70 - t.Fatalf("Failed to set status to offline: %v", err) 153 + t.Fatalf("Failed to set status to online again: %v", err) 71 154 } 72 155 73 - // Verify post was updated 74 - _, post, err := holdPDS.GetStatusPost(ctx) 156 + // List all posts - should have at least 3 now (1 from previous test + 2 from this test) 157 + posts, err := listPosts() 75 158 if err != nil { 76 - t.Fatalf("Failed to get updated status post: %v", err) 159 + t.Fatalf("Failed to list posts: %v", err) 77 160 } 78 161 79 - if post.Text != "🔴 Current status: offline" { 80 - t.Errorf("Expected text '🔴 Current status: offline', got '%s'", post.Text) 162 + if len(posts) < 3 { 163 + t.Errorf("Expected at least 3 status posts, got %d", len(posts)) 164 + } 165 + 166 + // Verify each post has a unique URI 167 + uris := make(map[string]bool) 168 + for _, post := range posts { 169 + uri, ok := post["uri"].(string) 170 + if !ok { 171 + t.Errorf("Expected uri string, got %T", post["uri"]) 172 + continue 173 + } 174 + if uris[uri] { 175 + t.Errorf("Duplicate URI found: %s", uri) 176 + } 177 + uris[uri] = true 81 178 } 82 179 83 - // Verify CreatedAt was preserved 84 - if post.CreatedAt != originalPost.CreatedAt { 85 - t.Errorf("CreatedAt should be preserved. Expected '%s', got '%s'", originalPost.CreatedAt, post.CreatedAt) 180 + // Verify the latest post is online 181 + latestPost := posts[0] 182 + value, ok := latestPost["value"].(map[string]any) 183 + if !ok { 184 + t.Fatalf("Expected value map, got %T", latestPost["value"]) 185 + } 186 + text, ok := value["text"].(string) 187 + if !ok { 188 + t.Fatalf("Expected text string, got %T", value["text"]) 189 + } 190 + if text != "🟢 Current status: online" { 191 + t.Errorf("Expected latest post text '🟢 Current status: online', got '%s'", text) 86 192 } 87 193 }) 88 194 89 - t.Run("ToggleStatus", func(t *testing.T) { 90 - // Toggle back to online 91 - err := holdPDS.SetStatus(ctx, "online") 195 + t.Run("OfflineStatus", func(t *testing.T) { 196 + // Create offline status post 197 + err := holdPDS.SetStatus(ctx, "offline") 92 198 if err != nil { 93 - t.Fatalf("Failed to set status to online: %v", err) 199 + t.Fatalf("Failed to set status to offline: %v", err) 94 200 } 95 201 96 - _, post, err := holdPDS.GetStatusPost(ctx) 202 + // Get the latest post 203 + posts, err := listPosts() 97 204 if err != nil { 98 - t.Fatalf("Failed to get status post: %v", err) 205 + t.Fatalf("Failed to list posts: %v", err) 206 + } 207 + 208 + if len(posts) == 0 { 209 + t.Fatal("Expected at least one status post, got 0") 210 + } 211 + 212 + latestPost := posts[0] 213 + value, ok := latestPost["value"].(map[string]any) 214 + if !ok { 215 + t.Fatalf("Expected value map, got %T", latestPost["value"]) 216 + } 217 + text, ok := value["text"].(string) 218 + if !ok { 219 + t.Fatalf("Expected text string, got %T", value["text"]) 99 220 } 100 221 101 - if post.Text != "🟢 Current status: online" { 102 - t.Errorf("Expected text '🟢 Current status: online', got '%s'", post.Text) 222 + if text != "🔴 Current status: offline" { 223 + t.Errorf("Expected text '🔴 Current status: offline', got '%s'", text) 103 224 } 104 225 }) 105 226 } 106 227 107 228 func TestStatusPostCollection(t *testing.T) { 108 - // Verify constants 229 + // Verify constant 109 230 if StatusPostCollection != "app.bsky.feed.post" { 110 231 t.Errorf("Expected StatusPostCollection 'app.bsky.feed.post', got '%s'", StatusPostCollection) 111 232 } 112 - 113 - if StatusPostRkey != "status" { 114 - t.Errorf("Expected StatusPostRkey 'status', got '%s'", StatusPostRkey) 115 - } 116 233 } 117 234 118 - func TestGetStatusPostNotExists(t *testing.T) { 119 - // Create temporary directory for test database 120 - tmpDir := t.TempDir() 121 - dbPath := filepath.Join(tmpDir, "test.db") 122 - keyPath := filepath.Join(tmpDir, "test.key") 235 + // Helper function to check if a string contains a substring 236 + func contains(s, substr string) bool { 237 + return len(s) >= len(substr) && (s == substr || len(s) > len(substr) && findSubstring(s, substr)) 238 + } 123 239 124 - // Create test PDS 125 - ctx := context.Background() 126 - did := "did:web:test2.example.com" 127 - publicURL := "https://test2.example.com" 128 - 129 - holdPDS, err := NewHoldPDS(ctx, did, publicURL, dbPath, keyPath) 130 - if err != nil { 131 - t.Fatalf("Failed to create test PDS: %v", err) 240 + func findSubstring(s, substr string) bool { 241 + for i := 0; i <= len(s)-len(substr); i++ { 242 + if s[i:i+len(substr)] == substr { 243 + return true 244 + } 132 245 } 133 - 134 - // Initialize empty repo 135 - err = holdPDS.repomgr.InitNewActor(ctx, holdPDS.uid, "", did, "", "", "") 136 - if err != nil { 137 - t.Fatalf("Failed to initialize repo: %v", err) 138 - } 139 - 140 - // Try to get status post that doesn't exist 141 - _, _, err = holdPDS.GetStatusPost(ctx) 142 - if err == nil { 143 - t.Error("Expected error when getting non-existent status post, got nil") 144 - } 246 + return false 145 247 } 146 248 147 249 func init() { ··· 154 256 155 257 // Cleanup function to remove test files 156 258 func TestMain(m *testing.M) { 259 + // Create a temporary directory for shared test key 260 + tmpDir, err := os.MkdirTemp("", "pds-test-shared-*") 261 + if err != nil { 262 + panic(fmt.Sprintf("Failed to create temp dir: %v", err)) 263 + } 264 + defer os.RemoveAll(tmpDir) 265 + 266 + // Generate one signing key to be reused across all tests in the package 267 + sharedTestKeyPath = filepath.Join(tmpDir, "shared-signing-key") 268 + privateKey, err := GenerateOrLoadKey(sharedTestKeyPath) 269 + if err != nil { 270 + panic(fmt.Sprintf("Failed to generate shared signing key: %v", err)) 271 + } 272 + 273 + // Store the key bytes so tests can copy them 274 + sharedTestKey = privateKey.Bytes() 275 + 276 + // Create one shared, bootstrapped PDS for read-only tests 277 + // Use in-memory database for speed 278 + sharedCtx = context.Background() 279 + sharedPDS, err = NewHoldPDS(sharedCtx, "did:web:hold.example.com", "https://hold.example.com", ":memory:", sharedTestKeyPath) 280 + if err != nil { 281 + panic(fmt.Sprintf("Failed to create shared PDS: %v", err)) 282 + } 283 + 284 + // Bootstrap once 285 + ownerDID := "did:plc:testowner123" 286 + err = sharedPDS.Bootstrap(sharedCtx, nil, ownerDID, true, false, "") 287 + if err != nil { 288 + panic(fmt.Sprintf("Failed to bootstrap shared PDS: %v", err)) 289 + } 290 + 291 + // Create shared handler 292 + sharedHandler = NewXRPCHandler(sharedPDS, s3.S3Service{}, nil, nil, &mockPDSClient{}) 293 + 157 294 // Run tests 158 295 code := m.Run() 159 296
+6 -29
pkg/hold/pds/xrpc.go
··· 92 92 93 93 // Handle OPTIONS preflight 94 94 if r.Method == "OPTIONS" { 95 - w.WriteHeader(http.StatusNoContent) 95 + w.WriteHeader(http.StatusOK) 96 96 return 97 97 } 98 98 ··· 150 150 r.Get("/xrpc/_health", h.HandleHealth) 151 151 r.Get(atproto.ServerDescribeServer, h.HandleDescribeServer) 152 152 153 - // Session management (public - creates sessions) 154 - r.Post(atproto.ServerCreateSession, h.HandleCreateSession) 155 - r.Post(atproto.ServerRefreshSession, h.HandleRefreshSession) 156 - r.Get(atproto.ServerGetSession, h.HandleGetSession) 157 - 158 153 // Repository metadata 159 154 r.Get(atproto.RepoDescribeRepo, h.HandleDescribeRepo) 160 155 r.Get(atproto.RepoGetRecord, h.HandleGetRecord) ··· 186 181 // Write endpoints (owner/crew admin auth) 187 182 r.Group(func(r chi.Router) { 188 183 r.Use(h.requireOwnerOrCrewAdmin) 189 - 190 184 r.Post(atproto.RepoDeleteRecord, h.HandleDeleteRecord) 191 185 r.Post(atproto.RepoUploadBlob, h.HandleUploadBlob) 192 186 }) ··· 194 188 // Auth-only endpoints (DPoP auth) 195 189 r.Group(func(r chi.Router) { 196 190 r.Use(h.requireAuth) 197 - 198 191 r.Post(atproto.HoldRequestCrew, h.HandleRequestCrew) 199 - }) 200 - 201 - // JWT-authenticated endpoints (JWT auth from createSession) 202 - // Note: JWT auth is validated inside each handler 203 - r.Group(func(r chi.Router) { 204 - r.Post("/xrpc/com.atproto.repo.createRecord", h.HandleCreateRecord) 205 192 }) 206 193 } 207 194 ··· 861 848 } 862 849 863 850 // Get optional cursor parameter for backfill 864 - var cursor int64 = 0 851 + // Default to -1 (no backfill, only stream new events) 852 + // cursor=0 means "replay all events from the beginning" 853 + var cursor int64 = -1 865 854 if cursorStr := r.URL.Query().Get("cursor"); cursorStr != "" { 866 855 var err error 867 856 cursor, err = strconv.ParseInt(cursorStr, 10, 64) ··· 879 868 } 880 869 881 870 // Subscribe to events 882 - sub := h.broadcaster.Subscribe(conn, cursor) 883 - 884 871 // The broadcaster's handleSubscriber goroutine will manage this connection 885 - // We just need to keep reading to detect client disconnects 886 - go func() { 887 - defer h.broadcaster.Unsubscribe(sub) 888 - for { 889 - // Read messages from client (mostly just to detect disconnect) 890 - _, _, err := conn.ReadMessage() 891 - if err != nil { 892 - // Client disconnected 893 - break 894 - } 895 - } 896 - }() 872 + // and handle cleanup when the client disconnects 873 + h.broadcaster.Subscribe(conn, cursor) 897 874 } 898 875 899 876 // HandleUploadBlob handles blob uploads with support for multipart operations
+257 -3
pkg/hold/pds/xrpc_test.go
··· 12 12 "path/filepath" 13 13 "strings" 14 14 "testing" 15 + "time" 15 16 16 17 "atcr.io/pkg/atproto" 17 18 "atcr.io/pkg/s3" 19 + indigoAtproto "github.com/bluesky-social/indigo/api/atproto" 20 + "github.com/bluesky-social/indigo/events" 18 21 "github.com/distribution/distribution/v3/registry/storage/driver/factory" 19 22 _ "github.com/distribution/distribution/v3/registry/storage/driver/filesystem" 20 23 "github.com/go-chi/chi/v5" 24 + "github.com/gorilla/websocket" 25 + "github.com/ipfs/go-cid" 21 26 ) 22 27 23 28 // Test helpers ··· 30 35 ctx := context.Background() 31 36 tmpDir := t.TempDir() 32 37 33 - dbPath := filepath.Join(tmpDir, "pds.db") 38 + // Use in-memory database for speed 39 + dbPath := ":memory:" 34 40 keyPath := filepath.Join(tmpDir, "signing-key") 41 + 42 + // Copy shared signing key instead of generating a new one 43 + if err := os.WriteFile(keyPath, sharedTestKey, 0600); err != nil { 44 + t.Fatalf("Failed to copy shared signing key: %v", err) 45 + } 35 46 36 47 pds, err := NewHoldPDS(ctx, "did:web:hold.example.com", "https://hold.example.com", dbPath, keyPath) 37 48 if err != nil { ··· 115 126 } 116 127 117 128 return result 129 + } 130 + 131 + // decodeFirehoseMessage decodes an ATProto firehose message (header + CBOR body) 132 + func decodeFirehoseMessage(t *testing.T, message []byte) (*events.EventHeader, *indigoAtproto.SyncSubscribeRepos_Commit) { 133 + t.Helper() 134 + 135 + reader := bytes.NewReader(message) 136 + 137 + // Decode header 138 + var header events.EventHeader 139 + if err := header.UnmarshalCBOR(reader); err != nil { 140 + t.Fatalf("Failed to decode event header: %v", err) 141 + } 142 + 143 + // Verify it's a commit event 144 + if header.MsgType != "#commit" { 145 + t.Fatalf("Expected #commit event, got %s", header.MsgType) 146 + } 147 + 148 + // Decode commit event 149 + var commit indigoAtproto.SyncSubscribeRepos_Commit 150 + if err := commit.UnmarshalCBOR(reader); err != nil { 151 + t.Fatalf("Failed to decode commit event: %v", err) 152 + } 153 + 154 + return &header, &commit 118 155 } 119 156 120 157 // assertCARResponse validates CAR file response ··· 1331 1368 ctx := context.Background() 1332 1369 tmpDir := t.TempDir() 1333 1370 1334 - dbPath := filepath.Join(tmpDir, "pds.db") 1371 + // Use in-memory database for speed 1372 + dbPath := ":memory:" 1335 1373 keyPath := filepath.Join(tmpDir, "signing-key") 1374 + 1375 + // Copy shared signing key instead of generating a new one 1376 + if err := os.WriteFile(keyPath, sharedTestKey, 0600); err != nil { 1377 + t.Fatalf("Failed to copy shared signing key: %v", err) 1378 + } 1336 1379 1337 1380 pds, err := NewHoldPDS(ctx, "did:web:hold.example.com", "https://hold.example.com", dbPath, keyPath) 1338 1381 if err != nil { ··· 1686 1729 w := httptest.NewRecorder() 1687 1730 1688 1731 // Wrap with CORS middleware (chi-style) 1689 - corsHandler := handler.corsMiddleware(http.HandlerFunc(handler.HandleGetBlob)) 1732 + corsHandler := handler.CORSMiddleware()(http.HandlerFunc(handler.HandleGetBlob)) 1690 1733 corsHandler.ServeHTTP(w, req) 1691 1734 1692 1735 // Verify CORS headers are present ··· 1719 1762 1720 1763 // Create chi router and register handlers 1721 1764 r := chi.NewRouter() 1765 + r.Use(handler.CORSMiddleware()) // Apply CORS middleware 1722 1766 handler.RegisterHandlers(r) 1723 1767 1724 1768 tests := []struct { ··· 2009 2053 }) 2010 2054 } 2011 2055 } 2056 + 2057 + // TestHandleSubscribeRepos tests the WebSocket firehose endpoint 2058 + func TestHandleSubscribeRepos(t *testing.T) { 2059 + handler, ctx := setupTestXRPCHandler(t) 2060 + 2061 + // Create EventBroadcaster 2062 + broadcaster := NewEventBroadcaster(handler.pds.DID(), 100) 2063 + handler.broadcaster = broadcaster 2064 + 2065 + // Set up test HTTP server 2066 + r := chi.NewRouter() 2067 + handler.RegisterHandlers(r) 2068 + server := httptest.NewServer(r) 2069 + defer server.Close() 2070 + 2071 + // Broadcast some events before connecting 2072 + testCID, _ := cid.Decode("bafyreib2rxk3rkhh5ylyxj3x3gathxt3s32qvwj2lf3qg4kmzr6b7teqke") 2073 + for i := 1; i <= 3; i++ { 2074 + event := &RepoEvent{ 2075 + NewRoot: testCID, 2076 + Rev: fmt.Sprintf("rev-%d", i), 2077 + RepoSlice: []byte(fmt.Sprintf("CAR data %d", i)), 2078 + Ops: []RepoOp{}, 2079 + } 2080 + broadcaster.Broadcast(ctx, event) 2081 + } 2082 + 2083 + // Verify events were stored 2084 + if broadcaster.eventSeq != 3 { 2085 + t.Fatalf("Expected eventSeq=3, got %d", broadcaster.eventSeq) 2086 + } 2087 + 2088 + t.Run("cursor=0 replays all events", func(t *testing.T) { 2089 + // Connect to WebSocket with cursor=0 2090 + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/xrpc/com.atproto.sync.subscribeRepos?cursor=0" 2091 + conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) 2092 + if err != nil { 2093 + t.Fatalf("Failed to connect to WebSocket: %v", err) 2094 + } 2095 + defer conn.Close() 2096 + 2097 + // Should receive the 3 historical events 2098 + for i := 0; i < 3; i++ { 2099 + messageType, message, err := conn.ReadMessage() 2100 + if err != nil { 2101 + t.Fatalf("Failed to read message: %v", err) 2102 + } 2103 + 2104 + if messageType != websocket.BinaryMessage { 2105 + t.Errorf("Expected binary message, got type %d", messageType) 2106 + } 2107 + 2108 + // Decode CBOR message (header + commit) 2109 + header, commit := decodeFirehoseMessage(t, message) 2110 + 2111 + // Verify header 2112 + if header.MsgType != "#commit" { 2113 + t.Errorf("Expected MsgType=#commit, got %s", header.MsgType) 2114 + } 2115 + 2116 + // Verify commit fields 2117 + expectedSeq := int64(i + 1) 2118 + if commit.Seq != expectedSeq { 2119 + t.Errorf("Expected seq=%d, got %d", expectedSeq, commit.Seq) 2120 + } 2121 + if commit.Repo != handler.pds.DID() { 2122 + t.Errorf("Expected repo=%s, got %s", handler.pds.DID(), commit.Repo) 2123 + } 2124 + } 2125 + }) 2126 + 2127 + t.Run("cursor=2 only replays events after 2", func(t *testing.T) { 2128 + // Connect with cursor=2 2129 + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/xrpc/com.atproto.sync.subscribeRepos?cursor=2" 2130 + conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) 2131 + if err != nil { 2132 + t.Fatalf("Failed to connect to WebSocket: %v", err) 2133 + } 2134 + defer conn.Close() 2135 + 2136 + // Should only receive event 3 (after cursor=2) 2137 + messageType, message, err := conn.ReadMessage() 2138 + if err != nil { 2139 + t.Fatalf("Failed to read message: %v", err) 2140 + } 2141 + 2142 + if messageType != websocket.BinaryMessage { 2143 + t.Errorf("Expected binary message, got type %d", messageType) 2144 + } 2145 + 2146 + header, commit := decodeFirehoseMessage(t, message) 2147 + if header.MsgType != "#commit" { 2148 + t.Errorf("Expected MsgType=#commit, got %s", header.MsgType) 2149 + } 2150 + 2151 + if commit.Seq != 3 { 2152 + t.Errorf("Expected seq=3, got %d", commit.Seq) 2153 + } 2154 + 2155 + // Verify no more events (use timeout) 2156 + conn.SetReadDeadline(time.Now().Add(100 * time.Millisecond)) 2157 + _, _, err = conn.ReadMessage() 2158 + if err == nil { 2159 + t.Error("Expected no more events, but received another message") 2160 + } 2161 + }) 2162 + 2163 + t.Run("no cursor streams only new events", func(t *testing.T) { 2164 + // Connect without cursor (should not get backfill) 2165 + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/xrpc/com.atproto.sync.subscribeRepos" 2166 + conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) 2167 + if err != nil { 2168 + t.Fatalf("Failed to connect to WebSocket: %v", err) 2169 + } 2170 + defer conn.Close() 2171 + 2172 + // Verify no historical events by broadcasting immediately and checking 2173 + // that we only receive the new event (not historical ones) 2174 + // Give subscriber time to register first 2175 + time.Sleep(100 * time.Millisecond) 2176 + 2177 + // Broadcast a new event (seq 4) 2178 + newEvent := &RepoEvent{ 2179 + NewRoot: testCID, 2180 + Rev: "rev-4", 2181 + RepoSlice: []byte("CAR data 4"), 2182 + Ops: []RepoOp{}, 2183 + } 2184 + broadcaster.Broadcast(ctx, newEvent) 2185 + 2186 + // Should receive ONLY the new event (seq 4), not historical events 1-3 2187 + conn.SetReadDeadline(time.Now().Add(1 * time.Second)) 2188 + messageType, message, err := conn.ReadMessage() 2189 + if err != nil { 2190 + t.Fatalf("Failed to read new event: %v", err) 2191 + } 2192 + 2193 + if messageType != websocket.BinaryMessage { 2194 + t.Errorf("Expected binary message, got type %d", messageType) 2195 + } 2196 + 2197 + header, commit := decodeFirehoseMessage(t, message) 2198 + if header.MsgType != "#commit" { 2199 + t.Errorf("Expected MsgType=#commit, got %s", header.MsgType) 2200 + } 2201 + 2202 + // Key assertion: should be seq 4 (new event), not seq 1 (historical backfill) 2203 + if commit.Seq != 4 { 2204 + t.Errorf("Expected seq=4 for new event (no backfill), got %d", commit.Seq) 2205 + } 2206 + 2207 + // Verify no more messages (no historical backfill) 2208 + conn.SetReadDeadline(time.Now().Add(100 * time.Millisecond)) 2209 + _, _, err = conn.ReadMessage() 2210 + if err == nil { 2211 + t.Error("Expected no more events, but received another message (possible backfill leak)") 2212 + } 2213 + }) 2214 + 2215 + t.Run("real-time event delivery", func(t *testing.T) { 2216 + // Connect with cursor=0 to get all events first 2217 + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/xrpc/com.atproto.sync.subscribeRepos?cursor=0" 2218 + conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) 2219 + if err != nil { 2220 + t.Fatalf("Failed to connect to WebSocket: %v", err) 2221 + } 2222 + defer conn.Close() 2223 + 2224 + // Read and discard the 4 historical events (seq 1-4) 2225 + for i := 0; i < 4; i++ { 2226 + _, _, err := conn.ReadMessage() 2227 + if err != nil { 2228 + t.Fatalf("Failed to read historical event %d: %v", i+1, err) 2229 + } 2230 + } 2231 + 2232 + // Broadcast 2 new events 2233 + for i := 5; i <= 6; i++ { 2234 + newEvent := &RepoEvent{ 2235 + NewRoot: testCID, 2236 + Rev: fmt.Sprintf("rev-%d", i), 2237 + RepoSlice: []byte(fmt.Sprintf("CAR data %d", i)), 2238 + Ops: []RepoOp{}, 2239 + } 2240 + broadcaster.Broadcast(ctx, newEvent) 2241 + } 2242 + 2243 + // Should receive both new events 2244 + for expectedSeq := 5; expectedSeq <= 6; expectedSeq++ { 2245 + conn.SetReadDeadline(time.Now().Add(1 * time.Second)) 2246 + messageType, message, err := conn.ReadMessage() 2247 + if err != nil { 2248 + t.Fatalf("Failed to read event seq=%d: %v", expectedSeq, err) 2249 + } 2250 + 2251 + if messageType != websocket.BinaryMessage { 2252 + t.Errorf("Expected binary message, got type %d", messageType) 2253 + } 2254 + 2255 + header, commit := decodeFirehoseMessage(t, message) 2256 + if header.MsgType != "#commit" { 2257 + t.Errorf("Expected MsgType=#commit, got %s", header.MsgType) 2258 + } 2259 + 2260 + if commit.Seq != int64(expectedSeq) { 2261 + t.Errorf("Expected seq=%d, got %d", expectedSeq, commit.Seq) 2262 + } 2263 + } 2264 + }) 2265 + }