this repo has no description
0
fork

Configure Feed

Select the types of activity you want to include in your feed.

refactor SQL tables

+559 -493
+22 -25
cmd/relayered/handlers.go
··· 11 11 "strings" 12 12 13 13 comatproto "github.com/bluesky-social/indigo/api/atproto" 14 - "github.com/bluesky-social/indigo/cmd/relayered/relay" 14 + "github.com/bluesky-social/indigo/atproto/syntax" 15 15 "github.com/bluesky-social/indigo/cmd/relayered/relay/models" 16 16 "github.com/bluesky-social/indigo/xrpc" 17 17 ··· 128 128 } 129 129 130 130 // Fetch the repo roots for each user 131 - for i := range accounts { 132 - user := accounts[i] 133 - 134 - root, err := s.relay.GetRepoRoot(ctx, user.ID) 131 + for i, acc := range accounts { 132 + repo, err := s.relay.GetAccountRepo(ctx, acc.UID) 135 133 if err != nil { 136 - s.logger.Error("failed to get repo root", "err", err, "did", user.Did) 137 - return nil, echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("failed to get repo root for (%s): %v", user.Did, err.Error())) 134 + s.logger.Error("failed to get repo root", "err", err, "did", acc.DID) 135 + return nil, echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("failed to get repo root for (%s): %v", acc.DID, err.Error())) 138 136 } 139 137 140 138 resp.Repos[i] = &comatproto.SyncListRepos_Repo{ 141 - Did: user.Did, 142 - Head: root.String(), 139 + Did: acc.DID, 140 + Head: repo.CommitData, // XXX: is this what is expected here? 143 141 } 144 142 } 145 143 146 144 // If this is not the last page, set the cursor 147 145 if len(accounts) >= limit && len(accounts) > 1 { 148 - nextCursor := fmt.Sprintf("%d", accounts[len(accounts)-1].ID) 146 + nextCursor := fmt.Sprintf("%d", accounts[len(accounts)-1].UID) 149 147 resp.Cursor = &nextCursor 150 148 } 151 149 152 150 return resp, nil 153 151 } 154 152 155 - func (s *Service) handleComAtprotoSyncGetLatestCommit(ctx context.Context, did string) (*comatproto.SyncGetLatestCommit_Output, error) { 156 - u, err := s.relay.LookupUserByDid(ctx, did) 153 + func (s *Service) handleComAtprotoSyncGetLatestCommit(ctx context.Context, rawDID string) (*comatproto.SyncGetLatestCommit_Output, error) { 154 + did, err := syntax.ParseDID(rawDID) 155 + if err != nil { 156 + return nil, fmt.Errorf("invalid DID parameter: %w", err) 157 + } 158 + acc, err := s.relay.GetAccount(ctx, did) 157 159 if err != nil { 158 160 if errors.Is(err, gorm.ErrRecordNotFound) { 159 161 return nil, echo.NewHTTPError(http.StatusNotFound, "user not found") ··· 161 163 return nil, echo.NewHTTPError(http.StatusInternalServerError, "failed to lookup user") 162 164 } 163 165 164 - if u.GetTakenDown() { 166 + if acc.Status == models.AccountStatusTakendown { 165 167 return nil, fmt.Errorf("account was taken down by the Relay") 166 168 } 167 169 168 - ustatus := u.GetUpstreamStatus() 169 - if ustatus == models.AccountStatusTakendown { 170 + if acc.UpstreamStatus == models.AccountStatusTakendown { 170 171 return nil, fmt.Errorf("account was taken down by its PDS") 171 172 } 172 173 173 - if ustatus == models.AccountStatusDeactivated { 174 + if acc.Status == models.AccountStatusDeactivated { 174 175 return nil, fmt.Errorf("account is temporarily deactivated") 175 176 } 176 177 177 - if ustatus == models.AccountStatusSuspended { 178 + if acc.Status == models.AccountStatusSuspended { 178 179 return nil, fmt.Errorf("account is suspended by its PDS") 179 180 } 180 181 181 - prevState, err := s.relay.GetAccountPreviousState(ctx, u.ID) 182 + repo, err := s.relay.GetAccountRepo(ctx, acc.UID) 182 183 if err != nil { 183 - if errors.Is(err, gorm.ErrRecordNotFound) { 184 - return nil, relay.ErrAccountLastUnavailable 185 - } 186 - s.logger.Error("user db err", "err", err) 187 - return nil, fmt.Errorf("user prev db err, %w", err) 184 + return nil, err 188 185 } 189 186 190 187 return &comatproto.SyncGetLatestCommit_Output{ 191 - Cid: prevState.Cid.CID.String(), 192 - Rev: prevState.Rev, 188 + Cid: repo.CommitData, // XXX: this is probably not what is wanted here 189 + Rev: repo.Rev, 193 190 }, nil 194 191 } 195 192
+4 -4
cmd/relayered/main.go
··· 221 221 svcConfig := DefaultServiceConfig() 222 222 relayConfig := relay.DefaultRelayConfig() 223 223 relayConfig.SSL = !cctx.Bool("crawl-insecure-ws") 224 - relayConfig.ConcurrencyPerPDS = cctx.Int64("host-concurrency") 225 - relayConfig.MaxQueuePerPDS = cctx.Int64("max-queue-per-host") 224 + relayConfig.ConcurrencyPerHost = cctx.Int64("host-concurrency") 225 + relayConfig.MaxQueuePerHost = cctx.Int64("max-queue-per-host") 226 226 relayConfig.DefaultRepoLimit = cctx.Int64("default-account-limit") 227 227 ratelimitBypass := cctx.String("bsky-social-rate-limit-skip") 228 - relayConfig.ApplyPDSClientSettings = makePdsClientSetup(ratelimitBypass) 228 + relayConfig.ApplyHostClientSettings = makePdsClientSetup(ratelimitBypass) 229 229 nextCrawlers := cctx.StringSlice("forward-crawl-requests") 230 230 if len(nextCrawlers) > 0 { 231 231 nextCrawlerUrls := make([]*url.URL, len(nextCrawlers)) ··· 324 324 } 325 325 } 326 326 } else { 327 - // Generic PDS timeout 327 + // Generic host timeout 328 328 c.Client.Timeout = time.Minute * 1 329 329 } 330 330 }
+127 -128
cmd/relayered/relay/account.go
··· 13 13 "github.com/bluesky-social/indigo/cmd/relayered/relay/models" 14 14 "github.com/bluesky-social/indigo/xrpc" 15 15 16 - "github.com/ipfs/go-cid" 17 16 "gorm.io/gorm" 18 17 ) 19 18 20 - var ( 21 - ErrAccountNotFound = errors.New("account not found") 22 - ErrAccountLastUnavailable = errors.New("account last commit not available") 23 - ErrCommitNoUser = errors.New("commit no user") // TODO 24 - ) 25 - 19 + // this function with exact name and args implements the `diskpersist.UidSource` interface 26 20 func (r *Relay) DidToUid(ctx context.Context, did string) (uint64, error) { 27 - xu, err := r.LookupUserByDid(ctx, did) 21 + // NOTE: assuming DID is correct syntax (this is usually "loop back") 22 + xu, err := r.GetAccount(ctx, syntax.DID(did)) 28 23 if err != nil { 29 24 return 0, err 30 25 } 31 26 if xu == nil { 32 27 return 0, ErrAccountNotFound 33 28 } 34 - return xu.ID, nil 29 + return xu.UID, nil 35 30 } 36 31 37 - func (r *Relay) LookupUserByDid(ctx context.Context, did string) (*models.Account, error) { 38 - ctx, span := tracer.Start(ctx, "lookupUserByDid") 32 + /* XXX: unused? 33 + func (r *Relay) GetAccountByUID(ctx context.Context, uid uint64) (*models.Account, error) { 34 + ctx, span := tracer.Start(ctx, "getAccount") 39 35 defer span.End() 40 36 41 - cu, ok := r.userCache.Get(did) 42 - if ok { 43 - return cu, nil 44 - } 45 - 46 - var u models.Account 47 - if err := r.db.Find(&u, "did = ?", did).Error; err != nil { 37 + var acc models.Account 38 + if err := r.db.First(&acc, uid).Error; err != nil { 39 + if errors.Is(result.Error, gorm.ErrRecordNotFound) { 40 + return nil, ErrAccountNotFound 41 + } 48 42 return nil, err 49 43 } 50 44 51 - if u.ID == 0 { 52 - return nil, gorm.ErrRecordNotFound 45 + // TODO: is this further check needed? 46 + if acc.ID == 0 { 47 + return nil, ErrAccountNotFound 53 48 } 54 49 55 - r.userCache.Add(did, &u) 56 - 57 - return &u, nil 50 + return &acc, nil 58 51 } 52 + */ 59 53 60 - func (r *Relay) LookupUserByUID(ctx context.Context, uid uint64) (*models.Account, error) { 61 - ctx, span := tracer.Start(ctx, "lookupUserByUID") 54 + func (r *Relay) GetAccount(ctx context.Context, did syntax.DID) (*models.Account, error) { 55 + ctx, span := tracer.Start(ctx, "getAccount") 62 56 defer span.End() 63 57 64 - var u models.Account 65 - if err := r.db.Find(&u, "id = ?", uid).Error; err != nil { 58 + /* XXX 59 + cu, ok := r.accountCache.Get(did) 60 + if ok { 61 + return cu, nil 62 + } 63 + */ 64 + 65 + var acc models.Account 66 + // XXX: this needs to be a "find where" 67 + if err := r.db.Where("did = ?", did).First(&acc).Error; err != nil { 68 + if errors.Is(err, gorm.ErrRecordNotFound) { 69 + return nil, ErrAccountNotFound 70 + } 66 71 return nil, err 67 72 } 68 73 69 - if u.ID == 0 { 70 - return nil, gorm.ErrRecordNotFound 74 + // TODO: is this further check needed? 75 + if acc.UID == 0 { 76 + return nil, ErrAccountNotFound 71 77 } 72 78 73 - return &u, nil 79 + /* XXX: 80 + r.accountCache.Add(did, &u) 81 + */ 82 + 83 + return &acc, nil 74 84 } 75 85 76 - func (r *Relay) newUser(ctx context.Context, host *models.PDS, did string) (*models.Account, error) { 86 + func (r *Relay) GetAccountRepo(ctx context.Context, uid uint64) (*models.AccountRepo, error) { 87 + var repo models.AccountRepo 88 + if err := r.db.First(&repo, uid).Error; err != nil { 89 + if errors.Is(err, gorm.ErrRecordNotFound) { 90 + return nil, ErrAccountRepoNotFound 91 + } 92 + // TODO: log here? 93 + return nil, err 94 + } 95 + return &repo, nil 96 + } 97 + 98 + /* XXX: refactor in to syncHostAccount? */ 99 + func (r *Relay) CreateAccount(ctx context.Context, host *models.Host, did syntax.DID) (*models.Account, error) { 77 100 newUsersDiscovered.Inc() 78 101 start := time.Now() 79 - account, err := r.syncPDSAccount(ctx, did, host, nil) 102 + account, err := r.syncHostAccount(ctx, did, host, nil) 80 103 newUserDiscoveryDuration.Observe(time.Since(start).Seconds()) 81 104 if err != nil { 82 - repoCommitsResultCounter.WithLabelValues(host.Host, "uerr").Inc() 105 + repoCommitsResultCounter.WithLabelValues(host.Hostname, "uerr").Inc() 83 106 return nil, fmt.Errorf("fed event create external user: %w", err) 84 107 } 85 108 return account, nil 86 109 } 87 110 88 - // syncPDSAccount ensures that a DID has an account record in the database attached to a PDS record in the database 111 + // syncHostAccount ensures that a DID has an account record in the database attached to a Host record in the database 89 112 // Some fields may be updated if needed. 90 113 // did is the user 91 - // host is the PDS we received this from, not necessarily the canonical PDS in the DID document 114 + // host is the Host we received this from, not necessarily the canonical Host in the DID document 92 115 // cachedAccount is (optionally) the account that we have already looked up from cache or database 93 - func (r *Relay) syncPDSAccount(ctx context.Context, did string, host *models.PDS, cachedAccount *models.Account) (*models.Account, error) { 94 - ctx, span := tracer.Start(ctx, "syncPDSAccount") 116 + func (r *Relay) syncHostAccount(ctx context.Context, did syntax.DID, host *models.Host, cachedAccount *models.Account) (*models.Account, error) { 117 + ctx, span := tracer.Start(ctx, "syncHostAccount") 95 118 defer span.End() 96 119 97 120 externalUserCreationAttempts.Inc() 98 121 99 122 r.Logger.Debug("create external user", "did", did) 100 123 101 - // lookup identity so that we know a DID's canonical source PDS 102 - pdid, err := syntax.ParseDID(did) 103 - if err != nil { 104 - return nil, fmt.Errorf("bad did %#v, %w", did, err) 105 - } 106 - ident, err := r.dir.LookupDID(ctx, pdid) 124 + // lookup identity so that we know a DID's canonical source Host 125 + ident, err := r.dir.LookupDID(ctx, did) 107 126 if err != nil { 108 127 return nil, fmt.Errorf("no ident for did %s, %w", did, err) 109 128 } 110 - if len(ident.Services) == 0 { 111 - return nil, fmt.Errorf("no services for did %s", did) 112 - } 113 - pdsRelay, ok := ident.Services["atproto_pds"] 114 - if !ok { 115 - return nil, fmt.Errorf("no atproto_pds service for did %s", did) 129 + pdsEndpoint := ident.PDSEndpoint() 130 + if pdsEndpoint == "" { 131 + return nil, fmt.Errorf("account has no PDS endpoint registered: %s", did) 116 132 } 117 - durl, err := url.Parse(pdsRelay.URL) 133 + durl, err := url.Parse(pdsEndpoint) 118 134 if err != nil { 119 - return nil, fmt.Errorf("pds bad url %#v, %w", pdsRelay.URL, err) 135 + return nil, fmt.Errorf("account has bad url (%#v): %w", pdsEndpoint, err) 120 136 } 121 137 122 - // is the canonical PDS banned? 138 + // is the canonical Host banned? 123 139 ban, err := r.DomainIsBanned(ctx, durl.Host) 124 140 if err != nil { 125 141 return nil, fmt.Errorf("failed to check pds ban status: %w", err) ··· 132 148 durl.Scheme = "http" 133 149 } 134 150 135 - var canonicalHost *models.PDS 136 - if host.Host == durl.Host { 137 - // we got the message from the canonical PDS, convenient! 151 + var canonicalHost *models.Host 152 + if host.Hostname == durl.Host { 153 + // we got the message from the canonical Host, convenient! 138 154 canonicalHost = host 139 155 } else { 140 156 // we got the message from an intermediate relay 141 - // check our db for info on canonical PDS 142 - var peering models.PDS 143 - if err := r.db.Find(&peering, "host = ?", durl.Host).Error; err != nil { 144 - r.Logger.Error("failed to find pds", "host", durl.Host) 157 + // check our db for info on canonical Host 158 + // XXX: rename "peering" 159 + var peering models.Host 160 + if err := r.db.Find(&peering, "hostname = ?", durl.Host).Error; err != nil { 161 + r.Logger.Error("failed to find host", "host", durl.Host) 145 162 return nil, err 146 163 } 147 164 canonicalHost = &peering 148 165 } 149 166 150 - if canonicalHost.Blocked { 151 - return nil, fmt.Errorf("refusing to create user with blocked PDS") 167 + if canonicalHost.Status == models.HostStatusBanned { 168 + return nil, fmt.Errorf("refusing to create user with banned Host") 152 169 } 153 170 154 171 if canonicalHost.ID == 0 { 155 - // we got an event from a non-canonical PDS (an intermediate relay) 156 - // a non-canonical PDS we haven't seen before; ping it to make sure it's real 157 - // TODO: what do we actually want to track about the source we immediately got this message from vs the canonical PDS? 172 + // we got an event from a non-canonical Host (an intermediate relay) 173 + // a non-canonical Host we haven't seen before; ping it to make sure it's real 174 + // TODO: what do we actually want to track about the source we immediately got this message from vs the canonical Host? 158 175 r.Logger.Warn("pds discovered in new user flow", "pds", durl.String(), "did", did) 159 176 160 - // Do a trivial API request against the PDS to verify that it exists 177 + // Do a trivial API request against the Host to verify that it exists 161 178 pclient := &xrpc.Client{Host: durl.String()} 162 - if r.Config.ApplyPDSClientSettings != nil { 163 - r.Config.ApplyPDSClientSettings(pclient) 179 + if r.Config.ApplyHostClientSettings != nil { 180 + r.Config.ApplyHostClientSettings(pclient) 164 181 } 165 182 cfg, err := comatproto.ServerDescribeServer(ctx, pclient) 166 183 if err != nil { ··· 172 189 _ = cfg 173 190 174 191 // could check other things, a valid response is good enough for now 175 - canonicalHost.Host = durl.Host 176 - canonicalHost.SSL = (durl.Scheme == "https") 177 - canonicalHost.RateLimit = float64(r.Slurper.Config.DefaultPerSecondLimit) 178 - canonicalHost.HourlyEventLimit = r.Slurper.Config.DefaultPerHourLimit 179 - canonicalHost.DailyEventLimit = r.Slurper.Config.DefaultPerDayLimit 180 - canonicalHost.RepoLimit = r.Slurper.Config.DefaultRepoLimit 192 + canonicalHost.Hostname = durl.Host 193 + canonicalHost.NoSSL = !(durl.Scheme == "https") 194 + // XXX canonicalHost.RateLimit = float64(r.Slurper.Config.DefaultPerSecondLimit) 195 + // XXX canonicalHost.HourlyEventLimit = r.Slurper.Config.DefaultPerHourLimit 196 + // XXX canonicalHost.DailyEventLimit = r.Slurper.Config.DefaultPerDayLimit 197 + canonicalHost.AccountLimit = r.Slurper.Config.DefaultRepoLimit 181 198 182 - if r.Config.SSL && !canonicalHost.SSL { 183 - return nil, fmt.Errorf("did references non-ssl PDS, this is disallowed in prod: %q %q", did, pdsRelay.URL) 199 + if r.Config.SSL && canonicalHost.NoSSL { 200 + return nil, fmt.Errorf("did references non-ssl Host, this is disallowed in prod: %q %q", did, pdsEndpoint) 184 201 } 185 202 186 203 if err := r.db.Create(&canonicalHost).Error; err != nil { ··· 192 209 panic("somehow failed to create a pds entry?") 193 210 } 194 211 195 - if canonicalHost.RepoCount >= canonicalHost.RepoLimit { 212 + if canonicalHost.AccountCount >= canonicalHost.AccountLimit { 196 213 // TODO: soft-limit / hard-limit ? create account in 'throttled' state, unless there are _really_ too many accounts 197 - return nil, fmt.Errorf("refusing to create user on PDS at max repo limit for pds %q", canonicalHost.Host) 214 + return nil, fmt.Errorf("refusing to create user on Host at max repo limit for pds %q", canonicalHost.Hostname) 198 215 } 199 216 200 217 // this lock just governs the lower half of this function ··· 202 219 defer r.extUserLk.Unlock() 203 220 204 221 if cachedAccount == nil { 205 - cachedAccount, err = r.LookupUserByDid(ctx, did) 222 + cachedAccount, err = r.GetAccount(ctx, did) 206 223 } 207 224 if errors.Is(err, ErrAccountNotFound) || errors.Is(err, gorm.ErrRecordNotFound) { 208 225 err = nil ··· 211 228 return nil, err 212 229 } 213 230 if cachedAccount != nil { 214 - caPDS := cachedAccount.GetPDS() 215 - if caPDS != canonicalHost.ID { 216 - // Account is now on a different PDS, update 231 + // XXX: caHost := cachedAccount.GetHost() 232 + caHost := cachedAccount.HostID 233 + if caHost != canonicalHost.ID { 234 + // Account is now on a different Host, update 217 235 err = r.db.Transaction(func(tx *gorm.DB) error { 218 - if caPDS != 0 { 219 - // decrement prior PDS's account count 220 - tx.Model(&models.PDS{}).Where("id = ?", caPDS).Update("repo_count", gorm.Expr("repo_count - 1")) 236 + if caHost != 0 { 237 + // decrement prior Host's account count 238 + tx.Model(&models.Host{}).Where("id = ?", caHost).Update("account_count", gorm.Expr("account_count - 1")) 221 239 } 222 - // update user's PDS ID 223 - res := tx.Model(models.Account{}).Where("id = ?", cachedAccount.ID).Update("pds", canonicalHost.ID) 240 + // update user's Host ID 241 + res := tx.Model(models.Account{}).Where("id = ?", cachedAccount.UID).Update("pds", canonicalHost.ID) 224 242 if res.Error != nil { 225 243 return fmt.Errorf("failed to update users pds: %w", res.Error) 226 244 } 227 - // increment new PDS's account count 228 - res = tx.Model(&models.PDS{}).Where("id = ? AND repo_count < repo_limit", canonicalHost.ID).Update("repo_count", gorm.Expr("repo_count + 1")) 245 + // increment new Host's account count 246 + res = tx.Model(&models.Host{}).Where("id = ? AND account_count < account_limit", canonicalHost.ID).Update("account_count", gorm.Expr("account_count + 1")) 229 247 return nil 230 248 }) 231 249 232 - cachedAccount.SetPDS(canonicalHost.ID) 250 + // XXX: cachedAccount.SetHost(canonicalHost.ID) 251 + cachedAccount.HostID = canonicalHost.ID 233 252 } 234 253 return cachedAccount, nil 235 254 } 236 255 237 256 newAccount := models.Account{ 238 - Did: did, 239 - PDS: canonicalHost.ID, 257 + DID: did.String(), 258 + HostID: canonicalHost.ID, 259 + Status: models.AccountStatusActive, 260 + UpstreamStatus: models.AccountStatusActive, 240 261 } 241 262 242 263 err = r.db.Transaction(func(tx *gorm.DB) error { 243 - res := tx.Model(&models.PDS{}).Where("id = ? AND repo_count < repo_limit", canonicalHost.ID).Update("repo_count", gorm.Expr("repo_count + 1")) 264 + res := tx.Model(&models.Host{}).Where("id = ? AND account_count < account_limit", canonicalHost.ID).Update("account_count", gorm.Expr("account_count + 1")) 244 265 if res.Error != nil { 245 - return fmt.Errorf("failed to increment repo count for pds %q: %w", canonicalHost.Host, res.Error) 266 + return fmt.Errorf("failed to increment repo count for pds %q: %w", canonicalHost.Hostname, res.Error) 246 267 } 268 + r.Logger.Warn("XXX creating new account", "did", newAccount.DID) 247 269 if terr := tx.Create(&newAccount).Error; terr != nil { 248 - r.Logger.Error("failed to create user", "did", newAccount.Did, "err", terr) 270 + r.Logger.Error("failed to create user", "did", newAccount.DID, "err", terr) 249 271 return fmt.Errorf("failed to create other pds user: %w", terr) 250 272 } 251 273 return nil ··· 255 277 return nil, err 256 278 } 257 279 258 - r.userCache.Add(did, &newAccount) 280 + r.accountCache.Add(did.String(), &newAccount) 259 281 260 282 return &newAccount, nil 261 283 } 262 284 263 - func (r *Relay) TakeDownRepo(ctx context.Context, did string) error { 264 - u, err := r.LookupUserByDid(ctx, did) 285 + func (r *Relay) UpdateAccountStatus(ctx context.Context, did syntax.DID, status models.AccountStatus) error { 286 + acc, err := r.GetAccount(ctx, did) 265 287 if err != nil { 266 288 return err 267 289 } 268 290 269 - if err := r.db.Model(models.Account{}).Where("id = ?", u.ID).Update("taken_down", true).Error; err != nil { 291 + if err := r.db.Model(models.Account{}).Where("uid = ?", acc.UID).Update("status", status).Error; err != nil { 270 292 return err 271 293 } 272 - u.SetTakenDown(true) 273 - 274 - // NOTE: not wiping events for user from backfill window 294 + // XXX: u.SetTakenDown(true) 275 295 296 + // NOTE: not wiping events for user from persister (backfill window) 276 297 return nil 277 298 } 278 299 279 - func (r *Relay) ReverseTakedown(ctx context.Context, did string) error { 280 - u, err := r.LookupUserByDid(ctx, did) 281 - if err != nil { 282 - return err 283 - } 284 - 285 - if err := r.db.Model(models.Account{}).Where("id = ?", u.ID).Update("taken_down", false).Error; err != nil { 286 - return err 287 - } 288 - u.SetTakenDown(false) 289 - 290 - return nil 291 - } 292 - 293 - func (r *Relay) GetAccountPreviousState(ctx context.Context, uid uint64) (*models.AccountPreviousState, error) { 294 - var prevState models.AccountPreviousState 295 - if err := r.db.First(&prevState, uid).Error; err != nil { 296 - if errors.Is(err, gorm.ErrRecordNotFound) { 297 - return nil, ErrAccountLastUnavailable 298 - } 299 - r.Logger.Error("user db err", "err", err) 300 - return nil, err 301 - } 302 - return &prevState, nil 303 - } 304 - 300 + /* XXX 305 301 func (r *Relay) GetRepoRoot(ctx context.Context, uid uint64) (cid.Cid, error) { 306 302 var prevState models.AccountPreviousState 307 303 err := r.db.First(&prevState, uid).Error 308 304 if err == nil { 309 305 return prevState.Cid.CID, nil 310 306 } else if errors.Is(err, gorm.ErrRecordNotFound) { 311 - return cid.Cid{}, ErrAccountLastUnavailable 307 + return cid.Cid{}, ErrAccountRepoNotFound 312 308 } else { 313 309 r.Logger.Error("user db err", "err", err) 314 310 return cid.Cid{}, fmt.Errorf("user prev db err, %w", err) 315 311 } 316 312 } 313 + */ 317 314 315 + /* 318 316 func (r *Relay) GetHostForDID(ctx context.Context, did string) (string, error) { 319 317 var pdsHostname string 320 318 // TODO: use gorm, not "Raw" ··· 324 322 } 325 323 return pdsHostname, nil 326 324 } 325 + */ 327 326 328 327 func (r *Relay) ListAccounts(ctx context.Context, cursor int64, limit int) ([]*models.Account, error) { 329 328
+11
cmd/relayered/relay/errors.go
··· 1 + package relay 2 + 3 + import ( 4 + "errors" 5 + ) 6 + 7 + var ( 8 + ErrHostNotFound = errors.New("unknown host or PDS") 9 + ErrAccountNotFound = errors.New("unknown account") 10 + ErrAccountRepoNotFound = errors.New("repository state not available") 11 + )
+106 -90
cmd/relayered/relay/firehose.go
··· 4 4 "context" 5 5 "errors" 6 6 "fmt" 7 - "strconv" 8 7 "time" 9 8 10 9 comatproto "github.com/bluesky-social/indigo/api/atproto" ··· 18 17 ) 19 18 20 19 // handleFedEvent() is the callback passed to Slurper called from Slurper.handleConnection() 21 - func (r *Relay) handleFedEvent(ctx context.Context, host *models.PDS, env *stream.XRPCStreamEvent) error { 20 + // XXX: evt not env 21 + func (r *Relay) handleFedEvent(ctx context.Context, host *models.Host, env *stream.XRPCStreamEvent) error { 22 22 ctx, span := tracer.Start(ctx, "handleFedEvent") 23 23 defer span.End() 24 24 25 25 start := time.Now() 26 26 defer func() { 27 - eventsHandleDuration.WithLabelValues(host.Host).Observe(time.Since(start).Seconds()) 27 + eventsHandleDuration.WithLabelValues(host.Hostname).Observe(time.Since(start).Seconds()) 28 28 }() 29 29 30 - EventsReceivedCounter.WithLabelValues(host.Host).Add(1) 30 + EventsReceivedCounter.WithLabelValues(host.Hostname).Add(1) 31 31 32 32 switch { 33 33 case env.RepoCommit != nil: 34 - repoCommitsReceivedCounter.WithLabelValues(host.Host).Add(1) 34 + repoCommitsReceivedCounter.WithLabelValues(host.Hostname).Add(1) 35 35 return r.handleCommit(ctx, host, env.RepoCommit) 36 36 case env.RepoSync != nil: 37 - repoSyncReceivedCounter.WithLabelValues(host.Host).Add(1) 37 + repoSyncReceivedCounter.WithLabelValues(host.Hostname).Add(1) 38 38 return r.handleSync(ctx, host, env.RepoSync) 39 39 case env.RepoHandle != nil: 40 - eventsWarningsCounter.WithLabelValues(host.Host, "handle").Add(1) 41 - // TODO: rate limit warnings per PDS before we (temporarily?) block them 40 + eventsWarningsCounter.WithLabelValues(host.Hostname, "handle").Add(1) 41 + // TODO: rate limit warnings per Host before we (temporarily?) block them 42 42 return nil 43 43 case env.RepoIdentity != nil: 44 44 r.Logger.Info("relay got identity event", "did", env.RepoIdentity.Did) 45 + 46 + did, err := syntax.ParseDID(env.RepoIdentity.Did) 47 + if err != nil { 48 + return fmt.Errorf("invalid DID in message: %w", err) 49 + } 50 + 45 51 // Flush any cached DID documents for this user 46 - r.purgeDidCache(ctx, env.RepoIdentity.Did) 52 + r.purgeDidCache(ctx, did.String()) 47 53 48 54 // Refetch the DID doc and update our cached keys and handle etc. 49 - account, err := r.syncPDSAccount(ctx, env.RepoIdentity.Did, host, nil) 55 + account, err := r.syncHostAccount(ctx, did, host, nil) 50 56 if err != nil { 51 57 return err 52 58 } ··· 54 60 // Broadcast the identity event to all consumers 55 61 err = r.Events.AddEvent(ctx, &stream.XRPCStreamEvent{ 56 62 RepoIdentity: &comatproto.SyncSubscribeRepos_Identity{ 57 - Did: env.RepoIdentity.Did, 63 + Did: did.String(), 58 64 Seq: env.RepoIdentity.Seq, 59 65 Time: env.RepoIdentity.Time, 60 66 Handle: env.RepoIdentity.Handle, 61 67 }, 62 - PrivUid: account.ID, 68 + PrivUid: account.UID, 63 69 }) 64 70 if err != nil { 65 - r.Logger.Error("failed to broadcast Identity event", "error", err, "did", env.RepoIdentity.Did) 71 + r.Logger.Error("failed to broadcast Identity event", "error", err, "did", did) 66 72 return fmt.Errorf("failed to broadcast Identity event: %w", err) 67 73 } 68 74 ··· 73 79 attribute.Int64("seq", env.RepoAccount.Seq), 74 80 attribute.Bool("active", env.RepoAccount.Active), 75 81 ) 82 + 83 + did, err := syntax.ParseDID(env.RepoAccount.Did) 84 + if err != nil { 85 + return fmt.Errorf("invalid DID in message: %w", err) 86 + } 76 87 77 88 if env.RepoAccount.Status != nil { 78 89 span.SetAttributes(attribute.String("repo_status", *env.RepoAccount.Status)) ··· 82 93 if !env.RepoAccount.Active && env.RepoAccount.Status == nil { 83 94 // TODO: semantics here aren't really clear 84 95 r.Logger.Warn("dropping invalid account event", "did", env.RepoAccount.Did, "active", env.RepoAccount.Active, "status", env.RepoAccount.Status) 85 - accountVerifyWarnings.WithLabelValues(host.Host, "nostat").Inc() 96 + accountVerifyWarnings.WithLabelValues(host.Hostname, "nostat").Inc() 86 97 return nil 87 98 } 88 99 89 100 // Flush any cached DID documents for this user 90 - r.purgeDidCache(ctx, env.RepoAccount.Did) 101 + r.purgeDidCache(ctx, did.String()) 91 102 92 - // Refetch the DID doc to make sure the PDS is still authoritative 93 - account, err := r.syncPDSAccount(ctx, env.RepoAccount.Did, host, nil) 103 + // Refetch the DID doc to make sure the Host is still authoritative 104 + account, err := r.syncHostAccount(ctx, did, host, nil) 94 105 if err != nil { 95 106 span.RecordError(err) 96 107 return err 97 108 } 98 109 99 - // Check if the PDS is still authoritative 110 + // Check if the Host is still authoritative 100 111 // if not we don't want to be propagating this account event 101 - if account.GetPDS() != host.ID && !r.Config.SkipAccountHostCheck { 112 + // XXX: lock 113 + if account.HostID != host.ID && !r.Config.SkipAccountHostCheck { 102 114 r.Logger.Error("account event from non-authoritative pds", 103 115 "seq", env.RepoAccount.Seq, 104 116 "did", env.RepoAccount.Did, 105 - "event_from", host.Host, 106 - "did_doc_declared_pds", account.GetPDS(), 117 + "event_from", host.Hostname, 118 + "did_doc_declared_pds", account.HostID, 107 119 "account_evt", env.RepoAccount, 108 120 ) 109 121 return fmt.Errorf("event from non-authoritative pds") ··· 112 124 // Process the account status change 113 125 repoStatus := models.AccountStatusActive 114 126 if !env.RepoAccount.Active && env.RepoAccount.Status != nil { 115 - repoStatus = *env.RepoAccount.Status 127 + repoStatus = models.AccountStatus(*env.RepoAccount.Status) 116 128 } 117 129 118 - account.SetUpstreamStatus(repoStatus) 130 + // XXX: lock, and parse 131 + account.UpstreamStatus = models.AccountStatus(repoStatus) 119 132 err = r.db.Save(account).Error 120 133 if err != nil { 121 134 span.RecordError(err) ··· 126 139 status := env.RepoAccount.Status 127 140 128 141 // override with local status 129 - if account.GetTakenDown() { 142 + // XXX: lock 143 + if account.Status == "takendown" { 130 144 shouldBeActive = false 131 - status = &models.AccountStatusTakendown 145 + s := string(models.AccountStatusTakendown) 146 + status = &s 132 147 } 133 148 134 149 // Broadcast the account event to all consumers ··· 140 155 Status: status, 141 156 Time: env.RepoAccount.Time, 142 157 }, 143 - PrivUid: account.ID, 158 + PrivUid: account.UID, 144 159 }) 145 160 if err != nil { 146 161 r.Logger.Error("failed to broadcast Account event", "error", err, "did", env.RepoAccount.Did) ··· 149 164 150 165 return nil 151 166 case env.RepoMigrate != nil: 152 - eventsWarningsCounter.WithLabelValues(host.Host, "migrate").Add(1) 153 - // TODO: rate limit warnings per PDS before we (temporarily?) block them 167 + eventsWarningsCounter.WithLabelValues(host.Hostname, "migrate").Add(1) 168 + // TODO: rate limit warnings per Host before we (temporarily?) block them 154 169 return nil 155 170 case env.RepoTombstone != nil: 156 - eventsWarningsCounter.WithLabelValues(host.Host, "tombstone").Add(1) 157 - // TODO: rate limit warnings per PDS before we (temporarily?) block them 171 + eventsWarningsCounter.WithLabelValues(host.Hostname, "tombstone").Add(1) 172 + // TODO: rate limit warnings per Host before we (temporarily?) block them 158 173 return nil 159 174 default: 160 175 return fmt.Errorf("invalid fed event") 161 176 } 162 177 } 163 178 164 - func (r *Relay) handleCommit(ctx context.Context, host *models.PDS, evt *comatproto.SyncSubscribeRepos_Commit) error { 165 - r.Logger.Debug("relay got repo append event", "seq", evt.Seq, "pdsHost", host.Host, "repo", evt.Repo) 179 + func (r *Relay) handleCommit(ctx context.Context, host *models.Host, evt *comatproto.SyncSubscribeRepos_Commit) error { 180 + r.Logger.Debug("relay got repo append event", "seq", evt.Seq, "host", host.Hostname, "repo", evt.Repo) 166 181 167 - account, err := r.LookupUserByDid(ctx, evt.Repo) 182 + did, err := syntax.ParseDID(evt.Repo) 183 + if err != nil { 184 + return fmt.Errorf("invalid DID in message: %w", err) 185 + } 186 + // XXX: did = did.Normalize() 187 + account, err := r.GetAccount(ctx, did) 168 188 if err != nil { 169 189 if !errors.Is(err, gorm.ErrRecordNotFound) { 170 - repoCommitsResultCounter.WithLabelValues(host.Host, "nou").Inc() 190 + repoCommitsResultCounter.WithLabelValues(host.Hostname, "nou").Inc() 171 191 return fmt.Errorf("looking up event user: %w", err) 172 192 } 173 193 174 - account, err = r.newUser(ctx, host, evt.Repo) 194 + account, err = r.CreateAccount(ctx, host, did) 175 195 if err != nil { 176 - repoCommitsResultCounter.WithLabelValues(host.Host, "nuerr").Inc() 196 + repoCommitsResultCounter.WithLabelValues(host.Hostname, "nuerr").Inc() 177 197 return err 178 198 } 179 199 } 180 200 if account == nil { 181 - repoCommitsResultCounter.WithLabelValues(host.Host, "nou2").Inc() 182 - return ErrCommitNoUser 201 + repoCommitsResultCounter.WithLabelValues(host.Hostname, "nou2").Inc() 202 + return ErrAccountNotFound 183 203 } 184 204 185 - ustatus := account.GetUpstreamStatus() 205 + // XXX: lock on account 206 + ustatus := account.UpstreamStatus 186 207 187 - if account.GetTakenDown() || ustatus == models.AccountStatusTakendown { 188 - r.Logger.Debug("dropping commit event from taken down user", "did", evt.Repo, "seq", evt.Seq, "pdsHost", host.Host) 189 - repoCommitsResultCounter.WithLabelValues(host.Host, "tdu").Inc() 208 + // XXX: lock on account 209 + if account.Status == models.AccountStatusTakendown || ustatus == models.AccountStatusTakendown { 210 + r.Logger.Debug("dropping commit event from taken down user", "did", evt.Repo, "seq", evt.Seq, "host", host.Hostname) 211 + repoCommitsResultCounter.WithLabelValues(host.Hostname, "tdu").Inc() 190 212 return nil 191 213 } 192 214 193 215 if ustatus == models.AccountStatusSuspended { 194 - r.Logger.Debug("dropping commit event from suspended user", "did", evt.Repo, "seq", evt.Seq, "pdsHost", host.Host) 195 - repoCommitsResultCounter.WithLabelValues(host.Host, "susu").Inc() 216 + r.Logger.Debug("dropping commit event from suspended user", "did", evt.Repo, "seq", evt.Seq, "host", host.Hostname) 217 + repoCommitsResultCounter.WithLabelValues(host.Hostname, "susu").Inc() 196 218 return nil 197 219 } 198 220 199 221 if ustatus == models.AccountStatusDeactivated { 200 - r.Logger.Debug("dropping commit event from deactivated user", "did", evt.Repo, "seq", evt.Seq, "pdsHost", host.Host) 201 - repoCommitsResultCounter.WithLabelValues(host.Host, "du").Inc() 222 + r.Logger.Debug("dropping commit event from deactivated user", "did", evt.Repo, "seq", evt.Seq, "host", host.Hostname) 223 + repoCommitsResultCounter.WithLabelValues(host.Hostname, "du").Inc() 202 224 return nil 203 225 } 204 226 205 227 if evt.Rebase { 206 - repoCommitsResultCounter.WithLabelValues(host.Host, "rebase").Inc() 207 - return fmt.Errorf("rebase was true in event seq:%d,host:%s", evt.Seq, host.Host) 228 + repoCommitsResultCounter.WithLabelValues(host.Hostname, "rebase").Inc() 229 + return fmt.Errorf("rebase was true in event seq:%d,host:%s", evt.Seq, host.Hostname) 208 230 } 209 231 210 - accountPDSId := account.GetPDS() 211 - if host.ID != accountPDSId && accountPDSId != 0 { 212 - r.Logger.Warn("received event for repo from different pds than expected", "repo", evt.Repo, "expPds", accountPDSId, "gotPds", host.Host) 232 + accountHostId := account.HostID 233 + if host.ID != accountHostId && accountHostId != 0 { 234 + r.Logger.Warn("received event for repo from different pds than expected", "repo", evt.Repo, "expPds", accountHostId, "gotPds", host.Hostname) 213 235 // Flush any cached DID documents for this user 214 236 r.purgeDidCache(ctx, evt.Repo) 215 237 216 - account, err = r.syncPDSAccount(ctx, evt.Repo, host, account) 238 + account, err = r.syncHostAccount(ctx, did, host, account) 217 239 if err != nil { 218 - repoCommitsResultCounter.WithLabelValues(host.Host, "uerr2").Inc() 240 + repoCommitsResultCounter.WithLabelValues(host.Hostname, "uerr2").Inc() 219 241 return err 220 242 } 221 243 222 - if account.GetPDS() != host.ID && !r.Config.SkipAccountHostCheck { 223 - repoCommitsResultCounter.WithLabelValues(host.Host, "noauth").Inc() 244 + if account.HostID != host.ID && !r.Config.SkipAccountHostCheck { 245 + repoCommitsResultCounter.WithLabelValues(host.Hostname, "noauth").Inc() 224 246 return fmt.Errorf("event from non-authoritative pds") 225 247 } 226 248 } 227 249 228 250 // TODO: very messy fetch code here 229 - var prevState models.AccountPreviousState 230 - err = r.db.First(&prevState, account.ID).Error 231 - prevP := &prevState 232 - if errors.Is(err, gorm.ErrRecordNotFound) { 233 - prevP = nil 234 - } else if err != nil { 251 + var repo *models.AccountRepo 252 + err = r.db.First(repo, account.UID).Error 253 + if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { 235 254 r.Logger.Error("failed to get previous root", "err", err) 236 - prevP = nil 255 + repo = nil 237 256 } 238 - dbPrevRootStr := "" 239 - dbPrevSeqStr := "" 240 257 var prevRev *syntax.TID 241 258 var prevData *cid.Cid 242 - if prevP != nil { 243 - if prevState.Seq >= evt.Seq && ((prevState.Seq - evt.Seq) < 2000) { 244 - // ignore catchup overlap of 200 on some subscribeRepos restarts 245 - repoCommitsResultCounter.WithLabelValues(host.Host, "dup").Inc() 246 - return nil 247 - } 248 - prevData = &prevState.Cid.CID 249 - t := syntax.TID(prevState.Rev) 259 + if repo != nil { 260 + // XXX: repo.CommitData 261 + //prevData = &repo.Cid.CID 262 + t := syntax.TID(repo.Rev) 250 263 prevRev = &t 251 - dbPrevRootStr = prevState.Cid.CID.String() 252 - dbPrevSeqStr = strconv.FormatInt(prevState.Seq, 10) 253 264 } 254 265 evtPrevDataStr := "" 255 266 if evt.PrevData != nil { ··· 258 269 newRootCid, err := r.Validator.HandleCommit(ctx, host, account, evt, prevRev, prevData) 259 270 if err != nil { 260 271 // XXX: induction trace log 261 - r.Logger.Error("commit bad", "seq", evt.Seq, "pseq", dbPrevSeqStr, "pdsHost", host.Host, "repo", evt.Repo, "prev", evtPrevDataStr, "dbprev", dbPrevRootStr, "err", err) 262 - r.Logger.Warn("failed handling event", "err", err, "pdsHost", host.Host, "seq", evt.Seq, "repo", account.Did, "commit", evt.Commit.String()) 263 - repoCommitsResultCounter.WithLabelValues(host.Host, "err").Inc() 272 + r.Logger.Error("commit bad", "seq", evt.Seq, "host", host.Hostname, "repo", evt.Repo, "prev", evtPrevDataStr, "err", err) 273 + r.Logger.Warn("failed handling event", "err", err, "host", host.Hostname, "seq", evt.Seq, "repo", account.DID, "commit", evt.Commit.String()) 274 + repoCommitsResultCounter.WithLabelValues(host.Hostname, "err").Inc() 264 275 return fmt.Errorf("handle user event failed: %w", err) 265 276 } else { 266 277 // store now verified new repo state 267 - err = r.upsertPrevState(account.ID, newRootCid, evt.Rev, evt.Seq) 278 + err = r.upsertPrevState(account.UID, newRootCid, evt.Rev, evt.Seq) 268 279 if err != nil { 269 - return fmt.Errorf("failed to set previous root uid=%d: %w", account.ID, err) 280 + return fmt.Errorf("failed to set previous root uid=%d: %w", account.UID, err) 270 281 } 271 282 } 272 283 273 - repoCommitsResultCounter.WithLabelValues(host.Host, "ok").Inc() 284 + repoCommitsResultCounter.WithLabelValues(host.Hostname, "ok").Inc() 274 285 275 286 // Broadcast the identity event to all consumers 276 287 commitCopy := *evt 277 288 err = r.Events.AddEvent(ctx, &stream.XRPCStreamEvent{ 278 289 RepoCommit: &commitCopy, 279 - PrivUid: account.GetUid(), 290 + PrivUid: account.UID, 280 291 }) 281 292 if err != nil { 282 293 r.Logger.Error("failed to broadcast commit event", "error", err, "did", evt.Repo) ··· 287 298 } 288 299 289 300 // handleSync processes #sync messages 290 - func (r *Relay) handleSync(ctx context.Context, host *models.PDS, evt *comatproto.SyncSubscribeRepos_Sync) error { 291 - account, err := r.LookupUserByDid(ctx, evt.Did) 301 + func (r *Relay) handleSync(ctx context.Context, host *models.Host, evt *comatproto.SyncSubscribeRepos_Sync) error { 302 + did, err := syntax.ParseDID(evt.Did) 303 + if err != nil { 304 + return fmt.Errorf("invalid DID in message: %s", did) 305 + } 306 + // XXX: did.Normalize() 307 + account, err := r.GetAccount(ctx, did) 292 308 if err != nil { 293 309 if !errors.Is(err, gorm.ErrRecordNotFound) { 294 - repoCommitsResultCounter.WithLabelValues(host.Host, "nou").Inc() 310 + repoCommitsResultCounter.WithLabelValues(host.Hostname, "nou").Inc() 295 311 return fmt.Errorf("looking up event user: %w", err) 296 312 } 297 313 298 - account, err = r.newUser(ctx, host, evt.Did) 314 + account, err = r.CreateAccount(ctx, host, did) 299 315 } 300 316 if err != nil { 301 317 return fmt.Errorf("could not get user for did %#v: %w", evt.Did, err) ··· 305 321 if err != nil { 306 322 return err 307 323 } 308 - err = r.upsertPrevState(account.ID, newRootCid, evt.Rev, evt.Seq) 324 + err = r.upsertPrevState(account.UID, newRootCid, evt.Rev, evt.Seq) 309 325 if err != nil { 310 - return fmt.Errorf("could not sync set previous state uid=%d: %w", account.ID, err) 326 + return fmt.Errorf("could not sync set previous state uid=%d: %w", account.UID, err) 311 327 } 312 328 313 329 // Broadcast the sync event to all consumers ··· 324 340 } 325 341 326 342 func (r *Relay) upsertPrevState(uid uint64, newRootCid *cid.Cid, rev string, seq int64) error { 327 - cidBytes := newRootCid.Bytes() 343 + // XXX: which is this actually 328 344 return r.db.Exec( 329 - "INSERT INTO account_previous_states (uid, cid, rev, seq) VALUES (?, ?, ?, ?) ON CONFLICT (uid) DO UPDATE SET cid = EXCLUDED.cid, rev = EXCLUDED.rev, seq = EXCLUDED.seq", 330 - uid, cidBytes, rev, seq, 345 + "INSERT INTO account_repo (uid, rev, commit_data) VALUES (?, ?, ?) ON CONFLICT (uid) DO UPDATE SET commit_data = EXCLUDED.commit_data , rev = EXCLUDED.rev", 346 + uid, rev, newRootCid.String(), 331 347 ).Error 332 348 } 333 349
+30
cmd/relayered/relay/host.go
··· 1 + package relay 2 + 3 + import ( 4 + "context" 5 + "errors" 6 + 7 + "github.com/bluesky-social/indigo/cmd/relayered/relay/models" 8 + 9 + "gorm.io/gorm" 10 + ) 11 + 12 + func (r *Relay) GetHost(ctx context.Context, hostID uint64) (*models.Host, error) { 13 + ctx, span := tracer.Start(ctx, "getHost") 14 + defer span.End() 15 + 16 + var host models.Host 17 + if err := r.db.Find(&host, hostID).Error; err != nil { 18 + if errors.Is(err, gorm.ErrRecordNotFound) { 19 + return nil, ErrHostNotFound 20 + } 21 + return nil, err 22 + } 23 + 24 + // TODO: is this further check needed? 25 + if host.ID == 0 { 26 + return nil, ErrAccountNotFound 27 + } 28 + 29 + return &host, nil 30 + }
+54 -77
cmd/relayered/relay/models/models.go
··· 1 1 package models 2 2 3 3 import ( 4 - "sync" 5 4 "time" 6 5 7 - "github.com/bluesky-social/indigo/atproto/syntax" 8 - 9 - "github.com/ipfs/go-cid" 10 6 "gorm.io/gorm" 11 7 ) 12 8 9 + // TODO: revisit this 13 10 type DomainBan struct { 14 11 gorm.Model 15 12 Domain string `gorm:"unique"` 16 13 } 17 14 18 - type PDS struct { 19 - gorm.Model 15 + type HostStatus string 20 16 21 - Host string `gorm:"unique"` 22 - SSL bool 23 - Cursor int64 24 - Registered bool 25 - Blocked bool 17 + const ( 18 + HostStatusActive = HostStatus("active") 19 + HostStatusIdle = HostStatus("idle") 20 + HostStatusBanned = HostStatus("banned") 21 + ) 26 22 27 - RateLimit float64 23 + type Host struct { 24 + ID uint64 `gorm:"column:id;primarykey"` 25 + CreatedAt time.Time 26 + UpdatedAt time.Time 28 27 29 - RepoCount int64 30 - RepoLimit int64 28 + // hostname, without URL scheme. might include a port number if localhost, otherwise should not 29 + Hostname string `gorm:"column:hostname;uniqueIndex;not null"` 31 30 32 - HourlyEventLimit int64 33 - DailyEventLimit int64 34 - } 31 + // indicates ws:// not wss:// 32 + NoSSL bool `gorm:"column:no_ssl;default:false"` 35 33 36 - type Account struct { 37 - ID uint64 `gorm:"primarykey"` 38 - CreatedAt time.Time 39 - UpdatedAt time.Time 40 - DeletedAt gorm.DeletedAt `gorm:"index"` 41 - Did string `gorm:"uniqueIndex"` 42 - PDS uint // foreign key on models.PDS.ID 34 + // maximum number of active accounts 35 + AccountLimit int64 `gorm:"column:account_limit"` 43 36 44 - // TakenDown is set to true if the user in question has been taken down by an admin action at this relay. 45 - // A user in this state will have all future events related to it dropped 46 - // and no data about this user will be served. 47 - TakenDown bool 37 + // TODO: ThrottleUntil time.Time 48 38 49 - // UpstreamStatus is the state of the user as reported by the upstream PDS through #account messages. 50 - // Additionally, the non-standard string "active" is set to represent an upstream #account message with the active bool true. 51 - UpstreamStatus string `gorm:"index"` 39 + // indicates this is a highly trusted PDS (different limits apply) 40 + Trusted bool `gorm:"column:trusted;default:false"` 52 41 53 - lk sync.Mutex 54 - } 42 + // enum of account status 43 + Status HostStatus `gorm:"column:status;default:active"` 55 44 56 - func (account *Account) GetDid() string { 57 - return account.Did 45 + // negative number indicates no sequence recorded 46 + LastSeq int64 `gorm:"column:last_seq"` 47 + AccountCount int64 `gorm:"column:account_count"` 58 48 } 59 49 60 - func (account *Account) GetUid() uint64 { 61 - return account.ID 50 + func (Host) TableName() string { 51 + return "host" 62 52 } 63 53 64 - func (account *Account) SetTakenDown(v bool) { 65 - account.lk.Lock() 66 - defer account.lk.Unlock() 67 - account.TakenDown = v 68 - } 54 + type AccountStatus string 69 55 70 - func (account *Account) GetTakenDown() bool { 71 - account.lk.Lock() 72 - defer account.lk.Unlock() 73 - return account.TakenDown 74 - } 56 + var ( 57 + // AccountStatusActive is not in the spec but used internally 58 + AccountStatusActive = AccountStatus("active") 75 59 76 - func (account *Account) SetPDS(pdsId uint) { 77 - account.lk.Lock() 78 - defer account.lk.Unlock() 79 - account.PDS = pdsId 80 - } 60 + AccountStatusDeactivated = AccountStatus("deactivated") 61 + AccountStatusDeleted = AccountStatus("deleted") 62 + AccountStatusDesynchronized = AccountStatus("desynchronized") 63 + AccountStatusSuspended = AccountStatus("suspended") 64 + AccountStatusTakendown = AccountStatus("takendown") 65 + AccountStatusThrottled = AccountStatus("throttled") 66 + AccountStatusHostThrottled = AccountStatus("host-throttled") 67 + ) 81 68 82 - func (account *Account) GetPDS() uint { 83 - account.lk.Lock() 84 - defer account.lk.Unlock() 85 - return account.PDS 69 + type Account struct { 70 + UID uint64 `gorm:"column:uid;primarykey"` 71 + DID string `gorm:"column:did;uniqueIndex;not null"` 72 + HostID uint64 `gorm:"column:host_id;not null"` 73 + Status AccountStatus `gorm:"column:status;default:active"` 74 + UpstreamStatus AccountStatus `gorm:"column:upstream_status;default:active"` 75 + ThrottleUntil time.Time `gorm:"column:throttle_util"` 86 76 } 87 77 88 - func (account *Account) SetUpstreamStatus(v string) { 89 - account.lk.Lock() 90 - defer account.lk.Unlock() 91 - account.UpstreamStatus = v 92 - } 93 - 94 - func (account *Account) GetUpstreamStatus() string { 95 - account.lk.Lock() 96 - defer account.lk.Unlock() 97 - return account.UpstreamStatus 78 + func (Account) TableName() string { 79 + return "account" 98 80 } 99 81 100 - type AccountPreviousState struct { 101 - Uid uint64 `gorm:"column:uid;primaryKey"` 102 - Cid DbCID `gorm:"column:cid"` 103 - Rev string `gorm:"column:rev"` 104 - Seq int64 `gorm:"column:seq"` 82 + type AccountRepo struct { 83 + UID uint64 `gorm:"column:uid;primarykey"` 84 + Rev string `gorm:"column:rev"` 85 + CommitData string `gorm:"column:commit_data"` 105 86 } 106 87 107 - func (ups *AccountPreviousState) GetCid() cid.Cid { 108 - return ups.Cid.CID 109 - } 110 - func (ups *AccountPreviousState) GetRev() syntax.TID { 111 - xt, _ := syntax.ParseTID(ups.Rev) 112 - return xt 88 + func (AccountRepo) TableName() string { 89 + return "account_repo" 113 90 }
+2 -2
cmd/relayered/relay/rate_limits.go
··· 1 1 package relay 2 2 3 - type PDSRates struct { 3 + type HostRates struct { 4 4 // core event rate, counts firehose events 5 5 PerSecond int64 `json:"per_second,omitempty"` 6 6 PerHour int64 `json:"per_hour,omitempty"` ··· 9 9 RepoLimit int64 `json:"repo_limit,omitempty"` 10 10 } 11 11 12 - func (pr *PDSRates) FromSlurper(s *Slurper) { 12 + func (pr *HostRates) FromSlurper(s *Slurper) { 13 13 if pr.PerSecond == 0 { 14 14 pr.PerHour = s.Config.DefaultPerSecondLimit 15 15 }
+17 -17
cmd/relayered/relay/relay.go
··· 25 25 Validator *Validator 26 26 Config RelayConfig 27 27 28 - // extUserLk serializes a section of syncPDSAccount() 28 + // extUserLk serializes a section of syncHostAccount() 29 29 // TODO: at some point we will want to lock specific DIDs, this lock as is 30 30 // is overly broad, but i dont expect it to be a bottleneck for now 31 31 extUserLk sync.Mutex ··· 36 36 consumers map[uint64]*SocketConsumer 37 37 38 38 // Account cache 39 - userCache *lru.Cache[string, *models.Account] 39 + accountCache *lru.Cache[string, *models.Account] 40 40 } 41 41 42 42 type RelayConfig struct { 43 - SSL bool 44 - DefaultRepoLimit int64 45 - ConcurrencyPerPDS int64 46 - MaxQueuePerPDS int64 47 - ApplyPDSClientSettings func(c *xrpc.Client) 48 - SkipAccountHostCheck bool // XXX: only used for testing 43 + SSL bool 44 + DefaultRepoLimit int64 45 + ConcurrencyPerHost int64 46 + MaxQueuePerHost int64 47 + ApplyHostClientSettings func(c *xrpc.Client) 48 + SkipAccountHostCheck bool // XXX: only used for testing 49 49 } 50 50 51 51 func DefaultRelayConfig() *RelayConfig { 52 52 return &RelayConfig{ 53 - SSL: true, 54 - DefaultRepoLimit: 100, 55 - ConcurrencyPerPDS: 100, 56 - MaxQueuePerPDS: 1_000, 53 + SSL: true, 54 + DefaultRepoLimit: 100, 55 + ConcurrencyPerHost: 100, 56 + MaxQueuePerHost: 1_000, 57 57 } 58 58 } 59 59 ··· 76 76 consumersLk: sync.RWMutex{}, 77 77 consumers: make(map[uint64]*SocketConsumer), 78 78 79 - userCache: uc, 79 + accountCache: uc, 80 80 } 81 81 82 82 if err := r.MigrateDatabase(); err != nil { ··· 86 86 slOpts := DefaultSlurperConfig() 87 87 slOpts.SSL = config.SSL 88 88 slOpts.DefaultRepoLimit = config.DefaultRepoLimit 89 - slOpts.ConcurrencyPerPDS = config.ConcurrencyPerPDS 90 - slOpts.MaxQueuePerPDS = config.MaxQueuePerPDS 89 + slOpts.ConcurrencyPerHost = config.ConcurrencyPerHost 90 + slOpts.MaxQueuePerHost = config.MaxQueuePerHost 91 91 s, err := NewSlurper(db, r.handleFedEvent, slOpts, r.Logger) 92 92 if err != nil { 93 93 return nil, err ··· 104 104 if err := r.db.AutoMigrate(models.DomainBan{}); err != nil { 105 105 return err 106 106 } 107 - if err := r.db.AutoMigrate(models.PDS{}); err != nil { 107 + if err := r.db.AutoMigrate(models.Host{}); err != nil { 108 108 return err 109 109 } 110 110 if err := r.db.AutoMigrate(models.Account{}); err != nil { 111 111 return err 112 112 } 113 - if err := r.db.AutoMigrate(models.AccountPreviousState{}); err != nil { 113 + if err := r.db.AutoMigrate(models.AccountRepo{}); err != nil { 114 114 return err 115 115 } 116 116 return nil
+155 -134
cmd/relayered/relay/slurper.go
··· 20 20 "gorm.io/gorm" 21 21 ) 22 22 23 - type IndexCallback func(context.Context, *models.PDS, *stream.XRPCStreamEvent) error 23 + var ErrTimeoutShutdown = fmt.Errorf("timed out waiting for new events") 24 + 25 + type ProcessMessageFunc func(context.Context, *models.Host, *stream.XRPCStreamEvent) error 24 26 25 27 type Slurper struct { 26 - cb IndexCallback 28 + cb ProcessMessageFunc 27 29 db *gorm.DB 28 30 Config *SlurperConfig 29 31 30 32 lk sync.Mutex 31 - active map[string]*activeSub 32 - 33 - LimitMux sync.RWMutex 34 - Limiters map[uint]*Limiters 33 + active map[string]*Subscriber 35 34 36 - DefaultRepoLimit int64 37 - ConcurrencyPerPDS int64 38 - MaxQueuePerPDS int64 35 + LimitMtx sync.RWMutex 36 + Limiters map[uint64]*Limiters 39 37 40 - NewPDSPerDayLimiter *slidingwindow.Limiter 38 + NewHostPerDayLimiter *slidingwindow.Limiter 41 39 42 40 shutdownChan chan bool 43 41 shutdownResult chan []error ··· 57 55 DefaultPerHourLimit int64 58 56 DefaultPerDayLimit int64 59 57 DefaultRepoLimit int64 60 - ConcurrencyPerPDS int64 61 - MaxQueuePerPDS int64 58 + ConcurrencyPerHost int64 59 + MaxQueuePerHost int64 62 60 NewSubsDisabled bool 63 61 TrustedDomains []string 64 - NewPDSPerDayLimit int64 62 + NewHostPerDayLimit int64 65 63 } 66 64 67 65 func DefaultSlurperConfig() *SlurperConfig { ··· 71 69 DefaultPerHourLimit: 2500, 72 70 DefaultPerDayLimit: 20_000, 73 71 DefaultRepoLimit: 100, 74 - ConcurrencyPerPDS: 100, 75 - MaxQueuePerPDS: 1_000, 72 + ConcurrencyPerHost: 100, 73 + MaxQueuePerHost: 1_000, 76 74 } 77 75 } 78 76 79 - type activeSub struct { 80 - pds *models.PDS 77 + // represents an active client connection 78 + type Subscriber struct { 79 + Host *models.Host 80 + LastSeq int64 // XXX: switch to an atomic 81 + Limiters *Limiters 82 + 81 83 lk sync.RWMutex 82 84 ctx context.Context 83 85 cancel func() 84 86 } 85 87 86 - func (sub *activeSub) updateCursor(curs int64) { 88 + func (sub *Subscriber) UpdateSeq(seq int64) { 87 89 sub.lk.Lock() 88 90 defer sub.lk.Unlock() 89 - sub.pds.Cursor = curs 91 + sub.Host.LastSeq = seq 90 92 } 91 93 92 - func NewSlurper(db *gorm.DB, cb IndexCallback, config *SlurperConfig, logger *slog.Logger) (*Slurper, error) { 94 + func NewSlurper(db *gorm.DB, cb ProcessMessageFunc, config *SlurperConfig, logger *slog.Logger) (*Slurper, error) { 93 95 if config == nil { 94 96 config = DefaultSlurperConfig() 95 97 } ··· 98 100 } 99 101 100 102 // NOTE: unused second argument is not an 'error 101 - newPDSPerDayLimiter, _ := slidingwindow.NewLimiter(time.Hour*24, config.NewPDSPerDayLimit, windowFunc) 103 + newHostPerDayLimiter, _ := slidingwindow.NewLimiter(time.Hour*24, config.NewHostPerDayLimit, windowFunc) 102 104 103 105 s := &Slurper{ 104 - cb: cb, 105 - db: db, 106 - Config: config, 107 - active: make(map[string]*activeSub), 108 - Limiters: make(map[uint]*Limiters), 109 - shutdownChan: make(chan bool), 110 - shutdownResult: make(chan []error), 111 - NewPDSPerDayLimiter: newPDSPerDayLimiter, 112 - log: logger, 106 + cb: cb, 107 + db: db, 108 + Config: config, 109 + active: make(map[string]*Subscriber), 110 + Limiters: make(map[uint64]*Limiters), 111 + shutdownChan: make(chan bool), 112 + shutdownResult: make(chan []error), 113 + NewHostPerDayLimiter: newHostPerDayLimiter, 114 + log: logger, 113 115 } 114 116 115 117 // Start a goroutine to flush cursors to the DB every 30s ··· 117 119 for { 118 120 select { 119 121 case <-s.shutdownChan: 120 - s.log.Info("flushing PDS cursors on shutdown") 122 + s.log.Info("flushing Host cursors on shutdown") 121 123 ctx := context.Background() 122 124 var errs []error 123 125 if errs = s.flushCursors(ctx); len(errs) > 0 { ··· 125 127 s.log.Error("failed to flush cursors on shutdown", "err", err) 126 128 } 127 129 } 128 - s.log.Info("done flushing PDS cursors on shutdown") 130 + s.log.Info("done flushing Host cursors on shutdown") 129 131 s.shutdownResult <- errs 130 132 return 131 133 case <-time.After(time.Second * 10): 132 - s.log.Debug("flushing PDS cursors") 134 + s.log.Debug("flushing Host cursors") 133 135 ctx := context.Background() 134 136 if errs := s.flushCursors(ctx); len(errs) > 0 { 135 137 for _, err := range errs { 136 138 s.log.Error("failed to flush cursors", "err", err) 137 139 } 138 140 } 139 - s.log.Debug("done flushing PDS cursors") 141 + s.log.Debug("done flushing Host cursors") 140 142 } 141 143 } 142 144 }() ··· 148 150 return slidingwindow.NewLocalWindow() 149 151 } 150 152 151 - func (s *Slurper) GetLimiters(pdsID uint) *Limiters { 152 - s.LimitMux.RLock() 153 - defer s.LimitMux.RUnlock() 154 - return s.Limiters[pdsID] 153 + func (s *Slurper) GetLimiters(hostID uint64) *Limiters { 154 + s.LimitMtx.RLock() 155 + defer s.LimitMtx.RUnlock() 156 + return s.Limiters[hostID] 155 157 } 156 158 157 - func (s *Slurper) GetOrCreateLimiters(pdsID uint, perSecLimit int64, perHourLimit int64, perDayLimit int64) *Limiters { 158 - s.LimitMux.RLock() 159 - defer s.LimitMux.RUnlock() 160 - lim, ok := s.Limiters[pdsID] 159 + /* 160 + XXX 161 + 162 + func (s *Slurper) GetOrCreateLimiters(pdsID uint64, perSecLimit int64, perHourLimit int64, perDayLimit int64) *Limiters { 163 + s.LimitMtx.RLock() 164 + defer s.LimitMtx.RUnlock() 165 + lim, ok := s.Limiters[pdsID] 166 + if !ok { 167 + perSec, _ := slidingwindow.NewLimiter(time.Second, perSecLimit, windowFunc) 168 + perHour, _ := slidingwindow.NewLimiter(time.Hour, perHourLimit, windowFunc) 169 + perDay, _ := slidingwindow.NewLimiter(time.Hour*24, perDayLimit, windowFunc) 170 + lim = &Limiters{ 171 + PerSecond: perSec, 172 + PerHour: perHour, 173 + PerDay: perDay, 174 + } 175 + s.Limiters[pdsID] = lim 176 + } 177 + 178 + return lim 179 + } 180 + */ 181 + func (s *Slurper) GetOrCreateLimiters(hostID uint64) *Limiters { 182 + s.LimitMtx.RLock() 183 + defer s.LimitMtx.RUnlock() 184 + lim, ok := s.Limiters[hostID] 161 185 if !ok { 162 - perSec, _ := slidingwindow.NewLimiter(time.Second, perSecLimit, windowFunc) 163 - perHour, _ := slidingwindow.NewLimiter(time.Hour, perHourLimit, windowFunc) 164 - perDay, _ := slidingwindow.NewLimiter(time.Hour*24, perDayLimit, windowFunc) 186 + perSec, _ := slidingwindow.NewLimiter(time.Second, s.Config.DefaultPerSecondLimit, windowFunc) 187 + perHour, _ := slidingwindow.NewLimiter(time.Hour, s.Config.DefaultPerHourLimit, windowFunc) 188 + perDay, _ := slidingwindow.NewLimiter(time.Hour*24, s.Config.DefaultPerDayLimit, windowFunc) 165 189 lim = &Limiters{ 166 190 PerSecond: perSec, 167 191 PerHour: perHour, 168 192 PerDay: perDay, 169 193 } 170 - s.Limiters[pdsID] = lim 194 + s.Limiters[hostID] = lim 171 195 } 172 196 173 197 return lim 174 198 } 175 199 176 - func (s *Slurper) SetLimits(pdsID uint, perSecLimit int64, perHourLimit int64, perDayLimit int64) { 177 - s.LimitMux.Lock() 178 - defer s.LimitMux.Unlock() 179 - lim, ok := s.Limiters[pdsID] 200 + func (s *Slurper) SetLimits(hostID uint64, perSecLimit int64, perHourLimit int64, perDayLimit int64) { 201 + s.LimitMtx.Lock() 202 + defer s.LimitMtx.Unlock() 203 + lim, ok := s.Limiters[hostID] 180 204 if !ok { 181 205 perSec, _ := slidingwindow.NewLimiter(time.Second, perSecLimit, windowFunc) 182 206 perHour, _ := slidingwindow.NewLimiter(time.Hour, perHourLimit, windowFunc) ··· 186 210 PerHour: perHour, 187 211 PerDay: perDay, 188 212 } 189 - s.Limiters[pdsID] = lim 213 + s.Limiters[hostID] = lim 190 214 } 191 215 192 216 lim.PerSecond.SetLimit(perSecLimit) ··· 213 237 // Checks whether a host is allowed to be subscribed to 214 238 // must be called with the slurper lock held 215 239 func (s *Slurper) canSlurpHost(host string) bool { 216 - // Check if we're over the limit for new PDSs today 217 - if !s.NewPDSPerDayLimiter.Allow() { 240 + // Check if we're over the limit for new hosts today 241 + if !s.NewHostPerDayLimiter.Allow() { 218 242 return false 219 243 } 220 244 ··· 236 260 return !s.Config.NewSubsDisabled 237 261 } 238 262 239 - func (s *Slurper) SubscribeToPds(ctx context.Context, host string, reg bool, adminOverride bool, rateOverrides *PDSRates) error { 263 + func (s *Slurper) SubscribeToPds(ctx context.Context, hostname string, reg bool, adminOverride bool, rateOverrides *HostRates) error { 240 264 // TODO: for performance, lock on the hostname instead of global 241 265 s.lk.Lock() 242 266 defer s.lk.Unlock() 243 267 244 - _, ok := s.active[host] 268 + _, ok := s.active[hostname] 245 269 if ok { 246 270 return nil 247 271 } 248 272 249 - var peering models.PDS 250 - if err := s.db.Find(&peering, "host = ?", host).Error; err != nil { 273 + var host models.Host 274 + if err := s.db.Find(&host, "hostname = ?", hostname).Error; err != nil { 251 275 return err 252 276 } 253 277 254 - if peering.Blocked { 255 - return fmt.Errorf("cannot subscribe to blocked pds") 256 - } 257 - 258 278 newHost := false 259 279 260 - if peering.ID == 0 { 261 - if !adminOverride && !s.canSlurpHost(host) { 280 + if host.ID == 0 { 281 + if !adminOverride && !s.canSlurpHost(hostname) { 262 282 return ErrNewSubsDisabled 263 283 } 264 284 // New PDS! 265 - npds := models.PDS{ 266 - Host: host, 267 - SSL: s.Config.SSL, 268 - Registered: reg, 269 - RateLimit: float64(s.Config.DefaultPerSecondLimit), 270 - HourlyEventLimit: s.Config.DefaultPerHourLimit, 271 - DailyEventLimit: s.Config.DefaultPerDayLimit, 272 - RepoLimit: s.Config.DefaultRepoLimit, 285 + npds := models.Host{ 286 + Hostname: hostname, 287 + NoSSL: !s.Config.SSL, 288 + Status: models.HostStatusActive, 289 + AccountLimit: s.Config.DefaultRepoLimit, 273 290 } 291 + /* XXX 274 292 if rateOverrides != nil { 275 293 npds.RateLimit = float64(rateOverrides.PerSecond) 276 294 npds.HourlyEventLimit = rateOverrides.PerHour 277 295 npds.DailyEventLimit = rateOverrides.PerDay 278 296 npds.RepoLimit = rateOverrides.RepoLimit 279 297 } 298 + */ 280 299 if err := s.db.Create(&npds).Error; err != nil { 281 300 return err 282 301 } 283 302 284 303 newHost = true 285 - peering = npds 304 + host = npds 305 + } else if host.Status == models.HostStatusBanned { 306 + return fmt.Errorf("cannot subscribe to banned pds") 286 307 } 287 308 288 - if !peering.Registered && reg { 289 - peering.Registered = true 290 - if err := s.db.Model(models.PDS{}).Where("id = ?", peering.ID).Update("registered", true).Error; err != nil { 309 + /* XXX 310 + if !host.Registered && reg { 311 + host.Registered = true 312 + if err := s.db.Model(models.Host{}).Where("id = ?", host.ID).Update("registered", true).Error; err != nil { 291 313 return err 292 314 } 293 315 } 316 + */ 294 317 295 318 ctx, cancel := context.WithCancel(context.Background()) 296 - sub := activeSub{ 297 - pds: &peering, 319 + sub := Subscriber{ 320 + Host: &host, 298 321 ctx: ctx, 299 322 cancel: cancel, 300 323 } 301 - s.active[host] = &sub 324 + s.active[hostname] = &sub 302 325 303 - s.GetOrCreateLimiters(peering.ID, int64(peering.RateLimit), peering.HourlyEventLimit, peering.DailyEventLimit) 326 + s.GetOrCreateLimiters(host.ID) 304 327 305 - go s.subscribeWithRedialer(ctx, &peering, &sub, newHost) 328 + go s.subscribeWithRedialer(ctx, &host, &sub, newHost) 306 329 307 330 return nil 308 331 } ··· 311 334 s.lk.Lock() 312 335 defer s.lk.Unlock() 313 336 314 - var all []models.PDS 315 - if err := s.db.Find(&all, "registered = true AND blocked = false").Error; err != nil { 337 + var all []models.Host 338 + if err := s.db.Find(&all, "status = \"active\"").Error; err != nil { 316 339 return err 317 340 } 318 341 319 - for _, pds := range all { 320 - pds := pds 342 + for _, host := range all { 343 + host := host 321 344 322 345 ctx, cancel := context.WithCancel(context.Background()) 323 - sub := activeSub{ 324 - pds: &pds, 346 + sub := Subscriber{ 347 + Host: &host, 325 348 ctx: ctx, 326 349 cancel: cancel, 327 350 } 328 - s.active[pds.Host] = &sub 351 + s.active[host.Hostname] = &sub 329 352 330 - // Check if we've already got a limiter for this PDS 331 - s.GetOrCreateLimiters(pds.ID, int64(pds.RateLimit), pds.HourlyEventLimit, pds.DailyEventLimit) 332 - go s.subscribeWithRedialer(ctx, &pds, &sub, false) 353 + // Check if we've already got a limiter for this host 354 + // XXX: s.GetOrCreateLimiters(host.ID, int64(host.RateLimit), host.HourlyEventLimit, host.DailyEventLimit) 355 + go s.subscribeWithRedialer(ctx, &host, &sub, false) 333 356 } 334 357 335 358 return nil 336 359 } 337 360 338 - func (s *Slurper) subscribeWithRedialer(ctx context.Context, host *models.PDS, sub *activeSub, newHost bool) { 361 + func (s *Slurper) subscribeWithRedialer(ctx context.Context, host *models.Host, sub *Subscriber, newHost bool) { 339 362 defer func() { 340 363 s.lk.Lock() 341 364 defer s.lk.Unlock() 342 365 343 - delete(s.active, host.Host) 366 + delete(s.active, host.Hostname) 344 367 }() 345 368 346 369 d := websocket.Dialer{ ··· 352 375 protocol = "wss" 353 376 } 354 377 355 - // Special case `.host.bsky.network` PDSs to rewind cursor by 200 events to smooth over unclean shutdowns 356 - if strings.HasSuffix(host.Host, ".host.bsky.network") && host.Cursor > 200 { 357 - host.Cursor -= 200 378 + // Special case `.host.bsky.network` Host to rewind cursor by 200 events to smooth over unclean shutdowns 379 + if strings.HasSuffix(host.Hostname, ".host.bsky.network") && host.LastSeq > 200 { 380 + host.LastSeq -= 200 358 381 } 359 382 360 - cursor := host.Cursor 383 + cursor := host.LastSeq 361 384 362 385 connectedInbound.Inc() 363 386 defer connectedInbound.Dec() ··· 373 396 374 397 var url string 375 398 if newHost { 376 - url = fmt.Sprintf("%s://%s/xrpc/com.atproto.sync.subscribeRepos", protocol, host.Host) 399 + url = fmt.Sprintf("%s://%s/xrpc/com.atproto.sync.subscribeRepos", protocol, host.Hostname) 377 400 } else { 378 - url = fmt.Sprintf("%s://%s/xrpc/com.atproto.sync.subscribeRepos?cursor=%d", protocol, host.Host, cursor) 401 + url = fmt.Sprintf("%s://%s/xrpc/com.atproto.sync.subscribeRepos?cursor=%d", protocol, host.Hostname, cursor) 379 402 } 380 403 con, res, err := d.DialContext(ctx, url, nil) 381 404 if err != nil { 382 - s.log.Warn("dialing failed", "pdsHost", host.Host, "err", err, "backoff", backoff) 405 + s.log.Warn("dialing failed", "host", host.Hostname, "err", err, "backoff", backoff) 383 406 time.Sleep(sleepForBackoff(backoff)) 384 407 backoff++ 385 408 386 409 if backoff > 15 { 387 - s.log.Warn("pds does not appear to be online, disabling for now", "pdsHost", host.Host) 388 - if err := s.db.Model(&models.PDS{}).Where("id = ?", host.ID).Update("registered", false).Error; err != nil { 389 - s.log.Error("failed to unregister failing pds", "err", err) 410 + s.log.Warn("host does not appear to be online, disabling for now", "host", host.Hostname) 411 + if err := s.db.Model(&models.Host{}).Where("id = ?", host.ID).Update("registered", false).Error; err != nil { 412 + s.log.Error("failed to unregister failing host", "err", err) 390 413 } 391 414 392 415 return ··· 400 423 curCursor := cursor 401 424 if err := s.handleConnection(ctx, host, con, &cursor, sub); err != nil { 402 425 if errors.Is(err, ErrTimeoutShutdown) { 403 - s.log.Info("shutting down pds subscription after timeout", "host", host.Host, "time", EventsTimeout) 426 + s.log.Info("shutting down host subscription after timeout", "host", host.Hostname, "time", EventsTimeout) 404 427 return 405 428 } 406 - s.log.Warn("connection to failed", "host", host.Host, "err", err) 429 + s.log.Warn("connection to failed", "host", host.Hostname, "err", err) 407 430 // TODO: measure the last N connection error times and if they're coming too fast reconnect slower or don't reconnect and wait for requestCrawl 408 431 } 409 432 ··· 425 448 return time.Second * 30 426 449 } 427 450 428 - var ErrTimeoutShutdown = fmt.Errorf("timed out waiting for new events") 429 - 430 451 var EventsTimeout = time.Minute 431 452 432 - func (s *Slurper) handleConnection(ctx context.Context, host *models.PDS, con *websocket.Conn, lastCursor *int64, sub *activeSub) error { 453 + func (s *Slurper) handleConnection(ctx context.Context, host *models.Host, con *websocket.Conn, lastCursor *int64, sub *Subscriber) error { 433 454 ctx, cancel := context.WithCancel(ctx) 434 455 defer cancel() 435 456 436 457 rsc := &stream.RepoStreamCallbacks{ 437 458 RepoCommit: func(evt *comatproto.SyncSubscribeRepos_Commit) error { 438 - s.log.Debug("got remote repo event", "pdsHost", host.Host, "repo", evt.Repo, "seq", evt.Seq) 459 + s.log.Debug("got remote repo event", "host", host.Hostname, "repo", evt.Repo, "seq", evt.Seq) 439 460 if err := s.cb(context.TODO(), host, &stream.XRPCStreamEvent{ 440 461 RepoCommit: evt, 441 462 }); err != nil { 442 - s.log.Error("failed handling event", "host", host.Host, "seq", evt.Seq, "err", err) 463 + s.log.Error("failed handling event", "host", host.Hostname, "seq", evt.Seq, "err", err) 443 464 } 444 465 *lastCursor = evt.Seq 445 466 446 - sub.updateCursor(*lastCursor) 467 + sub.UpdateSeq(*lastCursor) 447 468 448 469 return nil 449 470 }, 450 471 RepoSync: func(evt *comatproto.SyncSubscribeRepos_Sync) error { 451 - s.log.Debug("got remote repo event", "pdsHost", host.Host, "repo", evt.Did, "seq", evt.Seq) 472 + s.log.Debug("got remote repo event", "host", host.Hostname, "repo", evt.Did, "seq", evt.Seq) 452 473 if err := s.cb(context.TODO(), host, &stream.XRPCStreamEvent{ 453 474 RepoSync: evt, 454 475 }); err != nil { 455 - s.log.Error("failed handling event", "host", host.Host, "seq", evt.Seq, "err", err) 476 + s.log.Error("failed handling event", "host", host.Hostname, "seq", evt.Seq, "err", err) 456 477 } 457 478 *lastCursor = evt.Seq 458 479 459 - sub.updateCursor(*lastCursor) 480 + sub.UpdateSeq(*lastCursor) 460 481 461 482 return nil 462 483 }, 463 484 RepoHandle: func(evt *comatproto.SyncSubscribeRepos_Handle) error { 464 - s.log.Debug("got remote handle update event", "pdsHost", host.Host, "did", evt.Did, "handle", evt.Handle) 485 + s.log.Debug("got remote handle update event", "host", host.Hostname, "did", evt.Did, "handle", evt.Handle) 465 486 if err := s.cb(context.TODO(), host, &stream.XRPCStreamEvent{ 466 487 RepoHandle: evt, 467 488 }); err != nil { 468 - s.log.Error("failed handling event", "host", host.Host, "seq", evt.Seq, "err", err) 489 + s.log.Error("failed handling event", "host", host.Hostname, "seq", evt.Seq, "err", err) 469 490 } 470 491 *lastCursor = evt.Seq 471 492 472 - sub.updateCursor(*lastCursor) 493 + sub.UpdateSeq(*lastCursor) 473 494 474 495 return nil 475 496 }, 476 497 RepoMigrate: func(evt *comatproto.SyncSubscribeRepos_Migrate) error { 477 - s.log.Debug("got remote repo migrate event", "pdsHost", host.Host, "did", evt.Did, "migrateTo", evt.MigrateTo) 498 + s.log.Debug("got remote repo migrate event", "host", host.Hostname, "did", evt.Did, "migrateTo", evt.MigrateTo) 478 499 if err := s.cb(context.TODO(), host, &stream.XRPCStreamEvent{ 479 500 RepoMigrate: evt, 480 501 }); err != nil { 481 - s.log.Error("failed handling event", "host", host.Host, "seq", evt.Seq, "err", err) 502 + s.log.Error("failed handling event", "host", host.Hostname, "seq", evt.Seq, "err", err) 482 503 } 483 504 *lastCursor = evt.Seq 484 505 485 - sub.updateCursor(*lastCursor) 506 + sub.UpdateSeq(*lastCursor) 486 507 487 508 return nil 488 509 }, 489 510 RepoTombstone: func(evt *comatproto.SyncSubscribeRepos_Tombstone) error { 490 - s.log.Debug("got remote repo tombstone event", "pdsHost", host.Host, "did", evt.Did) 511 + s.log.Debug("got remote repo tombstone event", "host", host.Hostname, "did", evt.Did) 491 512 if err := s.cb(context.TODO(), host, &stream.XRPCStreamEvent{ 492 513 RepoTombstone: evt, 493 514 }); err != nil { 494 - s.log.Error("failed handling event", "host", host.Host, "seq", evt.Seq, "err", err) 515 + s.log.Error("failed handling event", "host", host.Hostname, "seq", evt.Seq, "err", err) 495 516 } 496 517 *lastCursor = evt.Seq 497 518 498 - sub.updateCursor(*lastCursor) 519 + sub.UpdateSeq(*lastCursor) 499 520 500 521 return nil 501 522 }, 502 523 RepoInfo: func(info *comatproto.SyncSubscribeRepos_Info) error { 503 - s.log.Debug("info event", "name", info.Name, "message", info.Message, "pdsHost", host.Host) 524 + s.log.Debug("info event", "name", info.Name, "message", info.Message, "host", host.Hostname) 504 525 return nil 505 526 }, 506 527 RepoIdentity: func(ident *comatproto.SyncSubscribeRepos_Identity) error { ··· 508 529 if err := s.cb(context.TODO(), host, &stream.XRPCStreamEvent{ 509 530 RepoIdentity: ident, 510 531 }); err != nil { 511 - s.log.Error("failed handling event", "host", host.Host, "seq", ident.Seq, "err", err) 532 + s.log.Error("failed handling event", "host", host.Hostname, "seq", ident.Seq, "err", err) 512 533 } 513 534 *lastCursor = ident.Seq 514 535 515 - sub.updateCursor(*lastCursor) 536 + sub.UpdateSeq(*lastCursor) 516 537 517 538 return nil 518 539 }, ··· 521 542 if err := s.cb(context.TODO(), host, &stream.XRPCStreamEvent{ 522 543 RepoAccount: acct, 523 544 }); err != nil { 524 - s.log.Error("failed handling event", "host", host.Host, "seq", acct.Seq, "err", err) 545 + s.log.Error("failed handling event", "host", host.Hostname, "seq", acct.Seq, "err", err) 525 546 } 526 547 *lastCursor = acct.Seq 527 548 528 - sub.updateCursor(*lastCursor) 549 + sub.UpdateSeq(*lastCursor) 529 550 530 551 return nil 531 552 }, ··· 546 567 }, 547 568 } 548 569 549 - lims := s.GetOrCreateLimiters(host.ID, int64(host.RateLimit), host.HourlyEventLimit, host.DailyEventLimit) 570 + lims := s.GetOrCreateLimiters(host.ID) 550 571 551 572 limiters := []*slidingwindow.Limiter{ 552 573 lims.PerSecond, ··· 566 587 } 567 588 568 589 type cursorSnapshot struct { 569 - id uint 590 + id uint64 570 591 cursor int64 571 592 } 572 593 573 - // flushCursors updates the PDS cursors in the DB for all active subscriptions 594 + // flushCursors updates the Host cursors in the DB for all active subscriptions 574 595 func (s *Slurper) flushCursors(ctx context.Context) []error { 575 596 start := time.Now() 576 597 //ctx, span := otel.Tracer("feedmgr").Start(ctx, "flushCursors") ··· 583 604 for _, sub := range s.active { 584 605 sub.lk.RLock() 585 606 cursors = append(cursors, cursorSnapshot{ 586 - id: sub.pds.ID, 587 - cursor: sub.pds.Cursor, 607 + id: sub.Host.ID, 608 + cursor: sub.Host.LastSeq, 588 609 }) 589 610 sub.lk.RUnlock() 590 611 } ··· 595 616 596 617 tx := s.db.WithContext(ctx).Begin() 597 618 for _, cursor := range cursors { 598 - if err := tx.WithContext(ctx).Model(models.PDS{}).Where("id = ?", cursor.id).UpdateColumn("cursor", cursor.cursor).Error; err != nil { 619 + if err := tx.WithContext(ctx).Model(models.Host{}).Where("id = ?", cursor.id).UpdateColumn("cursor", cursor.cursor).Error; err != nil { 599 620 errs = append(errs, err) 600 621 } else { 601 622 okcount++ ··· 635 656 // cleanup in the run thread subscribeWithRedialer() will delete(s.active, host) 636 657 637 658 if block { 638 - if err := s.db.Model(models.PDS{}).Where("id = ?", ac.pds.ID).UpdateColumn("blocked", true).Error; err != nil { 659 + if err := s.db.Model(models.Host{}).Where("id = ?", ac.Host.ID).UpdateColumn("blocked", true).Error; err != nil { 639 660 return fmt.Errorf("failed to set host as blocked: %w", err) 640 661 } 641 662 }
+12 -12
cmd/relayered/relay/validator.go
··· 58 58 } 59 59 60 60 type NextCommitHandler interface { 61 - HandleCommit(ctx context.Context, host *models.PDS, uid uint64, did string, commit *comatproto.SyncSubscribeRepos_Commit) error 61 + HandleCommit(ctx context.Context, host *models.Host, uid uint64, did string, commit *comatproto.SyncSubscribeRepos_Commit) error 62 62 } 63 63 64 64 type userLock struct { ··· 99 99 } 100 100 } 101 101 102 - func (val *Validator) HandleCommit(ctx context.Context, host *models.PDS, account *models.Account, commit *comatproto.SyncSubscribeRepos_Commit, prevRev *syntax.TID, prevData *cid.Cid) (newRoot *cid.Cid, err error) { 103 - uid := account.GetUid() 102 + func (val *Validator) HandleCommit(ctx context.Context, host *models.Host, account *models.Account, commit *comatproto.SyncSubscribeRepos_Commit, prevRev *syntax.TID, prevData *cid.Cid) (newRoot *cid.Cid, err error) { 103 + uid := account.UID 104 104 unlock := val.lockUser(ctx, uid) 105 105 defer unlock() 106 106 repoFragment, err := val.VerifyCommitMessage(ctx, host, commit, prevRev, prevData) ··· 124 124 125 125 var ErrNewRevBeforePrevRev = &revOutOfOrderError{} 126 126 127 - func (val *Validator) VerifyCommitMessage(ctx context.Context, host *models.PDS, msg *comatproto.SyncSubscribeRepos_Commit, prevRev *syntax.TID, prevData *cid.Cid) (*repo.Repo, error) { 128 - hostname := host.Host 127 + func (val *Validator) VerifyCommitMessage(ctx context.Context, host *models.Host, msg *comatproto.SyncSubscribeRepos_Commit, prevRev *syntax.TID, prevData *cid.Cid) (*repo.Repo, error) { 128 + hostname := host.Hostname 129 129 hasWarning := false 130 130 commitVerifyStarts.Inc() 131 131 logger := slog.Default().With("did", msg.Repo, "rev", msg.Rev, "seq", msg.Seq, "time", msg.Time) ··· 163 163 //logger.Warn("event with tooBig flag set") 164 164 commitVerifyWarnings.WithLabelValues(hostname, "big").Inc() 165 165 // XXX: induction trace log 166 - val.log.Warn("commit tooBig", "seq", msg.Seq, "pdsHost", host.Host, "repo", msg.Repo) 166 + val.log.Warn("commit tooBig", "seq", msg.Seq, "host", host.Hostname, "repo", msg.Repo) 167 167 hasWarning = true 168 168 } 169 169 if msg.Rebase { 170 170 //logger.Warn("event with rebase flag set") 171 171 commitVerifyWarnings.WithLabelValues(hostname, "reb").Inc() 172 172 // XXX: induction trace log 173 - val.log.Warn("commit rebase", "seq", msg.Seq, "pdsHost", host.Host, "repo", msg.Repo) 173 + val.log.Warn("commit rebase", "seq", msg.Seq, "host", host.Hostname, "repo", msg.Repo) 174 174 hasWarning = true 175 175 } 176 176 ··· 228 228 if o.Prev == nil { 229 229 logger.Debug("can't invert legacy op", "action", o.Action) 230 230 // XXX: induction trace log 231 - val.log.Warn("commit delete op", "seq", msg.Seq, "pdsHost", host.Host, "repo", msg.Repo) 231 + val.log.Warn("commit delete op", "seq", msg.Seq, "host", host.Hostname, "repo", msg.Repo) 232 232 commitVerifyOkish.WithLabelValues(hostname, "del").Inc() 233 233 return repoFragment, nil 234 234 } ··· 236 236 if o.Prev == nil { 237 237 logger.Debug("can't invert legacy op", "action", o.Action) 238 238 // XXX: induction trace log 239 - val.log.Warn("commit update op", "seq", msg.Seq, "pdsHost", host.Host, "repo", msg.Repo) 239 + val.log.Warn("commit update op", "seq", msg.Seq, "host", host.Hostname, "repo", msg.Repo) 240 240 commitVerifyOkish.WithLabelValues(hostname, "up").Inc() 241 241 return repoFragment, nil 242 242 } ··· 249 249 if *c != *prevData { 250 250 commitVerifyWarnings.WithLabelValues(hostname, "pr").Inc() 251 251 // XXX: induction trace log 252 - val.log.Warn("commit prevData mismatch", "seq", msg.Seq, "pdsHost", host.Host, "repo", msg.Repo) 252 + val.log.Warn("commit prevData mismatch", "seq", msg.Seq, "host", host.Hostname, "repo", msg.Repo) 253 253 hasWarning = true 254 254 } 255 255 } else { ··· 305 305 } 306 306 307 307 // HandleSync checks signed commit from a #sync message 308 - func (val *Validator) HandleSync(ctx context.Context, host *models.PDS, msg *comatproto.SyncSubscribeRepos_Sync) (newRoot *cid.Cid, err error) { 309 - hostname := host.Host 308 + func (val *Validator) HandleSync(ctx context.Context, host *models.Host, msg *comatproto.SyncSubscribeRepos_Sync) (newRoot *cid.Cid, err error) { 309 + hostname := host.Hostname 310 310 hasWarning := false 311 311 312 312 did, err := syntax.ParseDID(msg.Did)
+19 -4
cmd/relayered/stubs.go
··· 85 85 func (s *Service) HandleComAtprotoSyncGetRepo(c echo.Context) error { 86 86 ctx, span := otel.Tracer("server").Start(c.Request().Context(), "HandleComAtprotoSyncGetRepo") 87 87 defer span.End() 88 + // XXX: this is not how to fetch query params... 88 89 // no request object, only params 89 90 params := c.QueryParams() 90 - var did string 91 + var did syntax.DID 91 92 hasDid := false 92 93 for paramName, pvl := range params { 93 94 switch paramName { 94 95 case "did": 95 96 if len(pvl) == 1 { 96 - did = pvl[0] 97 + d, err := syntax.ParseDID(pvl[0]) 98 + if err != nil { 99 + return err // XXX: better error 100 + } 101 + did = d 97 102 hasDid = true 98 103 } else if len(pvl) > 1 { 99 104 return c.JSON(http.StatusBadRequest, XRPCError{Message: "only allow one did param"}) ··· 108 113 return c.JSON(http.StatusBadRequest, XRPCError{Message: "need did param"}) 109 114 } 110 115 111 - pdsHostname, err := s.relay.GetHostForDID(ctx, did) 116 + acc, err := s.relay.GetAccount(ctx, did) 117 + if err != nil { 118 + // TODO: better error 119 + return err 120 + } 121 + 122 + host, err := s.relay.GetHost(ctx, acc.HostID) 123 + if err != nil { 124 + // TODO: better error 125 + return err 126 + } 112 127 113 128 // TODO: proper error responses 114 129 if err != nil { ··· 120 135 } 121 136 122 137 nextUrl := *(c.Request().URL) 123 - nextUrl.Host = pdsHostname 138 + nextUrl.Host = host.Hostname 124 139 if nextUrl.Scheme == "" { 125 140 nextUrl.Scheme = "https" 126 141 }