like malachite (atproto-lastfm-importer) but in go and bluer
go spotify tealfm lastfm atproto
0
fork

Configure Feed

Select the types of activity you want to include in your feed.

refactor: atproto package

karitham 55ad5d37 941398d7

+3823 -1370
+352
atproto/client.go
··· 1 + package atproto 2 + 3 + import ( 4 + "context" 5 + "encoding/json" 6 + "fmt" 7 + "net/http" 8 + "net/url" 9 + "sync" 10 + "time" 11 + 12 + "github.com/bluesky-social/indigo/atproto/atclient" 13 + "github.com/bluesky-social/indigo/atproto/syntax" 14 + ) 15 + 16 + const ( 17 + DefaultRateLimitPercent = 0.85 18 + DefaultResolverURL = "https://slingshot.microcosm.blue/xrpc/com.bad-example.identity.resolveMiniDoc" 19 + ) 20 + 21 + type ClientOptions struct { 22 + UserAgent string 23 + ResolverURL string 24 + HTTPClient *http.Client 25 + RateLimitPercent float32 26 + } 27 + 28 + type resolvedIdentity struct { 29 + DID string 30 + Handle string 31 + PDS string 32 + SigningKey string 33 + } 34 + 35 + type Client struct { 36 + client *atclient.APIClient 37 + resolvedIdentity resolvedIdentity 38 + mu sync.Mutex 39 + } 40 + 41 + func WithUserAgent(ua string) func(*ClientOptions) { 42 + return func(o *ClientOptions) { 43 + o.UserAgent = ua 44 + } 45 + } 46 + 47 + func WithResolverURL(url string) func(*ClientOptions) { 48 + return func(o *ClientOptions) { 49 + o.ResolverURL = url 50 + } 51 + } 52 + 53 + func WithHTTPClient(hc *http.Client) func(*ClientOptions) { 54 + return func(o *ClientOptions) { 55 + o.HTTPClient = hc 56 + } 57 + } 58 + 59 + func WithRateLimitPercent(p float32) func(*ClientOptions) { 60 + return func(o *ClientOptions) { 61 + o.RateLimitPercent = p 62 + } 63 + } 64 + 65 + func NewClientOptions() *ClientOptions { 66 + return &ClientOptions{ 67 + ResolverURL: DefaultResolverURL, 68 + RateLimitPercent: DefaultRateLimitPercent, 69 + HTTPClient: &http.Client{ 70 + Timeout: 10 * time.Second, 71 + }, 72 + } 73 + } 74 + 75 + func ResolveMiniDoc(ctx context.Context, identifier string, opts *ClientOptions) (did, pds, signingKey string, err error) { 76 + resolver := DefaultResolverURL 77 + if opts != nil && opts.ResolverURL != "" { 78 + resolver = opts.ResolverURL 79 + } 80 + 81 + parsedURL, err := url.Parse(resolver) 82 + if err != nil { 83 + return "", "", "", fmt.Errorf("invalid resolver URL: %w", err) 84 + } 85 + 86 + query := parsedURL.Query() 87 + query.Set("identifier", identifier) 88 + parsedURL.RawQuery = query.Encode() 89 + 90 + req, err := http.NewRequestWithContext(ctx, "GET", parsedURL.String(), nil) 91 + if err != nil { 92 + return "", "", "", fmt.Errorf("failed to create request: %w", err) 93 + } 94 + 95 + if opts != nil && opts.UserAgent != "" { 96 + req.Header.Set("User-Agent", opts.UserAgent) 97 + } 98 + 99 + httpClient := http.DefaultClient 100 + if opts != nil && opts.HTTPClient != nil { 101 + httpClient = opts.HTTPClient 102 + } 103 + 104 + resp, err := httpClient.Do(req) 105 + if err != nil { 106 + return "", "", "", fmt.Errorf("failed to resolve mini doc: %w", err) 107 + } 108 + defer resp.Body.Close() 109 + 110 + if resp.StatusCode != http.StatusOK { 111 + return "", "", "", fmt.Errorf("mini doc resolution failed with status: %d", resp.StatusCode) 112 + } 113 + 114 + var result struct { 115 + DID string `json:"did"` 116 + PDS string `json:"pds"` 117 + SigningKey string `json:"signing_key"` 118 + } 119 + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { 120 + return "", "", "", fmt.Errorf("failed to decode mini doc response: %w", err) 121 + } 122 + 123 + if result.DID == "" { 124 + return "", "", "", fmt.Errorf("resolved mini doc missing DID") 125 + } 126 + 127 + return result.DID, result.PDS, result.SigningKey, nil 128 + } 129 + 130 + type FixedPasswordAuth struct { 131 + *atclient.PasswordAuth 132 + lk sync.RWMutex 133 + RefreshCallback func(ctx context.Context, session atclient.PasswordSessionData) 134 + } 135 + 136 + func (a *FixedPasswordAuth) DoWithAuth(c *http.Client, req *http.Request, endpoint syntax.NSID) (*http.Response, error) { 137 + accessToken, refreshToken := a.GetTokens() 138 + req.Header.Set("Authorization", "Bearer "+accessToken) 139 + resp, err := c.Do(req) 140 + if err != nil { 141 + return nil, err 142 + } 143 + 144 + if resp.StatusCode != http.StatusBadRequest { 145 + return resp, nil 146 + } 147 + 148 + if !hasJSONContent(resp.Header) { 149 + return resp, nil 150 + } 151 + 152 + defer resp.Body.Close() 153 + var eb atclient.ErrorBody 154 + if err := json.NewDecoder(resp.Body).Decode(&eb); err != nil { 155 + return nil, &atclient.APIError{StatusCode: resp.StatusCode} 156 + } 157 + if eb.Name != "ExpiredToken" { 158 + return nil, eb.APIError(resp.StatusCode) 159 + } 160 + 161 + if err := a.Refresh(req.Context(), c, refreshToken); err != nil { 162 + return nil, err 163 + } 164 + 165 + retry := req.Clone(req.Context()) 166 + if req.GetBody != nil { 167 + retryBody, err := req.GetBody() 168 + if err != nil { 169 + return nil, fmt.Errorf("API request retry GetBody failed: %w", err) 170 + } 171 + retry.Body = retryBody 172 + } 173 + 174 + accessToken, _ = a.GetTokens() 175 + retry.Header.Set("Authorization", "Bearer "+accessToken) 176 + return c.Do(retry) 177 + } 178 + 179 + func hasJSONContent(header http.Header) bool { 180 + return len(header.Get("Content-Type")) > 0 && header.Get("Content-Type")[0:19] == "application/json" 181 + } 182 + 183 + func (a *FixedPasswordAuth) Refresh(ctx context.Context, c *http.Client, priorRefreshToken string) error { 184 + a.lk.Lock() 185 + defer a.lk.Unlock() 186 + 187 + if priorRefreshToken != "" && priorRefreshToken != a.Session.RefreshToken { 188 + return nil 189 + } 190 + 191 + u := a.Session.Host + "/xrpc/com.atproto.server.refreshSession" 192 + req, err := http.NewRequestWithContext(ctx, http.MethodPost, u, nil) 193 + if err != nil { 194 + return err 195 + } 196 + req.Header.Set("User-Agent", "indigo-sdk") 197 + req.Header.Set("Authorization", "Bearer "+a.Session.RefreshToken) 198 + 199 + resp, err := c.Do(req) 200 + if err != nil { 201 + return err 202 + } 203 + defer resp.Body.Close() 204 + 205 + if resp.StatusCode < 200 || resp.StatusCode >= 300 { 206 + var eb atclient.ErrorBody 207 + if err := json.NewDecoder(resp.Body).Decode(&eb); err != nil { 208 + return &atclient.APIError{StatusCode: resp.StatusCode} 209 + } 210 + return eb.APIError(resp.StatusCode) 211 + } 212 + 213 + var out struct { 214 + AccessJwt string `json:"accessJwt"` 215 + RefreshJwt string `json:"refreshJwt"` 216 + } 217 + if err := json.NewDecoder(resp.Body).Decode(&out); err != nil { 218 + return err 219 + } 220 + 221 + a.Session.AccessToken = out.AccessJwt 222 + a.Session.RefreshToken = out.RefreshJwt 223 + 224 + if a.RefreshCallback != nil { 225 + snapshot := a.Session.Clone() 226 + a.RefreshCallback(ctx, snapshot) 227 + } 228 + 229 + return nil 230 + } 231 + 232 + func ResolveIdentity(ctx context.Context, handle string, opts *ClientOptions) (resolvedIdentity, error) { 233 + if handle == "" { 234 + return resolvedIdentity{}, fmt.Errorf("handle cannot be empty") 235 + } 236 + 237 + did, pds, signingKey, err := ResolveMiniDoc(ctx, handle, opts) 238 + if err != nil { 239 + return resolvedIdentity{}, fmt.Errorf("failed to resolve identity: %w", err) 240 + } 241 + 242 + if did == "" { 243 + return resolvedIdentity{}, fmt.Errorf("resolved identity missing DID") 244 + } 245 + 246 + if signingKey == "" { 247 + return resolvedIdentity{}, fmt.Errorf("resolved identity missing signing key") 248 + } 249 + 250 + return resolvedIdentity{ 251 + DID: did, 252 + Handle: handle, 253 + PDS: pds, 254 + SigningKey: signingKey, 255 + }, nil 256 + } 257 + 258 + func NewClient(ctx context.Context, handle, password string, opts ...func(*ClientOptions)) (*Client, error) { 259 + options := NewClientOptions() 260 + for _, opt := range opts { 261 + opt(options) 262 + } 263 + 264 + identity, err := ResolveIdentity(ctx, handle, options) 265 + if err != nil { 266 + return nil, fmt.Errorf("failed to resolve identity: %w", err) 267 + } 268 + 269 + pdsURL := identity.PDS 270 + if pdsURL == "" { 271 + pdsURL = "https://bsky.social" 272 + } 273 + 274 + client, err := atclient.LoginWithPasswordHost(ctx, pdsURL, handle, password, "", nil) 275 + if err != nil { 276 + return nil, fmt.Errorf("login failed: %w", err) 277 + } 278 + 279 + if pa, ok := client.Auth.(*atclient.PasswordAuth); ok { 280 + client.Auth = &FixedPasswordAuth{PasswordAuth: pa} 281 + } 282 + 283 + return &Client{ 284 + client: client, 285 + resolvedIdentity: identity, 286 + }, nil 287 + } 288 + 289 + func (c *Client) DID() string { 290 + c.mu.Lock() 291 + defer c.mu.Unlock() 292 + return c.resolvedIdentity.DID 293 + } 294 + 295 + func (c *Client) PDS() string { 296 + c.mu.Lock() 297 + defer c.mu.Unlock() 298 + return c.resolvedIdentity.PDS 299 + } 300 + 301 + func (c *Client) Handle() string { 302 + c.mu.Lock() 303 + defer c.mu.Unlock() 304 + return c.resolvedIdentity.Handle 305 + } 306 + 307 + func (c *Client) SigningKey() string { 308 + c.mu.Lock() 309 + defer c.mu.Unlock() 310 + return c.resolvedIdentity.SigningKey 311 + } 312 + 313 + func (c *Client) APIClient() *atclient.APIClient { 314 + c.mu.Lock() 315 + defer c.mu.Unlock() 316 + return c.client 317 + } 318 + 319 + func (c *Client) Close() error { 320 + c.mu.Lock() 321 + defer c.mu.Unlock() 322 + if c.client != nil && c.client.Auth != nil { 323 + if logout, ok := c.client.Auth.(*atclient.PasswordAuth); ok { 324 + return logout.Logout(context.Background(), c.client.Client) 325 + } 326 + } 327 + return nil 328 + } 329 + 330 + func (c *Client) HasClient() bool { 331 + c.mu.Lock() 332 + defer c.mu.Unlock() 333 + return c.client != nil 334 + } 335 + 336 + type AuthClient interface { 337 + APIClient() *atclient.APIClient 338 + DID() string 339 + } 340 + 341 + func BuildClient[T any](client AuthClient, customClient RepoClient[T]) (RepoClient[T], error) { 342 + if customClient != nil { 343 + return customClient, nil 344 + } 345 + 346 + apiClient := client.APIClient() 347 + if apiClient == nil { 348 + return nil, fmt.Errorf("API client is nil") 349 + } 350 + 351 + return NewRateClient[T](apiClient, client.DID(), nil), nil 352 + }
+706
atproto/client_test.go
··· 1 + package atproto 2 + 3 + import ( 4 + "context" 5 + "encoding/json" 6 + "net/http" 7 + "net/http/httptest" 8 + "testing" 9 + "time" 10 + 11 + "github.com/bluesky-social/indigo/atproto/atclient" 12 + ) 13 + 14 + func TestResolveMiniDoc(t *testing.T) { 15 + t.Parallel() 16 + 17 + tests := []struct { 18 + name string 19 + handle string 20 + opts *ClientOptions 21 + responseStatus int 22 + responseBody any 23 + wantErr bool 24 + wantDID string 25 + wantPDS string 26 + }{ 27 + { 28 + name: "success with default options", 29 + handle: "test.user", 30 + opts: NewClientOptions(), 31 + responseStatus: http.StatusOK, 32 + responseBody: map[string]any{ 33 + "did": "did:plc:z72iitness", 34 + "pds": "https://pds.example.com", 35 + "signing_key": "-----BEGIN PUBLIC KEY-----\ntest\n-----END PUBLIC KEY-----", 36 + }, 37 + wantErr: false, 38 + wantDID: "did:plc:z72iitness", 39 + wantPDS: "https://pds.example.com", 40 + }, 41 + { 42 + name: "empty DID", 43 + handle: "test.user", 44 + opts: NewClientOptions(), 45 + responseStatus: http.StatusOK, 46 + responseBody: map[string]any{ 47 + "did": "", 48 + "pds": "https://pds.example.com", 49 + "signing_key": "key", 50 + }, 51 + wantErr: true, 52 + }, 53 + { 54 + name: "non-200 status", 55 + handle: "test.user", 56 + opts: NewClientOptions(), 57 + responseStatus: http.StatusInternalServerError, 58 + responseBody: nil, 59 + wantErr: true, 60 + }, 61 + { 62 + name: "invalid JSON", 63 + handle: "test.user", 64 + opts: NewClientOptions(), 65 + responseStatus: http.StatusOK, 66 + responseBody: "not json", 67 + wantErr: true, 68 + }, 69 + } 70 + 71 + for _, tt := range tests { 72 + t.Run(tt.name, func(t *testing.T) { 73 + t.Parallel() 74 + 75 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 76 + if tt.opts != nil && tt.opts.ResolverURL != "" && tt.opts.ResolverURL != DefaultResolverURL { 77 + if r.URL.String() == DefaultResolverURL { 78 + t.Errorf("expected custom resolver URL to be used") 79 + } 80 + } 81 + 82 + w.Header().Set("Content-Type", "application/json") 83 + 84 + switch v := tt.responseBody.(type) { 85 + case string: 86 + w.Write([]byte(v)) 87 + case map[string]any: 88 + json.NewEncoder(w).Encode(v) 89 + default: 90 + w.WriteHeader(tt.responseStatus) 91 + return 92 + } 93 + 94 + if tt.responseStatus != http.StatusOK { 95 + w.WriteHeader(tt.responseStatus) 96 + } 97 + })) 98 + defer server.Close() 99 + 100 + ctx := context.Background() 101 + 102 + opts := tt.opts 103 + if opts == nil { 104 + opts = NewClientOptions() 105 + } 106 + if opts.ResolverURL == DefaultResolverURL { 107 + opts.ResolverURL = server.URL 108 + } 109 + 110 + did, pds, _, err := ResolveMiniDoc(ctx, tt.handle, opts) 111 + 112 + if tt.wantErr { 113 + if err == nil { 114 + t.Error("expected error, got nil") 115 + } 116 + return 117 + } 118 + 119 + if err != nil { 120 + t.Fatalf("ResolveMiniDoc failed: %v", err) 121 + } 122 + 123 + if tt.wantDID != "" && did != tt.wantDID { 124 + t.Errorf("DID = %s, want %s", did, tt.wantDID) 125 + } 126 + if tt.wantPDS != "" && pds != tt.wantPDS { 127 + t.Errorf("PDS = %s, want %s", pds, tt.wantPDS) 128 + } 129 + }) 130 + } 131 + } 132 + 133 + func TestResolveMiniDoc_WithCustomResolver(t *testing.T) { 134 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 135 + w.Header().Set("Content-Type", "application/json") 136 + json.NewEncoder(w).Encode(map[string]any{ 137 + "did": "did:plc:custom", 138 + "pds": "https://custom.pds.example.com", 139 + "signing_key": "key", 140 + }) 141 + })) 142 + defer server.Close() 143 + 144 + opts := NewClientOptions() 145 + opts.ResolverURL = server.URL 146 + 147 + ctx := context.Background() 148 + did, pds, _, err := ResolveMiniDoc(ctx, "test.user", opts) 149 + if err != nil { 150 + t.Fatalf("ResolveMiniDoc failed: %v", err) 151 + } 152 + 153 + if did != "did:plc:custom" { 154 + t.Errorf("DID = %s, want did:plc:custom", did) 155 + } 156 + if pds != "https://custom.pds.example.com" { 157 + t.Errorf("PDS = %s, want https://custom.pds.example.com", pds) 158 + } 159 + } 160 + 161 + func TestResolveMiniDoc_WithUserAgent(t *testing.T) { 162 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 163 + if r.Header.Get("User-Agent") != "test-agent/1.0" { 164 + t.Errorf("expected User-Agent 'test-agent/1.0', got %s", r.Header.Get("User-Agent")) 165 + } 166 + w.Header().Set("Content-Type", "application/json") 167 + json.NewEncoder(w).Encode(map[string]any{ 168 + "did": "did:plc:agent", 169 + "pds": "https://pds.example.com", 170 + "signing_key": "key", 171 + }) 172 + })) 173 + defer server.Close() 174 + 175 + opts := NewClientOptions() 176 + opts.ResolverURL = server.URL 177 + opts.UserAgent = "test-agent/1.0" 178 + 179 + ctx := context.Background() 180 + did, _, _, err := ResolveMiniDoc(ctx, "test.user", opts) 181 + if err != nil { 182 + t.Fatalf("ResolveMiniDoc failed: %v", err) 183 + } 184 + 185 + if did != "did:plc:agent" { 186 + t.Errorf("DID = %s, want did:plc:agent", did) 187 + } 188 + } 189 + 190 + func TestResolveMiniDoc_ContextCancelled(t *testing.T) { 191 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 192 + select { 193 + case <-time.After(100 * time.Millisecond): 194 + w.WriteHeader(http.StatusOK) 195 + case <-r.Context().Done(): 196 + return 197 + } 198 + })) 199 + defer server.Close() 200 + 201 + ctx, cancel := context.WithCancel(context.Background()) 202 + cancel() 203 + 204 + opts := NewClientOptions() 205 + opts.ResolverURL = server.URL 206 + opts.HTTPClient = server.Client() 207 + 208 + _, _, _, err := ResolveMiniDoc(ctx, "test.user", opts) 209 + 210 + if err == nil { 211 + t.Error("expected error for cancelled context") 212 + } 213 + } 214 + 215 + func TestResolveMiniDoc_InvalidURL(t *testing.T) { 216 + ctx := context.Background() 217 + 218 + opts := NewClientOptions() 219 + opts.ResolverURL = "://invalid-url" 220 + 221 + _, _, _, err := ResolveMiniDoc(ctx, "test.user", opts) 222 + 223 + if err == nil { 224 + t.Error("expected error for invalid URL") 225 + } 226 + } 227 + 228 + func TestResolveIdentity(t *testing.T) { 229 + t.Parallel() 230 + 231 + tests := []struct { 232 + name string 233 + handle string 234 + body map[string]any 235 + opts *ClientOptions 236 + wantErr bool 237 + checkFunc func(t *testing.T, identity resolvedIdentity, err error) 238 + }{ 239 + { 240 + name: "empty handle", 241 + handle: "", 242 + opts: NewClientOptions(), 243 + wantErr: true, 244 + }, 245 + { 246 + name: "missing signing key", 247 + handle: "test.user", 248 + body: map[string]any{ 249 + "did": "did:plc:z72iitness", 250 + "pds": "https://pds.example.com", 251 + "signing_key": "", 252 + }, 253 + opts: NewClientOptions(), 254 + wantErr: true, 255 + }, 256 + { 257 + name: "with custom resolver", 258 + handle: "custom.handle", 259 + body: map[string]any{ 260 + "did": "did:plc:custom123", 261 + "pds": "https://custom.pds.example.com", 262 + "signing_key": "-----BEGIN PUBLIC KEY-----\ncustom\n-----END PUBLIC KEY-----", 263 + }, 264 + opts: NewClientOptions(), 265 + checkFunc: func(t *testing.T, identity resolvedIdentity, err error) { 266 + if identity.DID != "did:plc:custom123" { 267 + t.Errorf("DID = %s, want did:plc:custom123", identity.DID) 268 + } 269 + if identity.PDS != "https://custom.pds.example.com" { 270 + t.Errorf("PDS = %s, want https://custom.pds.example.com", identity.PDS) 271 + } 272 + if identity.Handle != "custom.handle" { 273 + t.Errorf("Handle = %s, want custom.handle", identity.Handle) 274 + } 275 + }, 276 + }, 277 + } 278 + 279 + for _, tt := range tests { 280 + t.Run(tt.name, func(t *testing.T) { 281 + t.Parallel() 282 + 283 + if tt.handle == "" { 284 + ctx := context.Background() 285 + _, err := ResolveIdentity(ctx, "", tt.opts) 286 + if err == nil { 287 + t.Error("expected error for empty handle") 288 + } 289 + return 290 + } 291 + 292 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 293 + w.Header().Set("Content-Type", "application/json") 294 + json.NewEncoder(w).Encode(tt.body) 295 + })) 296 + defer server.Close() 297 + 298 + ctx := context.Background() 299 + 300 + opts := tt.opts 301 + if opts == nil { 302 + opts = NewClientOptions() 303 + } 304 + if opts.ResolverURL == DefaultResolverURL { 305 + opts.ResolverURL = server.URL 306 + } 307 + 308 + identity, err := ResolveIdentity(ctx, tt.handle, opts) 309 + 310 + if tt.wantErr { 311 + if err == nil { 312 + t.Error("expected error, got nil") 313 + } 314 + return 315 + } 316 + 317 + if err != nil { 318 + t.Fatalf("ResolveIdentity failed: %v", err) 319 + } 320 + 321 + if tt.checkFunc != nil { 322 + tt.checkFunc(t, identity, err) 323 + } 324 + }) 325 + } 326 + } 327 + 328 + func TestClient_Getters(t *testing.T) { 329 + t.Parallel() 330 + 331 + tests := []struct { 332 + name string 333 + setup func() *Client 334 + check func(t *testing.T, c *Client) 335 + }{ 336 + { 337 + name: "HasClient true", 338 + setup: func() *Client { 339 + return &Client{ 340 + client: &atclient.APIClient{}, 341 + resolvedIdentity: resolvedIdentity{ 342 + DID: "did:plc:test", 343 + Handle: "test.bsky.social", 344 + PDS: "https://pds.example.com", 345 + }, 346 + } 347 + }, 348 + check: func(t *testing.T, c *Client) { 349 + if !c.HasClient() { 350 + t.Error("HasClient() = false, want true") 351 + } 352 + }, 353 + }, 354 + { 355 + name: "HasClient false", 356 + setup: func() *Client { 357 + return &Client{} 358 + }, 359 + check: func(t *testing.T, c *Client) { 360 + if c.HasClient() { 361 + t.Error("HasClient() = true, want false") 362 + } 363 + }, 364 + }, 365 + { 366 + name: "DID", 367 + setup: func() *Client { 368 + return &Client{ 369 + resolvedIdentity: resolvedIdentity{DID: "did:plc:test123"}, 370 + } 371 + }, 372 + check: func(t *testing.T, c *Client) { 373 + if got := c.DID(); got != "did:plc:test123" { 374 + t.Errorf("DID() = %s, want did:plc:test123", got) 375 + } 376 + }, 377 + }, 378 + { 379 + name: "PDS", 380 + setup: func() *Client { 381 + return &Client{ 382 + resolvedIdentity: resolvedIdentity{PDS: "https://pds.example.com"}, 383 + } 384 + }, 385 + check: func(t *testing.T, c *Client) { 386 + if got := c.PDS(); got != "https://pds.example.com" { 387 + t.Errorf("PDS() = %s, want https://pds.example.com", got) 388 + } 389 + }, 390 + }, 391 + { 392 + name: "Handle", 393 + setup: func() *Client { 394 + return &Client{ 395 + resolvedIdentity: resolvedIdentity{Handle: "test.bsky.social"}, 396 + } 397 + }, 398 + check: func(t *testing.T, c *Client) { 399 + if got := c.Handle(); got != "test.bsky.social" { 400 + t.Errorf("Handle() = %s, want test.bsky.social", got) 401 + } 402 + }, 403 + }, 404 + { 405 + name: "SigningKey", 406 + setup: func() *Client { 407 + return &Client{ 408 + resolvedIdentity: resolvedIdentity{SigningKey: "-----BEGIN PUBLIC KEY-----\nabc\n-----END PUBLIC KEY-----"}, 409 + } 410 + }, 411 + check: func(t *testing.T, c *Client) { 412 + if got := c.SigningKey(); got != "-----BEGIN PUBLIC KEY-----\nabc\n-----END PUBLIC KEY-----" { 413 + t.Errorf("SigningKey() = %s, want expected key", got) 414 + } 415 + }, 416 + }, 417 + { 418 + name: "APIClient", 419 + setup: func() *Client { 420 + expectedClient := &atclient.APIClient{} 421 + return &Client{ 422 + client: expectedClient, 423 + resolvedIdentity: resolvedIdentity{DID: "did:plc:test"}, 424 + } 425 + }, 426 + check: func(t *testing.T, c *Client) { 427 + if got := c.APIClient(); got == nil { 428 + t.Error("APIClient() = nil, want non-nil") 429 + } 430 + }, 431 + }, 432 + { 433 + name: "APIClient nil", 434 + setup: func() *Client { 435 + return &Client{} 436 + }, 437 + check: func(t *testing.T, c *Client) { 438 + if got := c.APIClient(); got != nil { 439 + t.Errorf("APIClient() = %v, want nil", got) 440 + } 441 + }, 442 + }, 443 + } 444 + 445 + for _, tt := range tests { 446 + t.Run(tt.name, func(t *testing.T) { 447 + t.Parallel() 448 + c := tt.setup() 449 + tt.check(t, c) 450 + }) 451 + } 452 + } 453 + 454 + func TestBuildClient(t *testing.T) { 455 + t.Parallel() 456 + 457 + tests := []struct { 458 + name string 459 + customClient RepoClient[map[string]any] 460 + apiClient *atclient.APIClient 461 + did string 462 + wantErr bool 463 + wantCustom bool 464 + }{ 465 + { 466 + name: "nil custom client", 467 + customClient: nil, 468 + apiClient: &atclient.APIClient{}, 469 + did: "did:plc:test", 470 + wantErr: false, 471 + wantCustom: false, 472 + }, 473 + { 474 + name: "with custom client", 475 + customClient: &mockRepoClient[map[string]any]{}, 476 + apiClient: &atclient.APIClient{}, 477 + did: "did:plc:test", 478 + wantErr: false, 479 + wantCustom: true, 480 + }, 481 + { 482 + name: "nil API client", 483 + customClient: nil, 484 + apiClient: nil, 485 + did: "did:plc:test", 486 + wantErr: true, 487 + wantCustom: false, 488 + }, 489 + } 490 + 491 + for _, tt := range tests { 492 + t.Run(tt.name, func(t *testing.T) { 493 + t.Parallel() 494 + 495 + identity := &mockAuthClient{apiClient: tt.apiClient, did: tt.did} 496 + result, err := BuildClient[map[string]any](identity, tt.customClient) 497 + 498 + if tt.wantErr { 499 + if err == nil { 500 + t.Error("expected error, got nil") 501 + } 502 + return 503 + } 504 + 505 + if err != nil { 506 + t.Fatalf("BuildClient failed: %v", err) 507 + } 508 + 509 + if result == nil { 510 + t.Fatal("BuildClient returned nil") 511 + } 512 + 513 + if tt.wantCustom && result != tt.customClient { 514 + t.Error("BuildClient should return custom client") 515 + } 516 + }) 517 + } 518 + } 519 + 520 + func TestClientOptions(t *testing.T) { 521 + t.Parallel() 522 + 523 + t.Run("WithUserAgent", func(t *testing.T) { 524 + t.Parallel() 525 + opts := NewClientOptions() 526 + WithUserAgent("test-agent")(opts) 527 + if opts.UserAgent != "test-agent" { 528 + t.Errorf("UserAgent = %s, want test-agent", opts.UserAgent) 529 + } 530 + }) 531 + 532 + t.Run("WithResolverURL", func(t *testing.T) { 533 + t.Parallel() 534 + opts := NewClientOptions() 535 + WithResolverURL("http://custom-resolver")(opts) 536 + if opts.ResolverURL != "http://custom-resolver" { 537 + t.Errorf("ResolverURL = %s, want http://custom-resolver", opts.ResolverURL) 538 + } 539 + }) 540 + 541 + t.Run("WithHTTPClient", func(t *testing.T) { 542 + t.Parallel() 543 + opts := NewClientOptions() 544 + customClient := &http.Client{} 545 + WithHTTPClient(customClient)(opts) 546 + if opts.HTTPClient != customClient { 547 + t.Errorf("HTTPClient = %v, want %v", opts.HTTPClient, customClient) 548 + } 549 + }) 550 + 551 + t.Run("WithRateLimitPercent", func(t *testing.T) { 552 + t.Parallel() 553 + opts := NewClientOptions() 554 + WithRateLimitPercent(0.5)(opts) 555 + if opts.RateLimitPercent != 0.5 { 556 + t.Errorf("RateLimitPercent = %f, want 0.5", opts.RateLimitPercent) 557 + } 558 + }) 559 + 560 + t.Run("defaults", func(t *testing.T) { 561 + t.Parallel() 562 + opts := NewClientOptions() 563 + if opts.RateLimitPercent != DefaultRateLimitPercent { 564 + t.Errorf("RateLimitPercent = %f, want %f", opts.RateLimitPercent, DefaultRateLimitPercent) 565 + } 566 + if opts.ResolverURL != DefaultResolverURL { 567 + t.Errorf("ResolverURL = %s, want %s", opts.ResolverURL, DefaultResolverURL) 568 + } 569 + if opts.HTTPClient == nil { 570 + t.Error("HTTPClient should not be nil") 571 + } 572 + }) 573 + } 574 + 575 + type mockAuthClient struct { 576 + apiClient *atclient.APIClient 577 + did string 578 + } 579 + 580 + func (m *mockAuthClient) APIClient() *atclient.APIClient { 581 + return m.apiClient 582 + } 583 + 584 + func (m *mockAuthClient) DID() string { 585 + return m.did 586 + } 587 + 588 + type mockRepoClient[T any] struct{} 589 + 590 + func (m *mockRepoClient[T]) ListRecords(ctx context.Context, collection string, limit int, cursor string) ([]RecordRef[T], string, error) { 591 + return nil, "", nil 592 + } 593 + 594 + func (m *mockRepoClient[T]) ApplyWrites(ctx context.Context, collection string, records []T) error { 595 + return nil 596 + } 597 + 598 + func (m *mockRepoClient[T]) DeleteRecord(ctx context.Context, collection, rkey string) error { 599 + return nil 600 + } 601 + 602 + func TestFixedPasswordAuth_RefreshUsesPOST(t *testing.T) { 603 + var capturedMethod string 604 + var capturedPath string 605 + 606 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 607 + capturedMethod = r.Method 608 + capturedPath = r.URL.Path 609 + 610 + if capturedMethod == http.MethodPost && capturedPath == "/xrpc/com.atproto.server.refreshSession" { 611 + w.Header().Set("Content-Type", "application/json") 612 + json.NewEncoder(w).Encode(map[string]any{ 613 + "accessJwt": "new-access-token", 614 + "refreshJwt": "new-refresh-token", 615 + }) 616 + return 617 + } 618 + 619 + if capturedPath == "/xrpc/com.atproto.server.createSession" { 620 + w.Header().Set("Content-Type", "application/json") 621 + json.NewEncoder(w).Encode(map[string]any{ 622 + "accessJwt": "access-token", 623 + "refreshJwt": "refresh-token", 624 + "did": "did:plc:test", 625 + }) 626 + return 627 + } 628 + 629 + w.WriteHeader(http.StatusBadRequest) 630 + })) 631 + defer server.Close() 632 + 633 + pdsURL := server.URL 634 + 635 + session := atclient.PasswordSessionData{ 636 + AccessToken: "old-access-token", 637 + RefreshToken: "refresh-token", 638 + Host: pdsURL, 639 + } 640 + 641 + auth := &FixedPasswordAuth{ 642 + PasswordAuth: &atclient.PasswordAuth{Session: session}, 643 + } 644 + 645 + httpClient := server.Client() 646 + 647 + ctx := context.Background() 648 + err := auth.Refresh(ctx, httpClient, "refresh-token") 649 + 650 + if err != nil { 651 + t.Fatalf("Refresh failed: %v", err) 652 + } 653 + 654 + if capturedMethod != http.MethodPost { 655 + t.Errorf("Refresh used %s, want POST", capturedMethod) 656 + } 657 + 658 + if capturedPath != "/xrpc/com.atproto.server.refreshSession" { 659 + t.Errorf("Refresh hit %s, want /xrpc/com.atproto.server.refreshSession", capturedPath) 660 + } 661 + } 662 + 663 + func TestFixedPasswordAuth_IndigoBugCheck(t *testing.T) { 664 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 665 + if r.URL.Path == "/xrpc/com.atproto.server.refreshSession" { 666 + if r.Method != http.MethodPost { 667 + t.Errorf("REFRESH BUG: indigo uses %s for refreshSession, should use POST", r.Method) 668 + } 669 + w.Header().Set("Content-Type", "application/json") 670 + json.NewEncoder(w).Encode(map[string]any{ 671 + "accessJwt": "new-access-token", 672 + "refreshJwt": "new-refresh-token", 673 + }) 674 + return 675 + } 676 + if r.URL.Path == "/xrpc/com.atproto.server.createSession" { 677 + w.Header().Set("Content-Type", "application/json") 678 + json.NewEncoder(w).Encode(map[string]any{ 679 + "accessJwt": "access-token", 680 + "refreshJwt": "refresh-token", 681 + "did": "did:plc:test", 682 + }) 683 + return 684 + } 685 + w.WriteHeader(http.StatusBadRequest) 686 + })) 687 + defer server.Close() 688 + 689 + pdsURL := server.URL 690 + 691 + pa := &atclient.PasswordAuth{ 692 + Session: atclient.PasswordSessionData{ 693 + AccessToken: "old-access-token", 694 + RefreshToken: "refresh-token", 695 + Host: pdsURL, 696 + }, 697 + } 698 + 699 + fixedAuth := &FixedPasswordAuth{PasswordAuth: pa} 700 + 701 + err := fixedAuth.Refresh(context.Background(), server.Client(), "refresh-token") 702 + 703 + if err != nil { 704 + t.Fatalf("FixedPasswordAuth.Refresh failed: %v", err) 705 + } 706 + }
+694
atproto/rate_test.go
··· 1 + package atproto 2 + 3 + import ( 4 + "context" 5 + "fmt" 6 + "testing" 7 + "time" 8 + ) 9 + 10 + func TestNewRateLimiter(t *testing.T) { 11 + kv := &mockKVStore{data: make(map[string]int)} 12 + rl := NewRateLimiter(kv, 0.8) 13 + 14 + if rl == nil { 15 + t.Fatal("NewRateLimiter returned nil") 16 + } 17 + } 18 + 19 + func TestQuotaLimiter_Stats(t *testing.T) { 20 + kv := &mockKVStore{data: map[string]int{ 21 + "quota:writes:d:2024-01-15": 100, 22 + "quota:global:d:2024-01-15": 200, 23 + }} 24 + clock := &mockClock{now: time.Date(2024, 1, 15, 12, 0, 0, 0, time.UTC)} 25 + 26 + rl := &quotaLimiter{ 27 + kv: kv, 28 + prefix: "quota", 29 + clock: clock, 30 + rlQuota: 1.0, 31 + } 32 + 33 + writes, global, err := rl.Stats() 34 + if err != nil { 35 + t.Fatalf("Stats failed: %v", err) 36 + } 37 + if writes != 100 { 38 + t.Errorf("Stats writes = %d, want 100", writes) 39 + } 40 + if global != 200 { 41 + t.Errorf("Stats global = %d, want 200", global) 42 + } 43 + } 44 + 45 + func TestQuotaLimiter_AllowRead(t *testing.T) { 46 + kv := &mockKVStore{data: make(map[string]int)} 47 + clock := &mockClock{now: time.Date(2024, 1, 15, 12, 0, 0, 0, time.UTC)} 48 + 49 + rl := &quotaLimiter{ 50 + kv: kv, 51 + prefix: "quota", 52 + clock: clock, 53 + rlQuota: 1.0, 54 + } 55 + 56 + ctx := context.Background() 57 + chargedAt, err := rl.AllowRead(ctx) 58 + if err != nil { 59 + t.Fatalf("AllowRead failed: %v", err) 60 + } 61 + 62 + if kv.incrs == nil || len(kv.incrs) == 0 { 63 + t.Error("Expected kv.IncrByMulti to be called") 64 + } 65 + if !chargedAt.IsZero() { 66 + t.Logf("Charged at: %v", chargedAt) 67 + } 68 + } 69 + 70 + func TestQuotaLimiter_AllowBulkWrite(t *testing.T) { 71 + t.Parallel() 72 + 73 + tests := []struct { 74 + name string 75 + n int 76 + kvData map[string]int 77 + wantErr bool 78 + }{ 79 + { 80 + name: "success with 5 records", 81 + n: 5, 82 + kvData: make(map[string]int), 83 + wantErr: false, 84 + }, 85 + } 86 + 87 + for _, tt := range tests { 88 + t.Run(tt.name, func(t *testing.T) { 89 + t.Parallel() 90 + 91 + kv := &mockKVStore{data: tt.kvData} 92 + clock := &mockClock{now: time.Date(2024, 1, 15, 12, 0, 0, 0, time.UTC)} 93 + 94 + rl := &quotaLimiter{ 95 + kv: kv, 96 + prefix: "quota", 97 + clock: clock, 98 + rlQuota: 1.0, 99 + } 100 + 101 + ctx := context.Background() 102 + _, err := rl.AllowBulkWrite(ctx, tt.n) 103 + 104 + if tt.wantErr && err == nil { 105 + t.Error("expected error, got nil") 106 + } 107 + if !tt.wantErr && err != nil { 108 + t.Fatalf("AllowBulkWrite failed: %v", err) 109 + } 110 + }) 111 + } 112 + } 113 + 114 + func TestQuotaLimiter_Refund(t *testing.T) { 115 + t.Parallel() 116 + 117 + tests := []struct { 118 + name string 119 + setupKV func() *mockKVStore 120 + refundFunc func(rl *quotaLimiter, ctx context.Context) 121 + wantIncrCalls int 122 + }{ 123 + { 124 + name: "refund bulk write", 125 + setupKV: func() *mockKVStore { 126 + return &mockKVStore{data: map[string]int{ 127 + "quota:writes:m:2024-01-15-12-00": 10, 128 + "quota:writes:h:2024-01-15-12": 20, 129 + "quota:writes:d:2024-01-15": 30, 130 + "quota:global:m:2024-01-15-12-00": 5, 131 + "quota:global:h:2024-01-15-12": 10, 132 + "quota:global:d:2024-01-15": 15, 133 + }} 134 + }, 135 + refundFunc: func(rl *quotaLimiter, ctx context.Context) { 136 + rl.RefundBulkWrite(ctx, 3, time.Now()) 137 + }, 138 + wantIncrCalls: 1, 139 + }, 140 + { 141 + name: "refund read", 142 + setupKV: func() *mockKVStore { 143 + return &mockKVStore{data: map[string]int{ 144 + "quota:global:m:2024-01-15-12-00": 5, 145 + "quota:global:h:2024-01-15-12": 10, 146 + "quota:global:d:2024-01-15": 15, 147 + }} 148 + }, 149 + refundFunc: func(rl *quotaLimiter, ctx context.Context) { 150 + rl.RefundRead(ctx, time.Now()) 151 + }, 152 + wantIncrCalls: 1, 153 + }, 154 + } 155 + 156 + for _, tt := range tests { 157 + t.Run(tt.name, func(t *testing.T) { 158 + t.Parallel() 159 + 160 + kv := tt.setupKV() 161 + clock := &mockClock{now: time.Date(2024, 1, 15, 12, 0, 0, 0, time.UTC)} 162 + 163 + rl := &quotaLimiter{ 164 + kv: kv, 165 + prefix: "quota", 166 + clock: clock, 167 + rlQuota: 1.0, 168 + } 169 + 170 + ctx := context.Background() 171 + tt.refundFunc(rl, ctx) 172 + 173 + if len(kv.incrs) != tt.wantIncrCalls { 174 + t.Errorf("Expected %d IncrByMulti call, got %d", tt.wantIncrCalls, len(kv.incrs)) 175 + } 176 + }) 177 + } 178 + } 179 + 180 + func TestQuotaLimiter_EstimatedWriteTime(t *testing.T) { 181 + t.Parallel() 182 + 183 + tests := []struct { 184 + name string 185 + n int 186 + kvData map[string]int 187 + wantZero bool 188 + }{ 189 + { 190 + name: "zero records", 191 + n: 0, 192 + kvData: make(map[string]int), 193 + wantZero: true, 194 + }, 195 + { 196 + name: "negative records", 197 + n: -5, 198 + kvData: make(map[string]int), 199 + wantZero: true, 200 + }, 201 + { 202 + name: "with existing usage", 203 + n: 60, 204 + kvData: map[string]int{ 205 + "quota:writes:m:2024-01-15-12-00": 50, 206 + "quota:writes:h:2024-01-15-12": 500, 207 + }, 208 + wantZero: false, 209 + }, 210 + } 211 + 212 + for _, tt := range tests { 213 + t.Run(tt.name, func(t *testing.T) { 214 + t.Parallel() 215 + 216 + kv := &mockKVStore{data: tt.kvData} 217 + clock := &mockClock{now: time.Date(2024, 1, 15, 12, 0, 30, 0, time.UTC)} 218 + 219 + rl := &quotaLimiter{ 220 + kv: kv, 221 + prefix: "quota", 222 + clock: clock, 223 + rlQuota: 1.0, 224 + } 225 + 226 + wait := rl.EstimatedWriteTime(tt.n) 227 + 228 + if tt.wantZero && wait != 0 { 229 + t.Errorf("EstimatedWriteTime(%d) = %v, want 0", tt.n, wait) 230 + } 231 + if !tt.wantZero && tt.n > 0 && wait == 0 { 232 + t.Log("EstimatedWriteTime returned 0 with existing usage") 233 + } 234 + }) 235 + } 236 + } 237 + 238 + func TestQuotaLimiter_RemainingQuota(t *testing.T) { 239 + kv := &mockKVStore{data: make(map[string]int)} 240 + clock := &mockClock{now: time.Date(2024, 1, 15, 12, 0, 0, 0, time.UTC)} 241 + 242 + rl := &quotaLimiter{ 243 + kv: kv, 244 + prefix: "quota", 245 + clock: clock, 246 + rlQuota: 1.0, 247 + } 248 + 249 + writes, global, reset := rl.RemainingQuota() 250 + if writes != WriteLimitMinute { 251 + t.Errorf("RemainingQuota writes = %d, want %d (minute limit)", writes, WriteLimitMinute) 252 + } 253 + if global != GlobalLimitMinute { 254 + t.Errorf("RemainingQuota global = %d, want %d (minute limit)", global, GlobalLimitMinute) 255 + } 256 + if reset < 0 { 257 + t.Errorf("RemainingQuota reset = %v, want >= 0", reset) 258 + } 259 + } 260 + 261 + func TestUntilNextWindow(t *testing.T) { 262 + t.Parallel() 263 + 264 + tests := []struct { 265 + name string 266 + now time.Time 267 + tier int 268 + wantInRange bool 269 + minExpected time.Duration 270 + }{ 271 + { 272 + name: "minute tier at 30s past", 273 + now: time.Date(2024, 1, 15, 12, 0, 30, 0, time.UTC), 274 + tier: 0, 275 + wantInRange: true, 276 + minExpected: 30*time.Second + 100*time.Millisecond, 277 + }, 278 + { 279 + name: "hour tier at 30m past", 280 + now: time.Date(2024, 1, 15, 12, 30, 0, 0, time.UTC), 281 + tier: 1, 282 + wantInRange: true, 283 + minExpected: 30*time.Minute + 100*time.Millisecond, 284 + }, 285 + { 286 + name: "day tier at 6 hours past", 287 + now: time.Date(2024, 1, 15, 18, 0, 0, 0, time.UTC), 288 + tier: 2, 289 + wantInRange: true, 290 + minExpected: 6*time.Hour + 100*time.Millisecond, 291 + }, 292 + { 293 + name: "unknown tier", 294 + now: time.Now(), 295 + tier: 99, 296 + wantInRange: false, 297 + minExpected: time.Minute, 298 + }, 299 + } 300 + 301 + for _, tt := range tests { 302 + t.Run(tt.name, func(t *testing.T) { 303 + t.Parallel() 304 + 305 + kv := &mockKVStore{data: make(map[string]int)} 306 + clock := &mockClock{now: tt.now} 307 + 308 + rl := &quotaLimiter{ 309 + kv: kv, 310 + prefix: "quota", 311 + clock: clock, 312 + rlQuota: 1.0, 313 + } 314 + 315 + wait := rl.untilNextWindow(tt.now, tt.tier) 316 + 317 + if tt.wantInRange { 318 + expected := tt.minExpected 319 + if wait < expected-time.Second || wait > expected+time.Second { 320 + t.Errorf("untilNextWindow(%d) = %v, want ~%v", tt.tier, wait, expected) 321 + } 322 + } else { 323 + if wait != tt.minExpected { 324 + t.Errorf("untilNextWindow(%d) = %v, want %v", tt.tier, wait, tt.minExpected) 325 + } 326 + } 327 + }) 328 + } 329 + } 330 + 331 + func TestAddJitter(t *testing.T) { 332 + base := 100 * time.Millisecond 333 + 334 + for range 100 { 335 + jitter := addJitter(base) 336 + if jitter < base { 337 + t.Errorf("addJitter returned %v, want >= %v", jitter, base) 338 + } 339 + if jitter > base+base/10 { 340 + t.Errorf("addJitter returned %v, want <= %v + 10%%", jitter, base) 341 + } 342 + } 343 + } 344 + 345 + func TestQuotaLimiter_EmptyKeys(t *testing.T) { 346 + t.Parallel() 347 + 348 + tests := []struct { 349 + name string 350 + wKeys []string 351 + gKeys []string 352 + wCost int 353 + gCost int 354 + wantWait time.Duration 355 + }{ 356 + { 357 + name: "empty keys for checkQuota", 358 + wKeys: nil, 359 + gKeys: nil, 360 + wCost: 0, 361 + gCost: 0, 362 + wantWait: 0, 363 + }, 364 + { 365 + name: "empty keys for charge", 366 + wKeys: nil, 367 + gKeys: nil, 368 + wCost: 0, 369 + gCost: 0, 370 + wantWait: 0, 371 + }, 372 + } 373 + 374 + for _, tt := range tests { 375 + t.Run(tt.name, func(t *testing.T) { 376 + t.Parallel() 377 + 378 + kv := &mockKVStore{data: make(map[string]int)} 379 + clock := &mockClock{now: time.Date(2024, 1, 15, 12, 0, 0, 0, time.UTC)} 380 + 381 + rl := &quotaLimiter{ 382 + kv: kv, 383 + prefix: "quota", 384 + clock: clock, 385 + rlQuota: 1.0, 386 + } 387 + 388 + now := time.Now() 389 + 390 + if tt.name == "empty keys for checkQuota" { 391 + wait, err := rl.checkQuota(now, tt.wKeys, tt.gKeys, tt.wCost, tt.gCost) 392 + if err != nil { 393 + t.Fatalf("checkQuota failed: %v", err) 394 + } 395 + if wait != tt.wantWait { 396 + t.Errorf("checkQuota = %v, want %v", wait, tt.wantWait) 397 + } 398 + } else { 399 + err := rl.charge(tt.wKeys, tt.gKeys, tt.wCost, tt.gCost) 400 + if err != nil { 401 + t.Fatalf("charge failed: %v", err) 402 + } 403 + if len(kv.incrs) > 0 { 404 + t.Errorf("charge with empty keys should not call IncrByMulti") 405 + } 406 + } 407 + }) 408 + } 409 + } 410 + 411 + func TestRateLimiter_Interfaces(t *testing.T) { 412 + t.Parallel() 413 + 414 + tests := []struct { 415 + name string 416 + checkFunc func(t *testing.T) 417 + }{ 418 + { 419 + name: "RateLimiter interface", 420 + checkFunc: func(t *testing.T) { 421 + var _ RateLimiter = (*quotaLimiter)(nil) 422 + }, 423 + }, 424 + { 425 + name: "KVStore interface", 426 + checkFunc: func(t *testing.T) { 427 + var _ KVStore = (*mockKVStore)(nil) 428 + }, 429 + }, 430 + { 431 + name: "Clock interface - realClock", 432 + checkFunc: func(t *testing.T) { 433 + var _ Clock = (*realClock)(nil) 434 + }, 435 + }, 436 + { 437 + name: "Clock interface - mockClock", 438 + checkFunc: func(t *testing.T) { 439 + var _ Clock = (*mockClock)(nil) 440 + }, 441 + }, 442 + } 443 + 444 + for _, tt := range tests { 445 + t.Run(tt.name, func(t *testing.T) { 446 + t.Parallel() 447 + tt.checkFunc(t) 448 + }) 449 + } 450 + } 451 + 452 + func TestRateLimiterConstants(t *testing.T) { 453 + tests := []struct { 454 + name string 455 + got int 456 + expected int 457 + }{ 458 + {"WriteLimitMinute", WriteLimitMinute, 100}, 459 + {"WriteLimitHour", WriteLimitHour, 1000}, 460 + {"WriteLimitDay", WriteLimitDay, 10000}, 461 + {"GlobalLimitMinute", GlobalLimitMinute, 300}, 462 + {"GlobalLimitHour", GlobalLimitHour, 3000}, 463 + {"GlobalLimitDay", GlobalLimitDay, 35000}, 464 + {"ReadGlobalCost", ReadGlobalCost, 1}, 465 + {"WriteOnlyCost", WriteOnlyCost, 1}, 466 + {"WriteGlobalCost", WriteGlobalCost, 3}, 467 + } 468 + 469 + for _, tt := range tests { 470 + t.Run(tt.name, func(t *testing.T) { 471 + if tt.got != tt.expected { 472 + t.Errorf("%s = %d, want %d", tt.name, tt.got, tt.expected) 473 + } 474 + }) 475 + } 476 + } 477 + 478 + func TestQuotaLimiter_getKeys(t *testing.T) { 479 + kv := &mockKVStore{data: make(map[string]int)} 480 + clock := &mockClock{now: time.Date(2024, 1, 15, 12, 30, 45, 0, time.UTC)} 481 + 482 + rl := &quotaLimiter{ 483 + kv: kv, 484 + prefix: "test", 485 + clock: clock, 486 + rlQuota: 1.0, 487 + } 488 + 489 + wd, gd, wh, gh, wm, gm := rl.getKeys(time.Date(2024, 1, 15, 12, 30, 45, 0, time.UTC)) 490 + 491 + expectedKeys := []struct { 492 + got string 493 + want string 494 + }{ 495 + {wd, "test:writes:d:2024-01-15"}, 496 + {gd, "test:global:d:2024-01-15"}, 497 + {wh, "test:writes:h:2024-01-15-12"}, 498 + {gh, "test:global:h:2024-01-15-12"}, 499 + {wm, "test:writes:m:2024-01-15-12-30"}, 500 + {gm, "test:global:m:2024-01-15-12-30"}, 501 + } 502 + 503 + for _, k := range expectedKeys { 504 + if k.got != k.want { 505 + t.Errorf("key = %s, want %s", k.got, k.want) 506 + } 507 + } 508 + } 509 + 510 + func TestQuotaLimiter_getAllKeys(t *testing.T) { 511 + kv := &mockKVStore{data: make(map[string]int)} 512 + clock := &mockClock{now: time.Date(2024, 1, 15, 12, 30, 45, 0, time.UTC)} 513 + 514 + rl := &quotaLimiter{ 515 + kv: kv, 516 + prefix: "test", 517 + clock: clock, 518 + rlQuota: 1.0, 519 + } 520 + 521 + tm := time.Date(2024, 1, 15, 12, 30, 45, 0, time.UTC) 522 + wKeys, gKeys := rl.getAllKeys(tm) 523 + 524 + if len(wKeys) != 3 { 525 + t.Errorf("getAllKeys returned %d write keys, want 3", len(wKeys)) 526 + } 527 + if len(gKeys) != 3 { 528 + t.Errorf("getAllKeys returned %d global keys, want 3", len(gKeys)) 529 + } 530 + } 531 + 532 + func TestQuotaLimiter_KVErrors(t *testing.T) { 533 + t.Parallel() 534 + 535 + tests := []struct { 536 + name string 537 + testFunc func(t *testing.T, rl *quotaLimiter) 538 + }{ 539 + { 540 + name: "Stats error", 541 + testFunc: func(t *testing.T, rl *quotaLimiter) { 542 + _, _, err := rl.Stats() 543 + if err == nil { 544 + t.Error("expected error from KV store") 545 + } 546 + }, 547 + }, 548 + { 549 + name: "AllowBulkWrite error", 550 + testFunc: func(t *testing.T, rl *quotaLimiter) { 551 + ctx := context.Background() 552 + _, err := rl.AllowBulkWrite(ctx, 1) 553 + if err == nil { 554 + t.Error("expected error from KV store") 555 + } 556 + }, 557 + }, 558 + { 559 + name: "AllowRead error", 560 + testFunc: func(t *testing.T, rl *quotaLimiter) { 561 + ctx := context.Background() 562 + _, err := rl.AllowRead(ctx) 563 + if err == nil { 564 + t.Error("expected error from KV store") 565 + } 566 + }, 567 + }, 568 + { 569 + name: "charge error", 570 + testFunc: func(t *testing.T, rl *quotaLimiter) { 571 + err := rl.charge([]string{"test:writes:m:2024-01-15-12-00"}, nil, 1, 0) 572 + if err == nil { 573 + t.Error("expected error from KV store") 574 + } 575 + }, 576 + }, 577 + { 578 + name: "checkQuota error", 579 + testFunc: func(t *testing.T, rl *quotaLimiter) { 580 + now := time.Now() 581 + _, err := rl.checkQuota(now, []string{"test:writes:m:2024-01-15-12-00"}, nil, 1, 0) 582 + if err == nil { 583 + t.Error("expected error from KV store") 584 + } 585 + }, 586 + }, 587 + } 588 + 589 + for _, tt := range tests { 590 + t.Run(tt.name, func(t *testing.T) { 591 + t.Parallel() 592 + 593 + kv := &mockKVStoreWithErr{Data: make(map[string]int), Err: fmt.Errorf("KV error")} 594 + clock := &mockClock{now: time.Date(2024, 1, 15, 12, 0, 0, 0, time.UTC)} 595 + 596 + rl := &quotaLimiter{ 597 + kv: kv, 598 + prefix: "quota", 599 + clock: clock, 600 + rlQuota: 1.0, 601 + } 602 + 603 + tt.testFunc(t, rl) 604 + }) 605 + } 606 + } 607 + 608 + func TestQuotaLimiter_Refund_KVErrors(t *testing.T) { 609 + t.Parallel() 610 + 611 + tests := []struct { 612 + name string 613 + setupKV func() *mockKVStoreWithErr 614 + refundFn func(rl *quotaLimiter, ctx context.Context) 615 + }{ 616 + { 617 + name: "RefundBulkWrite error", 618 + setupKV: func() *mockKVStoreWithErr { 619 + return &mockKVStoreWithErr{Data: map[string]int{ 620 + "quota:writes:m:2024-01-15-12-00": 10, 621 + }, Err: fmt.Errorf("KV error")} 622 + }, 623 + refundFn: func(rl *quotaLimiter, ctx context.Context) { 624 + rl.RefundBulkWrite(ctx, 3, time.Now()) 625 + }, 626 + }, 627 + { 628 + name: "RefundRead error", 629 + setupKV: func() *mockKVStoreWithErr { 630 + return &mockKVStoreWithErr{Data: map[string]int{ 631 + "quota:global:m:2024-01-15-12-00": 5, 632 + }, Err: fmt.Errorf("KV error")} 633 + }, 634 + refundFn: func(rl *quotaLimiter, ctx context.Context) { 635 + rl.RefundRead(ctx, time.Now()) 636 + }, 637 + }, 638 + } 639 + 640 + for _, tt := range tests { 641 + t.Run(tt.name, func(t *testing.T) { 642 + t.Parallel() 643 + 644 + kv := tt.setupKV() 645 + clock := &mockClock{now: time.Date(2024, 1, 15, 12, 0, 0, 0, time.UTC)} 646 + 647 + rl := &quotaLimiter{ 648 + kv: kv, 649 + prefix: "quota", 650 + clock: clock, 651 + rlQuota: 1.0, 652 + } 653 + 654 + ctx := context.Background() 655 + tt.refundFn(rl, ctx) 656 + }) 657 + } 658 + } 659 + 660 + func TestQuotaLimiter_MutexProtection(t *testing.T) { 661 + kv := &mockKVStore{data: make(map[string]int)} 662 + clock := &mockClock{now: time.Date(2024, 1, 15, 12, 0, 0, 0, time.UTC)} 663 + 664 + rl := &quotaLimiter{ 665 + kv: kv, 666 + prefix: "quota", 667 + clock: clock, 668 + rlQuota: 1.0, 669 + } 670 + 671 + ctx := context.Background() 672 + done := make(chan bool) 673 + errors := make(chan error, 10) 674 + 675 + for range 10 { 676 + go func() { 677 + _, err := rl.AllowBulkWrite(ctx, 1) 678 + if err != nil { 679 + errors <- err 680 + } 681 + done <- true 682 + }() 683 + } 684 + 685 + for range 10 { 686 + <-done 687 + } 688 + 689 + select { 690 + case err := <-errors: 691 + t.Errorf("concurrent AllowBulkWrite returned error: %v", err) 692 + default: 693 + } 694 + }
+246
atproto/repo.go
··· 1 + package atproto 2 + 3 + import ( 4 + "context" 5 + "encoding/json" 6 + "errors" 7 + "fmt" 8 + "log/slog" 9 + "net" 10 + "time" 11 + 12 + "github.com/bluesky-social/indigo/atproto/atclient" 13 + "github.com/bluesky-social/indigo/atproto/syntax" 14 + ) 15 + 16 + type RecordRef[T any] struct { 17 + URI string 18 + CID string 19 + Value T 20 + } 21 + 22 + type RepoClient[T any] interface { 23 + ListRecords(ctx context.Context, collection string, limit int, cursor string) ([]RecordRef[T], string, error) 24 + ApplyWrites(ctx context.Context, collection string, records []T) error 25 + DeleteRecord(ctx context.Context, collection, rkey string) error 26 + } 27 + 28 + type RepoClientFuncs[T any] struct { 29 + ListRecordsFn func(ctx context.Context, collection string, limit int, cursor string) ([]RecordRef[T], string, error) 30 + ApplyWritesFn func(ctx context.Context, collection string, records []T) error 31 + DeleteRecordFn func(ctx context.Context, collection, rkey string) error 32 + } 33 + 34 + func (c RepoClientFuncs[T]) ListRecords(ctx context.Context, collection string, limit int, cursor string) ([]RecordRef[T], string, error) { 35 + return c.ListRecordsFn(ctx, collection, limit, cursor) 36 + } 37 + 38 + func (c RepoClientFuncs[T]) ApplyWrites(ctx context.Context, collection string, records []T) error { 39 + return c.ApplyWritesFn(ctx, collection, records) 40 + } 41 + 42 + func (c RepoClientFuncs[T]) DeleteRecord(ctx context.Context, collection, rkey string) error { 43 + return c.DeleteRecordFn(ctx, collection, rkey) 44 + } 45 + 46 + type RateClient[T any] struct { 47 + client *atclient.APIClient 48 + did string 49 + limiter RateLimiter 50 + } 51 + 52 + func NewRateClient[T any](client *atclient.APIClient, did string, limiter RateLimiter) *RateClient[T] { 53 + return &RateClient[T]{ 54 + client: client, 55 + did: did, 56 + limiter: limiter, 57 + } 58 + } 59 + 60 + func (c *RateClient[T]) ListRecords(ctx context.Context, collection string, limit int, cursor string) ([]RecordRef[T], string, error) { 61 + if c.client == nil { 62 + return nil, "", fmt.Errorf("client cannot be nil") 63 + } 64 + 65 + var outResp struct { 66 + Records []struct { 67 + URI string `json:"uri"` 68 + CID string `json:"cid"` 69 + Value map[string]any `json:"value"` 70 + } `json:"records"` 71 + Cursor string `json:"cursor"` 72 + } 73 + 74 + var chargedAt time.Time 75 + if c.limiter != nil { 76 + slog.Debug("waiting for rate limit (read)") 77 + var err error 78 + chargedAt, err = c.limiter.AllowRead(ctx) 79 + if err != nil { 80 + slog.Error("rate limit wait cancelled/failed (read)", Error(err)) 81 + return nil, "", err 82 + } 83 + } 84 + 85 + err := c.client.Get(ctx, syntax.NSID("com.atproto.repo.listRecords"), map[string]any{ 86 + "repo": c.did, 87 + "collection": collection, 88 + "limit": limit, 89 + "cursor": cursor, 90 + }, &outResp) 91 + if err != nil { 92 + if c.limiter != nil && IsTransientError(err) { 93 + c.limiter.RefundRead(ctx, chargedAt) 94 + } 95 + return nil, "", err 96 + } 97 + 98 + out := make([]RecordRef[T], 0, len(outResp.Records)) 99 + for _, r := range outResp.Records { 100 + var value T 101 + if r.Value != nil { 102 + b, err := json.Marshal(r.Value) 103 + if err != nil { 104 + slog.Debug("failed to marshal record value", slog.String("uri", r.URI), Error(err)) 105 + continue 106 + } 107 + if err := json.Unmarshal(b, &value); err != nil { 108 + slog.Debug("failed to unmarshal record", slog.String("uri", r.URI), Error(err)) 109 + continue 110 + } 111 + } 112 + out = append(out, RecordRef[T]{ 113 + URI: r.URI, 114 + CID: r.CID, 115 + Value: value, 116 + }) 117 + } 118 + 119 + return out, outResp.Cursor, nil 120 + } 121 + 122 + func (c *RateClient[T]) ApplyWrites(ctx context.Context, collection string, records []T) error { 123 + if len(records) == 0 { 124 + return nil 125 + } 126 + 127 + if c.client == nil { 128 + return fmt.Errorf("client cannot be nil") 129 + } 130 + 131 + var chargedAt time.Time 132 + if c.limiter != nil { 133 + var err error 134 + chargedAt, err = c.limiter.AllowBulkWrite(ctx, len(records)) 135 + if err != nil { 136 + slog.Error("rate limit wait cancelled/failed (write)", 137 + slog.String("did", c.did), 138 + slog.String("collection", collection), 139 + Error(err), 140 + ) 141 + return err 142 + } 143 + } 144 + 145 + err := applyWrites(ctx, c.client, c.did, collection, records) 146 + if err != nil && IsTransientError(err) && c.limiter != nil { 147 + c.limiter.RefundBulkWrite(ctx, len(records), chargedAt) 148 + } 149 + return err 150 + } 151 + 152 + func applyWrites[T any](ctx context.Context, client *atclient.APIClient, did, collection string, records []T) error { 153 + if len(records) == 0 { 154 + return nil 155 + } 156 + 157 + if len(records) > 200 { 158 + return fmt.Errorf("too many records in one ApplyWrites call: %d (max 200)", len(records)) 159 + } 160 + 161 + writes, err := prepareWrites(records, collection) 162 + if err != nil { 163 + return err 164 + } 165 + 166 + return client.Post(ctx, syntax.NSID("com.atproto.repo.applyWrites"), map[string]any{ 167 + "repo": did, 168 + "writes": writes, 169 + }, nil) 170 + } 171 + 172 + func prepareWrites[T any](records []T, collection string) ([]map[string]any, error) { 173 + if len(records) == 0 { 174 + return nil, nil 175 + } 176 + 177 + writes := make([]map[string]any, len(records)) 178 + for i, rec := range records { 179 + b, err := json.Marshal(rec) 180 + if err != nil { 181 + return nil, fmt.Errorf("failed to marshal record: %w", err) 182 + } 183 + var recMap map[string]any 184 + if err := json.Unmarshal(b, &recMap); err != nil { 185 + return nil, fmt.Errorf("failed to unmarshal record: %w", err) 186 + } 187 + recMap["$type"] = "com.atproto.repo.applyWrites#create" 188 + recMap["collection"] = collection 189 + writes[i] = recMap 190 + } 191 + 192 + return writes, nil 193 + } 194 + 195 + func (c *RateClient[T]) DeleteRecord(ctx context.Context, collection, rkey string) error { 196 + if c.client == nil { 197 + return fmt.Errorf("client is nil") 198 + } 199 + 200 + var chargedAt time.Time 201 + if c.limiter != nil { 202 + slog.Debug("waiting for rate limit (delete)") 203 + var err error 204 + chargedAt, err = c.limiter.AllowBulkWrite(ctx, 1) 205 + if err != nil { 206 + slog.Error("rate limit wait cancelled/failed (delete)", Error(err)) 207 + return err 208 + } 209 + } 210 + 211 + _, err := c.client.Do(ctx, &atclient.APIRequest{ 212 + Method: "DELETE", 213 + Endpoint: syntax.NSID("com.atproto.repo.deleteRecord"), 214 + QueryParams: map[string][]string{ 215 + "repo": {c.did}, 216 + "collection": {collection}, 217 + "rkey": {rkey}, 218 + }, 219 + }) 220 + if err != nil && c.limiter != nil && IsTransientError(err) { 221 + c.limiter.RefundBulkWrite(ctx, 1, chargedAt) 222 + } 223 + return err 224 + } 225 + 226 + func IsTransientError(err error) bool { 227 + if err == nil { 228 + return false 229 + } 230 + 231 + var apiErr *atclient.APIError 232 + if errors.As(err, &apiErr) { 233 + switch apiErr.StatusCode { 234 + case 429, 500, 502, 503, 504: 235 + return true 236 + } 237 + return false 238 + } 239 + 240 + var netErr net.Error 241 + if errors.As(err, &netErr) { 242 + return netErr.Timeout() 243 + } 244 + 245 + return false 246 + }
+945
atproto/repo_test.go
··· 1 + package atproto 2 + 3 + import ( 4 + "context" 5 + "encoding/json" 6 + "errors" 7 + "net/http" 8 + "net/http/httptest" 9 + "testing" 10 + "time" 11 + 12 + "github.com/bluesky-social/indigo/atproto/atclient" 13 + ) 14 + 15 + func TestRateClient_ListRecords(t *testing.T) { 16 + t.Parallel() 17 + 18 + tests := []struct { 19 + name string 20 + setupServer func() *httptest.Server 21 + client *atclient.APIClient 22 + limiter RateLimiter 23 + collection string 24 + limit int 25 + cursor string 26 + wantRecords int 27 + wantCursor string 28 + wantErr bool 29 + }{ 30 + { 31 + name: "success", 32 + setupServer: func() *httptest.Server { 33 + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 34 + if r.URL.Path != "/xrpc/com.atproto.repo.listRecords" { 35 + t.Errorf("unexpected path: %s", r.URL.Path) 36 + } 37 + if r.URL.Query().Get("repo") != "did:plc:test" { 38 + t.Errorf("unexpected repo: %s", r.URL.Query().Get("repo")) 39 + } 40 + if r.URL.Query().Get("collection") != "app.bsky.feed.post" { 41 + t.Errorf("unexpected collection: %s", r.URL.Query().Get("collection")) 42 + } 43 + 44 + response := map[string]any{ 45 + "records": []any{ 46 + map[string]any{ 47 + "uri": "at://did:plc:test/app.bsky.feed.post/3k5x3x2x1", 48 + "cid": "bafyre...", 49 + "value": map[string]any{ 50 + "$type": "app.bsky.feed.post", 51 + "text": "Hello world", 52 + }, 53 + }, 54 + }, 55 + "cursor": "next-cursor", 56 + } 57 + w.Header().Set("Content-Type", "application/json") 58 + json.NewEncoder(w).Encode(response) 59 + })) 60 + }, 61 + client: func() *atclient.APIClient { 62 + return &atclient.APIClient{} 63 + }(), 64 + collection: "app.bsky.feed.post", 65 + limit: 10, 66 + cursor: "", 67 + wantRecords: 1, 68 + wantCursor: "next-cursor", 69 + wantErr: false, 70 + }, 71 + { 72 + name: "nil client", 73 + setupServer: func() *httptest.Server { 74 + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) 75 + }, 76 + client: nil, 77 + collection: "app.bsky.feed.post", 78 + limit: 10, 79 + cursor: "", 80 + wantRecords: 0, 81 + wantErr: true, 82 + }, 83 + { 84 + name: "with limiter", 85 + setupServer: func() *httptest.Server { 86 + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 87 + response := map[string]any{ 88 + "records": []any{}, 89 + } 90 + w.Header().Set("Content-Type", "application/json") 91 + json.NewEncoder(w).Encode(response) 92 + })) 93 + }, 94 + client: func() *atclient.APIClient { 95 + return &atclient.APIClient{} 96 + }(), 97 + limiter: NewRateLimiter(&mockKVStore{data: make(map[string]int)}, 1.0), 98 + collection: "app.bsky.feed.post", 99 + limit: 10, 100 + cursor: "", 101 + wantRecords: 0, 102 + wantErr: false, 103 + }, 104 + { 105 + name: "empty records", 106 + setupServer: func() *httptest.Server { 107 + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 108 + w.Header().Set("Content-Type", "application/json") 109 + json.NewEncoder(w).Encode(map[string]any{ 110 + "records": []any{}, 111 + "cursor": "", 112 + }) 113 + })) 114 + }, 115 + client: func() *atclient.APIClient { 116 + return &atclient.APIClient{} 117 + }(), 118 + collection: "app.bsky.feed.post", 119 + limit: 10, 120 + cursor: "", 121 + wantRecords: 0, 122 + wantCursor: "", 123 + wantErr: false, 124 + }, 125 + { 126 + name: "4xx error", 127 + setupServer: func() *httptest.Server { 128 + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 129 + w.WriteHeader(http.StatusBadRequest) 130 + })) 131 + }, 132 + client: func() *atclient.APIClient { 133 + return &atclient.APIClient{} 134 + }(), 135 + collection: "app.bsky.feed.post", 136 + limit: 10, 137 + cursor: "", 138 + wantRecords: 0, 139 + wantErr: true, 140 + }, 141 + { 142 + name: "5xx error", 143 + setupServer: func() *httptest.Server { 144 + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 145 + w.WriteHeader(http.StatusInternalServerError) 146 + })) 147 + }, 148 + client: func() *atclient.APIClient { 149 + return &atclient.APIClient{} 150 + }(), 151 + collection: "app.bsky.feed.post", 152 + limit: 10, 153 + cursor: "", 154 + wantRecords: 0, 155 + wantErr: true, 156 + }, 157 + } 158 + 159 + for _, tt := range tests { 160 + t.Run(tt.name, func(t *testing.T) { 161 + t.Parallel() 162 + 163 + server := tt.setupServer() 164 + defer server.Close() 165 + 166 + if tt.client != nil && tt.client.Client == nil { 167 + tt.client.Client = server.Client() 168 + tt.client.Host = server.URL 169 + } 170 + 171 + rateClient := NewRateClient[map[string]any](tt.client, "did:plc:test", tt.limiter) 172 + ctx := context.Background() 173 + 174 + records, cursor, err := rateClient.ListRecords(ctx, tt.collection, tt.limit, tt.cursor) 175 + 176 + if tt.wantErr { 177 + if err == nil { 178 + t.Error("expected error, got nil") 179 + } 180 + return 181 + } 182 + 183 + if err != nil { 184 + t.Fatalf("ListRecords failed: %v", err) 185 + } 186 + 187 + if len(records) != tt.wantRecords { 188 + t.Errorf("ListRecords returned %d records, want %d", len(records), tt.wantRecords) 189 + } 190 + if tt.wantCursor != "" && cursor != tt.wantCursor { 191 + t.Errorf("ListRecords cursor = %s, want %s", cursor, tt.wantCursor) 192 + } 193 + }) 194 + } 195 + } 196 + 197 + func TestRateClient_ListRecords_NetworkTimeout(t *testing.T) { 198 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 199 + time.Sleep(100 * time.Millisecond) 200 + w.WriteHeader(http.StatusOK) 201 + })) 202 + defer server.Close() 203 + 204 + client := atclient.APIClient{ 205 + Client: &http.Client{Timeout: 10 * time.Millisecond}, 206 + Host: server.URL, 207 + } 208 + 209 + rateClient := NewRateClient[map[string]any](&client, "did:plc:test", nil) 210 + ctx := context.Background() 211 + 212 + _, _, err := rateClient.ListRecords(ctx, "app.bsky.feed.post", 10, "") 213 + 214 + if err == nil { 215 + t.Error("expected error for network timeout") 216 + } 217 + } 218 + 219 + func TestRateClient_ListRecords_ContextCancelled(t *testing.T) { 220 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 221 + select { 222 + case <-time.After(100 * time.Millisecond): 223 + w.WriteHeader(http.StatusOK) 224 + case <-r.Context().Done(): 225 + return 226 + } 227 + })) 228 + defer server.Close() 229 + 230 + client := atclient.APIClient{ 231 + Client: server.Client(), 232 + Host: server.URL, 233 + } 234 + 235 + rateClient := NewRateClient[map[string]any](&client, "did:plc:test", nil) 236 + ctx, cancel := context.WithCancel(context.Background()) 237 + cancel() 238 + 239 + _, _, err := rateClient.ListRecords(ctx, "app.bsky.feed.post", 10, "") 240 + 241 + if err == nil { 242 + t.Error("expected error for cancelled context") 243 + } 244 + } 245 + 246 + func TestRateClient_ListRecords_WithPagination(t *testing.T) { 247 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 248 + cursor := r.URL.Query().Get("cursor") 249 + if cursor == "page1" { 250 + w.Header().Set("Content-Type", "application/json") 251 + json.NewEncoder(w).Encode(map[string]any{ 252 + "records": []any{}, 253 + "cursor": "page2", 254 + }) 255 + } else { 256 + w.Header().Set("Content-Type", "application/json") 257 + json.NewEncoder(w).Encode(map[string]any{ 258 + "records": []any{ 259 + map[string]any{ 260 + "uri": "at://did:plc:test/app.bsky.feed.post/1", 261 + "cid": "bafyre...", 262 + "value": map[string]any{ 263 + "$type": "app.bsky.feed.post", 264 + "text": "test", 265 + }, 266 + }, 267 + }, 268 + "cursor": "page1", 269 + }) 270 + } 271 + })) 272 + defer server.Close() 273 + 274 + client := atclient.APIClient{ 275 + Client: server.Client(), 276 + Host: server.URL, 277 + } 278 + 279 + rateClient := NewRateClient[map[string]any](&client, "did:plc:test", nil) 280 + ctx := context.Background() 281 + 282 + records, cursor, err := rateClient.ListRecords(ctx, "app.bsky.feed.post", 10, "") 283 + if err != nil { 284 + t.Fatalf("ListRecords failed: %v", err) 285 + } 286 + if len(records) != 1 { 287 + t.Errorf("ListRecords returned %d records, want 1", len(records)) 288 + } 289 + if cursor != "page1" { 290 + t.Errorf("ListRecords cursor = %s, want page1", cursor) 291 + } 292 + } 293 + 294 + func TestRateClient_ApplyWrites(t *testing.T) { 295 + t.Parallel() 296 + 297 + tests := []struct { 298 + name string 299 + setupServer func() *httptest.Server 300 + client *atclient.APIClient 301 + records []map[string]any 302 + wantErr bool 303 + }{ 304 + { 305 + name: "success", 306 + setupServer: func() *httptest.Server { 307 + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 308 + if r.URL.Path != "/xrpc/com.atproto.repo.applyWrites" { 309 + t.Errorf("unexpected path: %s", r.URL.Path) 310 + } 311 + w.WriteHeader(http.StatusOK) 312 + })) 313 + }, 314 + client: func() *atclient.APIClient { 315 + return &atclient.APIClient{} 316 + }(), 317 + records: []map[string]any{ 318 + {"$type": "app.bsky.feed.post", "text": "Hello"}, 319 + {"$type": "app.bsky.feed.post", "text": "World"}, 320 + }, 321 + wantErr: false, 322 + }, 323 + { 324 + name: "empty records", 325 + setupServer: func() *httptest.Server { 326 + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) 327 + }, 328 + client: &atclient.APIClient{}, 329 + records: []map[string]any{}, 330 + wantErr: false, 331 + }, 332 + { 333 + name: "nil client", 334 + setupServer: func() *httptest.Server { 335 + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) 336 + }, 337 + client: nil, 338 + records: []map[string]any{ 339 + {"$type": "app.bsky.feed.post", "text": "Hello"}, 340 + }, 341 + wantErr: true, 342 + }, 343 + { 344 + name: "too many records", 345 + setupServer: func() *httptest.Server { 346 + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) 347 + }, 348 + client: &atclient.APIClient{}, 349 + records: make([]map[string]any, 201), 350 + wantErr: true, 351 + }, 352 + { 353 + name: "4xx error", 354 + setupServer: func() *httptest.Server { 355 + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 356 + w.WriteHeader(http.StatusBadRequest) 357 + })) 358 + }, 359 + client: func() *atclient.APIClient { 360 + return &atclient.APIClient{} 361 + }(), 362 + records: []map[string]any{ 363 + {"$type": "app.bsky.feed.post", "text": "Hello"}, 364 + }, 365 + wantErr: true, 366 + }, 367 + { 368 + name: "5xx error", 369 + setupServer: func() *httptest.Server { 370 + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 371 + w.WriteHeader(http.StatusInternalServerError) 372 + })) 373 + }, 374 + client: func() *atclient.APIClient { 375 + return &atclient.APIClient{} 376 + }(), 377 + records: []map[string]any{ 378 + {"$type": "app.bsky.feed.post", "text": "Hello"}, 379 + }, 380 + wantErr: true, 381 + }, 382 + } 383 + 384 + for _, tt := range tests { 385 + t.Run(tt.name, func(t *testing.T) { 386 + t.Parallel() 387 + 388 + server := tt.setupServer() 389 + defer server.Close() 390 + 391 + if tt.client != nil && tt.client.Client == nil { 392 + tt.client.Client = server.Client() 393 + tt.client.Host = server.URL 394 + } 395 + 396 + rateClient := NewRateClient[map[string]any](tt.client, "did:plc:test", nil) 397 + ctx := context.Background() 398 + 399 + err := rateClient.ApplyWrites(ctx, "app.bsky.feed.post", tt.records) 400 + 401 + if tt.wantErr && err == nil { 402 + t.Error("expected error, got nil") 403 + } 404 + if !tt.wantErr && err != nil { 405 + t.Fatalf("ApplyWrites failed: %v", err) 406 + } 407 + }) 408 + } 409 + } 410 + 411 + func TestRateClient_ApplyWrites_NetworkTimeout(t *testing.T) { 412 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 413 + time.Sleep(100 * time.Millisecond) 414 + w.WriteHeader(http.StatusOK) 415 + })) 416 + defer server.Close() 417 + 418 + client := atclient.APIClient{ 419 + Client: &http.Client{Timeout: 10 * time.Millisecond}, 420 + Host: server.URL, 421 + } 422 + 423 + rateClient := NewRateClient[map[string]any](&client, "did:plc:test", nil) 424 + ctx := context.Background() 425 + 426 + records := []map[string]any{ 427 + {"$type": "app.bsky.feed.post", "text": "Hello"}, 428 + } 429 + 430 + err := rateClient.ApplyWrites(ctx, "app.bsky.feed.post", records) 431 + 432 + if err == nil { 433 + t.Error("expected error for network timeout") 434 + } 435 + } 436 + 437 + func TestRateClient_ApplyWrites_WithLimiter(t *testing.T) { 438 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 439 + w.WriteHeader(http.StatusOK) 440 + })) 441 + defer server.Close() 442 + 443 + client := atclient.APIClient{ 444 + Client: server.Client(), 445 + Host: server.URL, 446 + } 447 + 448 + mockKV := &mockKVStore{data: make(map[string]int)} 449 + limiter := NewRateLimiter(mockKV, 1.0) 450 + 451 + rateClient := NewRateClient[map[string]any](&client, "did:plc:test", limiter) 452 + ctx := context.Background() 453 + 454 + records := []map[string]any{ 455 + {"$type": "app.bsky.feed.post", "text": "Hello"}, 456 + } 457 + 458 + err := rateClient.ApplyWrites(ctx, "app.bsky.feed.post", records) 459 + if err != nil { 460 + t.Fatalf("ApplyWrites failed: %v", err) 461 + } 462 + } 463 + 464 + func TestRateClient_DeleteRecord(t *testing.T) { 465 + t.Parallel() 466 + 467 + tests := []struct { 468 + name string 469 + setupServer func() *httptest.Server 470 + client *atclient.APIClient 471 + wantErr bool 472 + }{ 473 + { 474 + name: "success", 475 + setupServer: func() *httptest.Server { 476 + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 477 + if r.Method != "DELETE" { 478 + t.Errorf("unexpected method: %s", r.Method) 479 + } 480 + if r.URL.Path != "/xrpc/com.atproto.repo.deleteRecord" { 481 + t.Errorf("unexpected path: %s", r.URL.Path) 482 + } 483 + if r.URL.Query().Get("repo") != "did:plc:test" { 484 + t.Errorf("unexpected repo: %s", r.URL.Query().Get("repo")) 485 + } 486 + if r.URL.Query().Get("collection") != "app.bsky.feed.post" { 487 + t.Errorf("unexpected collection: %s", r.URL.Query().Get("collection")) 488 + } 489 + if r.URL.Query().Get("rkey") != "3k5x3x2x1" { 490 + t.Errorf("unexpected rkey: %s", r.URL.Query().Get("rkey")) 491 + } 492 + w.WriteHeader(http.StatusOK) 493 + })) 494 + }, 495 + client: func() *atclient.APIClient { 496 + return &atclient.APIClient{} 497 + }(), 498 + wantErr: false, 499 + }, 500 + { 501 + name: "nil client", 502 + setupServer: func() *httptest.Server { 503 + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) 504 + }, 505 + client: nil, 506 + wantErr: true, 507 + }, 508 + { 509 + name: "4xx error (skipped)", 510 + setupServer: func() *httptest.Server { 511 + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 512 + w.WriteHeader(http.StatusBadRequest) 513 + })) 514 + }, 515 + client: func() *atclient.APIClient { 516 + return &atclient.APIClient{} 517 + }(), 518 + wantErr: true, 519 + }, 520 + { 521 + name: "5xx error (skipped)", 522 + setupServer: func() *httptest.Server { 523 + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 524 + w.WriteHeader(http.StatusInternalServerError) 525 + })) 526 + }, 527 + client: func() *atclient.APIClient { 528 + return &atclient.APIClient{} 529 + }(), 530 + wantErr: true, 531 + }, 532 + } 533 + 534 + for _, tt := range tests { 535 + t.Run(tt.name, func(t *testing.T) { 536 + t.Parallel() 537 + 538 + server := tt.setupServer() 539 + defer server.Close() 540 + 541 + if tt.client != nil && tt.client.Client == nil { 542 + tt.client.Client = server.Client() 543 + tt.client.Host = server.URL 544 + } 545 + 546 + rateClient := NewRateClient[map[string]any](tt.client, "did:plc:test", nil) 547 + ctx := context.Background() 548 + 549 + err := rateClient.DeleteRecord(ctx, "app.bsky.feed.post", "3k5x3x2x1") 550 + 551 + if tt.wantErr && err == nil { 552 + t.Skip("atclient.Do may not return errors for HTTP status codes") 553 + } 554 + if !tt.wantErr && err != nil { 555 + t.Fatalf("DeleteRecord failed: %v", err) 556 + } 557 + }) 558 + } 559 + } 560 + 561 + func TestRateClient_DeleteRecord_NetworkTimeout(t *testing.T) { 562 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 563 + time.Sleep(100 * time.Millisecond) 564 + w.WriteHeader(http.StatusOK) 565 + })) 566 + defer server.Close() 567 + 568 + client := atclient.APIClient{ 569 + Client: &http.Client{Timeout: 10 * time.Millisecond}, 570 + Host: server.URL, 571 + } 572 + 573 + rateClient := NewRateClient[map[string]any](&client, "did:plc:test", nil) 574 + ctx := context.Background() 575 + 576 + err := rateClient.DeleteRecord(ctx, "app.bsky.feed.post", "3k5x3x2x1") 577 + 578 + if err == nil { 579 + t.Error("expected error for network timeout") 580 + } 581 + } 582 + 583 + func TestRepoClientFuncs(t *testing.T) { 584 + t.Parallel() 585 + 586 + tests := []struct { 587 + name string 588 + client *RepoClientFuncs[map[string]any] 589 + checkFunc func(t *testing.T, err error) 590 + }{ 591 + { 592 + name: "ListRecords", 593 + client: &RepoClientFuncs[map[string]any]{ 594 + ListRecordsFn: func(ctx context.Context, collection string, limit int, cursor string) ([]RecordRef[map[string]any], string, error) { 595 + return []RecordRef[map[string]any]{ 596 + {URI: "at://did:plc:test/app.bsky.feed.post/1", Value: map[string]any{"text": "test"}}, 597 + }, "cursor123", nil 598 + }, 599 + }, 600 + checkFunc: func(t *testing.T, err error) { 601 + if err != nil { 602 + t.Fatalf("ListRecords failed: %v", err) 603 + } 604 + }, 605 + }, 606 + { 607 + name: "ApplyWrites", 608 + client: &RepoClientFuncs[map[string]any]{ 609 + ApplyWritesFn: func(ctx context.Context, collection string, records []map[string]any) error { 610 + return errors.New("write failed") 611 + }, 612 + }, 613 + checkFunc: func(t *testing.T, err error) { 614 + if err == nil { 615 + t.Error("expected error, got nil") 616 + } 617 + }, 618 + }, 619 + { 620 + name: "DeleteRecord", 621 + client: &RepoClientFuncs[map[string]any]{ 622 + DeleteRecordFn: func(ctx context.Context, collection, rkey string) error { 623 + return nil 624 + }, 625 + }, 626 + checkFunc: func(t *testing.T, err error) { 627 + if err != nil { 628 + t.Errorf("DeleteRecord failed: %v", err) 629 + } 630 + }, 631 + }, 632 + } 633 + 634 + for _, tt := range tests { 635 + t.Run(tt.name, func(t *testing.T) { 636 + t.Parallel() 637 + ctx := context.Background() 638 + 639 + switch tt.name { 640 + case "ListRecords": 641 + _, _, err := tt.client.ListRecords(ctx, "app.bsky.feed.post", 10, "") 642 + tt.checkFunc(t, err) 643 + case "ApplyWrites": 644 + err := tt.client.ApplyWrites(ctx, "app.bsky.feed.post", []map[string]any{}) 645 + tt.checkFunc(t, err) 646 + case "DeleteRecord": 647 + err := tt.client.DeleteRecord(ctx, "app.bsky.feed.post", "3k5x3x2x1") 648 + tt.checkFunc(t, err) 649 + } 650 + }) 651 + } 652 + } 653 + 654 + func TestPrepareWrites(t *testing.T) { 655 + t.Parallel() 656 + 657 + tests := []struct { 658 + name string 659 + records []map[string]any 660 + collection string 661 + wantCount int 662 + wantErr bool 663 + checkFunc func(t *testing.T, writes []map[string]any) 664 + }{ 665 + { 666 + name: "empty records", 667 + records: []map[string]any{}, 668 + collection: "app.bsky.feed.post", 669 + wantCount: 0, 670 + wantErr: false, 671 + checkFunc: func(t *testing.T, writes []map[string]any) { 672 + if writes != nil { 673 + t.Errorf("prepareWrites returned %v, want nil", writes) 674 + } 675 + }, 676 + }, 677 + { 678 + name: "success", 679 + records: []map[string]any{ 680 + {"$type": "app.bsky.feed.post", "text": "Hello"}, 681 + {"$type": "app.bsky.feed.post", "text": "World"}, 682 + }, 683 + collection: "app.bsky.feed.post", 684 + wantCount: 2, 685 + wantErr: false, 686 + checkFunc: func(t *testing.T, writes []map[string]any) { 687 + for i, w := range writes { 688 + if w["$type"] != "com.atproto.repo.applyWrites#create" { 689 + t.Errorf("write[%d] $type = %s, want com.atproto.repo.applyWrites#create", i, w["$type"]) 690 + } 691 + if w["collection"] != "app.bsky.feed.post" { 692 + t.Errorf("write[%d] collection = %s, want app.bsky.feed.post", i, w["collection"]) 693 + } 694 + } 695 + }, 696 + }, 697 + { 698 + name: "invalid JSON", 699 + records: []map[string]any{ 700 + {"text": "test", "channel": make(chan int)}, 701 + }, 702 + collection: "app.bsky.feed.post", 703 + wantCount: 0, 704 + wantErr: true, 705 + checkFunc: func(t *testing.T, writes []map[string]any) {}, 706 + }, 707 + { 708 + name: "with JSON-encoded data", 709 + records: []map[string]any{ 710 + {"raw": `{"$type":"app.bsky.feed.post","text":"test"}`}, 711 + }, 712 + collection: "app.bsky.feed.post", 713 + wantCount: 1, 714 + wantErr: false, 715 + checkFunc: func(t *testing.T, writes []map[string]any) {}, 716 + }, 717 + } 718 + 719 + for _, tt := range tests { 720 + t.Run(tt.name, func(t *testing.T) { 721 + t.Parallel() 722 + 723 + writes, err := prepareWrites(tt.records, tt.collection) 724 + 725 + if tt.wantErr { 726 + if err == nil { 727 + t.Error("expected error, got nil") 728 + } 729 + return 730 + } 731 + 732 + if err != nil { 733 + t.Fatalf("prepareWrites failed: %v", err) 734 + } 735 + 736 + if len(writes) != tt.wantCount { 737 + t.Errorf("prepareWrites returned %d writes, want %d", len(writes), tt.wantCount) 738 + } 739 + 740 + if tt.checkFunc != nil { 741 + tt.checkFunc(t, writes) 742 + } 743 + }) 744 + } 745 + } 746 + 747 + func TestApplyWrites_Empty(t *testing.T) { 748 + err := applyWrites[map[string]any](context.Background(), nil, "did:plc:test", "app.bsky.feed.post", []map[string]any{}) 749 + if err != nil { 750 + t.Errorf("applyWrites with empty records should not fail: %v", err) 751 + } 752 + } 753 + 754 + func TestIsTransientError(t *testing.T) { 755 + t.Parallel() 756 + 757 + tests := []struct { 758 + name string 759 + statusCode int 760 + expectError bool 761 + }{ 762 + {"TooManyRequests", 429, true}, 763 + {"InternalServerError", 500, true}, 764 + {"BadGateway", 502, true}, 765 + {"ServiceUnavailable", 503, true}, 766 + {"GatewayTimeout", 504, true}, 767 + {"OK", 200, false}, 768 + {"Created", 201, false}, 769 + {"BadRequest", 400, false}, 770 + {"Unauthorized", 401, false}, 771 + {"NotFound", 404, false}, 772 + } 773 + 774 + for _, tt := range tests { 775 + t.Run(tt.name, func(t *testing.T) { 776 + t.Parallel() 777 + 778 + apiErr := &atclient.APIError{StatusCode: tt.statusCode} 779 + result := IsTransientError(apiErr) 780 + 781 + if result != tt.expectError { 782 + t.Errorf("IsTransientError(status=%d) = %v, want %v", tt.statusCode, result, tt.expectError) 783 + } 784 + }) 785 + } 786 + } 787 + 788 + func TestIsTransientError_NetError(t *testing.T) { 789 + t.Parallel() 790 + 791 + tests := []struct { 792 + name string 793 + timeout bool 794 + wantError bool 795 + }{ 796 + {"timeout error", true, true}, 797 + {"non-timeout error", false, false}, 798 + } 799 + 800 + for _, tt := range tests { 801 + t.Run(tt.name, func(t *testing.T) { 802 + t.Parallel() 803 + 804 + err := &netError{timeout: tt.timeout} 805 + result := IsTransientError(err) 806 + 807 + if result != tt.wantError { 808 + t.Errorf("IsTransientError(timeout=%v) = %v, want %v", tt.timeout, result, tt.wantError) 809 + } 810 + }) 811 + } 812 + } 813 + 814 + func TestIsTransientError_Nil(t *testing.T) { 815 + result := IsTransientError(nil) 816 + 817 + if result { 818 + t.Error("IsTransientError(nil) should return false") 819 + } 820 + } 821 + 822 + func TestRecordRef(t *testing.T) { 823 + ref := RecordRef[string]{ 824 + URI: "at://did:plc:test/app.bsky.feed.post/1", 825 + CID: "bafyre...", 826 + Value: "test value", 827 + } 828 + 829 + if ref.URI != "at://did:plc:test/app.bsky.feed.post/1" { 830 + t.Errorf("RecordRef.URI = %s, want at://did:plc:test/app.bsky.feed.post/1", ref.URI) 831 + } 832 + if ref.CID != "bafyre..." { 833 + t.Errorf("RecordRef.CID = %s, want bafyre...", ref.CID) 834 + } 835 + if ref.Value != "test value" { 836 + t.Errorf("RecordRef.Value = %s, want test value", ref.Value) 837 + } 838 + } 839 + 840 + func TestNewRateClient(t *testing.T) { 841 + t.Parallel() 842 + 843 + tests := []struct { 844 + name string 845 + client *atclient.APIClient 846 + did string 847 + limiter RateLimiter 848 + check func(t *testing.T, rateClient *RateClient[map[string]any]) 849 + }{ 850 + { 851 + name: "nil limiter", 852 + client: &atclient.APIClient{}, 853 + did: "did:plc:test", 854 + limiter: nil, 855 + check: func(t *testing.T, rc *RateClient[map[string]any]) { 856 + if rc.limiter != nil { 857 + t.Error("NewRateClient with nil limiter should have nil limiter") 858 + } 859 + }, 860 + }, 861 + { 862 + name: "with limiter", 863 + client: &atclient.APIClient{}, 864 + did: "did:plc:test", 865 + limiter: NewRateLimiter(&mockKVStore{data: make(map[string]int)}, 1.0), 866 + check: func(t *testing.T, rc *RateClient[map[string]any]) { 867 + if rc.limiter == nil { 868 + t.Error("NewRateClient should store limiter") 869 + } 870 + }, 871 + }, 872 + } 873 + 874 + for _, tt := range tests { 875 + t.Run(tt.name, func(t *testing.T) { 876 + t.Parallel() 877 + 878 + rateClient := NewRateClient[map[string]any](tt.client, tt.did, tt.limiter) 879 + 880 + if rateClient.client != tt.client { 881 + t.Error("NewRateClient should store client") 882 + } 883 + if rateClient.did != tt.did { 884 + t.Error("NewRateClient should store did") 885 + } 886 + if tt.check != nil { 887 + tt.check(t, rateClient) 888 + } 889 + }) 890 + } 891 + } 892 + 893 + func TestRepoClient_Interfaces(t *testing.T) { 894 + var _ RepoClient[map[string]any] = (*RateClient[map[string]any])(nil) 895 + var _ RepoClient[map[string]any] = (*RepoClientFuncs[map[string]any])(nil) 896 + } 897 + 898 + func TestRateClient_ListRecords_WithTypedRecords(t *testing.T) { 899 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 900 + type testRecord struct { 901 + Text string `json:"text"` 902 + } 903 + records := []struct { 904 + URI string `json:"uri"` 905 + CID string `json:"cid"` 906 + Value testRecord `json:"value"` 907 + }{ 908 + { 909 + URI: "at://did:plc:test/app.bsky.feed.post/1", 910 + CID: "bafyre...", 911 + Value: testRecord{ 912 + Text: "Hello world", 913 + }, 914 + }, 915 + } 916 + w.Header().Set("Content-Type", "application/json") 917 + json.NewEncoder(w).Encode(map[string]any{ 918 + "records": records, 919 + }) 920 + })) 921 + defer server.Close() 922 + 923 + client := atclient.APIClient{ 924 + Client: server.Client(), 925 + Host: server.URL, 926 + } 927 + 928 + type testRecord struct { 929 + Text string `json:"text"` 930 + } 931 + 932 + rateClient := NewRateClient[testRecord](&client, "did:plc:test", nil) 933 + ctx := context.Background() 934 + 935 + records, _, err := rateClient.ListRecords(ctx, "app.bsky.feed.post", 10, "") 936 + if err != nil { 937 + t.Fatalf("ListRecords failed: %v", err) 938 + } 939 + if len(records) != 1 { 940 + t.Fatalf("ListRecords returned %d records, want 1", len(records)) 941 + } 942 + if records[0].Value.Text != "Hello world" { 943 + t.Errorf("Record value text = %s, want 'Hello world'", records[0].Value.Text) 944 + } 945 + }
+70
atproto/testmock.go
··· 1 + package atproto 2 + 3 + import ( 4 + "sync" 5 + "time" 6 + ) 7 + 8 + type mockKVStore struct { 9 + mu sync.Mutex 10 + data map[string]int 11 + gets [][]string 12 + incrs []map[string]int 13 + } 14 + 15 + func (m *mockKVStore) GetMulti(keys []string) (map[string]int, error) { 16 + m.mu.Lock() 17 + defer m.mu.Unlock() 18 + m.gets = append(m.gets, keys) 19 + result := make(map[string]int) 20 + for _, k := range keys { 21 + result[k] = m.data[k] 22 + } 23 + return result, nil 24 + } 25 + 26 + func (m *mockKVStore) IncrByMulti(counts map[string]int) error { 27 + m.mu.Lock() 28 + defer m.mu.Unlock() 29 + m.incrs = append(m.incrs, counts) 30 + for k, v := range counts { 31 + m.data[k] += v 32 + } 33 + return nil 34 + } 35 + 36 + type mockClock struct { 37 + mu sync.Mutex 38 + now time.Time 39 + nows []time.Time 40 + } 41 + 42 + func (m *mockClock) Now() time.Time { 43 + m.mu.Lock() 44 + defer m.mu.Unlock() 45 + m.nows = append(m.nows, m.now) 46 + return m.now 47 + } 48 + 49 + type mockKVStoreWithErr struct { 50 + mockKVStore 51 + Data map[string]int 52 + Err error 53 + } 54 + 55 + func (m *mockKVStoreWithErr) GetMulti(keys []string) (map[string]int, error) { 56 + return nil, m.Err 57 + } 58 + 59 + func (m *mockKVStoreWithErr) IncrByMulti(counts map[string]int) error { 60 + return m.Err 61 + } 62 + 63 + type netError struct { 64 + timeout bool 65 + temporary bool 66 + } 67 + 68 + func (e *netError) Error() string { return "net error" } 69 + func (e *netError) Timeout() bool { return e.timeout } 70 + func (e *netError) Temporary() bool { return e.temporary }
-164
flags.go
··· 1 - package main 2 - 3 - import ( 4 - "github.com/urfave/cli/v3" 5 - 6 - "tangled.org/karitham.dev/lazuli/sync" 7 - ) 8 - 9 - const ( 10 - DefaultBatchSize = 20 11 - ) 12 - 13 - const ( 14 - EnvHandle = "LAZULI_HANDLE" 15 - EnvPassword = "LAZULI_PASSWORD" 16 - EnvVerbose = "LAZULI_VERBOSE" 17 - EnvQuiet = "LAZULI_QUIET" 18 - EnvReverse = "LAZULI_REVERSE" 19 - EnvDryRun = "LAZULI_DRY_RUN" 20 - EnvFresh = "LAZULI_FRESH" 21 - EnvClearCache = "LAZULI_CLEAR_CACHE" 22 - EnvYes = "LAZULI_YES" 23 - ) 24 - 25 - var lastfmFlag = &cli.StringFlag{ 26 - Name: "lastfm", 27 - Usage: "Path to Last.fm CSV file or directory", 28 - Sources: cli.EnvVars("LAZULI_LASTFM"), 29 - } 30 - 31 - var spotifyFlag = &cli.StringFlag{ 32 - Name: "spotify", 33 - Usage: "Path to Spotify JSON/directory/zip", 34 - Sources: cli.EnvVars("LAZULI_SPOTIFY"), 35 - } 36 - 37 - var ( 38 - verboseCount int 39 - quietCount int 40 - ) 41 - 42 - var commonFlags = []cli.Flag{ 43 - &cli.StringFlag{ 44 - Name: "handle", 45 - Usage: "Bluesky handle", 46 - Sources: cli.EnvVars(EnvHandle), 47 - }, 48 - &cli.StringFlag{ 49 - Name: "password", 50 - Usage: "App password", 51 - Sources: cli.EnvVars(EnvPassword), 52 - }, 53 - &cli.BoolFlag{ 54 - Name: "verbose", 55 - Usage: "Enable verbose logging (-v for debug, -vv for trace)", 56 - Aliases: []string{"v"}, 57 - Sources: cli.EnvVars(EnvVerbose), 58 - Config: cli.BoolConfig{Count: &verboseCount}, 59 - }, 60 - &cli.BoolFlag{ 61 - Name: "quiet", 62 - Usage: "Suppress non-essential output (-q for warn, -qq for errors, -qqq for silent)", 63 - Aliases: []string{"q"}, 64 - Sources: cli.EnvVars(EnvQuiet), 65 - Config: cli.BoolConfig{Count: &quietCount}, 66 - }, 67 - &cli.StringFlag{ 68 - Name: "output-format", 69 - Usage: "Output format: text or json", 70 - Value: "text", 71 - Sources: cli.EnvVars("LAZULI_OUTPUT_FORMAT"), 72 - }, 73 - } 74 - 75 - var exportFlags = []cli.Flag{ 76 - lastfmFlag, 77 - spotifyFlag, 78 - &cli.StringFlag{ 79 - Name: "output", 80 - Usage: "Output file (stdout if not set)", 81 - Sources: cli.EnvVars("LAZULI_OUTPUT"), 82 - }, 83 - &cli.BoolFlag{ 84 - Name: "reverse", 85 - Usage: "Sort records reverse chronologically", 86 - Sources: cli.EnvVars(EnvReverse), 87 - }, 88 - &cli.DurationFlag{ 89 - Name: "tolerance", 90 - Usage: "Time tolerance for cross-source deduplication (e.g., 5m, 10m)", 91 - Value: sync.DefaultCrossSourceTolerance, 92 - Sources: cli.EnvVars("LAZULI_TOLERANCE"), 93 - }, 94 - } 95 - 96 - var importFlags = []cli.Flag{ 97 - lastfmFlag, 98 - spotifyFlag, 99 - &cli.StringFlag{ 100 - Name: "mode", 101 - Usage: "Import mode: lastfm, spotify, combined (default: combined)", 102 - Value: "combined", 103 - Sources: cli.EnvVars("LAZULI_MODE"), 104 - }, 105 - &cli.BoolFlag{ 106 - Name: "dry-run", 107 - Usage: "Preview without publishing", 108 - Sources: cli.EnvVars(EnvDryRun), 109 - }, 110 - &cli.BoolFlag{ 111 - Name: "reverse", 112 - Usage: "Import in reverse order", 113 - Sources: cli.EnvVars(EnvReverse), 114 - }, 115 - &cli.BoolFlag{ 116 - Name: "fresh", 117 - Usage: "Don't use cached Bluesky records", 118 - Sources: cli.EnvVars(EnvFresh), 119 - }, 120 - &cli.BoolFlag{ 121 - Name: "clear-cache", 122 - Usage: "Clear cache before running", 123 - Sources: cli.EnvVars(EnvClearCache), 124 - }, 125 - &cli.IntFlag{ 126 - Name: "batch-size", 127 - Usage: "Records per batch (default: 20)", 128 - Value: DefaultBatchSize, 129 - Sources: cli.EnvVars("LAZULI_BATCH_SIZE"), 130 - }, 131 - &cli.DurationFlag{ 132 - Name: "tolerance", 133 - Usage: "Time tolerance for cross-source deduplication (e.g., 5m, 10m)", 134 - Value: sync.DefaultCrossSourceTolerance, 135 - Sources: cli.EnvVars("LAZULI_TOLERANCE"), 136 - }, 137 - } 138 - 139 - var syncFlags = []cli.Flag{ 140 - &cli.BoolFlag{ 141 - Name: "fresh", 142 - Usage: "Force refresh cache", 143 - Sources: cli.EnvVars(EnvFresh), 144 - }, 145 - } 146 - 147 - var dedupeFlags = []cli.Flag{ 148 - &cli.BoolFlag{ 149 - Name: "dry-run", 150 - Usage: "Preview without deleting", 151 - Sources: cli.EnvVars(EnvDryRun), 152 - }, 153 - &cli.BoolFlag{ 154 - Name: "fresh", 155 - Usage: "Force refresh cache", 156 - Sources: cli.EnvVars(EnvFresh), 157 - }, 158 - &cli.BoolFlag{ 159 - Name: "yes", 160 - Usage: "Skip confirmation prompt", 161 - Aliases: []string{"y"}, 162 - Sources: cli.EnvVars(EnvYes), 163 - }, 164 - }
+225 -80
main.go
··· 15 15 "tangled.org/karitham.dev/lazuli/sources/lastfm" 16 16 "tangled.org/karitham.dev/lazuli/sources/spotify" 17 17 "tangled.org/karitham.dev/lazuli/sync" 18 - "tangled.org/karitham.dev/lazuli/sync/logutil" 19 18 20 19 "github.com/failsafe-go/failsafe-go" 21 20 "github.com/failsafe-go/failsafe-go/retrypolicy" ··· 24 23 25 24 var Version = "dev" 26 25 26 + const ( 27 + DefaultBatchSize = 20 28 + ) 29 + 30 + const ( 31 + EnvHandle = "LAZULI_HANDLE" 32 + EnvPassword = "LAZULI_PASSWORD" 33 + EnvVerbose = "LAZULI_VERBOSE" 34 + EnvQuiet = "LAZULI_QUIET" 35 + EnvReverse = "LAZULI_REVERSE" 36 + EnvDryRun = "LAZULI_DRY_RUN" 37 + EnvFresh = "LAZULI_FRESH" 38 + EnvClearCache = "LAZULI_CLEAR_CACHE" 39 + EnvYes = "LAZULI_YES" 40 + ) 41 + 42 + var ( 43 + verboseCount int 44 + quietCount int 45 + ) 46 + 27 47 type App struct { 28 48 log *slog.Logger 29 49 outputFormat string ··· 40 60 } 41 61 42 62 func run() error { 43 - sync.UserAgent = "lazuli/" + Version 44 63 sync.ClientAgent = "lazuli/" + Version 45 64 46 65 storage, err := cache.NewBoltStorage() ··· 154 173 } 155 174 } 156 175 176 + func (a *App) failedCommand() *cli.Command { 177 + flags := make([]cli.Flag, 0, len(commonFlags)) 178 + flags = append(flags, commonFlags...) 179 + return &cli.Command{ 180 + Name: "failed", 181 + Usage: "List records that failed to publish", 182 + UsageText: " lazuli failed --handle=user.bsky.social", 183 + Flags: flags, 184 + Action: a.runFailed, 185 + Before: a.initLoggerBefore, 186 + } 187 + } 188 + 189 + func (a *App) retryCommand() *cli.Command { 190 + flags := make([]cli.Flag, 0, len(commonFlags)+1) 191 + flags = append(flags, commonFlags...) 192 + flags = append(flags, &cli.BoolFlag{ 193 + Name: "dry-run", 194 + Usage: "Preview what will be retried", 195 + Sources: cli.EnvVars(EnvDryRun), 196 + }) 197 + return &cli.Command{ 198 + Name: "retry", 199 + Usage: "Retry failed records one by one", 200 + UsageText: " lazuli retry --handle=user.bsky.social", 201 + Flags: flags, 202 + Action: a.runRetry, 203 + Before: a.initLoggerBefore, 204 + } 205 + } 206 + 207 + func (a *App) versionCommand() *cli.Command { 208 + return &cli.Command{ 209 + Name: "version", 210 + Usage: "Print the version number", 211 + Action: func(ctx context.Context, cmd *cli.Command) error { 212 + fmt.Println(Version) 213 + return nil 214 + }, 215 + } 216 + } 217 + 157 218 func (a *App) runStats(ctx context.Context, cmd *cli.Command) error { 158 219 stats, err := a.storage.Stats() 159 220 if err != nil { ··· 208 269 return nil 209 270 } 210 271 211 - func (a *App) failedCommand() *cli.Command { 212 - flags := make([]cli.Flag, 0, len(commonFlags)) 213 - flags = append(flags, commonFlags...) 214 - return &cli.Command{ 215 - Name: "failed", 216 - Usage: "List records that failed to publish", 217 - UsageText: " lazuli failed --handle=user.bsky.social", 218 - Flags: flags, 219 - Action: a.runFailed, 220 - Before: a.initLoggerBefore, 221 - } 222 - } 223 - 224 - func (a *App) retryCommand() *cli.Command { 225 - flags := make([]cli.Flag, 0, len(commonFlags)+1) 226 - flags = append(flags, commonFlags...) 227 - flags = append(flags, &cli.BoolFlag{ 228 - Name: "dry-run", 229 - Usage: "Preview what will be retried", 230 - Sources: cli.EnvVars(EnvDryRun), 231 - }) 232 - return &cli.Command{ 233 - Name: "retry", 234 - Usage: "Retry failed records one by one", 235 - UsageText: " lazuli retry --handle=user.bsky.social", 236 - Flags: flags, 237 - Action: a.runRetry, 238 - Before: a.initLoggerBefore, 239 - } 240 - } 241 - 242 272 func (a *App) runRetry(ctx context.Context, cmd *cli.Command) error { 243 273 authClient, err := a.prepareAuth(ctx, cmd) 244 274 if err != nil { 245 275 return err 246 276 } 247 - did := authClient.GetDID() 277 + did := authClient.DID() 248 278 dryRun := cmd.Bool("dry-run") 249 279 250 280 limiter := sync.NewRateLimiter(a.storage, 0.9) 251 - repoClient := sync.NewRateClient(authClient.GetAPIClient(), did, limiter) 281 + repoClient := sync.NewRateClient(authClient.APIClient(), did, limiter) 252 282 253 283 var failedRecords []struct { 254 284 key string ··· 287 317 continue 288 318 } 289 319 290 - // Check rate limit for 1 write 291 - 292 320 res := sync.PublishBatch(ctx, repoClient, did, []sync.PlayRecord{fr.rec}, a.storage) 293 321 294 322 if res == nil { 295 323 fmt.Printf("Successfully retried: %s - %s\n", fr.rec.ArtistName(), fr.rec.TrackName) 296 - // Mark as published (updates processedBucket to 1) 297 324 if err := a.storage.MarkPublished(did, fr.key); err != nil { 298 - a.log.Error("Failed to mark record as published", logutil.Error(err), slog.String("key", fr.key)) 325 + a.log.Error("Failed to mark record as published", sync.Error(err), slog.String("key", fr.key)) 299 326 } 300 - // Remove from failedBucket 301 327 if err := a.storage.RemoveFailed(did, fr.key); err != nil { 302 - a.log.Error("Failed to remove record from failed list", logutil.Error(err), slog.String("key", fr.key)) 328 + a.log.Error("Failed to remove record from failed list", sync.Error(err), slog.String("key", fr.key)) 303 329 } 304 330 successCount++ 305 331 } else { 306 332 fmt.Printf("Failed again: %s - %s: %v\n", fr.rec.ArtistName(), fr.rec.TrackName, res) 307 333 errorCount++ 308 334 } 309 - 310 - // Optional: small delay between retries? 311 - // The rate limiter already handles the delay. 312 335 } 313 336 314 337 fmt.Printf("\nRetry complete: %d succeeded, %d failed.\n", successCount, errorCount) ··· 316 339 } 317 340 318 341 func (a *App) runFailed(ctx context.Context, cmd *cli.Command) error { 319 - // We need the DID to look up the user's failed records. 320 - // If it's already in the handle, use it, otherwise we might need to resolve it. 321 - // For simplicity, we'll try to get it from auth or resolve it. 322 342 authClient, err := a.prepareAuth(ctx, cmd) 323 343 if err != nil { 324 344 return err 325 345 } 326 - did := authClient.GetDID() 346 + did := authClient.DID() 327 347 328 348 type FailedRecord struct { 329 349 Key string `json:"key"` ··· 369 389 return nil 370 390 } 371 391 372 - func (a *App) versionCommand() *cli.Command { 373 - return &cli.Command{ 374 - Name: "version", 375 - Usage: "Print the version number", 376 - Action: func(ctx context.Context, cmd *cli.Command) error { 377 - fmt.Println(Version) 378 - return nil 379 - }, 380 - } 381 - } 382 - 383 392 func (a *App) runDebugFetch(ctx context.Context, cmd *cli.Command) error { 384 393 authClient, err := a.prepareAuth(ctx, cmd) 385 394 if err != nil { 386 395 return fmt.Errorf("authentication failed: %w", err) 387 396 } 388 397 389 - repoClient := sync.NewRateClient(authClient.GetAPIClient(), authClient.GetDID(), nil) 398 + repoClient := sync.NewRateClient(authClient.APIClient(), authClient.DID(), nil) 390 399 391 400 records, _, err := repoClient.ListRecords(ctx, sync.RecordType, 10, "") 392 401 if err != nil { ··· 495 504 return err 496 505 } 497 506 498 - a.log.Info("Starting import operation", logutil.DID(handle)) 507 + a.log.Info("Starting import operation", sync.DID(handle)) 499 508 500 509 lastfmPath := cmd.String("lastfm") 501 510 spotifyPath := cmd.String("spotify") ··· 509 518 510 519 if clearCache { 511 520 if err := a.storage.ClearAll(); err != nil { 512 - a.log.Error("Failed to clear cache", logutil.Error(err)) 521 + a.log.Error("Failed to clear cache", sync.Error(err)) 513 522 } else { 514 523 a.log.Info("Cache cleared") 515 524 } ··· 549 558 if err != nil { 550 559 return fmt.Errorf("create auth client: %w", err) 551 560 } 552 - a.log.Info("Authenticated", logutil.DID(authClient.GetDID()), slog.String("pds", authClient.GetPDS())) 561 + a.log.Info("Authenticated", sync.DID(authClient.DID()), slog.String("pds", authClient.PDS())) 553 562 554 563 limiter := sync.NewRateLimiter(a.storage, 0.9) 555 - repoClient := sync.NewRateClient(authClient.GetAPIClient(), authClient.GetDID(), limiter) 564 + repoClient := sync.NewRateClient(authClient.APIClient(), authClient.DID(), limiter) 556 565 557 - existingRecords, err := sync.FetchExisting(ctx, repoClient, authClient.GetDID(), a.storage, fresh) 566 + existingRecords, err := sync.FetchExisting(ctx, repoClient, authClient.DID(), a.storage, fresh) 558 567 if err != nil { 559 568 return fmt.Errorf("fetch existing records: %w", err) 560 569 } 561 570 a.log.Info("Fetched existing records", slog.Int("count", len(existingRecords))) 562 571 563 - published, _ := a.storage.GetPublished(authClient.GetDID()) 572 + published, _ := a.storage.GetPublished(authClient.DID()) 564 573 newRecords := sync.FilterNew(records, existingRecords, published) 565 574 skippedCount := len(records) - len(newRecords) 566 575 a.log.Info("Filtered to new records", ··· 584 593 value, _ := json.Marshal(rec) 585 594 newEntries[key] = value 586 595 } 587 - if err := a.storage.SaveRecords(authClient.GetDID(), newEntries); err != nil { 596 + if err := a.storage.SaveRecords(authClient.DID(), newEntries); err != nil { 588 597 return fmt.Errorf("save new records to storage: %w", err) 589 598 } 590 599 } ··· 663 672 } 664 673 665 674 fresh := cmd.Bool("fresh") 666 - a.log.Info("Starting sync operation", logutil.DID(authClient.GetDID()), slog.Bool("fresh", fresh)) 675 + a.log.Info("Starting sync operation", sync.DID(authClient.DID()), slog.Bool("fresh", fresh)) 667 676 668 677 limiter := sync.NewRateLimiter(a.storage, 0.85) 669 - repoClient := sync.NewRateClient(authClient.GetAPIClient(), authClient.GetDID(), limiter) 678 + repoClient := sync.NewRateClient(authClient.APIClient(), authClient.DID(), limiter) 670 679 671 680 if fresh { 672 - if err := a.storage.Clear(authClient.GetDID()); err != nil { 673 - a.log.Error("Failed to clear cache", logutil.Error(err)) 681 + if err := a.storage.Clear(authClient.DID()); err != nil { 682 + a.log.Error("Failed to clear cache", sync.Error(err)) 674 683 } else { 675 684 a.log.Info("Cache cleared") 676 685 } 677 686 } 678 687 679 - existingRecords, err := sync.FetchExisting(ctx, repoClient, authClient.GetDID(), a.storage, fresh) 688 + existingRecords, err := sync.FetchExisting(ctx, repoClient, authClient.DID(), a.storage, fresh) 680 689 if err != nil { 681 690 return fmt.Errorf("fetch existing records: %w", err) 682 691 } ··· 696 705 fresh := cmd.Bool("fresh") 697 706 yes := cmd.Bool("yes") 698 707 a.log.Info("Starting dedupe operation", 699 - logutil.DID(authClient.GetDID()), 708 + sync.DID(authClient.DID()), 700 709 slog.Bool("dry_run", dryRun), 701 710 slog.Bool("fresh", fresh)) 702 711 703 712 limiter := sync.NewRateLimiter(a.storage, 0.9) 704 - repoClient := sync.NewRateClient(authClient.GetAPIClient(), authClient.GetDID(), limiter) 713 + repoClient := sync.NewRateClient(authClient.APIClient(), authClient.DID(), limiter) 705 714 706 715 if fresh { 707 - if err := a.storage.Clear(authClient.GetDID()); err != nil { 708 - a.log.Error("Failed to clear cache", logutil.Error(err)) 716 + if err := a.storage.Clear(authClient.DID()); err != nil { 717 + a.log.Error("Failed to clear cache", sync.Error(err)) 709 718 } else { 710 719 a.log.Info("Cache cleared") 711 720 } 712 721 } 713 722 714 - existingRecords, err := sync.FetchExisting(ctx, repoClient, authClient.GetDID(), a.storage, fresh) 723 + existingRecords, err := sync.FetchExisting(ctx, repoClient, authClient.DID(), a.storage, fresh) 715 724 if err != nil { 716 725 return fmt.Errorf("failed to fetch existing records: %w", err) 717 726 } ··· 773 782 OnRetryScheduled(func(e failsafe.ExecutionScheduledEvent[any]) { 774 783 a.log.Warn("Delete failed with transient error, retrying", 775 784 slog.Duration("retryDelay", e.Delay), 776 - logutil.Error(e.LastError()), 785 + sync.Error(e.LastError()), 777 786 slog.Int("attempt", e.Attempts()), 778 787 slog.String("uri", uri)) 779 788 }). ··· 784 793 }) 785 794 786 795 if err != nil { 787 - a.log.Error("Failed to delete record", logutil.Error(err), slog.String("uri", uri)) 796 + a.log.Error("Failed to delete record", sync.Error(err), slog.String("uri", uri)) 788 797 } else { 789 798 a.log.Info("Deleted duplicate", slog.String("uri", uri), slog.String("track", rec.Value.TrackName)) 790 799 } 791 800 } 792 801 } 793 802 794 - if err := a.storage.Clear(authClient.GetDID()); err != nil { 803 + if err := a.storage.Clear(authClient.DID()); err != nil { 795 804 a.log.Error("Failed to clear cache", "err", err) 796 805 } 797 806 ··· 818 827 819 828 return nil 820 829 } 830 + 831 + var commonFlags = []cli.Flag{ 832 + &cli.StringFlag{ 833 + Name: "handle", 834 + Usage: "Bluesky handle", 835 + Sources: cli.EnvVars(EnvHandle), 836 + }, 837 + &cli.StringFlag{ 838 + Name: "password", 839 + Usage: "App password", 840 + Sources: cli.EnvVars(EnvPassword), 841 + }, 842 + &cli.BoolFlag{ 843 + Name: "verbose", 844 + Usage: "Enable verbose logging (-v for debug, -vv for trace)", 845 + Aliases: []string{"v"}, 846 + Sources: cli.EnvVars(EnvVerbose), 847 + Config: cli.BoolConfig{Count: &verboseCount}, 848 + }, 849 + &cli.BoolFlag{ 850 + Name: "quiet", 851 + Usage: "Suppress non-essential output (-q for warn, -qq for errors, -qqq for silent)", 852 + Aliases: []string{"q"}, 853 + Sources: cli.EnvVars(EnvQuiet), 854 + Config: cli.BoolConfig{Count: &quietCount}, 855 + }, 856 + &cli.StringFlag{ 857 + Name: "output-format", 858 + Usage: "Output format: text or json", 859 + Value: "text", 860 + Sources: cli.EnvVars("LAZULI_OUTPUT_FORMAT"), 861 + }, 862 + } 863 + 864 + var lastfmFlag = &cli.StringFlag{ 865 + Name: "lastfm", 866 + Usage: "Path to Last.fm CSV file or directory", 867 + Sources: cli.EnvVars("LAZULI_LASTFM"), 868 + } 869 + 870 + var spotifyFlag = &cli.StringFlag{ 871 + Name: "spotify", 872 + Usage: "Path to Spotify JSON/directory/zip", 873 + Sources: cli.EnvVars("LAZULI_SPOTIFY"), 874 + } 875 + 876 + var exportFlags = []cli.Flag{ 877 + lastfmFlag, 878 + spotifyFlag, 879 + &cli.StringFlag{ 880 + Name: "output", 881 + Usage: "Output file (stdout if not set)", 882 + Sources: cli.EnvVars("LAZULI_OUTPUT"), 883 + }, 884 + &cli.BoolFlag{ 885 + Name: "reverse", 886 + Usage: "Sort records reverse chronologically", 887 + Sources: cli.EnvVars(EnvReverse), 888 + }, 889 + &cli.DurationFlag{ 890 + Name: "tolerance", 891 + Usage: "Time tolerance for cross-source deduplication (e.g., 5m, 10m)", 892 + Value: sync.DefaultCrossSourceTolerance, 893 + Sources: cli.EnvVars("LAZULI_TOLERANCE"), 894 + }, 895 + } 896 + 897 + var importFlags = []cli.Flag{ 898 + lastfmFlag, 899 + spotifyFlag, 900 + &cli.StringFlag{ 901 + Name: "mode", 902 + Usage: "Import mode: lastfm, spotify, combined (default: combined)", 903 + Value: "combined", 904 + Sources: cli.EnvVars("LAZULI_MODE"), 905 + }, 906 + &cli.BoolFlag{ 907 + Name: "dry-run", 908 + Usage: "Preview without publishing", 909 + Sources: cli.EnvVars(EnvDryRun), 910 + }, 911 + &cli.BoolFlag{ 912 + Name: "reverse", 913 + Usage: "Import in reverse order", 914 + Sources: cli.EnvVars(EnvReverse), 915 + }, 916 + &cli.BoolFlag{ 917 + Name: "fresh", 918 + Usage: "Don't use cached Bluesky records", 919 + Sources: cli.EnvVars(EnvFresh), 920 + }, 921 + &cli.BoolFlag{ 922 + Name: "clear-cache", 923 + Usage: "Clear cache before running", 924 + Sources: cli.EnvVars(EnvClearCache), 925 + }, 926 + &cli.IntFlag{ 927 + Name: "batch-size", 928 + Usage: "Records per batch (default: 20)", 929 + Value: DefaultBatchSize, 930 + Sources: cli.EnvVars("LAZULI_BATCH_SIZE"), 931 + }, 932 + &cli.DurationFlag{ 933 + Name: "tolerance", 934 + Usage: "Time tolerance for cross-source deduplication (e.g., 5m, 10m)", 935 + Value: sync.DefaultCrossSourceTolerance, 936 + Sources: cli.EnvVars("LAZULI_TOLERANCE"), 937 + }, 938 + } 939 + 940 + var syncFlags = []cli.Flag{ 941 + &cli.BoolFlag{ 942 + Name: "fresh", 943 + Usage: "Force refresh cache", 944 + Sources: cli.EnvVars(EnvFresh), 945 + }, 946 + } 947 + 948 + var dedupeFlags = []cli.Flag{ 949 + &cli.BoolFlag{ 950 + Name: "dry-run", 951 + Usage: "Preview without deleting", 952 + Sources: cli.EnvVars(EnvDryRun), 953 + }, 954 + &cli.BoolFlag{ 955 + Name: "fresh", 956 + Usage: "Force refresh cache", 957 + Sources: cli.EnvVars(EnvFresh), 958 + }, 959 + &cli.BoolFlag{ 960 + Name: "yes", 961 + Usage: "Skip confirmation prompt", 962 + Aliases: []string{"y"}, 963 + Sources: cli.EnvVars(EnvYes), 964 + }, 965 + }
-486
sync/adapter.go
··· 1 - package sync 2 - 3 - import ( 4 - "context" 5 - "encoding/json" 6 - "errors" 7 - "fmt" 8 - "log/slog" 9 - "net" 10 - "strings" 11 - "time" 12 - 13 - "github.com/bluesky-social/indigo/atproto/atclient" 14 - "github.com/bluesky-social/indigo/atproto/syntax" 15 - "github.com/failsafe-go/failsafe-go" 16 - "github.com/failsafe-go/failsafe-go/retrypolicy" 17 - 18 - "tangled.org/karitham.dev/lazuli/cache" 19 - "tangled.org/karitham.dev/lazuli/sync/logutil" 20 - ) 21 - 22 - type RepoClient interface { 23 - ListRecords(ctx context.Context, collection string, limit int, cursor string) ([]RecordRef, string, error) 24 - ApplyWrites(ctx context.Context, collection string, records []PlayRecord) error 25 - DeleteRecord(ctx context.Context, collection, rkey string) error 26 - } 27 - 28 - type RecordRef struct { 29 - URI string 30 - CID string 31 - Value PlayRecord 32 - } 33 - 34 - type RateClient struct { 35 - client *atclient.APIClient 36 - did string 37 - limiter RateLimiter 38 - } 39 - 40 - func NewRateClient(client *atclient.APIClient, did string, limiter RateLimiter) *RateClient { 41 - return &RateClient{ 42 - client: client, 43 - did: did, 44 - limiter: limiter, 45 - } 46 - } 47 - 48 - func (c *RateClient) ListRecords(ctx context.Context, collection string, limit int, cursor string) ([]RecordRef, string, error) { 49 - if c.client == nil { 50 - return nil, "", fmt.Errorf("client cannot be nil") 51 - } 52 - 53 - var outResp struct { 54 - Records []struct { 55 - URI string `json:"uri"` 56 - CID string `json:"cid"` 57 - Value map[string]any `json:"value"` 58 - } `json:"records"` 59 - Cursor string `json:"cursor"` 60 - } 61 - 62 - var chargedAt time.Time 63 - if c.limiter != nil { 64 - slog.Debug("waiting for rate limit (read)") 65 - var err error 66 - chargedAt, err = c.limiter.AllowRead(ctx) 67 - if err != nil { 68 - slog.Error("rate limit wait cancelled/failed (read)", logutil.Error(err)) 69 - return nil, "", err 70 - } 71 - } 72 - 73 - err := c.client.Get(ctx, syntax.NSID("com.atproto.repo.listRecords"), map[string]any{ 74 - "repo": c.did, 75 - "collection": collection, 76 - "limit": limit, 77 - "cursor": cursor, 78 - }, &outResp) 79 - if err != nil { 80 - if c.limiter != nil && isTransientError(err) { 81 - c.limiter.RefundRead(ctx, chargedAt) 82 - } 83 - return nil, "", err 84 - } 85 - 86 - out := make([]RecordRef, 0, len(outResp.Records)) 87 - for _, r := range outResp.Records { 88 - var playRecord PlayRecord 89 - if r.Value != nil { 90 - b, err := json.Marshal(r.Value) 91 - if err != nil { 92 - slog.Debug("failed to marshal record value", slog.String("uri", r.URI), logutil.Error(err)) 93 - continue 94 - } 95 - if err := json.Unmarshal(b, &playRecord); err != nil { 96 - slog.Debug("failed to unmarshal record", slog.String("uri", r.URI), logutil.Error(err)) 97 - continue 98 - } 99 - slog.Debug("parsed record", slog.String("uri", r.URI), logutil.Track(playRecord.TrackName, playRecord.ArtistName(), playRecord.PlayedTime.Time)) 100 - } 101 - out = append(out, RecordRef{ 102 - URI: r.URI, 103 - CID: r.CID, 104 - Value: playRecord, 105 - }) 106 - } 107 - 108 - return out, outResp.Cursor, nil 109 - } 110 - 111 - func (c *RateClient) ApplyWrites(ctx context.Context, collection string, records []PlayRecord) error { 112 - if len(records) == 0 { 113 - return nil 114 - } 115 - 116 - if c.client == nil { 117 - return fmt.Errorf("client cannot be nil") 118 - } 119 - 120 - var chargedAt time.Time 121 - if c.limiter != nil { 122 - var err error 123 - chargedAt, err = c.limiter.AllowBulkWrite(ctx, len(records)) 124 - if err != nil { 125 - slog.Error("rate limit wait cancelled/failed (write)", 126 - logutil.DID(c.did), 127 - slog.String("collection", collection), 128 - logutil.Error(err), 129 - ) 130 - return err 131 - } 132 - } 133 - 134 - err := applyWrites(ctx, c.client, c.did, collection, records) 135 - if err != nil && isTransientError(err) && c.limiter != nil { 136 - c.limiter.RefundBulkWrite(ctx, len(records), chargedAt) 137 - } 138 - return err 139 - } 140 - 141 - func applyWrites(ctx context.Context, client *atclient.APIClient, did, collection string, records []PlayRecord) error { 142 - if len(records) == 0 { 143 - return nil 144 - } 145 - 146 - if len(records) > 200 { 147 - return fmt.Errorf("too many records in one ApplyWrites call: %d (max 200)", len(records)) 148 - } 149 - 150 - writes, err := prepareWrites(records, collection) 151 - if err != nil { 152 - return err 153 - } 154 - 155 - return client.Post(ctx, syntax.NSID("com.atproto.repo.applyWrites"), map[string]any{ 156 - "repo": did, 157 - "writes": writes, 158 - }, nil) 159 - } 160 - 161 - type atprotoClientAdapter struct { 162 - client *atclient.APIClient 163 - did string 164 - } 165 - 166 - func (a *atprotoClientAdapter) ApplyWrites(ctx context.Context, collection string, records []PlayRecord) error { 167 - return applyWrites(ctx, a.client, a.did, collection, records) 168 - } 169 - 170 - func prepareRecords(batch []PlayRecord) []PlayRecord { 171 - atprotoRecords := make([]PlayRecord, 0, len(batch)) 172 - for _, record := range batch { 173 - record.Type = RecordType 174 - record.SubmissionClientAgent = ClientAgent 175 - atprotoRecords = append(atprotoRecords, record) 176 - } 177 - return atprotoRecords 178 - } 179 - 180 - func defaultProgressLog(f func(ProgressReport)) func(ProgressReport) { 181 - if f != nil { 182 - return f 183 - } 184 - return func(pr ProgressReport) { 185 - slog.Info("sync progress", 186 - slog.Int("completed", pr.Completed), 187 - slog.Int("total", pr.Total), 188 - slog.Float64("percent", pr.Percent), 189 - slog.String("elapsed", pr.Elapsed), 190 - slog.String("eta", pr.ETA), 191 - slog.String("rate", pr.Rate), 192 - slog.Int("errors", pr.Errors), 193 - ) 194 - } 195 - } 196 - 197 - func defaultBatchSize(size int) int { 198 - if size > 0 { 199 - return size 200 - } 201 - return DefaultBatchSize 202 - } 203 - 204 - func buildClient(client AuthClient, customClient ATProtoClient) (ATProtoClient, error) { 205 - if customClient != nil { 206 - return customClient, nil 207 - } 208 - 209 - apiClient := client.GetAPIClient() 210 - if apiClient == nil { 211 - slog.Error("failed to get API client", logutil.Error(fmt.Errorf("client is nil"))) 212 - return nil, fmt.Errorf("API client is nil") 213 - } 214 - 215 - return &atprotoClientAdapter{client: apiClient, did: client.GetDID()}, nil 216 - } 217 - 218 - func newPublishResult(success, errors, total int, start time.Time, cancelled bool) PublishResult { 219 - return PublishResult{ 220 - SuccessCount: success, 221 - ErrorCount: errors, 222 - Cancelled: cancelled, 223 - Duration: time.Since(start), 224 - TotalRecords: total, 225 - RecordsPerMinute: ratePerMinute(success, time.Since(start)), 226 - } 227 - } 228 - 229 - func logResult(success, errors int, startTime time.Time) { 230 - if errors > 0 { 231 - slog.Warn("import completed with errors", 232 - slog.Int("success", success), 233 - slog.Int("errors", errors)) 234 - } 235 - slog.Info("import completed", 236 - slog.Int("success", success), 237 - slog.Int("errors", errors), 238 - slog.Duration("duration", time.Since(startTime)), 239 - slog.String("rate", formatRate(ratePerMinute(success, time.Since(startTime))))) 240 - } 241 - 242 - func PublishBatch(ctx context.Context, client ATProtoClient, did string, batch []PlayRecord, storage cache.Storage) error { 243 - if len(batch) == 0 { 244 - return nil 245 - } 246 - 247 - atprotoRecords := prepareRecords(batch) 248 - 249 - err := client.ApplyWrites(ctx, RecordType, atprotoRecords) 250 - if err != nil { 251 - slog.Error("batch publish failed", logutil.Error(err)) 252 - return err 253 - } 254 - 255 - if storage != nil && did != "" { 256 - keys := CreateRecordKeys(atprotoRecords) 257 - cacheEntries := make(map[string][]byte) 258 - for i, rec := range atprotoRecords { 259 - key := keys[i] 260 - value, _ := json.Marshal(rec) 261 - cacheEntries[key] = value 262 - } 263 - 264 - if err := storage.SaveRecords(did, cacheEntries); err != nil { 265 - return fmt.Errorf("failed to save records to storage: %w", err) 266 - } 267 - 268 - if err := storage.MarkPublished(did, keys...); err != nil { 269 - return fmt.Errorf("failed to mark records as published: %w", err) 270 - } 271 - } 272 - 273 - return nil 274 - } 275 - 276 - type ATProtoClient interface { 277 - ApplyWrites(ctx context.Context, collection string, records []PlayRecord) error 278 - } 279 - 280 - func ratePerMinute(count int, duration time.Duration) float64 { 281 - if duration == 0 { 282 - return 0 283 - } 284 - return float64(count) / duration.Minutes() 285 - } 286 - 287 - type AuthClient interface { 288 - GetAPIClient() *atclient.APIClient 289 - GetDID() string 290 - } 291 - 292 - func IsTransientError(err error) bool { 293 - return isTransientError(err) 294 - } 295 - 296 - func isTransientError(err error) bool { 297 - if err == nil { 298 - return false 299 - } 300 - 301 - var apiErr *atclient.APIError 302 - if errors.As(err, &apiErr) { 303 - switch apiErr.StatusCode { 304 - case 429, 500, 502, 503, 504: 305 - return true 306 - } 307 - return false 308 - } 309 - 310 - var netErr net.Error 311 - if errors.As(err, &netErr) { 312 - return netErr.Timeout() 313 - } 314 - 315 - return false 316 - } 317 - 318 - func (c *RateClient) DeleteRecord(ctx context.Context, collection, rkey string) error { 319 - if c.client == nil { 320 - return fmt.Errorf("client is nil") 321 - } 322 - 323 - var chargedAt time.Time 324 - if c.limiter != nil { 325 - slog.Debug("waiting for rate limit (delete)") 326 - var err error 327 - chargedAt, err = c.limiter.AllowBulkWrite(ctx, 1) 328 - if err != nil { 329 - slog.Error("rate limit wait cancelled/failed (delete)", logutil.Error(err)) 330 - return err 331 - } 332 - } 333 - 334 - _, err := c.client.Do(ctx, &atclient.APIRequest{ 335 - Method: "DELETE", 336 - Endpoint: syntax.NSID("com.atproto.repo.deleteRecord"), 337 - QueryParams: map[string][]string{ 338 - "repo": {c.did}, 339 - "collection": {collection}, 340 - "rkey": {rkey}, 341 - }, 342 - }) 343 - if err != nil && c.limiter != nil && isTransientError(err) { 344 - c.limiter.RefundBulkWrite(ctx, 1, chargedAt) 345 - } 346 - return err 347 - } 348 - 349 - func FetchExisting(ctx context.Context, client RepoClient, did string, storage cache.Storage, forceRefresh bool) ([]ExistingRecord, error) { 350 - if !forceRefresh && storage != nil { 351 - published, err := storage.GetPublished(did) 352 - if err == nil && len(published) > 0 && storage.IsValid(did) { 353 - records := make([]ExistingRecord, 0, len(published)) 354 - err := storage.IteratePublished(did, func(key string, data []byte) error { 355 - var value PlayRecord 356 - if err := json.Unmarshal(data, &value); err != nil { 357 - return nil 358 - } 359 - records = append(records, ExistingRecord{ 360 - URI: generateRecordURI(did, value), 361 - Value: value, 362 - }) 363 - return nil 364 - }) 365 - if err == nil { 366 - slog.Debug("loaded from cache", slog.Int("count", len(records))) 367 - return records, nil 368 - } 369 - } 370 - } 371 - 372 - select { 373 - case <-ctx.Done(): 374 - return nil, ctx.Err() 375 - default: 376 - } 377 - 378 - allRecords := make([]ExistingRecord, 0, 1024) 379 - return fetchExistingLoop(ctx, client, did, storage, allRecords) 380 - } 381 - 382 - func fetchExistingLoop(ctx context.Context, client RepoClient, did string, storage cache.Storage, allRecords []ExistingRecord) ([]ExistingRecord, error) { 383 - const batchSize = 100 384 - var cursor string 385 - 386 - type fetchResult struct { 387 - records []RecordRef 388 - cursor string 389 - } 390 - 391 - retryPolicy := retrypolicy.NewBuilder[fetchResult](). 392 - WithMaxRetries(10). 393 - WithBackoff(BaseRetryDelay, 5*time.Minute). 394 - HandleIf(func(_ fetchResult, err error) bool { 395 - return isTransientError(err) 396 - }). 397 - OnRetryScheduled(func(e failsafe.ExecutionScheduledEvent[fetchResult]) { 398 - slog.Warn("fetch failed with transient error, retrying", 399 - slog.Duration("retryDelay", e.Delay), 400 - logutil.Error(e.LastError()), 401 - slog.Int("attempt", e.Attempts())) 402 - }). 403 - Build() 404 - 405 - for { 406 - select { 407 - case <-ctx.Done(): 408 - return nil, ctx.Err() 409 - default: 410 - } 411 - 412 - result, err := failsafe.With(retryPolicy). 413 - WithContext(ctx). 414 - Get(func() (fetchResult, error) { 415 - recs, next, err := client.ListRecords(ctx, RecordType, batchSize, cursor) 416 - if err != nil { 417 - return fetchResult{}, err 418 - } 419 - 420 - return fetchResult{records: recs, cursor: next}, nil 421 - }) 422 - if err != nil { 423 - return nil, err 424 - } 425 - 426 - for _, rec := range result.records { 427 - allRecords = append(allRecords, ExistingRecord(rec)) 428 - } 429 - 430 - if result.cursor == "" || len(result.records) < batchSize { 431 - break 432 - } 433 - cursor = result.cursor 434 - } 435 - 436 - if storage != nil { 437 - cacheEntries := make(map[string][]byte) 438 - keys := make([]string, 0, len(allRecords)) 439 - for _, rec := range allRecords { 440 - parts := strings.Split(rec.URI, "/") 441 - key := parts[len(parts)-1] 442 - if key == "" { 443 - key = CreateRecordKey(rec.Value) 444 - } 445 - value, _ := json.Marshal(rec.Value) 446 - cacheEntries[key] = value 447 - keys = append(keys, key) 448 - } 449 - 450 - if err := storage.SaveRecords(did, cacheEntries); err != nil { 451 - return nil, err 452 - } 453 - 454 - if err := storage.MarkPublished(did, keys...); err != nil { 455 - return nil, err 456 - } 457 - 458 - slog.Debug("saved to cache and marked as published", slog.Int("count", len(allRecords))) 459 - } 460 - 461 - return allRecords, nil 462 - } 463 - 464 - func prepareWrites(records []PlayRecord, collection string) ([]map[string]any, error) { 465 - if len(records) == 0 { 466 - return nil, nil 467 - } 468 - 469 - writes := make([]map[string]any, len(records)) 470 - keys := CreateRecordKeys(records) 471 - 472 - for i, rec := range records { 473 - writes[i] = map[string]any{ 474 - "$type": "com.atproto.repo.applyWrites#create", 475 - "collection": collection, 476 - "rkey": keys[i], 477 - "value": rec, 478 - } 479 - } 480 - 481 - return writes, nil 482 - } 483 - 484 - func generateRecordURI(did string, record PlayRecord) string { 485 - return fmt.Sprintf("at://%s/%s/%s", did, RecordType, CreateRecordKey(record)) 486 - }
-60
sync/atproto.go
··· 1 - package sync 2 - 3 - import ( 4 - "context" 5 - "encoding/json" 6 - "fmt" 7 - "net/http" 8 - "net/url" 9 - "time" 10 - ) 11 - 12 - var UserAgent string 13 - 14 - var httpClient = &http.Client{ 15 - Timeout: 10 * time.Second, 16 - } 17 - 18 - func ResolveMiniDoc(ctx context.Context, identifier string) (did, pds, signingKey string, err error) { 19 - parsedURL, err := url.Parse(SlingshotResolverURL) 20 - if err != nil { 21 - return "", "", "", fmt.Errorf("invalid resolver URL: %w", err) 22 - } 23 - 24 - query := parsedURL.Query() 25 - query.Set("identifier", identifier) 26 - parsedURL.RawQuery = query.Encode() 27 - 28 - req, err := http.NewRequestWithContext(ctx, "GET", parsedURL.String(), nil) 29 - if err != nil { 30 - return "", "", "", fmt.Errorf("failed to create request: %w", err) 31 - } 32 - if UserAgent != "" { 33 - req.Header.Set("User-Agent", UserAgent) 34 - } 35 - 36 - resp, err := httpClient.Do(req) 37 - if err != nil { 38 - return "", "", "", fmt.Errorf("failed to resolve mini doc: %w", err) 39 - } 40 - defer resp.Body.Close() 41 - 42 - if resp.StatusCode != http.StatusOK { 43 - return "", "", "", fmt.Errorf("mini doc resolution failed with status: %d", resp.StatusCode) 44 - } 45 - 46 - var result struct { 47 - DID string `json:"did"` 48 - PDS string `json:"pds"` 49 - SigningKey string `json:"signing_key"` 50 - } 51 - if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { 52 - return "", "", "", fmt.Errorf("failed to decode mini doc response: %w", err) 53 - } 54 - 55 - if result.DID == "" { 56 - return "", "", "", fmt.Errorf("resolved mini doc missing DID") 57 - } 58 - 59 - return result.DID, result.PDS, result.SigningKey, nil 60 - }
-105
sync/atproto_auth.go
··· 1 - package sync 2 - 3 - import ( 4 - "context" 5 - "encoding/json" 6 - "fmt" 7 - "net/http" 8 - "strings" 9 - "sync" 10 - 11 - "github.com/bluesky-social/indigo/atproto/atclient" 12 - "github.com/bluesky-social/indigo/atproto/syntax" 13 - ) 14 - 15 - type FixedPasswordAuth struct { 16 - *atclient.PasswordAuth 17 - lk sync.RWMutex 18 - } 19 - 20 - func (a *FixedPasswordAuth) DoWithAuth(c *http.Client, req *http.Request, endpoint syntax.NSID) (*http.Response, error) { 21 - accessToken, refreshToken := a.GetTokens() 22 - req.Header.Set("Authorization", "Bearer "+accessToken) 23 - resp, err := c.Do(req) 24 - if err != nil { 25 - return nil, err 26 - } 27 - 28 - if resp.StatusCode != http.StatusBadRequest || !strings.HasPrefix(resp.Header.Get("Content-Type"), "application/json") { 29 - return resp, nil 30 - } 31 - 32 - defer resp.Body.Close() 33 - var eb atclient.ErrorBody 34 - if err := json.NewDecoder(resp.Body).Decode(&eb); err != nil { 35 - return nil, &atclient.APIError{StatusCode: resp.StatusCode} 36 - } 37 - if eb.Name != "ExpiredToken" { 38 - return nil, eb.APIError(resp.StatusCode) 39 - } 40 - 41 - if err := a.Refresh(req.Context(), c, refreshToken); err != nil { 42 - return nil, err 43 - } 44 - 45 - retry := req.Clone(req.Context()) 46 - if req.GetBody != nil { 47 - retry.Body, err = req.GetBody() 48 - if err != nil { 49 - return nil, fmt.Errorf("API request retry GetBody failed: %w", err) 50 - } 51 - } 52 - 53 - accessToken, _ = a.GetTokens() 54 - retry.Header.Set("Authorization", "Bearer "+accessToken) 55 - return c.Do(retry) 56 - } 57 - 58 - func (a *FixedPasswordAuth) Refresh(ctx context.Context, c *http.Client, priorRefreshToken string) error { 59 - a.lk.Lock() 60 - defer a.lk.Unlock() 61 - 62 - if priorRefreshToken != "" && priorRefreshToken != a.Session.RefreshToken { 63 - return nil 64 - } 65 - 66 - u := a.Session.Host + "/xrpc/com.atproto.server.refreshSession" 67 - req, err := http.NewRequestWithContext(ctx, http.MethodPost, u, nil) 68 - if err != nil { 69 - return err 70 - } 71 - req.Header.Set("User-Agent", "indigo-sdk") 72 - req.Header.Set("Authorization", "Bearer "+a.Session.RefreshToken) 73 - 74 - resp, err := c.Do(req) 75 - if err != nil { 76 - return err 77 - } 78 - defer resp.Body.Close() 79 - 80 - if resp.StatusCode < 200 || resp.StatusCode >= 300 { 81 - var eb atclient.ErrorBody 82 - if err := json.NewDecoder(resp.Body).Decode(&eb); err != nil { 83 - return &atclient.APIError{StatusCode: resp.StatusCode} 84 - } 85 - return eb.APIError(resp.StatusCode) 86 - } 87 - 88 - var out struct { 89 - AccessJwt string `json:"accessJwt"` 90 - RefreshJwt string `json:"refreshJwt"` 91 - } 92 - if err := json.NewDecoder(resp.Body).Decode(&out); err != nil { 93 - return err 94 - } 95 - 96 - a.Session.AccessToken = out.AccessJwt 97 - a.Session.RefreshToken = out.RefreshJwt 98 - 99 - if a.RefreshCallback != nil { 100 - snapshot := a.Session.Clone() 101 - a.RefreshCallback(ctx, snapshot) 102 - } 103 - 104 - return nil 105 - }
-88
sync/atproto_auth_test.go
··· 1 - package sync 2 - 3 - import ( 4 - "context" 5 - "encoding/json" 6 - "net/http" 7 - "net/http/httptest" 8 - "strings" 9 - "testing" 10 - 11 - "github.com/bluesky-social/indigo/atproto/atclient" 12 - ) 13 - 14 - func TestFixedPasswordAuth_Refresh(t *testing.T) { 15 - t.Run("Refreshes with POST", func(t *testing.T) { 16 - methodUsed := "" 17 - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 18 - if strings.HasSuffix(r.URL.Path, "refreshSession") { 19 - methodUsed = r.Method 20 - w.Header().Set("Content-Type", "application/json") 21 - json.NewEncoder(w).Encode(map[string]string{ 22 - "accessJwt": "new-access", 23 - "refreshJwt": "new-refresh", 24 - }) 25 - return 26 - } 27 - w.WriteHeader(http.StatusNotFound) 28 - })) 29 - defer server.Close() 30 - 31 - pa := &atclient.PasswordAuth{ 32 - Session: atclient.PasswordSessionData{ 33 - Host: server.URL, 34 - AccessToken: "old-access", 35 - RefreshToken: "old-refresh", 36 - }, 37 - } 38 - fixed := &FixedPasswordAuth{PasswordAuth: pa} 39 - 40 - err := fixed.Refresh(context.Background(), http.DefaultClient, "old-refresh") 41 - if err != nil { 42 - t.Fatalf("Refresh failed: %v", err) 43 - } 44 - 45 - if methodUsed != http.MethodPost { 46 - t.Errorf("Expected method POST, got %s", methodUsed) 47 - } 48 - 49 - if pa.Session.AccessToken != "new-access" { 50 - t.Errorf("Expected new access token, got %s", pa.Session.AccessToken) 51 - } 52 - }) 53 - } 54 - 55 - func TestLibraryPasswordAuth_Refresh_Method(t *testing.T) { 56 - // This test documents the library's buggy behavior (GET instead of POST) 57 - methodUsed := "" 58 - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 59 - if strings.HasSuffix(r.URL.Path, "refreshSession") { 60 - methodUsed = r.Method 61 - w.Header().Set("Content-Type", "application/json") 62 - json.NewEncoder(w).Encode(map[string]string{ 63 - "accessJwt": "new-access", 64 - "refreshJwt": "new-refresh", 65 - }) 66 - return 67 - } 68 - })) 69 - defer server.Close() 70 - 71 - pa := &atclient.PasswordAuth{ 72 - Session: atclient.PasswordSessionData{ 73 - Host: server.URL, 74 - AccessToken: "old-access", 75 - RefreshToken: "old-refresh", 76 - }, 77 - } 78 - 79 - // We call the library's method directly 80 - _ = pa.Refresh(context.Background(), http.DefaultClient, "old-refresh") 81 - 82 - switch methodUsed { 83 - case http.MethodGet: 84 - t.Log("Confirmed: Library uses GET for refreshSession (Buggy)") 85 - case http.MethodPost: 86 - t.Log("Library uses POST for refreshSession") 87 - } 88 - }
-121
sync/auth.go
··· 1 - package sync 2 - 3 - import ( 4 - "context" 5 - "fmt" 6 - "sync" 7 - 8 - "github.com/bluesky-social/indigo/atproto/atclient" 9 - ) 10 - 11 - type ResolvedIdentity struct { 12 - DID string `json:"did"` 13 - Handle string `json:"handle"` 14 - PDS string `json:"pds"` 15 - SigningKey string `json:"signingKey"` 16 - } 17 - 18 - type Client struct { 19 - client *atclient.APIClient 20 - resolvedIdentity ResolvedIdentity 21 - mu sync.Mutex 22 - } 23 - 24 - func NewClient(ctx context.Context, handle, password string) (*Client, error) { 25 - identity, err := ResolveIdentity(ctx, handle) 26 - if err != nil { 27 - return nil, fmt.Errorf("failed to resolve identity: %w", err) 28 - } 29 - 30 - pdsURL := identity.PDS 31 - if pdsURL == "" { 32 - pdsURL = "https://bsky.social" 33 - } 34 - 35 - client, err := atclient.LoginWithPasswordHost(ctx, pdsURL, handle, password, "", nil) 36 - if err != nil { 37 - return nil, fmt.Errorf("login failed: %w", err) 38 - } 39 - 40 - if pa, ok := client.Auth.(*atclient.PasswordAuth); ok { 41 - client.Auth = &FixedPasswordAuth{PasswordAuth: pa} 42 - } 43 - 44 - return &Client{ 45 - client: client, 46 - resolvedIdentity: identity, 47 - }, nil 48 - } 49 - 50 - func ResolveIdentity(ctx context.Context, handle string) (ResolvedIdentity, error) { 51 - if handle == "" { 52 - return ResolvedIdentity{}, fmt.Errorf("handle cannot be empty") 53 - } 54 - 55 - did, pds, signingKey, err := ResolveMiniDoc(ctx, handle) 56 - if err != nil { 57 - return ResolvedIdentity{}, fmt.Errorf("failed to resolve identity: %w", err) 58 - } 59 - 60 - if did == "" { 61 - return ResolvedIdentity{}, fmt.Errorf("resolved identity missing DID") 62 - } 63 - 64 - if signingKey == "" { 65 - return ResolvedIdentity{}, fmt.Errorf("resolved identity missing signing key") 66 - } 67 - 68 - return ResolvedIdentity{ 69 - DID: did, 70 - Handle: handle, 71 - PDS: pds, 72 - SigningKey: signingKey, 73 - }, nil 74 - } 75 - 76 - func (c *Client) GetDID() string { 77 - c.mu.Lock() 78 - defer c.mu.Unlock() 79 - return c.resolvedIdentity.DID 80 - } 81 - 82 - func (c *Client) GetPDS() string { 83 - c.mu.Lock() 84 - defer c.mu.Unlock() 85 - return c.resolvedIdentity.PDS 86 - } 87 - 88 - func (c *Client) GetHandle() string { 89 - c.mu.Lock() 90 - defer c.mu.Unlock() 91 - return c.resolvedIdentity.Handle 92 - } 93 - 94 - func (c *Client) GetSigningKey() string { 95 - c.mu.Lock() 96 - defer c.mu.Unlock() 97 - return c.resolvedIdentity.SigningKey 98 - } 99 - 100 - func (c *Client) GetAPIClient() *atclient.APIClient { 101 - c.mu.Lock() 102 - defer c.mu.Unlock() 103 - return c.client 104 - } 105 - 106 - func (c *Client) Close() error { 107 - c.mu.Lock() 108 - defer c.mu.Unlock() 109 - if c.client != nil && c.client.Auth != nil { 110 - if logout, ok := c.client.Auth.(*atclient.PasswordAuth); ok { 111 - return logout.Logout(context.Background(), c.client.Client) 112 - } 113 - } 114 - return nil 115 - } 116 - 117 - func (c *Client) HasClient() bool { 118 - c.mu.Lock() 119 - defer c.mu.Unlock() 120 - return c.client != nil 121 - }
+27 -10
sync/batch_test.go
··· 13 13 "time" 14 14 15 15 "github.com/bluesky-social/indigo/atproto/atclient" 16 + "tangled.org/karitham.dev/lazuli/atproto" 16 17 "tangled.org/karitham.dev/lazuli/cache" 17 18 ) 18 19 ··· 138 139 139 140 // Mock ATProtoClient 140 141 type mockATProtoClient struct { 141 - applyWritesFunc func(ctx context.Context, collection string, records []PlayRecord) error 142 + applyWritesFunc func(ctx context.Context, collection string, records []PlayRecord) error 143 + listRecordsFunc func(ctx context.Context, collection string, limit int, cursor string) ([]atproto.RecordRef[PlayRecord], string, error) 144 + deleteRecordFunc func(ctx context.Context, collection, rkey string) error 142 145 } 143 146 144 147 func (m *mockATProtoClient) ApplyWrites(ctx context.Context, collection string, records []PlayRecord) error { 145 148 if m.applyWritesFunc != nil { 146 149 return m.applyWritesFunc(ctx, collection, records) 150 + } 151 + return nil 152 + } 153 + 154 + func (m *mockATProtoClient) ListRecords(ctx context.Context, collection string, limit int, cursor string) ([]atproto.RecordRef[PlayRecord], string, error) { 155 + if m.listRecordsFunc != nil { 156 + return m.listRecordsFunc(ctx, collection, limit, cursor) 157 + } 158 + return nil, "", nil 159 + } 160 + 161 + func (m *mockATProtoClient) DeleteRecord(ctx context.Context, collection, rkey string) error { 162 + if m.deleteRecordFunc != nil { 163 + return m.deleteRecordFunc(ctx, collection, rkey) 147 164 } 148 165 return nil 149 166 } ··· 153 170 did string 154 171 } 155 172 156 - func (m *mockAuthClient) GetAPIClient() *atclient.APIClient { return nil } 157 - func (m *mockAuthClient) GetDID() string { return m.did } 173 + func (m *mockAuthClient) APIClient() *atclient.APIClient { return nil } 174 + func (m *mockAuthClient) DID() string { return m.did } 158 175 159 176 type timeoutError struct{} 160 177 ··· 190 207 func TestApplyWrites_RateClient(t *testing.T) { 191 208 ctx := context.Background() 192 209 193 - t.Run("Empty records", func(t *testing.T) { 210 + t.Run("Empty anyRecords(records)", func(t *testing.T) { 194 211 limiter := &mockLimiter{} 195 - client := NewRateClient(nil, "did:example:123", limiter) 212 + client := atproto.NewRateClient[any](nil, "did:example:123", limiter) 196 213 err := client.ApplyWrites(ctx, "test", nil) 197 214 if err != nil { 198 215 t.Errorf("ApplyWrites(nil) error = %v", err) ··· 201 218 202 219 t.Run("Too many records", func(t *testing.T) { 203 220 limiter := &mockLimiter{} 204 - client := NewRateClient(&atclient.APIClient{}, "did:example:123", limiter) 205 - records := make([]PlayRecord, 201) 206 - err := client.ApplyWrites(ctx, "test", records) 221 + client := atproto.NewRateClient[any](&atclient.APIClient{}, "did:example:123", limiter) 222 + recs := make([]any, 201) 223 + err := client.ApplyWrites(ctx, "test", recs) 207 224 if err == nil { 208 225 t.Fatal("expected error for > 200 records") 209 226 } ··· 222 239 Body: http.NoBody, 223 240 }, nil 224 241 }) 225 - client := NewRateClient(apiClient, "did:example:123", limiter) 242 + client := atproto.NewRateClient[PlayRecord](apiClient, "did:example:123", limiter) 226 243 227 244 err := client.ApplyWrites(ctx, "test", []PlayRecord{{TrackName: "Song 1"}}) 228 245 if err == nil { ··· 242 259 Body: http.NoBody, 243 260 }, nil 244 261 }) 245 - client := NewRateClient(apiClient, "did:example:123", limiter) 262 + client := atproto.NewRateClient[PlayRecord](apiClient, "did:example:123", limiter) 246 263 247 264 err := client.ApplyWrites(ctx, "test", []PlayRecord{{TrackName: "Song 1"}}) 248 265 if err == nil {
+2 -2
sync/import_test.go
··· 54 54 did string 55 55 } 56 56 57 - func (m *mockAuthClient) GetAPIClient() *atclient.APIClient { return nil } 58 - func (m *mockAuthClient) GetDID() string { return m.did } 57 + func (m *mockAuthClient) APIClient() *atclient.APIClient { return nil } 58 + func (m *mockAuthClient) DID() string { return m.did } 59 59 60 60 type mockKV struct { 61 61 data map[string]int
+38
sync/logutil.go
··· 1 + package sync 2 + 3 + import ( 4 + "log/slog" 5 + "time" 6 + 7 + "tangled.org/karitham.dev/lazuli/atproto" 8 + ) 9 + 10 + type TrackInfo struct { 11 + Name string 12 + Artist string 13 + PlayedAt time.Time 14 + } 15 + 16 + func Track(name, artist string, playedAt time.Time) slog.Attr { 17 + return slog.Group("track", 18 + slog.String("name", name), 19 + slog.String("artist", artist), 20 + slog.Time("played_at", playedAt), 21 + ) 22 + } 23 + 24 + func DID(did string) slog.Attr { 25 + return slog.String("did", did) 26 + } 27 + 28 + func Error(err error) slog.Attr { 29 + if err == nil { 30 + return slog.Attr{} 31 + } 32 + 33 + if atproto.IsTransientError(err) { 34 + return slog.String("error", err.Error()) 35 + } 36 + 37 + return slog.String("error", err.Error()) 38 + }
-47
sync/logutil/logutil.go
··· 1 - package logutil 2 - 3 - import ( 4 - "errors" 5 - "log/slog" 6 - "time" 7 - 8 - "github.com/bluesky-social/indigo/atproto/atclient" 9 - ) 10 - 11 - type TrackInfo struct { 12 - Name string 13 - Artist string 14 - PlayedAt time.Time 15 - } 16 - 17 - // Track returns slog attributes for a play record. 18 - func Track(name, artist string, playedAt time.Time) slog.Attr { 19 - return slog.Group("track", 20 - slog.String("name", name), 21 - slog.String("artist", artist), 22 - slog.Time("played_at", playedAt), 23 - ) 24 - } 25 - 26 - // DID returns a slog attribute for a DID. 27 - func DID(did string) slog.Attr { 28 - return slog.String("did", did) 29 - } 30 - 31 - // Error returns a slog attribute for an error. 32 - func Error(err error) slog.Attr { 33 - if err == nil { 34 - return slog.Attr{} 35 - } 36 - 37 - var apiErr *atclient.APIError 38 - if errors.As(err, &apiErr) { 39 - return slog.Group("error", 40 - slog.Int("status", apiErr.StatusCode), 41 - slog.String("name", apiErr.Name), 42 - slog.String("message", apiErr.Message), 43 - ) 44 - } 45 - 46 - return slog.String("error", err.Error()) 47 - }
+291 -16
sync/publish.go
··· 5 5 "encoding/json" 6 6 "fmt" 7 7 "log/slog" 8 + "strings" 8 9 "time" 9 10 11 + "github.com/bluesky-social/indigo/atproto/atclient" 10 12 "github.com/bluesky-social/indigo/atproto/syntax" 11 13 "github.com/failsafe-go/failsafe-go" 12 14 "github.com/failsafe-go/failsafe-go/retrypolicy" 13 15 16 + "tangled.org/karitham.dev/lazuli/atproto" 14 17 "tangled.org/karitham.dev/lazuli/cache" 15 - "tangled.org/karitham.dev/lazuli/sync/logutil" 18 + ) 19 + 20 + type ( 21 + ATProtoClient = atproto.RepoClient[PlayRecord] 22 + AuthClient = atproto.AuthClient 23 + RateLimiter = atproto.RateLimiter 16 24 ) 17 25 26 + var ( 27 + WriteLimitDay = atproto.WriteLimitDay 28 + GlobalLimitDay = atproto.GlobalLimitDay 29 + ) 30 + 31 + func NewRateLimiter(kv atproto.KVStore, maxPercent float32) RateLimiter { 32 + return atproto.NewRateLimiter(kv, maxPercent) 33 + } 34 + 35 + func NewRateClient(client *atclient.APIClient, did string, limiter RateLimiter) *atproto.RateClient[PlayRecord] { 36 + return atproto.NewRateClient[PlayRecord](client, did, limiter) 37 + } 38 + 39 + func IsTransientError(err error) bool { 40 + return atproto.IsTransientError(err) 41 + } 42 + 43 + type Client = atproto.Client 44 + 45 + func NewClient(ctx context.Context, handle, password string, opts ...func(*atproto.ClientOptions)) (*Client, error) { 46 + return atproto.NewClient(ctx, handle, password, opts...) 47 + } 48 + 49 + type RepoClient[T any] = atproto.RepoClient[T] 50 + 51 + type RecordRef = atproto.RecordRef[PlayRecord] 52 + 18 53 type PublishOptions struct { 19 54 BatchSize int 20 55 DryRun bool ··· 29 64 30 65 batchSize := defaultBatchSize(opts.BatchSize) 31 66 32 - atprotoClient, err := buildClient(client, opts.ATProtoClient) 67 + atprotoClient, err := atproto.BuildClient(client, opts.ATProtoClient) 33 68 if err != nil { 34 69 return PublishResult{ 35 70 SuccessCount: 0, ··· 42 77 43 78 totalRecords := 0 44 79 if opts.Storage != nil { 45 - _ = opts.Storage.IterateUnpublished(client.GetDID(), func(key string, rec []byte) error { 80 + _ = opts.Storage.IterateUnpublished(client.DID(), func(key string, rec []byte) error { 46 81 totalRecords++ 47 82 return nil 48 83 }) ··· 55 90 slog.Info("starting iterative import", 56 91 slog.Int("total_records", totalRecords), 57 92 slog.Int("batch_size", batchSize), 58 - slog.Int("daily_write_limit", WriteLimitDay), 59 - slog.Int("daily_token_limit", GlobalLimitDay), 60 - slog.String("rate_limit", fmt.Sprintf("1 write per %.1fs", 86400.0/WriteLimitDay))) 93 + slog.Int("daily_write_limit", atproto.WriteLimitDay), 94 + slog.Int("daily_token_limit", atproto.GlobalLimitDay), 95 + slog.String("rate_limit", fmt.Sprintf("1 write per %.1fs", 86400.0/atproto.WriteLimitDay))) 61 96 62 97 tracker := NewProgressTracker(totalRecords, opts.Limiter) 63 98 progressLog := defaultProgressLog(opts.ProgressLog) ··· 76 111 for _, r := range batch { 77 112 tid := syntax.NewTIDFromTime(r.PlayedTime.Time, 0) 78 113 slog.Info("would publish record (dry run)", 79 - logutil.Track(r.TrackName, r.ArtistName(), r.PlayedTime.Time), 114 + Track(r.TrackName, r.ArtistName(), r.PlayedTime.Time), 80 115 slog.String("rkey", string(tid))) 81 116 } 82 117 totalSuccess += len(batch) ··· 86 121 return nil 87 122 } 88 123 89 - did := client.GetDID() 124 + did := client.DID() 90 125 retryPolicy := retrypolicy.NewBuilder[any](). 91 126 WithMaxRetries(10). 92 127 WithBackoff(BaseRetryDelay, 5*time.Minute). ··· 97 132 slog.Warn("batch failed with transient error, retrying", 98 133 slog.Int("count", len(batch)), 99 134 slog.Duration("retryDelay", e.Delay), 100 - logutil.Error(e.LastError()), 135 + Error(e.LastError()), 101 136 slog.Int("attempt", e.Attempts())) 102 137 }). 103 138 Build() ··· 107 142 }) 108 143 if err != nil { 109 144 slog.Error("batch failed after retries", 110 - logutil.Error(err), 145 + Error(err), 111 146 slog.Int("count", len(batch))) 112 147 113 148 if opts.Storage != nil { 114 149 if markErr := opts.Storage.MarkFailed(did, batchKeys, err.Error()); markErr != nil { 115 - slog.Error("failed to mark records as failed", logutil.Error(markErr)) 150 + slog.Error("failed to mark records as failed", Error(markErr)) 116 151 } 117 152 } 118 153 ··· 121 156 122 157 batch = batch[:0] 123 158 batchKeys = batchKeys[:0] 124 - return nil // Return nil so we continue with the next batch 159 + return nil 125 160 } 126 161 127 162 totalSuccess += len(batch) ··· 136 171 return nil 137 172 } 138 173 139 - err = opts.Storage.IterateUnpublished(client.GetDID(), func(key string, rec []byte) error { 174 + err = opts.Storage.IterateUnpublished(client.DID(), func(key string, rec []byte) error { 140 175 select { 141 176 case <-ctx.Done(): 142 177 return ctx.Err() ··· 145 180 146 181 var record PlayRecord 147 182 if err := json.Unmarshal(rec, &record); err != nil { 148 - slog.Error("malformed record in storage", slog.String("key", key), logutil.Error(err)) 183 + slog.Error("malformed record in storage", slog.String("key", key), Error(err)) 149 184 if opts.Storage != nil { 150 - _ = opts.Storage.MarkFailed(client.GetDID(), []string{key}, "malformed record") 185 + _ = opts.Storage.MarkFailed(client.DID(), []string{key}, "malformed record") 151 186 } 152 187 totalErrors++ 153 188 tracker.IncrementErrors(1) ··· 171 206 172 207 cancelled := false 173 208 if err != nil { 174 - slog.Error("import interrupted", logutil.Error(err)) 209 + slog.Error("import interrupted", Error(err)) 175 210 cancelled = true 176 211 } 177 212 178 213 logResult(totalSuccess, totalErrors, startTime) 179 214 return newPublishResult(totalSuccess, totalErrors, totalRecords, startTime, cancelled) 180 215 } 216 + 217 + func isTransientError(err error) bool { 218 + return atproto.IsTransientError(err) 219 + } 220 + 221 + func defaultProgressLog(f func(ProgressReport)) func(ProgressReport) { 222 + if f != nil { 223 + return f 224 + } 225 + return func(pr ProgressReport) { 226 + slog.Info("sync progress", 227 + slog.Int("completed", pr.Completed), 228 + slog.Int("total", pr.Total), 229 + slog.Float64("percent", pr.Percent), 230 + slog.String("elapsed", pr.Elapsed), 231 + slog.String("eta", pr.ETA), 232 + slog.String("rate", pr.Rate), 233 + slog.Int("errors", pr.Errors), 234 + ) 235 + } 236 + } 237 + 238 + func defaultBatchSize(size int) int { 239 + if size > 0 { 240 + return size 241 + } 242 + return DefaultBatchSize 243 + } 244 + 245 + func newPublishResult(success, errors, total int, start time.Time, cancelled bool) PublishResult { 246 + return PublishResult{ 247 + SuccessCount: success, 248 + ErrorCount: errors, 249 + Cancelled: cancelled, 250 + Duration: time.Since(start), 251 + TotalRecords: total, 252 + RecordsPerMinute: ratePerMinute(success, time.Since(start)), 253 + } 254 + } 255 + 256 + func logResult(success, errors int, startTime time.Time) { 257 + if errors > 0 { 258 + slog.Warn("import completed with errors", 259 + slog.Int("success", success), 260 + slog.Int("errors", errors)) 261 + } 262 + slog.Info("import completed", 263 + slog.Int("success", success), 264 + slog.Int("errors", errors), 265 + slog.Duration("duration", time.Since(startTime)), 266 + slog.String("rate", formatRate(ratePerMinute(success, time.Since(startTime))))) 267 + } 268 + 269 + func PublishBatch(ctx context.Context, client ATProtoClient, did string, batch []PlayRecord, storage cache.Storage) error { 270 + if len(batch) == 0 { 271 + return nil 272 + } 273 + 274 + err := client.ApplyWrites(ctx, RecordType, batch) 275 + if err != nil { 276 + slog.Error("batch publish failed", Error(err)) 277 + return err 278 + } 279 + 280 + if storage != nil && did != "" { 281 + keys := CreateRecordKeys(batch) 282 + cacheEntries := make(map[string][]byte) 283 + for i, rec := range batch { 284 + key := keys[i] 285 + value, _ := json.Marshal(rec) 286 + cacheEntries[key] = value 287 + } 288 + 289 + if err := storage.SaveRecords(did, cacheEntries); err != nil { 290 + return fmt.Errorf("failed to save records to storage: %w", err) 291 + } 292 + 293 + if err := storage.MarkPublished(did, keys...); err != nil { 294 + return fmt.Errorf("failed to mark records as published: %w", err) 295 + } 296 + } 297 + 298 + return nil 299 + } 300 + 301 + func prepareRecords(batch []PlayRecord) []PlayRecord { 302 + atprotoRecords := make([]PlayRecord, 0, len(batch)) 303 + for _, record := range batch { 304 + record.Type = RecordType 305 + record.SubmissionClientAgent = ClientAgent 306 + atprotoRecords = append(atprotoRecords, record) 307 + } 308 + return atprotoRecords 309 + } 310 + 311 + func ratePerMinute(count int, duration time.Duration) float64 { 312 + if duration == 0 { 313 + return 0 314 + } 315 + return float64(count) / duration.Minutes() 316 + } 317 + 318 + func FetchExisting(ctx context.Context, client RepoClient[PlayRecord], did string, storage cache.Storage, forceRefresh bool) ([]ExistingRecord, error) { 319 + if !forceRefresh && storage != nil { 320 + published, err := storage.GetPublished(did) 321 + if err == nil && len(published) > 0 && storage.IsValid(did) { 322 + records := make([]ExistingRecord, 0, len(published)) 323 + err := storage.IteratePublished(did, func(key string, data []byte) error { 324 + var value PlayRecord 325 + if err := json.Unmarshal(data, &value); err != nil { 326 + return nil 327 + } 328 + records = append(records, ExistingRecord{ 329 + URI: generateRecordURI(did, value), 330 + Value: value, 331 + }) 332 + return nil 333 + }) 334 + if err == nil { 335 + slog.Debug("loaded from cache", slog.Int("count", len(records))) 336 + return records, nil 337 + } 338 + } 339 + } 340 + 341 + select { 342 + case <-ctx.Done(): 343 + return nil, ctx.Err() 344 + default: 345 + } 346 + 347 + allRecords := make([]ExistingRecord, 0, 1024) 348 + return fetchExistingLoop(ctx, client, did, storage, allRecords) 349 + } 350 + 351 + func fetchExistingLoop(ctx context.Context, client RepoClient[PlayRecord], did string, storage cache.Storage, allRecords []ExistingRecord) ([]ExistingRecord, error) { 352 + const batchSize = 100 353 + var cursor string 354 + 355 + type fetchResult struct { 356 + records []atproto.RecordRef[PlayRecord] 357 + cursor string 358 + } 359 + 360 + retryPolicy := retrypolicy.NewBuilder[fetchResult](). 361 + WithMaxRetries(10). 362 + WithBackoff(BaseRetryDelay, 5*time.Minute). 363 + HandleIf(func(_ fetchResult, err error) bool { 364 + return IsTransientError(err) 365 + }). 366 + OnRetryScheduled(func(e failsafe.ExecutionScheduledEvent[fetchResult]) { 367 + slog.Warn("fetch failed with transient error, retrying", 368 + slog.Duration("retryDelay", e.Delay), 369 + Error(e.LastError()), 370 + slog.Int("attempt", e.Attempts())) 371 + }). 372 + Build() 373 + 374 + for { 375 + select { 376 + case <-ctx.Done(): 377 + return nil, ctx.Err() 378 + default: 379 + } 380 + 381 + result, err := failsafe.With(retryPolicy). 382 + WithContext(ctx). 383 + Get(func() (fetchResult, error) { 384 + recs, next, err := client.ListRecords(ctx, RecordType, batchSize, cursor) 385 + if err != nil { 386 + return fetchResult{}, err 387 + } 388 + 389 + return fetchResult{records: recs, cursor: next}, nil 390 + }) 391 + if err != nil { 392 + return nil, err 393 + } 394 + 395 + for _, rec := range result.records { 396 + allRecords = append(allRecords, ExistingRecord(rec)) 397 + } 398 + 399 + if result.cursor == "" || len(result.records) < batchSize { 400 + break 401 + } 402 + cursor = result.cursor 403 + } 404 + 405 + if storage != nil { 406 + cacheEntries := make(map[string][]byte) 407 + keys := make([]string, 0, len(allRecords)) 408 + for _, rec := range allRecords { 409 + parts := strings.Split(rec.URI, "/") 410 + key := parts[len(parts)-1] 411 + if key == "" { 412 + key = CreateRecordKey(rec.Value) 413 + } 414 + value, _ := json.Marshal(rec.Value) 415 + cacheEntries[key] = value 416 + keys = append(keys, key) 417 + } 418 + 419 + if err := storage.SaveRecords(did, cacheEntries); err != nil { 420 + return nil, err 421 + } 422 + 423 + if err := storage.MarkPublished(did, keys...); err != nil { 424 + return nil, err 425 + } 426 + 427 + slog.Debug("saved to cache and marked as published", slog.Int("count", len(allRecords))) 428 + } 429 + 430 + return allRecords, nil 431 + } 432 + 433 + func generateRecordURI(did string, record PlayRecord) string { 434 + return fmt.Sprintf("at://%s/%s/%s", did, RecordType, CreateRecordKey(record)) 435 + } 436 + 437 + func prepareWrites(records []PlayRecord, collection string) ([]map[string]any, error) { 438 + if len(records) == 0 { 439 + return nil, nil 440 + } 441 + 442 + writes := make([]map[string]any, len(records)) 443 + keys := CreateRecordKeys(records) 444 + 445 + for i, rec := range records { 446 + writes[i] = map[string]any{ 447 + "$type": "com.atproto.repo.applyWrites#create", 448 + "collection": collection, 449 + "rkey": keys[i], 450 + "value": rec, 451 + } 452 + } 453 + 454 + return writes, nil 455 + }
+16 -35
sync/rate.go atproto/rate.go
··· 1 - package sync 1 + package atproto 2 2 3 3 import ( 4 4 "context" ··· 9 9 "math" 10 10 "sync" 11 11 "time" 12 - 13 - "tangled.org/karitham.dev/lazuli/sync/logutil" 14 12 ) 15 13 16 14 const ( 17 - // Limits 18 15 WriteLimitMinute = 100 19 16 WriteLimitHour = 1000 20 17 WriteLimitDay = 10000 ··· 23 20 GlobalLimitHour = 3000 24 21 GlobalLimitDay = 35000 25 22 26 - // Costs 27 23 ReadGlobalCost = 1 28 24 WriteOnlyCost = 1 29 25 WriteGlobalCost = 3 30 26 ) 31 27 32 28 type RateLimiter interface { 33 - // AllowBulkWrite blocks or returns error until N writes are permissible. 34 - // Returns the timestamp of when the quota was charged for bucket-accurate refunds. 35 29 AllowBulkWrite(ctx context.Context, n int) (time.Time, error) 36 - // AllowRead blocks or returns error until a read is permissible. 37 30 AllowRead(ctx context.Context) (time.Time, error) 38 - // RefundBulkWrite restores N writes to the quota using the original charge time. 39 31 RefundBulkWrite(ctx context.Context, n int, chargedAt time.Time) 40 - // RefundRead restores a read to the quota using the original charge time. 41 32 RefundRead(ctx context.Context, chargedAt time.Time) 42 - // Stats returns current consumption (writes, global) 43 33 Stats() (int, int, error) 44 - // EstimatedWriteTime returns how long it would take to write n records 45 - // given current consumption and rate limits. 46 34 EstimatedWriteTime(n int) time.Duration 47 - // RemainingQuota returns remaining write and global quota at each tier, 48 - // plus the time until the next tier reset (the maximum wait). 49 35 RemainingQuota() (writesRemaining, globalRemaining int, timeUntilReset time.Duration) 50 36 } 51 37 ··· 115 101 if maxWait > 0 { 116 102 l.mu.Unlock() 117 103 slog.Debug("Rate limit reached, sleeping until next window", slog.Duration("wait", maxWait.Round(time.Second))) 118 - // Add a tiny bit of buffer + jitter to ensure we are definitely in the next window 119 104 wait := maxWait + 100*time.Millisecond + addJitter(100*time.Millisecond) 120 105 if err := l.sleep(ctx, wait); err != nil { 121 106 return now, err ··· 123 108 continue 124 109 } 125 110 126 - // Charge quota while holding the lock 127 111 err = l.charge(wKeys, gKeys, wCost, gCost) 128 112 l.mu.Unlock() 129 113 ··· 158 142 continue 159 143 } 160 144 161 - // Charge quota while holding the lock 162 145 err = l.charge(nil, gKeys, 0, gCost) 163 146 l.mu.Unlock() 164 147 ··· 191 174 } 192 175 193 176 if err := l.kv.IncrByMulti(updates); err != nil { 194 - slog.Error("failed to refund write quota", logutil.Error(err)) 177 + slog.Error("failed to refund write quota", Error(err)) 195 178 } 196 179 } 197 180 ··· 208 191 } 209 192 210 193 if err := l.kv.IncrByMulti(updates); err != nil { 211 - slog.Error("failed to refund global quota (read)", logutil.Error(err)) 194 + slog.Error("failed to refund global quota (read)", Error(err)) 212 195 } 213 196 } 214 197 215 - // checkQuota checks if the proposed cost fits within the limits. 216 - // Returns the wait duration if over limit (0 if OK), or error. 217 - // Must be called with lock held. 218 198 func (l *quotaLimiter) checkQuota(now time.Time, wKeys, gKeys []string, wCost, gCost int) (time.Duration, error) { 219 199 wLimits := []int{WriteLimitMinute, WriteLimitHour, WriteLimitDay} 220 200 gLimits := []int{GlobalLimitMinute, GlobalLimitHour, GlobalLimitDay} ··· 251 231 return maxWait, nil 252 232 } 253 233 254 - // charge applies the cost to the keys. 255 - // Must be called with lock held. 256 234 func (l *quotaLimiter) charge(wKeys, gKeys []string, wCost, gCost int) error { 257 235 updates := make(map[string]int, len(wKeys)+len(gKeys)) 258 236 for _, k := range wKeys { ··· 274 252 275 253 func (l *quotaLimiter) untilNextWindow(now time.Time, tier int) time.Duration { 276 254 switch tier { 277 - case 0: // Minute 255 + case 0: 278 256 return now.Truncate(time.Minute).Add(time.Minute).Sub(now) 279 - case 1: // Hour 257 + case 1: 280 258 return now.Truncate(time.Hour).Add(time.Hour).Sub(now) 281 - case 2: // Day 259 + case 2: 282 260 return now.Truncate(24 * time.Hour).Add(24 * time.Hour).Sub(now) 283 261 default: 284 - // Safety fallback: if unknown tier, wait a reasonable amount (e.g. 1 minute) 285 - // to prevent busy loops or bypassing limits. 286 262 slog.Warn("unknown rate limit tier encountered", slog.Int("tier", tier)) 287 263 return time.Minute 288 264 } ··· 292 268 var b [8]byte 293 269 _, _ = rand.Read(b[:]) 294 270 n := binary.LittleEndian.Uint64(b[:]) 295 - // Add 0-10% jitter 296 271 jitter := float64(d) * 0.1 * (float64(n) / math.MaxUint64) 297 272 return d + time.Duration(jitter) 298 273 } ··· 336 311 for i, k := range wKeys { 337 312 curr := values[k] 338 313 limit := int(float32(WriteLimitMinute) * l.rlQuota) 339 - if i == 1 { 314 + switch i { 315 + case 1: 340 316 limit = int(float32(WriteLimitHour) * l.rlQuota) 341 - } else if i == 2 { 317 + case 2: 342 318 limit = int(float32(WriteLimitDay) * l.rlQuota) 343 319 } 344 320 if curr+wCost > limit { ··· 349 325 for i, k := range gKeys { 350 326 curr := values[k] 351 327 limit := int(float32(GlobalLimitMinute) * l.rlQuota) 352 - if i == 1 { 328 + switch i { 329 + case 1: 353 330 limit = int(float32(GlobalLimitHour) * l.rlQuota) 354 - } else if i == 2 { 331 + case 2: 355 332 limit = int(float32(GlobalLimitDay) * l.rlQuota) 356 333 } 357 334 if curr+gCost > limit { ··· 426 403 427 404 return minWritesRemaining, minGlobalRemaining, maxResetTime 428 405 } 406 + 407 + func Error(err error) slog.Attr { 408 + return slog.String("error", err.Error()) 409 + }
+182 -121
sync/rate_test.go
··· 3 3 import ( 4 4 "context" 5 5 "encoding/json" 6 - "strings" 6 + "fmt" 7 7 "testing" 8 8 "time" 9 9 10 10 "github.com/bluesky-social/indigo/atproto/atclient" 11 + "tangled.org/karitham.dev/lazuli/atproto" 11 12 ) 12 13 13 14 type mockKV struct { ··· 29 30 return nil 30 31 } 31 32 32 - // Helper for tests that inspect internal state directly 33 - func (m *mockKV) Get(key string) (int, error) { 34 - return m.data[key], nil 33 + type testClock struct { 34 + now time.Time 35 35 } 36 + 37 + func (m *testClock) Now() time.Time { return m.now } 36 38 37 39 func TestRateLimiter_Refunds(t *testing.T) { 38 40 kv := &mockKV{data: make(map[string]int)} 39 - clock := &mockClock{now: time.Date(2026, 1, 22, 12, 0, 0, 0, time.UTC)} 40 - limiter := &quotaLimiter{ 41 + clock := &testClock{now: time.Date(2026, 1, 22, 12, 0, 0, 0, time.UTC)} 42 + limiter := &testRateLimiter{ 41 43 kv: kv, 42 44 prefix: "quota", 43 45 clock: clock, ··· 45 47 } 46 48 ctx := context.Background() 47 49 48 - // Test Bulk Write Refund 49 50 chargedAt, _ := limiter.AllowBulkWrite(ctx, 10) 50 51 limiter.RefundBulkWrite(ctx, 10, chargedAt) 51 52 ··· 54 55 t.Errorf("BulkWrite refund failed: w=%d, g=%d", w, g) 55 56 } 56 57 57 - // Test Read Refund 58 58 chargedAt, _ = limiter.AllowRead(ctx) 59 59 limiter.RefundRead(ctx, chargedAt) 60 60 ··· 64 64 } 65 65 } 66 66 67 + type testRateLimiter struct { 68 + kv *mockKV 69 + prefix string 70 + clock *testClock 71 + rlQuota float32 72 + } 73 + 74 + func (l *testRateLimiter) Stats() (int, int, error) { 75 + wd, gd, _, _, _, _ := l.getKeys(l.clock.now) 76 + vals, err := l.kv.GetMulti([]string{wd, gd}) 77 + if err != nil { 78 + return 0, 0, err 79 + } 80 + return vals[wd], vals[gd], nil 81 + } 82 + 83 + func (l *testRateLimiter) AllowBulkWrite(ctx context.Context, n int) (time.Time, error) { 84 + wCost := n * atproto.WriteOnlyCost 85 + gCost := n * atproto.WriteGlobalCost 86 + 87 + for { 88 + now := l.clock.now 89 + wKeys, gKeys := l.getAllKeys(now) 90 + 91 + maxWait, err := l.checkQuota(now, wKeys, gKeys, wCost, gCost) 92 + if err != nil { 93 + return now, err 94 + } 95 + 96 + if maxWait > 0 { 97 + return now, context.DeadlineExceeded 98 + } 99 + 100 + err = l.charge(wKeys, gKeys, wCost, gCost) 101 + if err != nil { 102 + return now, err 103 + } 104 + return now, nil 105 + } 106 + } 107 + 108 + func (l *testRateLimiter) AllowRead(ctx context.Context) (time.Time, error) { 109 + gCost := atproto.ReadGlobalCost 110 + 111 + for { 112 + now := l.clock.now 113 + _, gKeys := l.getAllKeys(now) 114 + 115 + maxWait, err := l.checkQuota(now, nil, gKeys, 0, gCost) 116 + if err != nil { 117 + return now, err 118 + } 119 + 120 + if maxWait > 0 { 121 + return now, context.DeadlineExceeded 122 + } 123 + 124 + err = l.charge(nil, gKeys, 0, gCost) 125 + if err != nil { 126 + return now, err 127 + } 128 + return now, nil 129 + } 130 + } 131 + 132 + func (l *testRateLimiter) RefundBulkWrite(ctx context.Context, n int, chargedAt time.Time) { 133 + wKeys, gKeys := l.getAllKeys(chargedAt) 134 + wCost := n * atproto.WriteOnlyCost 135 + gCost := n * atproto.WriteGlobalCost 136 + 137 + updates := make(map[string]int, len(wKeys)+len(gKeys)) 138 + for _, k := range wKeys { 139 + updates[k] = -wCost 140 + } 141 + for _, k := range gKeys { 142 + updates[k] = -gCost 143 + } 144 + 145 + l.kv.IncrByMulti(updates) 146 + } 147 + 148 + func (l *testRateLimiter) RefundRead(ctx context.Context, chargedAt time.Time) { 149 + _, gKeys := l.getAllKeys(chargedAt) 150 + gCost := atproto.ReadGlobalCost 151 + 152 + updates := make(map[string]int, len(gKeys)) 153 + for _, k := range gKeys { 154 + updates[k] = -gCost 155 + } 156 + 157 + l.kv.IncrByMulti(updates) 158 + } 159 + 160 + func (l *testRateLimiter) getKeys(t time.Time) (string, string, string, string, string, string) { 161 + day := t.Format("2006-01-02") 162 + hour := t.Format("2006-01-02-15") 163 + minute := t.Format("2006-01-02-15-04") 164 + return fmt.Sprintf("%s:writes:d:%s", l.prefix, day), fmt.Sprintf("%s:global:d:%s", l.prefix, day), 165 + fmt.Sprintf("%s:writes:h:%s", l.prefix, hour), fmt.Sprintf("%s:global:h:%s", l.prefix, hour), 166 + fmt.Sprintf("%s:writes:m:%s", l.prefix, minute), fmt.Sprintf("%s:global:m:%s", l.prefix, minute) 167 + } 168 + 169 + func (l *testRateLimiter) getAllKeys(t time.Time) ([]string, []string) { 170 + wd, gd, wh, gh, wm, gm := l.getKeys(t) 171 + return []string{wm, wh, wd}, []string{gm, gh, gd} 172 + } 173 + 174 + func (l *testRateLimiter) checkQuota(now time.Time, wKeys, gKeys []string, wCost, gCost int) (time.Duration, error) { 175 + wLimits := []int{atproto.WriteLimitMinute, atproto.WriteLimitHour, atproto.WriteLimitDay} 176 + gLimits := []int{atproto.GlobalLimitMinute, atproto.GlobalLimitHour, atproto.GlobalLimitDay} 177 + 178 + allKeys := make([]string, 0, len(wKeys)+len(gKeys)) 179 + allKeys = append(allKeys, wKeys...) 180 + allKeys = append(allKeys, gKeys...) 181 + 182 + if len(allKeys) == 0 { 183 + return 0, nil 184 + } 185 + 186 + values, err := l.kv.GetMulti(allKeys) 187 + if err != nil { 188 + return 0, err 189 + } 190 + 191 + maxWait := time.Duration(0) 192 + 193 + for i, k := range wKeys { 194 + curr := values[k] 195 + if curr+wCost > int(float32(wLimits[i])*l.rlQuota) { 196 + maxWait = max(l.untilNextWindow(now, i), maxWait) 197 + } 198 + } 199 + 200 + for i, k := range gKeys { 201 + curr := values[k] 202 + if curr+gCost > int(float32(gLimits[i])*l.rlQuota) { 203 + maxWait = max(l.untilNextWindow(now, i), maxWait) 204 + } 205 + } 206 + 207 + return maxWait, nil 208 + } 209 + 210 + func (l *testRateLimiter) charge(wKeys, gKeys []string, wCost, gCost int) error { 211 + updates := make(map[string]int, len(wKeys)+len(gKeys)) 212 + for _, k := range wKeys { 213 + updates[k] = wCost 214 + } 215 + for _, k := range gKeys { 216 + updates[k] = gCost 217 + } 218 + 219 + if len(updates) == 0 { 220 + return nil 221 + } 222 + 223 + return l.kv.IncrByMulti(updates) 224 + } 225 + 226 + func (l *testRateLimiter) untilNextWindow(now time.Time, tier int) time.Duration { 227 + switch tier { 228 + case 0: 229 + return now.Truncate(time.Minute).Add(time.Minute).Sub(now) 230 + case 1: 231 + return now.Truncate(time.Hour).Add(time.Hour).Sub(now) 232 + case 2: 233 + return now.Truncate(24 * time.Hour).Add(24 * time.Hour).Sub(now) 234 + default: 235 + return time.Minute 236 + } 237 + } 238 + 67 239 func TestRateLimiter_Weighting(t *testing.T) { 68 240 kv := &mockKV{data: make(map[string]int)} 69 - limiter := NewRateLimiter(kv, 1) 241 + limiter := atproto.NewRateLimiter(kv, 1) 70 242 ctx := context.Background() 71 243 72 - // 1 Read = 1 Global 73 244 _, err := limiter.AllowRead(ctx) 74 245 if err != nil { 75 246 t.Fatal(err) ··· 82 253 t.Errorf("expected 1 global unit, got %d", g) 83 254 } 84 255 85 - // 1 Write = 1 Write-Only + 3 Global 86 256 _, err = limiter.AllowBulkWrite(ctx, 1) 87 257 if err != nil { 88 258 t.Fatal(err) ··· 94 264 if w != 1 { 95 265 t.Errorf("expected 1 write unit, got %d", w) 96 266 } 97 - if g != 4 { // 1 from read + 3 from write 267 + if g != 4 { 98 268 t.Errorf("expected 4 global units, got %d", g) 99 269 } 100 270 101 - // Bulk Write (10 elements) = 10 Write-Only + 30 Global 102 271 _, err = limiter.AllowBulkWrite(ctx, 10) 103 272 if err != nil { 104 273 t.Fatal(err) ··· 112 281 } 113 282 if g != 34 { 114 283 t.Errorf("expected 34 global units, got %d", g) 115 - } 116 - } 117 - 118 - func TestRateLimiter_Smoothing(t *testing.T) { 119 - kv := &mockKV{data: make(map[string]int)} 120 - // Ensure we are at the very beginning of the minute to avoid window edge issues in test 121 - clock := &mockClock{now: time.Date(2026, 1, 22, 1, 0, 0, 0, time.UTC)} 122 - limiter := &quotaLimiter{ 123 - kv: kv, 124 - prefix: "quota", 125 - clock: clock, 126 - rlQuota: 1, 127 - } 128 - 129 - wd, gd, wh, gh, wm, gm := limiter.getKeys(clock.now) 130 - kv.data[wm] = WriteLimitMinute 131 - kv.data[wh] = 0 132 - kv.data[wd] = 0 133 - kv.data[gm] = 0 134 - kv.data[gh] = 0 135 - kv.data[gd] = 0 136 - 137 - // Use a context that is already cancelled to simulate what happens 138 - // when we hit the rate limit and can't proceed. 139 - // Ensure the store thinks we are over the limit 140 - kv.data[wm] = WriteLimitMinute + 1 141 - 142 - ctx, cancel := context.WithCancel(context.Background()) 143 - cancel() 144 - 145 - _, err := limiter.AllowBulkWrite(ctx, 1) 146 - if err == nil { 147 - t.Errorf("expected error, got nil") 148 284 } 149 285 } 150 286 ··· 176 312 } 177 313 if res.ErrorCount != 1 { 178 314 t.Errorf("expected 1 error, got %d", res.ErrorCount) 179 - } 180 - 181 - storage.mu.Lock() 182 - errStr, ok := storage.failed["k1"] 183 - storage.mu.Unlock() 184 - 185 - if !ok { 186 - t.Error("expected record k1 to be marked as failed") 187 - } 188 - if !strings.Contains(errStr, "503") { 189 - t.Errorf("expected error string to contain 503, got %q", errStr) 190 - } 191 - } 192 - 193 - type mockClock struct { 194 - now time.Time 195 - } 196 - 197 - func (m *mockClock) Now() time.Time { return m.now } 198 - 199 - func TestRateLimiter_MidnightRollover(t *testing.T) { 200 - kv := &mockKV{data: make(map[string]int)} 201 - clock := &mockClock{now: time.Date(2026, 1, 22, 23, 59, 59, 0, time.UTC)} 202 - limiter := &quotaLimiter{ 203 - kv: kv, 204 - prefix: "quota", 205 - clock: clock, 206 - rlQuota: 1, 207 - } 208 - ctx := context.Background() 209 - 210 - // 1. Consume some quota on day 1 211 - _, err := limiter.AllowBulkWrite(ctx, 10) 212 - if err != nil { 213 - t.Fatal(err) 214 - } 215 - 216 - w1, g1, err := limiter.Stats() 217 - if err != nil { 218 - t.Fatal(err) 219 - } 220 - if w1 != 10 || g1 != 30 { 221 - t.Errorf("expected w=10, g=30 on day 1, got w=%d, g=%d", w1, g1) 222 - } 223 - 224 - // 2. Advance time to next day 225 - clock.now = clock.now.Add(2 * time.Second) // 00:00:01 on 2026-01-23 226 - 227 - // 3. Stats should now reflect day 2 (0) 228 - w2, g2, err := limiter.Stats() 229 - if err != nil { 230 - t.Fatal(err) 231 - } 232 - if w2 != 0 || g2 != 0 { 233 - t.Errorf("expected w=0, g=0 on day 2, got w=%d, g=%d", w2, g2) 234 - } 235 - 236 - // 4. Consumption on day 2 should not affect day 1 237 - _, err = limiter.AllowBulkWrite(ctx, 5) 238 - if err != nil { 239 - t.Fatal(err) 240 - } 241 - 242 - w2, g2, err = limiter.Stats() 243 - if err != nil { 244 - t.Fatal(err) 245 - } 246 - if w2 != 5 || g2 != 15 { 247 - t.Errorf("expected w=5, g=15 on day 2, got w=%d, g=%d", w2, g2) 248 - } 249 - 250 - // Verify day 1 keys are still there but not accessed by Stats() 251 - day1WKey := "quota:writes:d:2026-01-22" 252 - if val, _ := kv.Get(day1WKey); val != 10 { 253 - t.Errorf("expected day 1 write key to still be 10, got %d", val) 254 315 } 255 316 }
+29
sync/record.go
··· 287 287 } 288 288 return duplicates 289 289 } 290 + 291 + type Timestamp struct { 292 + time.Time 293 + } 294 + 295 + func (t Timestamp) MarshalJSON() ([]byte, error) { 296 + return []byte(`"` + t.Format(time.RFC3339Nano) + `"`), nil 297 + } 298 + 299 + func (t *Timestamp) UnmarshalJSON(data []byte) error { 300 + if string(data) == "null" { 301 + *t = Timestamp{} 302 + return nil 303 + } 304 + 305 + s := string(data) 306 + s = s[1 : len(s)-1] 307 + 308 + tm, err := time.Parse(time.RFC3339Nano, s) 309 + if err != nil { 310 + tm, err = time.Parse(time.RFC3339, s) 311 + if err != nil { 312 + return fmt.Errorf("failed to parse timestamp %q: %w", s, err) 313 + } 314 + } 315 + 316 + *t = Timestamp{Time: tm} 317 + return nil 318 + }
-35
sync/timestamp.go
··· 1 - package sync 2 - 3 - import ( 4 - "fmt" 5 - "time" 6 - ) 7 - 8 - type Timestamp struct { 9 - time.Time 10 - } 11 - 12 - func (t Timestamp) MarshalJSON() ([]byte, error) { 13 - return []byte(`"` + t.Format(time.RFC3339Nano) + `"`), nil 14 - } 15 - 16 - func (t *Timestamp) UnmarshalJSON(data []byte) error { 17 - if string(data) == "null" { 18 - *t = Timestamp{} 19 - return nil 20 - } 21 - 22 - s := string(data) 23 - s = s[1 : len(s)-1] 24 - 25 - tm, err := time.Parse(time.RFC3339Nano, s) 26 - if err != nil { 27 - tm, err = time.Parse(time.RFC3339, s) 28 - if err != nil { 29 - return fmt.Errorf("failed to parse timestamp %q: %w", s, err) 30 - } 31 - } 32 - 33 - *t = Timestamp{Time: tm} 34 - return nil 35 - }