A container registry that uses the AT Protocol for manifest storage and S3 for blob storage. atcr.io
docker container atproto go
80
fork

Configure Feed

Select the types of activity you want to include in your feed.

fix more db query issues, improve labeler usage

+863 -289
+30
pkg/appview/db/annotations.go
··· 29 29 return annotations, rows.Err() 30 30 } 31 31 32 + // GetRepositoryAnnotationsByDID retrieves all annotations for every 33 + // repository owned by a DID, grouped as map[repository]map[key]value. 34 + // Used by bulk-fetch paths to avoid issuing one query per repository. 35 + func GetRepositoryAnnotationsByDID(db DBTX, did string) (map[string]map[string]string, error) { 36 + rows, err := db.Query(` 37 + SELECT repository, key, value 38 + FROM repository_annotations 39 + WHERE did = ? 40 + `, did) 41 + if err != nil { 42 + return nil, err 43 + } 44 + defer rows.Close() 45 + 46 + out := make(map[string]map[string]string) 47 + for rows.Next() { 48 + var repo, key, value string 49 + if err := rows.Scan(&repo, &key, &value); err != nil { 50 + return nil, err 51 + } 52 + m, ok := out[repo] 53 + if !ok { 54 + m = make(map[string]string) 55 + out[repo] = m 56 + } 57 + m[key] = value 58 + } 59 + return out, rows.Err() 60 + } 61 + 32 62 // UpsertRepositoryAnnotations upserts annotations for a repository. 33 63 // Stale keys not present in the new map are deleted. 34 64 // Unchanged values are skipped to avoid unnecessary writes.
+152 -82
pkg/appview/db/queries.go
··· 241 241 // GetUserRepositories fetches all repositories for a user. 242 242 // viewerDID scopes results to repositories whose manifests live on holds the 243 243 // viewer can access (empty viewerDID = anonymous → public + self-service only). 244 + // 245 + // Implementation: one summary query for the accessible repository set, then 246 + // four bulk queries (tags, manifests, annotations, repo_pages) all keyed by 247 + // did. Results are grouped in Go and assembled per repo. Total: 5 queries 248 + // regardless of how many repos the user owns. 244 249 func GetUserRepositories(db DBTX, did string, viewerDID string) ([]Repository, error) { 245 - // Get repository summary. 246 - // Both tags and manifests are filtered via join onto manifests.hold_endpoint 247 - // so repositories where every row lives on an inaccessible hold drop out. 250 + // Step 1: summary query. Both tags and manifests are filtered via join 251 + // onto manifests.hold_endpoint so repositories where every row lives on 252 + // an inaccessible hold drop out. 248 253 rows, err := db.Query(` 249 254 SELECT 250 255 repository, ··· 264 269 GROUP BY repository 265 270 ORDER BY last_push DESC 266 271 `, did, viewerDID, viewerDID, did, viewerDID, viewerDID) 267 - 268 272 if err != nil { 269 273 return nil, err 270 274 } 271 - defer rows.Close() 272 275 273 - var repos []Repository 276 + type repoSummary struct { 277 + Name string 278 + TagCount int 279 + ManifestCount int 280 + LastPushStr string 281 + } 282 + var summaries []repoSummary 274 283 for rows.Next() { 275 - var r Repository 276 - var lastPushStr string 277 - if err := rows.Scan(&r.Name, &r.TagCount, &r.ManifestCount, &lastPushStr); err != nil { 284 + var s repoSummary 285 + if err := rows.Scan(&s.Name, &s.TagCount, &s.ManifestCount, &s.LastPushStr); err != nil { 286 + rows.Close() 278 287 return nil, err 279 288 } 289 + summaries = append(summaries, s) 290 + } 291 + rows.Close() 280 292 281 - // Parse the timestamp string into time.Time 282 - if lastPushStr != "" { 283 - // Try multiple timestamp formats 284 - formats := []string{ 285 - time.RFC3339Nano, // 2006-01-02T15:04:05.999999999Z07:00 286 - "2006-01-02 15:04:05.999999999-07:00", // SQLite with microseconds and timezone 287 - "2006-01-02 15:04:05.999999999", // SQLite with microseconds 288 - time.RFC3339, // 2006-01-02T15:04:05Z07:00 289 - "2006-01-02 15:04:05", // SQLite default 290 - } 293 + if len(summaries) == 0 { 294 + return nil, nil 295 + } 291 296 292 - for _, format := range formats { 293 - if t, err := time.Parse(format, lastPushStr); err == nil { 294 - r.LastPush = t 295 - break 296 - } 297 - } 298 - } 297 + // Build the set of accessible repo names for filtering bulk-fetched rows 298 + // against repos that the viewer can't see (rows for repos owned by `did` 299 + // but stored on inaccessible holds). 300 + accessible := make(map[string]bool, len(summaries)) 301 + for _, s := range summaries { 302 + accessible[s.Name] = true 303 + } 299 304 300 - // Get tags for this repo 301 - tagRows, err := db.Query(` 302 - SELECT id, tag, digest, created_at 303 - FROM tags 304 - WHERE did = ? AND repository = ? 305 - ORDER BY created_at DESC 306 - `, did, r.Name) 307 - 308 - if err != nil { 309 - return nil, err 310 - } 311 - 312 - for tagRows.Next() { 313 - var t Tag 314 - t.DID = did 315 - t.Repository = r.Name 316 - if err := tagRows.Scan(&t.ID, &t.Tag, &t.Digest, &t.CreatedAt); err != nil { 317 - tagRows.Close() 318 - return nil, err 319 - } 320 - r.Tags = append(r.Tags, t) 321 - } 322 - tagRows.Close() 323 - 324 - // Get manifests for this repo 325 - manifestRows, err := db.Query(` 326 - SELECT id, digest, hold_endpoint, schema_version, media_type, 327 - config_digest, config_size, artifact_type, created_at 328 - FROM manifests 329 - WHERE did = ? AND repository = ? 330 - ORDER BY created_at DESC 331 - `, did, r.Name) 305 + // Step 2: bulk-fetch tags for all repos owned by did, grouped by repo. 306 + tagsByRepo, err := bulkTagsByRepo(db, did, accessible) 307 + if err != nil { 308 + return nil, err 309 + } 332 310 333 - if err != nil { 334 - return nil, err 335 - } 311 + // Step 3: bulk-fetch manifests, grouped by repo. 312 + manifestsByRepo, err := bulkManifestsByRepo(db, did, accessible) 313 + if err != nil { 314 + return nil, err 315 + } 336 316 337 - for manifestRows.Next() { 338 - var m Manifest 339 - m.DID = did 340 - m.Repository = r.Name 341 - 342 - if err := manifestRows.Scan(&m.ID, &m.Digest, &m.HoldEndpoint, &m.SchemaVersion, 343 - &m.MediaType, &m.ConfigDigest, &m.ConfigSize, &m.ArtifactType, &m.CreatedAt); err != nil { 344 - manifestRows.Close() 345 - return nil, err 346 - } 317 + // Step 4: bulk-fetch annotations, grouped by repo. 318 + annotationsByRepo, err := GetRepositoryAnnotationsByDID(db, did) 319 + if err != nil { 320 + return nil, err 321 + } 347 322 348 - r.Manifests = append(r.Manifests, m) 349 - } 350 - manifestRows.Close() 323 + // Step 5: bulk-fetch repo pages (existing helper), keyed by repo. 324 + pages, err := GetRepoPagesByDID(db, did) 325 + if err != nil { 326 + return nil, err 327 + } 328 + pagesByRepo := make(map[string]*RepoPage, len(pages)) 329 + for i := range pages { 330 + pagesByRepo[pages[i].Repository] = &pages[i] 331 + } 351 332 352 - // Fetch repository-level annotations from annotations table 353 - annotations, err := GetRepositoryAnnotations(db, did, r.Name) 354 - if err != nil { 355 - return nil, err 333 + // Assemble results in summary order (preserves last_push DESC). 334 + repos := make([]Repository, 0, len(summaries)) 335 + for _, s := range summaries { 336 + r := Repository{ 337 + Name: s.Name, 338 + TagCount: s.TagCount, 339 + ManifestCount: s.ManifestCount, 340 + LastPush: parseRepoTimestamp(s.LastPushStr), 341 + Tags: tagsByRepo[s.Name], 342 + Manifests: manifestsByRepo[s.Name], 356 343 } 357 344 345 + annotations := annotationsByRepo[s.Name] 358 346 r.Title = annotations["org.opencontainers.image.title"] 359 347 r.Description = annotations["org.opencontainers.image.description"] 360 348 r.SourceURL = annotations["org.opencontainers.image.source"] ··· 363 351 r.IconURL = annotations["io.atcr.icon"] 364 352 r.ReadmeURL = annotations["io.atcr.readme"] 365 353 366 - // Check for repo page avatar (overrides annotation icon) 367 - repoPage, err := GetRepoPage(db, did, r.Name) 368 - if err == nil && repoPage != nil && repoPage.AvatarCID != "" { 369 - r.IconURL = BlobCDNURL(did, repoPage.AvatarCID) 354 + // Repo page avatar overrides annotation icon when present. 355 + if page, ok := pagesByRepo[s.Name]; ok && page.AvatarCID != "" { 356 + r.IconURL = BlobCDNURL(did, page.AvatarCID) 370 357 } 371 358 372 359 repos = append(repos, r) 373 360 } 374 361 375 362 return repos, nil 363 + } 364 + 365 + // bulkTagsByRepo fetches every tag owned by did and groups by repository, 366 + // dropping repos not in the accessible set. Result preserves created_at DESC 367 + // ordering within each repo. 368 + func bulkTagsByRepo(db DBTX, did string, accessible map[string]bool) (map[string][]Tag, error) { 369 + rows, err := db.Query(` 370 + SELECT id, repository, tag, digest, created_at 371 + FROM tags 372 + WHERE did = ? 373 + ORDER BY repository, created_at DESC 374 + `, did) 375 + if err != nil { 376 + return nil, err 377 + } 378 + defer rows.Close() 379 + 380 + out := make(map[string][]Tag) 381 + for rows.Next() { 382 + var t Tag 383 + t.DID = did 384 + if err := rows.Scan(&t.ID, &t.Repository, &t.Tag, &t.Digest, &t.CreatedAt); err != nil { 385 + return nil, err 386 + } 387 + if !accessible[t.Repository] { 388 + continue 389 + } 390 + out[t.Repository] = append(out[t.Repository], t) 391 + } 392 + return out, rows.Err() 393 + } 394 + 395 + // bulkManifestsByRepo fetches every manifest owned by did and groups by 396 + // repository, dropping repos not in the accessible set. Result preserves 397 + // created_at DESC ordering within each repo. 398 + func bulkManifestsByRepo(db DBTX, did string, accessible map[string]bool) (map[string][]Manifest, error) { 399 + rows, err := db.Query(` 400 + SELECT id, repository, digest, hold_endpoint, schema_version, media_type, 401 + config_digest, config_size, artifact_type, created_at 402 + FROM manifests 403 + WHERE did = ? 404 + ORDER BY repository, created_at DESC 405 + `, did) 406 + if err != nil { 407 + return nil, err 408 + } 409 + defer rows.Close() 410 + 411 + out := make(map[string][]Manifest) 412 + for rows.Next() { 413 + var m Manifest 414 + m.DID = did 415 + if err := rows.Scan(&m.ID, &m.Repository, &m.Digest, &m.HoldEndpoint, &m.SchemaVersion, 416 + &m.MediaType, &m.ConfigDigest, &m.ConfigSize, &m.ArtifactType, &m.CreatedAt); err != nil { 417 + return nil, err 418 + } 419 + if !accessible[m.Repository] { 420 + continue 421 + } 422 + out[m.Repository] = append(out[m.Repository], m) 423 + } 424 + return out, rows.Err() 425 + } 426 + 427 + // parseRepoTimestamp tolerates the several timestamp formats SQLite/libsql 428 + // can return for MAX(created_at) depending on driver and schema history. 429 + func parseRepoTimestamp(s string) time.Time { 430 + if s == "" { 431 + return time.Time{} 432 + } 433 + formats := []string{ 434 + time.RFC3339Nano, // 2006-01-02T15:04:05.999999999Z07:00 435 + "2006-01-02 15:04:05.999999999-07:00", // SQLite with microseconds and timezone 436 + "2006-01-02 15:04:05.999999999", // SQLite with microseconds 437 + time.RFC3339, // 2006-01-02T15:04:05Z07:00 438 + "2006-01-02 15:04:05", // SQLite default 439 + } 440 + for _, format := range formats { 441 + if t, err := time.Parse(format, s); err == nil { 442 + return t 443 + } 444 + } 445 + return time.Time{} 376 446 } 377 447 378 448 // GetRepositoryMetadata retrieves metadata for a repository from annotations table
+160
pkg/appview/db/queries_test.go
··· 1607 1607 t.Errorf("crew viewer: expected both repos, got %d: %v", len(repos), repos) 1608 1608 } 1609 1609 } 1610 + 1611 + // TestGetUserRepositories_BulkGrouping verifies that the bulk-fetch 1612 + // implementation correctly groups tags, manifests, annotations, and repo-page 1613 + // avatars per repository — and that ordering (last_push DESC for repos, 1614 + // created_at DESC for tags/manifests within a repo) is preserved. 1615 + // 1616 + // Regression guard for the previous N+1 implementation, which issued one 1617 + // query per repo and per relation. 1618 + func TestGetUserRepositories_BulkGrouping(t *testing.T) { 1619 + db, err := InitDB("file:TestGetUserRepositories_BulkGrouping?mode=memory&cache=shared", LibsqlConfig{}) 1620 + if err != nil { 1621 + t.Fatalf("init db: %v", err) 1622 + } 1623 + defer db.Close() 1624 + 1625 + user := &User{DID: "did:plc:owner", Handle: "owner.test", PDSEndpoint: "https://pds.example", LastSeen: time.Now()} 1626 + if err := UpsertUser(db, user); err != nil { 1627 + t.Fatalf("upsert user: %v", err) 1628 + } 1629 + if err := UpsertCaptainRecord(db, &HoldCaptainRecord{ 1630 + HoldDID: "did:web:hold.example", OwnerDID: "did:plc:holdowner", Public: true, 1631 + }); err != nil { 1632 + t.Fatalf("seed captain: %v", err) 1633 + } 1634 + 1635 + now := time.Now().UTC().Truncate(time.Second) 1636 + mediaType := "application/vnd.oci.image.manifest.v1+json" 1637 + 1638 + // repoA: two manifests (oldest then newer) and two tags. last_push = now+10s. 1639 + manifestA1, err := InsertManifest(db, &Manifest{ 1640 + DID: user.DID, Repository: "repoA", Digest: "sha256:a1", 1641 + HoldEndpoint: "did:web:hold.example", SchemaVersion: 2, MediaType: mediaType, 1642 + CreatedAt: now, 1643 + }) 1644 + if err != nil { 1645 + t.Fatalf("insert manifest a1: %v", err) 1646 + } 1647 + manifestA2, err := InsertManifest(db, &Manifest{ 1648 + DID: user.DID, Repository: "repoA", Digest: "sha256:a2", 1649 + HoldEndpoint: "did:web:hold.example", SchemaVersion: 2, MediaType: mediaType, 1650 + CreatedAt: now.Add(5 * time.Second), 1651 + }) 1652 + if err != nil { 1653 + t.Fatalf("insert manifest a2: %v", err) 1654 + } 1655 + if err := UpsertTag(db, &Tag{DID: user.DID, Repository: "repoA", Tag: "v1", Digest: "sha256:a1", CreatedAt: now.Add(8 * time.Second)}); err != nil { 1656 + t.Fatalf("upsert tag v1: %v", err) 1657 + } 1658 + if err := UpsertTag(db, &Tag{DID: user.DID, Repository: "repoA", Tag: "v2", Digest: "sha256:a2", CreatedAt: now.Add(10 * time.Second)}); err != nil { 1659 + t.Fatalf("upsert tag v2: %v", err) 1660 + } 1661 + 1662 + // repoB: one manifest, one tag. last_push = now+1s (older than repoA → repoA sorts first). 1663 + if _, err := InsertManifest(db, &Manifest{ 1664 + DID: user.DID, Repository: "repoB", Digest: "sha256:b1", 1665 + HoldEndpoint: "did:web:hold.example", SchemaVersion: 2, MediaType: mediaType, 1666 + CreatedAt: now.Add(1 * time.Second), 1667 + }); err != nil { 1668 + t.Fatalf("insert manifest b1: %v", err) 1669 + } 1670 + if err := UpsertTag(db, &Tag{DID: user.DID, Repository: "repoB", Tag: "latest", Digest: "sha256:b1", CreatedAt: now.Add(1 * time.Second)}); err != nil { 1671 + t.Fatalf("upsert tag b latest: %v", err) 1672 + } 1673 + 1674 + // Annotations only on repoA, plus a repo-page avatar on repoB to exercise the icon override. 1675 + if err := UpsertRepositoryAnnotations(db, user.DID, "repoA", map[string]string{ 1676 + "org.opencontainers.image.title": "Repo A Title", 1677 + "org.opencontainers.image.description": "alpha", 1678 + "io.atcr.icon": "https://example.com/a.png", 1679 + }); err != nil { 1680 + t.Fatalf("upsert annotations: %v", err) 1681 + } 1682 + if err := UpsertRepoPage(db, user.DID, "repoB", "", "bafyrepob", false, now, now); err != nil { 1683 + t.Fatalf("upsert repo page: %v", err) 1684 + } 1685 + 1686 + repos, err := GetUserRepositories(db, user.DID, "") 1687 + if err != nil { 1688 + t.Fatalf("GetUserRepositories: %v", err) 1689 + } 1690 + 1691 + // Order: repoA first (newer last_push), then repoB. 1692 + if len(repos) != 2 { 1693 + t.Fatalf("expected 2 repos, got %d: %#v", len(repos), repos) 1694 + } 1695 + if repos[0].Name != "repoA" || repos[1].Name != "repoB" { 1696 + t.Fatalf("expected order [repoA, repoB] (last_push DESC), got [%s, %s]", repos[0].Name, repos[1].Name) 1697 + } 1698 + 1699 + // repoA grouping 1700 + a := repos[0] 1701 + if len(a.Tags) != 2 { 1702 + t.Errorf("repoA: expected 2 tags, got %d", len(a.Tags)) 1703 + } 1704 + // tags ordered created_at DESC → v2 first 1705 + if len(a.Tags) >= 2 && (a.Tags[0].Tag != "v2" || a.Tags[1].Tag != "v1") { 1706 + t.Errorf("repoA tags out of order, want [v2, v1] got [%s, %s]", a.Tags[0].Tag, a.Tags[1].Tag) 1707 + } 1708 + if len(a.Manifests) != 2 { 1709 + t.Errorf("repoA: expected 2 manifests, got %d", len(a.Manifests)) 1710 + } 1711 + // manifests ordered created_at DESC → a2 first 1712 + if len(a.Manifests) >= 2 && (a.Manifests[0].ID != manifestA2 || a.Manifests[1].ID != manifestA1) { 1713 + t.Errorf("repoA manifests out of order, want [a2, a1] got [%d, %d]", a.Manifests[0].ID, a.Manifests[1].ID) 1714 + } 1715 + if a.Title != "Repo A Title" || a.Description != "alpha" { 1716 + t.Errorf("repoA annotations not applied: title=%q desc=%q", a.Title, a.Description) 1717 + } 1718 + if a.IconURL != "https://example.com/a.png" { 1719 + t.Errorf("repoA icon: expected annotation URL, got %q", a.IconURL) 1720 + } 1721 + 1722 + // repoB grouping + page-avatar override 1723 + b := repos[1] 1724 + if len(b.Tags) != 1 || b.Tags[0].Tag != "latest" { 1725 + t.Errorf("repoB tags: %#v", b.Tags) 1726 + } 1727 + if len(b.Manifests) != 1 || b.Manifests[0].Digest != "sha256:b1" { 1728 + t.Errorf("repoB manifests: %#v", b.Manifests) 1729 + } 1730 + if b.IconURL == "" { 1731 + t.Errorf("repoB icon should be derived from repo-page avatar CID, got empty") 1732 + } 1733 + 1734 + // Cross-repo isolation: tags/manifests for repoB must not leak into repoA and vice versa. 1735 + for _, tag := range a.Tags { 1736 + if tag.Repository != "repoA" { 1737 + t.Errorf("repoA tag has wrong repository: %#v", tag) 1738 + } 1739 + } 1740 + for _, m := range b.Manifests { 1741 + if m.Repository != "repoB" { 1742 + t.Errorf("repoB manifest has wrong repository: %#v", m) 1743 + } 1744 + } 1745 + } 1746 + 1747 + // TestGetUserRepositories_Empty verifies the bulk-fetch path short-circuits 1748 + // cleanly when the summary query returns no rows (no extra queries issued, 1749 + // nil slice returned). 1750 + func TestGetUserRepositories_Empty(t *testing.T) { 1751 + db, err := InitDB("file:TestGetUserRepositories_Empty?mode=memory&cache=shared", LibsqlConfig{}) 1752 + if err != nil { 1753 + t.Fatalf("init db: %v", err) 1754 + } 1755 + defer db.Close() 1756 + 1757 + user := &User{DID: "did:plc:nobody", Handle: "nobody.test", PDSEndpoint: "https://pds.example", LastSeen: time.Now()} 1758 + if err := UpsertUser(db, user); err != nil { 1759 + t.Fatalf("upsert user: %v", err) 1760 + } 1761 + 1762 + repos, err := GetUserRepositories(db, user.DID, "") 1763 + if err != nil { 1764 + t.Fatalf("GetUserRepositories empty: %v", err) 1765 + } 1766 + if repos != nil { 1767 + t.Errorf("expected nil slice for user with no repos, got %#v", repos) 1768 + } 1769 + }
+207 -98
pkg/labeler/db.go
··· 19 19 const LabelVersion int64 = labeling.ATPROTO_LABEL_VERSION 20 20 21 21 const schema = ` 22 + CREATE TABLE IF NOT EXISTS takedowns ( 23 + id INTEGER PRIMARY KEY AUTOINCREMENT, 24 + input TEXT NOT NULL, 25 + subject_did TEXT NOT NULL, 26 + subject_repo TEXT NOT NULL DEFAULT '', 27 + subject_handle TEXT NOT NULL DEFAULT '', 28 + reason TEXT NOT NULL DEFAULT '', 29 + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, 30 + created_by TEXT NOT NULL DEFAULT '', 31 + reversed_at TIMESTAMP, 32 + reversed_by TEXT NOT NULL DEFAULT '' 33 + ); 34 + CREATE INDEX IF NOT EXISTS idx_takedowns_active ON takedowns(reversed_at, created_at DESC); 35 + CREATE INDEX IF NOT EXISTS idx_takedowns_subject ON takedowns(subject_did, subject_repo); 22 36 CREATE TABLE IF NOT EXISTS labels ( 23 37 id INTEGER PRIMARY KEY AUTOINCREMENT, 24 38 src TEXT NOT NULL, ··· 31 45 ver INTEGER NOT NULL DEFAULT 1, 32 46 sig BLOB NOT NULL, 33 47 subject_did TEXT NOT NULL, 34 - subject_repo TEXT NOT NULL DEFAULT '' 48 + subject_repo TEXT NOT NULL DEFAULT '', 49 + takedown_id INTEGER REFERENCES takedowns(id) 35 50 ); 36 51 CREATE INDEX IF NOT EXISTS idx_labels_subject ON labels(subject_did, subject_repo); 37 52 CREATE INDEX IF NOT EXISTS idx_labels_cts ON labels(cts DESC); 38 53 CREATE INDEX IF NOT EXISTS idx_labels_uri ON labels(uri); 54 + CREATE INDEX IF NOT EXISTS idx_labels_takedown ON labels(takedown_id); 39 55 ` 40 56 41 57 // Label represents an ATProto label record stored locally. Its on-the-wire representation 42 58 // is produced by ToLabeling() which round-trips through indigo's labeling package so the 43 59 // signature stays valid byte-for-byte. 60 + // 61 + // TakedownID is a labeler-internal pointer to the takedown event that produced this 62 + // label (positive or negation). It's never serialized into ATProto wire format. 44 63 type Label struct { 45 64 ID int64 46 65 Src string ··· 54 73 Sig []byte 55 74 SubjectDID string 56 75 SubjectRepo string 76 + TakedownID *int64 77 + } 78 + 79 + // Takedown is a single operator-issued takedown action. Each Takedown owns one or more 80 + // Label rows linked by takedown_id. Reversal sets reversed_at / reversed_by in place. 81 + type Takedown struct { 82 + ID int64 83 + Input string 84 + SubjectDID string 85 + SubjectRepo string 86 + SubjectHandle string 87 + Reason string 88 + CreatedAt time.Time 89 + CreatedBy string 90 + ReversedAt *time.Time 91 + ReversedBy string 92 + LabelCount int 57 93 } 58 94 59 95 // LibsqlSync configures optional embedded-replica sync to a remote libSQL database. ··· 233 269 s := l.Exp.UTC().Format(time.RFC3339) 234 270 expStr = &s 235 271 } 272 + var takedownID any 273 + if l.TakedownID != nil { 274 + takedownID = *l.TakedownID 275 + } 236 276 result, err := db.Exec( 237 - `INSERT INTO labels (src, uri, cid, val, neg, cts, exp, ver, sig, subject_did, subject_repo) 238 - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, 277 + `INSERT INTO labels (src, uri, cid, val, neg, cts, exp, ver, sig, subject_did, subject_repo, takedown_id) 278 + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, 239 279 l.Src, l.URI, nullableString(l.CID), l.Val, l.Neg, 240 280 l.Cts.UTC().Format(time.RFC3339), expStr, l.Ver, l.Sig, 241 - l.SubjectDID, l.SubjectRepo, 281 + l.SubjectDID, l.SubjectRepo, takedownID, 242 282 ) 243 283 if err != nil { 244 284 return 0, fmt.Errorf("failed to insert label: %w", err) ··· 261 301 // GetLabelsSince returns labels with id > cursor, ordered by id ascending. 262 302 func GetLabelsSince(db *sql.DB, cursor int64, limit int) ([]Label, error) { 263 303 rows, err := db.Query( 264 - `SELECT id, src, uri, COALESCE(cid, ''), val, neg, cts, exp, ver, sig, subject_did, subject_repo 304 + `SELECT id, src, uri, COALESCE(cid, ''), val, neg, cts, exp, ver, sig, subject_did, subject_repo, takedown_id 265 305 FROM labels WHERE id > ? ORDER BY id ASC LIMIT ?`, 266 306 cursor, limit, 267 307 ) ··· 284 324 return seq.Int64, nil 285 325 } 286 326 287 - // ListActiveTakedowns returns active (non-negated) takedown labels. 288 - func ListActiveTakedowns(db *sql.DB, limit, offset int) ([]Label, int, error) { 327 + // CreateTakedown inserts a takedown event row and returns its id. The id should then 328 + // be stamped onto every label produced by this takedown (positive labels at issue time, 329 + // negation labels at reversal time) so the audit trail stays linked. 330 + func CreateTakedown(db *sql.DB, t *Takedown) (int64, error) { 331 + if t.CreatedAt.IsZero() { 332 + t.CreatedAt = time.Now().UTC() 333 + } 334 + result, err := db.Exec( 335 + `INSERT INTO takedowns (input, subject_did, subject_repo, subject_handle, reason, created_at, created_by) 336 + VALUES (?, ?, ?, ?, ?, ?, ?)`, 337 + t.Input, t.SubjectDID, t.SubjectRepo, t.SubjectHandle, t.Reason, 338 + t.CreatedAt.UTC().Format(time.RFC3339), t.CreatedBy, 339 + ) 340 + if err != nil { 341 + return 0, fmt.Errorf("failed to insert takedown: %w", err) 342 + } 343 + id, err := result.LastInsertId() 344 + if err != nil { 345 + return 0, err 346 + } 347 + t.ID = id 348 + return id, nil 349 + } 350 + 351 + // GetTakedown loads a single takedown row by id. Returns sql.ErrNoRows when missing. 352 + func GetTakedown(db *sql.DB, id int64) (*Takedown, error) { 353 + row := db.QueryRow( 354 + `SELECT t.id, t.input, t.subject_did, t.subject_repo, t.subject_handle, t.reason, 355 + t.created_at, t.created_by, t.reversed_at, t.reversed_by, 356 + (SELECT COUNT(*) FROM labels l WHERE l.takedown_id = t.id AND l.neg = 0) 357 + FROM takedowns t WHERE t.id = ?`, 358 + id, 359 + ) 360 + return scanTakedown(row.Scan) 361 + } 362 + 363 + // TakedownFilter scopes ListTakedowns to active, reversed, or all rows. 364 + type TakedownFilter int 365 + 366 + const ( 367 + TakedownAll TakedownFilter = iota // every takedown row, regardless of reversal state 368 + TakedownActive // only takedowns whose reversed_at is NULL 369 + TakedownReversed // only takedowns whose reversed_at is set 370 + ) 371 + 372 + // ListTakedowns returns takedown events ordered by created_at DESC, scoped by filter. 373 + // The total count reflects the same filter. 374 + func ListTakedowns(db *sql.DB, filter TakedownFilter, limit, offset int) ([]Takedown, int, error) { 375 + where := "" 376 + switch filter { 377 + case TakedownActive: 378 + where = "WHERE reversed_at IS NULL" 379 + case TakedownReversed: 380 + where = "WHERE reversed_at IS NOT NULL" 381 + } 382 + 289 383 var total int 290 - err := db.QueryRow( 291 - `SELECT COUNT(*) FROM labels l1 292 - WHERE l1.val = '!takedown' AND l1.neg = 0 293 - AND NOT EXISTS ( 294 - SELECT 1 FROM labels l2 295 - WHERE l2.src = l1.src AND l2.uri = l1.uri AND l2.val = l1.val 296 - AND l2.neg = 1 AND l2.id > l1.id 297 - ) 298 - AND (l1.exp IS NULL OR l1.exp > CURRENT_TIMESTAMP)`, 299 - ).Scan(&total) 300 - if err != nil { 384 + if err := db.QueryRow(`SELECT COUNT(*) FROM takedowns ` + where).Scan(&total); err != nil { 301 385 return nil, 0, err 302 386 } 303 387 304 388 rows, err := db.Query( 305 - `SELECT l1.id, l1.src, l1.uri, COALESCE(l1.cid, ''), l1.val, l1.neg, l1.cts, l1.exp, l1.ver, l1.sig, l1.subject_did, l1.subject_repo 306 - FROM labels l1 307 - WHERE l1.val = '!takedown' AND l1.neg = 0 308 - AND NOT EXISTS ( 309 - SELECT 1 FROM labels l2 310 - WHERE l2.src = l1.src AND l2.uri = l1.uri AND l2.val = l1.val 311 - AND l2.neg = 1 AND l2.id > l1.id 312 - ) 313 - AND (l1.exp IS NULL OR l1.exp > CURRENT_TIMESTAMP) 314 - ORDER BY l1.cts DESC LIMIT ? OFFSET ?`, 389 + `SELECT t.id, t.input, t.subject_did, t.subject_repo, t.subject_handle, t.reason, 390 + t.created_at, t.created_by, t.reversed_at, t.reversed_by, 391 + (SELECT COUNT(*) FROM labels l WHERE l.takedown_id = t.id AND l.neg = 0) 392 + FROM takedowns t `+where+` 393 + ORDER BY t.created_at DESC LIMIT ? OFFSET ?`, 315 394 limit, offset, 316 395 ) 317 396 if err != nil { ··· 319 398 } 320 399 defer rows.Close() 321 400 322 - labels, err := scanLabels(rows) 323 - return labels, total, err 401 + var out []Takedown 402 + for rows.Next() { 403 + t, err := scanTakedown(rows.Scan) 404 + if err != nil { 405 + return nil, 0, err 406 + } 407 + out = append(out, *t) 408 + } 409 + return out, total, rows.Err() 324 410 } 325 411 326 - // GetLabelsForRepo returns all labels for a specific DID + repository. 327 - func GetLabelsForRepo(db *sql.DB, did, repo string) ([]Label, error) { 328 - rows, err := db.Query( 329 - `SELECT id, src, uri, COALESCE(cid, ''), val, neg, cts, exp, ver, sig, subject_did, subject_repo 330 - FROM labels 331 - WHERE subject_did = ? AND subject_repo = ? 332 - ORDER BY cts DESC`, 333 - did, repo, 412 + // MarkTakedownReversed sets the reversed_at / reversed_by fields on the takedown row. 413 + // Refuses to overwrite an existing reversal. 414 + func MarkTakedownReversed(db *sql.DB, id int64, by string, at time.Time) error { 415 + if at.IsZero() { 416 + at = time.Now().UTC() 417 + } 418 + res, err := db.Exec( 419 + `UPDATE takedowns SET reversed_at = ?, reversed_by = ? 420 + WHERE id = ? AND reversed_at IS NULL`, 421 + at.UTC().Format(time.RFC3339), by, id, 334 422 ) 335 423 if err != nil { 336 - return nil, err 424 + return fmt.Errorf("failed to mark takedown reversed: %w", err) 337 425 } 338 - defer rows.Close() 339 - return scanLabels(rows) 340 - } 341 - 342 - // newNegationLabel constructs an unsigned negation label awaiting Sign(). 343 - func newNegationLabel(src, uri, val, did, repo string) *Label { 344 - return &Label{ 345 - Src: src, 346 - URI: uri, 347 - Val: val, 348 - Neg: true, 349 - Cts: time.Now().UTC(), 350 - SubjectDID: did, 351 - SubjectRepo: repo, 426 + n, err := res.RowsAffected() 427 + if err != nil { 428 + return err 429 + } 430 + if n == 0 { 431 + return fmt.Errorf("takedown %d not found or already reversed", id) 352 432 } 433 + return nil 353 434 } 354 435 355 - // NegateRepoLabels signs+inserts negation labels for all active takedown labels on (DID, repo). 356 - func NegateRepoLabels(db *sql.DB, key *atcrypto.PrivateKeyK256, src, did, repo string) ([]Label, error) { 436 + // GetLabelsByTakedown returns all labels (positive + negations) linked to a takedown. 437 + func GetLabelsByTakedown(db *sql.DB, takedownID int64) ([]Label, error) { 357 438 rows, err := db.Query( 358 - `SELECT uri FROM labels 359 - WHERE subject_did = ? AND subject_repo = ? AND val = '!takedown' AND neg = 0`, 360 - did, repo, 439 + `SELECT id, src, uri, COALESCE(cid, ''), val, neg, cts, exp, ver, sig, subject_did, subject_repo, takedown_id 440 + FROM labels WHERE takedown_id = ? ORDER BY id ASC`, 441 + takedownID, 361 442 ) 362 443 if err != nil { 363 444 return nil, err 364 445 } 365 - var uris []string 366 - for rows.Next() { 367 - var uri string 368 - if err := rows.Scan(&uri); err != nil { 369 - rows.Close() 370 - return nil, err 371 - } 372 - uris = append(uris, uri) 373 - } 374 - rows.Close() 375 - if err := rows.Err(); err != nil { 376 - return nil, err 377 - } 378 - 379 - out := make([]Label, 0, len(uris)) 380 - for _, uri := range uris { 381 - neg := newNegationLabel(src, uri, "!takedown", did, repo) 382 - if err := neg.Sign(key); err != nil { 383 - return out, err 384 - } 385 - if _, err := CreateLabel(db, neg); err != nil { 386 - return out, err 387 - } 388 - out = append(out, *neg) 389 - } 390 - return out, nil 446 + defer rows.Close() 447 + return scanLabels(rows) 391 448 } 392 449 393 - // NegateUserLabels signs+inserts negation labels for all active takedown labels on a DID. 394 - func NegateUserLabels(db *sql.DB, key *atcrypto.PrivateKeyK256, src, did string) ([]Label, error) { 450 + // NegateTakedownLabels signs+inserts negation labels for every active (non-negated) 451 + // label linked to the given takedown_id. Negations carry the same takedown_id so they 452 + // remain part of the takedown's audit trail. 453 + // 454 + // The NOT EXISTS subquery skips URIs that already have a later neg=1 row (from a prior 455 + // reversal call or from an external negation streamed in via subscribeLabels), so this 456 + // function is idempotent and won't emit duplicate negations. 457 + func NegateTakedownLabels(db *sql.DB, key *atcrypto.PrivateKeyK256, src string, takedownID int64) ([]Label, error) { 395 458 rows, err := db.Query( 396 - `SELECT uri, subject_repo FROM labels 397 - WHERE subject_did = ? AND val = '!takedown' AND neg = 0`, 398 - did, 459 + `SELECT l1.uri, l1.subject_did, l1.subject_repo FROM labels l1 460 + WHERE l1.takedown_id = ? AND l1.val = '!takedown' AND l1.neg = 0 461 + AND NOT EXISTS ( 462 + SELECT 1 FROM labels l2 463 + WHERE l2.src = l1.src AND l2.uri = l1.uri AND l2.val = l1.val 464 + AND l2.neg = 1 AND l2.id > l1.id 465 + )`, 466 + takedownID, 399 467 ) 400 468 if err != nil { 401 469 return nil, err 402 470 } 403 - type uriRepo struct { 471 + type entry struct { 404 472 uri string 473 + did string 405 474 repo string 406 475 } 407 - var entries []uriRepo 476 + var entries []entry 408 477 for rows.Next() { 409 - var e uriRepo 410 - if err := rows.Scan(&e.uri, &e.repo); err != nil { 478 + var e entry 479 + if err := rows.Scan(&e.uri, &e.did, &e.repo); err != nil { 411 480 rows.Close() 412 481 return nil, err 413 482 } ··· 418 487 return nil, err 419 488 } 420 489 490 + id := takedownID 421 491 out := make([]Label, 0, len(entries)) 422 492 for _, e := range entries { 423 - neg := newNegationLabel(src, e.uri, "!takedown", did, e.repo) 493 + neg := &Label{ 494 + Src: src, 495 + URI: e.uri, 496 + Val: "!takedown", 497 + Neg: true, 498 + Cts: time.Now().UTC(), 499 + SubjectDID: e.did, 500 + SubjectRepo: e.repo, 501 + TakedownID: &id, 502 + } 424 503 if err := neg.Sign(key); err != nil { 425 504 return out, err 426 505 } ··· 432 511 return out, nil 433 512 } 434 513 514 + func scanTakedown(scan func(...any) error) (*Takedown, error) { 515 + var ( 516 + t Takedown 517 + created string 518 + revAt *string 519 + ) 520 + if err := scan( 521 + &t.ID, &t.Input, &t.SubjectDID, &t.SubjectRepo, &t.SubjectHandle, &t.Reason, 522 + &created, &t.CreatedBy, &revAt, &t.ReversedBy, &t.LabelCount, 523 + ); err != nil { 524 + return nil, err 525 + } 526 + if ts, err := time.Parse(time.RFC3339, created); err == nil { 527 + t.CreatedAt = ts 528 + } 529 + if revAt != nil { 530 + if ts, err := time.Parse(time.RFC3339, *revAt); err == nil { 531 + t.ReversedAt = &ts 532 + } 533 + } 534 + return &t, nil 535 + } 536 + 435 537 func scanLabels(rows *sql.Rows) ([]Label, error) { 436 538 var labels []Label 437 539 for rows.Next() { 438 - var l Label 439 - var cts string 440 - var exp *string 441 - if err := rows.Scan(&l.ID, &l.Src, &l.URI, &l.CID, &l.Val, &l.Neg, &cts, &exp, &l.Ver, &l.Sig, &l.SubjectDID, &l.SubjectRepo); err != nil { 540 + var ( 541 + l Label 542 + cts string 543 + exp *string 544 + tdID sql.NullInt64 545 + ) 546 + if err := rows.Scan(&l.ID, &l.Src, &l.URI, &l.CID, &l.Val, &l.Neg, &cts, &exp, &l.Ver, &l.Sig, &l.SubjectDID, &l.SubjectRepo, &tdID); err != nil { 442 547 return nil, err 443 548 } 444 549 if t, err := time.Parse(time.RFC3339, cts); err == nil { ··· 448 553 if t, err := time.Parse(time.RFC3339, *exp); err == nil { 449 554 l.Exp = &t 450 555 } 556 + } 557 + if tdID.Valid { 558 + id := tdID.Int64 559 + l.TakedownID = &id 451 560 } 452 561 labels = append(labels, l) 453 562 }
+116 -50
pkg/labeler/db_test.go
··· 151 151 } 152 152 } 153 153 154 - func TestListActiveTakedowns(t *testing.T) { 154 + func TestListTakedowns(t *testing.T) { 155 155 dir := t.TempDir() 156 156 db := openTestDB(t, filepath.Join(dir, "test.db")) 157 157 key := newTestKey(t) ··· 159 159 src := "did:plc:labeler-1" 160 160 now := time.Now().UTC() 161 161 162 + // Create three takedown events, each with a single summary label, so the 163 + // label_count subquery has something to count. 164 + var ids []int64 162 165 for i, repo := range []string{"repo1", "repo2", "repo3"} { 166 + td := &Takedown{ 167 + Input: "atcr.io/r/did:plc:abc/" + repo, 168 + SubjectDID: "did:plc:abc", 169 + SubjectRepo: repo, 170 + CreatedAt: now.Add(time.Duration(i) * time.Minute), 171 + } 172 + id, err := CreateTakedown(db, td) 173 + if err != nil { 174 + t.Fatalf("CreateTakedown: %v", err) 175 + } 176 + ids = append(ids, id) 163 177 signAndCreate(t, db, key, &Label{ 164 178 Src: src, URI: "at://did:plc:abc/io.atcr.repo/" + repo, 165 - Val: "!takedown", Cts: now.Add(time.Duration(i) * time.Minute), 179 + Val: "!takedown", Cts: td.CreatedAt, 166 180 SubjectDID: "did:plc:abc", SubjectRepo: repo, 181 + TakedownID: &id, 167 182 }) 168 183 } 169 184 170 - labels, total, err := ListActiveTakedowns(db, 10, 0) 185 + tds, total, err := ListTakedowns(db, TakedownActive, 10, 0) 171 186 if err != nil { 172 187 t.Fatal(err) 173 188 } 174 - if total != 3 || len(labels) != 3 { 175 - t.Errorf("expected 3 active takedowns, got total=%d returned=%d", total, len(labels)) 189 + if total != 3 || len(tds) != 3 { 190 + t.Errorf("expected 3 active takedowns, got total=%d returned=%d", total, len(tds)) 191 + } 192 + for _, td := range tds { 193 + if td.LabelCount != 1 { 194 + t.Errorf("takedown %d label_count = %d, want 1", td.ID, td.LabelCount) 195 + } 176 196 } 177 197 178 - if _, err := NegateRepoLabels(db, key, src, "did:plc:abc", "repo2"); err != nil { 198 + // Reverse the middle takedown. 199 + if _, err := NegateTakedownLabels(db, key, src, ids[1]); err != nil { 200 + t.Fatal(err) 201 + } 202 + if err := MarkTakedownReversed(db, ids[1], "did:plc:operator", time.Now().UTC()); err != nil { 179 203 t.Fatal(err) 180 204 } 181 205 182 - _, total, err = ListActiveTakedowns(db, 10, 0) 206 + _, activeTotal, err := ListTakedowns(db, TakedownActive, 10, 0) 207 + if err != nil { 208 + t.Fatal(err) 209 + } 210 + if activeTotal != 2 { 211 + t.Errorf("expected 2 active takedowns after reversal, got %d", activeTotal) 212 + } 213 + revs, revTotal, err := ListTakedowns(db, TakedownReversed, 10, 0) 183 214 if err != nil { 184 215 t.Fatal(err) 185 216 } 186 - if total != 2 { 187 - t.Errorf("expected 2 active takedowns after negation, got %d", total) 217 + if revTotal != 1 || len(revs) != 1 { 218 + t.Errorf("expected 1 reversed takedown, got total=%d returned=%d", revTotal, len(revs)) 219 + } 220 + if revs[0].ID != ids[1] || revs[0].ReversedAt == nil || revs[0].ReversedBy != "did:plc:operator" { 221 + t.Errorf("reversed takedown row has wrong fields: %+v", revs[0]) 188 222 } 189 223 } 190 224 191 - func TestNegateRepoLabels(t *testing.T) { 225 + func TestNegateTakedownLabels(t *testing.T) { 192 226 dir := t.TempDir() 193 227 db := openTestDB(t, filepath.Join(dir, "test.db")) 194 228 key := newTestKey(t) ··· 197 231 now := time.Now().UTC() 198 232 did := "did:plc:abc" 199 233 234 + tdID, err := CreateTakedown(db, &Takedown{ 235 + Input: "atcr.io/r/did:plc:abc/myimage", SubjectDID: did, SubjectRepo: "myimage", 236 + CreatedAt: now, 237 + }) 238 + if err != nil { 239 + t.Fatal(err) 240 + } 241 + 200 242 uris := []string{ 201 243 "at://did:plc:abc/io.atcr.manifest/sha256-111", 202 244 "at://did:plc:abc/io.atcr.manifest/sha256-222", ··· 205 247 for _, uri := range uris { 206 248 signAndCreate(t, db, key, &Label{ 207 249 Src: src, URI: uri, Val: "!takedown", Cts: now, 208 - SubjectDID: did, SubjectRepo: "myimage", 250 + SubjectDID: did, SubjectRepo: "myimage", TakedownID: &tdID, 209 251 }) 210 252 } 211 253 212 - negs, err := NegateRepoLabels(db, key, src, did, "myimage") 254 + negs, err := NegateTakedownLabels(db, key, src, tdID) 213 255 if err != nil { 214 256 t.Fatal(err) 215 257 } ··· 217 259 t.Errorf("expected %d negation labels, got %d", len(uris), len(negs)) 218 260 } 219 261 220 - _, total, err := ListActiveTakedowns(db, 10, 0) 262 + // Negations must carry the same takedown_id so they're part of the audit trail. 263 + all, err := GetLabelsByTakedown(db, tdID) 221 264 if err != nil { 222 265 t.Fatal(err) 223 266 } 224 - if total != 0 { 225 - t.Errorf("expected 0 active takedowns after repo negation, got %d", total) 267 + if len(all) != 2*len(uris) { 268 + t.Errorf("expected %d labels (positive + negation) for takedown %d, got %d", 2*len(uris), tdID, len(all)) 269 + } 270 + var pos, neg int 271 + for _, l := range all { 272 + if l.TakedownID == nil || *l.TakedownID != tdID { 273 + t.Errorf("label %d takedown_id = %v, want %d", l.ID, l.TakedownID, tdID) 274 + } 275 + if l.Neg { 276 + neg++ 277 + } else { 278 + pos++ 279 + } 280 + } 281 + if pos != len(uris) || neg != len(uris) { 282 + t.Errorf("expected pos=%d neg=%d, got pos=%d neg=%d", len(uris), len(uris), pos, neg) 283 + } 284 + 285 + // Calling negate again must be a no-op (no remaining positive labels to flip). 286 + negs2, err := NegateTakedownLabels(db, key, src, tdID) 287 + if err != nil { 288 + t.Fatal(err) 289 + } 290 + if len(negs2) != 0 { 291 + t.Errorf("expected 0 negations on second call, got %d", len(negs2)) 226 292 } 227 293 } 228 294 229 - func TestNegateUserLabels(t *testing.T) { 295 + func TestMarkTakedownReversed_RefusesDoubleReverse(t *testing.T) { 230 296 dir := t.TempDir() 231 297 db := openTestDB(t, filepath.Join(dir, "test.db")) 232 - key := newTestKey(t) 233 298 234 - src := "did:plc:labeler-1" 235 - now := time.Now().UTC() 236 - did := "did:plc:abc" 237 - 238 - signAndCreate(t, db, key, &Label{ 239 - Src: src, URI: "at://did:plc:abc", Val: "!takedown", Cts: now, 240 - SubjectDID: did, 299 + id, err := CreateTakedown(db, &Takedown{ 300 + Input: "did:plc:abc", SubjectDID: "did:plc:abc", 241 301 }) 242 - signAndCreate(t, db, key, &Label{ 243 - Src: src, URI: "at://did:plc:abc/io.atcr.repo/repo1", Val: "!takedown", Cts: now, 244 - SubjectDID: did, SubjectRepo: "repo1", 245 - }) 246 - 247 - negs, err := NegateUserLabels(db, key, src, did) 248 302 if err != nil { 249 303 t.Fatal(err) 250 304 } 251 - if len(negs) != 2 { 252 - t.Errorf("expected 2 negation labels, got %d", len(negs)) 305 + if err := MarkTakedownReversed(db, id, "did:plc:op", time.Now().UTC()); err != nil { 306 + t.Fatalf("first reverse: %v", err) 253 307 } 254 - 255 - _, total, err := ListActiveTakedowns(db, 10, 0) 256 - if err != nil { 257 - t.Fatal(err) 308 + if err := MarkTakedownReversed(db, id, "did:plc:op", time.Now().UTC()); err == nil { 309 + t.Error("expected second reverse to fail (already reversed)") 258 310 } 259 - if total != 0 { 260 - t.Errorf("expected 0 active takedowns after user negation, got %d", total) 311 + if err := MarkTakedownReversed(db, 9999, "did:plc:op", time.Now().UTC()); err == nil { 312 + t.Error("expected reverse on unknown id to fail") 261 313 } 262 314 } 263 315 ··· 326 378 } 327 379 } 328 380 329 - func TestGetLabelsForRepo(t *testing.T) { 381 + func TestGetLabelsByTakedown(t *testing.T) { 330 382 dir := t.TempDir() 331 383 db := openTestDB(t, filepath.Join(dir, "test.db")) 332 384 key := newTestKey(t) ··· 334 386 src := "did:plc:labeler-1" 335 387 now := time.Now().UTC() 336 388 389 + tdA, err := CreateTakedown(db, &Takedown{Input: "did:plc:abc/repo1", SubjectDID: "did:plc:abc", SubjectRepo: "repo1", CreatedAt: now}) 390 + if err != nil { 391 + t.Fatal(err) 392 + } 393 + tdB, err := CreateTakedown(db, &Takedown{Input: "did:plc:def/repo1", SubjectDID: "did:plc:def", SubjectRepo: "repo1", CreatedAt: now}) 394 + if err != nil { 395 + t.Fatal(err) 396 + } 397 + 337 398 signAndCreate(t, db, key, &Label{ 338 399 Src: src, URI: "at://did:plc:abc/io.atcr.repo/repo1", 339 - Val: "!takedown", Cts: now, SubjectDID: "did:plc:abc", SubjectRepo: "repo1", 400 + Val: "!takedown", Cts: now, SubjectDID: "did:plc:abc", SubjectRepo: "repo1", TakedownID: &tdA, 340 401 }) 341 402 signAndCreate(t, db, key, &Label{ 342 - Src: src, URI: "at://did:plc:abc/io.atcr.repo/repo2", 343 - Val: "!takedown", Cts: now, SubjectDID: "did:plc:abc", SubjectRepo: "repo2", 403 + Src: src, URI: "at://did:plc:abc/io.atcr.manifest/sha256-aaa", 404 + Val: "!takedown", Cts: now, SubjectDID: "did:plc:abc", SubjectRepo: "repo1", TakedownID: &tdA, 344 405 }) 345 406 signAndCreate(t, db, key, &Label{ 346 407 Src: src, URI: "at://did:plc:def/io.atcr.repo/repo1", 347 - Val: "!takedown", Cts: now, SubjectDID: "did:plc:def", SubjectRepo: "repo1", 408 + Val: "!takedown", Cts: now, SubjectDID: "did:plc:def", SubjectRepo: "repo1", TakedownID: &tdB, 348 409 }) 349 410 350 - labels, err := GetLabelsForRepo(db, "did:plc:abc", "repo1") 411 + labels, err := GetLabelsByTakedown(db, tdA) 351 412 if err != nil { 352 413 t.Fatal(err) 353 414 } 354 - if len(labels) != 1 { 355 - t.Errorf("expected 1 label for did:plc:abc/repo1, got %d", len(labels)) 415 + if len(labels) != 2 { 416 + t.Errorf("takedown A: expected 2 labels, got %d", len(labels)) 417 + } 418 + for _, l := range labels { 419 + if l.TakedownID == nil || *l.TakedownID != tdA { 420 + t.Errorf("label %d has wrong takedown_id: %v", l.ID, l.TakedownID) 421 + } 356 422 } 357 423 358 - labels, err = GetLabelsForRepo(db, "did:plc:def", "repo1") 424 + labels, err = GetLabelsByTakedown(db, tdB) 359 425 if err != nil { 360 426 t.Fatal(err) 361 427 } 362 428 if len(labels) != 1 { 363 - t.Errorf("expected 1 label for did:plc:def/repo1, got %d", len(labels)) 429 + t.Errorf("takedown B: expected 1 label, got %d", len(labels)) 364 430 } 365 431 366 - labels, err = GetLabelsForRepo(db, "did:plc:xyz", "repo1") 432 + labels, err = GetLabelsByTakedown(db, 9999) 367 433 if err != nil { 368 434 t.Fatal(err) 369 435 } 370 436 if len(labels) != 0 { 371 - t.Errorf("expected 0 labels for unknown did, got %d", len(labels)) 437 + t.Errorf("unknown takedown id: expected 0 labels, got %d", len(labels)) 372 438 } 373 439 }
+198 -59
pkg/labeler/takedown.go
··· 8 8 "log/slog" 9 9 "net/http" 10 10 "net/url" 11 + "strconv" 11 12 "strings" 12 13 "time" 13 14 ··· 19 20 DID string 20 21 Handle string 21 22 Repository string // empty = user-level takedown 23 + // Operator-supplied context. Captured into the takedowns row so we can show 24 + // who/why/what-was-typed on the dashboard. None are required. 25 + RawInput string // exact string the operator submitted (URL, did, handle, AT URI) 26 + Reason string // optional free-text note 27 + CreatedBy string // operator DID from session, "" if unknown 22 28 } 23 29 24 30 // ParseTakedownInput parses various input formats into a TakedownInput. ··· 165 171 166 172 // TakedownResult contains the results of a takedown operation. 167 173 type TakedownResult struct { 174 + TakedownID int64 168 175 DID string 169 176 Handle string 170 177 Repository string ··· 172 179 UserLevel bool 173 180 } 174 181 175 - // ExecuteTakedown creates takedown labels for a repo or user. 182 + // ExecuteTakedown creates a takedown event row and the labels that belong to it. 183 + // Every label (the user-level label, the per-record labels discovered via PDS, and the 184 + // repo-level summary label) carries the new takedown_id so reversal can target the 185 + // exact set without re-querying by subject. 176 186 func (s *Server) ExecuteTakedown(ctx context.Context, input *TakedownInput) (*TakedownResult, error) { 177 187 src := s.did 178 188 now := time.Now().UTC() 189 + 190 + td := &Takedown{ 191 + Input: input.RawInput, 192 + SubjectDID: input.DID, 193 + SubjectRepo: input.Repository, 194 + SubjectHandle: input.Handle, 195 + Reason: input.Reason, 196 + CreatedAt: now, 197 + CreatedBy: input.CreatedBy, 198 + } 199 + if td.Input == "" { 200 + // Fallback so the dashboard always has something to show, even if a caller 201 + // (e.g. a future API) didn't pass the original string. 202 + if input.Repository != "" { 203 + td.Input = fmt.Sprintf("%s/%s", input.DID, input.Repository) 204 + } else { 205 + td.Input = input.DID 206 + } 207 + } 208 + takedownID, err := CreateTakedown(s.db, td) 209 + if err != nil { 210 + return nil, fmt.Errorf("failed to create takedown event: %w", err) 211 + } 212 + 179 213 result := &TakedownResult{ 214 + TakedownID: takedownID, 180 215 DID: input.DID, 181 216 Handle: input.Handle, 182 217 Repository: input.Repository, ··· 184 219 } 185 220 186 221 if input.Repository == "" { 187 - // User-level takedown 222 + // User-level takedown: a single label on at://<did>. 188 223 label := &Label{ 189 224 Src: src, 190 225 URI: "at://" + input.DID, ··· 192 227 Cts: now, 193 228 SubjectDID: input.DID, 194 229 SubjectRepo: "", 230 + TakedownID: &takedownID, 195 231 } 196 232 if err := label.Sign(s.signingKey); err != nil { 197 233 return nil, fmt.Errorf("failed to sign user-level label: %w", err) ··· 201 237 } 202 238 s.hub.Broadcast(label) 203 239 result.Labels = append(result.Labels, *label) 204 - slog.Info("Created user-level takedown", "did", input.DID, "handle", input.Handle) 240 + slog.Info("Created user-level takedown", "takedown_id", takedownID, "did", input.DID, "handle", input.Handle) 205 241 return result, nil 206 242 } 207 243 208 - // Repo-level takedown: discover all records from PDS 209 - labels, err := s.discoverAndLabelRecords(ctx, input.DID, input.Repository, src, now) 244 + // Repo-level takedown: discover all records from PDS and label each. 245 + labels, err := s.discoverAndLabelRecords(ctx, input.DID, input.Repository, src, now, takedownID) 210 246 if err != nil { 211 - // Even if PDS discovery fails, create a repo-level summary label 247 + // Even if PDS discovery fails, create a repo-level summary label so reads 248 + // against the well-known summary URI still see the takedown. 212 249 slog.Warn("PDS discovery failed, creating summary label only", "error", err) 213 250 } 214 251 result.Labels = append(result.Labels, labels...) 215 252 216 - // Always create a repo-level summary label for efficient filtering 253 + // Always create a repo-level summary label for efficient filtering. 217 254 summaryLabel := &Label{ 218 255 Src: src, 219 256 URI: fmt.Sprintf("at://%s/io.atcr.repo/%s", input.DID, input.Repository), ··· 221 258 Cts: now, 222 259 SubjectDID: input.DID, 223 260 SubjectRepo: input.Repository, 261 + TakedownID: &takedownID, 224 262 } 225 263 if err := summaryLabel.Sign(s.signingKey); err != nil { 226 264 return nil, fmt.Errorf("failed to sign summary label: %w", err) ··· 232 270 result.Labels = append(result.Labels, *summaryLabel) 233 271 234 272 slog.Info("Created repo-level takedown", 273 + "takedown_id", takedownID, 235 274 "did", input.DID, 236 275 "handle", input.Handle, 237 276 "repository", input.Repository, ··· 241 280 return result, nil 242 281 } 243 282 283 + // ReverseTakedown negates every active label belonging to the given takedown event and 284 + // marks the event row as reversed. Refuses to act on a takedown that doesn't exist or 285 + // has already been reversed. 286 + func (s *Server) ReverseTakedown(ctx context.Context, takedownID int64, reversedBy string) (*TakedownResult, error) { 287 + td, err := GetTakedown(s.db, takedownID) 288 + if err != nil { 289 + return nil, fmt.Errorf("failed to load takedown %d: %w", takedownID, err) 290 + } 291 + if td.ReversedAt != nil { 292 + return nil, fmt.Errorf("takedown %d already reversed at %s", takedownID, td.ReversedAt.Format(time.RFC3339)) 293 + } 294 + 295 + negs, err := NegateTakedownLabels(s.db, s.signingKey, s.did, takedownID) 296 + if err != nil { 297 + return nil, fmt.Errorf("failed to negate labels for takedown %d: %w", takedownID, err) 298 + } 299 + for i := range negs { 300 + s.hub.Broadcast(&negs[i]) 301 + } 302 + if err := MarkTakedownReversed(s.db, takedownID, reversedBy, time.Now().UTC()); err != nil { 303 + return nil, err 304 + } 305 + 306 + slog.Info("Reversed takedown", 307 + "takedown_id", takedownID, 308 + "did", td.SubjectDID, 309 + "repository", td.SubjectRepo, 310 + "reversed_by", reversedBy, 311 + "negations", len(negs), 312 + ) 313 + return &TakedownResult{ 314 + TakedownID: takedownID, 315 + DID: td.SubjectDID, 316 + Handle: td.SubjectHandle, 317 + Repository: td.SubjectRepo, 318 + Labels: negs, 319 + UserLevel: td.SubjectRepo == "", 320 + }, nil 321 + } 322 + 244 323 // discoverAndLabelRecords queries the user's PDS for all records in the given repo 245 - // and creates takedown labels for each. 246 - func (s *Server) discoverAndLabelRecords(ctx context.Context, did, repo, src string, now time.Time) ([]Label, error) { 324 + // and creates takedown labels for each, all linked to takedownID. 325 + func (s *Server) discoverAndLabelRecords(ctx context.Context, did, repo, src string, now time.Time, takedownID int64) ([]Label, error) { 247 326 _, _, pdsEndpoint, err := atproto.ResolveIdentity(ctx, did) 248 327 if err != nil { 249 328 return nil, fmt.Errorf("failed to resolve DID: %w", err) ··· 282 361 Cts: now, 283 362 SubjectDID: did, 284 363 SubjectRepo: repo, 364 + TakedownID: &takedownID, 285 365 } 286 366 if err := label.Sign(s.signingKey); err != nil { 287 367 slog.Warn("Failed to sign label", "uri", uri, "error", err) ··· 314 394 // Handlers 315 395 316 396 func (s *Server) handleDashboard(w http.ResponseWriter, r *http.Request) { 317 - labels, total, err := ListActiveTakedowns(s.db, 50, 0) 397 + active, activeTotal, err := ListTakedowns(s.db, TakedownActive, 50, 0) 318 398 if err != nil { 319 399 http.Error(w, "Failed to list takedowns", http.StatusInternalServerError) 320 400 return 321 401 } 402 + reversed, reversedTotal, err := ListTakedowns(s.db, TakedownReversed, 50, 0) 403 + if err != nil { 404 + http.Error(w, "Failed to list reversed takedowns", http.StatusInternalServerError) 405 + return 406 + } 407 + 322 408 csrf := "" 323 409 if session := SessionFromContext(r.Context()); session != nil { 324 410 csrf = session.CSRFToken ··· 329 415 <html> 330 416 <head><title>%s Labeler</title> 331 417 <style> 332 - body{font-family:system-ui;max-width:900px;margin:40px auto;padding:0 20px} 418 + body{font-family:system-ui;max-width:1000px;margin:40px auto;padding:0 20px} 333 419 table{width:100%%;border-collapse:collapse;margin:20px 0} 334 - th,td{text-align:left;padding:8px;border-bottom:1px solid #ddd} 420 + th,td{text-align:left;padding:8px;border-bottom:1px solid #ddd;vertical-align:top} 335 421 th{background:#f5f5f5} 336 - .badge{background:#dc2626;color:white;padding:2px 8px;border-radius:4px;font-size:0.85em} 422 + .muted{color:#666;font-size:0.9em} 337 423 a{color:#2563eb} 338 424 nav{display:flex;gap:16px;margin-bottom:24px} 339 425 .btn{padding:8px 16px;background:#2563eb;color:white;text-decoration:none;border-radius:4px;border:none;cursor:pointer} 340 426 .btn-danger{background:#dc2626} 341 427 form{display:inline} 428 + code{background:#f4f4f5;padding:1px 4px;border-radius:3px} 429 + .reason{max-width:280px;white-space:pre-wrap} 342 430 </style> 343 431 </head> 344 432 <body> ··· 351 439 <h2>Active Takedowns (%d)</h2>`, 352 440 s.config.Server.ClientShortName, 353 441 s.config.Server.ClientShortName, 354 - total, 442 + activeTotal, 355 443 ) 356 444 357 - if len(labels) == 0 { 358 - fmt.Fprint(w, `<p>No active takedowns.</p>`) 445 + renderTakedownRows(w, active, csrf, true) 446 + 447 + fmt.Fprintf(w, `<h2>Reversed (%d)</h2>`, reversedTotal) 448 + renderTakedownRows(w, reversed, csrf, false) 449 + 450 + fmt.Fprint(w, `</body></html>`) 451 + } 452 + 453 + // renderTakedownRows writes either an active table (with a Reverse button) or a 454 + // reversed-history table (with a reversed-at column instead). 455 + func renderTakedownRows(w http.ResponseWriter, ts []Takedown, csrf string, withReverse bool) { 456 + if len(ts) == 0 { 457 + if withReverse { 458 + fmt.Fprint(w, `<p class="muted">No active takedowns.</p>`) 459 + } else { 460 + fmt.Fprint(w, `<p class="muted">No reversed takedowns yet.</p>`) 461 + } 462 + return 463 + } 464 + fmt.Fprint(w, `<table><tr><th>Input</th><th>Subject</th><th>Reason</th><th>Labels</th><th>Created</th>`) 465 + if withReverse { 466 + fmt.Fprint(w, `<th>Action</th>`) 359 467 } else { 360 - fmt.Fprint(w, `<table><tr><th>Subject</th><th>Repository</th><th>URI</th><th>Created</th><th>Action</th></tr>`) 361 - for _, l := range labels { 362 - repoDisplay := l.SubjectRepo 363 - if repoDisplay == "" { 364 - repoDisplay = "<em>all repos (user-level)</em>" 468 + fmt.Fprint(w, `<th>Reversed</th>`) 469 + } 470 + fmt.Fprint(w, `</tr>`) 471 + for _, t := range ts { 472 + subject := template.HTMLEscapeString(t.SubjectDID) 473 + if t.SubjectHandle != "" { 474 + subject = fmt.Sprintf(`%s<br><span class="muted">%s</span>`, 475 + template.HTMLEscapeString(t.SubjectHandle), subject) 476 + } 477 + if t.SubjectRepo != "" { 478 + subject += fmt.Sprintf(` / <code>%s</code>`, template.HTMLEscapeString(t.SubjectRepo)) 479 + } else { 480 + subject += ` <span class="muted">(user-level)</span>` 481 + } 482 + 483 + reason := template.HTMLEscapeString(t.Reason) 484 + if reason == "" { 485 + reason = `<span class="muted">—</span>` 486 + } 487 + 488 + var lastCol string 489 + if withReverse { 490 + lastCol = fmt.Sprintf( 491 + `<form method="POST" action="/reverse">%s<input type="hidden" name="takedown_id" value="%d"><button type="submit" class="btn btn-danger" onclick="return confirm('Reverse this takedown?')">Reverse</button></form>`, 492 + csrfInputHTML(csrf), t.ID, 493 + ) 494 + } else { 495 + rev := "" 496 + if t.ReversedAt != nil { 497 + rev = t.ReversedAt.Format("2006-01-02 15:04") 365 498 } 366 - fmt.Fprintf(w, `<tr> 499 + by := "" 500 + if t.ReversedBy != "" { 501 + by = fmt.Sprintf(`<br><span class="muted">by %s</span>`, template.HTMLEscapeString(t.ReversedBy)) 502 + } 503 + lastCol = rev + by 504 + } 505 + 506 + fmt.Fprintf(w, `<tr> 507 + <td><code>%s</code></td> 367 508 <td>%s</td> 509 + <td class="reason">%s</td> 510 + <td>%d</td> 368 511 <td>%s</td> 369 - <td><code>%s</code></td> 370 512 <td>%s</td> 371 - <td><form method="POST" action="/reverse">%s<input type="hidden" name="did" value="%s"><input type="hidden" name="repo" value="%s"><button type="submit" class="btn btn-danger" onclick="return confirm('Reverse this takedown?')">Reverse</button></form></td> 372 513 </tr>`, 373 - template.HTMLEscapeString(l.SubjectDID), 374 - repoDisplay, 375 - template.HTMLEscapeString(l.URI), 376 - l.Cts.Format("2006-01-02 15:04"), 377 - csrfInputHTML(csrf), 378 - template.HTMLEscapeString(l.SubjectDID), 379 - template.HTMLEscapeString(l.SubjectRepo), 380 - ) 381 - } 382 - fmt.Fprint(w, `</table>`) 514 + template.HTMLEscapeString(t.Input), 515 + subject, 516 + reason, 517 + t.LabelCount, 518 + t.CreatedAt.Format("2006-01-02 15:04"), 519 + lastCol, 520 + ) 383 521 } 384 - 385 - fmt.Fprint(w, `</body></html>`) 522 + fmt.Fprint(w, `</table>`) 386 523 } 387 524 388 525 func (s *Server) handleTakedownForm(w http.ResponseWriter, r *http.Request) { ··· 431 568 <label for="target"><strong>Target</strong></label> 432 569 <input type="text" id="target" name="target" placeholder="/r/handle/repo, /u/handle, at://did/collection/rkey, handle, or did:..." required> 433 570 <p class="help">Repo: <code>/r/handle/repo</code> (or full atcr.io URL). User-level: <code>/u/handle</code>, a bare handle, or a DID. AT URIs (<code>at://...</code>) also work.</p> 571 + 572 + <label for="reason"><strong>Reason</strong> <span class="help">(optional, internal note)</span></label> 573 + <textarea id="reason" name="reason" rows="3" placeholder="Why is this being taken down? Visible only to labeler operators."></textarea> 574 + 434 575 <br> 435 576 <button type="submit" class="btn" onclick="return confirm('Issue takedown? This will suppress the content immediately.')">Issue Takedown</button> 436 577 </form> ··· 443 584 http.Redirect(w, r, "/takedown?error=Target+is+required", http.StatusFound) 444 585 return 445 586 } 587 + reason := strings.TrimSpace(r.FormValue("reason")) 446 588 447 589 input, err := ParseTakedownInput(r.Context(), target) 448 590 if err != nil { 449 591 http.Redirect(w, r, "/takedown?error="+strings.ReplaceAll(err.Error(), " ", "+"), http.StatusFound) 450 592 return 451 593 } 594 + input.RawInput = target 595 + input.Reason = reason 596 + if session := SessionFromContext(r.Context()); session != nil { 597 + input.CreatedBy = session.DID 598 + } 452 599 453 600 result, err := s.ExecuteTakedown(r.Context(), input) 454 601 if err != nil { ··· 456 603 return 457 604 } 458 605 459 - msg := fmt.Sprintf("Takedown issued: %d labels created for %s", len(result.Labels), result.DID) 606 + msg := fmt.Sprintf("Takedown #%d issued: %d labels created for %s", result.TakedownID, len(result.Labels), result.DID) 460 607 if result.Repository != "" { 461 608 msg += "/" + result.Repository 462 609 } ··· 464 611 } 465 612 466 613 func (s *Server) handleReverse(w http.ResponseWriter, r *http.Request) { 467 - did := strings.TrimSpace(r.FormValue("did")) 468 - repo := strings.TrimSpace(r.FormValue("repo")) 469 - 470 - if did == "" { 471 - http.Redirect(w, r, "/?error=DID+is+required", http.StatusFound) 614 + idStr := strings.TrimSpace(r.FormValue("takedown_id")) 615 + if idStr == "" { 616 + http.Redirect(w, r, "/?error=Missing+takedown_id", http.StatusFound) 472 617 return 473 618 } 474 - 475 - src := s.did 476 - var ( 477 - negs []Label 478 - err error 479 - ) 480 - if repo == "" { 481 - negs, err = NegateUserLabels(s.db, s.signingKey, src, did) 482 - } else { 483 - negs, err = NegateRepoLabels(s.db, s.signingKey, src, did, repo) 619 + id, err := strconv.ParseInt(idStr, 10, 64) 620 + if err != nil || id <= 0 { 621 + http.Redirect(w, r, "/?error=Invalid+takedown_id", http.StatusFound) 622 + return 484 623 } 485 624 486 - if err != nil { 487 - slog.Error("Failed to reverse takedown", "did", did, "repo", repo, "error", err) 488 - http.Redirect(w, r, "/?error=Failed+to+reverse+takedown", http.StatusFound) 489 - return 625 + reversedBy := "" 626 + if session := SessionFromContext(r.Context()); session != nil { 627 + reversedBy = session.DID 490 628 } 491 629 492 - for i := range negs { 493 - s.hub.Broadcast(&negs[i]) 630 + if _, err := s.ReverseTakedown(r.Context(), id, reversedBy); err != nil { 631 + slog.Error("Failed to reverse takedown", "takedown_id", id, "error", err) 632 + http.Redirect(w, r, "/?error="+strings.ReplaceAll(err.Error(), " ", "+"), http.StatusFound) 633 + return 494 634 } 495 635 496 - slog.Info("Reversed takedown", "did", did, "repo", repo, "negations", len(negs)) 497 636 http.Redirect(w, r, "/", http.StatusFound) 498 637 }