Go boilerplate library for building atproto apps
atproto go
1
fork

Configure Feed

Select the types of activity you want to include in your feed.

feat: add tracing, public client, and jetstream consumer

+1026 -121
+2 -11
client.go
··· 29 29 return &Client{api: api, did: did} 30 30 } 31 31 32 - // DID returns the authenticated user's DID. 33 32 func (c *Client) DID() syntax.DID { return c.did } 34 33 35 34 // APIClient returns the underlying indigo APIClient for advanced usage 36 35 // (custom XRPC calls, service proxying, etc.). 37 36 func (c *Client) APIClient() *atclient.APIClient { return c.api } 38 37 39 - // CreateRecord creates a new record with an auto-generated TID key. 40 38 func (c *Client) CreateRecord(ctx context.Context, collection string, record any) (uri, cid string, err error) { 41 39 body := map[string]any{ 42 40 "repo": c.did.String(), ··· 54 52 return result.URI, result.CID, nil 55 53 } 56 54 57 - // CreateRecordWithRKey creates a new record with a specific record key. 58 55 func (c *Client) CreateRecordWithRKey(ctx context.Context, collection, rkey string, record any) (uri, cid string, err error) { 59 56 body := map[string]any{ 60 57 "repo": c.did.String(), ··· 92 89 return &Record{URI: result.URI, CID: result.CID, Value: result.Value}, nil 93 90 } 94 91 95 - // ListRecords retrieves a single page of records from a collection. 96 - // Pass limit <= 0 for the server default (usually 50). Pass empty cursor for the first page. 92 + // Pass limit <= 0 for the server default (usually 50). 93 + // Pass empty cursor for the first page. 97 94 func (c *Client) ListRecords(ctx context.Context, collection string, limit int, cursor string) (*ListResult, error) { 98 95 params := map[string]any{ 99 96 "repo": c.did.String(), ··· 130 127 return out, nil 131 128 } 132 129 133 - // ListAllRecords fetches every record in a collection, handling cursor 134 - // pagination automatically. Returns all records at once. 135 130 func (c *Client) ListAllRecords(ctx context.Context, collection string) ([]Record, error) { 136 131 var all []Record 137 132 cursor := "" ··· 157 152 return all, nil 158 153 } 159 154 160 - // PutRecord creates or updates a record at a specific record key. 161 155 func (c *Client) PutRecord(ctx context.Context, collection, rkey string, record any) (uri, cid string, err error) { 162 156 body := map[string]any{ 163 157 "repo": c.did.String(), ··· 176 170 return result.URI, result.CID, nil 177 171 } 178 172 179 - // DeleteRecord removes a record from the user's repository. 180 173 func (c *Client) DeleteRecord(ctx context.Context, collection, rkey string) error { 181 174 body := map[string]any{ 182 175 "repo": c.did.String(), ··· 191 184 return nil 192 185 } 193 186 194 - // UploadBlob uploads a blob to the user's PDS. 195 187 // Data must be at most 1 MB (MaxBlobSize). The mimeType should match the blob content. 196 188 func (c *Client) UploadBlob(ctx context.Context, data []byte, mimeType string) (*BlobRef, error) { 197 189 if len(data) > MaxBlobSize { ··· 235 227 }, nil 236 228 } 237 229 238 - // GetBlob downloads a blob from the user's PDS by its CID. 239 230 func (c *Client) GetBlob(ctx context.Context, cid string) ([]byte, error) { 240 231 data, err := atproto.SyncGetBlob(ctx, c.api, cid, c.did.String()) 241 232 if err != nil {
+1
errors.go
··· 12 12 13 13 // WrapPDSError inspects an XRPC error for signals that the OAuth grant is no 14 14 // longer valid and, if so, wraps it with ErrSessionExpired. 15 + // TODO: handle other common error types 15 16 func WrapPDSError(err error) error { 16 17 if err == nil { 17 18 return nil
+2
go.mod
··· 4 4 5 5 require ( 6 6 github.com/bluesky-social/indigo v0.0.0-20260318212431-cbaa83aee9dd 7 + github.com/gorilla/websocket v1.5.3 8 + github.com/klauspost/compress v1.17.3 7 9 github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c 8 10 go.etcd.io/bbolt v1.4.3 9 11 go.opentelemetry.io/otel v1.43.0
+4
go.sum
··· 30 30 github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA= 31 31 github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= 32 32 github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= 33 + github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= 34 + github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= 33 35 github.com/grpc-ecosystem/grpc-gateway/v2 v2.28.0 h1:HWRh5R2+9EifMyIHV7ZV+MIZqgz+PMpZ14Jynv3O2Zs= 34 36 github.com/grpc-ecosystem/grpc-gateway/v2 v2.28.0/go.mod h1:JfhWUomR1baixubs02l85lZYYOm7LV6om4ceouMv45c= 35 37 github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k= 36 38 github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM= 37 39 github.com/ipfs/go-cid v0.4.1 h1:A/T3qGvxi4kpKWWcPC/PgbvDA2bjVLO7n4UeVwnbs/s= 38 40 github.com/ipfs/go-cid v0.4.1/go.mod h1:uQHwDeX4c6CtyrFwdqyhpNcxVewur1M7l7fNU7LKwZk= 41 + github.com/klauspost/compress v1.17.3 h1:qkRjuerhUU1EmXLYGkSH6EZL+vPSxIrYjLNAK4slzwA= 42 + github.com/klauspost/compress v1.17.3/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM= 39 43 github.com/klauspost/cpuid/v2 v2.2.7 h1:ZWSB3igEs+d0qvnxR/ZBzXVmxkgt8DdzP6m9pfuVLDM= 40 44 github.com/klauspost/cpuid/v2 v2.2.7/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws= 41 45 github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
+360
jetstream/jetstream.go
··· 1 + // Package jetstream consumes real-time AT Protocol events from a Jetstream relay. 2 + // 3 + // Jetstream is a WebSocket-based relay that delivers a filtered stream of AT 4 + // Protocol repository events (commits, identity changes, account changes). 5 + // This package handles connection management, reconnection with backoff, 6 + // endpoint rotation, cursor tracking, and optional zstd decompression. 7 + // 8 + // Basic usage: 9 + // 10 + // consumer := jetstream.New(&jetstream.Config{ 11 + // WantedCollections: []string{"app.bsky.feed.post"}, 12 + // }, func(ctx context.Context, evt *jetstream.Event) error { 13 + // fmt.Printf("new post from %s\n", evt.DID) 14 + // return nil 15 + // }) 16 + // consumer.Start(ctx) 17 + // defer consumer.Stop() 18 + package jetstream 19 + 20 + import ( 21 + "context" 22 + "encoding/json" 23 + "fmt" 24 + "net/url" 25 + "sync" 26 + "sync/atomic" 27 + "time" 28 + 29 + "github.com/gorilla/websocket" 30 + "github.com/klauspost/compress/zstd" 31 + ) 32 + 33 + // DefaultEndpoints are the public Jetstream relay endpoints. 34 + var DefaultEndpoints = []string{ 35 + "wss://jetstream1.us-east.bsky.network/subscribe", 36 + "wss://jetstream2.us-east.bsky.network/subscribe", 37 + "wss://jetstream1.us-west.bsky.network/subscribe", 38 + "wss://jetstream2.us-west.bsky.network/subscribe", 39 + } 40 + 41 + // Event is a single event from the Jetstream relay. 42 + type Event struct { 43 + DID string `json:"did"` 44 + TimeUS int64 `json:"time_us"` 45 + Kind string `json:"kind"` // "commit", "identity", "account" 46 + Commit *Commit `json:"commit,omitempty"` 47 + } 48 + 49 + // Commit is the commit payload within an Event. 50 + type Commit struct { 51 + Rev string `json:"rev"` 52 + Operation string `json:"operation"` // "create", "update", "delete" 53 + Collection string `json:"collection"` 54 + RKey string `json:"rkey"` 55 + Record json.RawMessage `json:"record,omitempty"` 56 + CID string `json:"cid"` 57 + } 58 + 59 + // Handler is called for each event received from Jetstream. 60 + // Returning an error logs a warning but does not stop the consumer. 61 + type Handler func(ctx context.Context, event *Event) error 62 + 63 + // CursorStore persists the Jetstream cursor across restarts. 64 + // If nil, the cursor is tracked in memory only (replay from live on restart). 65 + type CursorStore interface { 66 + GetCursor(ctx context.Context) (int64, error) 67 + SetCursor(ctx context.Context, cursor int64) error 68 + } 69 + 70 + // Config configures a Jetstream consumer. 71 + type Config struct { 72 + // Endpoints is the list of Jetstream WebSocket URLs. Defaults to DefaultEndpoints. 73 + Endpoints []string 74 + 75 + // WantedCollections filters events to specific NSIDs. 76 + // Empty means all collections (high volume). 77 + WantedCollections []string 78 + 79 + // Compress enables zstd compression. Disabled by default because Jetstream 80 + // uses a custom dictionary incompatible with the standard zstd decoder. 81 + Compress bool 82 + 83 + // CursorStore persists the cursor for resume after restart. 84 + // If nil, the consumer starts from live on each restart. 85 + CursorStore CursorStore 86 + 87 + // CursorPersistEvery controls how often the cursor is flushed to CursorStore. 88 + // Defaults to every 1000 events. 89 + CursorPersistEvery int64 90 + 91 + // OnConnect is called each time a WebSocket connection is established. 92 + OnConnect func() 93 + 94 + // OnDisconnect is called each time a connection is lost. 95 + OnDisconnect func() 96 + 97 + // OnError is called when the handler returns an error. 98 + // If nil, errors are silently dropped (caller should log in the handler). 99 + OnError func(err error, event *Event) 100 + } 101 + 102 + func (c *Config) endpoints() []string { 103 + if len(c.Endpoints) > 0 { 104 + return c.Endpoints 105 + } 106 + return DefaultEndpoints 107 + } 108 + 109 + func (c *Config) cursorPersistEvery() int64 { 110 + if c.CursorPersistEvery > 0 { 111 + return c.CursorPersistEvery 112 + } 113 + return 1000 114 + } 115 + 116 + // Consumer consumes events from a Jetstream relay. 117 + type Consumer struct { 118 + cfg *Config 119 + handler Handler 120 + 121 + conn *websocket.Conn 122 + connMu sync.Mutex 123 + currentEndpointIdx int 124 + 125 + zstdDecoder *zstd.Decoder 126 + 127 + cursor atomic.Int64 128 + eventsReceived atomic.Int64 129 + bytesReceived atomic.Int64 130 + connected atomic.Bool 131 + 132 + stopCh chan struct{} 133 + wg sync.WaitGroup 134 + } 135 + 136 + // New creates a new Consumer. Call Start to begin consuming events. 137 + func New(cfg *Config, handler Handler) *Consumer { 138 + decoder, err := zstd.NewReader(nil, zstd.WithDecoderConcurrency(1)) 139 + if err != nil { 140 + // zstd.NewReader with nil src only fails on bad options 141 + panic(fmt.Sprintf("jetstream: create zstd decoder: %v", err)) 142 + } 143 + 144 + c := &Consumer{ 145 + cfg: cfg, 146 + handler: handler, 147 + stopCh: make(chan struct{}), 148 + zstdDecoder: decoder, 149 + } 150 + 151 + if cfg.CursorStore != nil { 152 + if cursor, err := cfg.CursorStore.GetCursor(context.Background()); err == nil && cursor > 0 { 153 + c.cursor.Store(cursor) 154 + } 155 + } 156 + 157 + return c 158 + } 159 + 160 + // Start begins consuming events in a background goroutine. 161 + func (c *Consumer) Start(ctx context.Context) { 162 + c.wg.Add(1) 163 + go func() { 164 + defer c.wg.Done() 165 + c.run(ctx) 166 + }() 167 + } 168 + 169 + // Stop gracefully shuts down the consumer and waits for it to finish. 170 + func (c *Consumer) Stop() { 171 + close(c.stopCh) 172 + c.connMu.Lock() 173 + if c.conn != nil { 174 + c.conn.Close() 175 + } 176 + c.connMu.Unlock() 177 + c.wg.Wait() 178 + c.zstdDecoder.Close() 179 + } 180 + 181 + // IsConnected reports whether the consumer is currently connected. 182 + func (c *Consumer) IsConnected() bool { 183 + return c.connected.Load() 184 + } 185 + 186 + // Stats returns cumulative event and byte counts since Start was called. 187 + func (c *Consumer) Stats() (eventsReceived, bytesReceived int64) { 188 + return c.eventsReceived.Load(), c.bytesReceived.Load() 189 + } 190 + 191 + func (c *Consumer) run(ctx context.Context) { 192 + backoff := time.Second 193 + const maxBackoff = 30 * time.Second 194 + 195 + for { 196 + select { 197 + case <-ctx.Done(): 198 + return 199 + case <-c.stopCh: 200 + return 201 + default: 202 + } 203 + 204 + endpoints := c.cfg.endpoints() 205 + endpoint := endpoints[c.currentEndpointIdx] 206 + 207 + if err := c.connectAndConsume(ctx, endpoint); err != nil { 208 + c.connected.Store(false) 209 + if c.cfg.OnDisconnect != nil { 210 + c.cfg.OnDisconnect() 211 + } 212 + 213 + // Rotate to next endpoint 214 + c.currentEndpointIdx = (c.currentEndpointIdx + 1) % len(endpoints) 215 + 216 + select { 217 + case <-ctx.Done(): 218 + return 219 + case <-c.stopCh: 220 + return 221 + case <-time.After(backoff): 222 + } 223 + 224 + backoff *= 2 225 + if backoff > maxBackoff { 226 + backoff = maxBackoff 227 + } 228 + } else { 229 + backoff = time.Second 230 + } 231 + } 232 + } 233 + 234 + func (c *Consumer) connectAndConsume(ctx context.Context, endpoint string) error { 235 + wsURL, err := c.buildURL(endpoint) 236 + if err != nil { 237 + return fmt.Errorf("build URL: %w", err) 238 + } 239 + 240 + dialer := websocket.Dialer{HandshakeTimeout: 10 * time.Second} 241 + conn, _, err := dialer.DialContext(ctx, wsURL, nil) 242 + if err != nil { 243 + return fmt.Errorf("dial: %w", err) 244 + } 245 + 246 + c.connMu.Lock() 247 + c.conn = conn 248 + c.connMu.Unlock() 249 + 250 + c.connected.Store(true) 251 + if c.cfg.OnConnect != nil { 252 + c.cfg.OnConnect() 253 + } 254 + 255 + defer func() { 256 + c.connMu.Lock() 257 + if c.conn != nil { 258 + c.conn.Close() 259 + c.conn = nil 260 + } 261 + c.connMu.Unlock() 262 + c.connected.Store(false) 263 + }() 264 + 265 + for { 266 + select { 267 + case <-ctx.Done(): 268 + return ctx.Err() 269 + case <-c.stopCh: 270 + return nil 271 + default: 272 + } 273 + 274 + conn.SetReadDeadline(time.Now().Add(60 * time.Second)) 275 + 276 + _, msg, err := conn.ReadMessage() 277 + if err != nil { 278 + return fmt.Errorf("read: %w", err) 279 + } 280 + 281 + c.bytesReceived.Add(int64(len(msg))) 282 + 283 + if err := c.process(ctx, msg); err != nil { 284 + if c.cfg.OnError != nil { 285 + // We don't have the event here since parsing may have failed, 286 + // pass nil to signal a parse/process error 287 + c.cfg.OnError(err, nil) 288 + } 289 + } 290 + } 291 + } 292 + 293 + func (c *Consumer) buildURL(endpoint string) (string, error) { 294 + u, err := url.Parse(endpoint) 295 + if err != nil { 296 + return "", err 297 + } 298 + 299 + q := u.Query() 300 + for _, coll := range c.cfg.WantedCollections { 301 + q.Add("wantedCollections", coll) 302 + } 303 + if c.cfg.Compress { 304 + q.Set("compress", "true") 305 + } 306 + if cursor := c.cursor.Load(); cursor > 0 { 307 + // Rewind 5 seconds to cover any gaps at reconnect 308 + rewind := cursor - (5 * time.Second.Microseconds()) 309 + q.Set("cursor", fmt.Sprintf("%d", rewind)) 310 + } 311 + 312 + u.RawQuery = q.Encode() 313 + return u.String(), nil 314 + } 315 + 316 + func (c *Consumer) process(ctx context.Context, data []byte) error { 317 + // Decompress if enabled and data has zstd magic bytes 318 + if c.cfg.Compress { 319 + if len(data) >= 4 && data[0] == 0x28 && data[1] == 0xB5 && data[2] == 0x2F && data[3] == 0xFD { 320 + decompressed, err := c.zstdDecoder.DecodeAll(data, nil) 321 + if err != nil { 322 + return fmt.Errorf("decompress: %w", err) 323 + } 324 + data = decompressed 325 + } else if len(data) > 0 && data[0] != '{' { 326 + // Try anyway in case magic bytes differ 327 + if decompressed, err := c.zstdDecoder.DecodeAll(data, nil); err == nil { 328 + data = decompressed 329 + } 330 + } 331 + } 332 + 333 + var event Event 334 + if err := json.Unmarshal(data, &event); err != nil { 335 + return fmt.Errorf("unmarshal event: %w", err) 336 + } 337 + 338 + c.eventsReceived.Add(1) 339 + 340 + if event.TimeUS > 0 { 341 + c.cursor.Store(event.TimeUS) 342 + 343 + if c.cfg.CursorStore != nil && c.eventsReceived.Load()%c.cfg.cursorPersistEvery() == 0 { 344 + if err := c.cfg.CursorStore.SetCursor(ctx, event.TimeUS); err != nil { 345 + // Non-fatal: log via OnError if configured 346 + if c.cfg.OnError != nil { 347 + c.cfg.OnError(fmt.Errorf("persist cursor: %w", err), nil) 348 + } 349 + } 350 + } 351 + } 352 + 353 + if err := c.handler(ctx, &event); err != nil { 354 + if c.cfg.OnError != nil { 355 + c.cfg.OnError(err, &event) 356 + } 357 + } 358 + 359 + return nil 360 + }
+143
jetstream/jetstream_test.go
··· 1 + package jetstream 2 + 3 + import ( 4 + "context" 5 + "encoding/json" 6 + "testing" 7 + ) 8 + 9 + func TestNew(t *testing.T) { 10 + c := New(&Config{ 11 + WantedCollections: []string{"app.bsky.feed.post"}, 12 + }, func(ctx context.Context, evt *Event) error { 13 + return nil 14 + }) 15 + if c == nil { 16 + t.Fatal("expected non-nil consumer") 17 + } 18 + if c.IsConnected() { 19 + t.Fatal("should not be connected before Start") 20 + } 21 + } 22 + 23 + func TestConfig_Defaults(t *testing.T) { 24 + cfg := &Config{} 25 + if eps := cfg.endpoints(); len(eps) == 0 { 26 + t.Fatal("expected default endpoints") 27 + } 28 + if cfg.cursorPersistEvery() != 1000 { 29 + t.Fatalf("expected 1000, got %d", cfg.cursorPersistEvery()) 30 + } 31 + } 32 + 33 + func TestConfig_CustomEndpoints(t *testing.T) { 34 + cfg := &Config{Endpoints: []string{"wss://custom.example.com/subscribe"}} 35 + eps := cfg.endpoints() 36 + if len(eps) != 1 || eps[0] != "wss://custom.example.com/subscribe" { 37 + t.Fatalf("unexpected endpoints: %v", eps) 38 + } 39 + } 40 + 41 + func TestBuildURL_Collections(t *testing.T) { 42 + c := New(&Config{ 43 + Endpoints: []string{"wss://jetstream1.us-east.bsky.network/subscribe"}, 44 + WantedCollections: []string{"app.bsky.feed.post", "app.bsky.feed.like"}, 45 + }, func(ctx context.Context, evt *Event) error { return nil }) 46 + 47 + u, err := c.buildURL("wss://jetstream1.us-east.bsky.network/subscribe") 48 + if err != nil { 49 + t.Fatal(err) 50 + } 51 + if u == "" { 52 + t.Fatal("expected non-empty URL") 53 + } 54 + // Should contain wantedCollections params 55 + if !contains(u, "wantedCollections=app.bsky.feed.post") { 56 + t.Errorf("URL missing wantedCollections: %s", u) 57 + } 58 + } 59 + 60 + func TestBuildURL_Cursor(t *testing.T) { 61 + c := New(&Config{ 62 + Endpoints: []string{"wss://jetstream1.us-east.bsky.network/subscribe"}, 63 + }, func(ctx context.Context, evt *Event) error { return nil }) 64 + 65 + c.cursor.Store(1000000) 66 + u, err := c.buildURL("wss://jetstream1.us-east.bsky.network/subscribe") 67 + if err != nil { 68 + t.Fatal(err) 69 + } 70 + if !contains(u, "cursor=") { 71 + t.Errorf("URL missing cursor: %s", u) 72 + } 73 + } 74 + 75 + func TestProcess_ValidEvent(t *testing.T) { 76 + var received *Event 77 + c := New(&Config{}, func(ctx context.Context, evt *Event) error { 78 + received = evt 79 + return nil 80 + }) 81 + 82 + evt := Event{ 83 + DID: "did:plc:test", 84 + TimeUS: 1234567890, 85 + Kind: "commit", 86 + Commit: &Commit{ 87 + Operation: "create", 88 + Collection: "app.bsky.feed.post", 89 + RKey: "abc123", 90 + }, 91 + } 92 + data, _ := json.Marshal(evt) 93 + 94 + if err := c.process(context.Background(), data); err != nil { 95 + t.Fatal(err) 96 + } 97 + if received == nil { 98 + t.Fatal("expected handler to be called") 99 + } 100 + if received.DID != "did:plc:test" { 101 + t.Fatalf("got DID %q", received.DID) 102 + } 103 + } 104 + 105 + func TestProcess_UpdatesCursor(t *testing.T) { 106 + c := New(&Config{}, func(ctx context.Context, evt *Event) error { return nil }) 107 + 108 + data, _ := json.Marshal(Event{DID: "did:plc:test", TimeUS: 9999999}) 109 + c.process(context.Background(), data) 110 + 111 + if c.cursor.Load() != 9999999 { 112 + t.Fatalf("cursor not updated: %d", c.cursor.Load()) 113 + } 114 + } 115 + 116 + func TestProcess_InvalidJSON(t *testing.T) { 117 + c := New(&Config{}, func(ctx context.Context, evt *Event) error { return nil }) 118 + err := c.process(context.Background(), []byte("{bad json")) 119 + if err == nil { 120 + t.Fatal("expected error for invalid JSON") 121 + } 122 + } 123 + 124 + func TestStats(t *testing.T) { 125 + c := New(&Config{}, func(ctx context.Context, evt *Event) error { return nil }) 126 + evts, bytes := c.Stats() 127 + if evts != 0 || bytes != 0 { 128 + t.Fatalf("expected zero stats, got events=%d bytes=%d", evts, bytes) 129 + } 130 + } 131 + 132 + func contains(s, substr string) bool { 133 + return len(s) >= len(substr) && (s == substr || len(s) > 0 && containsStr(s, substr)) 134 + } 135 + 136 + func containsStr(s, substr string) bool { 137 + for i := 0; i <= len(s)-len(substr); i++ { 138 + if s[i:i+len(substr)] == substr { 139 + return true 140 + } 141 + } 142 + return false 143 + }
+1 -2
oauth.go
··· 98 98 // LoginCLI runs a complete loopback OAuth flow for CLI applications. 99 99 // It opens the user's browser, starts a temporary HTTP server to receive the 100 100 // callback, and blocks until authentication completes. 101 + // TODO: should this be part of the library? probably not? (removeds `browser` dep) 101 102 func (a *OAuthApp) LoginCLI(ctx context.Context, handle string) (*SessionInfo, error) { 102 103 authURL, err := a.app.StartAuthFlow(ctx, handle) 103 104 if err != nil { ··· 182 183 return meta 183 184 } 184 185 185 - // Store returns the underlying session store, useful for implementing 186 - // features like "list all sessions" or session cleanup. 187 186 func (a *OAuthApp) Store() oauth.ClientAuthStore { 188 187 return a.app.Store 189 188 }
+49 -54
oauth_test.go
··· 6 6 "github.com/bluesky-social/indigo/atproto/auth/oauth" 7 7 ) 8 8 9 - func TestNewOAuthApp_Localhost(t *testing.T) { 10 - app, err := NewOAuthApp(OAuthConfig{ 11 - ClientID: "", 12 - RedirectURI: "http://127.0.0.1:12345/callback", 13 - Scopes: []string{"atproto"}, 14 - Store: oauth.NewMemStore(), 15 - }) 16 - if err != nil { 17 - t.Fatal(err) 9 + func TestNewOAuthApp(t *testing.T) { 10 + tests := []struct { 11 + name string 12 + config OAuthConfig 13 + }{ 14 + { 15 + name: "localhost IP", 16 + config: OAuthConfig{ 17 + ClientID: "", 18 + RedirectURI: "http://127.0.0.1:12345/callback", 19 + Scopes: []string{"atproto"}, 20 + Store: oauth.NewMemStore(), 21 + }, 22 + }, 23 + { 24 + name: "localhost prefix", 25 + config: OAuthConfig{ 26 + ClientID: "http://localhost:8080", 27 + RedirectURI: "http://localhost:8080/oauth/callback", 28 + Scopes: []string{"atproto"}, 29 + Store: oauth.NewMemStore(), 30 + }, 31 + }, 32 + { 33 + name: "public client", 34 + config: OAuthConfig{ 35 + ClientID: "https://example.com/client-metadata.json", 36 + RedirectURI: "https://example.com/oauth/callback", 37 + Scopes: ScopesForCollections("x.y.bean"), 38 + Store: oauth.NewMemStore(), 39 + }, 40 + }, 41 + { 42 + name: "nil store", 43 + config: OAuthConfig{ 44 + RedirectURI: "http://127.0.0.1:12345/callback", 45 + Scopes: []string{"atproto"}, 46 + }, 47 + }, 18 48 } 19 - if app == nil { 20 - t.Fatal("expected non-nil app") 21 - } 22 - } 23 - 24 - func TestNewOAuthApp_LocalhostPrefix(t *testing.T) { 25 - app, err := NewOAuthApp(OAuthConfig{ 26 - ClientID: "http://localhost:8080", 27 - RedirectURI: "http://localhost:8080/oauth/callback", 28 - Scopes: []string{"atproto"}, 29 - Store: oauth.NewMemStore(), 30 - }) 31 - if err != nil { 32 - t.Fatal(err) 33 - } 34 - if app == nil { 35 - t.Fatal("expected non-nil app") 36 - } 37 - } 38 - 39 - func TestNewOAuthApp_Public(t *testing.T) { 40 - app, err := NewOAuthApp(OAuthConfig{ 41 - ClientID: "https://example.com/client-metadata.json", 42 - RedirectURI: "https://example.com/oauth/callback", 43 - Scopes: ScopesForCollections("x.y.bean"), 44 - Store: oauth.NewMemStore(), 45 - }) 46 - if err != nil { 47 - t.Fatal(err) 48 - } 49 - if app == nil { 50 - t.Fatal("expected non-nil app") 51 - } 52 - } 53 - 54 - func TestNewOAuthApp_NilStore(t *testing.T) { 55 - app, err := NewOAuthApp(OAuthConfig{ 56 - RedirectURI: "http://127.0.0.1:12345/callback", 57 - Scopes: []string{"atproto"}, 58 - }) 59 - if err != nil { 60 - t.Fatal(err) 61 - } 62 - if app == nil { 63 - t.Fatal("expected non-nil app") 49 + for _, tc := range tests { 50 + t.Run(tc.name, func(t *testing.T) { 51 + app, err := NewOAuthApp(tc.config) 52 + if err != nil { 53 + t.Fatal(err) 54 + } 55 + if app == nil { 56 + t.Fatal("expected non-nil app") 57 + } 58 + }) 64 59 } 65 60 } 66 61
+304
public.go
··· 1 + package atp 2 + 3 + import ( 4 + "context" 5 + "encoding/json" 6 + "errors" 7 + "fmt" 8 + "net" 9 + "net/http" 10 + "net/url" 11 + "slices" 12 + "strings" 13 + "sync" 14 + "time" 15 + ) 16 + 17 + const ( 18 + // PublicAPIBase is the Bluesky public API endpoint used for profile and handle lookups. 19 + PublicAPIBase = "https://public.api.bsky.app" 20 + 21 + // PLCDirectory is used to resolve did:plc identifiers to DID documents. 22 + PLCDirectory = "https://plc.directory" 23 + ) 24 + 25 + // ErrSSRFBlocked is returned when a request is blocked due to a private/internal destination. 26 + var ErrSSRFBlocked = errors.New("request blocked: potential SSRF detected") 27 + 28 + // PublicClient provides unauthenticated read access to public AT Protocol APIs. 29 + // Use this to resolve handles, look up profiles, and read public records without 30 + // requiring an OAuth session. 31 + type PublicClient struct { 32 + httpClient *http.Client 33 + pdsCache map[string]string 34 + pdsCacheMu sync.RWMutex 35 + } 36 + 37 + // NewPublicClient creates a PublicClient with a 30-second timeout. 38 + // To add OTel instrumentation, use NewPublicClientWithHTTP and pass an 39 + // otelhttp-wrapped transport. 40 + func NewPublicClient() *PublicClient { 41 + return NewPublicClientWithHTTP(&http.Client{ 42 + Timeout: 30 * time.Second, 43 + }) 44 + } 45 + 46 + // NewPublicClientWithHTTP creates a PublicClient using the provided http.Client. 47 + // This lets callers inject custom transports (e.g. with OTel or rate limiting). 48 + func NewPublicClientWithHTTP(hc *http.Client) *PublicClient { 49 + return &PublicClient{ 50 + httpClient: hc, 51 + pdsCache: make(map[string]string), 52 + } 53 + } 54 + 55 + // ResolveHandle resolves an AT Protocol handle to a DID string. 56 + func (c *PublicClient) ResolveHandle(ctx context.Context, handle string) (string, error) { 57 + reqURL := fmt.Sprintf("%s/xrpc/com.atproto.identity.resolveHandle?handle=%s", 58 + PublicAPIBase, url.QueryEscape(handle)) 59 + 60 + req, err := http.NewRequestWithContext(ctx, "GET", reqURL, nil) 61 + if err != nil { 62 + return "", fmt.Errorf("build request: %w", err) 63 + } 64 + 65 + resp, err := c.httpClient.Do(req) 66 + if err != nil { 67 + return "", fmt.Errorf("resolve handle: %w", err) 68 + } 69 + defer resp.Body.Close() 70 + 71 + if resp.StatusCode == http.StatusNotFound { 72 + return "", fmt.Errorf("handle not found: %s", handle) 73 + } 74 + if resp.StatusCode != http.StatusOK { 75 + return "", fmt.Errorf("resolve handle: HTTP %d", resp.StatusCode) 76 + } 77 + 78 + var result struct { 79 + DID string `json:"did"` 80 + } 81 + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { 82 + return "", fmt.Errorf("decode response: %w", err) 83 + } 84 + return result.DID, nil 85 + } 86 + 87 + // GetPDSEndpoint resolves a DID to the user's PDS base URL. 88 + // Results are cached in-memory for the lifetime of the client. 89 + func (c *PublicClient) GetPDSEndpoint(ctx context.Context, did string) (string, error) { 90 + c.pdsCacheMu.RLock() 91 + if pds, ok := c.pdsCache[did]; ok { 92 + c.pdsCacheMu.RUnlock() 93 + return pds, nil 94 + } 95 + c.pdsCacheMu.RUnlock() 96 + 97 + var pdsEndpoint string 98 + 99 + switch { 100 + case strings.HasPrefix(did, "did:plc:"): 101 + reqURL := fmt.Sprintf("%s/%s", PLCDirectory, did) 102 + req, err := http.NewRequestWithContext(ctx, "GET", reqURL, nil) 103 + if err != nil { 104 + return "", fmt.Errorf("build request: %w", err) 105 + } 106 + resp, err := c.httpClient.Do(req) 107 + if err != nil { 108 + return "", fmt.Errorf("fetch DID document: %w", err) 109 + } 110 + defer resp.Body.Close() 111 + 112 + if resp.StatusCode != http.StatusOK { 113 + return "", fmt.Errorf("DID resolution: HTTP %d", resp.StatusCode) 114 + } 115 + 116 + var didDoc struct { 117 + Service []struct { 118 + ID string `json:"id"` 119 + Type string `json:"type"` 120 + ServiceEndpoint string `json:"serviceEndpoint"` 121 + } `json:"service"` 122 + } 123 + if err := json.NewDecoder(resp.Body).Decode(&didDoc); err != nil { 124 + return "", fmt.Errorf("decode DID document: %w", err) 125 + } 126 + for _, svc := range didDoc.Service { 127 + if svc.ID == "#atproto_pds" || svc.Type == "AtprotoPersonalDataServer" { 128 + pdsEndpoint = svc.ServiceEndpoint 129 + break 130 + } 131 + } 132 + 133 + case strings.HasPrefix(did, "did:web:"): 134 + domain := strings.TrimPrefix(did, "did:web:") 135 + domain = strings.ReplaceAll(domain, "%3A", ":") 136 + if idx := strings.Index(domain, "/"); idx != -1 { 137 + domain = domain[:idx] 138 + } 139 + host := domain 140 + if h, _, err := net.SplitHostPort(domain); err == nil { 141 + host = h 142 + } 143 + if err := validateDomain(host); err != nil { 144 + return "", err 145 + } 146 + pdsEndpoint = "https://" + domain 147 + } 148 + 149 + if pdsEndpoint == "" { 150 + return "", fmt.Errorf("could not resolve PDS endpoint for %s", did) 151 + } 152 + 153 + c.pdsCacheMu.Lock() 154 + c.pdsCache[did] = pdsEndpoint 155 + c.pdsCacheMu.Unlock() 156 + 157 + return pdsEndpoint, nil 158 + } 159 + 160 + // PublicProfile is a user's public profile as returned by the Bluesky public API. 161 + type PublicProfile struct { 162 + DID string `json:"did"` 163 + Handle string `json:"handle"` 164 + DisplayName *string `json:"displayName,omitempty"` 165 + Avatar *string `json:"avatar,omitempty"` 166 + } 167 + 168 + // GetProfile fetches a user's public profile by DID or handle. 169 + func (c *PublicClient) GetProfile(ctx context.Context, actor string) (*PublicProfile, error) { 170 + reqURL := fmt.Sprintf("%s/xrpc/app.bsky.actor.getProfile?actor=%s", 171 + PublicAPIBase, url.QueryEscape(actor)) 172 + 173 + req, err := http.NewRequestWithContext(ctx, "GET", reqURL, nil) 174 + if err != nil { 175 + return nil, fmt.Errorf("build request: %w", err) 176 + } 177 + 178 + resp, err := c.httpClient.Do(req) 179 + if err != nil { 180 + return nil, fmt.Errorf("fetch profile: %w", err) 181 + } 182 + defer resp.Body.Close() 183 + 184 + if resp.StatusCode != http.StatusOK { 185 + return nil, fmt.Errorf("get profile: HTTP %d", resp.StatusCode) 186 + } 187 + 188 + var profile PublicProfile 189 + if err := json.NewDecoder(resp.Body).Decode(&profile); err != nil { 190 + return nil, fmt.Errorf("decode profile: %w", err) 191 + } 192 + return &profile, nil 193 + } 194 + 195 + // ListPublicRecords fetches up to limit records from a public collection. 196 + // Queries the user's PDS directly, so it works with any collection NSID. 197 + func (c *PublicClient) ListPublicRecords(ctx context.Context, did, collection string, limit int) ([]Record, error) { 198 + pdsEndpoint, err := c.GetPDSEndpoint(ctx, did) 199 + if err != nil { 200 + return nil, fmt.Errorf("resolve PDS: %w", err) 201 + } 202 + 203 + reqURL := fmt.Sprintf("%s/xrpc/com.atproto.repo.listRecords?repo=%s&collection=%s&limit=%d", 204 + pdsEndpoint, url.QueryEscape(did), url.QueryEscape(collection), limit) 205 + 206 + req, err := http.NewRequestWithContext(ctx, "GET", reqURL, nil) 207 + if err != nil { 208 + return nil, fmt.Errorf("build request: %w", err) 209 + } 210 + 211 + resp, err := c.httpClient.Do(req) 212 + if err != nil { 213 + return nil, fmt.Errorf("list records: %w", err) 214 + } 215 + defer resp.Body.Close() 216 + 217 + if resp.StatusCode != http.StatusOK { 218 + return nil, fmt.Errorf("list records: HTTP %d", resp.StatusCode) 219 + } 220 + 221 + var result struct { 222 + Records []struct { 223 + URI string `json:"uri"` 224 + CID string `json:"cid"` 225 + Value map[string]any `json:"value"` 226 + } `json:"records"` 227 + } 228 + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { 229 + return nil, fmt.Errorf("decode records: %w", err) 230 + } 231 + 232 + records := make([]Record, len(result.Records)) 233 + for i, r := range result.Records { 234 + records[i] = Record{URI: r.URI, CID: r.CID, Value: r.Value} 235 + } 236 + return records, nil 237 + } 238 + 239 + // GetPublicRecord fetches a single public record from a user's PDS. 240 + func (c *PublicClient) GetPublicRecord(ctx context.Context, did, collection, rkey string) (*Record, error) { 241 + pdsEndpoint, err := c.GetPDSEndpoint(ctx, did) 242 + if err != nil { 243 + return nil, fmt.Errorf("resolve PDS: %w", err) 244 + } 245 + 246 + reqURL := fmt.Sprintf("%s/xrpc/com.atproto.repo.getRecord?repo=%s&collection=%s&rkey=%s", 247 + pdsEndpoint, url.QueryEscape(did), url.QueryEscape(collection), url.QueryEscape(rkey)) 248 + 249 + req, err := http.NewRequestWithContext(ctx, "GET", reqURL, nil) 250 + if err != nil { 251 + return nil, fmt.Errorf("build request: %w", err) 252 + } 253 + 254 + resp, err := c.httpClient.Do(req) 255 + if err != nil { 256 + return nil, fmt.Errorf("get record: %w", err) 257 + } 258 + defer resp.Body.Close() 259 + 260 + if resp.StatusCode != http.StatusOK { 261 + return nil, fmt.Errorf("get record: HTTP %d", resp.StatusCode) 262 + } 263 + 264 + var r struct { 265 + URI string `json:"uri"` 266 + CID string `json:"cid"` 267 + Value map[string]any `json:"value"` 268 + } 269 + if err := json.NewDecoder(resp.Body).Decode(&r); err != nil { 270 + return nil, fmt.Errorf("decode record: %w", err) 271 + } 272 + return &Record{URI: r.URI, CID: r.CID, Value: r.Value}, nil 273 + } 274 + 275 + // isPrivateIP reports whether ip is in a private/reserved range. 276 + func isPrivateIP(ip net.IP) bool { 277 + return ip.IsLoopback() || 278 + ip.IsLinkLocalUnicast() || 279 + ip.IsLinkLocalMulticast() || 280 + ip.IsPrivate() || 281 + ip.IsUnspecified() || 282 + ip.Equal(net.ParseIP("169.254.169.254")) // cloud metadata 283 + } 284 + 285 + // validateDomain blocks requests to private/internal hosts. 286 + func validateDomain(domain string) error { 287 + if domain == "localhost" || strings.HasSuffix(domain, ".local") { 288 + return ErrSSRFBlocked 289 + } 290 + if ip := net.ParseIP(domain); ip != nil { 291 + if isPrivateIP(ip) { 292 + return ErrSSRFBlocked 293 + } 294 + return nil 295 + } 296 + ips, err := net.LookupIP(domain) 297 + if err != nil { 298 + return nil // let the HTTP request fail naturally 299 + } 300 + if slices.ContainsFunc(ips, isPrivateIP) { 301 + return ErrSSRFBlocked 302 + } 303 + return nil 304 + }
+62
public_test.go
··· 1 + package atp 2 + 3 + import ( 4 + "net" 5 + "testing" 6 + ) 7 + 8 + func TestIsPrivateIP(t *testing.T) { 9 + cases := []struct { 10 + ip string 11 + private bool 12 + }{ 13 + {"127.0.0.1", true}, 14 + {"::1", true}, 15 + {"10.0.0.1", true}, 16 + {"172.16.0.1", true}, 17 + {"192.168.1.1", true}, 18 + {"169.254.169.254", true}, 19 + {"0.0.0.0", true}, 20 + {"8.8.8.8", false}, 21 + {"1.1.1.1", false}, 22 + } 23 + 24 + for _, tc := range cases { 25 + ip := net.ParseIP(tc.ip) 26 + got := isPrivateIP(ip) 27 + if got != tc.private { 28 + t.Errorf("isPrivateIP(%s) = %v, want %v", tc.ip, got, tc.private) 29 + } 30 + } 31 + } 32 + 33 + func TestValidateDomain_Localhost(t *testing.T) { 34 + if err := validateDomain("localhost"); err != ErrSSRFBlocked { 35 + t.Fatalf("expected ErrSSRFBlocked for localhost, got %v", err) 36 + } 37 + } 38 + 39 + func TestValidateDomain_DotLocal(t *testing.T) { 40 + if err := validateDomain("internal.local"); err != ErrSSRFBlocked { 41 + t.Fatalf("expected ErrSSRFBlocked for .local domain, got %v", err) 42 + } 43 + } 44 + 45 + func TestValidateDomain_PrivateIP(t *testing.T) { 46 + if err := validateDomain("192.168.1.1"); err != ErrSSRFBlocked { 47 + t.Fatalf("expected ErrSSRFBlocked for private IP, got %v", err) 48 + } 49 + } 50 + 51 + func TestValidateDomain_MetadataIP(t *testing.T) { 52 + if err := validateDomain("169.254.169.254"); err != ErrSSRFBlocked { 53 + t.Fatalf("expected ErrSSRFBlocked for metadata IP, got %v", err) 54 + } 55 + } 56 + 57 + func TestNewPublicClient(t *testing.T) { 58 + c := NewPublicClient() 59 + if c == nil { 60 + t.Fatal("expected non-nil client") 61 + } 62 + }
-1
record.go
··· 1 1 package atp 2 2 3 - // Record represents a single record returned from a PDS. 4 3 type Record struct { 5 4 URI string 6 5 CID string
+2
scopes.go
··· 6 6 // 7 7 // ScopesForCollections("x.y.bean", "x.y.brew") 8 8 // // => ["atproto", "repo:x.y.bean", "repo:x.y.brew"] 9 + // 10 + // TODO: add support for collections and more granular control than just full rw 9 11 func ScopesForCollections(collections ...string) []string { 10 12 scopes := make([]string, 0, 1+len(collections)) 11 13 scopes = append(scopes, "atproto")
+22 -11
scopes_test.go
··· 6 6 ) 7 7 8 8 func TestScopesForCollections(t *testing.T) { 9 - got := ScopesForCollections("social.arabica.alpha.bean", "social.arabica.alpha.brew") 10 - want := []string{"atproto", "repo:social.arabica.alpha.bean", "repo:social.arabica.alpha.brew"} 11 - if !slices.Equal(got, want) { 12 - t.Fatalf("got %v, want %v", got, want) 9 + tests := []struct { 10 + name string 11 + collections []string 12 + want []string 13 + }{ 14 + { 15 + name: "no collections", 16 + collections: nil, 17 + want: []string{"atproto"}, 18 + }, 19 + { 20 + name: "multiple collections", 21 + collections: []string{"social.arabica.alpha.bean", "social.arabica.alpha.brew"}, 22 + want: []string{"atproto", "repo:social.arabica.alpha.bean", "repo:social.arabica.alpha.brew"}, 23 + }, 13 24 } 14 - } 15 - 16 - func TestScopesForCollections_Empty(t *testing.T) { 17 - got := ScopesForCollections() 18 - want := []string{"atproto"} 19 - if !slices.Equal(got, want) { 20 - t.Fatalf("got %v, want %v", got, want) 25 + for _, tc := range tests { 26 + t.Run(tc.name, func(t *testing.T) { 27 + got := ScopesForCollections(tc.collections...) 28 + if !slices.Equal(got, tc.want) { 29 + t.Fatalf("got %v, want %v", got, tc.want) 30 + } 31 + }) 21 32 } 22 33 } 23 34
+2 -1
tracing/tracing.go
··· 24 24 25 25 // Init creates and registers a tracer provider with an OTLP HTTP exporter. 26 26 // It reads OTEL_EXPORTER_OTLP_ENDPOINT (default: localhost:4318). 27 - // The serviceName appears in your tracing backend (e.g. "arabica", "solanum"). 27 + // The serviceName appears in your tracing backend (e.g. "arabica"). 28 28 // Returns the provider so the caller can defer provider.Shutdown(ctx). 29 + // TODO: allow grpc exporting to port 4317 29 30 func Init(ctx context.Context, serviceName string) (*sdktrace.TracerProvider, error) { 30 31 endpoint := os.Getenv("OTEL_EXPORTER_OTLP_ENDPOINT") 31 32 if endpoint == "" {
+21 -20
tracing/tracing_test.go
··· 7 7 "go.opentelemetry.io/otel/trace" 8 8 ) 9 9 10 - func TestBoltSpan_NoParent(t *testing.T) { 10 + func TestSpan_NoParent(t *testing.T) { 11 11 ctx := context.Background() 12 - // Without a parent span, should return a no-op span 13 - _, span := BoltSpan(ctx, "GetSession", "oauth_sessions") 14 - if span.SpanContext().IsValid() { 15 - t.Fatal("expected no-op span without parent") 12 + tests := []struct { 13 + name string 14 + fn func(context.Context) (context.Context, trace.Span) 15 + }{ 16 + {"bolt", func(ctx context.Context) (context.Context, trace.Span) { 17 + return BoltSpan(ctx, "GetSession", "oauth_sessions") 18 + }}, 19 + {"sqlite", func(ctx context.Context) (context.Context, trace.Span) { 20 + return SqliteSpan(ctx, "query", "records") 21 + }}, 22 + {"pds", func(ctx context.Context) (context.Context, trace.Span) { 23 + return PdsSpan(ctx, "createRecord", "x.y.z", "did:plc:test") 24 + }}, 16 25 } 17 - } 18 - 19 - func TestSqliteSpan_NoParent(t *testing.T) { 20 - ctx := context.Background() 21 - _, span := SqliteSpan(ctx, "query", "records") 22 - if span.SpanContext().IsValid() { 23 - t.Fatal("expected no-op span without parent") 24 - } 25 - } 26 - 27 - func TestPdsSpan_NoParent(t *testing.T) { 28 - ctx := context.Background() 29 - _, span := PdsSpan(ctx, "createRecord", "x.y.z", "did:plc:test") 30 - if span.SpanContext().IsValid() { 31 - t.Fatal("expected no-op span without parent") 26 + for _, tc := range tests { 27 + t.Run(tc.name, func(t *testing.T) { 28 + _, span := tc.fn(ctx) 29 + if span.SpanContext().IsValid() { 30 + t.Fatal("expected no-op span without parent") 31 + } 32 + }) 32 33 } 33 34 } 34 35
+1
uri.go
··· 7 7 ) 8 8 9 9 func BuildATURI(did, collection, rkey string) string { 10 + // TODO: add validation on each param (maybe just call ParseATURI?) 10 11 return fmt.Sprintf("at://%s/%s/%s", did, collection, rkey) 11 12 } 12 13
+50 -21
uri_test.go
··· 11 11 } 12 12 13 13 func TestParseATURI(t *testing.T) { 14 - did, collection, rkey, err := ParseATURI("at://did:plc:abc/app.bsky.feed.post/3jxy") 15 - if err != nil { 16 - t.Fatal(err) 14 + tests := []struct { 15 + name string 16 + input string 17 + wantDID string 18 + wantColl string 19 + wantRKey string 20 + wantErr bool 21 + }{ 22 + { 23 + name: "valid URI", 24 + input: "at://did:plc:abc/app.bsky.feed.post/3jxy", 25 + wantDID: "did:plc:abc", 26 + wantColl: "app.bsky.feed.post", 27 + wantRKey: "3jxy", 28 + }, 29 + { 30 + name: "invalid URI", 31 + input: "not-a-uri", 32 + wantErr: true, 33 + }, 17 34 } 18 - if did != "did:plc:abc" || collection != "app.bsky.feed.post" || rkey != "3jxy" { 19 - t.Fatalf("got did=%q collection=%q rkey=%q", did, collection, rkey) 20 - } 21 - } 22 - 23 - func TestParseATURI_Invalid(t *testing.T) { 24 - _, _, _, err := ParseATURI("not-a-uri") 25 - if err == nil { 26 - t.Fatal("expected error for invalid URI") 35 + for _, tc := range tests { 36 + t.Run(tc.name, func(t *testing.T) { 37 + did, collection, rkey, err := ParseATURI(tc.input) 38 + if tc.wantErr { 39 + if err == nil { 40 + t.Fatal("expected error for invalid URI") 41 + } 42 + return 43 + } 44 + if err != nil { 45 + t.Fatal(err) 46 + } 47 + if did != tc.wantDID || collection != tc.wantColl || rkey != tc.wantRKey { 48 + t.Fatalf("got did=%q collection=%q rkey=%q", did, collection, rkey) 49 + } 50 + }) 27 51 } 28 52 } 29 53 30 54 func TestRKeyFromURI(t *testing.T) { 31 - got := RKeyFromURI("at://did:plc:abc/app.bsky.feed.post/3jxy") 32 - if got != "3jxy" { 33 - t.Fatalf("got %q, want %q", got, "3jxy") 55 + tests := []struct { 56 + name string 57 + input string 58 + want string 59 + }{ 60 + {"valid URI", "at://did:plc:abc/app.bsky.feed.post/3jxy", "3jxy"}, 61 + {"invalid URI", "garbage", ""}, 34 62 } 35 - } 36 - 37 - func TestRKeyFromURI_Invalid(t *testing.T) { 38 - got := RKeyFromURI("garbage") 39 - if got != "" { 40 - t.Fatalf("expected empty string for invalid URI, got %q", got) 63 + for _, tc := range tests { 64 + t.Run(tc.name, func(t *testing.T) { 65 + got := RKeyFromURI(tc.input) 66 + if got != tc.want { 67 + t.Fatalf("got %q, want %q", got, tc.want) 68 + } 69 + }) 41 70 } 42 71 }