package backstream import ( //"bytes" "context" "encoding/json" "errors" "fmt" "log" "net/http" "strings" "sync" "time" "io" "io/ioutil" "os" "runtime" "runtime/debug" "github.com/gorilla/websocket" "github.com/klauspost/compress/zstd" data "github.com/bluesky-social/indigo/atproto/atdata" atrepo "github.com/bluesky-social/indigo/atproto/repo" // "github.com/bluesky-social/indigo/repo" "github.com/bluesky-social/indigo/atproto/syntax" "github.com/ipfs/go-cid" ) const ( numWorkers = 20 ) var DefaultUpgrader = websocket.Upgrader{ CheckOrigin: func(r *http.Request) bool { return true }, } type BackfillHandler struct { Upgrader websocket.Upgrader SessionManager *SessionManager AtpClient *ATProtoClient ZstdDict []byte UseGetRepoMethod bool } type BackfillParams struct { WantedDIDs []string WantedCollections []string GetRecordFormat bool } func (h *BackfillHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { compress := (r.URL.Query().Get("compress") == "true") && (h.ZstdDict != nil) conn, err := h.Upgrader.Upgrade(w, r, nil) if err != nil { log.Printf("Failed to upgrade connection: %v", err) return } defer conn.Close() if compress { log.Println("Client requested zstd compression. Enabling.") } params, ticket, err := h.parseQueryParams(r) if err != nil { h.sendError(conn, err.Error()) return } log.Printf("New connection for ticket: %s. DIDs: %v, Collections: %v, Workers: %d", ticket, params.WantedDIDs, params.WantedCollections, numWorkers) session := h.SessionManager.GetOrCreate(ticket, params) session.LastAccessed = time.Now() ctx, cancel := context.WithCancel(r.Context()) defer cancel() go func() { defer cancel() for { if _, _, err := conn.ReadMessage(); err != nil { if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { log.Printf("Client disconnected for ticket %s (read error): %v", ticket, err) } break } } }() var wg sync.WaitGroup jobs := make(chan string, numWorkers) results := make(chan interface{}, 100) for i := 1; i <= numWorkers; i++ { go h.worker(ctx, i, &wg, jobs, results, session) } writerDone := make(chan struct{}) if compress { go h.compressedWriter(ctx, cancel, conn, results, writerDone) } else { go h.writer(ctx, cancel, conn, results, writerDone) } wg.Add(1) go h.producer(ctx, &wg, jobs, session) wg.Wait() close(results) <-writerDone log.Printf("Backfill completed for ticket: %s", session.Ticket) h.sendMessage(conn, map[string]string{"status": "complete", "message": "Backfill finished."}) } func (h *BackfillHandler) compressedWriter(ctx context.Context, cancel context.CancelFunc, conn *websocket.Conn, results <-chan interface{}, done chan<- struct{}) { defer close(done) encoder, err := zstd.NewWriter(nil, zstd.WithEncoderDict(h.ZstdDict)) if err != nil { log.Printf("ERROR: [CompressedWriter] Failed to create zstd encoder with dictionary: %v", err) cancel() return } defer encoder.Close() for { select { case result, ok := <-results: if !ok { return } data, err := json.Marshal(result) if err != nil { log.Printf("ERROR: [CompressedWriter] Failed to marshal JSON: %v", err) cancel() return } compressed := encoder.EncodeAll(data, nil) if err := conn.WriteMessage(websocket.BinaryMessage, compressed); err != nil { log.Printf("ERROR: [CompressedWriter] Failed to write compressed message: %v", err) cancel() return } case <-ctx.Done(): log.Printf("[CompressedWriter] Context cancelled, stopping.") return } } } func (h *BackfillHandler) writer(ctx context.Context, cancel context.CancelFunc, conn *websocket.Conn, results <-chan interface{}, done chan<- struct{}) { defer close(done) for { select { case result, ok := <-results: if !ok { return } if err := h.sendMessage(conn, result); err != nil { log.Printf("ERROR: [Writer] Failed to write message, closing connection: %v", err) cancel() return } case <-ctx.Done(): log.Printf("[Writer] Context cancelled, stopping.") return } } } func (h *BackfillHandler) producer(ctx context.Context, wg *sync.WaitGroup, jobs chan<- string, session *Session) { defer close(jobs) defer wg.Done() isFullNetwork := len(session.Params.WantedDIDs) == 1 && session.Params.WantedDIDs[0] == "*" isAllCollections := len(session.Params.WantedCollections) == 1 && session.Params.WantedCollections[0] == "*" if isFullNetwork { if isAllCollections { // --- Case 1: Full Network, All Collections (dids=*&collections=*) --- // We need to list *all* repos from the relay. log.Printf("[Producer] Starting full network scan for all collections.") for { select { case <-ctx.Done(): log.Printf("[Producer] Context cancelled, stopping full repo fetch.") return default: } log.Printf("[Producer] Fetching all repos with cursor: %s", session.ListReposCursor) repos, nextCursor, err := h.AtpClient.ListRepos(ctx, session.ListReposCursor) if err != nil { log.Printf("ERROR: [Producer] Failed to list all repos: %v", err) return } for _, repo := range repos { if !session.IsDIDComplete(repo.DID) { wg.Add(1) jobs <- repo.DID } } session.mu.Lock() session.ListReposCursor = nextCursor session.LastAccessed = time.Now() session.mu.Unlock() if nextCursor == "" { log.Printf("[Producer] Finished fetching all repos from relay.") break } } } else { // --- Case 2: Full Network, Specific Collections (dids=*&collections=a,b,c) --- // For each specific collection, page through all repos and send DIDs to workers. log.Printf("[Producer] Starting network scan for specific collections: %v", session.Params.WantedCollections) for _, collection := range session.Params.WantedCollections { for { select { case <-ctx.Done(): log.Printf("[Producer] Context cancelled, stopping repo fetch.") return default: } log.Printf("[Producer] Fetching repos for %s with cursor: %s", collection, session.ListReposCursor) repos, nextCursor, err := h.AtpClient.ListReposByCollection(ctx, collection, session.ListReposCursor) if err != nil { log.Printf("ERROR: [Producer] Failed to list repos for collection %s: %v", collection, err) return } for _, repo := range repos { if !session.IsDIDComplete(repo.DID) { wg.Add(1) jobs <- repo.DID } } session.mu.Lock() session.ListReposCursor = nextCursor session.LastAccessed = time.Now() session.mu.Unlock() if nextCursor == "" { log.Printf("[Producer] Finished fetching all repos for collection %s", collection) break } } } } } else { // --- Case 3: Specific List of DIDs (dids=a,b,c) --- // Send user-provided DIDs to workers. for _, did := range session.Params.WantedDIDs { select { case <-ctx.Done(): log.Printf("[Producer] Context cancelled, stopping DID processing.") return default: if !session.IsDIDComplete(did) { wg.Add(1) jobs <- did } else { log.Printf("[Producer] Skipping already completed DID: %s", did) } } } } } func (h *BackfillHandler) worker(ctx context.Context, id int, wg *sync.WaitGroup, jobs <-chan string, results chan<- interface{}, session *Session) { for did := range jobs { func(did string) { defer func() { wg.Done() runtime.GC() debug.FreeOSMemory() log.Printf("[Worker %d] Cleaned up resources for DID: %s", id, did) }() select { case <-ctx.Done(): return default: } log.Printf("[Worker %d] Processing DID: %s", id, did) pdsURL, err := h.AtpClient.ResolveDID(ctx, did) if err != nil { log.Printf("WARN: [Worker %d] Could not resolve DID %s, skipping. Error: %v", id, did, err) return } if h.UseGetRepoMethod { h.processDIDWithGetRepo(ctx, id, did, pdsURL, results, session) } else { h.processDIDWithListRecords(ctx, id, did, pdsURL, results, session) } session.MarkDIDComplete(did) log.Printf("[Worker %d] Finished DID: %s", id, did) }(did) } } func (h *BackfillHandler) processDIDWithGetRepo(ctx context.Context, id int, did, pdsURL string, results chan<- interface{}, session *Session) { log.Printf("[Worker %d] Using streaming getRepo method for %s", id, did) isAllCollections := len(session.Params.WantedCollections) == 1 && session.Params.WantedCollections[0] == "*" wantedSet := make(map[string]struct{}) if !isAllCollections { for _, coll := range session.Params.WantedCollections { wantedSet[coll] = struct{}{} } } respBody, err := h.AtpClient.GetRepo(ctx, pdsURL, did) if err != nil { log.Printf("WARN: [Worker %d] Failed to get repo stream for %s: %v", id, did, err) return } defer respBody.Close() if err := os.MkdirAll("./temp", 0o755); err != nil { panic(err) } tempFile, err := ioutil.TempFile("./temp", "backstream-repo-*.car") if err != nil { log.Printf("ERROR: [Worker %d] Failed to create temp file for %s: %v", id, did, err) return } defer os.Remove(tempFile.Name()) if _, err := io.Copy(tempFile, respBody); err != nil { log.Printf("ERROR: [Worker %d] Failed to write repo to temp file for %s: %v", id, did, err) return } if err := tempFile.Close(); err != nil { log.Printf("ERROR: [Worker %d] Failed to close temp file for %s: %v", id, did, err) return } readHandle, err := os.Open(tempFile.Name()) if err != nil { log.Printf("ERROR: [Worker %d] Failed to open temp file for reading %s: %v", id, did, err) return } defer readHandle.Close() _, r, err := atrepo.LoadRepoFromCAR(ctx, readHandle) if err != nil { log.Printf("WARN: [Worker %d] Failed to read CAR stream for %s from temp file: %v", id, did, err) return } err = r.MST.Walk(func(k []byte, v cid.Cid) error { select { case <-ctx.Done(): return errors.New("context cancelled during repo walk") default: } path := string(k) collection, rkey, err := syntax.ParseRepoPath(path) if err != nil { log.Printf("WARN: [Worker %d] Could not parse repo path '%s' for %s, skipping record", id, path, did) return nil } if !isAllCollections { if _, ok := wantedSet[string(collection)]; !ok { return nil } } recBytes, _, err := r.GetRecordBytes(ctx, collection, rkey) if err != nil { log.Printf("WARN: [Worker %d] Failed to get record bytes for %s: %v", id, path, err) return nil } recordVal, err := data.UnmarshalCBOR(recBytes) if err != nil { log.Printf("WARN: [Worker %d] Failed to unmarshal record CBOR for %s: %v", id, path, err) return nil } record := Record{ URI: fmt.Sprintf("at://%s/%s", did, path), CID: v.String(), Value: recordVal, } output := h.formatOutput(record, did, string(collection), session.Params.GetRecordFormat) select { case results <- output: case <-ctx.Done(): return errors.New("context cancelled while sending result") } session.SetListRecordsCursor(did, string(collection), string(rkey)) return nil }) if err != nil && !errors.Is(err, context.Canceled) { log.Printf("WARN: [Worker %d] Error while walking repo for %s: %v", id, did, err) } } func (h *BackfillHandler) processDIDWithListRecords(ctx context.Context, id int, did, pdsURL string, results chan<- interface{}, session *Session) { log.Printf("[Worker %d] Using listRecords method for %s", id, did) isAllCollections := len(session.Params.WantedCollections) == 1 && session.Params.WantedCollections[0] == "*" var collectionsToProcess []string if isAllCollections { repoCollections, err := h.AtpClient.DescribeRepo(ctx, pdsURL, did) if err != nil { log.Printf("WARN: [Worker %d] Could not describe repo for %s to find collections, skipping. Error: %v", id, did, err) return } collectionsToProcess = repoCollections log.Printf("[Worker %d] Found %d collections for DID %s", id, len(collectionsToProcess), did) } else { collectionsToProcess = session.Params.WantedCollections } for _, collection := range collectionsToProcess { cursor := session.GetListRecordsCursor(did, collection) for { select { case <-ctx.Done(): log.Printf("[Worker %d] Context cancelled for DID %s", id, did) return default: } records, nextCursor, err := h.AtpClient.ListRecords(ctx, pdsURL, did, collection, cursor) if err != nil { if !strings.Contains(err.Error(), "status: 400") { log.Printf("WARN: [Worker %d] Failed to list records for %s/%s, skipping. Error: %v", id, did, collection, err) } break } for _, record := range records { output := h.formatOutput(record, did, collection, session.Params.GetRecordFormat) select { case results <- output: case <-ctx.Done(): log.Printf("[Worker %d] Context cancelled while sending results for %s", id, did) return } } session.SetListRecordsCursor(did, collection, nextCursor) cursor = nextCursor if cursor == "" { break } } } } func (h *BackfillHandler) parseQueryParams(r *http.Request) (BackfillParams, string, error) { query := r.URL.Query() ticket := query.Get("ticket") wantedDidsStr := query.Get("wantedDids") wantedCollectionsStr := query.Get("wantedCollections") if wantedCollectionsStr == "" && wantedDidsStr == "" && ticket == "" { ticket = "jetstreamfalse" } else if ticket == "" { ticket = generateTicket() } if wantedDidsStr == "" { log.Println("Query parameter 'wantedDids' not specified, defaulting to '*' (all repos).") wantedDidsStr = "*" } if wantedCollectionsStr == "" { log.Println("Query parameter 'wantedCollections' not specified, defaulting to '*' (all collections).") wantedCollectionsStr = "*" } params := BackfillParams{ WantedDIDs: strings.Split(wantedDidsStr, ","), WantedCollections: strings.Split(wantedCollectionsStr, ","), GetRecordFormat: query.Get("getRecordFormat") == "true", } return params, ticket, nil } func (h *BackfillHandler) formatOutput(record Record, did, collection string, getRecordFormat bool) interface{} { if getRecordFormat { return GetRecordOutput{ URI: record.URI, CID: record.CID, Value: record.Value, } } uriParts := strings.Split(record.URI, "/") rkey := "" if len(uriParts) == 5 { rkey = uriParts[4] } return JetstreamLikeOutput{ Did: did, Kind: "commit", TimeUS: "1725911162329308", Commit: JetstreamLikeCommit{ Rev: rkey, Operation: "create", Collection: collection, RKey: rkey, Record: record.Value, CID: record.CID, }, } } func (h *BackfillHandler) sendError(conn *websocket.Conn, message string) { log.Printf("Sending error to client: %s", message) _ = conn.WriteJSON(map[string]string{"error": message}) } func (h *BackfillHandler) sendMessage(conn *websocket.Conn, v interface{}) error { return conn.WriteJSON(v) }