this repo has no description
0
fork

Configure Feed

Select the types of activity you want to include in your feed.

push some DB code from slurper to relay

+143 -127
+1 -1
cmd/relayered/handlers.go
··· 103 103 } 104 104 } 105 105 106 - return s.relay.Slurper.SubscribeToPds(ctx, host, true, false, nil) 106 + return s.relay.SubscribeToHost(host, true, false, nil) 107 107 } 108 108 109 109 func (s *Service) handleComAtprotoSyncListRepos(ctx context.Context, cursor int64, limit int) (*comatproto.SyncListRepos_Output, error) {
+109
cmd/relayered/relay/crawl.go
··· 1 + package relay 2 + 3 + import ( 4 + "fmt" 5 + "strings" 6 + 7 + "github.com/bluesky-social/indigo/cmd/relayered/relay/models" 8 + ) 9 + 10 + // Checks whether a host is allowed to be subscribed to 11 + // 12 + // Must be called with the slurper lock held 13 + func (r *Relay) canSlurpHost(hostname string) bool { 14 + // Check if we're over the limit for new hosts today 15 + if !r.Slurper.NewHostPerDayLimiter.Allow() { 16 + return false 17 + } 18 + 19 + // Check if the host is a trusted domain 20 + for _, d := range r.Config.TrustedDomains { 21 + // If the domain starts with a *., it's a wildcard 22 + if strings.HasPrefix(d, "*.") { 23 + // Cut off the * so we have .domain.com 24 + if strings.HasSuffix(hostname, strings.TrimPrefix(d, "*")) { 25 + return true 26 + } 27 + } else { 28 + if hostname == d { 29 + return true 30 + } 31 + } 32 + } 33 + 34 + return !r.Config.DisableNewHosts 35 + } 36 + 37 + func (r *Relay) SubscribeToHost(hostname string, reg bool, adminOverride bool, rateOverrides *HostRates) error { 38 + 39 + // if we already have an active subscription going, exit early 40 + if r.Slurper.CheckIfSubscribed(hostname) { 41 + return nil 42 + } 43 + 44 + var host models.Host 45 + if err := r.db.Find(&host, "hostname = ?", hostname).Error; err != nil { 46 + return err 47 + } 48 + 49 + newHost := false 50 + 51 + if host.ID == 0 { 52 + if !adminOverride && !r.canSlurpHost(hostname) { 53 + return ErrNewSubsDisabled 54 + } 55 + // New PDS! 56 + npds := models.Host{ 57 + Hostname: hostname, 58 + NoSSL: !r.Config.SSL, 59 + Status: models.HostStatusActive, 60 + AccountLimit: r.Config.DefaultRepoLimit, 61 + } 62 + /* XXX 63 + if rateOverrides != nil { 64 + npds.RateLimit = float64(rateOverrides.PerSecond) 65 + npds.HourlyEventLimit = rateOverrides.PerHour 66 + npds.DailyEventLimit = rateOverrides.PerDay 67 + npds.RepoLimit = rateOverrides.RepoLimit 68 + } 69 + */ 70 + if err := r.db.Create(&npds).Error; err != nil { 71 + return err 72 + } 73 + 74 + newHost = true 75 + host = npds 76 + } else if host.Status == models.HostStatusBanned { 77 + return fmt.Errorf("cannot subscribe to banned pds") 78 + } 79 + 80 + /* XXX 81 + if !host.Registered && reg { 82 + host.Registered = true 83 + if err := s.db.Model(models.Host{}).Where("id = ?", host.ID).Update("registered", true).Error; err != nil { 84 + return err 85 + } 86 + } 87 + */ 88 + 89 + return r.Slurper.Subscribe(&host, newHost) 90 + } 91 + 92 + // This function expects to be run when starting up, to re-connect to known active hosts 93 + func (r *Relay) ResubscribeAllHosts() error { 94 + 95 + var all []models.Host 96 + if err := r.db.Find(&all, "status = \"active\"").Error; err != nil { 97 + return err 98 + } 99 + 100 + for _, host := range all { 101 + // copy host 102 + host := host 103 + err := r.Slurper.Subscribe(&host, false) 104 + if err != nil { 105 + r.Logger.Warn("failed to re-subscribe to host", "hostID", host.ID, "hostname", host.Hostname, "err", err) 106 + } 107 + } 108 + return nil 109 + }
+5
cmd/relayered/relay/errors.go
··· 8 8 ErrHostNotFound = errors.New("unknown host or PDS") 9 9 ErrAccountNotFound = errors.New("unknown account") 10 10 ErrAccountRepoNotFound = errors.New("repository state not available") 11 + 12 + // TODO: these might need better names 13 + ErrTimeoutShutdown = errors.New("timed out waiting for new events") 14 + ErrNewSubsDisabled = errors.New("new subscriptions temporarily disabled") 15 + ErrNoActiveConnection = errors.New("no active connection to host") 11 16 )
+6 -1
cmd/relayered/relay/relay.go
··· 47 47 MaxQueuePerHost int64 48 48 ApplyHostClientSettings func(c *xrpc.Client) 49 49 SkipAccountHostCheck bool // XXX: only used for testing 50 + 51 + // if true, ignore "requestCrawl" 52 + DisableNewHosts bool 53 + TrustedDomains []string 50 54 } 51 55 52 56 func DefaultRelayConfig() *RelayConfig { ··· 98 102 } 99 103 r.Slurper = s 100 104 101 - if err := r.Slurper.RestartAll(); err != nil { 105 + // TODO: should this happen in a separate "start" method, instead of "NewRelay()"? 106 + if err := r.ResubscribeAllHosts(); err != nil { 102 107 return nil, err 103 108 } 104 109 return r, nil
+21 -124
cmd/relayered/relay/slurper.go
··· 10 10 "sync" 11 11 "time" 12 12 13 - "github.com/RussellLuo/slidingwindow" 14 13 comatproto "github.com/bluesky-social/indigo/api/atproto" 15 14 "github.com/bluesky-social/indigo/cmd/relayered/relay/models" 16 15 "github.com/bluesky-social/indigo/cmd/relayered/stream" 17 16 "github.com/bluesky-social/indigo/cmd/relayered/stream/schedulers/parallel" 18 17 18 + "github.com/RussellLuo/slidingwindow" 19 19 "github.com/gorilla/websocket" 20 20 "gorm.io/gorm" 21 21 ) 22 22 23 - var ErrTimeoutShutdown = fmt.Errorf("timed out waiting for new events") 24 - 25 23 type ProcessMessageFunc func(context.Context, *models.Host, *stream.XRPCStreamEvent) error 26 24 27 25 type Slurper struct { ··· 30 28 Config *SlurperConfig 31 29 32 30 lk sync.Mutex 33 - active map[string]*Subscriber 31 + active map[string]*Subscription 34 32 35 33 LimitMtx sync.RWMutex 36 34 Limiters map[uint64]*Limiters ··· 57 55 DefaultRepoLimit int64 58 56 ConcurrencyPerHost int64 59 57 MaxQueuePerHost int64 60 - NewSubsDisabled bool 61 - TrustedDomains []string 62 58 NewHostPerDayLimit int64 63 59 } 64 60 ··· 74 70 } 75 71 } 76 72 77 - // represents an active client connection 78 - type Subscriber struct { 73 + // represents an active client connection to a remote host 74 + type Subscription struct { 79 75 Host *models.Host 80 76 LastSeq int64 // XXX: switch to an atomic 81 77 Limiters *Limiters ··· 85 81 cancel func() 86 82 } 87 83 88 - func (sub *Subscriber) UpdateSeq(seq int64) { 84 + func (sub *Subscription) UpdateSeq(seq int64) { 89 85 sub.lk.Lock() 90 86 defer sub.lk.Unlock() 91 87 sub.Host.LastSeq = seq ··· 106 102 cb: cb, 107 103 db: db, 108 104 Config: config, 109 - active: make(map[string]*Subscriber), 105 + active: make(map[string]*Subscription), 110 106 Limiters: make(map[uint64]*Limiters), 111 107 shutdownChan: make(chan bool), 112 108 shutdownResult: make(chan []error), ··· 232 228 return errs 233 229 } 234 230 235 - var ErrNewSubsDisabled = fmt.Errorf("new subscriptions temporarily disabled") 231 + func (s *Slurper) CheckIfSubscribed(hostname string) bool { 232 + s.lk.Lock() 233 + defer s.lk.Unlock() 236 234 237 - // Checks whether a host is allowed to be subscribed to 238 - // must be called with the slurper lock held 239 - func (s *Slurper) canSlurpHost(host string) bool { 240 - // Check if we're over the limit for new hosts today 241 - if !s.NewHostPerDayLimiter.Allow() { 242 - return false 243 - } 244 - 245 - // Check if the host is a trusted domain 246 - for _, d := range s.Config.TrustedDomains { 247 - // If the domain starts with a *., it's a wildcard 248 - if strings.HasPrefix(d, "*.") { 249 - // Cut off the * so we have .domain.com 250 - if strings.HasSuffix(host, strings.TrimPrefix(d, "*")) { 251 - return true 252 - } 253 - } else { 254 - if host == d { 255 - return true 256 - } 257 - } 258 - } 259 - 260 - return !s.Config.NewSubsDisabled 235 + _, ok := s.active[hostname] 236 + return ok 261 237 } 262 238 263 - func (s *Slurper) SubscribeToPds(ctx context.Context, hostname string, reg bool, adminOverride bool, rateOverrides *HostRates) error { 239 + func (s *Slurper) Subscribe(host *models.Host, newHost bool) error { 264 240 // TODO: for performance, lock on the hostname instead of global 265 241 s.lk.Lock() 266 242 defer s.lk.Unlock() 267 243 268 - _, ok := s.active[hostname] 269 - if ok { 270 - return nil 271 - } 272 - 273 - var host models.Host 274 - if err := s.db.Find(&host, "hostname = ?", hostname).Error; err != nil { 275 - return err 276 - } 277 - 278 - newHost := false 279 - 280 - if host.ID == 0 { 281 - if !adminOverride && !s.canSlurpHost(hostname) { 282 - return ErrNewSubsDisabled 283 - } 284 - // New PDS! 285 - npds := models.Host{ 286 - Hostname: hostname, 287 - NoSSL: !s.Config.SSL, 288 - Status: models.HostStatusActive, 289 - AccountLimit: s.Config.DefaultRepoLimit, 290 - } 291 - /* XXX 292 - if rateOverrides != nil { 293 - npds.RateLimit = float64(rateOverrides.PerSecond) 294 - npds.HourlyEventLimit = rateOverrides.PerHour 295 - npds.DailyEventLimit = rateOverrides.PerDay 296 - npds.RepoLimit = rateOverrides.RepoLimit 297 - } 298 - */ 299 - if err := s.db.Create(&npds).Error; err != nil { 300 - return err 301 - } 302 - 303 - newHost = true 304 - host = npds 305 - } else if host.Status == models.HostStatusBanned { 306 - return fmt.Errorf("cannot subscribe to banned pds") 307 - } 308 - 309 - /* XXX 310 - if !host.Registered && reg { 311 - host.Registered = true 312 - if err := s.db.Model(models.Host{}).Where("id = ?", host.ID).Update("registered", true).Error; err != nil { 313 - return err 314 - } 315 - } 316 - */ 317 - 318 244 ctx, cancel := context.WithCancel(context.Background()) 319 - sub := Subscriber{ 320 - Host: &host, 245 + sub := Subscription{ 246 + Host: host, 321 247 ctx: ctx, 322 248 cancel: cancel, 323 249 } 324 - s.active[hostname] = &sub 250 + s.active[host.Hostname] = &sub 325 251 326 252 s.GetOrCreateLimiters(host.ID) 327 253 328 - go s.subscribeWithRedialer(ctx, &host, &sub, newHost) 329 - 330 - return nil 331 - } 332 - 333 - func (s *Slurper) RestartAll() error { 334 - s.lk.Lock() 335 - defer s.lk.Unlock() 336 - 337 - var all []models.Host 338 - if err := s.db.Find(&all, "status = \"active\"").Error; err != nil { 339 - return err 340 - } 341 - 342 - for _, host := range all { 343 - host := host 344 - 345 - ctx, cancel := context.WithCancel(context.Background()) 346 - sub := Subscriber{ 347 - Host: &host, 348 - ctx: ctx, 349 - cancel: cancel, 350 - } 351 - s.active[host.Hostname] = &sub 352 - 353 - // Check if we've already got a limiter for this host 354 - // XXX: s.GetOrCreateLimiters(host.ID, int64(host.RateLimit), host.HourlyEventLimit, host.DailyEventLimit) 355 - go s.subscribeWithRedialer(ctx, &host, &sub, false) 356 - } 254 + go s.subscribeWithRedialer(ctx, host, &sub, newHost) 357 255 358 256 return nil 359 257 } 360 258 361 - func (s *Slurper) subscribeWithRedialer(ctx context.Context, host *models.Host, sub *Subscriber, newHost bool) { 259 + func (s *Slurper) subscribeWithRedialer(ctx context.Context, host *models.Host, sub *Subscription, newHost bool) { 362 260 defer func() { 363 261 s.lk.Lock() 364 262 defer s.lk.Unlock() ··· 450 348 451 349 var EventsTimeout = time.Minute 452 350 453 - func (s *Slurper) handleConnection(ctx context.Context, host *models.Host, con *websocket.Conn, lastCursor *int64, sub *Subscriber) error { 351 + func (s *Slurper) handleConnection(ctx context.Context, host *models.Host, con *websocket.Conn, lastCursor *int64, sub *Subscription) error { 454 352 ctx, cancel := context.WithCancel(ctx) 455 353 defer cancel() 456 354 ··· 554 452 Error: func(errf *stream.ErrorFrame) error { 555 453 switch errf.Error { 556 454 case "FutureCursor": 455 + // XXX: need test coverage for this path 557 456 // if we get a FutureCursor frame, reset our sequence number for this host 558 - if err := s.db.Table("pds").Where("id = ?", host.ID).Update("cursor", 0).Error; err != nil { 457 + if err := s.db.Table("host").Where("id = ?", host.ID).Update("last_seq", 0).Error; err != nil { 559 458 return err 560 459 } 561 460 ··· 616 515 617 516 tx := s.db.WithContext(ctx).Begin() 618 517 for _, cursor := range cursors { 619 - if err := tx.WithContext(ctx).Model(models.Host{}).Where("id = ?", cursor.id).UpdateColumn("cursor", cursor.cursor).Error; err != nil { 518 + if err := tx.WithContext(ctx).Model(models.Host{}).Where("id = ?", cursor.id).UpdateColumn("last_seq", cursor.cursor).Error; err != nil { 620 519 errs = append(errs, err) 621 520 } else { 622 521 okcount++ ··· 641 540 642 541 return out 643 542 } 644 - 645 - var ErrNoActiveConnection = fmt.Errorf("no active connection to host") 646 543 647 544 func (s *Slurper) KillUpstreamConnection(host string, block bool) error { 648 545 s.lk.Lock()
+1 -1
cmd/relayered/testing/runner.go
··· 109 109 110 110 sr := MustSimpleRelay(&dir, tmpd) 111 111 112 - err = sr.Relay.Slurper.SubscribeToPds(ctx, fmt.Sprintf("localhost:%d", hostPort), true, true, nil) 112 + err = sr.Relay.SubscribeToHost(fmt.Sprintf("localhost:%d", hostPort), true, true, nil) 113 113 if err != nil { 114 114 return err 115 115 }