this repo has no description
4
fork

Configure Feed

Select the types of activity you want to include in your feed.

at 5cc2d9e30a1eccab5c679c534f957af233b173da 313 lines 9.6 kB view raw
1import argparse 2import asyncio 3from datetime import datetime 4import decimal 5import gzip 6import json 7import logging 8import os 9import sys 10import time 11from typing import Tuple, List, Dict 12 13from atproto import AsyncClient 14from atproto import exceptions as at_exceptions 15from atproto_client.models.app.bsky.embed.record import ViewBlocked 16from atproto_client.models.app.bsky.feed.defs import FeedViewPost 17import pandas as pd 18from rich import print 19 20from crawl_follows import RateLimit 21 22logger = logging.getLogger(__name__) 23logger.setLevel(logging.INFO) 24 25# Create formatter 26formatter = logging.Formatter( 27 "%(asctime)s | %(levelname)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S" 28) 29 30# Console handler 31console_handler = logging.StreamHandler(sys.stdout) 32console_handler.setFormatter(formatter) 33logger.addHandler(console_handler) 34 35 36BATCH_SIZE = 10 37FOLLOWER_THRESHOLD = 150 38REQUIRED_ENV = ("BSKY_USER", "BSKY_APP_PW") 39 40 41def process_post(top: FeedViewPost): 42 post = top.post 43 data = { 44 "author": post.author.did, 45 "text": post.record.text, 46 "cid": post.cid, 47 "created_at": post.record.created_at, 48 "repost": False, 49 } 50 if ( 51 top.reason is not None 52 and top.reason.py_type == "app.bsky.feed.defs#reasonRepost" 53 ): 54 data["repost"] = True 55 56 if post.embed: 57 data["embed"] = {} 58 if post.embed.py_type == "app.bsky.embed.external#view": 59 data["embed"] = { 60 "title": post.embed.external.title, 61 "description": post.embed.external.description, 62 "uri": post.embed.external.uri, 63 "thumb": post.embed.external.thumb, 64 } 65 elif post.embed.py_type == "app.bsky.embed.record#view": 66 # Ignore everything thats not a quote-tweet 67 if post.embed.record.py_type == "app.bsky.embed.record#viewRecord": 68 data["embed"] = { 69 "author": post.embed.record.author.did, 70 "text": post.embed.record.value.text, 71 "cid": post.embed.record.cid, 72 "created_at": post.embed.record.value.created_at, 73 } 74 elif post.embed.py_type == "app.bsky.embed.images#view": 75 data["embed"]["images"] = [] 76 for image in post.embed.images: 77 data["embed"]["images"].append( 78 { 79 "alt_text": image.alt, 80 "full_url": image.fullsize, 81 "thumb_url": image.thumb, 82 } 83 ) 84 elif post.embed.py_type == "app.bsky.embed.video#view": 85 data["embed"]["video"] = { 86 "alt": post.embed.alt, 87 "full_url": post.embed.playlist, 88 "thumb_url": post.embed.thumbnail, 89 } 90 91 if top.reply: 92 if top.reply.parent.py_type == "app.bsky.feed.defs#postView": 93 data["reply_parent"] = {} 94 data["reply_parent"]["author"] = top.reply.parent.author.did 95 data["reply_parent"]["text"] = top.reply.parent.record.text 96 data["reply_parent"]["cid"] = top.reply.parent.cid 97 data["reply_parent"]["created_at"] = top.reply.parent.record.created_at 98 if top.reply.root.py_type == "app.bsky.feed.defs#postView": 99 data["reply_parent"]["root_cid"] = top.reply.root.cid 100 101 return data 102 103 104async def get_all_posts( 105 client: AsyncClient, 106 rate_limit: RateLimit, 107 account_did: str, 108 start_dt: datetime, 109 end_dt: datetime, 110) -> Tuple[List[Dict], str]: 111 posts: List[Dict] = [] 112 await rate_limit.acquire() 113 data = await client.get_author_feed( 114 actor=account_did, 115 filter="posts_and_author_threads", 116 ) 117 118 for top in data.feed: 119 dt = datetime.strptime(top.post.indexed_at, "%Y-%m-%dT%H:%M:%S.%fZ") 120 if start_dt <= dt and dt < end_dt: 121 parsed = process_post(top) 122 if parsed is not None: 123 posts.append(parsed) 124 125 hit_start_window = False 126 while data.cursor and not hit_start_window: 127 await rate_limit.acquire() 128 data = await client.get_author_feed( 129 actor=account_did, filter="posts_and_author_threads", cursor=data.cursor 130 ) 131 132 for top in data.feed: 133 dt = datetime.strptime(top.post.indexed_at, "%Y-%m-%dT%H:%M:%S.%fZ") 134 if start_dt <= dt and dt < end_dt: 135 parsed = process_post(top) 136 if parsed is not None: 137 posts.append(parsed) 138 if dt < start_dt: 139 hit_start_window = True 140 141 return posts, account_did 142 143 144async def retrieve_posts( 145 user: str, 146 app_pw: str, 147 graph_file: str, 148 checkpoint_dir: str, 149 start_dt: datetime, 150 end_dt: datetime, 151): 152 153 # If checkpoint dir doesn't exist, try to create it 154 if not os.path.isdir(checkpoint_dir): 155 logger.info("Checkpoint dir doesn't exist, creating...") 156 try: 157 os.mkdir(checkpoint_dir) 158 except Exception as e: 159 logger.error(f"Failed to created checkpoint dir, {checkpoint_dir}\n{e}") 160 sys.exit(1) 161 162 # Checkpoint folders contain one file per user 163 completed_accounts = set() 164 try: 165 files = os.listdir(checkpoint_dir) 166 for file in files: 167 # Grab entire file name except for .gz extension 168 completed_accounts.add(file[:-3]) 169 except Exception as e: 170 logger.error( 171 f"Failed to recover from checkpoint dir, {checkpoint_dir}\n{e}", 172 exc_info=1, 173 ) 174 sys.exit(1) 175 176 # Load follow graph parquet file 177 to_explore = dict() 178 try: 179 logger.info("Parsing follower graph file...") 180 follow_df = pd.read_parquet(graph_file) 181 except Exception as e: 182 logger.error(f"Failed to open follow graph file, {graph_file}\n{e}") 183 sys.exit(1) 184 185 for _, row in follow_df.iterrows(): 186 for acct in row["follows"]: 187 if acct not in completed_accounts: 188 if acct not in to_explore: 189 to_explore[acct] = 0 190 to_explore[acct] += 1 191 192 accts = [ 193 (acct, follows) 194 for acct, follows in to_explore.items() 195 if follows >= FOLLOWER_THRESHOLD 196 ] 197 accts.sort(key=lambda x: -1 * x[1]) 198 199 logger.info(f"Num of accounts to retrieve posts from: {len(accts)}") 200 201 client = AsyncClient() 202 await client.login(user, app_pw) 203 204 # Get all posts for accounts 205 batch_count = 0 206 fail_count = 0 207 rate_limiter = RateLimit(BATCH_SIZE) 208 for i in range(0, len(accts), BATCH_SIZE): 209 batch = [acct for acct, _ in accts[i : i + BATCH_SIZE]] 210 for result in asyncio.as_completed( 211 [ 212 get_all_posts(client, rate_limiter, did, start_dt, end_dt) 213 for did in batch 214 ] 215 ): 216 try: 217 posts, did = await result 218 # Save posts 219 with gzip.open( 220 os.path.join(checkpoint_dir, did + ".gz"), "wt" 221 ) as out_file: 222 for post in posts: 223 out_file.write(json.dumps(post) + "\n") 224 except at_exceptions.BadRequestError as e: 225 # Bad request is probably a profile that's private or deleted 226 logger.info(f"Bad Request: {e.response.content.error}") 227 continue 228 except Exception as e: 229 logger.error(f"Failed to get posts: {e}", exc_info=1) 230 fail_count += 1 231 if fail_count >= 3: 232 sys.exit(1) 233 continue 234 235 batch_count += 1 236 if batch_count % 10 == 0: 237 logger.info(f"Completed batch: {batch_count}") 238 239 240def main(): 241 for key in REQUIRED_ENV: 242 if key not in os.environ: 243 raise ValueError(f"Must set '{key}' env var") 244 245 user_name = os.environ["BSKY_USER"] 246 app_pw = os.environ["BSKY_APP_PW"] 247 248 parser = argparse.ArgumentParser( 249 prog="GetPosts", 250 description="Get all posts for accounts in provided follow graph", 251 ) 252 parser.add_argument( 253 "--graph-file", 254 dest="graph_file", 255 required=True, 256 help="File with follow graph", 257 ) 258 parser.add_argument( 259 "--save-dir", 260 dest="save_dir", 261 required=True, 262 help="Where to store crawl data", 263 ) 264 parser.add_argument( 265 "--start", 266 dest="start", 267 required=True, 268 help="Date to start saving posts from (YYYY-MM-DD)", 269 ) 270 parser.add_argument( 271 "--end", 272 dest="end", 273 required=True, 274 help="Date to stop (exclusive) saving posts from (YYYY-MM-DD)", 275 ) 276 args = parser.parse_args() 277 278 if args.save_dir is None and args.ckpt is None: 279 logger.error("Must provide save dir or checkpoint dir") 280 sys.exit(1) 281 282 try: 283 start = datetime.strptime(args.start, "%Y-%m-%d") 284 except: 285 logger.error("Invalid start date") 286 sys.exit(1) 287 288 try: 289 end = datetime.strptime(args.end, "%Y-%m-%d") 290 except: 291 logger.error("Invalid end date") 292 sys.exit(1) 293 294 if end <= start: 295 logger.error( 296 "Start date has to be before date, what're you trying to do man..." 297 ) 298 sys.exit(1) 299 300 asyncio.run( 301 retrieve_posts( 302 user_name, 303 app_pw, 304 graph_file=args.graph_file, 305 checkpoint_dir=args.save_dir, 306 start_dt=start, 307 end_dt=end, 308 ) 309 ) 310 311 312if __name__ == "__main__": 313 main()