A container registry that uses the AT Protocol for manifest storage and S3 for blob storage. atcr.io
docker container atproto go
80
fork

Configure Feed

Select the types of activity you want to include in your feed.

more labeler improvements. standardize did work between labeler and hold. improve sql race conditions on local-only db

+3714 -1532
+25
.air.labeler.toml
··· 1 + root = "." 2 + tmp_dir = "tmp" 3 + 4 + [build] 5 + cmd = "go build -buildvcs=false -o ./tmp/atcr-labeler ./cmd/labeler" 6 + entrypoint = ["./tmp/atcr-labeler", "serve", "--config", "config-labeler.example.yaml"] 7 + include_ext = ["go", "html", "css", "js"] 8 + exclude_dir = ["bin", "tmp", "vendor", "deploy", "docs", ".git", "dist", "pkg/appview", "pkg/hold", "node_modules"] 9 + exclude_regex = ["_test\\.go$", "cbor_gen\\.go$", "\\.min\\.js$", "public/css/style\\.css$", "public/icons\\.svg$"] 10 + delay = 3000 11 + stop_on_error = true 12 + send_interrupt = true 13 + kill_delay = 500 14 + 15 + [log] 16 + time = false 17 + 18 + [color] 19 + main = "cyan" 20 + watcher = "magenta" 21 + build = "yellow" 22 + runner = "green" 23 + 24 + [misc] 25 + clean_on_exit = true
+157 -92
cmd/hold/plc.go
··· 5 5 "fmt" 6 6 "log/slog" 7 7 8 + "atcr.io/pkg/atproto/did" 8 9 "atcr.io/pkg/auth/oauth" 9 10 "atcr.io/pkg/hold" 10 - "atcr.io/pkg/hold/pds" 11 11 12 12 "github.com/bluesky-social/indigo/atproto/atcrypto" 13 - didplc "github.com/did-method-plc/go-didplc" 14 13 "github.com/spf13/cobra" 15 14 ) 16 15 ··· 20 19 } 21 20 22 21 var plcConfigFile string 22 + 23 + var ( 24 + plcAddRotationKeyFirst bool 25 + plcAddRotationKeyLast bool 26 + ) 23 27 24 28 var plcAddRotationKeyCmd = &cobra.Command{ 25 - Use: "add-rotation-key <multibase-key>", 29 + Use: "add-rotation-key [multibase-key]", 26 30 Short: "Add a rotation key to this hold's PLC identity", 27 31 Long: `Add an additional rotation key to the hold's did:plc document. 28 - The key must be a multibase-encoded private key (K-256 or P-256, starting with 'z'). 32 + 33 + If a multibase-encoded private key (K-256 or P-256, starting with 'z') is supplied as 34 + the positional argument, that key is added. If no argument is given, a fresh K-256 35 + keypair is generated and the private half is printed to stdout. Save it offline as 36 + your recovery key, since it will not be shown again. 37 + 38 + By default the new key is inserted at the highest priority position (--first), which 39 + allows it to override ops signed by lower-priority keys within PLC's 72-hour recovery 40 + window. Pass --last to append at the lowest priority instead. 41 + 29 42 The hold's configured rotation key is used to sign the PLC update. 30 43 31 - atcr-hold plc add-rotation-key --config config.yaml z...`, 32 - Args: cobra.ExactArgs(1), 44 + atcr-hold plc add-rotation-key --config config.yaml # generate + print 45 + atcr-hold plc add-rotation-key --config config.yaml --last # append, low priority 46 + atcr-hold plc add-rotation-key --config config.yaml z... # use supplied key`, 47 + Args: cobra.MaximumNArgs(1), 33 48 RunE: func(cmd *cobra.Command, args []string) error { 49 + firstSet := cmd.Flags().Changed("first") 50 + lastSet := cmd.Flags().Changed("last") 51 + if firstSet && lastSet { 52 + return fmt.Errorf("--first and --last are mutually exclusive") 53 + } 54 + prepend := !plcAddRotationKeyLast 55 + 34 56 cfg, err := hold.LoadConfig(plcConfigFile) 35 57 if err != nil { 36 58 return fmt.Errorf("failed to load config: %w", err) 37 59 } 38 - 39 60 if cfg.Database.DIDMethod != "plc" { 40 61 return fmt.Errorf("this command only works with did:plc (database.did_method is %q)", cfg.Database.DIDMethod) 41 62 } 42 63 43 64 ctx := context.Background() 44 - 45 - // Resolve the hold's DID 46 - holdDID, err := pds.LoadOrCreateDID(ctx, pds.DIDConfig{ 47 - DID: cfg.Database.DID, 48 - DIDMethod: cfg.Database.DIDMethod, 49 - PublicURL: cfg.Server.PublicURL, 50 - DBPath: cfg.Database.Path, 51 - SigningKeyPath: cfg.Database.KeyPath, 52 - RotationKey: cfg.Database.RotationKey, 53 - PLCDirectoryURL: cfg.Database.PLCDirectoryURL, 54 - }) 65 + holdDID, rotationKey, signingKey, err := loadHoldPLCIdentity(ctx, cfg) 55 66 if err != nil { 56 - return fmt.Errorf("failed to resolve hold DID: %w", err) 67 + return err 57 68 } 58 69 59 - // Parse the rotation key from config (required for signing PLC updates) 60 - if cfg.Database.RotationKey == "" { 61 - return fmt.Errorf("database.rotation_key must be set to sign PLC updates") 70 + var newKey atcrypto.PrivateKeyExportable 71 + if len(args) == 1 { 72 + newKey, err = atcrypto.ParsePrivateMultibase(args[0]) 73 + if err != nil { 74 + return fmt.Errorf("failed to parse key argument: %w", err) 75 + } 62 76 } 63 - rotationKey, err := atcrypto.ParsePrivateMultibase(cfg.Database.RotationKey) 77 + 78 + res, err := did.AddRotationKey(ctx, did.AddRotationKeyOptions{ 79 + DID: holdDID, 80 + PLCDirectoryURL: cfg.Database.PLCDirectoryURL, 81 + RotationKey: rotationKey, 82 + SigningKey: signingKey, 83 + VerificationKeyName: "atproto", 84 + NewKey: newKey, 85 + Prepend: prepend, 86 + }) 64 87 if err != nil { 65 - return fmt.Errorf("failed to parse rotation_key from config: %w", err) 88 + return err 66 89 } 67 90 68 - // Parse the new key to add (K-256 or P-256) 69 - newKey, err := atcrypto.ParsePrivateMultibase(args[0]) 70 - if err != nil { 71 - return fmt.Errorf("failed to parse key argument: %w", err) 91 + if res.AlreadyPresent { 92 + fmt.Printf("Key %s is already a rotation key for %s (priority %d of %d)\n", 93 + res.NewKeyDIDKey, holdDID, res.ExistingAt, res.TotalKeys) 94 + return nil 72 95 } 73 - newKeyPub, err := newKey.PublicKey() 74 - if err != nil { 75 - return fmt.Errorf("failed to get public key from argument: %w", err) 76 - } 77 - newKeyDIDKey := newKeyPub.DIDKey() 78 96 79 - // Load signing key for verification methods 80 - keyPath := cfg.Database.KeyPath 81 - if keyPath == "" { 82 - keyPath = cfg.Database.Path + "/signing.key" 83 - } 84 - signingKey, err := oauth.GenerateOrLoadPDSKey(keyPath) 85 - if err != nil { 86 - return fmt.Errorf("failed to load signing key: %w", err) 97 + if res.Generated { 98 + fmt.Println("=========================================================================") 99 + fmt.Println("GENERATED NEW ROTATION KEY. SAVE THIS NOW. IT WILL NOT BE SHOWN AGAIN.") 100 + fmt.Println("Store it offline (password manager, paper, hardware token).") 101 + fmt.Println() 102 + fmt.Printf("Private key (multibase): %s\n", res.NewKey.Multibase()) 103 + fmt.Printf("Public key (did:key): %s\n", res.NewKeyDIDKey) 104 + fmt.Println("=========================================================================") 87 105 } 88 106 89 - // Fetch current PLC state 90 - plcDirectoryURL := cfg.Database.PLCDirectoryURL 91 - if plcDirectoryURL == "" { 92 - plcDirectoryURL = "https://plc.directory" 93 - } 94 - client := &didplc.Client{DirectoryURL: plcDirectoryURL} 107 + slog.Info("Added rotation key to PLC identity", 108 + "did", holdDID, 109 + "new_key", res.NewKeyDIDKey, 110 + "priority", res.InsertedAt, 111 + "total_rotation_keys", res.TotalKeys, 112 + "generated", res.Generated, 113 + ) 114 + fmt.Printf("Added rotation key %s to %s (priority %d of %d)\n", 115 + res.NewKeyDIDKey, holdDID, res.InsertedAt, res.TotalKeys) 116 + return nil 117 + }, 118 + } 119 + 120 + var plcListRotationKeysCmd = &cobra.Command{ 121 + Use: "list-rotation-keys", 122 + Short: "List rotation keys in this hold's PLC document", 123 + Long: `Fetch the hold's did:plc document from the PLC directory and print its 124 + rotation keys in priority order (index 0 is highest priority and can override 125 + ops signed by lower-priority keys within PLC's 72-hour recovery window). 95 126 96 - opLog, err := client.OpLog(ctx, holdDID) 127 + The key matching the local database.rotation_key is marked as LOCAL.`, 128 + Args: cobra.NoArgs, 129 + RunE: func(cmd *cobra.Command, args []string) error { 130 + cfg, err := hold.LoadConfig(plcConfigFile) 97 131 if err != nil { 98 - return fmt.Errorf("failed to fetch PLC op log: %w", err) 132 + return fmt.Errorf("failed to load config: %w", err) 99 133 } 100 - if len(opLog) == 0 { 101 - return fmt.Errorf("empty op log for %s", holdDID) 134 + if cfg.Database.DIDMethod != "plc" { 135 + return fmt.Errorf("this command only works with did:plc (database.did_method is %q)", cfg.Database.DIDMethod) 102 136 } 103 137 104 - lastEntry := opLog[len(opLog)-1] 105 - lastOp := lastEntry.Regular 106 - if lastOp == nil { 107 - return fmt.Errorf("last PLC operation is not a regular op") 138 + ctx := context.Background() 139 + holdDID, err := did.LoadOrCreate(ctx, cfg.DIDConfig()) 140 + if err != nil { 141 + return fmt.Errorf("failed to resolve hold DID: %w", err) 108 142 } 109 143 110 - // Check if key already present 111 - for _, k := range lastOp.RotationKeys { 112 - if k == newKeyDIDKey { 113 - fmt.Printf("Key %s is already a rotation key for %s\n", newKeyDIDKey, holdDID) 114 - return nil 144 + var localRotationKey atcrypto.PrivateKey 145 + if cfg.Database.RotationKey != "" { 146 + localRotationKey, err = atcrypto.ParsePrivateMultibase(cfg.Database.RotationKey) 147 + if err != nil { 148 + return fmt.Errorf("failed to parse rotation_key from config: %w", err) 115 149 } 116 150 } 117 151 118 - // Build updated rotation keys: keep existing, append new 119 - rotationKeys := make([]string, len(lastOp.RotationKeys)) 120 - copy(rotationKeys, lastOp.RotationKeys) 121 - rotationKeys = append(rotationKeys, newKeyDIDKey) 122 - 123 - // Build update: preserve everything else from current state 124 - sigPub, err := signingKey.PublicKey() 152 + res, err := did.ListRotationKeys(ctx, did.ListRotationKeysOptions{ 153 + DID: holdDID, 154 + PLCDirectoryURL: cfg.Database.PLCDirectoryURL, 155 + LocalRotationKey: localRotationKey, 156 + }) 125 157 if err != nil { 126 - return fmt.Errorf("failed to get signing public key: %w", err) 158 + return err 127 159 } 160 + printRotationKeys(res) 161 + return nil 162 + }, 163 + } 128 164 129 - prevCID := lastEntry.AsOperation().CID().String() 165 + // loadHoldPLCIdentity is the shared "load DID + rotation key + signing key" helper used 166 + // by every PLC command. It enforces that database.rotation_key is set since every PLC 167 + // command needs a rotation key to either sign updates or verify the LOCAL marker. 168 + func loadHoldPLCIdentity(ctx context.Context, cfg *hold.Config) (string, atcrypto.PrivateKey, *atcrypto.PrivateKeyK256, error) { 169 + holdDID, err := did.LoadOrCreate(ctx, cfg.DIDConfig()) 170 + if err != nil { 171 + return "", nil, nil, fmt.Errorf("failed to resolve hold DID: %w", err) 172 + } 130 173 131 - op := &didplc.RegularOp{ 132 - Type: "plc_operation", 133 - RotationKeys: rotationKeys, 134 - VerificationMethods: map[string]string{ 135 - "atproto": sigPub.DIDKey(), 136 - }, 137 - AlsoKnownAs: lastOp.AlsoKnownAs, 138 - Services: lastOp.Services, 139 - Prev: &prevCID, 140 - } 174 + if cfg.Database.RotationKey == "" { 175 + return "", nil, nil, fmt.Errorf("database.rotation_key must be set to sign PLC updates") 176 + } 177 + rotationKey, err := atcrypto.ParsePrivateMultibase(cfg.Database.RotationKey) 178 + if err != nil { 179 + return "", nil, nil, fmt.Errorf("failed to parse rotation_key from config: %w", err) 180 + } 141 181 142 - if err := op.Sign(rotationKey); err != nil { 143 - return fmt.Errorf("failed to sign PLC update: %w", err) 144 - } 182 + keyPath := cfg.Database.KeyPath 183 + if keyPath == "" { 184 + keyPath = cfg.Database.Path + "/signing.key" 185 + } 186 + signingKey, err := oauth.GenerateOrLoadPDSKey(keyPath) 187 + if err != nil { 188 + return "", nil, nil, fmt.Errorf("failed to load signing key: %w", err) 189 + } 190 + return holdDID, rotationKey, signingKey, nil 191 + } 145 192 146 - if err := client.Submit(ctx, holdDID, op); err != nil { 147 - return fmt.Errorf("failed to submit PLC update: %w", err) 193 + // printRotationKeys is the shared CLI output for `list-rotation-keys`. 194 + func printRotationKeys(res *did.ListRotationKeysResult) { 195 + fmt.Printf("DID: %s\n", res.DID) 196 + fmt.Printf("PLC directory: %s\n", res.Directory) 197 + fmt.Printf("Rotation keys (%d):\n", len(res.Keys)) 198 + for i, k := range res.Keys { 199 + marker := "" 200 + switch { 201 + case len(res.Keys) == 1: 202 + marker = "(only key)" 203 + case i == 0: 204 + marker = "(highest priority)" 205 + case i == len(res.Keys)-1: 206 + marker = "(lowest priority)" 148 207 } 208 + localTag := "" 209 + if res.LocalDIDKey != "" && k == res.LocalDIDKey { 210 + localTag = " [LOCAL — database.rotation_key]" 211 + } 212 + fmt.Printf(" [%d] %s %s%s\n", i, k, marker, localTag) 213 + } 149 214 150 - slog.Info("Added rotation key to PLC identity", 151 - "did", holdDID, 152 - "new_key", newKeyDIDKey, 153 - "total_rotation_keys", len(rotationKeys), 154 - ) 155 - fmt.Printf("Added rotation key %s to %s\n", newKeyDIDKey, holdDID) 156 - return nil 157 - }, 215 + if res.LocalDIDKey != "" && !res.LocalPresent { 216 + fmt.Printf("\nWARNING: local rotation_key (%s) is NOT present in the PLC document.\n", res.LocalDIDKey) 217 + fmt.Println("This service cannot sign PLC updates. Possible compromise or out-of-band rotation.") 218 + } 158 219 } 159 220 160 221 func init() { 161 222 plcCmd.PersistentFlags().StringVarP(&plcConfigFile, "config", "c", "", "path to YAML configuration file") 162 223 224 + plcAddRotationKeyCmd.Flags().BoolVar(&plcAddRotationKeyFirst, "first", true, "insert at highest priority (default)") 225 + plcAddRotationKeyCmd.Flags().BoolVar(&plcAddRotationKeyLast, "last", false, "insert at lowest priority") 226 + 163 227 plcCmd.AddCommand(plcAddRotationKeyCmd) 228 + plcCmd.AddCommand(plcListRotationKeysCmd) 164 229 }
+3 -10
cmd/hold/repo.go
··· 6 6 "log/slog" 7 7 "os" 8 8 9 + "atcr.io/pkg/atproto/did" 9 10 "atcr.io/pkg/hold" 10 11 holddb "atcr.io/pkg/hold/db" 11 12 "atcr.io/pkg/hold/pds" ··· 39 40 } 40 41 defer cleanup() 41 42 42 - if err := holdPDS.ExportToCAR(ctx, os.Stdout); err != nil { 43 + if err := holdPDS.RepomgrRef().ReadRepo(ctx, holdPDS.UID(), "", os.Stdout); err != nil { 43 44 return fmt.Errorf("failed to export: %w", err) 44 45 } 45 46 ··· 105 106 // openHoldPDS creates a HoldPDS from config for offline CLI operations. 106 107 // Returns the PDS and a cleanup function that must be deferred. 107 108 func openHoldPDS(ctx context.Context, cfg *hold.Config) (*pds.HoldPDS, func(), error) { 108 - holdDID, err := pds.LoadOrCreateDID(ctx, pds.DIDConfig{ 109 - DID: cfg.Database.DID, 110 - DIDMethod: cfg.Database.DIDMethod, 111 - PublicURL: cfg.Server.PublicURL, 112 - DBPath: cfg.Database.Path, 113 - SigningKeyPath: cfg.Database.KeyPath, 114 - RotationKey: cfg.Database.RotationKey, 115 - PLCDirectoryURL: cfg.Database.PLCDirectoryURL, 116 - }) 109 + holdDID, err := did.LoadOrCreate(ctx, cfg.DIDConfig()) 117 110 if err != nil { 118 111 return nil, nil, fmt.Errorf("failed to resolve hold DID: %w", err) 119 112 }
+1
cmd/labeler/main.go
··· 73 73 74 74 rootCmd.AddCommand(serveCmd) 75 75 rootCmd.AddCommand(configCmd) 76 + rootCmd.AddCommand(plcCmd) 76 77 } 77 78 78 79 func main() {
+225
cmd/labeler/plc.go
··· 1 + package main 2 + 3 + import ( 4 + "context" 5 + "fmt" 6 + "log/slog" 7 + 8 + "atcr.io/pkg/atproto/did" 9 + "atcr.io/pkg/auth/oauth" 10 + "atcr.io/pkg/labeler" 11 + 12 + "github.com/bluesky-social/indigo/atproto/atcrypto" 13 + "github.com/spf13/cobra" 14 + ) 15 + 16 + var plcCmd = &cobra.Command{ 17 + Use: "plc", 18 + Short: "PLC directory management commands", 19 + } 20 + 21 + var plcConfigFile string 22 + 23 + var ( 24 + plcAddRotationKeyFirst bool 25 + plcAddRotationKeyLast bool 26 + ) 27 + 28 + var plcAddRotationKeyCmd = &cobra.Command{ 29 + Use: "add-rotation-key [multibase-key]", 30 + Short: "Add a rotation key to this labeler's PLC identity", 31 + Long: `Add an additional rotation key to the labeler's did:plc document. 32 + 33 + If a multibase-encoded private key (K-256 or P-256, starting with 'z') is supplied as 34 + the positional argument, that key is added. If no argument is given, a fresh K-256 35 + keypair is generated and the private half is printed to stdout. Save it offline as 36 + your recovery key, since it will not be shown again. 37 + 38 + By default the new key is inserted at the highest priority position (--first), which 39 + allows it to override ops signed by lower-priority keys within PLC's 72-hour recovery 40 + window. Pass --last to append at the lowest priority instead. 41 + 42 + The labeler's configured rotation key is used to sign the PLC update. 43 + 44 + atcr-labeler plc add-rotation-key --config config.yaml # generate + print 45 + atcr-labeler plc add-rotation-key --config config.yaml --last # append, low priority 46 + atcr-labeler plc add-rotation-key --config config.yaml z... # use supplied key`, 47 + Args: cobra.MaximumNArgs(1), 48 + RunE: func(cmd *cobra.Command, args []string) error { 49 + firstSet := cmd.Flags().Changed("first") 50 + lastSet := cmd.Flags().Changed("last") 51 + if firstSet && lastSet { 52 + return fmt.Errorf("--first and --last are mutually exclusive") 53 + } 54 + prepend := !plcAddRotationKeyLast 55 + 56 + cfg, err := labeler.LoadConfig(plcConfigFile) 57 + if err != nil { 58 + return fmt.Errorf("failed to load config: %w", err) 59 + } 60 + if cfg.Labeler.DIDMethod != "plc" { 61 + return fmt.Errorf("this command only works with did:plc (labeler.did_method is %q)", cfg.Labeler.DIDMethod) 62 + } 63 + 64 + ctx := context.Background() 65 + labelerDID, rotationKey, signingKey, err := loadLabelerPLCIdentity(ctx, cfg) 66 + if err != nil { 67 + return err 68 + } 69 + 70 + var newKey atcrypto.PrivateKeyExportable 71 + if len(args) == 1 { 72 + newKey, err = atcrypto.ParsePrivateMultibase(args[0]) 73 + if err != nil { 74 + return fmt.Errorf("failed to parse key argument: %w", err) 75 + } 76 + } 77 + 78 + res, err := did.AddRotationKey(ctx, did.AddRotationKeyOptions{ 79 + DID: labelerDID, 80 + PLCDirectoryURL: cfg.PLCDirectoryURL(), 81 + RotationKey: rotationKey, 82 + SigningKey: signingKey, 83 + VerificationKeyName: "atproto_label", 84 + NewKey: newKey, 85 + Prepend: prepend, 86 + }) 87 + if err != nil { 88 + return err 89 + } 90 + 91 + if res.AlreadyPresent { 92 + fmt.Printf("Key %s is already a rotation key for %s (priority %d of %d)\n", 93 + res.NewKeyDIDKey, labelerDID, res.ExistingAt, res.TotalKeys) 94 + return nil 95 + } 96 + 97 + if res.Generated { 98 + fmt.Println("=========================================================================") 99 + fmt.Println("GENERATED NEW ROTATION KEY. SAVE THIS NOW. IT WILL NOT BE SHOWN AGAIN.") 100 + fmt.Println("Store it offline (password manager, paper, hardware token).") 101 + fmt.Println() 102 + fmt.Printf("Private key (multibase): %s\n", res.NewKey.Multibase()) 103 + fmt.Printf("Public key (did:key): %s\n", res.NewKeyDIDKey) 104 + fmt.Println("=========================================================================") 105 + } 106 + 107 + slog.Info("Added rotation key to PLC identity", 108 + "did", labelerDID, 109 + "new_key", res.NewKeyDIDKey, 110 + "priority", res.InsertedAt, 111 + "total_rotation_keys", res.TotalKeys, 112 + "generated", res.Generated, 113 + ) 114 + fmt.Printf("Added rotation key %s to %s (priority %d of %d)\n", 115 + res.NewKeyDIDKey, labelerDID, res.InsertedAt, res.TotalKeys) 116 + return nil 117 + }, 118 + } 119 + 120 + var plcListRotationKeysCmd = &cobra.Command{ 121 + Use: "list-rotation-keys", 122 + Short: "List rotation keys in this labeler's PLC document", 123 + Long: `Fetch the labeler's did:plc document from the PLC directory and print its 124 + rotation keys in priority order (index 0 is highest priority and can override 125 + ops signed by lower-priority keys within PLC's 72-hour recovery window). 126 + 127 + The key matching the local labeler.rotation_key is marked as LOCAL.`, 128 + Args: cobra.NoArgs, 129 + RunE: func(cmd *cobra.Command, args []string) error { 130 + cfg, err := labeler.LoadConfig(plcConfigFile) 131 + if err != nil { 132 + return fmt.Errorf("failed to load config: %w", err) 133 + } 134 + if cfg.Labeler.DIDMethod != "plc" { 135 + return fmt.Errorf("this command only works with did:plc (labeler.did_method is %q)", cfg.Labeler.DIDMethod) 136 + } 137 + 138 + ctx := context.Background() 139 + labelerDID, _, _, err := loadLabelerPLCIdentity(ctx, cfg) 140 + if err != nil { 141 + return err 142 + } 143 + 144 + var localRotationKey atcrypto.PrivateKey 145 + if cfg.Labeler.RotationKey != "" { 146 + localRotationKey, err = atcrypto.ParsePrivateMultibase(cfg.Labeler.RotationKey) 147 + if err != nil { 148 + return fmt.Errorf("failed to parse rotation_key from config: %w", err) 149 + } 150 + } 151 + 152 + res, err := did.ListRotationKeys(ctx, did.ListRotationKeysOptions{ 153 + DID: labelerDID, 154 + PLCDirectoryURL: cfg.PLCDirectoryURL(), 155 + LocalRotationKey: localRotationKey, 156 + }) 157 + if err != nil { 158 + return err 159 + } 160 + printRotationKeys(res) 161 + return nil 162 + }, 163 + } 164 + 165 + // loadLabelerPLCIdentity is the shared "load DID + rotation key + signing key" helper 166 + // used by every PLC command. Mirrors loadHoldPLCIdentity over in cmd/hold/plc.go. 167 + func loadLabelerPLCIdentity(ctx context.Context, cfg *labeler.Config) (string, atcrypto.PrivateKey, *atcrypto.PrivateKeyK256, error) { 168 + labelerDID, _, err := labeler.LoadIdentity(ctx, cfg) 169 + if err != nil { 170 + return "", nil, nil, err 171 + } 172 + 173 + if cfg.Labeler.RotationKey == "" { 174 + return "", nil, nil, fmt.Errorf("labeler.rotation_key must be set to sign PLC updates") 175 + } 176 + rotationKey, err := atcrypto.ParsePrivateMultibase(cfg.Labeler.RotationKey) 177 + if err != nil { 178 + return "", nil, nil, fmt.Errorf("failed to parse rotation_key from config: %w", err) 179 + } 180 + 181 + signingKey, err := oauth.GenerateOrLoadPDSKey(cfg.SigningKeyPath()) 182 + if err != nil { 183 + return "", nil, nil, fmt.Errorf("failed to load signing key: %w", err) 184 + } 185 + return labelerDID, rotationKey, signingKey, nil 186 + } 187 + 188 + // printRotationKeys is the shared CLI output for list-rotation-keys, kept identical to 189 + // the hold version since the formatting is service-agnostic. 190 + func printRotationKeys(res *did.ListRotationKeysResult) { 191 + fmt.Printf("DID: %s\n", res.DID) 192 + fmt.Printf("PLC directory: %s\n", res.Directory) 193 + fmt.Printf("Rotation keys (%d):\n", len(res.Keys)) 194 + for i, k := range res.Keys { 195 + marker := "" 196 + switch { 197 + case len(res.Keys) == 1: 198 + marker = "(only key)" 199 + case i == 0: 200 + marker = "(highest priority)" 201 + case i == len(res.Keys)-1: 202 + marker = "(lowest priority)" 203 + } 204 + localTag := "" 205 + if res.LocalDIDKey != "" && k == res.LocalDIDKey { 206 + localTag = " [LOCAL — labeler.rotation_key]" 207 + } 208 + fmt.Printf(" [%d] %s %s%s\n", i, k, marker, localTag) 209 + } 210 + 211 + if res.LocalDIDKey != "" && !res.LocalPresent { 212 + fmt.Printf("\nWARNING: local rotation_key (%s) is NOT present in the PLC document.\n", res.LocalDIDKey) 213 + fmt.Println("This service cannot sign PLC updates. Possible compromise or out-of-band rotation.") 214 + } 215 + } 216 + 217 + func init() { 218 + plcCmd.PersistentFlags().StringVarP(&plcConfigFile, "config", "c", "", "path to YAML configuration file") 219 + 220 + plcAddRotationKeyCmd.Flags().BoolVar(&plcAddRotationKeyFirst, "first", true, "insert at highest priority (default)") 221 + plcAddRotationKeyCmd.Flags().BoolVar(&plcAddRotationKeyLast, "last", false, "insert at lowest priority") 222 + 223 + plcCmd.AddCommand(plcAddRotationKeyCmd) 224 + plcCmd.AddCommand(plcListRotationKeysCmd) 225 + }
+55
config-labeler.example.yaml
··· 1 + # ATCR Labeler Configuration 2 + # Generated with defaults — edit as needed. 3 + 4 + # Configuration format version. 5 + version: "0.1" 6 + # Log level: debug, info, warn, error. 7 + log_level: info 8 + # Labeler service settings. 9 + labeler: 10 + # Enable the labeler service. 11 + enabled: true 12 + # Listen address for labeler (e.g., :5002). 13 + addr: :5002 14 + # Externally reachable labeler URL. Empty = derive from server.base_url. 15 + public_url: "" 16 + # DID of the labeler admin. Only this DID can log into the admin panel. 17 + owner_did: did:plc:your-did-here 18 + # Directory for labeler state (database, signing key, did.txt). 19 + data_dir: /var/lib/atcr-labeler 20 + # DID method: "plc" (recommended) or "web". 21 + did_method: plc 22 + # Explicit did:plc identifier for adoption/recovery (optional). 23 + did: "" 24 + # Path to K-256 signing key (defaults to <data_dir>/signing.key). 25 + key_path: "" 26 + # Multibase-encoded rotation key (K-256 or P-256). Required to update the PLC document. 27 + rotation_key: "" 28 + # PLC directory URL (default https://plc.directory). 29 + plc_directory_url: https://plc.directory 30 + # Optional libSQL/Bunny remote sync URL. Empty = local-only. 31 + libsql_sync_url: "" 32 + # Auth token for libsql_sync_url. 33 + libsql_auth_token: "" 34 + # Embedded-replica pull interval (e.g. 30s). 0 = manual sync only. 35 + libsql_sync_interval: 0s 36 + # AppView server settings (shared config). 37 + server: 38 + base_url: https://atcr.io 39 + client_name: AT Container Registry 40 + client_short_name: ATCR 41 + test_mode: false 42 + # Remote log shipping settings. 43 + log_shipper: 44 + # Log shipping backend: "victoria", "opensearch", or "loki". Empty disables shipping. 45 + backend: "" 46 + # Remote log service endpoint, e.g. "http://victorialogs:9428". 47 + url: "" 48 + # Number of log entries to buffer before flushing to the remote service. 49 + batch_size: 0 50 + # Maximum time between flushes, even if batch is not full. 51 + flush_interval: 0s 52 + # Basic auth username for the log service (optional). 53 + username: "" 54 + # Basic auth password for the log service (optional). 55 + password: ""
+6 -1
deploy/upcloud/configs/labeler.yaml.tmpl
··· 11 11 enabled: true 12 12 addr: :5002 13 13 owner_did: "" 14 - db_path: "{{.BasePath}}/labeler/labeler.db" 14 + data_dir: "{{.BasePath}}/labeler" 15 + did_method: plc 16 + did: "" 17 + key_path: "" 18 + rotation_key: "" 19 + plc_directory_url: https://plc.directory 15 20 server: 16 21 base_url: "https://seamark.dev" 17 22 client_name: Seamark
+50
docker-compose.yml
··· 19 19 # ATCR_SERVER_CLIENT_SHORT_NAME: "Seamark" 20 20 ATCR_SERVER_MANAGED_HOLDS: did:web:172.28.0.3%3A8080 21 21 ATCR_SERVER_DEFAULT_HOLD_DID: did:web:172.28.0.3%3A8080 22 + # Labeler URL (HTTP for dev — ParseLabelerURL accepts it directly so we don't 23 + # have to round-trip through did:web → https:// resolution). 24 + ATCR_LABELER_DID: http://172.28.0.4:5002 22 25 ATCR_SERVER_TEST_MODE: true 23 26 ATCR_LOG_LEVEL: debug 24 27 LOG_SHIPPER_BACKEND: victoria ··· 97 100 atcr-network: 98 101 ipv4_address: 172.28.0.3 99 102 103 + atcr-labeler: 104 + # Base config: config-labeler.example.yaml (passed via Air entrypoint). 105 + # Env vars below override config file values for local dev. 106 + # 107 + # Why did:web for dev: did:plc would submit a real PLC operation to plc.directory 108 + # for every fresh dev environment, polluting production with throwaway DIDs that 109 + # point at 172.28.0.x. did:web is purely self-served via /.well-known/did.json so 110 + # nothing leaks. Switch to plc + a real public_url for production. 111 + environment: 112 + LABELER_LABELER_DID_METHOD: web 113 + LABELER_LABELER_PUBLIC_URL: http://172.28.0.4:5002 114 + LABELER_LABELER_OWNER_DID: did:plc:pddp4xt5lgnv2qsegbzzs4xg 115 + LABELER_LABELER_DATA_DIR: /var/lib/atcr-labeler 116 + LABELER_SERVER_TEST_MODE: true 117 + LABELER_LOG_LEVEL: debug 118 + LOG_SHIPPER_BACKEND: victoria 119 + LOG_SHIPPER_URL: http://172.28.0.10:9428 120 + logging: 121 + driver: json-file 122 + options: 123 + max-size: "10m" 124 + max-file: "1" 125 + build: 126 + context: . 127 + dockerfile: Dockerfile.dev 128 + args: 129 + AIR_CONFIG: .air.labeler.toml 130 + image: atcr-labeler-dev:latest 131 + container_name: atcr-labeler 132 + ports: 133 + - "5002:5002" 134 + volumes: 135 + # Mount source code for Air hot reload 136 + - .:/app:z 137 + - go-mod-cache:/go/pkg/mod 138 + # Persist signing key + did.txt + label database across container restarts so 139 + # dev signatures stay verifiable. Wipe with `docker compose down -v` to reset. 140 + - atcr-labeler:/var/lib/atcr-labeler 141 + restart: unless-stopped 142 + dns: 143 + - 8.8.8.8 144 + - 1.1.1.1 145 + networks: 146 + atcr-network: 147 + ipv4_address: 172.28.0.4 148 + 100 149 # Victoria Logs for centralized log storage 101 150 # Uncomment to enable, then set LOG_SHIPPER_* env vars above 102 151 victorialogs: ··· 123 172 124 173 volumes: 125 174 atcr-hold: 175 + atcr-labeler: 126 176 atcr-auth: 127 177 atcr-ui: 128 178 go-mod-cache:
pkg/appview/db/migrations/0017_create_labels.yaml pkg/appview/db/migrations/0023_create_labels.yaml
+28 -15
pkg/appview/db/queries.go
··· 16 16 return fmt.Sprintf("https://imgs.blue/%s/%s", did, cid) 17 17 } 18 18 19 + // activeTakedownClause returns a SQL fragment ready to drop into a `WHERE NOT 20 + // EXISTS (...)` filter for excluding rows whose `(did, repository)` pair is currently 21 + // taken down. The `alias` argument is the outer table alias (e.g. "m" for manifests, 22 + // "lm" for latest_manifests) and must already be in scope at the use site. 23 + // 24 + // Mirrors the semantics of `IsTakenDown` (defined in labels.go) so listings stay 25 + // consistent with the per-repo page check: a label only counts as active when it has 26 + // neg=0, no newer neg=1 row with the same (src, uri, val), and a non-expired `exp`. 27 + // Without these clauses listings hide a repo forever once you've ever taken it down, 28 + // even after a reversal. 29 + func activeTakedownClause(alias string) string { 30 + return `NOT EXISTS ( 31 + SELECT 1 FROM labels l1 32 + WHERE l1.subject_did = ` + alias + `.did 33 + AND (l1.subject_repo = ` + alias + `.repository OR l1.subject_repo = '') 34 + AND l1.val = '!takedown' AND l1.neg = 0 35 + AND NOT EXISTS ( 36 + SELECT 1 FROM labels l2 37 + WHERE l2.src = l1.src AND l2.uri = l1.uri AND l2.val = l1.val 38 + AND l2.neg = 1 AND l2.id > l1.id 39 + ) 40 + AND (l1.exp IS NULL OR l1.exp > CURRENT_TIMESTAMP) 41 + )` 42 + } 43 + 19 44 // accessibleHoldsSubquery returns SQL that evaluates to the set of hold DIDs 20 45 // the viewer is allowed to see in listings. Requires the viewerDID to be 21 46 // passed twice as query arguments (once for the owner_did check and once ··· 107 132 WHERE ra.did = lm.did AND ra.repository = lm.repository 108 133 AND ra.value LIKE ? ESCAPE '\' 109 134 )) 110 - AND NOT EXISTS ( 111 - SELECT 1 FROM labels 112 - WHERE (subject_did = lm.did AND (subject_repo = lm.repository OR subject_repo = '')) 113 - AND val = '!takedown' AND neg = 0 114 - ) 135 + AND ` + activeTakedownClause("lm") + ` 115 136 ), 116 137 repo_stats AS ( 117 138 SELECT ··· 2122 2143 JOIN users u ON m.did = u.did 2123 2144 LEFT JOIN repository_stats rs ON m.did = rs.did AND m.repository = rs.repository 2124 2145 LEFT JOIN repo_pages rp ON m.did = rp.did AND m.repository = rp.repository 2125 - WHERE NOT EXISTS ( 2126 - SELECT 1 FROM labels 2127 - WHERE (subject_did = m.did AND (subject_repo = m.repository OR subject_repo = '')) 2128 - AND val = '!takedown' AND neg = 0 2129 - ) 2146 + WHERE ` + activeTakedownClause("m") + ` 2130 2147 ORDER BY ` + orderBy + ` 2131 2148 LIMIT ? 2132 2149 ` ··· 2205 2222 JOIN users u ON m.did = u.did 2206 2223 LEFT JOIN repository_stats rs ON m.did = rs.did AND m.repository = rs.repository 2207 2224 LEFT JOIN repo_pages rp ON m.did = rp.did AND m.repository = rp.repository 2208 - WHERE NOT EXISTS ( 2209 - SELECT 1 FROM labels 2210 - WHERE (subject_did = m.did AND (subject_repo = m.repository OR subject_repo = '')) 2211 - AND val = '!takedown' AND neg = 0 2212 - ) 2225 + WHERE ` + activeTakedownClause("m") + ` 2213 2226 ORDER BY MAX(rs.last_push, m.created_at) DESC 2214 2227 ` 2215 2228
+8 -9
pkg/appview/db/readonly.go
··· 39 39 } else { 40 40 roDSN += "?mode=ro" 41 41 } 42 - readOnlyDB, err := sql.Open("libsql", roDSN) 42 + // Wrap with busyTimeoutConnector so every pooled read-only connection 43 + // gets PRAGMA busy_timeout. Without this, reads return SQLITE_BUSY 44 + // immediately when a write is in progress on the read-write connection 45 + // (busy_timeout is per-connection, so a one-shot PRAGMA only configures 46 + // whichever conn served it). 47 + roBase, err := openLibsqlLocalConnector(roDSN) 43 48 if err != nil { 44 - slog.Warn("Failed to open read-only database connection", "error", err) 49 + slog.Warn("Failed to open read-only database connector", "error", err) 45 50 return nil, nil, nil 46 51 } 47 - 48 - // busy_timeout is per-connection — without this, reads return SQLITE_BUSY 49 - // immediately when a write is in progress on the read-write connection. 50 - var busyTimeout int 51 - if err := readOnlyDB.QueryRow("PRAGMA busy_timeout = 5000").Scan(&busyTimeout); err != nil { 52 - slog.Warn("Failed to set busy_timeout on read-only connection", "error", err) 53 - } 52 + readOnlyDB := sql.OpenDB(&busyTimeoutConnector{base: roBase, timeoutMs: 5000}) 54 53 55 54 slog.Info("UI database initialized", "mode", "readonly", "path", dbPath) 56 55
+71 -15
pkg/appview/db/schema.go
··· 5 5 package db 6 6 7 7 import ( 8 + "context" 8 9 "database/sql" 10 + "database/sql/driver" 9 11 "embed" 10 12 "fmt" 11 13 "io/fs" ··· 55 57 db = sql.OpenDB(connector) 56 58 slog.Info("Database opened in embedded replica mode", "path", path, "sync_url", cfg.SyncURL) 57 59 } else { 58 - // Local-only mode: plain file via libsql driver 59 - // Paths starting with "file:" or ":memory:" are already valid libsql URIs 60 + // Local-only mode: plain file via libsql driver, wrapped so every new 61 + // connection gets PRAGMA busy_timeout. SQLite's busy_timeout is 62 + // per-connection, so a one-shot db.Exec only configures whichever 63 + // pooled conn served the call — leaving the rest to fail SQLITE_BUSY 64 + // instantly on any write contention with the jetstream/backfill workers. 65 + // Paths starting with "file:" or ":memory:" are already valid libsql URIs. 60 66 dsn := path 61 67 if !strings.HasPrefix(path, "file:") && !strings.HasPrefix(path, ":memory:") { 62 68 dsn = "file:" + path 63 69 } 64 - var err error 65 - db, err = sql.Open("libsql", dsn) 70 + baseConnector, err := openLibsqlLocalConnector(dsn) 66 71 if err != nil { 67 72 return nil, err 68 73 } 74 + db = sql.OpenDB(&busyTimeoutConnector{base: baseConnector, timeoutMs: 5000}) 69 75 slog.Info("Database opened in local-only mode", "path", path) 70 76 } 71 77 72 - // In local-only mode, configure WAL and busy_timeout locally. 73 - // In embedded replica mode, the remote server manages these settings 74 - // and PRAGMA assignments are rejected as "unsupported statement" 75 - // (observed with Bunny Database; Turso may behave similarly). 78 + // In local-only mode, set WAL mode (database-wide setting, persists 79 + // across connections — single call is sufficient unlike busy_timeout). 80 + // In embedded replica mode, the remote server manages this and the 81 + // PRAGMA is rejected as "unsupported statement" (observed with Bunny; 82 + // Turso may behave similarly). 76 83 if cfg.SyncURL == "" { 77 - // Enable WAL mode for concurrent read/write access 78 84 var journalMode string 79 85 if err := db.QueryRow("PRAGMA journal_mode = WAL").Scan(&journalMode); err != nil { 80 - return nil, err 81 - } 82 - 83 - // Retry on lock instead of failing immediately (5s timeout) 84 - var busyTimeout int 85 - if err := db.QueryRow("PRAGMA busy_timeout = 5000").Scan(&busyTimeout); err != nil { 86 86 return nil, err 87 87 } 88 88 } ··· 377 377 378 378 return version, name, nil 379 379 } 380 + 381 + // openLibsqlLocalConnector returns a driver.Connector for a local libsql DSN. 382 + // go-libsql exports NewEmbeddedReplicaConnector for replica mode but no public 383 + // constructor for local files, so we obtain the driver via a probe sql.Open 384 + // (which is lazy and opens no connection) and ask it for a Connector. 385 + func openLibsqlLocalConnector(dsn string) (driver.Connector, error) { 386 + probe, err := sql.Open("libsql", dsn) 387 + if err != nil { 388 + return nil, fmt.Errorf("probe libsql driver: %w", err) 389 + } 390 + drv := probe.Driver() 391 + _ = probe.Close() 392 + 393 + dctx, ok := drv.(driver.DriverContext) 394 + if !ok { 395 + return nil, fmt.Errorf("libsql driver does not implement driver.DriverContext") 396 + } 397 + return dctx.OpenConnector(dsn) 398 + } 399 + 400 + // busyTimeoutConnector wraps a driver.Connector and runs PRAGMA busy_timeout 401 + // on every newly opened connection. SQLite's busy_timeout is per-connection, 402 + // so this is the only way to ensure every conn in the pool waits on lock 403 + // contention instead of returning SQLITE_BUSY immediately. 404 + type busyTimeoutConnector struct { 405 + base driver.Connector 406 + timeoutMs int 407 + } 408 + 409 + func (c *busyTimeoutConnector) Connect(ctx context.Context) (driver.Conn, error) { 410 + conn, err := c.base.Connect(ctx) 411 + if err != nil { 412 + return nil, err 413 + } 414 + 415 + // libsql treats PRAGMA assignments as queries that return a row, so we 416 + // must use QueryerContext rather than ExecerContext. 417 + queryer, ok := conn.(driver.QueryerContext) 418 + if !ok { 419 + _ = conn.Close() 420 + return nil, fmt.Errorf("libsql conn does not support QueryerContext") 421 + } 422 + 423 + rows, err := queryer.QueryContext(ctx, fmt.Sprintf("PRAGMA busy_timeout = %d", c.timeoutMs), nil) 424 + if err != nil { 425 + _ = conn.Close() 426 + return nil, fmt.Errorf("set busy_timeout on new conn: %w", err) 427 + } 428 + _ = rows.Close() 429 + 430 + return conn, nil 431 + } 432 + 433 + func (c *busyTimeoutConnector) Driver() driver.Driver { 434 + return c.base.Driver() 435 + }
+1 -1
pkg/appview/handlers/diff.go
··· 234 234 return 235 235 } 236 236 if owner.Handle != resolvedHandle { 237 - _ = db.UpdateUserHandle(h.ReadOnlyDB, did, resolvedHandle) 237 + _ = db.UpdateUserHandle(h.DB, did, resolvedHandle) 238 238 owner.Handle = resolvedHandle 239 239 } 240 240
+1 -1
pkg/appview/handlers/digest.go
··· 105 105 return 106 106 } 107 107 if owner.Handle != resolvedHandle { 108 - _ = db.UpdateUserHandle(h.ReadOnlyDB, did, resolvedHandle) 108 + _ = db.UpdateUserHandle(h.DB, did, resolvedHandle) 109 109 owner.Handle = resolvedHandle 110 110 } 111 111
+1 -1
pkg/appview/handlers/repository.go
··· 63 63 64 64 // Opportunistically update cached handle if it changed 65 65 if owner.Handle != resolvedHandle { 66 - _ = db.UpdateUserHandle(h.ReadOnlyDB, did, resolvedHandle) 66 + _ = db.UpdateUserHandle(h.DB, did, resolvedHandle) 67 67 owner.Handle = resolvedHandle 68 68 } 69 69
+1 -1
pkg/appview/handlers/user.go
··· 44 44 } 45 45 } else if viewedUser.Handle != resolvedHandle { 46 46 // Opportunistically update cached handle if it changed 47 - _ = db.UpdateUserHandle(h.ReadOnlyDB, did, resolvedHandle) 47 + _ = db.UpdateUserHandle(h.DB, did, resolvedHandle) 48 48 viewedUser.Handle = resolvedHandle 49 49 } 50 50
+83 -38
pkg/appview/labeler/subscriber.go
··· 3 3 package labeler 4 4 5 5 import ( 6 + "bytes" 6 7 "database/sql" 7 - "encoding/json" 8 + "errors" 8 9 "fmt" 9 10 "log/slog" 10 11 "net/url" ··· 13 14 14 15 "atcr.io/pkg/appview/db" 15 16 17 + comatproto "github.com/bluesky-social/indigo/api/atproto" 18 + "github.com/bluesky-social/indigo/events" 16 19 "github.com/gorilla/websocket" 17 20 ) 18 - 19 - // LabelsMessage is the wire format for subscribeLabels events. 20 - type LabelsMessage struct { 21 - Seq int64 `json:"seq"` 22 - Labels []LabelEvent `json:"labels"` 23 - } 24 - 25 - // LabelEvent is a single label from the labeler. 26 - type LabelEvent struct { 27 - Src string `json:"src"` 28 - URI string `json:"uri"` 29 - CID string `json:"cid,omitempty"` 30 - Val string `json:"val"` 31 - Neg bool `json:"neg"` 32 - Cts string `json:"cts"` 33 - Exp string `json:"exp,omitempty"` 34 - } 35 21 36 22 // Subscriber connects to a labeler's subscribeLabels endpoint 37 23 // and mirrors labels into the appview database. ··· 121 107 default: 122 108 } 123 109 124 - var msg LabelsMessage 125 - if err := conn.ReadJSON(&msg); err != nil { 110 + mt, payload, err := conn.ReadMessage() 111 + if err != nil { 126 112 return fmt.Errorf("read error: %w", err) 127 113 } 114 + // Per the ATProto event-stream spec each frame is a binary message; reject text. 115 + if mt != websocket.BinaryMessage { 116 + slog.Warn("Ignoring non-binary frame from labeler", "type", mt) 117 + continue 118 + } 128 119 129 - for _, le := range msg.Labels { 120 + seq, labels, err := decodeFrame(payload) 121 + if err != nil { 122 + if errors.Is(err, errInfoFrame) { 123 + continue // already logged inside decodeFrame 124 + } 125 + return fmt.Errorf("decode frame: %w", err) 126 + } 127 + 128 + for _, le := range labels { 130 129 cts, _ := time.Parse(time.RFC3339, le.Cts) 131 - did, repo := extractSubjectFromURI(le.URI) 130 + did, repo := extractSubjectFromURI(le.Uri) 132 131 133 132 label := &db.Label{ 134 133 Src: le.Src, 135 - URI: le.URI, 134 + URI: le.Uri, 136 135 Val: le.Val, 137 - Neg: le.Neg, 136 + Neg: le.Neg != nil && *le.Neg, 138 137 Cts: cts, 139 138 SubjectDID: did, 140 139 SubjectRepo: repo, 141 - Seq: msg.Seq, 140 + Seq: seq, 142 141 } 143 142 144 143 if err := db.UpsertLabel(s.database, label); err != nil { 145 - slog.Warn("Failed to upsert label", "uri", le.URI, "error", err) 144 + slog.Warn("Failed to upsert label", "uri", le.Uri, "error", err) 146 145 continue 147 146 } 148 147 149 - slog.Info("Mirrored label", 150 - "uri", le.URI, 148 + // "Mirrored label X" reads as an apply; reversals are a different action 149 + // from the operator's POV (and a different SQL effect — the NOT EXISTS 150 + // negation clause kicks in), so log them distinctly. 151 + msg := "Mirrored label" 152 + if label.Neg { 153 + msg = "Mirrored label reversal" 154 + } 155 + slog.Info(msg, 156 + "uri", le.Uri, 151 157 "val", le.Val, 152 - "neg", le.Neg, 158 + "neg", label.Neg, 153 159 "subject_did", did, 154 160 "subject_repo", repo, 155 161 ) 156 162 } 163 + } 164 + } 165 + 166 + // errInfoFrame is returned by decodeFrame when the frame is informational and the 167 + // caller should just continue to the next message. 168 + var errInfoFrame = errors.New("labeler: info frame") 169 + 170 + // decodeFrame parses a single subscribeLabels binary frame. ATProto event-stream framing 171 + // is two concatenated CBOR objects: a {op,t} header and a body. We dispatch on the 172 + // header op/t pair and return the labels body for op=1, t="#labels". For #info frames 173 + // we log and signal errInfoFrame so the caller skips. Error frames (op=-1) become Go 174 + // errors so the run loop reconnects with backoff. 175 + func decodeFrame(payload []byte) (int64, []*comatproto.LabelDefs_Label, error) { 176 + r := bytes.NewReader(payload) 177 + var header events.EventHeader 178 + if err := header.UnmarshalCBOR(r); err != nil { 179 + return 0, nil, fmt.Errorf("unmarshal header: %w", err) 180 + } 181 + 182 + switch { 183 + case header.Op == events.EvtKindErrorFrame: 184 + var ef events.ErrorFrame 185 + if err := ef.UnmarshalCBOR(r); err != nil { 186 + return 0, nil, fmt.Errorf("unmarshal error frame: %w", err) 187 + } 188 + return 0, nil, fmt.Errorf("labeler error frame: %s — %s", ef.Error, ef.Message) 189 + 190 + case header.Op == events.EvtKindMessage && header.MsgType == "#labels": 191 + var body comatproto.LabelSubscribeLabels_Labels 192 + if err := body.UnmarshalCBOR(r); err != nil { 193 + return 0, nil, fmt.Errorf("unmarshal labels body: %w", err) 194 + } 195 + return body.Seq, body.Labels, nil 196 + 197 + case header.Op == events.EvtKindMessage && header.MsgType == "#info": 198 + var info comatproto.LabelSubscribeLabels_Info 199 + if err := info.UnmarshalCBOR(r); err != nil { 200 + return 0, nil, fmt.Errorf("unmarshal info body: %w", err) 201 + } 202 + message := "" 203 + if info.Message != nil { 204 + message = *info.Message 205 + } 206 + slog.Info("Labeler info frame", "name", info.Name, "message", message) 207 + return 0, nil, errInfoFrame 208 + 209 + default: 210 + return 0, nil, fmt.Errorf("unexpected frame op=%d t=%q", header.Op, header.MsgType) 157 211 } 158 212 } 159 213 ··· 228 282 labelerURL := ParseLabelerURL(labelerDIDOrURL) 229 283 return NewSubscriber(labelerURL, database) 230 284 } 231 - 232 - // DecodeLabelsFromJSON decodes a JSON-encoded labels message. 233 - func DecodeLabelsFromJSON(data []byte) (*LabelsMessage, error) { 234 - var msg LabelsMessage 235 - if err := json.Unmarshal(data, &msg); err != nil { 236 - return nil, err 237 - } 238 - return &msg, nil 239 - }
+208
pkg/atproto/did/cmd.go
··· 1 + package did 2 + 3 + import ( 4 + "context" 5 + "fmt" 6 + 7 + "github.com/bluesky-social/indigo/atproto/atcrypto" 8 + didplc "github.com/did-method-plc/go-didplc" 9 + ) 10 + 11 + // AddRotationKeyOptions configures a rotation-key insert operation. 12 + type AddRotationKeyOptions struct { 13 + // DID is the resolved did:plc identifier of the service. 14 + DID string 15 + 16 + // PLCDirectoryURL is the PLC directory endpoint (defaults to https://plc.directory if empty). 17 + PLCDirectoryURL string 18 + 19 + // RotationKey is the currently-authorized rotation key used to sign the update op. 20 + RotationKey atcrypto.PrivateKey 21 + 22 + // SigningKey is the local k256 verification key — its public part goes into the 23 + // new op's VerificationMethods so we don't accidentally drop it during the update. 24 + SigningKey *atcrypto.PrivateKeyK256 25 + 26 + // VerificationKeyName is the fragment under which SigningKey is registered 27 + // (e.g. "atproto" for a PDS, "atproto_label" for a labeler). 28 + VerificationKeyName string 29 + 30 + // NewKey is the rotation key to add. If nil, a fresh K-256 key is generated and 31 + // returned in the result so the caller can print/persist it. 32 + NewKey atcrypto.PrivateKeyExportable 33 + 34 + // Prepend places the new key at the highest priority position. When false the key 35 + // is appended at the lowest priority — only set false when the operator explicitly 36 + // asks for it. 37 + Prepend bool 38 + } 39 + 40 + // AddRotationKeyResult describes the outcome of an AddRotationKey call. 41 + type AddRotationKeyResult struct { 42 + NewKey atcrypto.PrivateKeyExportable 43 + NewKeyDIDKey string 44 + Generated bool 45 + AlreadyPresent bool 46 + ExistingAt int 47 + InsertedAt int 48 + TotalKeys int 49 + } 50 + 51 + // AddRotationKey fetches the current PLC op log, inserts NewKey (generating one if nil), 52 + // signs the update with RotationKey, and submits it. Caller is responsible for printing 53 + // the generated key material — this function returns it on the result so prints can 54 + // happen in the binary's own format. 55 + func AddRotationKey(ctx context.Context, opt AddRotationKeyOptions) (*AddRotationKeyResult, error) { 56 + if opt.DID == "" { 57 + return nil, fmt.Errorf("plc: DID is required") 58 + } 59 + if opt.RotationKey == nil { 60 + return nil, fmt.Errorf("plc: rotation key is required to sign updates") 61 + } 62 + if opt.SigningKey == nil { 63 + return nil, fmt.Errorf("plc: signing key is required (becomes verificationMethods.%s)", opt.VerificationKeyName) 64 + } 65 + if opt.VerificationKeyName == "" { 66 + return nil, fmt.Errorf("plc: VerificationKeyName is required") 67 + } 68 + 69 + directory := opt.PLCDirectoryURL 70 + if directory == "" { 71 + directory = "https://plc.directory" 72 + } 73 + client := &didplc.Client{DirectoryURL: directory} 74 + 75 + res := &AddRotationKeyResult{NewKey: opt.NewKey} 76 + if res.NewKey == nil { 77 + raw, err := atcrypto.GeneratePrivateKeyK256() 78 + if err != nil { 79 + return nil, fmt.Errorf("plc: failed to generate rotation key: %w", err) 80 + } 81 + res.NewKey = raw 82 + res.Generated = true 83 + } 84 + newPub, err := res.NewKey.PublicKey() 85 + if err != nil { 86 + return nil, fmt.Errorf("plc: failed to derive new public key: %w", err) 87 + } 88 + res.NewKeyDIDKey = newPub.DIDKey() 89 + 90 + opLog, err := client.OpLog(ctx, opt.DID) 91 + if err != nil { 92 + return nil, fmt.Errorf("plc: failed to fetch op log for %s: %w", opt.DID, err) 93 + } 94 + if len(opLog) == 0 { 95 + return nil, fmt.Errorf("plc: empty op log for %s", opt.DID) 96 + } 97 + lastEntry := opLog[len(opLog)-1] 98 + lastOp := lastEntry.Regular 99 + if lastOp == nil { 100 + return nil, fmt.Errorf("plc: last operation is not a regular op") 101 + } 102 + 103 + for i, k := range lastOp.RotationKeys { 104 + if k == res.NewKeyDIDKey { 105 + res.AlreadyPresent = true 106 + res.ExistingAt = i 107 + res.TotalKeys = len(lastOp.RotationKeys) 108 + return res, nil 109 + } 110 + } 111 + 112 + rotationKeys := make([]string, 0, len(lastOp.RotationKeys)+1) 113 + if opt.Prepend { 114 + rotationKeys = append(rotationKeys, res.NewKeyDIDKey) 115 + rotationKeys = append(rotationKeys, lastOp.RotationKeys...) 116 + res.InsertedAt = 0 117 + } else { 118 + rotationKeys = append(rotationKeys, lastOp.RotationKeys...) 119 + rotationKeys = append(rotationKeys, res.NewKeyDIDKey) 120 + res.InsertedAt = len(rotationKeys) - 1 121 + } 122 + res.TotalKeys = len(rotationKeys) 123 + 124 + sigPub, err := opt.SigningKey.PublicKey() 125 + if err != nil { 126 + return nil, fmt.Errorf("plc: failed to derive signing public key: %w", err) 127 + } 128 + prevCID := lastEntry.AsOperation().CID().String() 129 + 130 + op := &didplc.RegularOp{ 131 + Type: "plc_operation", 132 + RotationKeys: rotationKeys, 133 + VerificationMethods: map[string]string{ 134 + opt.VerificationKeyName: sigPub.DIDKey(), 135 + }, 136 + AlsoKnownAs: lastOp.AlsoKnownAs, 137 + Services: lastOp.Services, 138 + Prev: &prevCID, 139 + } 140 + if err := op.Sign(opt.RotationKey); err != nil { 141 + return nil, fmt.Errorf("plc: failed to sign update: %w", err) 142 + } 143 + if err := client.Submit(ctx, opt.DID, op); err != nil { 144 + return nil, fmt.Errorf("plc: failed to submit update: %w", err) 145 + } 146 + return res, nil 147 + } 148 + 149 + // ListRotationKeysOptions configures a list-rotation-keys read. 150 + type ListRotationKeysOptions struct { 151 + DID string 152 + PLCDirectoryURL string 153 + LocalRotationKey atcrypto.PrivateKey // optional — used to compute the LOCAL marker 154 + } 155 + 156 + // ListRotationKeysResult holds the priority-ordered rotation keys plus the local 157 + // rotation key's did:key form (if provided), so callers can mark and warn appropriately. 158 + type ListRotationKeysResult struct { 159 + DID string 160 + Directory string 161 + Keys []string 162 + LocalDIDKey string 163 + LocalPresent bool 164 + } 165 + 166 + // ListRotationKeys fetches the current PLC op and returns its rotation keys in priority order. 167 + func ListRotationKeys(ctx context.Context, opt ListRotationKeysOptions) (*ListRotationKeysResult, error) { 168 + if opt.DID == "" { 169 + return nil, fmt.Errorf("plc: DID is required") 170 + } 171 + directory := opt.PLCDirectoryURL 172 + if directory == "" { 173 + directory = "https://plc.directory" 174 + } 175 + client := &didplc.Client{DirectoryURL: directory} 176 + 177 + opLog, err := client.OpLog(ctx, opt.DID) 178 + if err != nil { 179 + return nil, fmt.Errorf("plc: failed to fetch op log for %s: %w", opt.DID, err) 180 + } 181 + if len(opLog) == 0 { 182 + return nil, fmt.Errorf("plc: empty op log for %s", opt.DID) 183 + } 184 + lastOp := opLog[len(opLog)-1].Regular 185 + if lastOp == nil { 186 + return nil, fmt.Errorf("plc: last operation is not a regular op") 187 + } 188 + 189 + res := &ListRotationKeysResult{ 190 + DID: opt.DID, 191 + Directory: directory, 192 + Keys: append([]string(nil), lastOp.RotationKeys...), 193 + } 194 + if opt.LocalRotationKey != nil { 195 + pub, err := opt.LocalRotationKey.PublicKey() 196 + if err != nil { 197 + return nil, fmt.Errorf("plc: failed to derive local rotation public key: %w", err) 198 + } 199 + res.LocalDIDKey = pub.DIDKey() 200 + for _, k := range res.Keys { 201 + if k == res.LocalDIDKey { 202 + res.LocalPresent = true 203 + break 204 + } 205 + } 206 + } 207 + return res, nil 208 + }
+304
pkg/atproto/did/cmd_test.go
··· 1 + package did 2 + 3 + import ( 4 + "context" 5 + "strings" 6 + "testing" 7 + 8 + "github.com/bluesky-social/indigo/atproto/atcrypto" 9 + ) 10 + 11 + // TestAddRotationKey_AppendNew confirms a fresh key is appended at the lowest priority 12 + // when Prepend is false. 13 + func TestAddRotationKey_AppendNew(t *testing.T) { 14 + ctx := context.Background() 15 + 16 + serverRot := generateK256(t) 17 + signing := generateK256(t) 18 + fake := newFakePLC(t, []*atcrypto.PrivateKeyK256{serverRot}, serverRot, signing) 19 + defer fake.Close() 20 + 21 + newKey := generateK256(t) 22 + res, err := AddRotationKey(ctx, AddRotationKeyOptions{ 23 + DID: fake.did, 24 + PLCDirectoryURL: fake.URL(), 25 + RotationKey: serverRot, 26 + SigningKey: signing, 27 + VerificationKeyName: "atproto", 28 + NewKey: newKey, 29 + Prepend: false, 30 + }) 31 + if err != nil { 32 + t.Fatalf("AddRotationKey: %v", err) 33 + } 34 + if res.AlreadyPresent { 35 + t.Fatal("AlreadyPresent should be false") 36 + } 37 + if res.Generated { 38 + t.Error("Generated should be false when NewKey provided") 39 + } 40 + if res.TotalKeys != 2 { 41 + t.Errorf("TotalKeys: got %d want 2", res.TotalKeys) 42 + } 43 + if res.InsertedAt != 1 { 44 + t.Errorf("InsertedAt: got %d want 1 (appended)", res.InsertedAt) 45 + } 46 + 47 + if len(fake.submitted) != 1 { 48 + t.Fatalf("expected one update submission, got %d", len(fake.submitted)) 49 + } 50 + got := fake.submitted[0] 51 + newPub, _ := newKey.PublicKey() 52 + if got.RotationKeys[len(got.RotationKeys)-1] != newPub.DIDKey() { 53 + t.Errorf("appended key not at last position: %v", got.RotationKeys) 54 + } 55 + } 56 + 57 + // TestAddRotationKey_Prepend confirms Prepend=true puts the new key at index 0 58 + // (highest priority position). 59 + func TestAddRotationKey_Prepend(t *testing.T) { 60 + ctx := context.Background() 61 + 62 + serverRot := generateK256(t) 63 + signing := generateK256(t) 64 + fake := newFakePLC(t, []*atcrypto.PrivateKeyK256{serverRot}, serverRot, signing) 65 + defer fake.Close() 66 + 67 + newKey := generateK256(t) 68 + res, err := AddRotationKey(ctx, AddRotationKeyOptions{ 69 + DID: fake.did, 70 + PLCDirectoryURL: fake.URL(), 71 + RotationKey: serverRot, 72 + SigningKey: signing, 73 + VerificationKeyName: "atproto", 74 + NewKey: newKey, 75 + Prepend: true, 76 + }) 77 + if err != nil { 78 + t.Fatalf("AddRotationKey: %v", err) 79 + } 80 + if res.InsertedAt != 0 { 81 + t.Errorf("InsertedAt: got %d want 0", res.InsertedAt) 82 + } 83 + if len(fake.submitted) != 1 { 84 + t.Fatalf("expected one update, got %d", len(fake.submitted)) 85 + } 86 + newPub, _ := newKey.PublicKey() 87 + if fake.submitted[0].RotationKeys[0] != newPub.DIDKey() { 88 + t.Errorf("prepended key not at first position: %v", fake.submitted[0].RotationKeys) 89 + } 90 + } 91 + 92 + // TestAddRotationKey_GeneratesWhenNil confirms the helper generates a fresh key 93 + // and reports it via Result. 94 + func TestAddRotationKey_GeneratesWhenNil(t *testing.T) { 95 + ctx := context.Background() 96 + 97 + serverRot := generateK256(t) 98 + signing := generateK256(t) 99 + fake := newFakePLC(t, []*atcrypto.PrivateKeyK256{serverRot}, serverRot, signing) 100 + defer fake.Close() 101 + 102 + res, err := AddRotationKey(ctx, AddRotationKeyOptions{ 103 + DID: fake.did, 104 + PLCDirectoryURL: fake.URL(), 105 + RotationKey: serverRot, 106 + SigningKey: signing, 107 + VerificationKeyName: "atproto", 108 + NewKey: nil, 109 + }) 110 + if err != nil { 111 + t.Fatalf("AddRotationKey: %v", err) 112 + } 113 + if !res.Generated { 114 + t.Error("Generated should be true when NewKey is nil") 115 + } 116 + if res.NewKey == nil { 117 + t.Fatal("NewKey on result should not be nil") 118 + } 119 + if res.NewKeyDIDKey == "" { 120 + t.Error("NewKeyDIDKey should be populated") 121 + } 122 + } 123 + 124 + // TestAddRotationKey_AlreadyPresent confirms a no-op when the key is already in the list. 125 + func TestAddRotationKey_AlreadyPresent(t *testing.T) { 126 + ctx := context.Background() 127 + 128 + serverRot := generateK256(t) 129 + signing := generateK256(t) 130 + fake := newFakePLC(t, []*atcrypto.PrivateKeyK256{serverRot}, serverRot, signing) 131 + defer fake.Close() 132 + 133 + res, err := AddRotationKey(ctx, AddRotationKeyOptions{ 134 + DID: fake.did, 135 + PLCDirectoryURL: fake.URL(), 136 + RotationKey: serverRot, 137 + SigningKey: signing, 138 + VerificationKeyName: "atproto", 139 + NewKey: serverRot, 140 + }) 141 + if err != nil { 142 + t.Fatalf("AddRotationKey: %v", err) 143 + } 144 + if !res.AlreadyPresent { 145 + t.Error("AlreadyPresent should be true") 146 + } 147 + if res.ExistingAt != 0 { 148 + t.Errorf("ExistingAt: got %d want 0", res.ExistingAt) 149 + } 150 + if len(fake.submitted) != 0 { 151 + t.Errorf("no submission expected when key already present, got %d", len(fake.submitted)) 152 + } 153 + } 154 + 155 + // TestAddRotationKey_ValidationErrors covers the early-return guard clauses in AddRotationKey. 156 + func TestAddRotationKey_ValidationErrors(t *testing.T) { 157 + ctx := context.Background() 158 + signing := generateK256(t) 159 + rot := generateK256(t) 160 + 161 + cases := []struct { 162 + name string 163 + opt AddRotationKeyOptions 164 + wantSub string 165 + }{ 166 + { 167 + name: "missing DID", 168 + opt: AddRotationKeyOptions{RotationKey: rot, SigningKey: signing, VerificationKeyName: "atproto"}, 169 + wantSub: "DID is required", 170 + }, 171 + { 172 + name: "missing rotation key", 173 + opt: AddRotationKeyOptions{DID: "did:plc:abc", SigningKey: signing, VerificationKeyName: "atproto"}, 174 + wantSub: "rotation key is required", 175 + }, 176 + { 177 + name: "missing signing key", 178 + opt: AddRotationKeyOptions{DID: "did:plc:abc", RotationKey: rot, VerificationKeyName: "atproto"}, 179 + wantSub: "signing key is required", 180 + }, 181 + { 182 + name: "missing verification key name", 183 + opt: AddRotationKeyOptions{DID: "did:plc:abc", RotationKey: rot, SigningKey: signing}, 184 + wantSub: "VerificationKeyName is required", 185 + }, 186 + } 187 + for _, tc := range cases { 188 + t.Run(tc.name, func(t *testing.T) { 189 + _, err := AddRotationKey(ctx, tc.opt) 190 + if err == nil { 191 + t.Fatal("expected error, got nil") 192 + } 193 + if !strings.Contains(err.Error(), tc.wantSub) { 194 + t.Errorf("error: got %q want substring %q", err.Error(), tc.wantSub) 195 + } 196 + }) 197 + } 198 + } 199 + 200 + // TestListRotationKeys returns the priority-ordered keys from the latest op. 201 + func TestListRotationKeys(t *testing.T) { 202 + ctx := context.Background() 203 + 204 + rot1 := generateK256(t) 205 + rot2 := generateK256(t) 206 + signing := generateK256(t) 207 + fake := newFakePLC(t, []*atcrypto.PrivateKeyK256{rot1, rot2}, rot1, signing) 208 + defer fake.Close() 209 + 210 + res, err := ListRotationKeys(ctx, ListRotationKeysOptions{ 211 + DID: fake.did, 212 + PLCDirectoryURL: fake.URL(), 213 + }) 214 + if err != nil { 215 + t.Fatalf("ListRotationKeys: %v", err) 216 + } 217 + if res.DID != fake.did { 218 + t.Errorf("DID: got %s want %s", res.DID, fake.did) 219 + } 220 + if res.Directory != fake.URL() { 221 + t.Errorf("Directory: got %s want %s", res.Directory, fake.URL()) 222 + } 223 + if len(res.Keys) != 2 { 224 + t.Fatalf("Keys length: got %d want 2", len(res.Keys)) 225 + } 226 + pub1, _ := rot1.PublicKey() 227 + pub2, _ := rot2.PublicKey() 228 + if res.Keys[0] != pub1.DIDKey() { 229 + t.Errorf("Keys[0]: got %s want %s", res.Keys[0], pub1.DIDKey()) 230 + } 231 + if res.Keys[1] != pub2.DIDKey() { 232 + t.Errorf("Keys[1]: got %s want %s", res.Keys[1], pub2.DIDKey()) 233 + } 234 + if res.LocalDIDKey != "" { 235 + t.Errorf("LocalDIDKey should be empty when no LocalRotationKey provided, got %s", res.LocalDIDKey) 236 + } 237 + if res.LocalPresent { 238 + t.Error("LocalPresent should be false when no LocalRotationKey provided") 239 + } 240 + } 241 + 242 + // TestListRotationKeys_LocalPresent confirms LocalPresent flips to true when the local 243 + // key matches one in the published list. 244 + func TestListRotationKeys_LocalPresent(t *testing.T) { 245 + ctx := context.Background() 246 + 247 + serverRot := generateK256(t) 248 + signing := generateK256(t) 249 + fake := newFakePLC(t, []*atcrypto.PrivateKeyK256{serverRot}, serverRot, signing) 250 + defer fake.Close() 251 + 252 + res, err := ListRotationKeys(ctx, ListRotationKeysOptions{ 253 + DID: fake.did, 254 + PLCDirectoryURL: fake.URL(), 255 + LocalRotationKey: serverRot, 256 + }) 257 + if err != nil { 258 + t.Fatalf("ListRotationKeys: %v", err) 259 + } 260 + if !res.LocalPresent { 261 + t.Error("LocalPresent should be true") 262 + } 263 + pub, _ := serverRot.PublicKey() 264 + if res.LocalDIDKey != pub.DIDKey() { 265 + t.Errorf("LocalDIDKey: got %s want %s", res.LocalDIDKey, pub.DIDKey()) 266 + } 267 + } 268 + 269 + // TestListRotationKeys_LocalNotPresent flags a rotated-out local key. 270 + func TestListRotationKeys_LocalNotPresent(t *testing.T) { 271 + ctx := context.Background() 272 + 273 + serverRot := generateK256(t) 274 + signing := generateK256(t) 275 + fake := newFakePLC(t, []*atcrypto.PrivateKeyK256{serverRot}, serverRot, signing) 276 + defer fake.Close() 277 + 278 + stranger := generateK256(t) 279 + res, err := ListRotationKeys(ctx, ListRotationKeysOptions{ 280 + DID: fake.did, 281 + PLCDirectoryURL: fake.URL(), 282 + LocalRotationKey: stranger, 283 + }) 284 + if err != nil { 285 + t.Fatalf("ListRotationKeys: %v", err) 286 + } 287 + if res.LocalPresent { 288 + t.Error("LocalPresent should be false for a stranger key") 289 + } 290 + if res.LocalDIDKey == "" { 291 + t.Error("LocalDIDKey should still be populated even when not present") 292 + } 293 + } 294 + 295 + // TestListRotationKeys_MissingDID surfaces the early-return validation. 296 + func TestListRotationKeys_MissingDID(t *testing.T) { 297 + _, err := ListRotationKeys(context.Background(), ListRotationKeysOptions{}) 298 + if err == nil { 299 + t.Fatal("expected error for missing DID") 300 + } 301 + if !strings.Contains(err.Error(), "DID is required") { 302 + t.Errorf("error: got %q", err.Error()) 303 + } 304 + }
+243
pkg/atproto/did/did.go
··· 1 + // Package did provides shared did:web and did:plc identity management for ATCR services. 2 + // 3 + // Both the hold and labeler services declare an ATProto identity with a signing key 4 + // and one or more service endpoints. This package generalizes the genesis/update/load 5 + // flow so callers only have to specify their verification key fragment name and the 6 + // service entries they want to register. 7 + package did 8 + 9 + import ( 10 + "context" 11 + "encoding/json" 12 + "fmt" 13 + "log/slog" 14 + "net/url" 15 + "os" 16 + "path/filepath" 17 + "strings" 18 + 19 + "atcr.io/pkg/auth/oauth" 20 + "github.com/bluesky-social/indigo/atproto/atcrypto" 21 + ) 22 + 23 + // Service is a service entry in a DID document or PLC operation. 24 + type Service struct { 25 + Type string 26 + Endpoint string 27 + } 28 + 29 + // Config configures DID identity loading or creation. 30 + type Config struct { 31 + // Method is "web" or "plc". 32 + Method string 33 + 34 + // PublicURL is the externally reachable URL of the service. 35 + PublicURL string 36 + 37 + // DBPath is a directory used to persist did.txt for did:plc identities. 38 + DBPath string 39 + 40 + // SigningKeyPath is the on-disk path for the K-256 signing key (will be generated if missing). 41 + SigningKeyPath string 42 + 43 + // RotationKey is a multibase-encoded private key used to sign PLC operations (optional). 44 + // If empty for did:plc, a new rotation key is generated and logged once for the operator. 45 + RotationKey string 46 + 47 + // PLCDirectoryURL is the PLC directory endpoint. 48 + PLCDirectoryURL string 49 + 50 + // DID overrides the persisted DID (used for adoption/recovery of an existing did:plc). 51 + DID string 52 + 53 + // VerificationKeyName is the fragment used in the DID document and PLC operation 54 + // for the signing key (e.g. "atproto" for a PDS, "atproto_label" for a labeler). 55 + VerificationKeyName string 56 + 57 + // Services lists service entries keyed by service id (e.g. "atproto_pds", "atproto_labeler"). 58 + Services map[string]Service 59 + } 60 + 61 + // LoadOrCreate returns the service's DID. did:web is derived deterministically from 62 + // PublicURL; did:plc is loaded from disk or created and registered with the PLC directory. 63 + func LoadOrCreate(ctx context.Context, cfg Config) (string, error) { 64 + if cfg.Method != "plc" { 65 + return GenerateDIDFromURL(cfg.PublicURL), nil 66 + } 67 + 68 + if cfg.VerificationKeyName == "" { 69 + return "", fmt.Errorf("did: VerificationKeyName is required for did:plc") 70 + } 71 + if len(cfg.Services) == 0 { 72 + return "", fmt.Errorf("did: at least one service entry is required for did:plc") 73 + } 74 + 75 + didPath := filepath.Join(cfg.DBPath, "did.txt") 76 + 77 + var d string 78 + if cfg.DID != "" { 79 + if !strings.HasPrefix(cfg.DID, "did:plc:") { 80 + return "", fmt.Errorf("did: DID must be a did:plc identifier, got %q", cfg.DID) 81 + } 82 + d = cfg.DID 83 + slog.Info("Using DID from config (adoption/recovery)", "did", d) 84 + } else if data, err := os.ReadFile(didPath); err == nil { 85 + val := strings.TrimSpace(string(data)) 86 + if strings.HasPrefix(val, "did:plc:") { 87 + d = val 88 + slog.Info("Loaded existing did:plc identity", "did", d) 89 + } 90 + } 91 + 92 + if d != "" { 93 + if err := os.MkdirAll(filepath.Dir(didPath), 0755); err != nil { 94 + return "", fmt.Errorf("did: failed to create did.txt directory: %w", err) 95 + } 96 + if err := os.WriteFile(didPath, []byte(d+"\n"), 0600); err != nil { 97 + return "", fmt.Errorf("did: failed to write did.txt: %w", err) 98 + } 99 + 100 + signingKey, err := oauth.GenerateOrLoadPDSKey(cfg.SigningKeyPath) 101 + if err != nil { 102 + return "", fmt.Errorf("did: failed to load signing key: %w", err) 103 + } 104 + rotationKey, _ := parseOptionalMultibaseKey(cfg.RotationKey) 105 + 106 + if err := EnsureCurrent(ctx, d, rotationKey, signingKey, cfg); err != nil { 107 + slog.Warn("Failed to verify PLC identity is current (will retry on next restart)", 108 + "did", d, "error", err) 109 + } 110 + return d, nil 111 + } 112 + 113 + slog.Info("Creating new did:plc identity") 114 + 115 + signingKey, err := oauth.GenerateOrLoadPDSKey(cfg.SigningKeyPath) 116 + if err != nil { 117 + return "", fmt.Errorf("did: failed to load signing key: %w", err) 118 + } 119 + 120 + var rotationKey atcrypto.PrivateKeyExportable 121 + if cfg.RotationKey != "" { 122 + rotationKey, err = parseOptionalMultibaseKey(cfg.RotationKey) 123 + if err != nil { 124 + return "", fmt.Errorf("did: failed to parse rotation_key: %w", err) 125 + } 126 + } else { 127 + rawKey, genErr := atcrypto.GeneratePrivateKeyK256() 128 + if genErr != nil { 129 + return "", fmt.Errorf("did: failed to generate rotation key: %w", genErr) 130 + } 131 + rotationKey = rawKey 132 + slog.Warn("Generated new rotation key — save this in your config as rotation_key", 133 + "rotation_key", rawKey.Multibase()) 134 + } 135 + 136 + d, err = CreateIdentity(ctx, rotationKey, signingKey, cfg) 137 + if err != nil { 138 + return "", fmt.Errorf("did: failed to create PLC identity: %w", err) 139 + } 140 + 141 + if err := os.MkdirAll(filepath.Dir(didPath), 0755); err != nil { 142 + return "", fmt.Errorf("did: failed to create did.txt directory: %w", err) 143 + } 144 + if err := os.WriteFile(didPath, []byte(d+"\n"), 0600); err != nil { 145 + return "", fmt.Errorf("did: failed to write did.txt: %w", err) 146 + } 147 + 148 + slog.Info("Created did:plc identity", "did", d, "plc_directory", cfg.PLCDirectoryURL) 149 + slog.Warn("Back up your rotation_key. It is only needed for DID updates (URL changes, key rotation).") 150 + return d, nil 151 + } 152 + 153 + // DIDDocument is the JSON shape we serve for did:web identities. 154 + type DIDDocument struct { 155 + Context []string `json:"@context"` 156 + ID string `json:"id"` 157 + AlsoKnownAs []string `json:"alsoKnownAs,omitempty"` 158 + VerificationMethod []VerificationMethod `json:"verificationMethod"` 159 + Authentication []string `json:"authentication,omitempty"` 160 + AssertionMethod []string `json:"assertionMethod,omitempty"` 161 + Service []DIDService `json:"service,omitempty"` 162 + } 163 + 164 + // VerificationMethod is a public key entry in a DID document. 165 + type VerificationMethod struct { 166 + ID string `json:"id"` 167 + Type string `json:"type"` 168 + Controller string `json:"controller"` 169 + PublicKeyMultibase string `json:"publicKeyMultibase"` 170 + } 171 + 172 + // DIDService is a service entry in a DID document. 173 + type DIDService struct { 174 + ID string `json:"id"` 175 + Type string `json:"type"` 176 + ServiceEndpoint string `json:"serviceEndpoint"` 177 + } 178 + 179 + // BuildDIDDocument constructs a DID document for a did:web identity. The verification 180 + // method fragment matches verificationKeyName (e.g. "#atproto" or "#atproto_label"); 181 + // pass "" to default to "atproto". Authentication is only added for the standard 182 + // "atproto" key per the bsky/PDS pattern. 183 + func BuildDIDDocument(did, publicURL string, signingKey *atcrypto.PrivateKeyK256, verificationKeyName string, services map[string]Service) (*DIDDocument, error) { 184 + host, err := hostWithPort(publicURL) 185 + if err != nil { 186 + return nil, err 187 + } 188 + pub, err := signingKey.PublicKey() 189 + if err != nil { 190 + return nil, fmt.Errorf("did: failed to get public key: %w", err) 191 + } 192 + 193 + keyName := verificationKeyName 194 + if keyName == "" { 195 + keyName = "atproto" 196 + } 197 + 198 + doc := &DIDDocument{ 199 + Context: []string{ 200 + "https://www.w3.org/ns/did/v1", 201 + "https://w3id.org/security/multikey/v1", 202 + "https://w3id.org/security/suites/secp256k1-2019/v1", 203 + }, 204 + ID: did, 205 + AlsoKnownAs: []string{"at://" + host}, 206 + VerificationMethod: []VerificationMethod{ 207 + { 208 + ID: fmt.Sprintf("%s#%s", did, keyName), 209 + Type: "Multikey", 210 + Controller: did, 211 + PublicKeyMultibase: pub.Multibase(), 212 + }, 213 + }, 214 + } 215 + if keyName == "atproto" { 216 + doc.Authentication = []string{fmt.Sprintf("%s#atproto", did)} 217 + } 218 + for id, svc := range services { 219 + doc.Service = append(doc.Service, DIDService{ 220 + ID: "#" + id, 221 + Type: svc.Type, 222 + ServiceEndpoint: svc.Endpoint, 223 + }) 224 + } 225 + return doc, nil 226 + } 227 + 228 + // MarshalDIDDocument is a convenience for serving a DID doc as indented JSON. 229 + func MarshalDIDDocument(doc *DIDDocument) ([]byte, error) { 230 + return json.MarshalIndent(doc, "", " ") 231 + } 232 + 233 + func hostWithPort(publicURL string) (string, error) { 234 + u, err := url.Parse(publicURL) 235 + if err != nil { 236 + return "", fmt.Errorf("did: failed to parse public URL: %w", err) 237 + } 238 + host := u.Hostname() 239 + if port := u.Port(); port != "" && port != "80" && port != "443" { 240 + host = host + ":" + port 241 + } 242 + return host, nil 243 + }
+356
pkg/atproto/did/did_test.go
··· 1 + package did 2 + 3 + import ( 4 + "context" 5 + "encoding/json" 6 + "os" 7 + "path/filepath" 8 + "strings" 9 + "testing" 10 + ) 11 + 12 + func testServices(publicURL string) map[string]Service { 13 + return map[string]Service{ 14 + "atproto_pds": {Type: "AtprotoPersonalDataServer", Endpoint: publicURL}, 15 + "atcr_hold": {Type: "AtcrHoldService", Endpoint: publicURL}, 16 + } 17 + } 18 + 19 + // TestBuildDIDDocument verifies the standard atproto DID document layout for a did:web service. 20 + func TestBuildDIDDocument(t *testing.T) { 21 + publicURL := "https://hold.example.com" 22 + signingKey := generateK256(t) 23 + 24 + doc, err := BuildDIDDocument("did:web:hold.example.com", publicURL, signingKey, "atproto", testServices(publicURL)) 25 + if err != nil { 26 + t.Fatalf("BuildDIDDocument: %v", err) 27 + } 28 + 29 + if doc.ID != "did:web:hold.example.com" { 30 + t.Errorf("ID: got %s want did:web:hold.example.com", doc.ID) 31 + } 32 + 33 + expectedContexts := []string{ 34 + "https://www.w3.org/ns/did/v1", 35 + "https://w3id.org/security/multikey/v1", 36 + "https://w3id.org/security/suites/secp256k1-2019/v1", 37 + } 38 + if len(doc.Context) != len(expectedContexts) { 39 + t.Errorf("Context length: got %d want %d", len(doc.Context), len(expectedContexts)) 40 + } 41 + for i, expected := range expectedContexts { 42 + if doc.Context[i] != expected { 43 + t.Errorf("Context[%d]: got %s want %s", i, doc.Context[i], expected) 44 + } 45 + } 46 + 47 + if len(doc.AlsoKnownAs) != 1 || doc.AlsoKnownAs[0] != "at://hold.example.com" { 48 + t.Errorf("AlsoKnownAs: got %v want [at://hold.example.com]", doc.AlsoKnownAs) 49 + } 50 + 51 + if len(doc.VerificationMethod) != 1 { 52 + t.Fatalf("VerificationMethod length: got %d want 1", len(doc.VerificationMethod)) 53 + } 54 + vm := doc.VerificationMethod[0] 55 + if vm.ID != "did:web:hold.example.com#atproto" { 56 + t.Errorf("VerificationMethod.ID: got %s", vm.ID) 57 + } 58 + if vm.Type != "Multikey" { 59 + t.Errorf("VerificationMethod.Type: got %s want Multikey", vm.Type) 60 + } 61 + if vm.Controller != "did:web:hold.example.com" { 62 + t.Errorf("VerificationMethod.Controller: got %s", vm.Controller) 63 + } 64 + if vm.PublicKeyMultibase == "" { 65 + t.Error("VerificationMethod.PublicKeyMultibase is empty") 66 + } 67 + 68 + pub, _ := signingKey.PublicKey() 69 + if vm.PublicKeyMultibase != pub.Multibase() { 70 + t.Errorf("VerificationMethod.PublicKeyMultibase: got %s want %s", vm.PublicKeyMultibase, pub.Multibase()) 71 + } 72 + 73 + if len(doc.Authentication) != 1 || doc.Authentication[0] != "did:web:hold.example.com#atproto" { 74 + t.Errorf("Authentication: got %v", doc.Authentication) 75 + } 76 + 77 + if len(doc.Service) != 2 { 78 + t.Fatalf("Service length: got %d want 2", len(doc.Service)) 79 + } 80 + 81 + svcByID := map[string]DIDService{} 82 + for _, s := range doc.Service { 83 + svcByID[s.ID] = s 84 + } 85 + pdsService, ok := svcByID["#atproto_pds"] 86 + if !ok { 87 + t.Fatalf("missing #atproto_pds service in %v", svcByID) 88 + } 89 + if pdsService.Type != "AtprotoPersonalDataServer" { 90 + t.Errorf("#atproto_pds Type: got %s", pdsService.Type) 91 + } 92 + if pdsService.ServiceEndpoint != publicURL { 93 + t.Errorf("#atproto_pds Endpoint: got %s want %s", pdsService.ServiceEndpoint, publicURL) 94 + } 95 + holdService, ok := svcByID["#atcr_hold"] 96 + if !ok { 97 + t.Fatalf("missing #atcr_hold service in %v", svcByID) 98 + } 99 + if holdService.Type != "AtcrHoldService" { 100 + t.Errorf("#atcr_hold Type: got %s", holdService.Type) 101 + } 102 + if holdService.ServiceEndpoint != publicURL { 103 + t.Errorf("#atcr_hold Endpoint: got %s want %s", holdService.ServiceEndpoint, publicURL) 104 + } 105 + } 106 + 107 + // TestBuildDIDDocument_WithPort confirms non-standard ports flow into AlsoKnownAs. 108 + func TestBuildDIDDocument_WithPort(t *testing.T) { 109 + publicURL := "https://hold.example.com:8443" 110 + signingKey := generateK256(t) 111 + 112 + doc, err := BuildDIDDocument("did:web:hold.example.com%3A8443", publicURL, signingKey, "atproto", testServices(publicURL)) 113 + if err != nil { 114 + t.Fatalf("BuildDIDDocument: %v", err) 115 + } 116 + 117 + if doc.ID != "did:web:hold.example.com%3A8443" { 118 + t.Errorf("ID: got %s", doc.ID) 119 + } 120 + if doc.AlsoKnownAs[0] != "at://hold.example.com:8443" { 121 + t.Errorf("AlsoKnownAs: got %s want at://hold.example.com:8443", doc.AlsoKnownAs[0]) 122 + } 123 + } 124 + 125 + // TestBuildDIDDocument_StandardPortsStripped verifies port 80/443 are not appended to alsoKnownAs. 126 + func TestBuildDIDDocument_StandardPortsStripped(t *testing.T) { 127 + signingKey := generateK256(t) 128 + 129 + cases := []struct { 130 + name string 131 + publicURL string 132 + wantAKA string 133 + }{ 134 + {"http port 80", "http://hold.example.com:80", "at://hold.example.com"}, 135 + {"https port 443", "https://hold.example.com:443", "at://hold.example.com"}, 136 + } 137 + for _, tc := range cases { 138 + t.Run(tc.name, func(t *testing.T) { 139 + doc, err := BuildDIDDocument("did:web:hold.example.com", tc.publicURL, signingKey, "atproto", nil) 140 + if err != nil { 141 + t.Fatalf("BuildDIDDocument: %v", err) 142 + } 143 + if doc.AlsoKnownAs[0] != tc.wantAKA { 144 + t.Errorf("AlsoKnownAs: got %s want %s", doc.AlsoKnownAs[0], tc.wantAKA) 145 + } 146 + }) 147 + } 148 + } 149 + 150 + // TestBuildDIDDocument_InvalidURL confirms malformed URLs surface as errors. 151 + func TestBuildDIDDocument_InvalidURL(t *testing.T) { 152 + signingKey := generateK256(t) 153 + _, err := BuildDIDDocument("did:web:bogus", "ht!tp://invalid url", signingKey, "atproto", nil) 154 + if err == nil { 155 + t.Fatal("expected error for invalid URL, got nil") 156 + } 157 + } 158 + 159 + // TestBuildDIDDocument_DefaultVerificationKeyName confirms the empty fragment defaults to "atproto" 160 + // and adds Authentication. 161 + func TestBuildDIDDocument_DefaultVerificationKeyName(t *testing.T) { 162 + signingKey := generateK256(t) 163 + doc, err := BuildDIDDocument("did:web:example.com", "https://example.com", signingKey, "", nil) 164 + if err != nil { 165 + t.Fatalf("BuildDIDDocument: %v", err) 166 + } 167 + if doc.VerificationMethod[0].ID != "did:web:example.com#atproto" { 168 + t.Errorf("VerificationMethod.ID: got %s want did:web:example.com#atproto", doc.VerificationMethod[0].ID) 169 + } 170 + if len(doc.Authentication) != 1 || doc.Authentication[0] != "did:web:example.com#atproto" { 171 + t.Errorf("Authentication: got %v", doc.Authentication) 172 + } 173 + } 174 + 175 + // TestBuildDIDDocument_LabelerKey confirms a non-"atproto" verification key (e.g. labeler) 176 + // does not add Authentication, mirroring the bsky labeler pattern. 177 + func TestBuildDIDDocument_LabelerKey(t *testing.T) { 178 + signingKey := generateK256(t) 179 + services := map[string]Service{ 180 + "atproto_labeler": {Type: "AtprotoLabeler", Endpoint: "https://labeler.example.com"}, 181 + } 182 + doc, err := BuildDIDDocument("did:web:labeler.example.com", "https://labeler.example.com", signingKey, "atproto_label", services) 183 + if err != nil { 184 + t.Fatalf("BuildDIDDocument: %v", err) 185 + } 186 + if doc.VerificationMethod[0].ID != "did:web:labeler.example.com#atproto_label" { 187 + t.Errorf("VerificationMethod.ID: got %s", doc.VerificationMethod[0].ID) 188 + } 189 + if len(doc.Authentication) != 0 { 190 + t.Errorf("Authentication should be empty for non-atproto key, got %v", doc.Authentication) 191 + } 192 + if len(doc.Service) != 1 || doc.Service[0].ID != "#atproto_labeler" { 193 + t.Errorf("Service: got %v", doc.Service) 194 + } 195 + } 196 + 197 + // TestBuildDIDDocument_NoServices confirms a DID document can be built without any service entries. 198 + func TestBuildDIDDocument_NoServices(t *testing.T) { 199 + signingKey := generateK256(t) 200 + doc, err := BuildDIDDocument("did:web:example.com", "https://example.com", signingKey, "atproto", nil) 201 + if err != nil { 202 + t.Fatalf("BuildDIDDocument: %v", err) 203 + } 204 + if len(doc.Service) != 0 { 205 + t.Errorf("Service should be empty, got %v", doc.Service) 206 + } 207 + } 208 + 209 + // TestMarshalDIDDocument confirms marshaling produces parseable, indented JSON. 210 + func TestMarshalDIDDocument(t *testing.T) { 211 + signingKey := generateK256(t) 212 + doc, err := BuildDIDDocument("did:web:example.com", "https://example.com", signingKey, "atproto", testServices("https://example.com")) 213 + if err != nil { 214 + t.Fatalf("BuildDIDDocument: %v", err) 215 + } 216 + 217 + data, err := MarshalDIDDocument(doc) 218 + if err != nil { 219 + t.Fatalf("MarshalDIDDocument: %v", err) 220 + } 221 + 222 + if !strings.Contains(string(data), " ") { 223 + t.Error("expected indented JSON output") 224 + } 225 + 226 + var parsed DIDDocument 227 + if err := json.Unmarshal(data, &parsed); err != nil { 228 + t.Fatalf("Unmarshal: %v", err) 229 + } 230 + if parsed.ID != doc.ID { 231 + t.Errorf("ID round-trip: got %s want %s", parsed.ID, doc.ID) 232 + } 233 + if len(parsed.Service) != len(doc.Service) { 234 + t.Errorf("Service length round-trip: got %d want %d", len(parsed.Service), len(doc.Service)) 235 + } 236 + } 237 + 238 + // TestLoadOrCreate_DIDWeb confirms did:web mode returns a deterministic identifier 239 + // without touching disk or any external service. 240 + func TestLoadOrCreate_DIDWeb(t *testing.T) { 241 + cfg := Config{ 242 + Method: "web", 243 + PublicURL: "https://hold.example.com", 244 + } 245 + d, err := LoadOrCreate(context.Background(), cfg) 246 + if err != nil { 247 + t.Fatalf("LoadOrCreate: %v", err) 248 + } 249 + if d != "did:web:hold.example.com" { 250 + t.Errorf("DID: got %s want did:web:hold.example.com", d) 251 + } 252 + } 253 + 254 + // TestLoadOrCreate_DIDWebDefaultsToWebWhenMethodEmpty confirms an empty method behaves like did:web. 255 + func TestLoadOrCreate_DIDWebDefaultsToWebWhenMethodEmpty(t *testing.T) { 256 + cfg := Config{ 257 + PublicURL: "https://example.com:8443", 258 + } 259 + d, err := LoadOrCreate(context.Background(), cfg) 260 + if err != nil { 261 + t.Fatalf("LoadOrCreate: %v", err) 262 + } 263 + if d != "did:web:example.com%3A8443" { 264 + t.Errorf("DID: got %s want did:web:example.com%%3A8443", d) 265 + } 266 + } 267 + 268 + // TestLoadOrCreate_PLCRequiresVerificationKeyName confirms missing required PLC fields error early. 269 + func TestLoadOrCreate_PLCRequiresVerificationKeyName(t *testing.T) { 270 + cfg := Config{ 271 + Method: "plc", 272 + PublicURL: "https://example.com", 273 + Services: map[string]Service{ 274 + "atproto_pds": {Type: "AtprotoPersonalDataServer", Endpoint: "https://example.com"}, 275 + }, 276 + } 277 + _, err := LoadOrCreate(context.Background(), cfg) 278 + if err == nil { 279 + t.Fatal("expected error when VerificationKeyName is empty") 280 + } 281 + if !strings.Contains(err.Error(), "VerificationKeyName") { 282 + t.Errorf("expected error about VerificationKeyName, got: %v", err) 283 + } 284 + } 285 + 286 + // TestLoadOrCreate_PLCRequiresServices confirms PLC mode demands at least one service entry. 287 + func TestLoadOrCreate_PLCRequiresServices(t *testing.T) { 288 + cfg := Config{ 289 + Method: "plc", 290 + PublicURL: "https://example.com", 291 + VerificationKeyName: "atproto", 292 + } 293 + _, err := LoadOrCreate(context.Background(), cfg) 294 + if err == nil { 295 + t.Fatal("expected error when Services is empty") 296 + } 297 + if !strings.Contains(err.Error(), "service") { 298 + t.Errorf("expected error about services, got: %v", err) 299 + } 300 + } 301 + 302 + // TestLoadOrCreate_PLCRejectsNonPLCAdoption confirms a configured DID must be a did:plc identifier. 303 + func TestLoadOrCreate_PLCRejectsNonPLCAdoption(t *testing.T) { 304 + tmp := t.TempDir() 305 + cfg := Config{ 306 + Method: "plc", 307 + PublicURL: "https://example.com", 308 + DBPath: tmp, 309 + SigningKeyPath: filepath.Join(tmp, "signing.key"), 310 + VerificationKeyName: "atproto", 311 + DID: "did:web:example.com", 312 + Services: map[string]Service{ 313 + "atproto_pds": {Type: "AtprotoPersonalDataServer", Endpoint: "https://example.com"}, 314 + }, 315 + } 316 + _, err := LoadOrCreate(context.Background(), cfg) 317 + if err == nil { 318 + t.Fatal("expected error for non-did:plc adoption") 319 + } 320 + if !strings.Contains(err.Error(), "did:plc") { 321 + t.Errorf("expected error about did:plc, got: %v", err) 322 + } 323 + } 324 + 325 + // TestLoadOrCreate_PLCAdoptionPersistsDID confirms a configured did:plc is written to did.txt 326 + // even when the PLC directory call fails (the failure is logged, not returned). 327 + func TestLoadOrCreate_PLCAdoptionPersistsDID(t *testing.T) { 328 + tmp := t.TempDir() 329 + cfg := Config{ 330 + Method: "plc", 331 + PublicURL: "https://example.com", 332 + DBPath: tmp, 333 + SigningKeyPath: filepath.Join(tmp, "signing.key"), 334 + VerificationKeyName: "atproto", 335 + DID: "did:plc:abcdefghijklmnopqrstuvwx", 336 + PLCDirectoryURL: "http://127.0.0.1:1", // unreachable; EnsureCurrent failure is non-fatal 337 + Services: map[string]Service{ 338 + "atproto_pds": {Type: "AtprotoPersonalDataServer", Endpoint: "https://example.com"}, 339 + }, 340 + } 341 + d, err := LoadOrCreate(context.Background(), cfg) 342 + if err != nil { 343 + t.Fatalf("LoadOrCreate: %v", err) 344 + } 345 + if d != "did:plc:abcdefghijklmnopqrstuvwx" { 346 + t.Errorf("DID: got %s", d) 347 + } 348 + 349 + got, err := os.ReadFile(filepath.Join(tmp, "did.txt")) 350 + if err != nil { 351 + t.Fatalf("read did.txt: %v", err) 352 + } 353 + if strings.TrimSpace(string(got)) != "did:plc:abcdefghijklmnopqrstuvwx" { 354 + t.Errorf("did.txt contents: got %q", string(got)) 355 + } 356 + }
+172
pkg/atproto/did/plc.go
··· 1 + package did 2 + 3 + import ( 4 + "context" 5 + "fmt" 6 + "log/slog" 7 + 8 + "github.com/bluesky-social/indigo/atproto/atcrypto" 9 + didplc "github.com/did-method-plc/go-didplc" 10 + ) 11 + 12 + // CreateIdentity builds a genesis PLC operation with the configured verification key and 13 + // services, signs it with the rotation key, and submits it to the PLC directory. 14 + func CreateIdentity(ctx context.Context, rotationKey atcrypto.PrivateKey, signingKey *atcrypto.PrivateKeyK256, cfg Config) (string, error) { 15 + rotPub, err := rotationKey.PublicKey() 16 + if err != nil { 17 + return "", fmt.Errorf("did: failed to get rotation public key: %w", err) 18 + } 19 + sigPub, err := signingKey.PublicKey() 20 + if err != nil { 21 + return "", fmt.Errorf("did: failed to get signing public key: %w", err) 22 + } 23 + 24 + host, err := hostWithPort(cfg.PublicURL) 25 + if err != nil { 26 + return "", err 27 + } 28 + 29 + op := &didplc.RegularOp{ 30 + Type: "plc_operation", 31 + RotationKeys: []string{rotPub.DIDKey()}, 32 + VerificationMethods: map[string]string{ 33 + cfg.VerificationKeyName: sigPub.DIDKey(), 34 + }, 35 + AlsoKnownAs: []string{"at://" + host}, 36 + Services: toOpServices(cfg.Services), 37 + Prev: nil, 38 + } 39 + if err := op.Sign(rotationKey); err != nil { 40 + return "", fmt.Errorf("did: failed to sign genesis operation: %w", err) 41 + } 42 + 43 + d, err := op.DID() 44 + if err != nil { 45 + return "", fmt.Errorf("did: failed to compute DID from genesis: %w", err) 46 + } 47 + 48 + client := &didplc.Client{DirectoryURL: cfg.PLCDirectoryURL} 49 + if err := client.Submit(ctx, d, op); err != nil { 50 + return "", fmt.Errorf("did: failed to submit genesis operation: %w", err) 51 + } 52 + return d, nil 53 + } 54 + 55 + // EnsureCurrent reconciles the published DID document with the local config; if the local 56 + // signing key, public URL, or service set differs, an update operation is signed and submitted. 57 + // Without a rotation key, mismatches log a warning but are not fatal. 58 + func EnsureCurrent(ctx context.Context, did string, rotationKey atcrypto.PrivateKey, signingKey *atcrypto.PrivateKeyK256, cfg Config) error { 59 + client := &didplc.Client{DirectoryURL: cfg.PLCDirectoryURL} 60 + 61 + opLog, err := client.OpLog(ctx, did) 62 + if err != nil { 63 + return fmt.Errorf("did: failed to fetch op log for %s: %w", did, err) 64 + } 65 + if len(opLog) == 0 { 66 + return fmt.Errorf("did: empty op log for %s", did) 67 + } 68 + lastEntry := opLog[len(opLog)-1] 69 + lastOp := lastEntry.Regular 70 + if lastOp == nil { 71 + slog.Warn("Last PLC operation is not a regular op, skipping auto-update", "did", did) 72 + return nil 73 + } 74 + 75 + sigPub, err := signingKey.PublicKey() 76 + if err != nil { 77 + return fmt.Errorf("did: failed to get signing public key: %w", err) 78 + } 79 + localKey := sigPub.DIDKey() 80 + plcKey := lastOp.VerificationMethods[cfg.VerificationKeyName] 81 + keyMatch := localKey == plcKey 82 + 83 + servicesMatch := true 84 + for name, svc := range cfg.Services { 85 + plcSvc, ok := lastOp.Services[name] 86 + if !ok || plcSvc.Type != svc.Type || plcSvc.Endpoint != svc.Endpoint { 87 + servicesMatch = false 88 + break 89 + } 90 + } 91 + 92 + if keyMatch && servicesMatch { 93 + slog.Info("PLC identity is current", "did", did) 94 + return nil 95 + } 96 + 97 + slog.Info("PLC identity needs update", 98 + "did", did, "signing_key_changed", !keyMatch, "services_changed", !servicesMatch) 99 + 100 + if rotationKey == nil { 101 + slog.Warn("PLC document doesn't match local state but no rotation key available. Provide rotation key to auto-update.", 102 + "did", did, "signing_key_changed", !keyMatch, "services_changed", !servicesMatch) 103 + return nil 104 + } 105 + 106 + rotPub, err := rotationKey.PublicKey() 107 + if err != nil { 108 + return fmt.Errorf("did: failed to get rotation public key: %w", err) 109 + } 110 + localRotKey := rotPub.DIDKey() 111 + 112 + // Verify the local rotation key still has authority on the PLC document. 113 + // If it's been rotated out (possibly maliciously), refuse to submit — PLC would 114 + // reject anyway, and silent failure here would be confusing. 115 + localRotKeyPresent := false 116 + for _, k := range lastOp.RotationKeys { 117 + if k == localRotKey { 118 + localRotKeyPresent = true 119 + break 120 + } 121 + } 122 + if !localRotKeyPresent { 123 + slog.Warn("Local rotation key not present in PLC document — refusing to update. Possible compromise or out-of-band rotation. Recover with offline key if available.", 124 + "did", did, "local_rotation_key", localRotKey, "plc_rotation_keys", lastOp.RotationKeys) 125 + return nil 126 + } 127 + 128 + host, err := hostWithPort(cfg.PublicURL) 129 + if err != nil { 130 + return err 131 + } 132 + prevCID := lastEntry.AsOperation().CID().String() 133 + 134 + op := &didplc.RegularOp{ 135 + Type: "plc_operation", 136 + RotationKeys: lastOp.RotationKeys, 137 + VerificationMethods: map[string]string{ 138 + cfg.VerificationKeyName: localKey, 139 + }, 140 + AlsoKnownAs: []string{"at://" + host}, 141 + Services: toOpServices(cfg.Services), 142 + Prev: &prevCID, 143 + } 144 + if err := op.Sign(rotationKey); err != nil { 145 + return fmt.Errorf("did: failed to sign update operation: %w", err) 146 + } 147 + if err := client.Submit(ctx, did, op); err != nil { 148 + return fmt.Errorf("did: failed to submit update: %w", err) 149 + } 150 + slog.Info("Updated PLC identity", 151 + "did", did, "signing_key_rotated", !keyMatch, "services_changed", !servicesMatch) 152 + return nil 153 + } 154 + 155 + func toOpServices(in map[string]Service) map[string]didplc.OpService { 156 + out := make(map[string]didplc.OpService, len(in)) 157 + for name, svc := range in { 158 + out[name] = didplc.OpService{Type: svc.Type, Endpoint: svc.Endpoint} 159 + } 160 + return out 161 + } 162 + 163 + func parseOptionalMultibaseKey(encoded string) (atcrypto.PrivateKeyExportable, error) { 164 + if encoded == "" { 165 + return nil, nil 166 + } 167 + key, err := atcrypto.ParsePrivateMultibase(encoded) 168 + if err != nil { 169 + return nil, fmt.Errorf("did: failed to parse multibase key: %w", err) 170 + } 171 + return key, nil 172 + }
+243
pkg/atproto/did/plc_test.go
··· 1 + package did 2 + 3 + import ( 4 + "context" 5 + "encoding/json" 6 + "io" 7 + "net/http" 8 + "net/http/httptest" 9 + "os" 10 + "path/filepath" 11 + "strings" 12 + "testing" 13 + 14 + "github.com/bluesky-social/indigo/atproto/atcrypto" 15 + didplc "github.com/did-method-plc/go-didplc" 16 + ) 17 + 18 + // fakePLC stands up an httptest server that serves a single op log entry and 19 + // captures any submitted update op for inspection. 20 + type fakePLC struct { 21 + server *httptest.Server 22 + did string 23 + logEntries []didplc.OpEnum 24 + submitted []didplc.RegularOp 25 + } 26 + 27 + func (f *fakePLC) URL() string { return f.server.URL } 28 + 29 + func (f *fakePLC) Close() { f.server.Close() } 30 + 31 + // newFakePLC creates a fake PLC directory pre-loaded with a single signed 32 + // genesis op containing the given rotation keys (in priority order). Returns 33 + // the fake server and the resulting did:plc DID derived from the genesis op. 34 + func newFakePLC(t *testing.T, rotationKeys []*atcrypto.PrivateKeyK256, signer atcrypto.PrivateKey, signingKey *atcrypto.PrivateKeyK256) *fakePLC { 35 + t.Helper() 36 + 37 + rotationDIDKeys := make([]string, 0, len(rotationKeys)) 38 + for _, k := range rotationKeys { 39 + pub, err := k.PublicKey() 40 + if err != nil { 41 + t.Fatalf("rotation key public: %v", err) 42 + } 43 + rotationDIDKeys = append(rotationDIDKeys, pub.DIDKey()) 44 + } 45 + 46 + sigPub, err := signingKey.PublicKey() 47 + if err != nil { 48 + t.Fatalf("signing key public: %v", err) 49 + } 50 + 51 + op := &didplc.RegularOp{ 52 + Type: "plc_operation", 53 + RotationKeys: rotationDIDKeys, 54 + VerificationMethods: map[string]string{ 55 + "atproto": sigPub.DIDKey(), 56 + }, 57 + AlsoKnownAs: []string{"at://example.test"}, 58 + Services: map[string]didplc.OpService{ 59 + "atproto_pds": {Type: "AtprotoPersonalDataServer", Endpoint: "https://example.test"}, 60 + }, 61 + Prev: nil, 62 + } 63 + if err := op.Sign(signer); err != nil { 64 + t.Fatalf("sign genesis: %v", err) 65 + } 66 + did, err := op.DID() 67 + if err != nil { 68 + t.Fatalf("compute DID: %v", err) 69 + } 70 + 71 + f := &fakePLC{ 72 + did: did, 73 + logEntries: []didplc.OpEnum{{Regular: op}}, 74 + } 75 + 76 + mux := http.NewServeMux() 77 + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { 78 + if r.Method == http.MethodGet && strings.HasSuffix(r.URL.Path, "/log") { 79 + w.Header().Set("Content-Type", "application/json") 80 + _ = json.NewEncoder(w).Encode(f.logEntries) 81 + return 82 + } 83 + if r.Method == http.MethodPost && r.URL.Path == "/"+did { 84 + body, err := io.ReadAll(r.Body) 85 + if err != nil { 86 + http.Error(w, err.Error(), http.StatusBadRequest) 87 + return 88 + } 89 + var op didplc.RegularOp 90 + if err := json.Unmarshal(body, &op); err != nil { 91 + http.Error(w, err.Error(), http.StatusBadRequest) 92 + return 93 + } 94 + f.submitted = append(f.submitted, op) 95 + w.WriteHeader(http.StatusOK) 96 + return 97 + } 98 + http.NotFound(w, r) 99 + }) 100 + f.server = httptest.NewServer(mux) 101 + return f 102 + } 103 + 104 + // generateK256 returns a fresh K-256 keypair, failing the test on error. 105 + func generateK256(t *testing.T) *atcrypto.PrivateKeyK256 { 106 + t.Helper() 107 + k, err := atcrypto.GeneratePrivateKeyK256() 108 + if err != nil { 109 + t.Fatalf("generate K-256: %v", err) 110 + } 111 + return k 112 + } 113 + 114 + // writeSigningKey persists a signing key to a temp file and returns its path. 115 + // Used so EnsureCurrent can load it via oauth.GenerateOrLoadPDSKey. 116 + func writeSigningKey(t *testing.T, dir string, key *atcrypto.PrivateKeyK256) string { 117 + t.Helper() 118 + path := filepath.Join(dir, "signing.key") 119 + if err := os.WriteFile(path, key.Bytes(), 0600); err != nil { 120 + t.Fatalf("write signing key: %v", err) 121 + } 122 + return path 123 + } 124 + 125 + func TestEnsureCurrent_PreservesRotationKeys(t *testing.T) { 126 + ctx := context.Background() 127 + tmp := t.TempDir() 128 + 129 + // Server-side rotation key (the one stored in database.rotation_key) and an 130 + // "offline" recovery key that lives only in PLC. The genesis op lists offline 131 + // FIRST (highest priority). 132 + serverRot := generateK256(t) 133 + offlineRot := generateK256(t) 134 + 135 + // Original signing key used to build genesis; the local signing key on disk 136 + // will be different to force EnsureCurrent into the update path. 137 + originalSigning := generateK256(t) 138 + localSigning := generateK256(t) 139 + writeSigningKey(t, tmp, localSigning) 140 + 141 + fake := newFakePLC(t, []*atcrypto.PrivateKeyK256{offlineRot, serverRot}, serverRot, originalSigning) 142 + defer fake.Close() 143 + 144 + cfg := Config{ 145 + PublicURL: "https://example.test", 146 + PLCDirectoryURL: fake.URL(), 147 + VerificationKeyName: "atproto", 148 + Services: map[string]Service{ 149 + "atproto_pds": {Type: "AtprotoPersonalDataServer", Endpoint: "https://example.test"}, 150 + }, 151 + } 152 + 153 + if err := EnsureCurrent(ctx, fake.did, serverRot, localSigning, cfg); err != nil { 154 + t.Fatalf("EnsureCurrent: %v", err) 155 + } 156 + 157 + if len(fake.submitted) != 1 { 158 + t.Fatalf("expected exactly one update op submitted, got %d", len(fake.submitted)) 159 + } 160 + got := fake.submitted[0] 161 + 162 + offlinePub, _ := offlineRot.PublicKey() 163 + serverPub, _ := serverRot.PublicKey() 164 + want := []string{offlinePub.DIDKey(), serverPub.DIDKey()} 165 + 166 + if len(got.RotationKeys) != len(want) { 167 + t.Fatalf("rotation keys length: got %d want %d", len(got.RotationKeys), len(want)) 168 + } 169 + for i := range want { 170 + if got.RotationKeys[i] != want[i] { 171 + t.Errorf("rotation key [%d]: got %s want %s", i, got.RotationKeys[i], want[i]) 172 + } 173 + } 174 + 175 + // Verify signing key was actually rotated (sanity check we hit the update path). 176 + localPub, _ := localSigning.PublicKey() 177 + if got.VerificationMethods["atproto"] != localPub.DIDKey() { 178 + t.Errorf("expected signing key to update to local key %s, got %s", 179 + localPub.DIDKey(), got.VerificationMethods["atproto"]) 180 + } 181 + } 182 + 183 + func TestEnsureCurrent_RefusesUpdateWhenLocalKeyMissing(t *testing.T) { 184 + ctx := context.Background() 185 + tmp := t.TempDir() 186 + 187 + // Genesis lists only the offline key. The local server has been rotated out. 188 + offlineRot := generateK256(t) 189 + localRot := generateK256(t) // not in PLC 190 + 191 + originalSigning := generateK256(t) 192 + localSigning := generateK256(t) 193 + writeSigningKey(t, tmp, localSigning) 194 + 195 + fake := newFakePLC(t, []*atcrypto.PrivateKeyK256{offlineRot}, offlineRot, originalSigning) 196 + defer fake.Close() 197 + 198 + cfg := Config{ 199 + PublicURL: "https://example.test", 200 + PLCDirectoryURL: fake.URL(), 201 + VerificationKeyName: "atproto", 202 + Services: map[string]Service{ 203 + "atproto_pds": {Type: "AtprotoPersonalDataServer", Endpoint: "https://example.test"}, 204 + }, 205 + } 206 + 207 + // Local signing key has drifted, which would normally trigger an update. 208 + if err := EnsureCurrent(ctx, fake.did, localRot, localSigning, cfg); err != nil { 209 + t.Fatalf("EnsureCurrent returned error: %v", err) 210 + } 211 + 212 + if len(fake.submitted) != 0 { 213 + t.Fatalf("expected no update submission when local rotation key isn't in PLC list, got %d", len(fake.submitted)) 214 + } 215 + } 216 + 217 + func TestEnsureCurrent_NoOpWhenCurrent(t *testing.T) { 218 + ctx := context.Background() 219 + tmp := t.TempDir() 220 + 221 + serverRot := generateK256(t) 222 + signing := generateK256(t) 223 + writeSigningKey(t, tmp, signing) 224 + 225 + fake := newFakePLC(t, []*atcrypto.PrivateKeyK256{serverRot}, serverRot, signing) 226 + defer fake.Close() 227 + 228 + cfg := Config{ 229 + PublicURL: "https://example.test", 230 + PLCDirectoryURL: fake.URL(), 231 + VerificationKeyName: "atproto", 232 + Services: map[string]Service{ 233 + "atproto_pds": {Type: "AtprotoPersonalDataServer", Endpoint: "https://example.test"}, 234 + }, 235 + } 236 + 237 + if err := EnsureCurrent(ctx, fake.did, serverRot, signing, cfg); err != nil { 238 + t.Fatalf("EnsureCurrent: %v", err) 239 + } 240 + if len(fake.submitted) != 0 { 241 + t.Fatalf("expected no update when state is current, got %d submitted", len(fake.submitted)) 242 + } 243 + }
+24
pkg/atproto/did/web.go
··· 1 + package did 2 + 3 + import ( 4 + "fmt" 5 + "net/url" 6 + ) 7 + 8 + // GenerateDIDFromURL computes a did:web identifier from a public URL. 9 + // Per the did:web spec, ports are percent-encoded (`:` → `%3A`). 10 + func GenerateDIDFromURL(publicURL string) string { 11 + u, err := url.Parse(publicURL) 12 + if err != nil { 13 + return fmt.Sprintf("did:web:%s", publicURL) 14 + } 15 + hostname := u.Hostname() 16 + if hostname == "" { 17 + hostname = "localhost" 18 + } 19 + port := u.Port() 20 + if port != "" && port != "80" && port != "443" { 21 + return fmt.Sprintf("did:web:%s%%3A%s", hostname, port) 22 + } 23 + return fmt.Sprintf("did:web:%s", hostname) 24 + }
+38
pkg/atproto/did/web_test.go
··· 1 + package did 2 + 3 + import "testing" 4 + 5 + // TestGenerateDIDFromURL covers host extraction and port encoding for did:web. 6 + func TestGenerateDIDFromURL(t *testing.T) { 7 + cases := []struct { 8 + name string 9 + publicURL string 10 + want string 11 + }{ 12 + {"https no port", "https://hold.example.com", "did:web:hold.example.com"}, 13 + {"http no port", "http://hold.example.com", "did:web:hold.example.com"}, 14 + {"https port 443 stripped", "https://hold.example.com:443", "did:web:hold.example.com"}, 15 + {"http port 80 stripped", "http://hold.example.com:80", "did:web:hold.example.com"}, 16 + {"non-standard port encoded", "https://hold.example.com:8443", "did:web:hold.example.com%3A8443"}, 17 + {"localhost with port", "http://localhost:3000", "did:web:localhost%3A3000"}, 18 + {"trailing path ignored", "https://hold.example.com/foo/bar", "did:web:hold.example.com"}, 19 + {"subdomain preserved", "https://api.hold.example.com", "did:web:api.hold.example.com"}, 20 + } 21 + for _, tc := range cases { 22 + t.Run(tc.name, func(t *testing.T) { 23 + got := GenerateDIDFromURL(tc.publicURL) 24 + if got != tc.want { 25 + t.Errorf("GenerateDIDFromURL(%q): got %s want %s", tc.publicURL, got, tc.want) 26 + } 27 + }) 28 + } 29 + } 30 + 31 + // TestGenerateDIDFromURL_EmptyHost confirms a URL without a hostname falls back to localhost. 32 + // (url.Parse on a bare path does not error, so the fallback is the only signal we have.) 33 + func TestGenerateDIDFromURL_EmptyHost(t *testing.T) { 34 + got := GenerateDIDFromURL("") 35 + if got != "did:web:localhost" { 36 + t.Errorf("empty URL: got %s want did:web:localhost", got) 37 + } 38 + }
+4 -4
pkg/billing/billing.go
··· 644 644 } 645 645 646 646 var ( 647 - minQuota int64 = -1 648 - maxQuota int64 649 - scanCount int 650 - totalHolds int 647 + minQuota int64 = -1 648 + maxQuota int64 649 + scanCount int 650 + totalHolds int 651 651 ) 652 652 653 653 for _, cached := range m.holdTierCache {
+19
pkg/hold/config.go
··· 15 15 16 16 "github.com/spf13/viper" 17 17 18 + "atcr.io/pkg/atproto/did" 18 19 "atcr.io/pkg/config" 19 20 "atcr.io/pkg/hold/gc" 21 + "atcr.io/pkg/hold/pds" 20 22 "atcr.io/pkg/hold/quota" 21 23 ) 22 24 ··· 56 58 // ConfigPath returns the path to the YAML configuration file used to load this config. 57 59 // Subsystems (e.g. billing) use this to re-read the same file for extended fields. 58 60 func (c *Config) ConfigPath() string { return c.configPath } 61 + 62 + // DIDConfig builds the did.Config used to load or create the hold's identity. 63 + // The verification key fragment and service set are hold-specific (atproto PDS + 64 + // AtcrHoldService), so they're filled in here rather than at every callsite. 65 + func (c *Config) DIDConfig() did.Config { 66 + return did.Config{ 67 + DID: c.Database.DID, 68 + Method: c.Database.DIDMethod, 69 + PublicURL: c.Server.PublicURL, 70 + DBPath: c.Database.Path, 71 + SigningKeyPath: c.Database.KeyPath, 72 + RotationKey: c.Database.RotationKey, 73 + PLCDirectoryURL: c.Database.PLCDirectoryURL, 74 + VerificationKeyName: "atproto", 75 + Services: pds.HoldServices(c.Server.PublicURL), 76 + } 77 + } 59 78 60 79 // AdminConfig defines admin panel settings 61 80 type AdminConfig struct {
-435
pkg/hold/pds/did.go
··· 1 - package pds 2 - 3 - import ( 4 - "context" 5 - "encoding/json" 6 - "fmt" 7 - "log/slog" 8 - "net/url" 9 - "os" 10 - "path/filepath" 11 - "strings" 12 - 13 - "atcr.io/pkg/auth/oauth" 14 - "github.com/bluesky-social/indigo/atproto/atcrypto" 15 - didplc "github.com/did-method-plc/go-didplc" 16 - ) 17 - 18 - // DIDDocument represents a did:web document 19 - type DIDDocument struct { 20 - Context []string `json:"@context"` 21 - ID string `json:"id"` 22 - AlsoKnownAs []string `json:"alsoKnownAs,omitempty"` 23 - VerificationMethod []VerificationMethod `json:"verificationMethod"` 24 - Authentication []string `json:"authentication,omitempty"` 25 - AssertionMethod []string `json:"assertionMethod,omitempty"` 26 - Service []Service `json:"service,omitempty"` 27 - } 28 - 29 - // VerificationMethod represents a public key in a DID document 30 - type VerificationMethod struct { 31 - ID string `json:"id"` 32 - Type string `json:"type"` 33 - Controller string `json:"controller"` 34 - PublicKeyMultibase string `json:"publicKeyMultibase"` 35 - } 36 - 37 - // Service represents a service endpoint in a DID document 38 - type Service struct { 39 - ID string `json:"id"` 40 - Type string `json:"type"` 41 - ServiceEndpoint string `json:"serviceEndpoint"` 42 - } 43 - 44 - // GenerateDIDDocument creates a DID document for the hold's identity. 45 - // It uses the hold's stored DID (which may be did:web or did:plc). 46 - func (p *HoldPDS) GenerateDIDDocument(publicURL string) (*DIDDocument, error) { 47 - did := p.did 48 - 49 - // Parse URL for alsoKnownAs 50 - u, err := url.Parse(publicURL) 51 - if err != nil { 52 - return nil, fmt.Errorf("failed to parse public URL: %w", err) 53 - } 54 - host := u.Hostname() 55 - if port := u.Port(); port != "" && port != "80" && port != "443" { 56 - host = fmt.Sprintf("%s:%s", host, port) 57 - } 58 - 59 - // Get public key in multibase format using indigo's crypto 60 - pubKey, err := p.signingKey.PublicKey() 61 - if err != nil { 62 - return nil, fmt.Errorf("failed to get public key: %w", err) 63 - } 64 - publicKeyMultibase := pubKey.Multibase() 65 - 66 - doc := &DIDDocument{ 67 - Context: []string{ 68 - "https://www.w3.org/ns/did/v1", 69 - "https://w3id.org/security/multikey/v1", 70 - "https://w3id.org/security/suites/secp256k1-2019/v1", 71 - }, 72 - ID: did, 73 - AlsoKnownAs: []string{ 74 - fmt.Sprintf("at://%s", host), 75 - }, 76 - VerificationMethod: []VerificationMethod{ 77 - { 78 - ID: fmt.Sprintf("%s#atproto", did), 79 - Type: "Multikey", 80 - Controller: did, 81 - PublicKeyMultibase: publicKeyMultibase, 82 - }, 83 - }, 84 - Authentication: []string{ 85 - fmt.Sprintf("%s#atproto", did), 86 - }, 87 - Service: []Service{ 88 - { 89 - ID: "#atproto_pds", 90 - Type: "AtprotoPersonalDataServer", 91 - ServiceEndpoint: publicURL, 92 - }, 93 - { 94 - ID: "#atcr_hold", 95 - Type: "AtcrHoldService", 96 - ServiceEndpoint: publicURL, 97 - }, 98 - }, 99 - } 100 - 101 - return doc, nil 102 - } 103 - 104 - // MarshalDIDDocument converts a DID document to JSON using the stored public URL 105 - func (p *HoldPDS) MarshalDIDDocument() ([]byte, error) { 106 - doc, err := p.GenerateDIDDocument(p.PublicURL) 107 - if err != nil { 108 - return nil, err 109 - } 110 - 111 - return json.MarshalIndent(doc, "", " ") 112 - } 113 - 114 - // DIDConfig holds parameters for DID creation/loading. 115 - type DIDConfig struct { 116 - DID string // Explicit DID for adoption/recovery (optional) 117 - DIDMethod string // "web" or "plc" 118 - PublicURL string 119 - DBPath string 120 - SigningKeyPath string 121 - RotationKey string // Multibase-encoded private key, K-256 or P-256 (optional) 122 - PLCDirectoryURL string 123 - } 124 - 125 - // LoadOrCreateDID returns the hold's DID, either by deriving it from the URL (did:web) 126 - // or by loading/creating a did:plc identity registered with the PLC directory. 127 - // 128 - // For did:plc, the priority is: config DID > did.txt > create new. 129 - // When an existing DID is found (config or did.txt), EnsurePLCCurrent is called 130 - // to auto-update the PLC directory if the signing key or URL has changed. 131 - func LoadOrCreateDID(ctx context.Context, cfg DIDConfig) (string, error) { 132 - if cfg.DIDMethod != "plc" { 133 - return GenerateDIDFromURL(cfg.PublicURL), nil 134 - } 135 - 136 - didPath := filepath.Join(cfg.DBPath, "did.txt") 137 - 138 - // Priority: config DID > did.txt > create new 139 - var did string 140 - if cfg.DID != "" { 141 - if !strings.HasPrefix(cfg.DID, "did:plc:") { 142 - return "", fmt.Errorf("database.did must be a did:plc identifier, got %q", cfg.DID) 143 - } 144 - did = cfg.DID 145 - slog.Info("Using DID from config (adoption/recovery)", "did", did) 146 - } else if data, err := os.ReadFile(didPath); err == nil { 147 - d := strings.TrimSpace(string(data)) 148 - if strings.HasPrefix(d, "did:plc:") { 149 - did = d 150 - slog.Info("Loaded existing did:plc identity", "did", did) 151 - } 152 - } 153 - 154 - if did != "" { 155 - // Persist to did.txt (may be from config on first adoption) 156 - if err := os.MkdirAll(filepath.Dir(didPath), 0755); err != nil { 157 - return "", fmt.Errorf("failed to create directory for did.txt: %w", err) 158 - } 159 - if err := os.WriteFile(didPath, []byte(did+"\n"), 0600); err != nil { 160 - return "", fmt.Errorf("failed to write did.txt: %w", err) 161 - } 162 - 163 - // Load signing key (generate if missing — recovery case) 164 - signingKey, err := oauth.GenerateOrLoadPDSKey(cfg.SigningKeyPath) 165 - if err != nil { 166 - return "", fmt.Errorf("failed to load signing key: %w", err) 167 - } 168 - 169 - // Try to parse rotation key (optional — may not be configured) 170 - rotationKey, _ := parseOptionalMultibaseKey(cfg.RotationKey) 171 - 172 - if err := EnsurePLCCurrent(ctx, did, rotationKey, signingKey, cfg.PublicURL, cfg.PLCDirectoryURL); err != nil { 173 - slog.Warn("Failed to verify PLC identity is current (will retry on next restart)", 174 - "did", did, 175 - "error", err, 176 - ) 177 - } 178 - 179 - return did, nil 180 - } 181 - 182 - // No existing DID — create new genesis operation 183 - slog.Info("Creating new did:plc identity") 184 - 185 - // Load or generate signing key 186 - signingKey, err := oauth.GenerateOrLoadPDSKey(cfg.SigningKeyPath) 187 - if err != nil { 188 - return "", fmt.Errorf("failed to load signing key: %w", err) 189 - } 190 - 191 - // Parse or generate rotation key 192 - var rotationKey atcrypto.PrivateKeyExportable 193 - if cfg.RotationKey != "" { 194 - rotationKey, err = parseOptionalMultibaseKey(cfg.RotationKey) 195 - if err != nil { 196 - return "", fmt.Errorf("failed to parse rotation_key: %w", err) 197 - } 198 - } else { 199 - // Generate a new rotation key — user must save the multibase output 200 - rawKey, genErr := atcrypto.GeneratePrivateKeyK256() 201 - if genErr != nil { 202 - return "", fmt.Errorf("failed to generate rotation key: %w", genErr) 203 - } 204 - rotationKey = rawKey 205 - slog.Warn("Generated new rotation key — save this in your config as database.rotation_key", 206 - "rotation_key", rawKey.Multibase(), 207 - ) 208 - } 209 - 210 - did, err = CreatePLCIdentity(ctx, rotationKey, signingKey, cfg.PublicURL, cfg.PLCDirectoryURL) 211 - if err != nil { 212 - return "", fmt.Errorf("failed to create PLC identity: %w", err) 213 - } 214 - 215 - // Persist DID 216 - if err := os.MkdirAll(filepath.Dir(didPath), 0755); err != nil { 217 - return "", fmt.Errorf("failed to create directory for did.txt: %w", err) 218 - } 219 - if err := os.WriteFile(didPath, []byte(did+"\n"), 0600); err != nil { 220 - return "", fmt.Errorf("failed to write did.txt: %w", err) 221 - } 222 - 223 - slog.Info("Created did:plc identity", 224 - "did", did, 225 - "plc_directory", cfg.PLCDirectoryURL, 226 - ) 227 - slog.Warn("Back up your rotation_key. It is only needed for DID updates (URL changes, key rotation).") 228 - 229 - return did, nil 230 - } 231 - 232 - // parseOptionalMultibaseKey parses a multibase-encoded private key string (K-256 or P-256). 233 - // Returns nil, nil if the input is empty (key not configured). 234 - func parseOptionalMultibaseKey(encoded string) (atcrypto.PrivateKeyExportable, error) { 235 - if encoded == "" { 236 - return nil, nil 237 - } 238 - key, err := atcrypto.ParsePrivateMultibase(encoded) 239 - if err != nil { 240 - return nil, fmt.Errorf("failed to parse rotation key multibase string: %w", err) 241 - } 242 - return key, nil 243 - } 244 - 245 - // EnsurePLCCurrent checks the PLC directory for the given DID and updates it 246 - // if the local signing key or public URL doesn't match what's registered. 247 - // If rotationKey is nil, mismatches are logged as warnings but not fatal. 248 - func EnsurePLCCurrent(ctx context.Context, did string, rotationKey atcrypto.PrivateKey, signingKey *atcrypto.PrivateKeyK256, publicURL, plcDirectoryURL string) error { 249 - client := &didplc.Client{DirectoryURL: plcDirectoryURL} 250 - 251 - // Fetch current op log 252 - opLog, err := client.OpLog(ctx, did) 253 - if err != nil { 254 - return fmt.Errorf("failed to fetch PLC op log for %s: %w", did, err) 255 - } 256 - if len(opLog) == 0 { 257 - return fmt.Errorf("empty op log for %s", did) 258 - } 259 - 260 - lastEntry := opLog[len(opLog)-1] 261 - lastOp := lastEntry.Regular 262 - if lastOp == nil { 263 - // Last op is not a regular op (could be legacy or tombstone) — skip update 264 - slog.Warn("Last PLC operation is not a regular op, skipping auto-update", "did", did) 265 - return nil 266 - } 267 - 268 - // Compare local state vs PLC state 269 - sigPub, err := signingKey.PublicKey() 270 - if err != nil { 271 - return fmt.Errorf("failed to get signing public key: %w", err) 272 - } 273 - localVerificationKey := sigPub.DIDKey() 274 - plcVerificationKey := lastOp.VerificationMethods["atproto"] 275 - 276 - localEndpoint := publicURL 277 - var plcEndpoint string 278 - if svc, ok := lastOp.Services["atproto_pds"]; ok { 279 - plcEndpoint = svc.Endpoint 280 - } 281 - 282 - keyMatch := localVerificationKey == plcVerificationKey 283 - endpointMatch := localEndpoint == plcEndpoint 284 - 285 - if keyMatch && endpointMatch { 286 - slog.Info("PLC identity is current", "did", did) 287 - return nil 288 - } 289 - 290 - slog.Info("PLC identity needs update", 291 - "did", did, 292 - "signing_key_changed", !keyMatch, 293 - "endpoint_changed", !endpointMatch, 294 - ) 295 - 296 - if rotationKey == nil { 297 - slog.Warn("PLC document doesn't match local state but no rotation key available. Provide rotation key to auto-update PLC directory.", 298 - "did", did, 299 - "signing_key_changed", !keyMatch, 300 - "endpoint_changed", !endpointMatch, 301 - ) 302 - return nil 303 - } 304 - 305 - // Build update operation 306 - rotPub, err := rotationKey.PublicKey() 307 - if err != nil { 308 - return fmt.Errorf("failed to get rotation public key: %w", err) 309 - } 310 - 311 - // Extract hostname for alsoKnownAs 312 - u, err := url.Parse(publicURL) 313 - if err != nil { 314 - return fmt.Errorf("failed to parse public URL: %w", err) 315 - } 316 - host := u.Hostname() 317 - if port := u.Port(); port != "" && port != "80" && port != "443" { 318 - host = host + ":" + port 319 - } 320 - 321 - prevCID := lastEntry.AsOperation().CID().String() 322 - 323 - op := &didplc.RegularOp{ 324 - Type: "plc_operation", 325 - RotationKeys: []string{rotPub.DIDKey()}, 326 - VerificationMethods: map[string]string{ 327 - "atproto": localVerificationKey, 328 - }, 329 - AlsoKnownAs: []string{"at://" + host}, 330 - Services: map[string]didplc.OpService{ 331 - "atproto_pds": {Type: "AtprotoPersonalDataServer", Endpoint: publicURL}, 332 - "atcr_hold": {Type: "AtcrHoldService", Endpoint: publicURL}, 333 - }, 334 - Prev: &prevCID, 335 - } 336 - 337 - if err := op.Sign(rotationKey); err != nil { 338 - return fmt.Errorf("failed to sign PLC update operation: %w", err) 339 - } 340 - 341 - if err := client.Submit(ctx, did, op); err != nil { 342 - return fmt.Errorf("failed to submit PLC update: %w", err) 343 - } 344 - 345 - slog.Info("Updated PLC identity", 346 - "did", did, 347 - "signing_key_rotated", !keyMatch, 348 - "endpoint_changed", !endpointMatch, 349 - ) 350 - 351 - return nil 352 - } 353 - 354 - // CreatePLCIdentity creates a new did:plc identity by building a genesis operation, 355 - // signing it with the rotation key, and submitting it to the PLC directory. 356 - func CreatePLCIdentity(ctx context.Context, rotationKey atcrypto.PrivateKey, signingKey *atcrypto.PrivateKeyK256, publicURL, plcDirectoryURL string) (string, error) { 357 - rotPub, err := rotationKey.PublicKey() 358 - if err != nil { 359 - return "", fmt.Errorf("failed to get rotation public key: %w", err) 360 - } 361 - 362 - sigPub, err := signingKey.PublicKey() 363 - if err != nil { 364 - return "", fmt.Errorf("failed to get signing public key: %w", err) 365 - } 366 - 367 - // Extract hostname for alsoKnownAs 368 - u, err := url.Parse(publicURL) 369 - if err != nil { 370 - return "", fmt.Errorf("failed to parse public URL: %w", err) 371 - } 372 - host := u.Hostname() 373 - if port := u.Port(); port != "" && port != "80" && port != "443" { 374 - host = host + ":" + port 375 - } 376 - 377 - op := &didplc.RegularOp{ 378 - Type: "plc_operation", 379 - RotationKeys: []string{rotPub.DIDKey()}, 380 - VerificationMethods: map[string]string{ 381 - "atproto": sigPub.DIDKey(), 382 - }, 383 - AlsoKnownAs: []string{"at://" + host}, 384 - Services: map[string]didplc.OpService{ 385 - "atproto_pds": {Type: "AtprotoPersonalDataServer", Endpoint: publicURL}, 386 - "atcr_hold": {Type: "AtcrHoldService", Endpoint: publicURL}, 387 - }, 388 - Prev: nil, 389 - } 390 - 391 - if err := op.Sign(rotationKey); err != nil { 392 - return "", fmt.Errorf("failed to sign PLC genesis operation: %w", err) 393 - } 394 - 395 - did, err := op.DID() 396 - if err != nil { 397 - return "", fmt.Errorf("failed to compute DID from genesis operation: %w", err) 398 - } 399 - 400 - client := &didplc.Client{DirectoryURL: plcDirectoryURL} 401 - if err := client.Submit(ctx, did, op); err != nil { 402 - return "", fmt.Errorf("failed to submit genesis operation to PLC directory: %w", err) 403 - } 404 - 405 - return did, nil 406 - } 407 - 408 - // GenerateDIDFromURL creates a did:web identifier from a public URL. 409 - // Per the did:web spec, ports are percent-encoded: the colon becomes %3A. 410 - // Example: "http://hold1.example.com:8080" -> "did:web:hold1.example.com%3A8080" 411 - func GenerateDIDFromURL(publicURL string) string { 412 - // Parse URL 413 - u, err := url.Parse(publicURL) 414 - if err != nil { 415 - // Fallback: assume it's just a hostname 416 - return fmt.Sprintf("did:web:%s", publicURL) 417 - } 418 - 419 - // Get hostname 420 - hostname := u.Hostname() 421 - if hostname == "" { 422 - hostname = "localhost" 423 - } 424 - 425 - // Get port 426 - port := u.Port() 427 - 428 - // Include port in DID if it's non-standard (not 80 for http, not 443 for https) 429 - // Per did:web spec, the colon is percent-encoded as %3A 430 - if port != "" && port != "80" && port != "443" { 431 - return fmt.Sprintf("did:web:%s%%3A%s", hostname, port) 432 - } 433 - 434 - return fmt.Sprintf("did:web:%s", hostname) 435 - }
-274
pkg/hold/pds/did_test.go
··· 1 - package pds 2 - 3 - import ( 4 - "context" 5 - "encoding/json" 6 - "path/filepath" 7 - "testing" 8 - ) 9 - 10 - // TestGenerateDIDFromURL tests DID generation from various URL formats 11 - func TestGenerateDIDFromURL(t *testing.T) { 12 - tests := []struct { 13 - name string 14 - publicURL string 15 - expectedDID string 16 - }{ 17 - { 18 - name: "standard HTTP with standard port", 19 - publicURL: "http://hold.example.com", 20 - expectedDID: "did:web:hold.example.com", 21 - }, 22 - { 23 - name: "standard HTTPS with standard port", 24 - publicURL: "https://hold.example.com", 25 - expectedDID: "did:web:hold.example.com", 26 - }, 27 - { 28 - name: "HTTP with non-standard port", 29 - publicURL: "http://hold.example.com:8080", 30 - expectedDID: "did:web:hold.example.com%3A8080", 31 - }, 32 - { 33 - name: "HTTPS with non-standard port", 34 - publicURL: "https://hold.example.com:8443", 35 - expectedDID: "did:web:hold.example.com%3A8443", 36 - }, 37 - { 38 - name: "localhost with port", 39 - publicURL: "http://localhost:8080", 40 - expectedDID: "did:web:localhost%3A8080", 41 - }, 42 - { 43 - name: "HTTP with explicit port 80", 44 - publicURL: "http://hold.example.com:80", 45 - expectedDID: "did:web:hold.example.com", 46 - }, 47 - { 48 - name: "HTTPS with explicit port 443", 49 - publicURL: "https://hold.example.com:443", 50 - expectedDID: "did:web:hold.example.com", 51 - }, 52 - { 53 - name: "subdomain", 54 - publicURL: "https://hold1.atcr.io", 55 - expectedDID: "did:web:hold1.atcr.io", 56 - }, 57 - } 58 - 59 - for _, tt := range tests { 60 - t.Run(tt.name, func(t *testing.T) { 61 - did := GenerateDIDFromURL(tt.publicURL) 62 - if did != tt.expectedDID { 63 - t.Errorf("Expected DID %s, got %s", tt.expectedDID, did) 64 - } 65 - }) 66 - } 67 - } 68 - 69 - // TestGenerateDIDFromURL_InvalidURL tests handling of invalid URLs 70 - func TestGenerateDIDFromURL_InvalidURL(t *testing.T) { 71 - // Invalid URLs get parsed with empty hostname, which defaults to localhost 72 - did := GenerateDIDFromURL("not a url") 73 - if did != "did:web:localhost" { 74 - t.Errorf("Expected did:web:localhost for invalid URL, got %s", did) 75 - } 76 - } 77 - 78 - // TestGenerateDIDDocument tests DID document generation 79 - func TestGenerateDIDDocument(t *testing.T) { 80 - ctx := context.Background() 81 - tmpDir := t.TempDir() 82 - 83 - dbPath := filepath.Join(tmpDir, "pds.db") 84 - keyPath := filepath.Join(tmpDir, "signing-key") 85 - publicURL := "https://hold.example.com" 86 - 87 - pds, err := NewHoldPDS(ctx, "did:web:hold.example.com", publicURL, "https://atcr.io", dbPath, keyPath, false) 88 - if err != nil { 89 - t.Fatalf("Failed to create PDS: %v", err) 90 - } 91 - 92 - doc, err := pds.GenerateDIDDocument(publicURL) 93 - if err != nil { 94 - t.Fatalf("Failed to generate DID document: %v", err) 95 - } 96 - 97 - // Verify required fields 98 - if doc.ID != "did:web:hold.example.com" { 99 - t.Errorf("Expected DID did:web:hold.example.com, got %s", doc.ID) 100 - } 101 - 102 - // Verify context 103 - if len(doc.Context) != 3 { 104 - t.Errorf("Expected 3 context entries, got %d", len(doc.Context)) 105 - } 106 - 107 - expectedContexts := []string{ 108 - "https://www.w3.org/ns/did/v1", 109 - "https://w3id.org/security/multikey/v1", 110 - "https://w3id.org/security/suites/secp256k1-2019/v1", 111 - } 112 - for i, expected := range expectedContexts { 113 - if doc.Context[i] != expected { 114 - t.Errorf("Expected context[%d] = %s, got %s", i, expected, doc.Context[i]) 115 - } 116 - } 117 - 118 - // Verify alsoKnownAs 119 - if len(doc.AlsoKnownAs) != 1 || doc.AlsoKnownAs[0] != "at://hold.example.com" { 120 - t.Errorf("Expected alsoKnownAs=['at://hold.example.com'], got %v", doc.AlsoKnownAs) 121 - } 122 - 123 - // Verify verification method 124 - if len(doc.VerificationMethod) != 1 { 125 - t.Fatalf("Expected 1 verification method, got %d", len(doc.VerificationMethod)) 126 - } 127 - 128 - vm := doc.VerificationMethod[0] 129 - if vm.ID != "did:web:hold.example.com#atproto" { 130 - t.Errorf("Expected verification method ID did:web:hold.example.com#atproto, got %s", vm.ID) 131 - } 132 - if vm.Type != "Multikey" { 133 - t.Errorf("Expected type Multikey, got %s", vm.Type) 134 - } 135 - if vm.Controller != "did:web:hold.example.com" { 136 - t.Errorf("Expected controller did:web:hold.example.com, got %s", vm.Controller) 137 - } 138 - if vm.PublicKeyMultibase == "" { 139 - t.Error("Expected non-empty publicKeyMultibase") 140 - } 141 - 142 - // Verify authentication 143 - if len(doc.Authentication) != 1 || doc.Authentication[0] != "did:web:hold.example.com#atproto" { 144 - t.Errorf("Expected authentication=['did:web:hold.example.com#atproto'], got %v", doc.Authentication) 145 - } 146 - 147 - // Verify services 148 - if len(doc.Service) != 2 { 149 - t.Fatalf("Expected 2 services, got %d", len(doc.Service)) 150 - } 151 - 152 - // Check PDS service 153 - pdsService := doc.Service[0] 154 - if pdsService.ID != "#atproto_pds" { 155 - t.Errorf("Expected service ID #atproto_pds, got %s", pdsService.ID) 156 - } 157 - if pdsService.Type != "AtprotoPersonalDataServer" { 158 - t.Errorf("Expected service type AtprotoPersonalDataServer, got %s", pdsService.Type) 159 - } 160 - if pdsService.ServiceEndpoint != publicURL { 161 - t.Errorf("Expected service endpoint %s, got %s", publicURL, pdsService.ServiceEndpoint) 162 - } 163 - 164 - // Check hold service 165 - holdService := doc.Service[1] 166 - if holdService.ID != "#atcr_hold" { 167 - t.Errorf("Expected service ID #atcr_hold, got %s", holdService.ID) 168 - } 169 - if holdService.Type != "AtcrHoldService" { 170 - t.Errorf("Expected service type AtcrHoldService, got %s", holdService.Type) 171 - } 172 - if holdService.ServiceEndpoint != publicURL { 173 - t.Errorf("Expected service endpoint %s, got %s", publicURL, holdService.ServiceEndpoint) 174 - } 175 - } 176 - 177 - // TestGenerateDIDDocument_WithPort tests DID document with non-standard port 178 - func TestGenerateDIDDocument_WithPort(t *testing.T) { 179 - ctx := context.Background() 180 - tmpDir := t.TempDir() 181 - 182 - dbPath := filepath.Join(tmpDir, "pds.db") 183 - keyPath := filepath.Join(tmpDir, "signing-key") 184 - publicURL := "https://hold.example.com:8443" 185 - 186 - pds, err := NewHoldPDS(ctx, "did:web:hold.example.com%3A8443", publicURL, "https://atcr.io", dbPath, keyPath, false) 187 - if err != nil { 188 - t.Fatalf("Failed to create PDS: %v", err) 189 - } 190 - 191 - doc, err := pds.GenerateDIDDocument(publicURL) 192 - if err != nil { 193 - t.Fatalf("Failed to generate DID document: %v", err) 194 - } 195 - 196 - // Verify DID includes percent-encoded port 197 - if doc.ID != "did:web:hold.example.com%3A8443" { 198 - t.Errorf("Expected DID did:web:hold.example.com%%3A8443, got %s", doc.ID) 199 - } 200 - 201 - // Verify alsoKnownAs includes port 202 - if doc.AlsoKnownAs[0] != "at://hold.example.com:8443" { 203 - t.Errorf("Expected alsoKnownAs with port, got %s", doc.AlsoKnownAs[0]) 204 - } 205 - } 206 - 207 - // TestMarshalDIDDocument tests DID document JSON marshaling 208 - func TestMarshalDIDDocument(t *testing.T) { 209 - ctx := context.Background() 210 - tmpDir := t.TempDir() 211 - 212 - dbPath := filepath.Join(tmpDir, "pds.db") 213 - keyPath := filepath.Join(tmpDir, "signing-key") 214 - publicURL := "https://hold.example.com" 215 - 216 - pds, err := NewHoldPDS(ctx, "did:web:hold.example.com", publicURL, "https://atcr.io", dbPath, keyPath, false) 217 - if err != nil { 218 - t.Fatalf("Failed to create PDS: %v", err) 219 - } 220 - 221 - jsonBytes, err := pds.MarshalDIDDocument() 222 - if err != nil { 223 - t.Fatalf("Failed to marshal DID document: %v", err) 224 - } 225 - 226 - // Verify it's valid JSON 227 - var doc map[string]any 228 - if err := json.Unmarshal(jsonBytes, &doc); err != nil { 229 - t.Fatalf("Failed to unmarshal DID document JSON: %v", err) 230 - } 231 - 232 - // Verify required fields 233 - if id, ok := doc["id"].(string); !ok || id != "did:web:hold.example.com" { 234 - t.Errorf("Expected id='did:web:hold.example.com', got %v", doc["id"]) 235 - } 236 - 237 - if _, ok := doc["@context"]; !ok { 238 - t.Error("Expected @context field in JSON") 239 - } 240 - 241 - if _, ok := doc["verificationMethod"]; !ok { 242 - t.Error("Expected verificationMethod field in JSON") 243 - } 244 - 245 - if _, ok := doc["service"]; !ok { 246 - t.Error("Expected service field in JSON") 247 - } 248 - 249 - // Verify pretty-printed (has indentation) 250 - if len(jsonBytes) < 100 { 251 - t.Error("Expected pretty-printed JSON to be reasonably sized") 252 - } 253 - } 254 - 255 - // TestGenerateDIDDocument_InvalidURL tests error handling 256 - func TestGenerateDIDDocument_InvalidURL(t *testing.T) { 257 - ctx := context.Background() 258 - tmpDir := t.TempDir() 259 - 260 - dbPath := filepath.Join(tmpDir, "pds.db") 261 - keyPath := filepath.Join(tmpDir, "signing-key") 262 - publicURL := "https://hold.example.com" 263 - 264 - pds, err := NewHoldPDS(ctx, "did:web:hold.example.com", publicURL, "https://atcr.io", dbPath, keyPath, false) 265 - if err != nil { 266 - t.Fatalf("Failed to create PDS: %v", err) 267 - } 268 - 269 - // Try to generate DID document with invalid URL 270 - _, err = pds.GenerateDIDDocument("ht!tp://invalid url") 271 - if err == nil { 272 - t.Error("Expected error for invalid URL, got nil") 273 - } 274 - }
-11
pkg/hold/pds/export.go
··· 1 - package pds 2 - 3 - import ( 4 - "context" 5 - "io" 6 - ) 7 - 8 - // ExportToCAR streams the hold's repo as a CAR file to the writer. 9 - func (p *HoldPDS) ExportToCAR(ctx context.Context, w io.Writer) error { 10 - return p.repomgr.ReadRepo(ctx, p.uid, "", w) 11 - }
+12
pkg/hold/pds/server.go pkg/hold/pds/hold_pds.go
··· 11 11 "strings" 12 12 13 13 "atcr.io/pkg/atproto" 14 + "atcr.io/pkg/atproto/did" 14 15 "atcr.io/pkg/auth/oauth" 15 16 holddb "atcr.io/pkg/hold/db" 16 17 "atcr.io/pkg/s3" ··· 34 35 lexutil.RegisterType(atproto.DailyStatsCollection, &atproto.DailyStatsRecord{}) 35 36 lexutil.RegisterType(atproto.ScanCollection, &atproto.ScanRecord{}) 36 37 lexutil.RegisterType(atproto.ImageConfigCollection, &atproto.ImageConfigRecord{}) 38 + } 39 + 40 + // HoldServices returns the service entries the hold publishes in its DID document 41 + // and PLC operations: an atproto PDS endpoint plus the ATCR hold service endpoint. 42 + // Single source of truth, used by both boot-time identity loading and DID-document 43 + // serving. 44 + func HoldServices(publicURL string) map[string]did.Service { 45 + return map[string]did.Service{ 46 + "atproto_pds": {Type: "AtprotoPersonalDataServer", Endpoint: publicURL}, 47 + "atcr_hold": {Type: "AtcrHoldService", Endpoint: publicURL}, 48 + } 37 49 } 38 50 39 51 // HoldPDS is a minimal ATProto PDS implementation for a hold service
pkg/hold/pds/server_test.go pkg/hold/pds/hold_pds_test.go
+3 -2
pkg/hold/pds/xrpc.go
··· 7 7 "fmt" 8 8 9 9 "atcr.io/pkg/atproto" 10 + "atcr.io/pkg/atproto/did" 10 11 "atcr.io/pkg/hold/quota" 11 12 "atcr.io/pkg/s3" 12 13 "github.com/bluesky-social/indigo/api/bsky" ··· 425 426 } 426 427 427 428 // Generate DID document 428 - didDoc, err := h.pds.GenerateDIDDocument(h.pds.PublicURL) 429 + didDoc, err := did.BuildDIDDocument(h.pds.DID(), h.pds.PublicURL, h.pds.SigningKey(), "atproto", HoldServices(h.pds.PublicURL)) 429 430 if err != nil { 430 431 http.Error(w, fmt.Sprintf("failed to generate DID document: %v", err), http.StatusInternalServerError) 431 432 return ··· 1387 1388 1388 1389 // HandleDIDDocument returns the DID document 1389 1390 func (h *XRPCHandler) HandleDIDDocument(w http.ResponseWriter, r *http.Request) { 1390 - doc, err := h.pds.GenerateDIDDocument(h.pds.PublicURL) 1391 + doc, err := did.BuildDIDDocument(h.pds.DID(), h.pds.PublicURL, h.pds.SigningKey(), "atproto", HoldServices(h.pds.PublicURL)) 1391 1392 if err != nil { 1392 1393 http.Error(w, fmt.Sprintf("failed to generate DID document: %v", err), http.StatusInternalServerError) 1393 1394 return
+2 -9
pkg/hold/server.go
··· 11 11 "time" 12 12 13 13 "atcr.io/pkg/atproto" 14 + "atcr.io/pkg/atproto/did" 14 15 "atcr.io/pkg/hold/admin" 15 16 holddb "atcr.io/pkg/hold/db" 16 17 "atcr.io/pkg/hold/gc" ··· 76 77 if cfg.Database.Path != "" { 77 78 ctx := context.Background() 78 79 79 - holdDID, err := pds.LoadOrCreateDID(ctx, pds.DIDConfig{ 80 - DID: cfg.Database.DID, 81 - DIDMethod: cfg.Database.DIDMethod, 82 - PublicURL: cfg.Server.PublicURL, 83 - DBPath: cfg.Database.Path, 84 - SigningKeyPath: cfg.Database.KeyPath, 85 - RotationKey: cfg.Database.RotationKey, 86 - PLCDirectoryURL: cfg.Database.PLCDirectoryURL, 87 - }) 80 + holdDID, err := did.LoadOrCreate(ctx, cfg.DIDConfig()) 88 81 if err != nil { 89 82 return nil, fmt.Errorf("failed to resolve hold DID: %w", err) 90 83 }
+188 -21
pkg/labeler/auth.go
··· 1 1 package labeler 2 2 3 3 import ( 4 + "context" 4 5 "crypto/rand" 6 + "crypto/subtle" 5 7 "encoding/base64" 8 + "fmt" 9 + "html/template" 10 + "log/slog" 11 + "net" 6 12 "net/http" 13 + "strings" 7 14 "sync" 15 + "time" 8 16 ) 9 17 10 - // Session represents an authenticated admin session. 18 + const ( 19 + sessionCookieName = "labeler_session" 20 + sessionTTL = 24 * time.Hour 21 + csrfHeaderName = "X-CSRF-Token" 22 + csrfFormField = "csrf_token" 23 + ) 24 + 25 + // Session represents an authenticated admin session. Restart wipes the in-memory 26 + // map so any stolen cookie token becomes useless after a restart, by design. 27 + // 28 + // UserAgent and IPPrefix are captured at login and rechecked on every request — 29 + // a stolen token replayed from a different browser or network prefix is rejected 30 + // and the session is torn down. Empty bound values (Unix sockets, tests, unusual 31 + // proxies) opt out rather than locking users out. Binding at /24 (IPv4) / /64 32 + // (IPv6) tolerates DHCP renewals within a prefix without inviting cross-network 33 + // replay. 11 34 type Session struct { 12 - DID string 13 - Handle string 35 + DID string 36 + Handle string 37 + CSRFToken string 38 + CreatedAt time.Time 39 + UserAgent string 40 + IPPrefix string 14 41 } 15 42 16 - // Auth manages admin authentication. 43 + // Auth manages in-memory admin sessions for the labeler. 17 44 type Auth struct { 18 45 ownerDID string 19 46 sessions map[string]*Session 20 47 sessionsMu sync.RWMutex 21 48 } 22 49 23 - // NewAuth creates a new Auth manager. 50 + // NewAuth wires a fresh in-memory session store keyed to the configured owner DID. 24 51 func NewAuth(ownerDID string) *Auth { 25 52 return &Auth{ 26 53 ownerDID: ownerDID, ··· 28 55 } 29 56 } 30 57 31 - func (a *Auth) createSession(did, handle string) (string, error) { 58 + func randToken() (string, error) { 32 59 b := make([]byte, 32) 33 60 if _, err := rand.Read(b); err != nil { 34 - return "", err 61 + return "", fmt.Errorf("generate token: %w", err) 35 62 } 36 - token := base64.URLEncoding.EncodeToString(b) 63 + return base64.URLEncoding.EncodeToString(b), nil 64 + } 37 65 66 + // CreateSession installs a new in-memory session and returns its cookie token 67 + // alongside the embedded CSRF token (for echoing into forms). 68 + func (a *Auth) CreateSession(did, handle, userAgent, ipPrefix string) (string, *Session, error) { 69 + token, err := randToken() 70 + if err != nil { 71 + return "", nil, err 72 + } 73 + csrfToken, err := randToken() 74 + if err != nil { 75 + return "", nil, err 76 + } 77 + s := &Session{ 78 + DID: did, 79 + Handle: handle, 80 + CSRFToken: csrfToken, 81 + CreatedAt: time.Now(), 82 + UserAgent: userAgent, 83 + IPPrefix: ipPrefix, 84 + } 38 85 a.sessionsMu.Lock() 39 - a.sessions[token] = &Session{DID: did, Handle: handle} 86 + a.sessions[token] = s 40 87 a.sessionsMu.Unlock() 41 - 42 - return token, nil 88 + return token, s, nil 43 89 } 44 90 45 - func (a *Auth) getSession(token string) *Session { 91 + // GetSession returns the session for the cookie token, evicting expired entries on access. 92 + func (a *Auth) GetSession(token string) *Session { 46 93 a.sessionsMu.RLock() 47 - defer a.sessionsMu.RUnlock() 48 - return a.sessions[token] 94 + s := a.sessions[token] 95 + a.sessionsMu.RUnlock() 96 + if s == nil { 97 + return nil 98 + } 99 + if !s.CreatedAt.IsZero() && time.Since(s.CreatedAt) > sessionTTL { 100 + a.sessionsMu.Lock() 101 + delete(a.sessions, token) 102 + a.sessionsMu.Unlock() 103 + return nil 104 + } 105 + return s 49 106 } 50 107 51 - func (a *Auth) deleteSession(token string) { 108 + // DeleteSession removes a session by cookie token (logout). 109 + func (a *Auth) DeleteSession(token string) { 52 110 a.sessionsMu.Lock() 53 111 delete(a.sessions, token) 54 112 a.sessionsMu.Unlock() 55 113 } 56 114 57 - const sessionCookieName = "labeler_session" 58 - 59 115 func setSessionCookie(w http.ResponseWriter, r *http.Request, token string) { 60 116 secure := r.TLS != nil || r.Header.Get("X-Forwarded-Proto") == "https" 61 117 http.SetCookie(w, &http.Cookie{ 62 118 Name: sessionCookieName, 63 119 Value: token, 64 120 Path: "/", 65 - MaxAge: 86400, // 24 hours 121 + MaxAge: int(sessionTTL.Seconds()), 66 122 HttpOnly: true, 67 123 Secure: secure, 68 124 SameSite: http.SameSiteLaxMode, ··· 88 144 return cookie.Value, true 89 145 } 90 146 91 - // RequireOwner is middleware that checks the session belongs to the owner DID. 147 + type sessionContextKeyT struct{} 148 + 149 + var sessionContextKey = sessionContextKeyT{} 150 + 151 + // SessionFromContext returns the session attached to the request context, if any. 152 + func SessionFromContext(ctx context.Context) *Session { 153 + s, _ := ctx.Value(sessionContextKey).(*Session) 154 + return s 155 + } 156 + 157 + // RequireOwner enforces a valid session bound to the owner DID, with UA / IP-prefix 158 + // replay defense. State-mutating methods then go through the CSRF check below. 92 159 func (a *Auth) RequireOwner(next http.Handler) http.Handler { 93 160 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 94 161 token, ok := getSessionCookie(r) ··· 96 163 http.Redirect(w, r, "/auth/login", http.StatusFound) 97 164 return 98 165 } 99 - session := a.getSession(token) 100 - if session == nil || session.DID != a.ownerDID { 166 + session := a.GetSession(token) 167 + if session == nil { 168 + clearSessionCookie(w) 101 169 http.Redirect(w, r, "/auth/login", http.StatusFound) 102 170 return 103 171 } 172 + if session.DID != a.ownerDID { 173 + a.DeleteSession(token) 174 + clearSessionCookie(w) 175 + http.Redirect(w, r, "/auth/login?error=access+denied", http.StatusFound) 176 + return 177 + } 178 + if session.UserAgent != "" && session.UserAgent != r.UserAgent() { 179 + slog.Warn("Admin session UA mismatch — suspected token replay", "did", session.DID) 180 + a.DeleteSession(token) 181 + clearSessionCookie(w) 182 + http.Redirect(w, r, "/auth/login?error=access+denied", http.StatusFound) 183 + return 184 + } 185 + if session.IPPrefix != "" { 186 + if now := clientIPPrefix(r); now != "" && now != session.IPPrefix { 187 + slog.Warn("Admin session IP-prefix mismatch — suspected token replay", 188 + "did", session.DID, "session", session.IPPrefix, "request", now) 189 + a.DeleteSession(token) 190 + clearSessionCookie(w) 191 + http.Redirect(w, r, "/auth/login?error=access+denied", http.StatusFound) 192 + return 193 + } 194 + } 195 + 196 + ctx := context.WithValue(r.Context(), sessionContextKey, session) 197 + next.ServeHTTP(w, r.WithContext(ctx)) 198 + }) 199 + } 200 + 201 + // RequireCSRF validates a per-session CSRF token on state-mutating requests. Safe 202 + // methods pass through. Token comes from X-CSRF-Token header or the csrf_token form 203 + // field for application/x-www-form-urlencoded bodies. Must run after RequireOwner. 204 + func (a *Auth) RequireCSRF(next http.Handler) http.Handler { 205 + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 206 + switch r.Method { 207 + case http.MethodGet, http.MethodHead, http.MethodOptions: 208 + next.ServeHTTP(w, r) 209 + return 210 + } 211 + session := SessionFromContext(r.Context()) 212 + if session == nil || session.CSRFToken == "" { 213 + http.Error(w, "Forbidden: missing session", http.StatusForbidden) 214 + return 215 + } 216 + got := r.Header.Get(csrfHeaderName) 217 + if got == "" { 218 + contentType := r.Header.Get("Content-Type") 219 + if idx := strings.IndexByte(contentType, ';'); idx >= 0 { 220 + contentType = contentType[:idx] 221 + } 222 + contentType = strings.TrimSpace(strings.ToLower(contentType)) 223 + if contentType == "application/x-www-form-urlencoded" { 224 + if err := r.ParseForm(); err == nil { 225 + got = r.PostFormValue(csrfFormField) 226 + } 227 + } 228 + } 229 + if subtle.ConstantTimeCompare([]byte(got), []byte(session.CSRFToken)) != 1 { 230 + slog.Warn("Labeler CSRF mismatch", "path", r.URL.Path, "did", session.DID) 231 + http.Error(w, "Forbidden: CSRF token mismatch — reload the page and try again.", http.StatusForbidden) 232 + return 233 + } 104 234 next.ServeHTTP(w, r) 105 235 }) 106 236 } 237 + 238 + // csrfInputHTML emits a hidden form input carrying the per-session CSRF token. 239 + func csrfInputHTML(token string) template.HTML { 240 + escaped := template.HTMLEscapeString(token) 241 + return template.HTML(`<input type="hidden" name="` + csrfFormField + `" value="` + escaped + `">`) 242 + } 243 + 244 + // clientIPPrefix returns a stable prefix key for the request's client IP — /24 for 245 + // IPv4, /64 for IPv6. Empty string means "don't bind" (avoids locking users behind 246 + // unusual proxies / Unix sockets / tests). 247 + func clientIPPrefix(r *http.Request) string { 248 + var host string 249 + if fwd := r.Header.Get("X-Forwarded-For"); fwd != "" { 250 + if comma := strings.IndexByte(fwd, ','); comma >= 0 { 251 + host = strings.TrimSpace(fwd[:comma]) 252 + } else { 253 + host = strings.TrimSpace(fwd) 254 + } 255 + } else { 256 + h, _, err := net.SplitHostPort(r.RemoteAddr) 257 + if err == nil { 258 + host = h 259 + } else { 260 + host = r.RemoteAddr 261 + } 262 + } 263 + ip := net.ParseIP(host) 264 + if ip == nil { 265 + return "" 266 + } 267 + if v4 := ip.To4(); v4 != nil { 268 + return fmt.Sprintf("v4:%d.%d.%d", v4[0], v4[1], v4[2]) 269 + } 270 + v6 := ip.To16() 271 + return fmt.Sprintf("v6:%02x%02x%02x%02x%02x%02x%02x%02x", 272 + v6[0], v6[1], v6[2], v6[3], v6[4], v6[5], v6[6], v6[7]) 273 + }
+78 -18
pkg/labeler/config.go
··· 5 5 import ( 6 6 "fmt" 7 7 "net/url" 8 + "path/filepath" 8 9 "strings" 10 + "time" 9 11 10 12 "github.com/spf13/viper" 11 13 ··· 30 32 // Listen address for the labeler HTTP server. 31 33 Addr string `yaml:"addr" comment:"Listen address for labeler (e.g., :5002)."` 32 34 35 + // PublicURL is the externally reachable URL of the labeler. When empty the URL is 36 + // derived from server.base_url by prefixing "labeler." (so https://atcr.io → 37 + // https://labeler.atcr.io). Set explicitly for IP-based dev environments. 38 + PublicURL string `yaml:"public_url" comment:"Externally reachable labeler URL. Empty = derive from server.base_url."` 39 + 33 40 // DID of the labeler admin. Only this DID can log into the admin panel. 34 41 OwnerDID string `yaml:"owner_did" comment:"DID of the labeler admin. Only this DID can log into the admin panel."` 35 42 36 - // Path to labeler SQLite database. 37 - DBPath string `yaml:"db_path" comment:"Path to labeler SQLite database."` 43 + // Directory for labeler state: SQLite database, signing key, did.txt. 44 + DataDir string `yaml:"data_dir" comment:"Directory for labeler state (database, signing key, did.txt)."` 45 + 46 + // DID method: "plc" (recommended, portable) or "web" (hostname-bound). 47 + DIDMethod string `yaml:"did_method" comment:"DID method: \"plc\" (recommended) or \"web\"."` 48 + 49 + // Explicit DID for did:plc adoption/recovery (optional). 50 + DID string `yaml:"did" comment:"Explicit did:plc identifier for adoption/recovery (optional)."` 51 + 52 + // Signing key path (defaults to <DataDir>/signing.key). 53 + KeyPath string `yaml:"key_path" comment:"Path to K-256 signing key (defaults to <data_dir>/signing.key)."` 54 + 55 + // Rotation key multibase (K-256 or P-256). Required to update the PLC document. 56 + RotationKey string `yaml:"rotation_key" comment:"Multibase-encoded rotation key (K-256 or P-256). Required to update the PLC document."` 57 + 58 + // PLC directory URL (default https://plc.directory). 59 + PLCDirectoryURL string `yaml:"plc_directory_url" comment:"PLC directory URL (default https://plc.directory)."` 60 + 61 + // LibsqlSyncURL enables embedded-replica sync to a remote libSQL/Bunny database when set. 62 + // Empty = local-only mode (default). 63 + LibsqlSyncURL string `yaml:"libsql_sync_url" comment:"Optional libSQL/Bunny remote sync URL. Empty = local-only."` 64 + 65 + // LibsqlAuthToken is the auth token for the remote libSQL database. 66 + LibsqlAuthToken string `yaml:"libsql_auth_token" comment:"Auth token for libsql_sync_url."` 67 + 68 + // LibsqlSyncInterval is how often the embedded replica pulls from the remote. 69 + LibsqlSyncInterval time.Duration `yaml:"libsql_sync_interval" comment:"Embedded-replica pull interval (e.g. 30s). 0 = manual sync only."` 38 70 } 39 71 40 72 // AppviewServerConfig is a subset of the appview ServerConfig that the labeler needs. ··· 45 77 TestMode bool `yaml:"test_mode"` 46 78 } 47 79 48 - // PublicURL returns the labeler's public URL derived from the appview base URL. 49 - // If appview is https://atcr.io, labeler is https://labeler.atcr.io. 80 + // PublicURL returns the labeler's externally reachable URL. When labeler.public_url 81 + // is set explicitly it wins; otherwise it's derived from server.base_url by prefixing 82 + // "labeler." (so https://atcr.io → https://labeler.atcr.io). 50 83 func (c *Config) PublicURL() string { 84 + if c.Labeler.PublicURL != "" { 85 + return c.Labeler.PublicURL 86 + } 51 87 u, err := url.Parse(c.Server.BaseURL) 52 88 if err != nil { 53 89 return "" ··· 56 92 return u.String() 57 93 } 58 94 59 - // DID returns the labeler's did:web identity derived from its public URL. 60 - func (c *Config) DID() string { 61 - u, err := url.Parse(c.PublicURL()) 62 - if err != nil { 63 - return "" 95 + // DBPath returns the path to the SQLite database file inside the data dir. 96 + func (c *Config) DBPath() string { 97 + return filepath.Join(c.Labeler.DataDir, "labeler.db") 98 + } 99 + 100 + // SigningKeyPath returns the configured signing key path or the default inside DataDir. 101 + func (c *Config) SigningKeyPath() string { 102 + if c.Labeler.KeyPath != "" { 103 + return c.Labeler.KeyPath 64 104 } 65 - host := u.Hostname() 66 - if port := u.Port(); port != "" { 67 - host += "%3A" + port 105 + return filepath.Join(c.Labeler.DataDir, "signing.key") 106 + } 107 + 108 + // PLCDirectoryURL returns the configured PLC directory URL or the canonical default. 109 + func (c *Config) PLCDirectoryURL() string { 110 + if c.Labeler.PLCDirectoryURL != "" { 111 + return c.Labeler.PLCDirectoryURL 68 112 } 69 - return "did:web:" + host 113 + return "https://plc.directory" 70 114 } 71 115 72 116 func setDefaults(v *viper.Viper) { ··· 76 120 // Labeler defaults 77 121 v.SetDefault("labeler.enabled", false) 78 122 v.SetDefault("labeler.addr", ":5002") 123 + v.SetDefault("labeler.public_url", "") 79 124 v.SetDefault("labeler.owner_did", "") 80 - v.SetDefault("labeler.db_path", "/var/lib/atcr-labeler/labeler.db") 125 + v.SetDefault("labeler.data_dir", "/var/lib/atcr-labeler") 126 + v.SetDefault("labeler.did_method", "plc") 127 + v.SetDefault("labeler.did", "") 128 + v.SetDefault("labeler.key_path", "") 129 + v.SetDefault("labeler.rotation_key", "") 130 + v.SetDefault("labeler.plc_directory_url", "https://plc.directory") 131 + v.SetDefault("labeler.libsql_sync_url", "") 132 + v.SetDefault("labeler.libsql_auth_token", "") 133 + v.SetDefault("labeler.libsql_sync_interval", 0) 81 134 82 135 // Server defaults (read from shared appview config) 83 136 v.SetDefault("server.base_url", "") ··· 121 174 if !strings.HasPrefix(cfg.Labeler.OwnerDID, "did:") { 122 175 return nil, fmt.Errorf("labeler.owner_did must be a DID (got %q)", cfg.Labeler.OwnerDID) 123 176 } 177 + switch cfg.Labeler.DIDMethod { 178 + case "plc", "web": 179 + default: 180 + return nil, fmt.Errorf("labeler.did_method must be \"plc\" or \"web\" (got %q)", cfg.Labeler.DIDMethod) 181 + } 124 182 125 183 return cfg, nil 126 184 } ··· 136 194 ClientShortName: "ATCR", 137 195 }, 138 196 Labeler: LabelerConfig{ 139 - Enabled: true, 140 - Addr: ":5002", 141 - OwnerDID: "did:plc:your-did-here", 142 - DBPath: "/var/lib/atcr-labeler/labeler.db", 197 + Enabled: true, 198 + Addr: ":5002", 199 + OwnerDID: "did:plc:your-did-here", 200 + DataDir: "/var/lib/atcr-labeler", 201 + DIDMethod: "plc", 202 + PLCDirectoryURL: "https://plc.directory", 143 203 }, 144 204 } 145 205 return config.MarshalCommentedYAML("ATCR Labeler Configuration", cfg)
-24
pkg/labeler/config_test.go
··· 25 25 }) 26 26 } 27 27 } 28 - 29 - func TestConfig_DID(t *testing.T) { 30 - tests := []struct { 31 - name string 32 - baseURL string 33 - want string 34 - }{ 35 - {"standard", "https://atcr.io", "did:web:labeler.atcr.io"}, 36 - {"with port", "https://atcr.io:8080", "did:web:labeler.atcr.io%3A8080"}, 37 - {"localhost", "http://localhost:5000", "did:web:labeler.localhost%3A5000"}, 38 - } 39 - 40 - for _, tt := range tests { 41 - t.Run(tt.name, func(t *testing.T) { 42 - cfg := &Config{ 43 - Server: AppviewServerConfig{BaseURL: tt.baseURL}, 44 - } 45 - got := cfg.DID() 46 - if got != tt.want { 47 - t.Errorf("DID() = %q, want %q", got, tt.want) 48 - } 49 - }) 50 - } 51 - }
+240 -86
pkg/labeler/db.go
··· 3 3 import ( 4 4 "database/sql" 5 5 "fmt" 6 + "io" 7 + "log/slog" 6 8 "os" 7 9 "path/filepath" 10 + "strings" 8 11 "time" 9 12 10 - _ "github.com/tursodatabase/go-libsql" 13 + "github.com/bluesky-social/indigo/atproto/atcrypto" 14 + "github.com/bluesky-social/indigo/atproto/labeling" 15 + "github.com/tursodatabase/go-libsql" 11 16 ) 17 + 18 + // LabelVersion is the ATProto label format version (currently 1). 19 + const LabelVersion int64 = labeling.ATPROTO_LABEL_VERSION 12 20 13 21 const schema = ` 14 22 CREATE TABLE IF NOT EXISTS labels ( ··· 20 28 neg BOOLEAN NOT NULL DEFAULT 0, 21 29 cts TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, 22 30 exp TIMESTAMP, 31 + ver INTEGER NOT NULL DEFAULT 1, 32 + sig BLOB NOT NULL, 23 33 subject_did TEXT NOT NULL, 24 - subject_repo TEXT NOT NULL DEFAULT '', 25 - UNIQUE(src, uri, val, neg) 34 + subject_repo TEXT NOT NULL DEFAULT '' 26 35 ); 27 36 CREATE INDEX IF NOT EXISTS idx_labels_subject ON labels(subject_did, subject_repo); 28 37 CREATE INDEX IF NOT EXISTS idx_labels_cts ON labels(cts DESC); 38 + CREATE INDEX IF NOT EXISTS idx_labels_uri ON labels(uri); 29 39 ` 30 40 31 - // Label represents an ATProto label (com.atproto.label.defs#label). 41 + // Label represents an ATProto label record stored locally. Its on-the-wire representation 42 + // is produced by ToLabeling() which round-trips through indigo's labeling package so the 43 + // signature stays valid byte-for-byte. 32 44 type Label struct { 33 45 ID int64 34 46 Src string ··· 38 50 Neg bool 39 51 Cts time.Time 40 52 Exp *time.Time 53 + Ver int64 54 + Sig []byte 41 55 SubjectDID string 42 56 SubjectRepo string 43 57 } 44 58 45 - // OpenDB opens or creates the labeler database. 46 - func OpenDB(dbPath string) (*sql.DB, error) { 59 + // LibsqlSync configures optional embedded-replica sync to a remote libSQL database. 60 + // SyncURL empty means local-only mode. 61 + type LibsqlSync struct { 62 + SyncURL string 63 + AuthToken string 64 + SyncInterval time.Duration 65 + } 66 + 67 + // LabelerDB wraps the *sql.DB plus its libsql connector (when in embedded-replica mode) 68 + // so the caller can release file locks on shutdown. 69 + type LabelerDB struct { 70 + DB *sql.DB 71 + connector io.Closer 72 + } 73 + 74 + // Close closes the database and the libsql connector (if any). The connector close is 75 + // what releases file locks; without it a subsequent local-only open errors with 76 + // "database is locked" — the same gotcha the hold ran into. 77 + func (l *LabelerDB) Close() error { 78 + var dbErr, connErr error 79 + if l.DB != nil { 80 + dbErr = l.DB.Close() 81 + } 82 + if l.connector != nil { 83 + connErr = l.connector.Close() 84 + } 85 + if dbErr != nil { 86 + return dbErr 87 + } 88 + return connErr 89 + } 90 + 91 + // OpenDB opens or creates the labeler database. When sync.SyncURL is set, the DB runs 92 + // in embedded-replica mode (writes go to the remote, frames replicate to the local 93 + // file); otherwise it's a plain local libSQL file. Schema is applied either way. 94 + func OpenDB(dbPath string, sync LibsqlSync) (*LabelerDB, error) { 47 95 if err := os.MkdirAll(filepath.Dir(dbPath), 0755); err != nil { 48 96 return nil, fmt.Errorf("failed to create db directory: %w", err) 49 97 } 50 98 51 - db, err := sql.Open("libsql", "file:"+dbPath) 52 - if err != nil { 53 - return nil, fmt.Errorf("failed to open database: %w", err) 99 + var ( 100 + db *sql.DB 101 + connector io.Closer 102 + ) 103 + 104 + if sync.SyncURL != "" { 105 + opts := []libsql.Option{libsql.WithAuthToken(sync.AuthToken)} 106 + if sync.SyncInterval > 0 { 107 + opts = append(opts, libsql.WithSyncInterval(sync.SyncInterval)) 108 + } 109 + conn, err := libsql.NewEmbeddedReplicaConnector(dbPath, sync.SyncURL, opts...) 110 + if err != nil { 111 + return nil, fmt.Errorf("failed to create libsql embedded replica connector: %w", err) 112 + } 113 + db = sql.OpenDB(conn) 114 + connector = conn 115 + slog.Info("Labeler database opened in embedded replica mode", "path", dbPath, "sync_url", sync.SyncURL) 116 + } else { 117 + dsn := dbPath 118 + if !strings.HasPrefix(dsn, "file:") && !strings.HasPrefix(dsn, ":memory:") { 119 + dsn = "file:" + dsn 120 + } 121 + var err error 122 + db, err = sql.Open("libsql", dsn) 123 + if err != nil { 124 + return nil, fmt.Errorf("failed to open database: %w", err) 125 + } 126 + slog.Info("Labeler database opened in local-only mode", "path", dbPath) 127 + } 128 + 129 + // Local PRAGMAs only — Bunny rejects PRAGMA assignments forwarded over the 130 + // replication protocol (same caveat as pkg/hold/db). 131 + if sync.SyncURL == "" { 132 + var journalMode string 133 + if err := db.QueryRow("PRAGMA journal_mode = WAL").Scan(&journalMode); err != nil { 134 + _ = closeIfNonNil(db, connector) 135 + return nil, fmt.Errorf("failed to set journal mode: %w", err) 136 + } 137 + var busyTimeout int 138 + if err := db.QueryRow("PRAGMA busy_timeout = 5000").Scan(&busyTimeout); err != nil { 139 + _ = closeIfNonNil(db, connector) 140 + return nil, fmt.Errorf("failed to set busy_timeout: %w", err) 141 + } 54 142 } 55 143 56 - // Apply schema 57 144 for _, stmt := range splitStatements(schema) { 58 145 if _, err := db.Exec(stmt); err != nil { 146 + _ = closeIfNonNil(db, connector) 59 147 return nil, fmt.Errorf("failed to apply schema: %w", err) 60 148 } 61 149 } 150 + return &LabelerDB{DB: db, connector: connector}, nil 151 + } 62 152 63 - return db, nil 153 + // closeIfNonNil is the defensive cleanup for the failure path on OpenDB so we don't 154 + // leave file locks dangling if schema application fails. 155 + func closeIfNonNil(db *sql.DB, connector io.Closer) error { 156 + if db != nil { 157 + _ = db.Close() 158 + } 159 + if connector != nil { 160 + return connector.Close() 161 + } 162 + return nil 64 163 } 65 164 66 - // splitStatements splits SQL by semicolons (go-libsql doesn't support multi-statement exec). 67 165 func splitStatements(sql string) []string { 68 - var stmts []string 69 - for _, s := range splitOnSemicolon(sql) { 70 - s = trimSpace(s) 166 + parts := strings.Split(sql, ";") 167 + out := make([]string, 0, len(parts)) 168 + for _, s := range parts { 169 + s = strings.TrimSpace(s) 71 170 if s != "" { 72 - stmts = append(stmts, s) 171 + out = append(out, s) 73 172 } 74 173 } 75 - return stmts 174 + return out 76 175 } 77 176 78 - func splitOnSemicolon(s string) []string { 79 - var parts []string 80 - start := 0 81 - for i := 0; i < len(s); i++ { 82 - if s[i] == ';' { 83 - parts = append(parts, s[start:i]) 84 - start = i + 1 85 - } 177 + // ToLabeling converts the row into indigo's label struct (deterministic CBOR shape). 178 + func (l *Label) ToLabeling() labeling.Label { 179 + out := labeling.Label{ 180 + CreatedAt: l.Cts.UTC().Format(time.RFC3339), 181 + SourceDID: l.Src, 182 + URI: l.URI, 183 + Val: l.Val, 184 + Version: l.Ver, 86 185 } 87 - if start < len(s) { 88 - parts = append(parts, s[start:]) 186 + if l.CID != "" { 187 + s := l.CID 188 + out.CID = &s 89 189 } 90 - return parts 190 + if l.Exp != nil { 191 + s := l.Exp.UTC().Format(time.RFC3339) 192 + out.ExpiresAt = &s 193 + } 194 + if l.Neg { 195 + t := true 196 + out.Negated = &t 197 + } 198 + if len(l.Sig) > 0 { 199 + out.Sig = l.Sig 200 + } 201 + return out 91 202 } 92 203 93 - func trimSpace(s string) string { 94 - // Simple trim that handles newlines and spaces 95 - i := 0 96 - for i < len(s) && (s[i] == ' ' || s[i] == '\t' || s[i] == '\n' || s[i] == '\r') { 97 - i++ 204 + // Sign computes a k256 signature over the deterministic CBOR encoding of the label 205 + // (without the sig field) and stores it on the row. 206 + func (l *Label) Sign(key *atcrypto.PrivateKeyK256) error { 207 + if l.Ver == 0 { 208 + l.Ver = LabelVersion 209 + } 210 + if l.Cts.IsZero() { 211 + l.Cts = time.Now().UTC() 98 212 } 99 - j := len(s) 100 - for j > i && (s[j-1] == ' ' || s[j-1] == '\t' || s[j-1] == '\n' || s[j-1] == '\r') { 101 - j-- 213 + pre := l.ToLabeling() 214 + pre.Sig = nil 215 + if err := pre.Sign(key); err != nil { 216 + return fmt.Errorf("failed to sign label: %w", err) 102 217 } 103 - return s[i:j] 218 + l.Sig = pre.Sig 219 + return nil 104 220 } 105 221 106 - // CreateLabel inserts a new label into the database. 222 + // CreateLabel inserts a freshly signed label and returns its sequence id. 223 + // Caller must Sign() first — CreateLabel rejects rows missing a signature. 107 224 func CreateLabel(db *sql.DB, l *Label) (int64, error) { 225 + if len(l.Sig) == 0 { 226 + return 0, fmt.Errorf("refusing to insert unsigned label") 227 + } 228 + if l.Ver == 0 { 229 + l.Ver = LabelVersion 230 + } 231 + var expStr *string 232 + if l.Exp != nil { 233 + s := l.Exp.UTC().Format(time.RFC3339) 234 + expStr = &s 235 + } 108 236 result, err := db.Exec( 109 - `INSERT INTO labels (src, uri, cid, val, neg, cts, exp, subject_did, subject_repo) 110 - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) 111 - ON CONFLICT(src, uri, val, neg) DO UPDATE SET cts = excluded.cts`, 112 - l.Src, l.URI, l.CID, l.Val, l.Neg, l.Cts.UTC().Format(time.RFC3339), l.Exp, 237 + `INSERT INTO labels (src, uri, cid, val, neg, cts, exp, ver, sig, subject_did, subject_repo) 238 + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, 239 + l.Src, l.URI, nullableString(l.CID), l.Val, l.Neg, 240 + l.Cts.UTC().Format(time.RFC3339), expStr, l.Ver, l.Sig, 113 241 l.SubjectDID, l.SubjectRepo, 114 242 ) 115 243 if err != nil { 116 - return 0, fmt.Errorf("failed to create label: %w", err) 244 + return 0, fmt.Errorf("failed to insert label: %w", err) 245 + } 246 + id, err := result.LastInsertId() 247 + if err != nil { 248 + return 0, err 117 249 } 118 - return result.LastInsertId() 250 + l.ID = id 251 + return id, nil 119 252 } 120 253 121 - // NegateLabel creates a negation label to reverse a previous label. 122 - func NegateLabel(db *sql.DB, src, uri, val string, subjectDID, subjectRepo string) error { 123 - _, err := db.Exec( 124 - `INSERT INTO labels (src, uri, val, neg, cts, subject_did, subject_repo) 125 - VALUES (?, ?, ?, 1, ?, ?, ?)`, 126 - src, uri, val, time.Now().UTC().Format(time.RFC3339), subjectDID, subjectRepo, 127 - ) 128 - return err 254 + func nullableString(s string) any { 255 + if s == "" { 256 + return nil 257 + } 258 + return s 129 259 } 130 260 131 261 // GetLabelsSince returns labels with id > cursor, ordered by id ascending. 132 262 func GetLabelsSince(db *sql.DB, cursor int64, limit int) ([]Label, error) { 133 263 rows, err := db.Query( 134 - `SELECT id, src, uri, COALESCE(cid, ''), val, neg, cts, exp, subject_did, subject_repo 264 + `SELECT id, src, uri, COALESCE(cid, ''), val, neg, cts, exp, ver, sig, subject_did, subject_repo 135 265 FROM labels WHERE id > ? ORDER BY id ASC LIMIT ?`, 136 266 cursor, limit, 137 267 ) ··· 139 269 return nil, err 140 270 } 141 271 defer rows.Close() 272 + return scanLabels(rows) 273 + } 142 274 143 - return scanLabels(rows) 275 + // LatestSeq returns the highest sequence id in the database, or 0 if empty. 276 + func LatestSeq(db *sql.DB) (int64, error) { 277 + var seq sql.NullInt64 278 + if err := db.QueryRow(`SELECT MAX(id) FROM labels`).Scan(&seq); err != nil { 279 + return 0, err 280 + } 281 + if !seq.Valid { 282 + return 0, nil 283 + } 284 + return seq.Int64, nil 144 285 } 145 286 146 287 // ListActiveTakedowns returns active (non-negated) takedown labels. ··· 161 302 } 162 303 163 304 rows, err := db.Query( 164 - `SELECT l1.id, l1.src, l1.uri, COALESCE(l1.cid, ''), l1.val, l1.neg, l1.cts, l1.exp, l1.subject_did, l1.subject_repo 305 + `SELECT l1.id, l1.src, l1.uri, COALESCE(l1.cid, ''), l1.val, l1.neg, l1.cts, l1.exp, l1.ver, l1.sig, l1.subject_did, l1.subject_repo 165 306 FROM labels l1 166 307 WHERE l1.val = '!takedown' AND l1.neg = 0 167 308 AND NOT EXISTS ( ··· 182 323 return labels, total, err 183 324 } 184 325 185 - // GetLabelsForRepo returns all active labels for a specific DID + repository. 326 + // GetLabelsForRepo returns all labels for a specific DID + repository. 186 327 func GetLabelsForRepo(db *sql.DB, did, repo string) ([]Label, error) { 187 328 rows, err := db.Query( 188 - `SELECT id, src, uri, COALESCE(cid, ''), val, neg, cts, exp, subject_did, subject_repo 329 + `SELECT id, src, uri, COALESCE(cid, ''), val, neg, cts, exp, ver, sig, subject_did, subject_repo 189 330 FROM labels 190 331 WHERE subject_did = ? AND subject_repo = ? 191 332 ORDER BY cts DESC`, ··· 198 339 return scanLabels(rows) 199 340 } 200 341 201 - // NegateRepoLabels creates negation labels for all active takedown labels on a (DID, repo) pair. 202 - func NegateRepoLabels(db *sql.DB, src, did, repo string) error { 342 + // newNegationLabel constructs an unsigned negation label awaiting Sign(). 343 + func newNegationLabel(src, uri, val, did, repo string) *Label { 344 + return &Label{ 345 + Src: src, 346 + URI: uri, 347 + Val: val, 348 + Neg: true, 349 + Cts: time.Now().UTC(), 350 + SubjectDID: did, 351 + SubjectRepo: repo, 352 + } 353 + } 354 + 355 + // NegateRepoLabels signs+inserts negation labels for all active takedown labels on (DID, repo). 356 + func NegateRepoLabels(db *sql.DB, key *atcrypto.PrivateKeyK256, src, did, repo string) ([]Label, error) { 203 357 rows, err := db.Query( 204 358 `SELECT uri FROM labels 205 359 WHERE subject_did = ? AND subject_repo = ? AND val = '!takedown' AND neg = 0`, 206 360 did, repo, 207 361 ) 208 362 if err != nil { 209 - return err 363 + return nil, err 210 364 } 211 - 212 365 var uris []string 213 366 for rows.Next() { 214 367 var uri string 215 368 if err := rows.Scan(&uri); err != nil { 216 369 rows.Close() 217 - return err 370 + return nil, err 218 371 } 219 372 uris = append(uris, uri) 220 373 } 221 374 rows.Close() 222 375 if err := rows.Err(); err != nil { 223 - return err 376 + return nil, err 224 377 } 225 378 226 - now := time.Now().UTC().Format(time.RFC3339) 379 + out := make([]Label, 0, len(uris)) 227 380 for _, uri := range uris { 228 - if _, err := db.Exec( 229 - `INSERT INTO labels (src, uri, val, neg, cts, subject_did, subject_repo) 230 - VALUES (?, ?, '!takedown', 1, ?, ?, ?)`, 231 - src, uri, now, did, repo, 232 - ); err != nil { 233 - return err 381 + neg := newNegationLabel(src, uri, "!takedown", did, repo) 382 + if err := neg.Sign(key); err != nil { 383 + return out, err 384 + } 385 + if _, err := CreateLabel(db, neg); err != nil { 386 + return out, err 234 387 } 388 + out = append(out, *neg) 235 389 } 236 - return nil 390 + return out, nil 237 391 } 238 392 239 - // NegateUserLabels creates negation labels for all active takedown labels on a DID (user-level). 240 - func NegateUserLabels(db *sql.DB, src, did string) error { 393 + // NegateUserLabels signs+inserts negation labels for all active takedown labels on a DID. 394 + func NegateUserLabels(db *sql.DB, key *atcrypto.PrivateKeyK256, src, did string) ([]Label, error) { 241 395 rows, err := db.Query( 242 396 `SELECT uri, subject_repo FROM labels 243 397 WHERE subject_did = ? AND val = '!takedown' AND neg = 0`, 244 398 did, 245 399 ) 246 400 if err != nil { 247 - return err 401 + return nil, err 248 402 } 249 - 250 403 type uriRepo struct { 251 404 uri string 252 405 repo string ··· 256 409 var e uriRepo 257 410 if err := rows.Scan(&e.uri, &e.repo); err != nil { 258 411 rows.Close() 259 - return err 412 + return nil, err 260 413 } 261 414 entries = append(entries, e) 262 415 } 263 416 rows.Close() 264 417 if err := rows.Err(); err != nil { 265 - return err 418 + return nil, err 266 419 } 267 420 268 - now := time.Now().UTC().Format(time.RFC3339) 421 + out := make([]Label, 0, len(entries)) 269 422 for _, e := range entries { 270 - if _, err := db.Exec( 271 - `INSERT INTO labels (src, uri, val, neg, cts, subject_did, subject_repo) 272 - VALUES (?, ?, '!takedown', 1, ?, ?, ?)`, 273 - src, e.uri, now, did, e.repo, 274 - ); err != nil { 275 - return err 423 + neg := newNegationLabel(src, e.uri, "!takedown", did, e.repo) 424 + if err := neg.Sign(key); err != nil { 425 + return out, err 426 + } 427 + if _, err := CreateLabel(db, neg); err != nil { 428 + return out, err 276 429 } 430 + out = append(out, *neg) 277 431 } 278 - return nil 432 + return out, nil 279 433 } 280 434 281 435 func scanLabels(rows *sql.Rows) ([]Label, error) { ··· 284 438 var l Label 285 439 var cts string 286 440 var exp *string 287 - if err := rows.Scan(&l.ID, &l.Src, &l.URI, &l.CID, &l.Val, &l.Neg, &cts, &exp, &l.SubjectDID, &l.SubjectRepo); err != nil { 441 + if err := rows.Scan(&l.ID, &l.Src, &l.URI, &l.CID, &l.Val, &l.Neg, &cts, &exp, &l.Ver, &l.Sig, &l.SubjectDID, &l.SubjectRepo); err != nil { 288 442 return nil, err 289 443 } 290 444 if t, err := time.Parse(time.RFC3339, cts); err == nil {
+141 -180
pkg/labeler/db_test.go
··· 1 1 package labeler 2 2 3 3 import ( 4 + "database/sql" 4 5 "os" 5 6 "path/filepath" 6 7 "testing" 7 8 "time" 9 + 10 + "github.com/bluesky-social/indigo/atproto/atcrypto" 8 11 ) 9 12 13 + func newTestKey(t *testing.T) *atcrypto.PrivateKeyK256 { 14 + t.Helper() 15 + k, err := atcrypto.GeneratePrivateKeyK256() 16 + if err != nil { 17 + t.Fatalf("generate key: %v", err) 18 + } 19 + return k 20 + } 21 + 22 + // openTestDB opens a fresh local-only labeler DB and registers cleanup. Returns the 23 + // raw *sql.DB so existing tests can keep using it; the wrapper lifecycle is handled 24 + // here so tests don't have to know about the embedded-replica machinery. 25 + func openTestDB(t *testing.T, path string) *sql.DB { 26 + t.Helper() 27 + storage, err := OpenDB(path, LibsqlSync{}) 28 + if err != nil { 29 + t.Fatalf("OpenDB: %v", err) 30 + } 31 + t.Cleanup(func() { _ = storage.Close() }) 32 + return storage.DB 33 + } 34 + 35 + // signAndCreate is a helper that signs the label and inserts it; it returns the row id. 36 + func signAndCreate(t *testing.T, db *sql.DB, key *atcrypto.PrivateKeyK256, l *Label) int64 { 37 + t.Helper() 38 + if err := l.Sign(key); err != nil { 39 + t.Fatalf("sign: %v", err) 40 + } 41 + id, err := CreateLabel(db, l) 42 + if err != nil { 43 + t.Fatalf("create: %v", err) 44 + } 45 + return id 46 + } 47 + 10 48 func TestOpenDB(t *testing.T) { 11 49 dir := t.TempDir() 12 50 dbPath := filepath.Join(dir, "subdir", "test.db") 13 51 14 - db, err := OpenDB(dbPath) 52 + storage, err := OpenDB(dbPath, LibsqlSync{}) 15 53 if err != nil { 16 54 t.Fatalf("OpenDB failed: %v", err) 17 55 } 18 - defer db.Close() 56 + defer storage.Close() 19 57 20 - // Verify directory was created 21 58 if _, err := os.Stat(filepath.Dir(dbPath)); os.IsNotExist(err) { 22 59 t.Error("expected directory to be created") 23 60 } 24 61 25 - // Verify tables exist 26 62 var count int 27 - err = db.QueryRow("SELECT COUNT(*) FROM labels").Scan(&count) 28 - if err != nil { 63 + if err := storage.DB.QueryRow("SELECT COUNT(*) FROM labels").Scan(&count); err != nil { 29 64 t.Fatalf("failed to query labels table: %v", err) 30 65 } 31 66 if count != 0 { ··· 35 70 36 71 func TestCreateLabel(t *testing.T) { 37 72 dir := t.TempDir() 38 - db, err := OpenDB(filepath.Join(dir, "test.db")) 39 - if err != nil { 40 - t.Fatal(err) 41 - } 42 - defer db.Close() 73 + db := openTestDB(t, filepath.Join(dir, "test.db")) 74 + key := newTestKey(t) 43 75 44 76 now := time.Now().UTC().Truncate(time.Second) 45 77 label := &Label{ 46 - Src: "did:web:labeler.atcr.io", 78 + Src: "did:plc:labeler-1", 47 79 URI: "at://did:plc:abc/io.atcr.manifest/sha256-123", 48 80 Val: "!takedown", 49 81 Cts: now, 50 82 SubjectDID: "did:plc:abc", 51 83 SubjectRepo: "myimage", 52 84 } 53 - 54 - id, err := CreateLabel(db, label) 55 - if err != nil { 56 - t.Fatalf("CreateLabel failed: %v", err) 57 - } 85 + id := signAndCreate(t, db, key, label) 58 86 if id <= 0 { 59 87 t.Errorf("expected positive id, got %d", id) 60 88 } 89 + if len(label.Sig) == 0 { 90 + t.Error("expected signature populated by Sign()") 91 + } 61 92 62 - // Verify it was stored 63 93 labels, err := GetLabelsSince(db, 0, 10) 64 94 if err != nil { 65 95 t.Fatal(err) ··· 67 97 if len(labels) != 1 { 68 98 t.Fatalf("expected 1 label, got %d", len(labels)) 69 99 } 70 - if labels[0].Src != "did:web:labeler.atcr.io" { 71 - t.Errorf("expected src did:web:labeler.atcr.io, got %s", labels[0].Src) 72 - } 73 - if labels[0].Val != "!takedown" { 74 - t.Errorf("expected val !takedown, got %s", labels[0].Val) 100 + if labels[0].Src != label.Src { 101 + t.Errorf("Src = %s, want %s", labels[0].Src, label.Src) 75 102 } 76 103 if labels[0].SubjectDID != "did:plc:abc" { 77 - t.Errorf("expected subject_did did:plc:abc, got %s", labels[0].SubjectDID) 104 + t.Errorf("SubjectDID = %s", labels[0].SubjectDID) 78 105 } 79 106 if labels[0].SubjectRepo != "myimage" { 80 - t.Errorf("expected subject_repo myimage, got %s", labels[0].SubjectRepo) 107 + t.Errorf("SubjectRepo = %s", labels[0].SubjectRepo) 108 + } 109 + if labels[0].Ver != LabelVersion { 110 + t.Errorf("Ver = %d, want %d", labels[0].Ver, LabelVersion) 111 + } 112 + if len(labels[0].Sig) == 0 { 113 + t.Error("expected stored sig to be populated") 81 114 } 82 115 } 83 116 84 - func TestCreateLabel_Upsert(t *testing.T) { 117 + func TestCreateLabel_RejectsUnsigned(t *testing.T) { 85 118 dir := t.TempDir() 86 - db, err := OpenDB(filepath.Join(dir, "test.db")) 87 - if err != nil { 88 - t.Fatal(err) 89 - } 90 - defer db.Close() 119 + db := openTestDB(t, filepath.Join(dir, "test.db")) 91 120 92 - now := time.Now().UTC() 93 121 label := &Label{ 94 - Src: "did:web:labeler.atcr.io", 95 - URI: "at://did:plc:abc/io.atcr.manifest/sha256-123", 96 - Val: "!takedown", 97 - Cts: now, 98 - SubjectDID: "did:plc:abc", 99 - SubjectRepo: "myimage", 100 - } 101 - 102 - // First insert 103 - _, err = CreateLabel(db, label) 104 - if err != nil { 105 - t.Fatal(err) 106 - } 107 - 108 - // Same (src, uri, val) - should upsert, not error 109 - label.Cts = now.Add(time.Hour) 110 - _, err = CreateLabel(db, label) 111 - if err != nil { 112 - t.Fatalf("upsert should not fail: %v", err) 113 - } 114 - 115 - // Should still be 1 label 116 - labels, err := GetLabelsSince(db, 0, 10) 117 - if err != nil { 118 - t.Fatal(err) 122 + Src: "did:plc:labeler-1", URI: "at://did:plc:abc", 123 + Val: "!takedown", Cts: time.Now().UTC(), 124 + SubjectDID: "did:plc:abc", 119 125 } 120 - if len(labels) != 1 { 121 - t.Errorf("expected 1 label after upsert, got %d", len(labels)) 126 + if _, err := CreateLabel(db, label); err == nil { 127 + t.Fatal("expected CreateLabel to reject an unsigned label") 122 128 } 123 129 } 124 130 125 - func TestNegateLabel(t *testing.T) { 126 - dir := t.TempDir() 127 - db, err := OpenDB(filepath.Join(dir, "test.db")) 128 - if err != nil { 129 - t.Fatal(err) 131 + func TestSignAndVerify(t *testing.T) { 132 + key := newTestKey(t) 133 + label := &Label{ 134 + Src: "did:plc:labeler-1", 135 + URI: "at://did:plc:abc", 136 + Val: "!takedown", 137 + Cts: time.Now().UTC(), 138 + Ver: LabelVersion, 130 139 } 131 - defer db.Close() 132 - 133 - src := "did:web:labeler.atcr.io" 134 - now := time.Now().UTC() 135 - 136 - // Create a label 137 - _, err = CreateLabel(db, &Label{ 138 - Src: src, URI: "at://did:plc:abc/io.atcr.manifest/sha256-123", 139 - Val: "!takedown", Cts: now, 140 - SubjectDID: "did:plc:abc", SubjectRepo: "myimage", 141 - }) 142 - if err != nil { 140 + if err := label.Sign(key); err != nil { 143 141 t.Fatal(err) 144 142 } 145 143 146 - // Negate it 147 - err = NegateLabel(db, src, "at://did:plc:abc/io.atcr.manifest/sha256-123", "!takedown", "did:plc:abc", "myimage") 148 - if err != nil { 149 - t.Fatalf("NegateLabel failed: %v", err) 150 - } 151 - 152 - // Should have 2 labels now (original + negation) 153 - labels, err := GetLabelsSince(db, 0, 10) 144 + pub, err := key.PublicKey() 154 145 if err != nil { 155 146 t.Fatal(err) 156 147 } 157 - if len(labels) != 2 { 158 - t.Fatalf("expected 2 labels, got %d", len(labels)) 159 - } 160 - 161 - // The negation label should have neg=true 162 - negLabel := labels[1] 163 - if !negLabel.Neg { 164 - t.Error("expected negation label to have neg=true") 148 + wire := label.ToLabeling() 149 + if err := wire.VerifySignature(pub); err != nil { 150 + t.Fatalf("signature did not verify: %v", err) 165 151 } 166 152 } 167 153 168 154 func TestListActiveTakedowns(t *testing.T) { 169 155 dir := t.TempDir() 170 - db, err := OpenDB(filepath.Join(dir, "test.db")) 171 - if err != nil { 172 - t.Fatal(err) 173 - } 174 - defer db.Close() 156 + db := openTestDB(t, filepath.Join(dir, "test.db")) 157 + key := newTestKey(t) 175 158 176 - src := "did:web:labeler.atcr.io" 159 + src := "did:plc:labeler-1" 177 160 now := time.Now().UTC() 178 161 179 - // Create 3 labels 180 162 for i, repo := range []string{"repo1", "repo2", "repo3"} { 181 - _, err = CreateLabel(db, &Label{ 163 + signAndCreate(t, db, key, &Label{ 182 164 Src: src, URI: "at://did:plc:abc/io.atcr.repo/" + repo, 183 165 Val: "!takedown", Cts: now.Add(time.Duration(i) * time.Minute), 184 166 SubjectDID: "did:plc:abc", SubjectRepo: repo, 185 167 }) 186 - if err != nil { 187 - t.Fatal(err) 188 - } 189 168 } 190 169 191 - // All 3 should be active 192 170 labels, total, err := ListActiveTakedowns(db, 10, 0) 193 171 if err != nil { 194 172 t.Fatal(err) 195 173 } 196 - if total != 3 { 197 - t.Errorf("expected 3 active takedowns, got %d", total) 198 - } 199 - if len(labels) != 3 { 200 - t.Errorf("expected 3 labels returned, got %d", len(labels)) 174 + if total != 3 || len(labels) != 3 { 175 + t.Errorf("expected 3 active takedowns, got total=%d returned=%d", total, len(labels)) 201 176 } 202 177 203 - // Negate one 204 - err = NegateLabel(db, src, "at://did:plc:abc/io.atcr.repo/repo2", "!takedown", "did:plc:abc", "repo2") 205 - if err != nil { 178 + if _, err := NegateRepoLabels(db, key, src, "did:plc:abc", "repo2"); err != nil { 206 179 t.Fatal(err) 207 180 } 208 181 209 - // Should be 2 active 210 182 _, total, err = ListActiveTakedowns(db, 10, 0) 211 183 if err != nil { 212 184 t.Fatal(err) ··· 218 190 219 191 func TestNegateRepoLabels(t *testing.T) { 220 192 dir := t.TempDir() 221 - db, err := OpenDB(filepath.Join(dir, "test.db")) 222 - if err != nil { 223 - t.Fatal(err) 224 - } 225 - defer db.Close() 193 + db := openTestDB(t, filepath.Join(dir, "test.db")) 194 + key := newTestKey(t) 226 195 227 - src := "did:web:labeler.atcr.io" 196 + src := "did:plc:labeler-1" 228 197 now := time.Now().UTC() 229 198 did := "did:plc:abc" 230 199 231 - // Create multiple labels for same repo 232 200 uris := []string{ 233 201 "at://did:plc:abc/io.atcr.manifest/sha256-111", 234 202 "at://did:plc:abc/io.atcr.manifest/sha256-222", 235 203 "at://did:plc:abc/io.atcr.tag/myimage-latest", 236 204 } 237 205 for _, uri := range uris { 238 - _, err = CreateLabel(db, &Label{ 206 + signAndCreate(t, db, key, &Label{ 239 207 Src: src, URI: uri, Val: "!takedown", Cts: now, 240 208 SubjectDID: did, SubjectRepo: "myimage", 241 209 }) 242 - if err != nil { 243 - t.Fatal(err) 244 - } 245 210 } 246 211 247 - // Negate all labels for the repo 248 - err = NegateRepoLabels(db, src, did, "myimage") 212 + negs, err := NegateRepoLabels(db, key, src, did, "myimage") 249 213 if err != nil { 250 214 t.Fatal(err) 251 215 } 216 + if len(negs) != len(uris) { 217 + t.Errorf("expected %d negation labels, got %d", len(uris), len(negs)) 218 + } 252 219 253 - // Should have 0 active takedowns 254 220 _, total, err := ListActiveTakedowns(db, 10, 0) 255 221 if err != nil { 256 222 t.Fatal(err) ··· 262 228 263 229 func TestNegateUserLabels(t *testing.T) { 264 230 dir := t.TempDir() 265 - db, err := OpenDB(filepath.Join(dir, "test.db")) 266 - if err != nil { 267 - t.Fatal(err) 268 - } 269 - defer db.Close() 231 + db := openTestDB(t, filepath.Join(dir, "test.db")) 232 + key := newTestKey(t) 270 233 271 - src := "did:web:labeler.atcr.io" 234 + src := "did:plc:labeler-1" 272 235 now := time.Now().UTC() 273 236 did := "did:plc:abc" 274 237 275 - // Create labels for different repos + a user-level label 276 - _, err = CreateLabel(db, &Label{ 238 + signAndCreate(t, db, key, &Label{ 277 239 Src: src, URI: "at://did:plc:abc", Val: "!takedown", Cts: now, 278 - SubjectDID: did, SubjectRepo: "", 240 + SubjectDID: did, 279 241 }) 280 - if err != nil { 281 - t.Fatal(err) 282 - } 283 - _, err = CreateLabel(db, &Label{ 242 + signAndCreate(t, db, key, &Label{ 284 243 Src: src, URI: "at://did:plc:abc/io.atcr.repo/repo1", Val: "!takedown", Cts: now, 285 244 SubjectDID: did, SubjectRepo: "repo1", 286 245 }) 287 - if err != nil { 288 - t.Fatal(err) 289 - } 290 246 291 - // Negate all labels for the user 292 - err = NegateUserLabels(db, src, did) 247 + negs, err := NegateUserLabels(db, key, src, did) 293 248 if err != nil { 294 249 t.Fatal(err) 295 250 } 251 + if len(negs) != 2 { 252 + t.Errorf("expected 2 negation labels, got %d", len(negs)) 253 + } 296 254 297 - // Should have 0 active 298 255 _, total, err := ListActiveTakedowns(db, 10, 0) 299 256 if err != nil { 300 257 t.Fatal(err) ··· 306 263 307 264 func TestGetLabelsSince(t *testing.T) { 308 265 dir := t.TempDir() 309 - db, err := OpenDB(filepath.Join(dir, "test.db")) 310 - if err != nil { 311 - t.Fatal(err) 312 - } 313 - defer db.Close() 266 + db := openTestDB(t, filepath.Join(dir, "test.db")) 267 + key := newTestKey(t) 314 268 315 - src := "did:web:labeler.atcr.io" 269 + src := "did:plc:labeler-1" 316 270 now := time.Now().UTC() 317 271 318 - // Create 5 labels 319 272 for i := 0; i < 5; i++ { 320 - _, err = CreateLabel(db, &Label{ 273 + signAndCreate(t, db, key, &Label{ 321 274 Src: src, URI: "at://did:plc:abc/io.atcr.manifest/" + string(rune('a'+i)), 322 275 Val: "!takedown", Cts: now.Add(time.Duration(i) * time.Minute), 323 276 SubjectDID: "did:plc:abc", SubjectRepo: "repo", 324 277 }) 325 - if err != nil { 326 - t.Fatal(err) 327 - } 328 278 } 329 279 330 - // Get all since 0 331 280 labels, err := GetLabelsSince(db, 0, 10) 332 281 if err != nil { 333 282 t.Fatal(err) ··· 336 285 t.Errorf("expected 5 labels, got %d", len(labels)) 337 286 } 338 287 339 - // Get since cursor (skip first 3) 340 - if len(labels) >= 3 { 341 - cursor := labels[2].ID 342 - after, err := GetLabelsSince(db, cursor, 10) 343 - if err != nil { 344 - t.Fatal(err) 345 - } 346 - if len(after) != 2 { 347 - t.Errorf("expected 2 labels after cursor %d, got %d", cursor, len(after)) 348 - } 288 + cursor := labels[2].ID 289 + after, err := GetLabelsSince(db, cursor, 10) 290 + if err != nil { 291 + t.Fatal(err) 292 + } 293 + if len(after) != 2 { 294 + t.Errorf("expected 2 labels after cursor %d, got %d", cursor, len(after)) 349 295 } 350 296 351 - // Get with limit 352 297 limited, err := GetLabelsSince(db, 0, 2) 353 298 if err != nil { 354 299 t.Fatal(err) ··· 358 303 } 359 304 } 360 305 361 - func TestGetLabelsForRepo(t *testing.T) { 306 + func TestLatestSeq(t *testing.T) { 362 307 dir := t.TempDir() 363 - db, err := OpenDB(filepath.Join(dir, "test.db")) 308 + db := openTestDB(t, filepath.Join(dir, "test.db")) 309 + key := newTestKey(t) 310 + 311 + if seq, err := LatestSeq(db); err != nil || seq != 0 { 312 + t.Fatalf("expected empty seq=0, got %d (err=%v)", seq, err) 313 + } 314 + 315 + id := signAndCreate(t, db, key, &Label{ 316 + Src: "did:plc:labeler-1", URI: "at://did:plc:abc", 317 + Val: "!takedown", Cts: time.Now().UTC(), 318 + SubjectDID: "did:plc:abc", 319 + }) 320 + seq, err := LatestSeq(db) 364 321 if err != nil { 365 322 t.Fatal(err) 366 323 } 367 - defer db.Close() 324 + if seq != id { 325 + t.Errorf("LatestSeq = %d, want %d", seq, id) 326 + } 327 + } 368 328 369 - src := "did:web:labeler.atcr.io" 329 + func TestGetLabelsForRepo(t *testing.T) { 330 + dir := t.TempDir() 331 + db := openTestDB(t, filepath.Join(dir, "test.db")) 332 + key := newTestKey(t) 333 + 334 + src := "did:plc:labeler-1" 370 335 now := time.Now().UTC() 371 336 372 - // Labels for different repos 373 - _, _ = CreateLabel(db, &Label{ 337 + signAndCreate(t, db, key, &Label{ 374 338 Src: src, URI: "at://did:plc:abc/io.atcr.repo/repo1", 375 339 Val: "!takedown", Cts: now, SubjectDID: "did:plc:abc", SubjectRepo: "repo1", 376 340 }) 377 - _, _ = CreateLabel(db, &Label{ 341 + signAndCreate(t, db, key, &Label{ 378 342 Src: src, URI: "at://did:plc:abc/io.atcr.repo/repo2", 379 343 Val: "!takedown", Cts: now, SubjectDID: "did:plc:abc", SubjectRepo: "repo2", 380 344 }) 381 - _, _ = CreateLabel(db, &Label{ 345 + signAndCreate(t, db, key, &Label{ 382 346 Src: src, URI: "at://did:plc:def/io.atcr.repo/repo1", 383 347 Val: "!takedown", Cts: now, SubjectDID: "did:plc:def", SubjectRepo: "repo1", 384 348 }) 385 349 386 - // Get labels for specific did+repo 387 350 labels, err := GetLabelsForRepo(db, "did:plc:abc", "repo1") 388 351 if err != nil { 389 352 t.Fatal(err) ··· 392 355 t.Errorf("expected 1 label for did:plc:abc/repo1, got %d", len(labels)) 393 356 } 394 357 395 - // Different user same repo 396 358 labels, err = GetLabelsForRepo(db, "did:plc:def", "repo1") 397 359 if err != nil { 398 360 t.Fatal(err) ··· 401 363 t.Errorf("expected 1 label for did:plc:def/repo1, got %d", len(labels)) 402 364 } 403 365 404 - // No labels 405 366 labels, err = GetLabelsForRepo(db, "did:plc:xyz", "repo1") 406 367 if err != nil { 407 368 t.Fatal(err)
+3 -4
pkg/labeler/handlers.go
··· 13 13 // Auth handlers 14 14 15 15 func (s *Server) handleLogin(w http.ResponseWriter, r *http.Request) { 16 - // If already logged in, redirect to dashboard 17 16 if token, ok := getSessionCookie(r); ok { 18 - if session := s.auth.getSession(token); session != nil && session.DID == s.config.Labeler.OwnerDID { 17 + if session := s.auth.GetSession(token); session != nil && session.DID == s.config.Labeler.OwnerDID { 19 18 http.Redirect(w, r, "/", http.StatusFound) 20 19 return 21 20 } ··· 99 98 return 100 99 } 101 100 102 - token, err := s.auth.createSession(did, handle) 101 + token, _, err := s.auth.CreateSession(did, handle, r.UserAgent(), clientIPPrefix(r)) 103 102 if err != nil { 104 103 http.Error(w, "Failed to create session", http.StatusInternalServerError) 105 104 return ··· 111 110 112 111 func (s *Server) handleLogout(w http.ResponseWriter, r *http.Request) { 113 112 if token, ok := getSessionCookie(r); ok { 114 - s.auth.deleteSession(token) 113 + s.auth.DeleteSession(token) 115 114 } 116 115 clearSessionCookie(w) 117 116 http.Redirect(w, r, "/auth/login", http.StatusFound)
+75
pkg/labeler/hub.go
··· 1 + package labeler 2 + 3 + import ( 4 + "sync" 5 + ) 6 + 7 + // hubSubscriber is one connected subscribeLabels client. The hub fans out new labels 8 + // to each subscriber's bounded channel; if a slow client fills the buffer, the hub 9 + // drops them rather than blocking the writer. 10 + type hubSubscriber struct { 11 + ch chan *Label 12 + closed bool 13 + } 14 + 15 + // Hub broadcasts newly-inserted labels to all live subscribeLabels clients. 16 + type Hub struct { 17 + mu sync.Mutex 18 + subs map[*hubSubscriber]struct{} 19 + } 20 + 21 + // NewHub returns an empty hub ready to accept subscribers. 22 + func NewHub() *Hub { 23 + return &Hub{subs: make(map[*hubSubscriber]struct{})} 24 + } 25 + 26 + // subscribe registers a new subscriber and returns its event channel + a cancel func. 27 + // The buffer size bounds backpressure tolerance per client. 28 + func (h *Hub) subscribe(buffer int) (*hubSubscriber, func()) { 29 + s := &hubSubscriber{ch: make(chan *Label, buffer)} 30 + h.mu.Lock() 31 + h.subs[s] = struct{}{} 32 + h.mu.Unlock() 33 + return s, func() { h.unsubscribe(s) } 34 + } 35 + 36 + func (h *Hub) unsubscribe(s *hubSubscriber) { 37 + h.mu.Lock() 38 + defer h.mu.Unlock() 39 + if _, ok := h.subs[s]; !ok { 40 + return 41 + } 42 + delete(h.subs, s) 43 + if !s.closed { 44 + s.closed = true 45 + close(s.ch) 46 + } 47 + } 48 + 49 + // Broadcast sends a copy of the label to every live subscriber. Subscribers whose 50 + // buffer is full are evicted on the spot rather than slowing down the writer. 51 + func (h *Hub) Broadcast(l *Label) { 52 + if l == nil { 53 + return 54 + } 55 + h.mu.Lock() 56 + dead := make([]*hubSubscriber, 0) 57 + for s := range h.subs { 58 + select { 59 + case s.ch <- l: 60 + default: 61 + dead = append(dead, s) 62 + } 63 + } 64 + h.mu.Unlock() 65 + for _, s := range dead { 66 + h.unsubscribe(s) 67 + } 68 + } 69 + 70 + // Len returns the number of live subscribers (mostly for tests / metrics). 71 + func (h *Hub) Len() int { 72 + h.mu.Lock() 73 + defer h.mu.Unlock() 74 + return len(h.subs) 75 + }
+38 -21
pkg/labeler/identity.go
··· 1 1 package labeler 2 2 3 3 import ( 4 + "context" 4 5 "encoding/json" 5 6 "fmt" 6 7 "net/http" 8 + 9 + "atcr.io/pkg/atproto/did" 10 + "atcr.io/pkg/auth/oauth" 11 + "github.com/bluesky-social/indigo/atproto/atcrypto" 7 12 ) 8 13 9 - // DIDDocument represents a did:web DID document. 10 - type DIDDocument struct { 11 - Context []string `json:"@context"` 12 - ID string `json:"id"` 13 - Service []DIDService `json:"service,omitempty"` 14 + // labelerServices returns the service entries the labeler publishes in its DID document 15 + // and PLC operations: a single AtprotoLabeler endpoint at #atproto_labeler. 16 + func labelerServices(publicURL string) map[string]did.Service { 17 + return map[string]did.Service{ 18 + "atproto_labeler": {Type: "AtprotoLabeler", Endpoint: publicURL}, 19 + } 14 20 } 15 21 16 - // DIDService represents a service entry in a DID document. 17 - type DIDService struct { 18 - ID string `json:"id"` 19 - Type string `json:"type"` 20 - ServiceEndpoint string `json:"serviceEndpoint"` 22 + // LoadIdentity resolves the labeler's DID and loads its k256 signing key. 23 + // For did:plc this calls into the shared PLC package (loading or creating); for did:web 24 + // the DID is derived from PublicURL and the signing key is generated on disk if missing. 25 + func LoadIdentity(ctx context.Context, cfg *Config) (string, *atcrypto.PrivateKeyK256, error) { 26 + labelerDID, err := did.LoadOrCreate(ctx, did.Config{ 27 + Method: cfg.Labeler.DIDMethod, 28 + PublicURL: cfg.PublicURL(), 29 + DBPath: cfg.Labeler.DataDir, 30 + SigningKeyPath: cfg.SigningKeyPath(), 31 + RotationKey: cfg.Labeler.RotationKey, 32 + PLCDirectoryURL: cfg.PLCDirectoryURL(), 33 + DID: cfg.Labeler.DID, 34 + VerificationKeyName: "atproto_label", 35 + Services: labelerServices(cfg.PublicURL()), 36 + }) 37 + if err != nil { 38 + return "", nil, fmt.Errorf("labeler: failed to resolve DID: %w", err) 39 + } 40 + signingKey, err := oauth.GenerateOrLoadPDSKey(cfg.SigningKeyPath()) 41 + if err != nil { 42 + return "", nil, fmt.Errorf("labeler: failed to load signing key: %w", err) 43 + } 44 + return labelerDID, signingKey, nil 21 45 } 22 46 23 47 func (s *Server) handleDIDDocument(w http.ResponseWriter, r *http.Request) { 24 - doc := DIDDocument{ 25 - Context: []string{"https://www.w3.org/ns/did/v1"}, 26 - ID: s.config.DID(), 27 - Service: []DIDService{ 28 - { 29 - ID: "#atproto_labeler", 30 - Type: "AtprotoLabeler", 31 - ServiceEndpoint: s.config.PublicURL(), 32 - }, 33 - }, 48 + doc, err := did.BuildDIDDocument(s.did, s.config.PublicURL(), s.signingKey, "atproto_label", labelerServices(s.config.PublicURL())) 49 + if err != nil { 50 + http.Error(w, "failed to build DID document", http.StatusInternalServerError) 51 + return 34 52 } 35 - 36 53 w.Header().Set("Content-Type", "application/json") 37 54 _ = json.NewEncoder(w).Encode(doc) 38 55 }
+50 -16
pkg/labeler/server.go
··· 5 5 "database/sql" 6 6 "fmt" 7 7 "log/slog" 8 + "net" 8 9 "net/http" 9 10 "net/url" 10 11 "os" 11 12 "os/signal" 12 - "strings" 13 13 "syscall" 14 + "time" 14 15 15 16 "atcr.io/pkg/atproto" 17 + "github.com/bluesky-social/indigo/atproto/atcrypto" 16 18 indigooauth "github.com/bluesky-social/indigo/atproto/auth/oauth" 17 19 "github.com/go-chi/chi/v5" 18 20 ) 19 21 20 22 // Server is the labeler HTTP server. 21 23 type Server struct { 22 - config *Config 23 - db *sql.DB 24 - router chi.Router 25 - clientApp *indigooauth.ClientApp 26 - auth *Auth 24 + config *Config 25 + storage *LabelerDB 26 + db *sql.DB 27 + router chi.Router 28 + clientApp *indigooauth.ClientApp 29 + auth *Auth 30 + did string 31 + signingKey *atcrypto.PrivateKeyK256 32 + hub *Hub 27 33 } 28 34 29 35 // NewServer creates a new labeler server. 30 36 func NewServer(cfg *Config) (*Server, error) { 31 - db, err := OpenDB(cfg.Labeler.DBPath) 37 + storage, err := OpenDB(cfg.DBPath(), LibsqlSync{ 38 + SyncURL: cfg.Labeler.LibsqlSyncURL, 39 + AuthToken: cfg.Labeler.LibsqlAuthToken, 40 + SyncInterval: cfg.Labeler.LibsqlSyncInterval, 41 + }) 32 42 if err != nil { 33 43 return nil, fmt.Errorf("failed to open database: %w", err) 44 + } 45 + 46 + ctx := context.Background() 47 + did, signingKey, err := LoadIdentity(ctx, cfg) 48 + if err != nil { 49 + _ = storage.Close() 50 + return nil, err 34 51 } 35 52 36 53 publicURL := cfg.PublicURL() ··· 68 85 auth := NewAuth(cfg.Labeler.OwnerDID) 69 86 70 87 s := &Server{ 71 - config: cfg, 72 - db: db, 73 - clientApp: clientApp, 74 - auth: auth, 88 + config: cfg, 89 + storage: storage, 90 + db: storage.DB, 91 + clientApp: clientApp, 92 + auth: auth, 93 + did: did, 94 + signingKey: signingKey, 95 + hub: NewHub(), 75 96 } 76 97 77 98 s.setupRoutes() ··· 97 118 r.Get("/xrpc/com.atproto.label.subscribeLabels", s.handleSubscribeLabels) 98 119 r.Get("/xrpc/com.atproto.label.queryLabels", s.handleQueryLabels) 99 120 100 - // Protected routes (require owner) 121 + // Protected routes (require owner). CSRF is enforced for state-mutating 122 + // methods inside the same group, so it sees the session on the context. 101 123 r.Group(func(r chi.Router) { 102 124 r.Use(s.auth.RequireOwner) 125 + r.Use(s.auth.RequireCSRF) 103 126 104 127 r.Get("/", s.handleDashboard) 105 128 r.Get("/takedown", s.handleTakedownForm) ··· 115 138 slog.Info("Starting labeler service", 116 139 "addr", s.config.Labeler.Addr, 117 140 "public_url", s.config.PublicURL(), 118 - "did", s.config.DID(), 141 + "did", s.did, 119 142 "owner", s.config.Labeler.OwnerDID, 120 143 ) 121 144 ··· 140 163 } 141 164 case <-ctx.Done(): 142 165 slog.Info("Shutting down labeler service") 143 - shutdownCtx, cancel := context.WithTimeout(context.Background(), 5000000000) // 5s 166 + shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 144 167 defer cancel() 145 168 if err := srv.Shutdown(shutdownCtx); err != nil { 146 169 return fmt.Errorf("shutdown error: %w", err) 147 170 } 148 171 } 149 172 150 - s.db.Close() 173 + if err := s.storage.Close(); err != nil { 174 + slog.Warn("Error closing labeler database", "error", err) 175 + } 151 176 return nil 152 177 } 153 178 179 + // isLocalhost returns true when the host is reachable only from the local machine / 180 + // docker host — anything that an external PDS can't reach. Matches the hold's policy: 181 + // any IP literal counts (covers 127.0.0.1, 192.168.*, 172.16-31.*, 10.*, ::1, etc.) plus 182 + // the literal "localhost". When this is true, OAuth uses indigo's `NewLocalhostConfig` 183 + // which sets a `http://localhost`-form client_id that PDSes accept under the loopback 184 + // exception — so the PDS never has to fetch the client metadata URL we publish. 154 185 func isLocalhost(host string) bool { 155 - return host == "localhost" || host == "127.0.0.1" || strings.HasPrefix(host, "192.168.") 186 + if host == "localhost" { 187 + return true 188 + } 189 + return net.ParseIP(host) != nil 156 190 }
+30
pkg/labeler/server_test.go
··· 1 + package labeler 2 + 3 + import "testing" 4 + 5 + // TestIsLocalhost covers the OAuth-mode decision: any IP-literal host (including 6 + // docker-compose private addresses like 172.28.0.x and the standard 127.0.0.1) plus 7 + // the literal "localhost" routes through the loopback OAuth path so PDSes don't have 8 + // to fetch our published client metadata. Domain names go through the public-client 9 + // path and require the metadata endpoint to be reachable from the PDS. 10 + func TestIsLocalhost(t *testing.T) { 11 + tests := []struct { 12 + host string 13 + want bool 14 + }{ 15 + {"localhost", true}, 16 + {"127.0.0.1", true}, 17 + {"::1", true}, 18 + {"192.168.1.10", true}, 19 + {"172.28.0.4", true}, // docker-compose private network 20 + {"10.0.0.5", true}, // RFC 1918 21 + {"labeler.atcr.io", false}, 22 + {"labeler.example.com", false}, 23 + {"", false}, 24 + } 25 + for _, tt := range tests { 26 + if got := isLocalhost(tt.host); got != tt.want { 27 + t.Errorf("isLocalhost(%q) = %v, want %v", tt.host, got, tt.want) 28 + } 29 + } 30 + }
+244 -94
pkg/labeler/subscribe.go
··· 1 1 package labeler 2 2 3 3 import ( 4 + "bytes" 5 + "database/sql" 4 6 "encoding/json" 7 + "errors" 8 + "fmt" 5 9 "log/slog" 6 10 "net/http" 7 11 "strconv" 8 - "time" 12 + "strings" 9 13 14 + comatproto "github.com/bluesky-social/indigo/api/atproto" 15 + "github.com/bluesky-social/indigo/events" 10 16 "github.com/gorilla/websocket" 17 + cbg "github.com/whyrusleeping/cbor-gen" 18 + ) 19 + 20 + const ( 21 + subscriberBuffer = 64 22 + backfillPageLimit = 200 11 23 ) 12 24 13 25 var upgrader = websocket.Upgrader{ 26 + // CheckOrigin is permissive: the firehose is a public stream by design and ATProto 27 + // consumers are not browsers, so the same-origin policy doesn't apply to them anyway. 14 28 CheckOrigin: func(r *http.Request) bool { return true }, 15 29 } 16 30 17 - // LabelsMessage is the ATProto subscribeLabels wire format. 18 - type LabelsMessage struct { 19 - Seq int64 `json:"seq"` 20 - Labels []LabelOutput `json:"labels"` 31 + // frameLabels builds the binary frame for a labels event: CBOR-encoded 32 + // {op:1, t:"#labels"} header concatenated with CBOR-encoded {seq, labels:[...]} body. 33 + func frameLabels(seq int64, labels []*comatproto.LabelDefs_Label) ([]byte, error) { 34 + var buf bytes.Buffer 35 + w := cbg.NewCborWriter(&buf) 36 + 37 + header := events.EventHeader{Op: events.EvtKindMessage, MsgType: "#labels"} 38 + if err := header.MarshalCBOR(w); err != nil { 39 + return nil, fmt.Errorf("marshal header: %w", err) 40 + } 41 + body := comatproto.LabelSubscribeLabels_Labels{Seq: seq, Labels: labels} 42 + if err := body.MarshalCBOR(w); err != nil { 43 + return nil, fmt.Errorf("marshal body: %w", err) 44 + } 45 + return buf.Bytes(), nil 21 46 } 22 47 23 - // LabelOutput is the ATProto label format for subscribeLabels/queryLabels output. 24 - type LabelOutput struct { 25 - Src string `json:"src"` 26 - URI string `json:"uri"` 27 - CID string `json:"cid,omitempty"` 28 - Val string `json:"val"` 29 - Neg bool `json:"neg"` 30 - Cts string `json:"cts"` 31 - Exp string `json:"exp,omitempty"` 48 + // frameInfo builds the binary frame for an info event: header {op:1, t:"#info"} plus body. 49 + func frameInfo(name, message string) ([]byte, error) { 50 + var buf bytes.Buffer 51 + w := cbg.NewCborWriter(&buf) 52 + 53 + header := events.EventHeader{Op: events.EvtKindMessage, MsgType: "#info"} 54 + if err := header.MarshalCBOR(w); err != nil { 55 + return nil, err 56 + } 57 + body := comatproto.LabelSubscribeLabels_Info{Name: name} 58 + if message != "" { 59 + body.Message = &message 60 + } 61 + if err := body.MarshalCBOR(w); err != nil { 62 + return nil, err 63 + } 64 + return buf.Bytes(), nil 32 65 } 33 66 34 - func labelToOutput(l Label) LabelOutput { 35 - out := LabelOutput{ 36 - Src: l.Src, 37 - URI: l.URI, 38 - CID: l.CID, 39 - Val: l.Val, 40 - Neg: l.Neg, 41 - Cts: l.Cts.UTC().Format(time.RFC3339), 67 + // frameError builds an error frame: header {op:-1} plus {error, message}. 68 + func frameError(name, message string) ([]byte, error) { 69 + var buf bytes.Buffer 70 + w := cbg.NewCborWriter(&buf) 71 + 72 + header := events.EventHeader{Op: events.EvtKindErrorFrame} 73 + if err := header.MarshalCBOR(w); err != nil { 74 + return nil, err 42 75 } 43 - if l.Exp != nil { 44 - out.Exp = l.Exp.UTC().Format(time.RFC3339) 76 + body := events.ErrorFrame{Error: name, Message: message} 77 + if err := body.MarshalCBOR(w); err != nil { 78 + return nil, err 45 79 } 46 - return out 80 + return buf.Bytes(), nil 81 + } 82 + 83 + // labelToLexicon converts a stored row into the indigo lexicon type used in the wire format. 84 + func labelToLexicon(l *Label) *comatproto.LabelDefs_Label { 85 + tmp := l.ToLabeling() 86 + lex := tmp.ToLexicon() 87 + return &lex 47 88 } 48 89 49 - // handleSubscribeLabels implements com.atproto.label.subscribeLabels (WebSocket). 90 + // handleSubscribeLabels implements com.atproto.label.subscribeLabels. 91 + // 92 + // Wire format: each WebSocket binary message is two concatenated CBOR objects (header 93 + // + body) matching the firehose convention. Backfill pages historical labels since the 94 + // cursor, then the connection joins the broadcast hub for live deliveries. 50 95 func (s *Server) handleSubscribeLabels(w http.ResponseWriter, r *http.Request) { 51 96 cursorStr := r.URL.Query().Get("cursor") 52 97 var cursor int64 53 98 if cursorStr != "" { 54 - var err error 55 - cursor, err = strconv.ParseInt(cursorStr, 10, 64) 99 + v, err := strconv.ParseInt(cursorStr, 10, 64) 56 100 if err != nil { 57 101 http.Error(w, "invalid cursor", http.StatusBadRequest) 58 102 return 59 103 } 104 + cursor = v 60 105 } 61 106 62 107 conn, err := upgrader.Upgrade(w, r, nil) ··· 68 113 69 114 slog.Info("subscribeLabels client connected", "cursor", cursor) 70 115 71 - // Send historical labels since cursor 72 - labels, err := GetLabelsSince(s.db, cursor, 1000) 116 + latest, err := LatestSeq(s.db) 73 117 if err != nil { 74 - slog.Error("Failed to get labels", "error", err) 118 + slog.Error("Failed to read latest seq", "error", err) 75 119 return 76 120 } 121 + if cursor > latest { 122 + if frame, ferr := frameError("FutureCursor", "cursor is in the future"); ferr == nil { 123 + _ = conn.WriteMessage(websocket.BinaryMessage, frame) 124 + } 125 + return 126 + } 127 + 128 + // Subscribe to the broadcast hub BEFORE backfilling so we don't lose events 129 + // that arrive while we're streaming the historical tail. 130 + sub, cancel := s.hub.subscribe(subscriberBuffer) 131 + defer cancel() 77 132 78 - for _, l := range labels { 79 - msg := LabelsMessage{ 80 - Seq: l.ID, 81 - Labels: []LabelOutput{labelToOutput(l)}, 133 + if cursor > 0 { 134 + if frame, ferr := frameInfo("OutdatedCursor", "starting backfill from cursor"); ferr == nil { 135 + if err := conn.WriteMessage(websocket.BinaryMessage, frame); err != nil { 136 + return 137 + } 82 138 } 83 - if err := conn.WriteJSON(msg); err != nil { 139 + } 140 + 141 + // Backfill historical labels in pages until we catch up. 142 + for { 143 + labels, err := GetLabelsSince(s.db, cursor, backfillPageLimit) 144 + if err != nil { 145 + slog.Error("Failed to read labels for backfill", "error", err) 84 146 return 85 147 } 86 - cursor = l.ID 148 + if len(labels) == 0 { 149 + break 150 + } 151 + for i := range labels { 152 + frame, ferr := frameLabels(labels[i].ID, []*comatproto.LabelDefs_Label{labelToLexicon(&labels[i])}) 153 + if ferr != nil { 154 + slog.Error("Failed to encode label frame", "error", ferr) 155 + return 156 + } 157 + if err := conn.WriteMessage(websocket.BinaryMessage, frame); err != nil { 158 + return 159 + } 160 + cursor = labels[i].ID 161 + } 162 + if len(labels) < backfillPageLimit { 163 + break 164 + } 87 165 } 88 166 89 - // Poll for new labels 90 - ticker := time.NewTicker(5 * time.Second) 91 - defer ticker.Stop() 92 - 93 - // Read pump (detect client disconnect) 167 + // Live delivery: a goroutine monitors the read side so we notice client disconnects; 168 + // the main loop pulls from the hub and writes frames until either side closes. 94 169 done := make(chan struct{}) 95 170 go func() { 96 171 defer close(done) 97 172 for { 98 - if _, _, err := conn.ReadMessage(); err != nil { 173 + if _, _, rerr := conn.ReadMessage(); rerr != nil { 99 174 return 100 175 } 101 176 } ··· 105 180 select { 106 181 case <-done: 107 182 return 108 - case <-ticker.C: 109 - labels, err := GetLabelsSince(s.db, cursor, 100) 110 - if err != nil { 111 - slog.Error("Failed to poll labels", "error", err) 112 - continue 183 + case lbl, ok := <-sub.ch: 184 + if !ok { 185 + return 113 186 } 114 - for _, l := range labels { 115 - msg := LabelsMessage{ 116 - Seq: l.ID, 117 - Labels: []LabelOutput{labelToOutput(l)}, 118 - } 119 - if err := conn.WriteJSON(msg); err != nil { 120 - return 121 - } 122 - cursor = l.ID 187 + if lbl.ID <= cursor { 188 + continue // already delivered during backfill 123 189 } 190 + frame, ferr := frameLabels(lbl.ID, []*comatproto.LabelDefs_Label{labelToLexicon(lbl)}) 191 + if ferr != nil { 192 + slog.Error("Failed to encode label frame", "error", ferr) 193 + return 194 + } 195 + if err := conn.WriteMessage(websocket.BinaryMessage, frame); err != nil { 196 + return 197 + } 198 + cursor = lbl.ID 124 199 } 125 200 } 126 201 } 127 202 128 - // handleQueryLabels implements com.atproto.label.queryLabels (HTTP GET). 203 + // queryLabelsResponse mirrors the lexicon JSON shape for queryLabels. 204 + type queryLabelsResponse struct { 205 + Cursor string `json:"cursor,omitempty"` 206 + Labels []*comatproto.LabelDefs_Label `json:"labels"` 207 + } 208 + 209 + // handleQueryLabels implements com.atproto.label.queryLabels. 210 + // 211 + // Filters (uriPatterns, sources) are applied in SQL so the LIMIT cap operates on the 212 + // filtered result, not the raw scan. URI patterns support a single trailing `*` glob 213 + // (LIKE), with `%` and `_` escaped to remain literal. 129 214 func (s *Server) handleQueryLabels(w http.ResponseWriter, r *http.Request) { 130 - uriPatterns := r.URL.Query()["uriPatterns"] 131 - cursorStr := r.URL.Query().Get("cursor") 132 - limitStr := r.URL.Query().Get("limit") 215 + q := r.URL.Query() 216 + patterns := q["uriPatterns"] 217 + sources := q["sources"] 133 218 134 219 var cursor int64 135 - if cursorStr != "" { 136 - cursor, _ = strconv.ParseInt(cursorStr, 10, 64) 220 + if cs := q.Get("cursor"); cs != "" { 221 + v, err := strconv.ParseInt(cs, 10, 64) 222 + if err != nil { 223 + http.Error(w, "invalid cursor", http.StatusBadRequest) 224 + return 225 + } 226 + cursor = v 137 227 } 228 + 138 229 limit := 50 139 - if limitStr != "" { 140 - if l, err := strconv.Atoi(limitStr); err == nil && l > 0 && l <= 250 { 230 + if ls := q.Get("limit"); ls != "" { 231 + if l, err := strconv.Atoi(ls); err == nil && l > 0 && l <= 250 { 141 232 limit = l 142 233 } 143 234 } 144 235 145 - labels, err := GetLabelsSince(s.db, cursor, limit) 236 + rows, err := queryLabelsSQL(s.db, patterns, sources, cursor, limit) 146 237 if err != nil { 147 - http.Error(w, "failed to query labels", http.StatusInternalServerError) 238 + if errors.Is(err, errInvalidPattern) { 239 + http.Error(w, "invalid uriPattern: wildcard '*' must be at the end", http.StatusBadRequest) 240 + return 241 + } 242 + slog.Error("queryLabels failed", "error", err) 243 + http.Error(w, "internal error", http.StatusInternalServerError) 148 244 return 149 245 } 150 246 151 - // Filter by URI patterns if provided 152 - var filtered []LabelOutput 153 - for _, l := range labels { 154 - if len(uriPatterns) == 0 || matchesAnyPattern(l.URI, uriPatterns) { 155 - filtered = append(filtered, labelToOutput(l)) 247 + out := &queryLabelsResponse{Labels: make([]*comatproto.LabelDefs_Label, 0, len(rows))} 248 + for i := range rows { 249 + out.Labels = append(out.Labels, labelToLexicon(&rows[i])) 250 + } 251 + if len(rows) > 0 { 252 + out.Cursor = strconv.FormatInt(rows[len(rows)-1].ID, 10) 253 + } 254 + 255 + w.Header().Set("Content-Type", "application/json") 256 + _ = json.NewEncoder(w).Encode(out) 257 + } 258 + 259 + var errInvalidPattern = errors.New("invalid uriPattern") 260 + 261 + // queryLabelsSQL builds the WHERE clause from filter args and runs the query. All 262 + // filtering happens in SQL so LIMIT operates on already-filtered rows. 263 + func queryLabelsSQL(db *sql.DB, patterns, sources []string, cursor int64, limit int) ([]Label, error) { 264 + var ( 265 + where []string 266 + args []any 267 + ) 268 + where = append(where, "id > ?") 269 + args = append(args, cursor) 270 + 271 + if len(patterns) > 0 { 272 + var ors []string 273 + var matchAll bool 274 + for _, p := range patterns { 275 + if p == "" { 276 + continue 277 + } 278 + if p == "*" { 279 + matchAll = true 280 + break 281 + } 282 + like, err := patternToLike(p) 283 + if err != nil { 284 + return nil, err 285 + } 286 + if strings.ContainsAny(like, `%_\`) { 287 + ors = append(ors, "uri LIKE ? ESCAPE '\\'") 288 + } else { 289 + ors = append(ors, "uri = ?") 290 + } 291 + args = append(args, like) 292 + } 293 + if !matchAll && len(ors) > 0 { 294 + where = append(where, "("+strings.Join(ors, " OR ")+")") 156 295 } 157 296 } 158 297 159 - var nextCursor string 160 - if len(labels) > 0 { 161 - nextCursor = strconv.FormatInt(labels[len(labels)-1].ID, 10) 298 + if len(sources) > 0 { 299 + placeholders := strings.Repeat("?,", len(sources)) 300 + placeholders = placeholders[:len(placeholders)-1] 301 + where = append(where, "src IN ("+placeholders+")") 302 + for _, s := range sources { 303 + args = append(args, s) 304 + } 162 305 } 163 306 164 - resp := struct { 165 - Cursor string `json:"cursor,omitempty"` 166 - Labels []LabelOutput `json:"labels"` 167 - }{ 168 - Cursor: nextCursor, 169 - Labels: filtered, 307 + args = append(args, limit) 308 + q := `SELECT id, src, uri, COALESCE(cid, ''), val, neg, cts, exp, ver, sig, subject_did, subject_repo 309 + FROM labels WHERE ` + strings.Join(where, " AND ") + 310 + ` ORDER BY id ASC LIMIT ?` 311 + 312 + rows, err := db.Query(q, args...) 313 + if err != nil { 314 + return nil, err 170 315 } 171 - if resp.Labels == nil { 172 - resp.Labels = []LabelOutput{} 173 - } 174 - 175 - w.Header().Set("Content-Type", "application/json") 176 - _ = json.NewEncoder(w).Encode(resp) 316 + defer rows.Close() 317 + return scanLabels(rows) 177 318 } 178 319 179 - func matchesAnyPattern(uri string, patterns []string) bool { 180 - for _, p := range patterns { 181 - // Simple prefix matching (ATProto spec allows glob-like patterns) 182 - if p == uri || (len(p) > 0 && p[len(p)-1] == '*' && len(uri) >= len(p)-1 && uri[:len(p)-1] == p[:len(p)-1]) { 183 - return true 184 - } 320 + // patternToLike converts a uriPattern into a SQLite LIKE expression. The only 321 + // wildcard supported is a trailing `*`, which becomes `%`. Literal `%`, `_`, and `\` 322 + // in the input are escaped via the LIKE ESCAPE clause used at query time. 323 + func patternToLike(p string) (string, error) { 324 + if idx := strings.Index(p, "*"); idx >= 0 && idx != len(p)-1 { 325 + return "", errInvalidPattern 326 + } 327 + literal := p 328 + suffix := "" 329 + if strings.HasSuffix(p, "*") { 330 + literal = p[:len(p)-1] 331 + suffix = "%" 185 332 } 186 - return false 333 + literal = strings.ReplaceAll(literal, `\`, `\\`) 334 + literal = strings.ReplaceAll(literal, `%`, `\%`) 335 + literal = strings.ReplaceAll(literal, `_`, `\_`) 336 + return literal + suffix, nil 187 337 }
+70 -47
pkg/labeler/subscribe_test.go
··· 1 1 package labeler 2 2 3 3 import ( 4 + "strings" 4 5 "testing" 5 6 "time" 7 + 8 + comatproto "github.com/bluesky-social/indigo/api/atproto" 6 9 ) 7 10 8 - func TestLabelToOutput(t *testing.T) { 11 + // TestLabelToLexicon checks the wire-format conversion populates the indigo lexicon 12 + // fields (which is what's serialized into both the WS frame and the queryLabels JSON). 13 + func TestLabelToLexicon(t *testing.T) { 9 14 now := time.Date(2026, 3, 22, 10, 0, 0, 0, time.UTC) 10 15 exp := time.Date(2026, 4, 22, 10, 0, 0, 0, time.UTC) 11 16 12 17 label := Label{ 13 18 ID: 1, 14 - Src: "did:web:labeler.atcr.io", 19 + Src: "did:plc:abc", 15 20 URI: "at://did:plc:abc/io.atcr.manifest/sha256-123", 16 21 CID: "bafyabc", 17 22 Val: "!takedown", 18 - Neg: false, 19 23 Cts: now, 20 24 Exp: &exp, 25 + Ver: LabelVersion, 26 + Sig: []byte{0x01, 0x02, 0x03}, 21 27 SubjectDID: "did:plc:abc", 22 28 SubjectRepo: "myimage", 23 29 } 30 + lex := labelToLexicon(&label) 24 31 25 - out := labelToOutput(label) 26 - if out.Src != "did:web:labeler.atcr.io" { 27 - t.Errorf("Src = %q, want did:web:labeler.atcr.io", out.Src) 32 + if lex.Src != label.Src { 33 + t.Errorf("Src = %q", lex.Src) 28 34 } 29 - if out.URI != "at://did:plc:abc/io.atcr.manifest/sha256-123" { 30 - t.Errorf("URI = %q", out.URI) 35 + if lex.Uri != label.URI { 36 + t.Errorf("Uri = %q", lex.Uri) 31 37 } 32 - if out.CID != "bafyabc" { 33 - t.Errorf("CID = %q, want bafyabc", out.CID) 38 + if lex.Cid == nil || *lex.Cid != "bafyabc" { 39 + t.Errorf("Cid = %v", lex.Cid) 34 40 } 35 - if out.Val != "!takedown" { 36 - t.Errorf("Val = %q", out.Val) 41 + if lex.Cts != "2026-03-22T10:00:00Z" { 42 + t.Errorf("Cts = %q", lex.Cts) 37 43 } 38 - if out.Neg { 39 - t.Error("expected Neg=false") 44 + if lex.Exp == nil || *lex.Exp != "2026-04-22T10:00:00Z" { 45 + t.Errorf("Exp = %v", lex.Exp) 40 46 } 41 - if out.Cts != "2026-03-22T10:00:00Z" { 42 - t.Errorf("Cts = %q", out.Cts) 43 - } 44 - if out.Exp != "2026-04-22T10:00:00Z" { 45 - t.Errorf("Exp = %q", out.Exp) 47 + if len(lex.Sig) != 3 { 48 + t.Errorf("Sig length = %d, want 3", len(lex.Sig)) 46 49 } 47 50 } 48 51 49 - func TestLabelToOutput_NoExpiration(t *testing.T) { 50 - label := Label{ 51 - Src: "did:web:labeler.atcr.io", 52 - URI: "at://did:plc:abc", 53 - Val: "!takedown", 54 - Cts: time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC), 55 - } 56 - 57 - out := labelToOutput(label) 58 - if out.Exp != "" { 59 - t.Errorf("expected empty Exp, got %q", out.Exp) 60 - } 61 - } 62 - 63 - func TestMatchesAnyPattern(t *testing.T) { 52 + func TestPatternToLike(t *testing.T) { 64 53 tests := []struct { 65 - name string 66 - uri string 67 - patterns []string 68 - want bool 54 + in string 55 + want string 56 + wantErr bool 69 57 }{ 70 - {"exact match", "at://did:plc:abc/io.atcr.manifest/sha256-123", []string{"at://did:plc:abc/io.atcr.manifest/sha256-123"}, true}, 71 - {"no match", "at://did:plc:abc/io.atcr.manifest/sha256-123", []string{"at://did:plc:def/io.atcr.manifest/sha256-123"}, false}, 72 - {"wildcard match", "at://did:plc:abc/io.atcr.manifest/sha256-123", []string{"at://did:plc:abc/*"}, true}, 73 - {"wildcard no match", "at://did:plc:abc/io.atcr.manifest/sha256-123", []string{"at://did:plc:def/*"}, false}, 74 - {"empty patterns", "at://did:plc:abc/io.atcr.manifest/sha256-123", []string{}, false}, 75 - {"multiple patterns", "at://did:plc:abc/io.atcr.manifest/sha256-123", []string{"at://did:plc:def/*", "at://did:plc:abc/*"}, true}, 58 + {"at://did:plc:abc/foo", "at://did:plc:abc/foo", false}, 59 + {"at://did:plc:abc/*", "at://did:plc:abc/%", false}, 60 + {"at://did:plc:abc%/foo", `at://did:plc:abc\%/foo`, false}, 61 + {"at://did:plc:abc_/foo", `at://did:plc:abc\_/foo`, false}, 62 + {"at://*/foo", "", true}, 76 63 } 77 64 78 65 for _, tt := range tests { 79 - t.Run(tt.name, func(t *testing.T) { 80 - got := matchesAnyPattern(tt.uri, tt.patterns) 66 + t.Run(tt.in, func(t *testing.T) { 67 + got, err := patternToLike(tt.in) 68 + if tt.wantErr { 69 + if err == nil { 70 + t.Errorf("expected error, got %q", got) 71 + } 72 + return 73 + } 74 + if err != nil { 75 + t.Fatalf("unexpected err: %v", err) 76 + } 81 77 if got != tt.want { 82 - t.Errorf("matchesAnyPattern(%q, %v) = %v, want %v", tt.uri, tt.patterns, got, tt.want) 78 + t.Errorf("patternToLike(%q) = %q, want %q", tt.in, got, tt.want) 83 79 } 84 80 }) 85 81 } 86 82 } 83 + 84 + // TestFrameLabels exercises the CBOR framing — produces non-empty bytes whose first 85 + // byte is a CBOR map header and whose body contains the labels payload keys. 86 + func TestFrameLabels(t *testing.T) { 87 + now := time.Date(2026, 3, 22, 10, 0, 0, 0, time.UTC) 88 + l := &Label{ 89 + Src: "did:plc:abc", 90 + URI: "at://did:plc:abc", 91 + Val: "!takedown", 92 + Cts: now, 93 + Ver: LabelVersion, 94 + Sig: []byte("sig-bytes"), 95 + } 96 + frame, err := frameLabels(42, []*comatproto.LabelDefs_Label{labelToLexicon(l)}) 97 + if err != nil { 98 + t.Fatalf("frame: %v", err) 99 + } 100 + if len(frame) < 2 { 101 + t.Fatalf("frame too short: %d bytes", len(frame)) 102 + } 103 + if frame[0]&0xe0 != 0xa0 { 104 + t.Errorf("first byte %#x is not a CBOR map header", frame[0]) 105 + } 106 + if !strings.Contains(string(frame), "labels") || !strings.Contains(string(frame), "seq") { 107 + t.Errorf("frame missing expected keys") 108 + } 109 + }
+146 -66
pkg/labeler/takedown.go
··· 7 7 "html/template" 8 8 "log/slog" 9 9 "net/http" 10 + "net/url" 10 11 "strings" 11 12 "time" 12 13 ··· 21 22 } 22 23 23 24 // ParseTakedownInput parses various input formats into a TakedownInput. 24 - // Supported formats: 25 - // - atcr.io/r/handle/repo 26 - // - handle/repo 27 - // - at://did:plc:xyz/io.atcr.repo.page/repo 28 - // - at://did:plc:xyz (user-level) 29 - // - handle (user-level) 30 - // - did:plc:xyz (user-level) 25 + // 26 + // Supported shapes (dispatched in order): 27 + // 28 + // - at://<did-or-handle>[/collection/rkey] — ATProto AT URI 29 + // - did:plc:..., did:web:... — bare DID, user-level takedown 30 + // - URL with /u/<handle> or /r/<handle>/<repo> — appview routes (with or without scheme) 31 + // - <handle> — bare handle, user-level takedown 32 + // 33 + // Anything else is rejected. The appview's /r/ route uses the repo name as a single 34 + // path segment so any trailing path (digest pages, tag tabs) is discarded; URL 35 + // fragments and query strings are dropped in all cases. 31 36 func ParseTakedownInput(ctx context.Context, input string) (*TakedownInput, error) { 32 37 input = strings.TrimSpace(input) 38 + if input == "" { 39 + return nil, fmt.Errorf("empty takedown input") 40 + } 33 41 34 - // AT URI format 35 42 if strings.HasPrefix(input, "at://") { 36 43 return parseATURI(ctx, input) 37 44 } 38 45 39 - // Strip URL prefix if present 40 - input = strings.TrimPrefix(input, "https://") 41 - input = strings.TrimPrefix(input, "http://") 42 - 43 - // Remove atcr.io/r/ or similar prefix 44 - for _, prefix := range []string{"atcr.io/r/", "localhost/r/"} { 45 - if strings.HasPrefix(input, prefix) { 46 - input = strings.TrimPrefix(input, prefix) 47 - break 48 - } 49 - } 50 - // Also handle custom domains: anything ending in /r/ 51 - if idx := strings.Index(input, "/r/"); idx >= 0 { 52 - input = input[idx+3:] 46 + // Bare DID — no slashes, no scheme. did:plc:..., did:web:..., did:web:host%3Aport. 47 + if strings.HasPrefix(input, "did:") && !strings.Contains(input, "/") { 48 + return resolveBareIdentifier(ctx, input) 53 49 } 54 50 55 - // Now input should be "handle/repo" or "handle" or "did:xxx" 56 - parts := strings.SplitN(input, "/", 2) 57 - identifier := parts[0] 58 - var repo string 59 - if len(parts) > 1 { 60 - repo = parts[1] 61 - repo = strings.TrimSuffix(repo, "/") 51 + // URL-shaped: contains a scheme or a slash. Parse and dispatch on the path. 52 + if hasURLShape(input) { 53 + return parseTakedownURL(ctx, input) 62 54 } 63 55 64 - did, handle, err := resolveIdentifier(ctx, identifier) 56 + // Otherwise: bare handle. 57 + return resolveBareIdentifier(ctx, input) 58 + } 59 + 60 + // hasURLShape reports whether the input looks like a URL or a path. A bare handle like 61 + // "alice.bsky.social" is not URL-shaped (no slashes, no scheme). 62 + func hasURLShape(s string) bool { 63 + return strings.Contains(s, "://") || strings.Contains(s, "/") 64 + } 65 + 66 + // parseTakedownURL parses a URL whose path is one of the appview's takedown-relevant 67 + // routes: /u/<handle> for user-level, /r/<handle>/<repo> for repo-level. The host part 68 + // is irrelevant — we only use the path — so this also accepts schemeless input like 69 + // "atcr.io/r/handle/repo" by prepending https:// before parsing. 70 + func parseTakedownURL(ctx context.Context, input string) (*TakedownInput, error) { 71 + if !strings.Contains(input, "://") { 72 + input = "https://" + input 73 + } 74 + u, err := url.Parse(input) 65 75 if err != nil { 66 - return nil, err 76 + return nil, fmt.Errorf("invalid URL: %w", err) 77 + } 78 + 79 + parts := strings.Split(strings.Trim(u.Path, "/"), "/") 80 + if len(parts) == 0 || parts[0] == "" { 81 + // No path — treat the host as the identifier (e.g. "alice.bsky.social/"). 82 + return resolveBareIdentifier(ctx, u.Host) 67 83 } 68 84 69 - return &TakedownInput{ 70 - DID: did, 71 - Handle: handle, 72 - Repository: repo, 73 - }, nil 85 + switch parts[0] { 86 + case "u": 87 + if len(parts) < 2 || parts[1] == "" { 88 + return nil, fmt.Errorf("missing handle in /u/<handle>") 89 + } 90 + return resolveBareIdentifier(ctx, parts[1]) 91 + case "r": 92 + if len(parts) < 3 || parts[1] == "" || parts[2] == "" { 93 + return nil, fmt.Errorf("missing handle or repo in /r/<handle>/<repo>") 94 + } 95 + base, err := resolveBareIdentifier(ctx, parts[1]) 96 + if err != nil { 97 + return nil, err 98 + } 99 + // parts[2] only — discard any /digest/..., /tags/..., etc. trailing path. 100 + base.Repository = parts[2] 101 + return base, nil 102 + default: 103 + return nil, fmt.Errorf("unsupported URL path %q (expected /u/<handle> or /r/<handle>/<repo>)", u.Path) 104 + } 74 105 } 75 106 107 + // parseATURI parses an at:// URI. The authority is a DID or handle; the path's third 108 + // segment (rkey) becomes the repo for repo-level takedowns. Fragment and query are 109 + // stripped first since paste-of-browser-AT-URI may include them. 76 110 func parseATURI(ctx context.Context, uri string) (*TakedownInput, error) { 77 - // at://did:plc:xyz/collection/rkey 78 111 trimmed := strings.TrimPrefix(uri, "at://") 112 + if idx := strings.IndexAny(trimmed, "#?"); idx >= 0 { 113 + trimmed = trimmed[:idx] 114 + } 79 115 parts := strings.SplitN(trimmed, "/", 3) 116 + authority := parts[0] 117 + if authority == "" { 118 + return nil, fmt.Errorf("at:// URI missing authority") 119 + } 80 120 81 - did := parts[0] 82 - if !strings.HasPrefix(did, "did:") { 83 - // It's a handle 84 - resolvedDID, handle, err := resolveIdentifier(ctx, did) 121 + var ( 122 + did, handle string 123 + err error 124 + ) 125 + if strings.HasPrefix(authority, "did:") { 126 + did = authority 127 + _, handle, _, _ = atproto.ResolveIdentity(ctx, did) 128 + } else { 129 + did, handle, err = resolveIdentifier(ctx, authority) 85 130 if err != nil { 86 131 return nil, err 87 132 } 88 - did = resolvedDID 89 - if len(parts) >= 3 { 90 - return &TakedownInput{DID: did, Handle: handle, Repository: parts[2]}, nil 91 - } 92 - return &TakedownInput{DID: did, Handle: handle}, nil 93 133 } 94 134 95 - // Resolve handle from DID 96 - _, handle, _, _ := atproto.ResolveIdentity(ctx, did) 97 - 98 - if len(parts) < 3 { 99 - // User-level takedown 100 - return &TakedownInput{DID: did, Handle: handle}, nil 135 + out := &TakedownInput{DID: did, Handle: handle} 136 + if len(parts) >= 3 { 137 + out.Repository = parts[2] 101 138 } 139 + return out, nil 140 + } 102 141 103 - // Extract repository from rkey (third part) 104 - repo := parts[2] 105 - return &TakedownInput{DID: did, Handle: handle, Repository: repo}, nil 142 + // resolveBareIdentifier resolves a handle or DID to a user-level TakedownInput. When 143 + // the input is already a DID, the resolve is best-effort (DID is the source of truth; 144 + // handle is just for display) and we don't fail if PLC/web resolution is unreachable. 145 + // For a handle, resolution is required since we need a DID to label. 146 + func resolveBareIdentifier(ctx context.Context, id string) (*TakedownInput, error) { 147 + if strings.HasPrefix(id, "did:") { 148 + _, handle, _, _ := atproto.ResolveIdentity(ctx, id) 149 + return &TakedownInput{DID: id, Handle: handle}, nil 150 + } 151 + did, handle, err := resolveIdentifier(ctx, id) 152 + if err != nil { 153 + return nil, err 154 + } 155 + return &TakedownInput{DID: did, Handle: handle}, nil 106 156 } 107 157 108 158 func resolveIdentifier(ctx context.Context, identifier string) (did, handle string, err error) { ··· 124 174 125 175 // ExecuteTakedown creates takedown labels for a repo or user. 126 176 func (s *Server) ExecuteTakedown(ctx context.Context, input *TakedownInput) (*TakedownResult, error) { 127 - src := s.config.DID() 177 + src := s.did 128 178 now := time.Now().UTC() 129 179 result := &TakedownResult{ 130 180 DID: input.DID, ··· 143 193 SubjectDID: input.DID, 144 194 SubjectRepo: "", 145 195 } 196 + if err := label.Sign(s.signingKey); err != nil { 197 + return nil, fmt.Errorf("failed to sign user-level label: %w", err) 198 + } 146 199 if _, err := CreateLabel(s.db, label); err != nil { 147 200 return nil, fmt.Errorf("failed to create user-level label: %w", err) 148 201 } 202 + s.hub.Broadcast(label) 149 203 result.Labels = append(result.Labels, *label) 150 204 slog.Info("Created user-level takedown", "did", input.DID, "handle", input.Handle) 151 205 return result, nil ··· 168 222 SubjectDID: input.DID, 169 223 SubjectRepo: input.Repository, 170 224 } 225 + if err := summaryLabel.Sign(s.signingKey); err != nil { 226 + return nil, fmt.Errorf("failed to sign summary label: %w", err) 227 + } 171 228 if _, err := CreateLabel(s.db, summaryLabel); err != nil { 172 229 return nil, fmt.Errorf("failed to create summary label: %w", err) 173 230 } 231 + s.hub.Broadcast(summaryLabel) 174 232 result.Labels = append(result.Labels, *summaryLabel) 175 233 176 234 slog.Info("Created repo-level takedown", ··· 225 283 SubjectDID: did, 226 284 SubjectRepo: repo, 227 285 } 286 + if err := label.Sign(s.signingKey); err != nil { 287 + slog.Warn("Failed to sign label", "uri", uri, "error", err) 288 + continue 289 + } 228 290 if _, err := CreateLabel(s.db, label); err != nil { 229 291 slog.Warn("Failed to create label", "uri", uri, "error", err) 230 292 continue 231 293 } 294 + s.hub.Broadcast(label) 232 295 labels = append(labels, *label) 233 296 } 234 297 } ··· 255 318 if err != nil { 256 319 http.Error(w, "Failed to list takedowns", http.StatusInternalServerError) 257 320 return 321 + } 322 + csrf := "" 323 + if session := SessionFromContext(r.Context()); session != nil { 324 + csrf = session.CSRFToken 258 325 } 259 326 260 327 w.Header().Set("Content-Type", "text/html") ··· 301 368 <td>%s</td> 302 369 <td><code>%s</code></td> 303 370 <td>%s</td> 304 - <td><form method="POST" action="/reverse"><input type="hidden" name="did" value="%s"><input type="hidden" name="repo" value="%s"><button type="submit" class="btn btn-danger" onclick="return confirm('Reverse this takedown?')">Reverse</button></form></td> 371 + <td><form method="POST" action="/reverse">%s<input type="hidden" name="did" value="%s"><input type="hidden" name="repo" value="%s"><button type="submit" class="btn btn-danger" onclick="return confirm('Reverse this takedown?')">Reverse</button></form></td> 305 372 </tr>`, 306 373 template.HTMLEscapeString(l.SubjectDID), 307 374 repoDisplay, 308 375 template.HTMLEscapeString(l.URI), 309 376 l.Cts.Format("2006-01-02 15:04"), 377 + csrfInputHTML(csrf), 310 378 template.HTMLEscapeString(l.SubjectDID), 311 379 template.HTMLEscapeString(l.SubjectRepo), 312 380 ) ··· 320 388 func (s *Server) handleTakedownForm(w http.ResponseWriter, r *http.Request) { 321 389 msg := r.URL.Query().Get("msg") 322 390 errorMsg := r.URL.Query().Get("error") 391 + csrf := "" 392 + if session := SessionFromContext(r.Context()); session != nil { 393 + csrf = session.CSRFToken 394 + } 323 395 324 396 w.Header().Set("Content-Type", "text/html") 325 397 fmt.Fprintf(w, `<!DOCTYPE html> ··· 353 425 fmt.Fprintf(w, `<div class="error">%s</div>`, template.HTMLEscapeString(errorMsg)) 354 426 } 355 427 356 - fmt.Fprint(w, ` 428 + fmt.Fprintf(w, ` 357 429 <form method="POST" action="/takedown"> 430 + %s 358 431 <label for="target"><strong>Target</strong></label> 359 - <input type="text" id="target" name="target" placeholder="atcr.io/r/handle/repo, at://did/collection/rkey, or handle" required> 360 - <p class="help">Accepts repo URLs, AT URIs, handles, or DIDs. Omit the repo for a user-level takedown.</p> 432 + <input type="text" id="target" name="target" placeholder="/r/handle/repo, /u/handle, at://did/collection/rkey, handle, or did:..." required> 433 + <p class="help">Repo: <code>/r/handle/repo</code> (or full atcr.io URL). User-level: <code>/u/handle</code>, a bare handle, or a DID. AT URIs (<code>at://...</code>) also work.</p> 361 434 <br> 362 435 <button type="submit" class="btn" onclick="return confirm('Issue takedown? This will suppress the content immediately.')">Issue Takedown</button> 363 436 </form> 364 - </body></html>`) 437 + </body></html>`, csrfInputHTML(csrf)) 365 438 } 366 439 367 440 func (s *Server) handleTakedownSubmit(w http.ResponseWriter, r *http.Request) { ··· 399 472 return 400 473 } 401 474 402 - src := s.config.DID() 403 - var err error 475 + src := s.did 476 + var ( 477 + negs []Label 478 + err error 479 + ) 404 480 if repo == "" { 405 - err = NegateUserLabels(s.db, src, did) 481 + negs, err = NegateUserLabels(s.db, s.signingKey, src, did) 406 482 } else { 407 - err = NegateRepoLabels(s.db, src, did, repo) 483 + negs, err = NegateRepoLabels(s.db, s.signingKey, src, did, repo) 408 484 } 409 485 410 486 if err != nil { ··· 413 489 return 414 490 } 415 491 416 - slog.Info("Reversed takedown", "did", did, "repo", repo) 492 + for i := range negs { 493 + s.hub.Broadcast(&negs[i]) 494 + } 495 + 496 + slog.Info("Reversed takedown", "did", did, "repo", repo, "negations", len(negs)) 417 497 http.Redirect(w, r, "/", http.StatusFound) 418 498 }
+67 -36
pkg/labeler/takedown_test.go
··· 5 5 "testing" 6 6 ) 7 7 8 - func TestParseTakedownInput_RepoURL(t *testing.T) { 9 - // These tests only exercise parsing logic, not PDS resolution. 10 - // ResolveIdentity calls are tested with mock server below. 8 + // TestParseTakedownURL exercises the URL-shaped parsing path offline by feeding only 9 + // inputs whose identifier is a DID — DIDs short-circuit ResolveIdentity entirely, so 10 + // this whole table runs without network. Every case here came up while debugging 11 + // browser-paste bugs (URL fragments, query strings, trailing /digest/... segments). 12 + func TestParseTakedownURL(t *testing.T) { 11 13 tests := []struct { 12 14 name string 13 15 input string 16 + wantDID string 14 17 wantRepo string 18 + wantErr bool 15 19 }{ 16 - {"full URL", "https://atcr.io/r/handle/myimage", "myimage"}, 17 - {"no scheme", "atcr.io/r/handle/myimage", "myimage"}, 18 - {"handle/repo", "handle/myimage", "myimage"}, 19 - {"trailing slash", "atcr.io/r/handle/myimage/", "myimage"}, 20 - {"custom domain", "https://registry.example.com/r/handle/myimage", "myimage"}, 21 - } 20 + // /r/<handle>/<repo> 21 + {"r path with scheme", "https://atcr.io/r/did:plc:abc/myimage", "did:plc:abc", "myimage", false}, 22 + {"r path no scheme", "atcr.io/r/did:plc:abc/myimage", "did:plc:abc", "myimage", false}, 23 + {"r path trailing slash", "atcr.io/r/did:plc:abc/myimage/", "did:plc:abc", "myimage", false}, 24 + {"r path custom domain", "https://seamark.dev/r/did:plc:abc/myimage", "did:plc:abc", "myimage", false}, 25 + {"r path subroute (digest)", "https://atcr.io/r/did:plc:abc/myimage/digest/sha256-deadbeef", "did:plc:abc", "myimage", false}, 26 + {"r path with hash fragment", "https://atcr.io/r/did:plc:abc/myimage#overview", "did:plc:abc", "myimage", false}, 27 + {"r path with query string", "https://atcr.io/r/did:plc:abc/myimage?tag=latest", "did:plc:abc", "myimage", false}, 28 + {"r path missing repo", "https://atcr.io/r/did:plc:abc", "", "", true}, 22 29 30 + // /u/<handle> 31 + {"u path with scheme", "https://atcr.io/u/did:plc:abc", "did:plc:abc", "", false}, 32 + {"u path no scheme", "atcr.io/u/did:plc:abc", "did:plc:abc", "", false}, 33 + {"u path with hash", "https://atcr.io/u/did:plc:abc#tab", "did:plc:abc", "", false}, 34 + {"u path missing handle", "https://atcr.io/u", "", "", true}, 35 + 36 + // Unknown route 37 + {"unsupported path", "https://atcr.io/foo/did:plc:abc", "", "", true}, 38 + } 23 39 for _, tt := range tests { 24 40 t.Run(tt.name, func(t *testing.T) { 25 - // These will fail on ResolveIdentity since there's no real PDS, 26 - // but we can at least verify the parsing doesn't panic 27 - _, err := ParseTakedownInput(context.Background(), tt.input) 28 - if err == nil { 29 - t.Skip("ResolveIdentity succeeded unexpectedly (network available)") 41 + got, err := parseTakedownURL(context.Background(), tt.input) 42 + if tt.wantErr { 43 + if err == nil { 44 + t.Errorf("expected error, got %+v", got) 45 + } 46 + return 30 47 } 31 - // The error should be from resolution, not parsing 32 48 if err != nil { 33 - t.Logf("Expected resolution error: %v", err) 49 + t.Fatalf("unexpected error: %v", err) 50 + } 51 + if got.DID != tt.wantDID { 52 + t.Errorf("DID = %q, want %q", got.DID, tt.wantDID) 53 + } 54 + if got.Repository != tt.wantRepo { 55 + t.Errorf("Repository = %q, want %q", got.Repository, tt.wantRepo) 34 56 } 35 57 }) 36 58 } 37 59 } 38 60 61 + func TestParseTakedownInput_BareDID(t *testing.T) { 62 + // Bare DIDs short-circuit network resolution. 63 + got, err := ParseTakedownInput(context.Background(), "did:plc:abc123") 64 + if err != nil { 65 + t.Fatalf("unexpected error: %v", err) 66 + } 67 + if got.DID != "did:plc:abc123" { 68 + t.Errorf("DID = %q, want did:plc:abc123", got.DID) 69 + } 70 + if got.Repository != "" { 71 + t.Errorf("Repository = %q, want empty (user-level)", got.Repository) 72 + } 73 + } 74 + 39 75 func TestParseTakedownInput_ATURI(t *testing.T) { 40 76 tests := []struct { 41 77 name string ··· 61 97 "did:plc:xyz", 62 98 "sha256-deadbeef", 63 99 }, 100 + { 101 + "AT URI with hash fragment", 102 + "at://did:plc:abc#frag", 103 + "did:plc:abc", 104 + "", 105 + }, 64 106 } 65 - 66 107 for _, tt := range tests { 67 108 t.Run(tt.name, func(t *testing.T) { 68 - input, err := ParseTakedownInput(context.Background(), tt.input) 109 + got, err := ParseTakedownInput(context.Background(), tt.input) 69 110 if err != nil { 70 - // Resolution may fail for handle-based AT URIs 71 - t.Logf("Parse error (may be expected): %v", err) 72 - return 111 + t.Fatalf("unexpected error: %v", err) 73 112 } 74 - if input.DID != tt.wantDID { 75 - t.Errorf("DID = %q, want %q", input.DID, tt.wantDID) 113 + if got.DID != tt.wantDID { 114 + t.Errorf("DID = %q, want %q", got.DID, tt.wantDID) 76 115 } 77 - if input.Repository != tt.wantRepo { 78 - t.Errorf("Repository = %q, want %q", input.Repository, tt.wantRepo) 116 + if got.Repository != tt.wantRepo { 117 + t.Errorf("Repository = %q, want %q", got.Repository, tt.wantRepo) 79 118 } 80 119 }) 81 120 } 82 121 } 83 122 84 - func TestParseTakedownInput_DID(t *testing.T) { 85 - // Direct DID input (user-level takedown) 86 - input, err := ParseTakedownInput(context.Background(), "at://did:plc:abc123") 87 - if err != nil { 88 - t.Fatalf("unexpected error: %v", err) 89 - } 90 - if input.DID != "did:plc:abc123" { 91 - t.Errorf("DID = %q, want did:plc:abc123", input.DID) 92 - } 93 - if input.Repository != "" { 94 - t.Errorf("Repository = %q, want empty (user-level)", input.Repository) 123 + func TestParseTakedownInput_Empty(t *testing.T) { 124 + if _, err := ParseTakedownInput(context.Background(), ""); err == nil { 125 + t.Error("expected error for empty input") 95 126 } 96 127 } 97 128