···44 "context"
55 "errors"
66 "fmt"
77+ "strings"
7889 "github.com/bluesky-social/indigo/atproto/syntax"
910 "github.com/bluesky-social/indigo/cmd/relay/relay/models"
···193194 return nil
194195}
195196197197+// Returns the of active accounts (based on local and upstream status). The sort order is by UID, ascending.
196198func (r *Relay) ListAccounts(ctx context.Context, cursor int64, limit int) ([]*models.Account, error) {
197199198198- // XXX: what status filter should be in place here? not deleted in addition to not takendown?
199200 accounts := []*models.Account{}
200200- if err := r.db.Model(&models.Account{}).Where("uid > ? AND status IS NOT 'takendown' AND (upstream_status IS NULL OR upstream_status = 'active')", cursor).Order("uid").Limit(limit).Find(&accounts).Error; err != nil {
201201+ if err := r.db.Model(&models.Account{}).Where("uid > ? AND status = 'active' AND upstream_status = 'active'", cursor).Order("uid").Limit(limit).Find(&accounts).Error; err != nil {
201202 return nil, err
202203 }
203204 return accounts, nil
···216217 return r.db.Exec("INSERT INTO account_repo (uid, rev, commit_cid, commit_data) VALUES (?, ?, ?, ?) ON CONFLICT (uid) DO UPDATE SET rev = EXCLUDED.rev, commit_cid = EXCLUDED.commit_cid, commit_data = EXCLUDED.commit_data", uid, rev, commitCID, commitDataCID).Error
217218}
218219219219-// this function with exact name and args implements the `diskpersist.UidSource` interface
220220+// This implements the `diskpersist.UidSource` interface
220221func (r *Relay) DidToUid(ctx context.Context, did string) (uint64, error) {
221222 // NOTE: not re-parsing DID here (this function is called "loopback" from persister)
222223 xu, err := r.GetAccount(ctx, syntax.DID(did))
···228229 }
229230 return xu.UID, nil
230231}
232232+233233+// In the general case, DIDs are case-sensitive. But PLC and did:web should not be, and should normalize to lower-case.
234234+func NormalizeDID(orig syntax.DID) syntax.DID {
235235+ lower := strings.ToLower(string(orig))
236236+ if strings.HasPrefix(lower, "did:plc:") || strings.HasPrefix(lower, "did:web:") {
237237+ return syntax.DID(lower)
238238+ }
239239+ return orig
240240+}
···3131 }
3232 }
33333434- return !r.Config.DisableNewHosts
3434+ return true
3535}
36363737func (r *Relay) SubscribeToHost(hostname string, noSSL, adminForce bool) error {
38383939- // if we already have an active subscription going, exit early
3939+ // if we already have an active subscription, exit early
4040 if r.Slurper.CheckIfSubscribed(hostname) {
4141 return nil
4242 }
43434444- // XXX: new PDS daily rate-limit
4545-4444+ // fetch host info from database. this query will not error if host does not yet exist
4645 newHost := false
4746 var host models.Host
4847 if err := r.db.Find(&host, "hostname = ?", hostname).Error; err != nil {
···5049 }
51505251 if host.ID == 0 {
5252+ newHost = true
5353 if !adminForce && !r.canSlurpHost(hostname) {
5454+ // TODO: is this the correct error code?
5455 return ErrNewSubsDisabled
5556 }
5656- // New PDS!
5757- npds := models.Host{
5757+5858+ // XXX: new host daily rate-limit
5959+6060+ host = models.Host{
5861 Hostname: hostname,
5962 NoSSL: noSSL,
6063 Status: models.HostStatusActive,
6164 AccountLimit: r.Config.DefaultRepoLimit,
6265 }
6363- /* XXX
6464- if rateOverrides != nil {
6565- npds.RateLimit = float64(rateOverrides.PerSecond)
6666- npds.HourlyEventLimit = rateOverrides.PerHour
6767- npds.DailyEventLimit = rateOverrides.PerDay
6868- npds.RepoLimit = rateOverrides.RepoLimit
6969- }
7070- */
7171- if err := r.db.Create(&npds).Error; err != nil {
6666+6767+ if err := r.db.Create(&newHost).Error; err != nil {
7268 return err
7369 }
74707575- newHost = true
7676- host = npds
7171+ r.Logger.Info("adding new host subscription", "hostname", hostname, "noSSL", noSSL, "adminForce", adminForce)
7772 } else if host.Status == models.HostStatusBanned {
7873 return fmt.Errorf("cannot subscribe to banned pds")
7974 }
80758181- /* XXX
8282- if !host.Registered && reg {
8383- host.Registered = true
8484- if err := s.db.Model(models.Host{}).Where("id = ?", host.ID).Update("registered", true).Error; err != nil {
8585- return err
8686- }
8787- }
8888- */
8989-9076 return r.Slurper.Subscribe(&host, newHost)
9177}
9278···9985 }
1008610187 for _, host := range all {
102102- // copy host
8888+ logger := r.Logger.With("hostID", host.ID, "hostname", host.Hostname)
8989+ logger.Info("re-subscribing to active host")
9090+ // make a copy of host
10391 host := host
10492 err := r.Slurper.Subscribe(&host, false)
10593 if err != nil {
106106- r.Logger.Warn("failed to re-subscribe to host", "hostID", host.ID, "hostname", host.Hostname, "err", err)
9494+ logger.Warn("failed to re-subscribe to host", "err", err)
10795 }
10896 }
10997 return nil
+2-1
cmd/relay/relay/domain_ban.go
···1919func (r *Relay) DomainIsBanned(ctx context.Context, hostname string) (bool, error) {
20202121 if strings.HasPrefix(hostname, "localhost:") {
2222- // XXX: check localhost config separately
2222+ // this method never allows localhost; need to use admin-mode for that
2323+ return true, nil
2324 }
24252526 // otherwise we shouldn't have a port/colon
+1
cmd/relay/relay/errors.go
···88 ErrHostNotFound = errors.New("unknown host or PDS")
99 ErrAccountNotFound = errors.New("unknown account")
1010 ErrAccountRepoNotFound = errors.New("repository state not available")
1111+ ErrNotPDS = errors.New("server is not a PDS")
11121213 // TODO: these might need better names
1314 ErrTimeoutShutdown = errors.New("timed out waiting for new events")
+8-3
cmd/relay/relay/host.go
···82828383// parses, normalizes, and validates a raw URL (HTTP or WebSocket) in to a hostname for subscriptions
8484//
8585-// Hostnames much be DNS names, not IP addresses
8585+// Hostnames must be DNS names, not IP addresses.
8686func ParseHostname(raw string) (hostname string, noSSL bool, err error) {
87878888 // handle case of bare hostname
8989 if !strings.Contains(raw, "://") {
9090- raw = "https://" + raw
9090+ if strings.HasPrefix(raw, "localhost:") {
9191+ raw = "http://" + raw
9292+ } else {
9393+ raw = "https://" + raw
9494+ }
9195 }
92969397 u, err := url.Parse(raw)
···100104 default:
101105 return "", false, fmt.Errorf("unsupported URL scheme: %s", u.Scheme)
102106 }
107107+103108 // 'localhost' (exact string) is allowed *with* a required port number; SSL is optional
104109 if u.Hostname() == "localhost" {
105105- if u.Port() == "" {
110110+ if u.Port() == "" || !strings.HasPrefix(u.Host, "localhost:") {
106111 return "", false, fmt.Errorf("port number is required for localhost")
107112 }
108113 return u.Host, noSSL, nil
-3
cmd/relay/relay/host_checker.go
···2233import (
44 "context"
55- "errors"
65 "fmt"
76 "net/http"
87···109 "github.com/bluesky-social/indigo/atproto/identity"
1110 "github.com/bluesky-social/indigo/xrpc"
1211)
1313-1414-var ErrNotPDS = errors.New("server is not a PDS")
15121613// Simple interface for doing host and account status checks.
1714//
···5757 }
5858}
59596060-// handles the shared part of event processing: that the account existing, is associated with this host, etc
6060+// Implements the shared part of event processing: that the account existing, is associated with this host, etc.
6161+//
6262+// If there is no error, the returned account is always non-nil, but the identity may be nil (if there was a resolution error).
6163func (r *Relay) preProcessEvent(ctx context.Context, didStr string, hostname string, hostID uint64, logger *slog.Logger) (*models.Account, *identity.Identity, error) {
62646365 did, err := syntax.ParseDID(didStr)
6466 if err != nil {
6567 return nil, nil, fmt.Errorf("invalid DID in message: %w", err)
6668 }
6767- // XXX: did = did.Normalize()
6969+ // TODO: add a test case for non-normalized DID
7070+ did = NormalizeDID(did)
68716972 acc, err := r.GetAccount(ctx, did)
7073 if err != nil {
···95989699 ident, err := r.Dir.LookupDID(ctx, did)
97100 if err != nil {
9898- // XXX: handle more granularly (eg, true NotFound vs other errors); and add tests
9999- logger.Warn("failed to load identity")
101101+ logger.Warn("failed to load identity", "did", did, "err", err)
100102 }
101103 return acc, ident, nil
102104}
···113115 if !acc.IsActive() {
114116 logger.Info("dropping commit message for non-active account", "status", acc.Status, "upstreamStatus", acc.UpstreamStatus)
115117 return nil
118118+ }
119119+120120+ if ident == nil {
121121+ // XXX: what to do if identity resolution fails
116122 }
117123118124 prevRepo, err := r.GetAccountRepo(ctx, acc.UID)
···159165 if !acc.IsActive() {
160166 logger.Info("dropping commit message for non-active account", "status", acc.Status, "upstreamStatus", acc.UpstreamStatus)
161167 return nil
168168+ }
169169+170170+ if ident == nil {
171171+ // XXX: what to do if identity resolution fails
162172 }
163173164174 newRepo, err := r.VerifyRepoSync(ctx, evt, ident, hostname)
-7
cmd/relay/relay/metrics.go
···3131 Help: "The total number of sync events received",
3232}, []string{"pds"})
33333434-/* XXX
3535-var repoCommitsResultCounter = promauto.NewCounterVec(prometheus.CounterOpts{
3636- Name: "repo_commits_result_counter",
3737- Help: "The results of commit events received",
3838-}, []string{"pds", "status"})
3939-*/
4040-4134var eventsSentCounter = promauto.NewCounterVec(prometheus.CounterOpts{
4235 Name: "events_sent_counter",
4336 Help: "The total number of events sent to consumers",
+8-8
cmd/relay/relay/models/models.go
···66 "gorm.io/gorm"
77)
8899-// TODO: revisit this
109type DomainBan struct {
1110 gorm.Model
1211 Domain string `gorm:"unique"`
···1514type HostStatus string
16151716const (
1818- HostStatusActive = HostStatus("active")
1919- HostStatusIdle = HostStatus("idle")
2020- HostStatusOffline = HostStatus("offline")
2121- HostStatusBanned = HostStatus("banned")
1717+ HostStatusActive = HostStatus("active")
1818+ HostStatusIdle = HostStatus("idle")
1919+ HostStatusOffline = HostStatus("offline")
2020+ HostStatusThrottled = HostStatus("throttled")
2121+ HostStatusBanned = HostStatus("banned")
2222)
23232424type Host struct {
···2626 CreatedAt time.Time
2727 UpdatedAt time.Time
28282929- // hostname, without URL scheme. might include a port number if localhost, otherwise should not
2929+ // hostname, without URL scheme. if localhost, must include a port number; otherwise must not include port
3030 Hostname string `gorm:"column:hostname;uniqueIndex;not null"`
31313232 // indicates ws:// not wss://
···37373838 // TODO: ThrottleUntil time.Time
39394040- // indicates this is a highly trusted PDS (different limits apply)
4040+ // indicates this is a highly trusted host (PDS) and different limits apply
4141 Trusted bool `gorm:"column:trusted;default:false"`
42424343 // enum of account status
···6464 AccountStatusSuspended = AccountStatus("suspended")
6565 AccountStatusTakendown = AccountStatus("takendown")
6666 AccountStatusThrottled = AccountStatus("throttled")
6767- AccountStatusHostThrottled = AccountStatus("host-throttled") // TODO: actually implement this
6767+ AccountStatusHostThrottled = AccountStatus("host-throttled") // TODO: not yet implemented
68686969 // generic "not active, but not known" status
7070 AccountStatusInactive = AccountStatus("inactive")
+7-10
cmd/relay/relay/relay.go
···3939}
40404141type RelayConfig struct {
4242- SSL bool
4342 DefaultRepoLimit int64
4443 ConcurrencyPerHost int
4544 QueueDepthPerHost int
4646- SkipAccountHostCheck bool // XXX: only used for testing
4747- LenientSyncValidation bool // XXX: wire through config
4545+ LenientSyncValidation bool
4646+ TrustedDomains []string
48474949- // if true, ignore "requestCrawl"
5050- DisableNewHosts bool
5151- TrustedDomains []string
4848+ // If true, skip validation that messages for a given account (DID) are coming from the expected upstream host (PDS). Currently only used in tests; might be used for intermediate relays in the future.
4949+ SkipAccountHostCheck bool
5250}
53515452func DefaultRelayConfig() *RelayConfig {
5555- // NOTE: many of these defaults are CLI arg defaults
5353+ // NOTE: many of these defaults are clobbered by CLI arguments
5654 return &RelayConfig{
5757- SSL: true,
5855 DefaultRepoLimit: 100,
5956 ConcurrencyPerHost: 40,
6057 QueueDepthPerHost: 1000,
···8986 return nil, err
9087 }
91889292- // XXX: need to pass-through more relay configs
9389 slurpConfig := DefaultSlurperConfig()
9494- slurpConfig.SSL = config.SSL
9590 slurpConfig.DefaultRepoLimit = config.DefaultRepoLimit
9691 slurpConfig.ConcurrencyPerHost = config.ConcurrencyPerHost
9792 slurpConfig.QueueDepthPerHost = config.QueueDepthPerHost
9393+9894 // register callbacks to persist cursors and host state in database
9995 slurpConfig.PersistCursorCallback = r.PersistHostCursors
10096 slurpConfig.PersistHostStatusCallback = r.UpdateHostStatus
9797+10198 s, err := NewSlurper(r.processRepoEvent, slurpConfig, r.Logger)
10299 if err != nil {
103100 return nil, err
-2
cmd/relay/relay/slurper.go
···5454}
55555656type SlurperConfig struct {
5757- SSL bool
5857 DefaultPerSecondLimit int64
5958 DefaultPerHourLimit int64
6059 DefaultPerDayLimit int64
···7069func DefaultSlurperConfig() *SlurperConfig {
7170 // NOTE: many of these defaults are overruled by DefaultRelayConfig, or even process CLI arg defaults
7271 return &SlurperConfig{
7373- SSL: false,
7472 NewHostPerDayLimit: 50,
7573 DefaultPerSecondLimit: 50,
7674 DefaultPerHourLimit: 2500,
+7-14
cmd/relay/service.go
···1515 "github.com/labstack/echo/v4"
1616 "github.com/labstack/echo/v4/middleware"
1717 "github.com/prometheus/client_golang/prometheus/promhttp"
1818- "gorm.io/gorm"
1918)
20192120type Service struct {
2222- db *gorm.DB // XXX
2321 logger *slog.Logger
2422 relay *relay.Relay
2523 config ServiceConfig
···2826}
29273028type ServiceConfig struct {
3131- // list of hosts which get forwarded com.atproto.sync.requestCrawl (HTTP POST)
3232- ForwardCrawlRequestHosts []string
2929+ // list of hosts which get forwarded admin state changes (takedowns, etc)
3030+ SiblingRelayHosts []string
33313432 // verified against Basic admin auth
3533 AdminPassword string
···39374038 // if true, don't process public (unauthenticated) requestCrawl
4139 DisableRequestCrawl bool
4040+4141+ // if true, allows non-SSL hosts to be added via public requestCrawl
4242+ AllowInsecureHosts bool
4243}
43444445func DefaultServiceConfig() *ServiceConfig {
···4748 }
4849}
49505050-func NewService(db *gorm.DB, r *relay.Relay, config *ServiceConfig) (*Service, error) {
5151+func NewService(r *relay.Relay, config *ServiceConfig) (*Service, error) {
51525253 if config == nil {
5354 config = DefaultServiceConfig()
5455 }
55565657 svc := &Service{
5757- db: db,
5858 logger: slog.Default().With("system", "relay"),
5959 relay: r,
6060 config: *config,
···9090 AllowOrigins: []string{"*"},
9191 AllowHeaders: []string{echo.HeaderOrigin, echo.HeaderContentType, echo.HeaderAccept, echo.HeaderAuthorization},
9292 }))
9393-9494- if !svc.relay.Config.SSL {
9595- e.Use(middleware.LoggerWithConfig(middleware.LoggerConfig{
9696- Format: "method=${method}, uri=${uri}, status=${status} latency=${latency_human}\n",
9797- }))
9898- } else {
9999- e.Use(middleware.LoggerWithConfig(middleware.DefaultLoggerConfig))
100100- }
9393+ e.Use(middleware.LoggerWithConfig(middleware.DefaultLoggerConfig))
1019410295 // React uses a virtual router, so we need to serve the index.html for all
10396 // routes that aren't otherwise handled or in the /assets directory.
+15-15
cmd/relay/stream/consumer.go
···115115// HandleRepoStream
116116// con is source of events
117117// sched gets AddWork for each event
118118-// log may be nil for default logger
119119-func HandleRepoStream(ctx context.Context, con *websocket.Conn, sched Scheduler, log *slog.Logger) error {
120120- if log == nil {
121121- log = slog.Default().With("system", "events")
118118+// logger may be nil for default logger
119119+func HandleRepoStream(ctx context.Context, con *websocket.Conn, sched Scheduler, logger *slog.Logger) error {
120120+ if logger == nil {
121121+ logger = slog.Default().With("system", "events")
122122 }
123123 ctx, cancel := context.WithCancel(ctx)
124124 defer cancel()
···136136 select {
137137 case <-t.C:
138138 if err := con.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(time.Second*10)); err != nil {
139139- log.Warn("failed to ping", "err", err)
139139+ logger.Warn("failed to ping", "err", err)
140140 failcount++
141141 if failcount >= 4 {
142142- log.Error("too many ping fails", "count", failcount)
142142+ logger.Error("too many ping fails", "count", failcount)
143143 con.Close()
144144 return
145145 }
···165165166166 con.SetPongHandler(func(_ string) error {
167167 if err := con.SetReadDeadline(time.Now().Add(time.Minute)); err != nil {
168168- log.Error("failed to set read deadline", "err", err)
168168+ logger.Error("failed to set read deadline", "err", err)
169169 }
170170171171 return nil
···214214 }
215215216216 if evt.Seq < lastSeq {
217217- log.Error("Got events out of order from stream", "seq", evt.Seq, "prev", lastSeq)
217217+ logger.Error("Got events out of order from stream", "seq", evt.Seq, "prev", lastSeq)
218218 }
219219220220 lastSeq = evt.Seq
···231231 }
232232233233 if evt.Seq < lastSeq {
234234- log.Error("Got events out of order from stream", "seq", evt.Seq, "prev", lastSeq)
234234+ logger.Error("Got events out of order from stream", "seq", evt.Seq, "prev", lastSeq)
235235 }
236236237237 lastSeq = evt.Seq
···249249 }
250250251251 if evt.Seq < lastSeq {
252252- log.Error("Got events out of order from stream", "seq", evt.Seq, "prev", lastSeq)
252252+ logger.Error("Got events out of order from stream", "seq", evt.Seq, "prev", lastSeq)
253253 }
254254 lastSeq = evt.Seq
255255···265265 }
266266267267 if evt.Seq < lastSeq {
268268- log.Error("Got events out of order from stream", "seq", evt.Seq, "prev", lastSeq)
268268+ logger.Error("Got events out of order from stream", "seq", evt.Seq, "prev", lastSeq)
269269 }
270270 lastSeq = evt.Seq
271271···281281 }
282282283283 if evt.Seq < lastSeq {
284284- log.Error("Got events out of order from stream", "seq", evt.Seq, "prev", lastSeq)
284284+ logger.Error("Got events out of order from stream", "seq", evt.Seq, "prev", lastSeq)
285285 }
286286 lastSeq = evt.Seq
287287···310310 }
311311312312 if evt.Seq < lastSeq {
313313- log.Error("Got events out of order from stream", "seq", evt.Seq, "prev", lastSeq)
313313+ logger.Error("Got events out of order from stream", "seq", evt.Seq, "prev", lastSeq)
314314 }
315315 lastSeq = evt.Seq
316316···327327 }
328328329329 if evt.Seq < lastSeq {
330330- log.Error("Got events out of order from stream", "seq", evt.Seq, "prev", lastSeq)
330330+ logger.Error("Got events out of order from stream", "seq", evt.Seq, "prev", lastSeq)
331331 }
332332 lastSeq = evt.Seq
333333···343343 }
344344345345 if evt.Seq < lastSeq {
346346- log.Error("Got events out of order from stream", "seq", evt.Seq, "prev", lastSeq)
346346+ logger.Error("Got events out of order from stream", "seq", evt.Seq, "prev", lastSeq)
347347 }
348348349349 lastSeq = evt.Seq