like malachite (atproto-lastfm-importer) but in go and bluer
go spotify tealfm lastfm atproto
0
fork

Configure Feed

Select the types of activity you want to include in your feed.

sync/publish: refactor

karitham e90f6de3 9cc177a3

+796 -553
+1 -1
flake.nix
··· 17 17 let 18 18 lazuli = pkgs.buildGoModule rec { 19 19 name = "lazuli"; 20 - version = "0.1.6"; 20 + version = "0.1.7"; 21 21 src = pkgs.nix-gitignore.gitignoreSource [ "*.csv" "*.zip" "*.json" ] ./.; 22 22 vendorHash = "sha256-O6R8jC8Ms5gsY2FUmuL8lTGTODfMW1CsSWuWbN27zeY="; 23 23 ldflags = [
-418
sync/batch_test.go
··· 1 - package sync 2 - 3 - import ( 4 - "context" 5 - "encoding/json" 6 - "errors" 7 - "maps" 8 - "net/http" 9 - "strings" 10 - "sync" 11 - "sync/atomic" 12 - "testing" 13 - "testing/synctest" 14 - "time" 15 - 16 - "github.com/bluesky-social/indigo/atproto/atclient" 17 - 18 - "tangled.org/karitham.dev/lazuli/atproto" 19 - "tangled.org/karitham.dev/lazuli/cache" 20 - ) 21 - 22 - type mockRoundTripper func(req *http.Request) (*http.Response, error) 23 - 24 - func (f mockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { 25 - return f(req) 26 - } 27 - 28 - // Mock Storage 29 - 30 - type mockStorage struct { 31 - cache.Storage 32 - unpublished map[string][]byte 33 - published map[string]bool 34 - failed map[string]string 35 - kv map[string]int 36 - mu sync.Mutex 37 - } 38 - 39 - func newMockStorage() *mockStorage { 40 - return &mockStorage{ 41 - unpublished: make(map[string][]byte), 42 - published: make(map[string]bool), 43 - failed: make(map[string]string), 44 - kv: make(map[string]int), 45 - } 46 - } 47 - 48 - func (m *mockStorage) SaveRecords(did string, records map[string][]byte) error { 49 - m.mu.Lock() 50 - defer m.mu.Unlock() 51 - maps.Copy(m.unpublished, records) 52 - return nil 53 - } 54 - 55 - func (m *mockStorage) IterateUnpublished(did string, fn func(key string, rec []byte) error) error { 56 - m.mu.Lock() 57 - // Copy to avoid deadlock if fn calls back 58 - keys := make([]string, 0, len(m.unpublished)) 59 - for k := range m.unpublished { 60 - keys = append(keys, k) 61 - } 62 - m.mu.Unlock() 63 - 64 - for _, k := range keys { 65 - m.mu.Lock() 66 - rec, ok := m.unpublished[k] 67 - m.mu.Unlock() 68 - if ok { 69 - if err := fn(k, rec); err != nil { 70 - return err 71 - } 72 - } 73 - } 74 - return nil 75 - } 76 - 77 - func (m *mockStorage) MarkPublished(did string, keys ...string) error { 78 - m.mu.Lock() 79 - defer m.mu.Unlock() 80 - for _, k := range keys { 81 - delete(m.unpublished, k) 82 - m.published[k] = true 83 - } 84 - return nil 85 - } 86 - 87 - func (m *mockStorage) MarkFailed(did string, keys []string, err string) error { 88 - m.mu.Lock() 89 - defer m.mu.Unlock() 90 - for _, k := range keys { 91 - m.failed[k] = err 92 - } 93 - return nil 94 - } 95 - 96 - func (m *mockStorage) Get(key string) (int, error) { 97 - m.mu.Lock() 98 - defer m.mu.Unlock() 99 - return m.kv[key], nil 100 - } 101 - 102 - func (m *mockStorage) IncrBy(key string, n int) (int, error) { 103 - m.mu.Lock() 104 - defer m.mu.Unlock() 105 - m.kv[key] += n 106 - return m.kv[key], nil 107 - } 108 - 109 - // Mock RateLimiter 110 - type mockLimiter struct { 111 - refunds int32 112 - } 113 - 114 - func (m *mockLimiter) AllowRead(ctx context.Context) (time.Time, error) { 115 - return time.Now(), nil 116 - } 117 - 118 - func (m *mockLimiter) AllowBulkWrite(ctx context.Context, n int) (time.Time, error) { 119 - return time.Now(), nil 120 - } 121 - 122 - func (m *mockLimiter) RefundBulkWrite(ctx context.Context, n int, chargedAt time.Time) { 123 - atomic.AddInt32(&m.refunds, 1) 124 - } 125 - 126 - func (m *mockLimiter) RefundRead(ctx context.Context, chargedAt time.Time) { 127 - atomic.AddInt32(&m.refunds, 1) 128 - } 129 - 130 - func (m *mockLimiter) Stats() (int, int, error) { 131 - return 0, 0, nil 132 - } 133 - 134 - func (m *mockLimiter) EstimatedWriteTime(n int) time.Duration { 135 - return 0 136 - } 137 - 138 - func (m *mockLimiter) RemainingQuota() (int, int, time.Duration) { 139 - return 10000, 35000, time.Hour 140 - } 141 - 142 - // Mock ATProtoClient 143 - type mockATProtoClient struct { 144 - applyWritesFunc func(ctx context.Context, collection string, records []PlayRecord) error 145 - listRecordsFunc func(ctx context.Context, collection string, limit int, cursor string) ([]atproto.RecordRef[PlayRecord], string, error) 146 - deleteRecordFunc func(ctx context.Context, collection, rkey string) error 147 - } 148 - 149 - func (m *mockATProtoClient) ApplyWrites(ctx context.Context, collection string, records []PlayRecord) error { 150 - if m.applyWritesFunc != nil { 151 - return m.applyWritesFunc(ctx, collection, records) 152 - } 153 - return nil 154 - } 155 - 156 - func (m *mockATProtoClient) ListRecords(ctx context.Context, collection string, limit int, cursor string) ([]atproto.RecordRef[PlayRecord], string, error) { 157 - if m.listRecordsFunc != nil { 158 - return m.listRecordsFunc(ctx, collection, limit, cursor) 159 - } 160 - return nil, "", nil 161 - } 162 - 163 - func (m *mockATProtoClient) DeleteRecord(ctx context.Context, collection, rkey string) error { 164 - if m.deleteRecordFunc != nil { 165 - return m.deleteRecordFunc(ctx, collection, rkey) 166 - } 167 - return nil 168 - } 169 - 170 - // Mock AuthClient 171 - type mockAuthClient struct { 172 - did string 173 - } 174 - 175 - func (m *mockAuthClient) APIClient() *atclient.APIClient { return nil } 176 - func (m *mockAuthClient) DID() string { return m.did } 177 - 178 - type timeoutError struct{} 179 - 180 - func (e timeoutError) Error() string { return "timeout" } 181 - func (e timeoutError) Timeout() bool { return true } 182 - func (e timeoutError) Temporary() bool { return true } 183 - 184 - func TestIsTransientError(t *testing.T) { 185 - tests := []struct { 186 - name string 187 - err error 188 - want bool 189 - }{ 190 - {"nil", nil, false}, 191 - {"generic error", errors.New("some error"), false}, 192 - {"API 400", &atclient.APIError{StatusCode: 400}, false}, 193 - {"API 429", &atclient.APIError{StatusCode: 429}, true}, 194 - {"API 500", &atclient.APIError{StatusCode: 500}, true}, 195 - {"API 503", &atclient.APIError{StatusCode: 503}, true}, 196 - {"net timeout", timeoutError{}, true}, 197 - {"net non-timeout", errors.New("network is down"), false}, 198 - } 199 - 200 - for _, tt := range tests { 201 - t.Run(tt.name, func(t *testing.T) { 202 - t.Parallel() 203 - if got := isTransientError(tt.err); got != tt.want { 204 - t.Errorf("isTransientError() = %v, want %v", got, tt.want) 205 - } 206 - }) 207 - } 208 - } 209 - 210 - func TestApplyWrites_RateClient(t *testing.T) { 211 - ctx := context.Background() 212 - clientAgent := "test-agent" 213 - 214 - tests := []struct { 215 - name string 216 - setupClient func() (*atproto.RateClient[PlayRecord], *mockLimiter) 217 - records []PlayRecord 218 - wantErr bool 219 - wantErrMsg string 220 - wantRefunds int32 221 - }{ 222 - { 223 - name: "empty records succeeds", 224 - setupClient: func() (*atproto.RateClient[PlayRecord], *mockLimiter) { 225 - limiter := &mockLimiter{} 226 - return atproto.NewRateClient[PlayRecord](nil, "did:example:123", limiter), limiter 227 - }, 228 - records: nil, 229 - wantErr: false, 230 - wantRefunds: 0, 231 - }, 232 - { 233 - name: "too many records fails", 234 - setupClient: func() (*atproto.RateClient[PlayRecord], *mockLimiter) { 235 - limiter := &mockLimiter{} 236 - return atproto.NewRateClient[PlayRecord](&atclient.APIClient{}, "did:example:123", limiter), limiter 237 - }, 238 - records: make([]PlayRecord, 201), 239 - wantErr: true, 240 - wantErrMsg: "too many records in one ApplyWrites call: 201 (max 200)", 241 - wantRefunds: 0, 242 - }, 243 - { 244 - name: "transient error refunds tokens", 245 - setupClient: func() (*atproto.RateClient[PlayRecord], *mockLimiter) { 246 - limiter := &mockLimiter{} 247 - apiClient := atclient.NewAPIClient("https://example.com") 248 - apiClient.Client.Transport = mockRoundTripper(func(req *http.Request) (*http.Response, error) { 249 - return &http.Response{ 250 - StatusCode: 503, 251 - Body: http.NoBody, 252 - }, nil 253 - }) 254 - return atproto.NewRateClient[PlayRecord](apiClient, "did:example:123", limiter), limiter 255 - }, 256 - records: []PlayRecord{{TrackName: "Song 1"}}, 257 - wantErr: true, 258 - wantRefunds: 1, 259 - }, 260 - { 261 - name: "non-transient error does NOT refund", 262 - setupClient: func() (*atproto.RateClient[PlayRecord], *mockLimiter) { 263 - limiter := &mockLimiter{} 264 - apiClient := atclient.NewAPIClient("https://example.com") 265 - apiClient.Client.Transport = mockRoundTripper(func(req *http.Request) (*http.Response, error) { 266 - return &http.Response{ 267 - StatusCode: 400, 268 - Body: http.NoBody, 269 - }, nil 270 - }) 271 - return atproto.NewRateClient[PlayRecord](apiClient, "did:example:123", limiter), limiter 272 - }, 273 - records: []PlayRecord{{TrackName: "Song 1"}}, 274 - wantErr: true, 275 - wantRefunds: 0, 276 - }, 277 - } 278 - 279 - for _, tt := range tests { 280 - t.Run(tt.name, func(t *testing.T) { 281 - client, limiter := tt.setupClient() 282 - err := client.ApplyWrites(ctx, "test", tt.records) 283 - 284 - if tt.wantErr { 285 - if err == nil { 286 - t.Error("expected error, got nil") 287 - } else if tt.wantErrMsg != "" && err.Error() != tt.wantErrMsg { 288 - t.Errorf("error msg = %q, want %q", err.Error(), tt.wantErrMsg) 289 - } 290 - } else if err != nil { 291 - t.Errorf("unexpected error: %v", err) 292 - } 293 - 294 - if got := atomic.LoadInt32(&limiter.refunds); got != tt.wantRefunds { 295 - t.Errorf("refunds = %d, want %d", got, tt.wantRefunds) 296 - } 297 - }) 298 - } 299 - _ = clientAgent 300 - } 301 - 302 - func TestPublishBatch(t *testing.T) { 303 - ctx := context.Background() 304 - did := "did:example:123" 305 - batch := []PlayRecord{{TrackName: "Song 1"}} 306 - clientAgent := "test-agent" 307 - 308 - t.Run("Success", func(t *testing.T) { 309 - storage := newMockStorage() 310 - client := &mockATProtoClient{} 311 - err := PublishBatch(ctx, client, did, batch, storage, clientAgent) 312 - if err != nil { 313 - t.Fatal(err) 314 - } 315 - if len(storage.published) != 1 { 316 - t.Errorf("expected 1 published record, got %d", len(storage.published)) 317 - } 318 - }) 319 - 320 - t.Run("ApplyWrites failure", func(t *testing.T) { 321 - storage := newMockStorage() 322 - expectedErr := errors.New("apply failed") 323 - client := &mockATProtoClient{ 324 - applyWritesFunc: func(ctx context.Context, collection string, records []PlayRecord) error { 325 - return expectedErr 326 - }, 327 - } 328 - err := PublishBatch(ctx, client, did, batch, storage, clientAgent) 329 - if !errors.Is(err, expectedErr) { 330 - t.Errorf("expected error %v, got %v", expectedErr, err) 331 - } 332 - if len(storage.published) != 0 { 333 - t.Error("expected 0 published records") 334 - } 335 - }) 336 - 337 - t.Run("Storage failure after ApplyWrites success", func(t *testing.T) { 338 - storage := &failingStorage{} 339 - client := &mockATProtoClient{} 340 - err := PublishBatch(ctx, client, did, batch, storage, clientAgent) 341 - if err == nil || !strings.Contains(err.Error(), "failed to save records") { 342 - t.Errorf("expected storage save error, got %v", err) 343 - } 344 - }) 345 - } 346 - 347 - type failingStorage struct { 348 - mockStorage 349 - } 350 - 351 - func (s *failingStorage) SaveRecords(did string, records map[string][]byte) error { 352 - return errors.New("failed to save records") 353 - } 354 - 355 - func TestPublish_Iterative(t *testing.T) { 356 - ctx := context.Background() 357 - did := "did:example:123" 358 - clientAgent := "test-agent" 359 - 360 - rec1, _ := json.Marshal(PlayRecord{TrackName: "Song 1"}) 361 - rec2, _ := json.Marshal(PlayRecord{TrackName: "Song 2"}) 362 - 363 - t.Run("Retry on transient error", func(t *testing.T) { 364 - synctest.Test(t, func(t *testing.T) { 365 - storage := newMockStorage() 366 - storage.SaveRecords(did, map[string][]byte{"k1": rec1, "k2": rec2}) 367 - 368 - var attempts int32 369 - client := &mockATProtoClient{ 370 - applyWritesFunc: func(ctx context.Context, collection string, records []PlayRecord) error { 371 - if atomic.AddInt32(&attempts, 1) <= 2 { 372 - return &atclient.APIError{StatusCode: 503} 373 - } 374 - return nil 375 - }, 376 - } 377 - 378 - res := Publish(ctx, &mockAuthClient{did: did}, PublishOptions{ 379 - BatchSize: 1, 380 - ATProtoClient: client, 381 - Storage: storage, 382 - ClientAgent: clientAgent, 383 - }) 384 - 385 - if res.SuccessCount != 2 { 386 - t.Errorf("expected 2 successes, got %d", res.SuccessCount) 387 - } 388 - if atomic.LoadInt32(&attempts) < 3 { 389 - t.Errorf("expected at least 3 attempts (2 fails + 1 success), got %d", attempts) 390 - } 391 - }) 392 - }) 393 - 394 - t.Run("Fail fast on non-transient error", func(t *testing.T) { 395 - storage := newMockStorage() 396 - storage.SaveRecords(did, map[string][]byte{"k1": rec1}) 397 - 398 - client := &mockATProtoClient{ 399 - applyWritesFunc: func(ctx context.Context, collection string, records []PlayRecord) error { 400 - return &atclient.APIError{StatusCode: 400} 401 - }, 402 - } 403 - 404 - res := Publish(ctx, &mockAuthClient{did: did}, PublishOptions{ 405 - BatchSize: 1, 406 - ATProtoClient: client, 407 - Storage: storage, 408 - ClientAgent: clientAgent, 409 - }) 410 - 411 - if res.SuccessCount != 0 { 412 - t.Errorf("expected 0 successes, got %d", res.SuccessCount) 413 - } 414 - if res.ErrorCount != 1 { 415 - t.Errorf("expected 1 error, got %d", res.ErrorCount) 416 - } 417 - }) 418 - }
+181 -134
sync/publish.go
··· 34 34 BaseRetryDelay = 2 * time.Second 35 35 ) 36 36 37 + var DefaultRetryPolicy = retrypolicy.NewBuilder[struct{}](). 38 + WithMaxRetries(10). 39 + WithBackoff(BaseRetryDelay, 5*time.Minute). 40 + HandleIf(func(_ struct{}, err error) bool { 41 + return atproto.IsTransientError(err) 42 + }). 43 + OnRetryScheduled(func(e failsafe.ExecutionScheduledEvent[struct{}]) { 44 + slog.Warn("batch failed with transient error, retrying", 45 + slog.Duration("retryDelay", e.Delay), 46 + ErrorAttr(e.LastError()), 47 + slog.Int("attempt", e.Attempts())) 48 + }). 49 + Build() 50 + 37 51 type ( 38 52 ATProtoClient = atproto.RepoClient[PlayRecord] 39 53 AuthClient = atproto.AuthClient ··· 53 67 RetryDelay time.Duration 54 68 } 55 69 56 - PublishResult struct { 70 + publishResult struct { 57 71 SuccessCount int `json:"successCount"` 58 72 ErrorCount int `json:"errorCount"` 59 73 Cancelled bool `json:"cancelled"` ··· 63 77 FirstRecordTime time.Time `json:"firstRecordTime"` 64 78 LastRecordTime time.Time `json:"lastRecordTime"` 65 79 } 80 + 81 + recordBatch struct { 82 + Records []PlayRecord 83 + Keys []string 84 + } 85 + 86 + batchProcessor struct { 87 + Client ATProtoClient 88 + Storage cache.Storage 89 + DID string 90 + ClientAgent string 91 + DryRun bool 92 + } 93 + 94 + batchResult struct { 95 + SuccessCount int 96 + ErrorCount int 97 + Duration time.Duration 98 + Errors []error 99 + } 66 100 ) 67 101 68 102 func NewRateLimiter(kv atproto.KVStore, maxPercent float32) RateLimiter { ··· 81 115 return atproto.NewClient(ctx, handle, password, opts...) 82 116 } 83 117 84 - func Publish(ctx context.Context, client AuthClient, opts PublishOptions) PublishResult { 85 - startTime := time.Now() 118 + // batchRecords iterates through storage and builds record batches 119 + func batchRecords(ctx context.Context, storage cache.Storage, did string, batchSize int) ([]recordBatch, error) { 120 + var batches []recordBatch 121 + var currentBatch recordBatch 86 122 87 - retryDelay := cmp.Or(opts.RetryDelay, BaseRetryDelay) 88 - batchSize := cmp.Or(opts.BatchSize, DefaultBatchSize) 123 + err := storage.IterateUnpublished(did, func(key string, rec []byte) error { 124 + select { 125 + case <-ctx.Done(): 126 + return ctx.Err() 127 + default: 128 + } 89 129 90 - atprotoClient, err := atproto.BuildClient(client, opts.ATProtoClient) 91 - if err != nil { 92 - return PublishResult{ 93 - SuccessCount: 0, 94 - ErrorCount: 0, 95 - Cancelled: false, 96 - Duration: time.Since(startTime), 97 - TotalRecords: 0, 130 + var record PlayRecord 131 + if err := json.Unmarshal(rec, &record); err != nil { 132 + slog.Error("malformed record in storage", slog.String("key", key), ErrorAttr(err)) 133 + if storage != nil { 134 + _ = storage.MarkFailed(did, []string{key}, "malformed record") 135 + } 136 + return nil // Skip malformed records 98 137 } 99 - } 138 + 139 + currentBatch.Records = append(currentBatch.Records, record) 140 + currentBatch.Keys = append(currentBatch.Keys, key) 100 141 101 - totalRecords := 0 102 - if opts.Storage != nil { 103 - _ = opts.Storage.IterateUnpublished(client.DID(), func(key string, rec []byte) error { 104 - totalRecords++ 105 - return nil 106 - }) 142 + if len(currentBatch.Records) >= batchSize { 143 + batches = append(batches, recordBatch{ 144 + Records: append([]PlayRecord{}, currentBatch.Records...), 145 + Keys: append([]string{}, currentBatch.Keys...), 146 + }) 147 + currentBatch = recordBatch{} 148 + } 149 + return nil 150 + }) 151 + if err != nil { 152 + return nil, err 107 153 } 108 154 109 - if totalRecords == 0 { 110 - return PublishResult{} 155 + // Add the last partial batch if it has records 156 + if len(currentBatch.Records) > 0 { 157 + batches = append(batches, currentBatch) 111 158 } 112 159 113 - slog.Info("starting iterative import", 114 - slog.Int("total_records", totalRecords), 115 - slog.Int("batch_size", batchSize), 116 - slog.Int("daily_write_limit", atproto.WriteLimitDay), 117 - slog.Int("daily_token_limit", atproto.GlobalLimitDay), 118 - slog.String("rate_limit", fmt.Sprintf("1 write per %.1fs", 86400.0/atproto.WriteLimitDay))) 160 + return batches, nil 161 + } 119 162 120 - tracker := NewProgressTracker(totalRecords, opts.Limiter) 121 - progressLog := defaultProgressLog(opts.ProgressLog) 122 - totalSuccess := 0 123 - totalErrors := 0 163 + // processBatch processes a single batch of records with retries 164 + func processBatch(ctx context.Context, batch recordBatch, processor batchProcessor) batchResult { 165 + if len(batch.Records) == 0 { 166 + return batchResult{} 167 + } 124 168 125 - var batch []PlayRecord 126 - var batchKeys []string 169 + start := time.Now() 127 170 128 - processBatch := func() error { 129 - if len(batch) == 0 { 130 - return nil 171 + if processor.DryRun { 172 + for _, r := range batch.Records { 173 + tid := syntax.NewTIDFromTime(r.PlayedTime.Time, 0) 174 + slog.Info("would publish record (dry run)", trackAttr(r), slog.String("rkey", string(tid))) 175 + } 176 + return batchResult{ 177 + SuccessCount: len(batch.Records), 178 + Duration: time.Since(start), 131 179 } 180 + } 132 181 133 - if opts.DryRun { 134 - for _, r := range batch { 135 - tid := syntax.NewTIDFromTime(r.PlayedTime.Time, 0) 136 - slog.Info("would publish record (dry run)", trackAttr(r), slog.String("rkey", string(tid))) 182 + err := failsafe.With(DefaultRetryPolicy).WithContext(ctx).Run(func() error { 183 + return PublishBatch(ctx, processor.Client, processor.DID, batch.Records, processor.Storage, processor.ClientAgent) 184 + }) 185 + if err != nil { 186 + slog.Error("batch failed after retries", 187 + ErrorAttr(err), 188 + slog.Int("count", len(batch.Records))) 189 + 190 + if processor.Storage != nil { 191 + if markErr := processor.Storage.MarkFailed(processor.DID, batch.Keys, err.Error()); markErr != nil { 192 + slog.Error("failed to mark records as failed", ErrorAttr(markErr)) 137 193 } 138 - totalSuccess += len(batch) 139 - tracker.Increment(len(batch)) 140 - slog.Debug("batch dry run completed", 141 - slog.Int("count", len(batch)), 142 - slog.Int("completed", tracker.Completed)) 143 - batch = batch[:0] 144 - batchKeys = batchKeys[:0] 145 - return nil 194 + } 195 + 196 + return batchResult{ 197 + ErrorCount: len(batch.Records), 198 + Duration: time.Since(start), 199 + Errors: []error{err}, 146 200 } 201 + } 147 202 148 - slog.Debug("processing batch", 149 - slog.Int("count", len(batch)), 150 - slog.Int("completed", tracker.Completed), 151 - slog.Int("total", tracker.Total)) 203 + return batchResult{ 204 + SuccessCount: len(batch.Records), 205 + Duration: time.Since(start), 206 + } 207 + } 152 208 153 - did := client.DID() 154 - retryPolicy := retrypolicy.NewBuilder[any](). 155 - WithMaxRetries(10). 156 - WithBackoff(retryDelay, 5*time.Minute). 157 - HandleIf(func(_ any, err error) bool { 158 - return isTransientError(err) 159 - }). 160 - OnRetryScheduled(func(e failsafe.ExecutionScheduledEvent[any]) { 161 - slog.Warn("batch failed with transient error, retrying", 162 - slog.Int("count", len(batch)), 163 - slog.Duration("retryDelay", e.Delay), 164 - ErrorAttr(e.LastError()), 165 - slog.Int("attempt", e.Attempts())) 166 - }). 167 - Build() 209 + // aggregate combines batch results into final publish result 210 + func aggregate(results []batchResult, startTime time.Time) publishResult { 211 + totalSuccess := 0 212 + totalErrors := 0 168 213 169 - err := failsafe.With(retryPolicy).WithContext(ctx).Run(func() error { 170 - return PublishBatch(ctx, atprotoClient, did, batch, opts.Storage, opts.ClientAgent) 171 - }) 172 - if err != nil { 173 - slog.Error("batch failed after retries", 174 - ErrorAttr(err), 175 - slog.Int("count", len(batch))) 214 + for _, result := range results { 215 + totalSuccess += result.SuccessCount 216 + totalErrors += result.ErrorCount 217 + } 218 + 219 + logResult(totalSuccess, totalErrors, startTime) 220 + return newPublishResult(totalSuccess, totalErrors, totalSuccess+totalErrors, startTime, false) 221 + } 222 + 223 + func Publish(ctx context.Context, client AuthClient, opts PublishOptions) publishResult { 224 + startTime := time.Now() 225 + batchSize := cmp.Or(opts.BatchSize, DefaultBatchSize) 176 226 177 - if opts.Storage != nil { 178 - if markErr := opts.Storage.MarkFailed(did, batchKeys, err.Error()); markErr != nil { 179 - slog.Error("failed to mark records as failed", ErrorAttr(markErr)) 180 - } 181 - } 227 + atprotoClient, err := atproto.BuildClient(client, opts.ATProtoClient) 228 + if err != nil { 229 + return errorResult(startTime) 230 + } 182 231 183 - totalErrors += len(batch) 184 - tracker.IncrementErrors(len(batch)) 232 + batches, err := batchRecords(ctx, opts.Storage, client.DID(), batchSize) 233 + if err != nil { 234 + cancelled := ctx.Err() != nil 235 + return newPublishResult(0, 0, 0, startTime, cancelled) 236 + } 185 237 186 - batch = batch[:0] 187 - batchKeys = batchKeys[:0] 188 - return nil 189 - } 238 + if len(batches) == 0 { 239 + return publishResult{} 240 + } 190 241 191 - totalSuccess += len(batch) 192 - tracker.Increment(len(batch)) 193 - slog.Debug("batch published", 194 - slog.Int("count", len(batch)), 195 - slog.Int("completed", tracker.Completed), 196 - slog.Int("total", tracker.Total)) 242 + slog.Info("starting iterative import", 243 + slog.Int("total_records", countTotalRecords(batches)), 244 + slog.Int("batch_size", batchSize), 245 + slog.Int("daily_write_limit", atproto.WriteLimitDay), 246 + slog.Int("daily_token_limit", atproto.GlobalLimitDay), 247 + slog.String("rate_limit", fmt.Sprintf("1 write per %.1fs", 86400.0/atproto.WriteLimitDay))) 197 248 198 - if tracker.ShouldLog() { 199 - progressLog(tracker.Report()) 200 - } 249 + tracker := NewProgressTracker(countTotalRecords(batches), opts.Limiter) 250 + progressLog := defaultProgressLog(opts.ProgressLog) 201 251 202 - batch = batch[:0] 203 - batchKeys = batchKeys[:0] 204 - return nil 252 + processor := batchProcessor{ 253 + Client: atprotoClient, 254 + Storage: opts.Storage, 255 + DID: client.DID(), 256 + ClientAgent: opts.ClientAgent, 257 + DryRun: opts.DryRun, 205 258 } 206 259 207 - err = opts.Storage.IterateUnpublished(client.DID(), func(key string, rec []byte) error { 260 + var results []batchResult 261 + for _, batch := range batches { 208 262 select { 209 263 case <-ctx.Done(): 210 - return ctx.Err() 264 + return aggregate(results, startTime) 211 265 default: 212 266 } 213 267 214 - var record PlayRecord 215 - if err := json.Unmarshal(rec, &record); err != nil { 216 - slog.Error("malformed record in storage", slog.String("key", key), ErrorAttr(err)) 217 - if opts.Storage != nil { 218 - _ = opts.Storage.MarkFailed(client.DID(), []string{key}, "malformed record") 219 - } 220 - totalErrors++ 221 - tracker.IncrementErrors(1) 222 - return nil 223 - } 268 + result := processBatch(ctx, batch, processor) 269 + results = append(results, result) 224 270 225 - batch = append(batch, record) 226 - batchKeys = append(batchKeys, key) 271 + // Update progress tracking 272 + tracker.Increment(result.SuccessCount + result.ErrorCount) 273 + tracker.IncrementErrors(result.ErrorCount) 227 274 228 - if len(batch) >= batchSize { 229 - if err := processBatch(); err != nil { 230 - return err 231 - } 275 + if tracker.ShouldLog() { 276 + progressLog(tracker.Report()) 232 277 } 233 - return nil 234 - }) 278 + } 235 279 236 - if err == nil && len(batch) > 0 { 237 - err = processBatch() 238 - } 280 + return aggregate(results, startTime) 281 + } 239 282 240 - cancelled := false 241 - if err != nil { 242 - slog.Error("import interrupted", ErrorAttr(err)) 243 - cancelled = true 283 + func countTotalRecords(batches []recordBatch) int { 284 + total := 0 285 + for _, batch := range batches { 286 + total += len(batch.Records) 244 287 } 245 - 246 - logResult(totalSuccess, totalErrors, startTime) 247 - return newPublishResult(totalSuccess, totalErrors, totalRecords, startTime, cancelled) 288 + return total 248 289 } 249 290 250 - func isTransientError(err error) bool { 251 - return atproto.IsTransientError(err) 291 + func errorResult(startTime time.Time) publishResult { 292 + return publishResult{ 293 + SuccessCount: 0, 294 + ErrorCount: 0, 295 + Cancelled: false, 296 + Duration: time.Since(startTime), 297 + TotalRecords: 0, 298 + } 252 299 } 253 300 254 301 func defaultProgressLog(f func(ProgressReport)) func(ProgressReport) { ··· 268 315 } 269 316 } 270 317 271 - func newPublishResult(success, errors, total int, start time.Time, cancelled bool) PublishResult { 272 - return PublishResult{ 318 + func newPublishResult(success, errors, total int, start time.Time, cancelled bool) publishResult { 319 + return publishResult{ 273 320 SuccessCount: success, 274 321 ErrorCount: errors, 275 322 Cancelled: cancelled, ··· 384 431 cursor string 385 432 } 386 433 387 - retryPolicy := retrypolicy.NewBuilder[fetchResult](). 434 + fetchRetryPolicy := retrypolicy.NewBuilder[fetchResult](). 388 435 WithMaxRetries(10). 389 436 WithBackoff(BaseRetryDelay, 5*time.Minute). 390 437 HandleIf(func(_ fetchResult, err error) bool { ··· 405 452 default: 406 453 } 407 454 408 - result, err := failsafe.With(retryPolicy). 455 + result, err := failsafe.With(fetchRetryPolicy). 409 456 WithContext(ctx). 410 457 Get(func() (fetchResult, error) { 411 458 recs, next, err := client.ListRecords(ctx, RecordType, batchSize, cursor)
+614
sync/publish_test.go
··· 1 + package sync 2 + 3 + import ( 4 + "context" 5 + "encoding/json" 6 + "errors" 7 + "fmt" 8 + "maps" 9 + "sync" 10 + "testing" 11 + "testing/synctest" 12 + "time" 13 + 14 + "github.com/bluesky-social/indigo/atproto/atclient" 15 + 16 + "tangled.org/karitham.dev/lazuli/atproto" 17 + "tangled.org/karitham.dev/lazuli/cache" 18 + ) 19 + 20 + // Mock Storage 21 + type mockStorage struct { 22 + cache.Storage 23 + unpublished map[string][]byte 24 + published map[string]bool 25 + failed map[string]string 26 + kv map[string]int 27 + mu sync.Mutex 28 + } 29 + 30 + func newMockStorage() *mockStorage { 31 + return &mockStorage{ 32 + unpublished: make(map[string][]byte), 33 + published: make(map[string]bool), 34 + failed: make(map[string]string), 35 + kv: make(map[string]int), 36 + } 37 + } 38 + 39 + func (m *mockStorage) SaveRecords(did string, records map[string][]byte) error { 40 + m.mu.Lock() 41 + defer m.mu.Unlock() 42 + maps.Copy(m.unpublished, records) 43 + return nil 44 + } 45 + 46 + func (m *mockStorage) IterateUnpublished(did string, fn func(key string, rec []byte) error) error { 47 + m.mu.Lock() 48 + keys := make([]string, 0, len(m.unpublished)) 49 + for k := range m.unpublished { 50 + keys = append(keys, k) 51 + } 52 + m.mu.Unlock() 53 + 54 + for _, k := range keys { 55 + m.mu.Lock() 56 + rec, ok := m.unpublished[k] 57 + m.mu.Unlock() 58 + if ok { 59 + if err := fn(k, rec); err != nil { 60 + return err 61 + } 62 + } 63 + } 64 + return nil 65 + } 66 + 67 + func (m *mockStorage) MarkPublished(did string, keys ...string) error { 68 + m.mu.Lock() 69 + defer m.mu.Unlock() 70 + for _, k := range keys { 71 + delete(m.unpublished, k) 72 + m.published[k] = true 73 + } 74 + return nil 75 + } 76 + 77 + func (m *mockStorage) MarkFailed(did string, keys []string, err string) error { 78 + m.mu.Lock() 79 + defer m.mu.Unlock() 80 + for _, k := range keys { 81 + m.failed[k] = err 82 + } 83 + return nil 84 + } 85 + 86 + func (m *mockStorage) Get(key string) (int, error) { 87 + m.mu.Lock() 88 + defer m.mu.Unlock() 89 + return m.kv[key], nil 90 + } 91 + 92 + func (m *mockStorage) IncrBy(key string, n int) (int, error) { 93 + m.mu.Lock() 94 + defer m.mu.Unlock() 95 + m.kv[key] += n 96 + return m.kv[key], nil 97 + } 98 + 99 + // Mock ATProtoClient 100 + type mockATProtoClient struct { 101 + applyWritesFunc func(ctx context.Context, collection string, records []PlayRecord) error 102 + listRecordsFunc func(ctx context.Context, collection string, limit int, cursor string) ([]atproto.RecordRef[PlayRecord], string, error) 103 + deleteRecordFunc func(ctx context.Context, collection, rkey string) error 104 + } 105 + 106 + func (m *mockATProtoClient) ApplyWrites(ctx context.Context, collection string, records []PlayRecord) error { 107 + if m.applyWritesFunc != nil { 108 + return m.applyWritesFunc(ctx, collection, records) 109 + } 110 + return nil 111 + } 112 + 113 + func (m *mockATProtoClient) ListRecords(ctx context.Context, collection string, limit int, cursor string) ([]atproto.RecordRef[PlayRecord], string, error) { 114 + if m.listRecordsFunc != nil { 115 + return m.listRecordsFunc(ctx, collection, limit, cursor) 116 + } 117 + return nil, "", nil 118 + } 119 + 120 + func (m *mockATProtoClient) DeleteRecord(ctx context.Context, collection, rkey string) error { 121 + if m.deleteRecordFunc != nil { 122 + return m.deleteRecordFunc(ctx, collection, rkey) 123 + } 124 + return nil 125 + } 126 + 127 + // Mock AuthClient 128 + type mockAuthClient struct { 129 + did string 130 + } 131 + 132 + func (m *mockAuthClient) APIClient() *atclient.APIClient { return nil } 133 + func (m *mockAuthClient) DID() string { return m.did } 134 + 135 + type timeoutError struct{} 136 + 137 + func (e timeoutError) Error() string { return "timeout" } 138 + func (e timeoutError) Timeout() bool { return true } 139 + func (e timeoutError) Temporary() bool { return true } 140 + 141 + type failingStorage struct { 142 + *mockStorage 143 + } 144 + 145 + func newFailingStorage() *failingStorage { 146 + return &failingStorage{ 147 + mockStorage: newMockStorage(), 148 + } 149 + } 150 + 151 + func (s *failingStorage) SaveRecords(did string, records map[string][]byte) error { 152 + return errors.New("failed to save records") 153 + } 154 + 155 + func TestBuildRecordBatches(t *testing.T) { 156 + tests := []struct { 157 + name string 158 + records []PlayRecord 159 + batchSize int 160 + wantBatches int 161 + wantErr bool 162 + ctxCancel bool 163 + }{ 164 + { 165 + name: "empty storage", 166 + records: []PlayRecord{}, 167 + batchSize: 2, 168 + wantBatches: 0, 169 + wantErr: false, 170 + }, 171 + { 172 + name: "single batch", 173 + records: []PlayRecord{ 174 + {TrackName: "Song 1"}, 175 + {TrackName: "Song 2"}, 176 + }, 177 + batchSize: 5, 178 + wantBatches: 1, 179 + wantErr: false, 180 + }, 181 + { 182 + name: "multiple exact batches", 183 + records: []PlayRecord{ 184 + {TrackName: "Song 1"}, 185 + {TrackName: "Song 2"}, 186 + {TrackName: "Song 3"}, 187 + {TrackName: "Song 4"}, 188 + }, 189 + batchSize: 2, 190 + wantBatches: 2, 191 + wantErr: false, 192 + }, 193 + { 194 + name: "partial final batch", 195 + records: []PlayRecord{ 196 + {TrackName: "Song 1"}, 197 + {TrackName: "Song 2"}, 198 + {TrackName: "Song 3"}, 199 + }, 200 + batchSize: 2, 201 + wantBatches: 2, 202 + wantErr: false, 203 + }, 204 + { 205 + name: "context cancelled", 206 + records: []PlayRecord{ 207 + {TrackName: "Song 1"}, 208 + {TrackName: "Song 2"}, 209 + }, 210 + batchSize: 2, 211 + wantBatches: 0, 212 + wantErr: true, 213 + ctxCancel: true, 214 + }, 215 + { 216 + name: "malformed records skipped", 217 + records: []PlayRecord{ 218 + {TrackName: "Song 1"}, 219 + {TrackName: "Song 2"}, 220 + }, 221 + batchSize: 2, 222 + wantBatches: 1, 223 + wantErr: false, 224 + }, 225 + } 226 + 227 + for _, tt := range tests { 228 + t.Run(tt.name, func(t *testing.T) { 229 + t.Parallel() 230 + 231 + ctx := context.Background() 232 + if tt.ctxCancel { 233 + var cancel context.CancelFunc 234 + ctx, cancel = context.WithCancel(ctx) 235 + cancel() 236 + } 237 + 238 + storage := newMockStorage() 239 + did := "did:example:123" 240 + 241 + // Add records to storage 242 + for i, record := range tt.records { 243 + data, _ := json.Marshal(record) 244 + if tt.name == "malformed records skipped" && i == 1 { 245 + data = []byte("invalid json") 246 + } 247 + storage.unpublished[fmt.Sprintf("key%d", i)] = data 248 + } 249 + 250 + batches, err := batchRecords(ctx, storage, did, tt.batchSize) 251 + 252 + if (err != nil) != tt.wantErr { 253 + t.Errorf("BuildRecordBatches() error = %v, wantErr %v", err, tt.wantErr) 254 + return 255 + } 256 + 257 + if len(batches) != tt.wantBatches { 258 + t.Errorf("BuildRecordBatches() batches = %d, want %d", len(batches), tt.wantBatches) 259 + } 260 + 261 + if !tt.wantErr { 262 + totalRecords := 0 263 + for _, batch := range batches { 264 + totalRecords += len(batch.Records) 265 + } 266 + 267 + expectedRecords := len(tt.records) 268 + if tt.name == "malformed records skipped" { 269 + expectedRecords = len(tt.records) - 1 // Skip malformed record 270 + } 271 + 272 + if totalRecords != expectedRecords { 273 + t.Errorf("BuildRecordBatches() total records = %d, want %d", totalRecords, expectedRecords) 274 + } 275 + } 276 + }) 277 + } 278 + } 279 + 280 + func TestProcessBatch(t *testing.T) { 281 + tests := []struct { 282 + name string 283 + batch recordBatch 284 + processor batchProcessor 285 + wantSuccess int 286 + wantError int 287 + wantErr bool 288 + setupClient func() *mockATProtoClient 289 + setupStorage func() cache.Storage 290 + }{ 291 + { 292 + name: "empty batch", 293 + batch: recordBatch{ 294 + Records: []PlayRecord{}, 295 + Keys: []string{}, 296 + }, 297 + processor: batchProcessor{ 298 + Client: &mockATProtoClient{}, 299 + Storage: newMockStorage(), 300 + }, 301 + wantSuccess: 0, 302 + wantError: 0, 303 + wantErr: false, 304 + }, 305 + { 306 + name: "successful batch", 307 + batch: recordBatch{ 308 + Records: []PlayRecord{{TrackName: "Song 1"}, {TrackName: "Song 2"}}, 309 + Keys: []string{"key1", "key2"}, 310 + }, 311 + processor: batchProcessor{ 312 + Client: &mockATProtoClient{}, 313 + Storage: newMockStorage(), 314 + DID: "did:example:123", 315 + ClientAgent: "test-agent", 316 + }, 317 + wantSuccess: 2, 318 + wantError: 0, 319 + wantErr: false, 320 + }, 321 + { 322 + name: "dry run batch", 323 + batch: recordBatch{ 324 + Records: []PlayRecord{{TrackName: "Song 1"}, {TrackName: "Song 2"}}, 325 + Keys: []string{"key1", "key2"}, 326 + }, 327 + processor: batchProcessor{ 328 + Client: &mockATProtoClient{}, 329 + Storage: newMockStorage(), 330 + DID: "did:example:123", 331 + ClientAgent: "test-agent", 332 + DryRun: true, 333 + }, 334 + wantSuccess: 2, 335 + wantError: 0, 336 + wantErr: false, 337 + }, 338 + { 339 + name: "batch with apply writes failure", 340 + batch: recordBatch{ 341 + Records: []PlayRecord{{TrackName: "Song 1"}}, 342 + Keys: []string{"key1"}, 343 + }, 344 + processor: batchProcessor{ 345 + Client: func() *mockATProtoClient { 346 + return &mockATProtoClient{ 347 + applyWritesFunc: func(ctx context.Context, collection string, records []PlayRecord) error { 348 + return errors.New("apply writes failed") 349 + }, 350 + } 351 + }(), 352 + Storage: newMockStorage(), 353 + DID: "did:example:123", 354 + ClientAgent: "test-agent", 355 + }, 356 + wantSuccess: 0, 357 + wantError: 1, 358 + wantErr: false, 359 + }, 360 + { 361 + name: "batch with storage failure", 362 + batch: recordBatch{ 363 + Records: []PlayRecord{{TrackName: "Song 1"}}, 364 + Keys: []string{"key1"}, 365 + }, 366 + processor: batchProcessor{ 367 + Client: &mockATProtoClient{}, 368 + Storage: newFailingStorage(), 369 + DID: "did:example:123", 370 + ClientAgent: "test-agent", 371 + }, 372 + wantSuccess: 0, 373 + wantError: 1, 374 + wantErr: false, 375 + }, 376 + } 377 + 378 + for _, tt := range tests { 379 + t.Run(tt.name, func(t *testing.T) { 380 + t.Parallel() 381 + 382 + ctx := context.Background() 383 + result := processBatch(ctx, tt.batch, tt.processor) 384 + 385 + if result.SuccessCount != tt.wantSuccess { 386 + t.Errorf("ProcessBatch() success count = %d, want %d", result.SuccessCount, tt.wantSuccess) 387 + } 388 + 389 + if result.ErrorCount != tt.wantError { 390 + t.Errorf("ProcessBatch() error count = %d, want %d", result.ErrorCount, tt.wantError) 391 + } 392 + 393 + if tt.wantError > 0 && len(result.Errors) == 0 { 394 + t.Error("ProcessBatch() expected errors but got none") 395 + } 396 + }) 397 + } 398 + } 399 + 400 + func TestAggregateResults(t *testing.T) { 401 + tests := []struct { 402 + name string 403 + results []batchResult 404 + startTime time.Time 405 + wantSuccess int 406 + wantErrors int 407 + wantTotal int 408 + wantDuration bool 409 + wantRatePerMin bool 410 + }{ 411 + { 412 + name: "empty results", 413 + results: []batchResult{}, 414 + wantSuccess: 0, 415 + wantErrors: 0, 416 + wantTotal: 0, 417 + }, 418 + { 419 + name: "single successful result", 420 + results: []batchResult{ 421 + {SuccessCount: 5, ErrorCount: 0, Duration: time.Second}, 422 + }, 423 + wantSuccess: 5, 424 + wantErrors: 0, 425 + wantTotal: 5, 426 + wantDuration: true, 427 + wantRatePerMin: true, 428 + }, 429 + { 430 + name: "multiple mixed results", 431 + results: []batchResult{ 432 + {SuccessCount: 3, ErrorCount: 1, Duration: time.Second}, 433 + {SuccessCount: 2, ErrorCount: 0, Duration: time.Second}, 434 + {SuccessCount: 0, ErrorCount: 2, Duration: time.Second}, 435 + }, 436 + wantSuccess: 5, 437 + wantErrors: 3, 438 + wantTotal: 8, 439 + wantDuration: true, 440 + wantRatePerMin: true, 441 + }, 442 + { 443 + name: "all errors", 444 + results: []batchResult{ 445 + {SuccessCount: 0, ErrorCount: 3, Duration: time.Second}, 446 + {SuccessCount: 0, ErrorCount: 2, Duration: time.Second}, 447 + }, 448 + wantSuccess: 0, 449 + wantErrors: 5, 450 + wantTotal: 5, 451 + wantDuration: true, 452 + wantRatePerMin: false, // 0 success rate 453 + }, 454 + } 455 + 456 + for _, tt := range tests { 457 + t.Run(tt.name, func(t *testing.T) { 458 + t.Parallel() 459 + 460 + startTime := time.Now() 461 + if !tt.startTime.IsZero() { 462 + startTime = tt.startTime 463 + } 464 + 465 + result := aggregate(tt.results, startTime) 466 + 467 + if result.SuccessCount != tt.wantSuccess { 468 + t.Errorf("AggregateResults() success = %d, want %d", result.SuccessCount, tt.wantSuccess) 469 + } 470 + 471 + if result.ErrorCount != tt.wantErrors { 472 + t.Errorf("AggregateResults() errors = %d, want %d", result.ErrorCount, tt.wantErrors) 473 + } 474 + 475 + if result.TotalRecords != tt.wantTotal { 476 + t.Errorf("AggregateResults() total = %d, want %d", result.TotalRecords, tt.wantTotal) 477 + } 478 + 479 + if tt.wantDuration && result.Duration == 0 { 480 + t.Error("AggregateResults() expected non-zero duration") 481 + } 482 + 483 + if tt.wantRatePerMin && result.RecordsPerMinute == 0 && tt.wantSuccess > 0 { 484 + t.Error("AggregateResults() expected non-zero rate per minute") 485 + } 486 + }) 487 + } 488 + } 489 + 490 + func TestPublish(t *testing.T) { 491 + tests := []struct { 492 + name string 493 + opts PublishOptions 494 + records []PlayRecord 495 + setupClient func() *mockATProtoClient 496 + wantSuccess int 497 + wantErrors int 498 + wantCancelled bool 499 + }{ 500 + { 501 + name: "successful publish", 502 + opts: PublishOptions{ 503 + BatchSize: 2, 504 + Storage: newMockStorage(), 505 + ClientAgent: "test-agent", 506 + }, 507 + records: []PlayRecord{ 508 + {TrackName: "Song 1"}, 509 + {TrackName: "Song 2"}, 510 + }, 511 + wantSuccess: 2, 512 + wantErrors: 0, 513 + }, 514 + { 515 + name: "dry run publish", 516 + opts: PublishOptions{ 517 + BatchSize: 2, 518 + DryRun: true, 519 + Storage: newMockStorage(), 520 + ClientAgent: "test-agent", 521 + }, 522 + records: []PlayRecord{ 523 + {TrackName: "Song 1"}, 524 + {TrackName: "Song 2"}, 525 + }, 526 + wantSuccess: 2, 527 + wantErrors: 0, 528 + }, 529 + { 530 + name: "publish with client errors", 531 + opts: PublishOptions{ 532 + BatchSize: 1, 533 + Storage: newMockStorage(), 534 + ClientAgent: "test-agent", 535 + }, 536 + records: []PlayRecord{ 537 + {TrackName: "Song 1"}, 538 + {TrackName: "Song 2"}, 539 + }, 540 + setupClient: func() *mockATProtoClient { 541 + return &mockATProtoClient{ 542 + applyWritesFunc: func(ctx context.Context, collection string, records []PlayRecord) error { 543 + return &atclient.APIError{StatusCode: 400} // Non-transient error 544 + }, 545 + } 546 + }, 547 + wantSuccess: 0, 548 + wantErrors: 2, 549 + }, 550 + } 551 + 552 + for _, tt := range tests { 553 + t.Run(tt.name, func(t *testing.T) { 554 + synctest.Test(t, func(t *testing.T) { 555 + ctx := context.Background() 556 + did := "did:example:123" 557 + 558 + storage := tt.opts.Storage.(*mockStorage) 559 + // Add records to storage 560 + for i, record := range tt.records { 561 + data, _ := json.Marshal(record) 562 + storage.unpublished[fmt.Sprintf("key%d", i)] = data 563 + } 564 + 565 + client := &mockAuthClient{did: did} 566 + if tt.setupClient != nil { 567 + tt.opts.ATProtoClient = tt.setupClient() 568 + } else { 569 + tt.opts.ATProtoClient = &mockATProtoClient{} 570 + } 571 + 572 + result := Publish(ctx, client, tt.opts) 573 + 574 + if result.SuccessCount != tt.wantSuccess { 575 + t.Errorf("Publish() success = %d, want %d", result.SuccessCount, tt.wantSuccess) 576 + } 577 + 578 + if result.ErrorCount != tt.wantErrors { 579 + t.Errorf("Publish() errors = %d, want %d", result.ErrorCount, tt.wantErrors) 580 + } 581 + 582 + if result.Cancelled != tt.wantCancelled { 583 + t.Errorf("Publish() cancelled = %v, want %v", result.Cancelled, tt.wantCancelled) 584 + } 585 + }) 586 + }) 587 + } 588 + } 589 + 590 + func TestIsTransientError(t *testing.T) { 591 + tests := []struct { 592 + name string 593 + err error 594 + want bool 595 + }{ 596 + {"nil", nil, false}, 597 + {"generic error", errors.New("some error"), false}, 598 + {"API 400", &atclient.APIError{StatusCode: 400}, false}, 599 + {"API 429", &atclient.APIError{StatusCode: 429}, true}, 600 + {"API 500", &atclient.APIError{StatusCode: 500}, true}, 601 + {"API 503", &atclient.APIError{StatusCode: 503}, true}, 602 + {"net timeout", timeoutError{}, true}, 603 + {"net non-timeout", errors.New("network is down"), false}, 604 + } 605 + 606 + for _, tt := range tests { 607 + t.Run(tt.name, func(t *testing.T) { 608 + t.Parallel() 609 + if got := IsTransientError(tt.err); got != tt.want { 610 + t.Errorf("IsTransientError() = %v, want %v", got, tt.want) 611 + } 612 + }) 613 + } 614 + }