···2233import (
44 "context"
55+ "errors"
66+ "fmt"
57 "strings"
6879 "github.com/bluesky-social/indigo/cmd/rerelay/relay/models"
1010+1111+ "gorm.io/gorm"
812)
9131010-// DomainIsBanned checks if the given host is banned, starting with the host
1111-// itself, then checking every parent domain up to the tld
1212-func (r *Relay) DomainIsBanned(ctx context.Context, host string) (bool, error) {
1313- // ignore ports when checking for ban status
1414- hostport := strings.Split(host, ":")
1414+// XXX: tests for domain ban logic (which hit an actual database)
15151616- segments := strings.Split(hostport[0], ".")
1616+// DomainIsBanned checks if the given hostname is banned. It checks all domain suffixs.
1717+//
1818+// Hostname is assumed to have been parsed/normalized (eg, lower-case).
1919+func (r *Relay) DomainIsBanned(ctx context.Context, hostname string) (bool, error) {
17201818- // TODO: use normalize method once that merges
1919- var cleaned []string
2020- for _, s := range segments {
2121- if s == "" {
2222- continue
2323- }
2424- s = strings.ToLower(s)
2121+ if strings.HasPrefix(hostname, "localhost:") {
2222+ // XXX: check localhost config separately
2323+ }
25242626- cleaned = append(cleaned, s)
2525+ // otherwise we shouldn't have a port/colon
2626+ if strings.Contains(hostname, ":") {
2727+ return false, fmt.Errorf("unexpected colon in hostname: %s", hostname)
2728 }
2828- segments = cleaned
29293030+ // try entire host, and then all domain suffixes
3131+ segments := strings.Split(hostname, ".")
3032 for i := 0; i < len(segments)-1; i++ {
3133 dchk := strings.Join(segments[i:], ".")
3234 found, err := r.findDomainBan(ctx, dchk)
3335 if err != nil {
3436 return false, err
3537 }
3636-3738 if found {
3839 return true, nil
3940 }
···4142 return false, nil
4243}
43444444-func (r *Relay) findDomainBan(ctx context.Context, host string) (bool, error) {
4545+func (r *Relay) findDomainBan(ctx context.Context, domain string) (bool, error) {
4546 var ban models.DomainBan
4646- if err := r.db.Find(&ban, "domain = ?", host).Error; err != nil {
4747+ if err := r.db.Model(&models.DomainBan{}).Where("domain = ?", domain).First(&ban).Error; err != nil {
4848+ if errors.Is(err, gorm.ErrRecordNotFound) {
4949+ return false, nil
5050+ }
4751 return false, err
4852 }
5353+ return true, nil
5454+}
49555050- if ban.ID == 0 {
5151- return false, nil
5252- }
5656+func (r *Relay) CreateDomainBan(ctx context.Context, domain string) error {
5757+ domainBan := models.DomainBan{Domain: domain}
5858+ return r.db.Create(&domainBan).Error
5959+}
53605454- return true, nil
6161+func (r *Relay) RemoveDomainBan(ctx context.Context, domain string) error {
6262+ return r.db.Delete(&models.DomainBan{}, "domain = ?", domain).Error
6363+}
6464+6565+// returns all domain bans
6666+func (r *Relay) ListDomainBans(ctx context.Context) ([]models.DomainBan, error) {
6767+ bans := []models.DomainBan{}
6868+ if err := r.db.Model(&models.DomainBan{}).Find(&bans).Error; err != nil {
6969+ return nil, err
7070+ }
7171+ return bans, nil
5572}
+5-6
cmd/rerelay/relay/slurper.go
···489489 return err
490490}
491491492492-// TODO: called from admin endpoint
493493-func (s *Slurper) GetActiveList() []string {
492492+func (s *Slurper) GetActiveSubHostnames() []string {
494493 s.subsLk.Lock()
495494 defer s.subsLk.Unlock()
496496- var out []string
495495+496496+ var keys []string
497497 for k := range s.subs {
498498- out = append(out, k)
498498+ keys = append(keys, k)
499499 }
500500-501501- return out
500500+ return keys
502501}
503502504503func (s *Slurper) KillUpstreamConnection(hostname string, ban bool) error {
+10-12
cmd/rerelay/service.go
···148148 e.GET("/xrpc/com.atproto.sync.getRepoStatus", svc.HandleComAtprotoSyncGetRepoStatus)
149149 e.GET("/xrpc/com.atproto.sync.getLatestCommit", svc.HandleComAtprotoSyncGetLatestCommit)
150150151151- /* XXX: disabled while refactoring
152151 admin := e.Group("/admin", svc.checkAdminAuth)
153152154153 // Slurper-related Admin API
155154 admin.GET("/subs/getUpstreamConns", svc.handleAdminGetUpstreamConns)
155155+ admin.POST("/subs/killUpstream", svc.handleAdminKillUpstreamConn)
156156 admin.GET("/subs/getEnabled", svc.handleAdminGetSubsEnabled)
157157- admin.GET("/subs/perDayLimit", svc.handleAdminGetNewPDSPerDayRateLimit)
158157 admin.POST("/subs/setEnabled", svc.handleAdminSetSubsEnabled)
159159- admin.POST("/subs/killUpstream", svc.handleAdminKillUpstreamConn)
160160- admin.POST("/subs/setPerDayLimit", svc.handleAdminSetNewPDSPerDayRateLimit)
158158+ admin.GET("/subs/perDayLimit", svc.handleAdminGetNewHostPerDayRateLimit)
159159+ admin.POST("/subs/setPerDayLimit", svc.handleAdminSetNewHostPerDayRateLimit)
161160162161 // Domain-related Admin API
163162 admin.GET("/subs/listDomainBans", svc.handleAdminListDomainBans)
···165164 admin.POST("/subs/unbanDomain", svc.handleAdminUnbanDomain)
166165167166 // Repo-related Admin API
167167+ admin.GET("/repo/takedowns", svc.handleAdminListRepoTakeDowns)
168168 admin.POST("/repo/takeDown", svc.handleAdminTakeDownRepo)
169169 admin.POST("/repo/reverseTakedown", svc.handleAdminReverseTakedown)
170170- admin.GET("/repo/takedowns", svc.handleAdminListRepoTakeDowns)
171170172172- // PDS-related Admin API
171171+ // Host-related Admin API
172172+ admin.GET("/pds/list", svc.handleListHosts)
173173 admin.POST("/pds/requestCrawl", svc.handleAdminRequestCrawl)
174174- admin.GET("/pds/list", svc.handleListPDSs)
175175- admin.POST("/pds/changeLimits", svc.handleAdminChangePDSRateLimits)
176176- admin.POST("/pds/block", svc.handleBlockPDS)
177177- admin.POST("/pds/unblock", svc.handleUnblockPDS)
178178- admin.POST("/pds/addTrustedDomain", svc.handleAdminAddTrustedDomain)
174174+ // TODO: admin.POST("/pds/changeLimits", svc.handleAdminChangeHostRateLimits)
175175+ admin.POST("/pds/block", svc.handleBlockHost)
176176+ admin.POST("/pds/unblock", svc.handleUnblockHost)
177177+ // removed: admin.POST("/pds/addTrustedDomain", svc.handleAdminAddTrustedDomain)
179178180179 // Consumer-related Admin API
181180 admin.GET("/consumers/list", svc.handleAdminListConsumers)
182182- */
183181184182 // In order to support booting on random ports in tests, we need to tell the
185183 // Echo instance it's already got a port, and then use its StartServer