this repo has no description
4
fork

Configure Feed

Select the types of activity you want to include in your feed.

Add scripts for retrieving follows

+200 -2
+1
data/follows/.gitignore
··· 1 + ./*
+198
scripts/get_follows.py
··· 1 + import argparse 2 + import asyncio 3 + from datetime import datetime 4 + import gzip 5 + import logging 6 + import os 7 + import sys 8 + from typing import Tuple, List 9 + 10 + from atproto import AsyncClient 11 + from atproto import exceptions as at_exceptions 12 + from atproto_client.models.app.bsky.graph.follow import Record as FollowRecord 13 + import pandas as pd 14 + 15 + from crawl_follows import RateLimit 16 + 17 + logger = logging.getLogger(__name__) 18 + logger.setLevel(logging.INFO) 19 + 20 + # Create formatter 21 + formatter = logging.Formatter( 22 + "%(asctime)s | %(levelname)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S" 23 + ) 24 + 25 + # Console handler 26 + console_handler = logging.StreamHandler(sys.stdout) 27 + console_handler.setFormatter(formatter) 28 + logger.addHandler(console_handler) 29 + 30 + 31 + BATCH_SIZE = 10 32 + FOLLOWER_THRESHOLD = 150 33 + 34 + 35 + async def get_all_follows( 36 + client: AsyncClient, 37 + rate_limit: RateLimit, 38 + account_did: str, 39 + ) -> Tuple[List[FollowRecord], str]: 40 + follows: List[FollowRecord] = [] 41 + await rate_limit.acquire() 42 + try: 43 + data = await client.com.atproto.repo.list_records( 44 + { 45 + "collection": "app.bsky.graph.follow", 46 + "repo": account_did, 47 + "limit": 100, 48 + } 49 + ) 50 + # If user can't be accessed just return an empty list to skip next time 51 + except at_exceptions.BadRequestError as e: 52 + if e.response.status_code == 400: 53 + return [], account_did 54 + else: 55 + logger.info(f"Error status code: {e.response.status_code}") 56 + raise e 57 + 58 + for follow in data.records: 59 + follows.append(follow) 60 + 61 + # Limit to 1000 follows per account 62 + while data.cursor and len(follows) < 1000: 63 + await rate_limit.acquire() 64 + data = await client.com.atproto.repo.list_records( 65 + { 66 + "collection": "app.bsky.graph.follow", 67 + "repo": account_did, 68 + "cursor": data.cursor, 69 + "limit": 100, 70 + } 71 + ) 72 + 73 + for follow in data.records: 74 + follows.append(follow) 75 + 76 + return follows, account_did 77 + 78 + 79 + async def retrieve_follows( 80 + graph_file: str, 81 + checkpoint_dir: str, 82 + ): 83 + 84 + # If checkpoint dir doesn't exist, try to create it 85 + if not os.path.isdir(checkpoint_dir): 86 + logger.info("Checkpoint dir doesn't exist, creating...") 87 + try: 88 + os.mkdir(checkpoint_dir) 89 + except Exception as e: 90 + logger.error(f"Failed to created checkpoint dir, {checkpoint_dir}\n{e}") 91 + sys.exit(1) 92 + 93 + # Checkpoint folders contain one file per user 94 + completed_accounts = set() 95 + try: 96 + files = os.listdir(checkpoint_dir) 97 + for file in files: 98 + # Grab entire file name except for .gz extension 99 + completed_accounts.add(file[:-3]) 100 + except Exception as e: 101 + logger.error( 102 + f"Failed to recover from checkpoint dir, {checkpoint_dir}\n{e}", 103 + exc_info=1, 104 + ) 105 + sys.exit(1) 106 + 107 + # Load follow graph parquet file 108 + to_explore = dict() 109 + try: 110 + logger.info("Parsing follower graph file...") 111 + follow_df = pd.read_parquet(graph_file) 112 + # Limit to only accounts following between 100 and 1000 followers 113 + follow_df = follow_df.loc[follow_df["follows"].str.len().between(100, 1000)] 114 + except Exception as e: 115 + logger.error(f"Failed to open follow graph file, {graph_file}\n{e}") 116 + sys.exit(1) 117 + 118 + for _, row in follow_df.iterrows(): 119 + for acct in row["follows"]: 120 + if acct not in completed_accounts: 121 + if acct not in to_explore: 122 + to_explore[acct] = 0 123 + to_explore[acct] += 1 124 + 125 + accts = [(acct, follows) for acct, follows in to_explore.items()] 126 + accts.sort(key=lambda x: -1 * x[1]) 127 + 128 + logger.info(f"Num of accounts to retrieve follows from: {len(accts)}") 129 + 130 + client = AsyncClient() 131 + 132 + # Get all follows for accounts 133 + batch_count = 0 134 + fail_count = 0 135 + rate_limiter = RateLimit(BATCH_SIZE) 136 + for i in range(0, len(accts), BATCH_SIZE): 137 + batch = [acct for acct, _ in accts[i : i + BATCH_SIZE]] 138 + for result in asyncio.as_completed( 139 + [get_all_follows(client, rate_limiter, did) for did in batch] 140 + ): 141 + try: 142 + follows, did = await result 143 + # Save follows 144 + with gzip.open( 145 + os.path.join(checkpoint_dir, did + ".gz"), "wt" 146 + ) as out_file: 147 + for follow in follows: 148 + out_file.write(follow.model_dump_json() + "\n") 149 + except at_exceptions.BadRequestError as e: 150 + # Bad request is probably a profile that's private or deleted 151 + logger.info(f"Bad Request: {e.response.content.error}") 152 + continue 153 + except Exception as e: 154 + logger.error(f"Failed to get follows: {e}", exc_info=1) 155 + fail_count += 1 156 + if fail_count >= 100: 157 + logger.error("Hitting error threshold, exiting...") 158 + sys.exit(1) 159 + continue 160 + 161 + batch_count += 1 162 + if batch_count % 10 == 0: 163 + logger.info(f"Completed batch: {batch_count}") 164 + 165 + 166 + def main(): 167 + parser = argparse.ArgumentParser( 168 + prog="GetFollows", 169 + description="Get all follows for accounts in provided follow graph", 170 + ) 171 + parser.add_argument( 172 + "--graph-file", 173 + dest="graph_file", 174 + required=True, 175 + help="File with follow graph", 176 + ) 177 + parser.add_argument( 178 + "--save-dir", 179 + dest="save_dir", 180 + required=True, 181 + help="Where to store crawl data", 182 + ) 183 + args = parser.parse_args() 184 + 185 + if args.save_dir is None and args.ckpt is None: 186 + logger.error("Must provide save dir or checkpoint dir") 187 + sys.exit(1) 188 + 189 + asyncio.run( 190 + retrieve_follows( 191 + graph_file=args.graph_file, 192 + checkpoint_dir=args.save_dir, 193 + ) 194 + ) 195 + 196 + 197 + if __name__ == "__main__": 198 + main()
+1 -2
scripts/get_likes.py
··· 2 2 import asyncio 3 3 from datetime import datetime 4 4 import gzip 5 - import json 6 5 import logging 7 6 import os 8 7 import sys 9 - from typing import Tuple, List, Dict 8 + from typing import Tuple, List 10 9 11 10 from atproto import AsyncClient 12 11 from atproto import exceptions as at_exceptions