like malachite (atproto-lastfm-importer) but in go and bluer
go spotify tealfm lastfm atproto
0
fork

Configure Feed

Select the types of activity you want to include in your feed.

cleanup more code

karitham 9cc177a3 abadd114

+148 -744
-4
atproto/client.go
··· 126 126 return result.DID, result.PDS, result.SigningKey, nil 127 127 } 128 128 129 - func hasJSONContent(header http.Header) bool { 130 - return len(header.Get("Content-Type")) > 0 && header.Get("Content-Type")[0:19] == "application/json" 131 - } 132 - 133 129 func ResolveIdentity(ctx context.Context, handle string, opts *ClientOptions) (resolvedIdentity, error) { 134 130 if handle == "" { 135 131 return resolvedIdentity{}, fmt.Errorf("handle cannot be empty")
+4 -4
atproto/repo_test.go
··· 897 897 } 898 898 } 899 899 900 - func TestRepoClient_Interfaces(t *testing.T) { 901 - var _ RepoClient[map[string]any] = (*RateClient[map[string]any])(nil) 902 - var _ RepoClient[map[string]any] = (*RepoClientFuncs[map[string]any])(nil) 903 - } 900 + var ( 901 + _ RepoClient[map[string]any] = (*RateClient[map[string]any])(nil) 902 + _ RepoClient[map[string]any] = (*RepoClientFuncs[map[string]any])(nil) 903 + ) 904 904 905 905 func TestRateClient_ListRecords_WithTypedRecords(t *testing.T) { 906 906 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+71 -32
main.go
··· 1 1 package main 2 2 3 3 import ( 4 + "archive/zip" 4 5 "context" 5 6 "encoding/json" 6 7 "fmt" 7 8 "io" 9 + "io/fs" 8 10 "log/slog" 9 11 "os" 10 12 "slices" ··· 475 477 reverse := cmd.Bool("reverse") 476 478 tolerance := cmd.Duration("tolerance") 477 479 478 - lastfmRecords, err := sync.ParseInput(ctx, lastfmPath, lastfm.Parser{}) 479 - if err != nil { 480 - return fmt.Errorf("parse lastfm: %w", err) 481 - } 482 - a.log.Info("Loaded Last.fm records", slog.Int("count", len(lastfmRecords))) 483 - 484 - spotifyRecords, err := sync.ParseInput(ctx, spotifyPath, spotify.Parser{}) 480 + records, _, err := loadRecordsMerge(ctx, lastfmPath, spotifyPath, tolerance) 485 481 if err != nil { 486 - return fmt.Errorf("parse spotify: %w", err) 482 + return fmt.Errorf("failed to deduplicate records: %w", err) 487 483 } 488 - a.log.Info("Loaded Spotify records", slog.Int("count", len(spotifyRecords))) 489 - 490 - mergedRecords := kway.Merge([][]sync.PlayRecord{lastfmRecords, spotifyRecords}, tolerance) 491 - 492 - a.log.Info( 493 - "Merged records", 494 - slog.Int("merged_total", len(mergedRecords)), 495 - slog.Int("duplicates_removed", len(lastfmRecords)+len(spotifyRecords)-len(mergedRecords)), 496 - ) 497 484 498 485 if reverse { 499 - slices.Reverse(mergedRecords) 486 + slices.Reverse(records) 500 487 } 501 488 502 - return a.outputRecords(mergedRecords, outputPath) 489 + return a.outputRecords(records, outputPath) 503 490 } 504 491 505 492 func (a *App) runImport(ctx context.Context, cmd *cli.Command) error { ··· 516 503 reverse := cmd.Bool("reverse") 517 504 fresh := cmd.Bool("fresh") 518 505 clearCache := cmd.Bool("clear-cache") 519 - batchSize := int(cmd.Int("batch-size")) 506 + batchSize := cmd.Int("batch-size") 520 507 tolerance := cmd.Duration("tolerance") 521 508 522 509 if clearCache { ··· 527 514 } 528 515 } 529 516 530 - records, totalCount, err := sync.LoadRecordsForImport(ctx, sync.ImportOptions{ 531 - LastFMPath: lastfmPath, 532 - SpotifyPath: spotifyPath, 533 - Tolerance: tolerance, 534 - LastFMParser: lastfm.Parser{}, 535 - SpotifyParser: spotify.Parser{}, 536 - }) 517 + records, totalCount, err := loadRecordsMerge(ctx, lastfmPath, spotifyPath, tolerance) 537 518 if err != nil { 538 519 return fmt.Errorf("load records: %w", err) 539 520 } ··· 588 569 } 589 570 } 590 571 591 - cfg := sync.DefaultConfig 592 - cfg.BatchSize = batchSize 593 - 594 572 progressLog := a.createProgressLogger() 595 573 596 574 publishOpts := sync.PublishOptions{ 597 - BatchSize: cfg.BatchSize, 575 + BatchSize: batchSize, 598 576 DryRun: dryRun, 599 577 ATProtoClient: repoClient, 600 578 ProgressLog: progressLog, ··· 626 604 } 627 605 } 628 606 629 - if result.Errored() { 607 + if result.ErrorCount > 0 { 630 608 return fmt.Errorf("import completed with %d errors", result.ErrorCount) 631 609 } 632 610 ··· 948 926 Sources: cli.EnvVars(EnvYes), 949 927 }, 950 928 } 929 + 930 + type Parser interface { 931 + ParseFile(ctx context.Context, r io.Reader) ([]sync.PlayRecord, error) 932 + ParseFS(ctx context.Context, fsys fs.FS) ([]sync.PlayRecord, error) 933 + } 934 + 935 + func parseInput(ctx context.Context, path string, parser Parser) ([]sync.PlayRecord, error) { 936 + info, err := os.Stat(path) 937 + if err != nil { 938 + return nil, fmt.Errorf("stat path: %w", err) 939 + } 940 + 941 + if info.IsDir() { 942 + return parser.ParseFS(ctx, os.DirFS(path)) 943 + } 944 + 945 + if strings.HasSuffix(path, ".zip") { 946 + zf, err := zip.OpenReader(path) 947 + if err != nil { 948 + return nil, fmt.Errorf("open zip: %w", err) 949 + } 950 + 951 + defer zf.Close() 952 + 953 + return parser.ParseFS(ctx, zf) 954 + } 955 + 956 + file, err := os.Open(path) 957 + if err != nil { 958 + return nil, fmt.Errorf("open file: %w", err) 959 + } 960 + 961 + defer file.Close() 962 + 963 + return parser.ParseFile(ctx, file) 964 + } 965 + 966 + func loadRecordsMerge(ctx context.Context, lastFMPath, spotifyPath string, tolerance time.Duration) ([]sync.PlayRecord, int, error) { 967 + var lastfmRecords, spotifyRecords []sync.PlayRecord 968 + var err error 969 + 970 + if lastFMPath != "" { 971 + lastfmRecords, err = parseInput(ctx, lastFMPath, lastfm.Parser{}) 972 + if err != nil { 973 + return nil, 0, fmt.Errorf("parse lastfm: %w", err) 974 + } 975 + } 976 + 977 + if spotifyPath != "" { 978 + spotifyRecords, err = parseInput(ctx, spotifyPath, spotify.Parser{}) 979 + if err != nil { 980 + return nil, 0, fmt.Errorf("parse spotify: %w", err) 981 + } 982 + } 983 + 984 + totalInput := len(lastfmRecords) + len(spotifyRecords) 985 + 986 + mergedRecords := kway.Merge([][]sync.PlayRecord{lastfmRecords, spotifyRecords}, tolerance) 987 + 988 + return mergedRecords, totalInput, nil 989 + }
+26 -26
sync/batch_test.go
··· 10 10 "sync" 11 11 "sync/atomic" 12 12 "testing" 13 + "testing/synctest" 13 14 "time" 14 15 15 16 "github.com/bluesky-social/indigo/atproto/atclient" 17 + 16 18 "tangled.org/karitham.dev/lazuli/atproto" 17 19 "tangled.org/karitham.dev/lazuli/cache" 18 20 ) ··· 359 361 rec2, _ := json.Marshal(PlayRecord{TrackName: "Song 2"}) 360 362 361 363 t.Run("Retry on transient error", func(t *testing.T) { 362 - storage := newMockStorage() 363 - storage.SaveRecords(did, map[string][]byte{"k1": rec1, "k2": rec2}) 364 + synctest.Test(t, func(t *testing.T) { 365 + storage := newMockStorage() 366 + storage.SaveRecords(did, map[string][]byte{"k1": rec1, "k2": rec2}) 364 367 365 - var attempts int32 366 - client := &mockATProtoClient{ 367 - applyWritesFunc: func(ctx context.Context, collection string, records []PlayRecord) error { 368 - if atomic.AddInt32(&attempts, 1) <= 2 { 369 - return &atclient.APIError{StatusCode: 503} 370 - } 371 - return nil 372 - }, 373 - } 368 + var attempts int32 369 + client := &mockATProtoClient{ 370 + applyWritesFunc: func(ctx context.Context, collection string, records []PlayRecord) error { 371 + if atomic.AddInt32(&attempts, 1) <= 2 { 372 + return &atclient.APIError{StatusCode: 503} 373 + } 374 + return nil 375 + }, 376 + } 374 377 375 - oldBase := BaseRetryDelay 376 - BaseRetryDelay = time.Millisecond 377 - defer func() { BaseRetryDelay = oldBase }() 378 + res := Publish(ctx, &mockAuthClient{did: did}, PublishOptions{ 379 + BatchSize: 1, 380 + ATProtoClient: client, 381 + Storage: storage, 382 + ClientAgent: clientAgent, 383 + }) 378 384 379 - res := Publish(ctx, &mockAuthClient{did: did}, PublishOptions{ 380 - BatchSize: 1, 381 - ATProtoClient: client, 382 - Storage: storage, 383 - ClientAgent: clientAgent, 385 + if res.SuccessCount != 2 { 386 + t.Errorf("expected 2 successes, got %d", res.SuccessCount) 387 + } 388 + if atomic.LoadInt32(&attempts) < 3 { 389 + t.Errorf("expected at least 3 attempts (2 fails + 1 success), got %d", attempts) 390 + } 384 391 }) 385 - 386 - if res.SuccessCount != 2 { 387 - t.Errorf("expected 2 successes, got %d", res.SuccessCount) 388 - } 389 - if atomic.LoadInt32(&attempts) < 3 { 390 - t.Errorf("expected at least 3 attempts (2 fails + 1 success), got %d", attempts) 391 - } 392 392 }) 393 393 394 394 t.Run("Fail fast on non-transient error", func(t *testing.T) {
-55
sync/config.go
··· 1 - package sync 2 - 3 - import ( 4 - "time" 5 - ) 6 - 7 - const ( 8 - RecordType = "fm.teal.alpha.feed.play" 9 - DefaultBatchSize = 20 10 - DefaultCrossSourceTolerance = 5 * time.Minute 11 - CrossSourceTolerance = DefaultCrossSourceTolerance 12 - CacheTTL = 24 * time.Hour 13 - CacheVersion = 1 14 - SlingshotResolverURL = "https://slingshot.microcosm.blue/xrpc/com.bad-example.identity.resolveMiniDoc" 15 - MaxRetryDelay = 15 * time.Minute 16 - MaxRetries = 1000 17 - ) 18 - 19 - var BaseRetryDelay = 2 * time.Second 20 - 21 - type Config struct { 22 - RecordType string `json:"recordType"` 23 - ClientAgent string `json:"clientAgent"` 24 - BatchSize int `json:"batchSize"` 25 - CrossSourceTolerance time.Duration `json:"crossSourceTolerance"` 26 - CacheTTL time.Duration `json:"cacheTTL"` 27 - CacheVersion int `json:"cacheVersion"` 28 - SlingshotResolverURL string `json:"slingshotResolverURL"` 29 - UserAgent string `json:"userAgent"` 30 - } 31 - 32 - var DefaultConfig = Config{ 33 - RecordType: RecordType, 34 - ClientAgent: DefaultClientAgent, 35 - BatchSize: DefaultBatchSize, 36 - CrossSourceTolerance: CrossSourceTolerance, 37 - CacheTTL: CacheTTL, 38 - CacheVersion: CacheVersion, 39 - SlingshotResolverURL: SlingshotResolverURL, 40 - } 41 - 42 - type PublishResult struct { 43 - SuccessCount int `json:"successCount"` 44 - ErrorCount int `json:"errorCount"` 45 - Cancelled bool `json:"cancelled"` 46 - Duration time.Duration `json:"duration"` 47 - TotalRecords int `json:"totalRecords"` 48 - RecordsPerMinute float64 `json:"recordsPerMinute"` 49 - FirstRecordTime time.Time `json:"firstRecordTime"` 50 - LastRecordTime time.Time `json:"lastRecordTime"` 51 - } 52 - 53 - func (r *PublishResult) Errored() bool { 54 - return r.ErrorCount > 0 55 - }
-84
sync/import.go
··· 1 - package sync 2 - 3 - import ( 4 - "archive/zip" 5 - "context" 6 - "fmt" 7 - "io" 8 - "io/fs" 9 - "os" 10 - "strings" 11 - "time" 12 - 13 - "tangled.org/karitham.dev/lazuli/kway" 14 - ) 15 - 16 - type Parser interface { 17 - ParseFile(ctx context.Context, r io.Reader) ([]PlayRecord, error) 18 - ParseFS(ctx context.Context, fsys fs.FS) ([]PlayRecord, error) 19 - } 20 - 21 - func ParseInput(ctx context.Context, path string, parser Parser) ([]PlayRecord, error) { 22 - if path == "" { 23 - return nil, nil 24 - } 25 - 26 - info, err := os.Stat(path) 27 - if err != nil { 28 - return nil, fmt.Errorf("stat path: %w", err) 29 - } 30 - 31 - if info.IsDir() { 32 - return parser.ParseFS(ctx, os.DirFS(path)) 33 - } 34 - 35 - if strings.HasSuffix(path, ".zip") { 36 - zf, err := zip.OpenReader(path) 37 - if err != nil { 38 - return nil, fmt.Errorf("open zip: %w", err) 39 - } 40 - defer zf.Close() 41 - return parser.ParseFS(ctx, zf) 42 - } 43 - 44 - file, err := os.Open(path) 45 - if err != nil { 46 - return nil, fmt.Errorf("open file: %w", err) 47 - } 48 - defer file.Close() 49 - return parser.ParseFile(ctx, file) 50 - } 51 - 52 - type ImportOptions struct { 53 - LastFMPath string 54 - SpotifyPath string 55 - Tolerance time.Duration 56 - 57 - LastFMParser Parser 58 - SpotifyParser Parser 59 - } 60 - 61 - func LoadRecordsForImport(ctx context.Context, opts ImportOptions) ([]PlayRecord, int, error) { 62 - var lastfmRecords, spotifyRecords []PlayRecord 63 - var err error 64 - 65 - if opts.LastFMPath != "" { 66 - lastfmRecords, err = ParseInput(ctx, opts.LastFMPath, opts.LastFMParser) 67 - if err != nil { 68 - return nil, 0, fmt.Errorf("parse lastfm: %w", err) 69 - } 70 - } 71 - 72 - if opts.SpotifyPath != "" { 73 - spotifyRecords, err = ParseInput(ctx, opts.SpotifyPath, opts.SpotifyParser) 74 - if err != nil { 75 - return nil, 0, fmt.Errorf("parse spotify: %w", err) 76 - } 77 - } 78 - 79 - totalInput := len(lastfmRecords) + len(spotifyRecords) 80 - 81 - mergedRecords := kway.Merge([][]PlayRecord{lastfmRecords, spotifyRecords}, opts.Tolerance) 82 - 83 - return mergedRecords, totalInput, nil 84 - }
-192
sync/import_test.go
··· 1 - package sync_test 2 - 3 - import ( 4 - "context" 5 - "encoding/json" 6 - "fmt" 7 - "io" 8 - "io/fs" 9 - "testing" 10 - "time" 11 - 12 - "github.com/bluesky-social/indigo/atproto/atclient" 13 - 14 - "tangled.org/karitham.dev/lazuli/cache" 15 - "tangled.org/karitham.dev/lazuli/sync" 16 - ) 17 - 18 - type mockParser struct { 19 - records []sync.PlayRecord 20 - } 21 - 22 - func (m *mockParser) ParseFile(ctx context.Context, r io.Reader) ([]sync.PlayRecord, error) { 23 - return m.records, nil 24 - } 25 - 26 - func (m *mockParser) ParseFS(ctx context.Context, fsys fs.FS) ([]sync.PlayRecord, error) { 27 - return m.records, nil 28 - } 29 - 30 - type mockRepoClient struct { 31 - records []sync.RecordRef 32 - deleted []string 33 - applied []sync.PlayRecord 34 - } 35 - 36 - func (m *mockRepoClient) ListRecords(ctx context.Context, collection string, limit int, cursor string) ([]sync.RecordRef, string, error) { 37 - return m.records, "", nil 38 - } 39 - 40 - func (m *mockRepoClient) ApplyWrites(ctx context.Context, collection string, records []sync.PlayRecord) error { 41 - if len(records) > 200 { 42 - return fmt.Errorf("too many records") 43 - } 44 - m.applied = append(m.applied, records...) 45 - return nil 46 - } 47 - 48 - func (m *mockRepoClient) DeleteRecord(ctx context.Context, collection, rkey string) error { 49 - m.deleted = append(m.deleted, rkey) 50 - return nil 51 - } 52 - 53 - type mockAuthClient struct { 54 - did string 55 - } 56 - 57 - func (m *mockAuthClient) APIClient() *atclient.APIClient { return nil } 58 - func (m *mockAuthClient) DID() string { return m.did } 59 - 60 - type mockKV struct { 61 - data map[string]int 62 - } 63 - 64 - func (m *mockKV) GetMulti(keys []string) (map[string]int, error) { 65 - out := make(map[string]int) 66 - for _, k := range keys { 67 - out[k] = m.data[k] 68 - } 69 - return out, nil 70 - } 71 - 72 - func (m *mockKV) IncrByMulti(counts map[string]int) error { 73 - for k, v := range counts { 74 - m.data[k] += v 75 - } 76 - return nil 77 - } 78 - 79 - func (m *mockKV) Get(key string) (int, error) { 80 - return m.data[key], nil 81 - } 82 - 83 - func (m *mockKV) Set(key string, val int) error { 84 - m.data[key] = val 85 - return nil 86 - } 87 - 88 - func (m *mockKV) IncrBy(key string, n int) (int, error) { 89 - m.data[key] += n 90 - return m.data[key], nil 91 - } 92 - 93 - func TestImportE2E(t *testing.T) { 94 - ctx := context.Background() 95 - did := "did:plc:test" 96 - 97 - // 1. Setup Storage 98 - storage, err := cache.NewBoltStorage() 99 - if err != nil { 100 - t.Fatal(err) 101 - } 102 - defer storage.Close() 103 - defer storage.ClearAll() 104 - 105 - // 2. Mock Data 106 - t1 := time.Date(2023, 1, 1, 12, 0, 0, 0, time.UTC) 107 - t2 := time.Date(2023, 1, 1, 12, 0, 5, 0, time.UTC) // Within tolerance (5s) 108 - t3 := time.Date(2023, 1, 2, 12, 0, 0, 0, time.UTC) // New record 109 - 110 - rec1 := sync.PlayRecord{TrackName: "Song A", PlayedTime: sync.Timestamp{Time: t1}} 111 - rec2 := sync.PlayRecord{TrackName: "Song A", PlayedTime: sync.Timestamp{Time: t2}} 112 - rec3 := sync.PlayRecord{TrackName: "Song B", PlayedTime: sync.Timestamp{Time: t3}} 113 - 114 - // 3. Load Records 115 - opts := sync.ImportOptions{ 116 - Tolerance: 10 * time.Second, 117 - LastFMParser: &mockParser{records: []sync.PlayRecord{rec1}}, 118 - SpotifyParser: &mockParser{records: []sync.PlayRecord{rec2, rec3}}, 119 - LastFMPath: "import_test.go", // Use existing file to pass Stat 120 - SpotifyPath: "import_test.go", 121 - } 122 - 123 - records, total, err := sync.LoadRecordsForImport(ctx, opts) 124 - if err != nil { 125 - t.Fatal(err) 126 - } 127 - 128 - if total != 3 { 129 - t.Errorf("expected 3 total records, got %d", total) 130 - } 131 - if len(records) != 2 { 132 - t.Errorf("expected 2 merged records, got %d", len(records)) 133 - } 134 - 135 - // 4. Save to storage (as if we just imported them) 136 - newEntries := make(map[string][]byte) 137 - for _, rec := range records { 138 - key := sync.CreateRecordKey(rec) 139 - val, _ := json.Marshal(rec) 140 - newEntries[key] = val 141 - } 142 - if err := storage.SaveRecords(did, newEntries); err != nil { 143 - t.Fatal(err) 144 - } 145 - 146 - // 5. Mock ATProto Client 147 - mockRepo := &mockRepoClient{ 148 - records: []sync.RecordRef{ 149 - {Value: rec1}, // Already exists on remote 150 - }, 151 - } 152 - 153 - // 6. Fetch Existing (Deduplicate) 154 - existing, err := sync.FetchExisting(ctx, mockRepo, did, storage, false) 155 - if err != nil { 156 - t.Fatal(err) 157 - } 158 - 159 - if len(existing) != 1 { 160 - t.Errorf("expected 1 existing record, got %d", len(existing)) 161 - } 162 - 163 - // 7. Publish 164 - kv := &mockKV{data: make(map[string]int)} 165 - limiter := sync.NewRateLimiter(kv, 1) 166 - publishOpts := sync.PublishOptions{ 167 - BatchSize: 10, 168 - ATProtoClient: mockRepo, 169 - Storage: storage, 170 - Limiter: limiter, 171 - ClientAgent: sync.DefaultClientAgent, 172 - } 173 - 174 - auth := &mockAuthClient{did: did} 175 - result := sync.Publish(ctx, auth, publishOpts) 176 - 177 - if result.SuccessCount != 1 { 178 - t.Errorf("expected 1 successful publish, got %d", result.SuccessCount) 179 - } 180 - if len(mockRepo.applied) != 1 { 181 - t.Errorf("expected 1 record applied to repo, got %d", len(mockRepo.applied)) 182 - } 183 - if mockRepo.applied[0].TrackName != "Song B" { 184 - t.Errorf("expected Song B to be published, got %s", mockRepo.applied[0].TrackName) 185 - } 186 - 187 - // 8. Verify storage state 188 - stats, _ := storage.Stats() 189 - if stats.UnpublishedCount != 0 { 190 - t.Errorf("expected 0 unpublished records, got %d", stats.UnpublishedCount) 191 - } 192 - }
+47 -32
sync/publish.go
··· 1 1 package sync 2 2 3 3 import ( 4 + "cmp" 4 5 "context" 5 6 "encoding/json" 6 7 "fmt" ··· 17 18 "tangled.org/karitham.dev/lazuli/cache" 18 19 ) 19 20 20 - type ( 21 - ATProtoClient = atproto.RepoClient[PlayRecord] 22 - AuthClient = atproto.AuthClient 23 - RateLimiter = atproto.RateLimiter 21 + const ( 22 + WriteLimitDay = atproto.WriteLimitDay 23 + GlobalLimitDay = atproto.GlobalLimitDay 24 + 25 + RecordType = "fm.teal.alpha.feed.play" 26 + DefaultBatchSize = 20 27 + DefaultCrossSourceTolerance = 5 * time.Minute 28 + CrossSourceTolerance = DefaultCrossSourceTolerance 29 + CacheTTL = 24 * time.Hour 30 + CacheVersion = 1 31 + SlingshotResolverURL = "https://slingshot.microcosm.blue/xrpc/com.bad-example.identity.resolveMiniDoc" 32 + MaxRetryDelay = 15 * time.Minute 33 + MaxRetries = 1000 34 + BaseRetryDelay = 2 * time.Second 24 35 ) 25 36 26 - var ( 27 - WriteLimitDay = atproto.WriteLimitDay 28 - GlobalLimitDay = atproto.GlobalLimitDay 37 + type ( 38 + ATProtoClient = atproto.RepoClient[PlayRecord] 39 + AuthClient = atproto.AuthClient 40 + RateLimiter = atproto.RateLimiter 41 + Client = atproto.Client 42 + RepoClient[T any] = atproto.RepoClient[T] 43 + RecordRef = atproto.RecordRef[PlayRecord] 44 + 45 + PublishOptions struct { 46 + BatchSize int 47 + DryRun bool 48 + ATProtoClient ATProtoClient 49 + ProgressLog func(ProgressReport) 50 + Storage cache.Storage 51 + Limiter RateLimiter 52 + ClientAgent string 53 + RetryDelay time.Duration 54 + } 55 + 56 + PublishResult struct { 57 + SuccessCount int `json:"successCount"` 58 + ErrorCount int `json:"errorCount"` 59 + Cancelled bool `json:"cancelled"` 60 + Duration time.Duration `json:"duration"` 61 + TotalRecords int `json:"totalRecords"` 62 + RecordsPerMinute float64 `json:"recordsPerMinute"` 63 + FirstRecordTime time.Time `json:"firstRecordTime"` 64 + LastRecordTime time.Time `json:"lastRecordTime"` 65 + } 29 66 ) 30 67 31 68 func NewRateLimiter(kv atproto.KVStore, maxPercent float32) RateLimiter { ··· 40 77 return atproto.IsTransientError(err) 41 78 } 42 79 43 - type Client = atproto.Client 44 - 45 80 func NewClient(ctx context.Context, handle, password string, opts ...func(*atproto.ClientOptions)) (*Client, error) { 46 81 return atproto.NewClient(ctx, handle, password, opts...) 47 82 } 48 83 49 - type RepoClient[T any] = atproto.RepoClient[T] 50 - 51 - type RecordRef = atproto.RecordRef[PlayRecord] 52 - 53 - type PublishOptions struct { 54 - BatchSize int 55 - DryRun bool 56 - ATProtoClient ATProtoClient 57 - ProgressLog func(ProgressReport) 58 - Storage cache.Storage 59 - Limiter RateLimiter 60 - ClientAgent string 61 - } 62 - 63 84 func Publish(ctx context.Context, client AuthClient, opts PublishOptions) PublishResult { 64 85 startTime := time.Now() 65 86 66 - batchSize := defaultBatchSize(opts.BatchSize) 87 + retryDelay := cmp.Or(opts.RetryDelay, BaseRetryDelay) 88 + batchSize := cmp.Or(opts.BatchSize, DefaultBatchSize) 67 89 68 90 atprotoClient, err := atproto.BuildClient(client, opts.ATProtoClient) 69 91 if err != nil { ··· 131 153 did := client.DID() 132 154 retryPolicy := retrypolicy.NewBuilder[any](). 133 155 WithMaxRetries(10). 134 - WithBackoff(BaseRetryDelay, 5*time.Minute). 156 + WithBackoff(retryDelay, 5*time.Minute). 135 157 HandleIf(func(_ any, err error) bool { 136 158 return isTransientError(err) 137 159 }). ··· 244 266 slog.Int("errors", pr.Errors), 245 267 ) 246 268 } 247 - } 248 - 249 - func defaultBatchSize(size int) int { 250 - if size > 0 { 251 - return size 252 - } 253 - return DefaultBatchSize 254 269 } 255 270 256 271 func newPublishResult(success, errors, total int, start time.Time, cancelled bool) PublishResult {
-315
sync/rate_test.go
··· 1 - package sync 2 - 3 - import ( 4 - "context" 5 - "encoding/json" 6 - "fmt" 7 - "testing" 8 - "time" 9 - 10 - "github.com/bluesky-social/indigo/atproto/atclient" 11 - 12 - "tangled.org/karitham.dev/lazuli/atproto" 13 - ) 14 - 15 - type mockKV struct { 16 - data map[string]int 17 - } 18 - 19 - func (m *mockKV) GetMulti(keys []string) (map[string]int, error) { 20 - out := make(map[string]int) 21 - for _, k := range keys { 22 - out[k] = m.data[k] 23 - } 24 - return out, nil 25 - } 26 - 27 - func (m *mockKV) IncrByMulti(counts map[string]int) error { 28 - for k, v := range counts { 29 - m.data[k] += v 30 - } 31 - return nil 32 - } 33 - 34 - type testClock struct { 35 - now time.Time 36 - } 37 - 38 - func (m *testClock) Now() time.Time { return m.now } 39 - 40 - func TestRateLimiter_Refunds(t *testing.T) { 41 - kv := &mockKV{data: make(map[string]int)} 42 - clock := &testClock{now: time.Date(2026, 1, 22, 12, 0, 0, 0, time.UTC)} 43 - limiter := &testRateLimiter{ 44 - kv: kv, 45 - prefix: "quota", 46 - clock: clock, 47 - rlQuota: 1, 48 - } 49 - ctx := context.Background() 50 - 51 - chargedAt, _ := limiter.AllowBulkWrite(ctx, 10) 52 - limiter.RefundBulkWrite(ctx, 10, chargedAt) 53 - 54 - w, g, _ := limiter.Stats() 55 - if w != 0 || g != 0 { 56 - t.Errorf("BulkWrite refund failed: w=%d, g=%d", w, g) 57 - } 58 - 59 - chargedAt, _ = limiter.AllowRead(ctx) 60 - limiter.RefundRead(ctx, chargedAt) 61 - 62 - _, g, _ = limiter.Stats() 63 - if g != 0 { 64 - t.Errorf("Read refund failed: g=%d", g) 65 - } 66 - } 67 - 68 - type testRateLimiter struct { 69 - kv *mockKV 70 - prefix string 71 - clock *testClock 72 - rlQuota float32 73 - } 74 - 75 - func (l *testRateLimiter) Stats() (int, int, error) { 76 - wd, gd, _, _, _, _ := l.getKeys(l.clock.now) 77 - vals, err := l.kv.GetMulti([]string{wd, gd}) 78 - if err != nil { 79 - return 0, 0, err 80 - } 81 - return vals[wd], vals[gd], nil 82 - } 83 - 84 - func (l *testRateLimiter) AllowBulkWrite(ctx context.Context, n int) (time.Time, error) { 85 - wCost := n * atproto.WriteOnlyCost 86 - gCost := n * atproto.WriteGlobalCost 87 - 88 - now := l.clock.now 89 - wKeys, gKeys := l.getAllKeys(now) 90 - 91 - maxWait, err := l.checkQuota(now, wKeys, gKeys, wCost, gCost) 92 - if err != nil { 93 - return now, err 94 - } 95 - 96 - if maxWait > 0 { 97 - return now, context.DeadlineExceeded 98 - } 99 - 100 - err = l.charge(wKeys, gKeys, wCost, gCost) 101 - if err != nil { 102 - return now, err 103 - } 104 - return now, nil 105 - } 106 - 107 - func (l *testRateLimiter) AllowRead(ctx context.Context) (time.Time, error) { 108 - gCost := atproto.ReadGlobalCost 109 - 110 - now := l.clock.now 111 - _, gKeys := l.getAllKeys(now) 112 - 113 - maxWait, err := l.checkQuota(now, nil, gKeys, 0, gCost) 114 - if err != nil { 115 - return now, err 116 - } 117 - 118 - if maxWait > 0 { 119 - return now, context.DeadlineExceeded 120 - } 121 - 122 - err = l.charge(nil, gKeys, 0, gCost) 123 - if err != nil { 124 - return now, err 125 - } 126 - return now, nil 127 - } 128 - 129 - func (l *testRateLimiter) RefundBulkWrite(ctx context.Context, n int, chargedAt time.Time) { 130 - wKeys, gKeys := l.getAllKeys(chargedAt) 131 - wCost := n * atproto.WriteOnlyCost 132 - gCost := n * atproto.WriteGlobalCost 133 - 134 - updates := make(map[string]int, len(wKeys)+len(gKeys)) 135 - for _, k := range wKeys { 136 - updates[k] = -wCost 137 - } 138 - for _, k := range gKeys { 139 - updates[k] = -gCost 140 - } 141 - 142 - l.kv.IncrByMulti(updates) 143 - } 144 - 145 - func (l *testRateLimiter) RefundRead(ctx context.Context, chargedAt time.Time) { 146 - _, gKeys := l.getAllKeys(chargedAt) 147 - gCost := atproto.ReadGlobalCost 148 - 149 - updates := make(map[string]int, len(gKeys)) 150 - for _, k := range gKeys { 151 - updates[k] = -gCost 152 - } 153 - 154 - l.kv.IncrByMulti(updates) 155 - } 156 - 157 - func (l *testRateLimiter) getKeys(t time.Time) (string, string, string, string, string, string) { 158 - day := t.Format("2006-01-02") 159 - hour := t.Format("2006-01-02-15") 160 - minute := t.Format("2006-01-02-15-04") 161 - return fmt.Sprintf("%s:writes:d:%s", l.prefix, day), fmt.Sprintf("%s:global:d:%s", l.prefix, day), 162 - fmt.Sprintf("%s:writes:h:%s", l.prefix, hour), fmt.Sprintf("%s:global:h:%s", l.prefix, hour), 163 - fmt.Sprintf("%s:writes:m:%s", l.prefix, minute), fmt.Sprintf("%s:global:m:%s", l.prefix, minute) 164 - } 165 - 166 - func (l *testRateLimiter) getAllKeys(t time.Time) ([]string, []string) { 167 - wd, gd, wh, gh, wm, gm := l.getKeys(t) 168 - return []string{wm, wh, wd}, []string{gm, gh, gd} 169 - } 170 - 171 - func (l *testRateLimiter) checkQuota(now time.Time, wKeys, gKeys []string, wCost, gCost int) (time.Duration, error) { 172 - wLimits := []int{atproto.WriteLimitMinute, atproto.WriteLimitHour, atproto.WriteLimitDay} 173 - gLimits := []int{atproto.GlobalLimitMinute, atproto.GlobalLimitHour, atproto.GlobalLimitDay} 174 - 175 - allKeys := make([]string, 0, len(wKeys)+len(gKeys)) 176 - allKeys = append(allKeys, wKeys...) 177 - allKeys = append(allKeys, gKeys...) 178 - 179 - if len(allKeys) == 0 { 180 - return 0, nil 181 - } 182 - 183 - values, err := l.kv.GetMulti(allKeys) 184 - if err != nil { 185 - return 0, err 186 - } 187 - 188 - maxWait := time.Duration(0) 189 - 190 - for i, k := range wKeys { 191 - curr := values[k] 192 - if curr+wCost > int(float32(wLimits[i])*l.rlQuota) { 193 - maxWait = max(l.untilNextWindow(now, i), maxWait) 194 - } 195 - } 196 - 197 - for i, k := range gKeys { 198 - curr := values[k] 199 - if curr+gCost > int(float32(gLimits[i])*l.rlQuota) { 200 - maxWait = max(l.untilNextWindow(now, i), maxWait) 201 - } 202 - } 203 - 204 - return maxWait, nil 205 - } 206 - 207 - func (l *testRateLimiter) charge(wKeys, gKeys []string, wCost, gCost int) error { 208 - updates := make(map[string]int, len(wKeys)+len(gKeys)) 209 - for _, k := range wKeys { 210 - updates[k] = wCost 211 - } 212 - for _, k := range gKeys { 213 - updates[k] = gCost 214 - } 215 - 216 - if len(updates) == 0 { 217 - return nil 218 - } 219 - 220 - return l.kv.IncrByMulti(updates) 221 - } 222 - 223 - func (l *testRateLimiter) untilNextWindow(now time.Time, tier int) time.Duration { 224 - switch tier { 225 - case 0: 226 - return now.Truncate(time.Minute).Add(time.Minute).Sub(now) 227 - case 1: 228 - return now.Truncate(time.Hour).Add(time.Hour).Sub(now) 229 - case 2: 230 - return now.Truncate(24 * time.Hour).Add(24 * time.Hour).Sub(now) 231 - default: 232 - return time.Minute 233 - } 234 - } 235 - 236 - func TestRateLimiter_Weighting(t *testing.T) { 237 - kv := &mockKV{data: make(map[string]int)} 238 - limiter := atproto.NewRateLimiter(kv, 1) 239 - ctx := context.Background() 240 - 241 - _, err := limiter.AllowRead(ctx) 242 - if err != nil { 243 - t.Fatal(err) 244 - } 245 - _, g, err := limiter.Stats() 246 - if err != nil { 247 - t.Fatal(err) 248 - } 249 - if g != 1 { 250 - t.Errorf("expected 1 global unit, got %d", g) 251 - } 252 - 253 - _, err = limiter.AllowBulkWrite(ctx, 1) 254 - if err != nil { 255 - t.Fatal(err) 256 - } 257 - w, g, err := limiter.Stats() 258 - if err != nil { 259 - t.Fatal(err) 260 - } 261 - if w != 1 { 262 - t.Errorf("expected 1 write unit, got %d", w) 263 - } 264 - if g != 4 { 265 - t.Errorf("expected 4 global units, got %d", g) 266 - } 267 - 268 - _, err = limiter.AllowBulkWrite(ctx, 10) 269 - if err != nil { 270 - t.Fatal(err) 271 - } 272 - w, g, err = limiter.Stats() 273 - if err != nil { 274 - t.Fatal(err) 275 - } 276 - if w != 11 { 277 - t.Errorf("expected 11 write units, got %d", w) 278 - } 279 - if g != 34 { 280 - t.Errorf("expected 34 global units, got %d", g) 281 - } 282 - } 283 - 284 - func TestRetryExhaustionMarkFailed(t *testing.T) { 285 - ctx := context.Background() 286 - did := "did:example:123" 287 - clientAgent := "test-agent" 288 - storage := newMockStorage() 289 - rec1, _ := json.Marshal(PlayRecord{TrackName: "Song 1"}) 290 - storage.SaveRecords(did, map[string][]byte{"k1": rec1}) 291 - 292 - client := &mockATProtoClient{ 293 - applyWritesFunc: func(ctx context.Context, collection string, records []PlayRecord) error { 294 - return &atclient.APIError{StatusCode: 503} 295 - }, 296 - } 297 - 298 - oldBase := BaseRetryDelay 299 - BaseRetryDelay = time.Nanosecond 300 - defer func() { BaseRetryDelay = oldBase }() 301 - 302 - res := Publish(ctx, &mockAuthClient{did: did}, PublishOptions{ 303 - BatchSize: 1, 304 - ATProtoClient: client, 305 - Storage: storage, 306 - ClientAgent: clientAgent, 307 - }) 308 - 309 - if res.SuccessCount != 0 { 310 - t.Errorf("expected 0 successes, got %d", res.SuccessCount) 311 - } 312 - if res.ErrorCount != 1 { 313 - t.Errorf("expected 1 error, got %d", res.ErrorCount) 314 - } 315 - }