this repo has no description
4
fork

Configure Feed

Select the types of activity you want to include in your feed.

at main 272 lines 8.6 kB view raw
1import argparse 2import asyncio 3from datetime import datetime 4import gzip 5import json 6import logging 7import os 8import sys 9from typing import Tuple, List, Dict 10 11from atproto import AsyncClient 12from atproto import exceptions as at_exceptions 13from atproto_client.models.app.bsky.feed.defs import FeedViewPost 14 15from utils import get_accounts, load_checkpoint, RateLimit, BSKY_API_LIMIT 16 17logger = logging.getLogger(__name__) 18logger.setLevel(logging.INFO) 19 20# Create formatter 21formatter = logging.Formatter( 22 "%(asctime)s | %(levelname)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S" 23) 24 25# Console handler 26console_handler = logging.StreamHandler(sys.stdout) 27console_handler.setFormatter(formatter) 28logger.addHandler(console_handler) 29 30 31BATCH_SIZE = 10 32REQUIRED_ENV = ("BSKY_USER", "BSKY_APP_PW") 33 34 35def process_post(top: FeedViewPost): 36 post = top.post 37 data = { 38 "author": post.author.did, 39 "text": post.record.text, 40 "cid": post.cid, 41 "created_at": post.record.created_at, 42 "repost": False, 43 } 44 if ( 45 top.reason is not None 46 and top.reason.py_type == "app.bsky.feed.defs#reasonRepost" 47 ): 48 data["repost"] = True 49 50 if post.embed: 51 data["embed"] = {} 52 if post.embed.py_type == "app.bsky.embed.external#view": 53 data["embed"] = { 54 "title": post.embed.external.title, 55 "description": post.embed.external.description, 56 "uri": post.embed.external.uri, 57 "thumb": post.embed.external.thumb, 58 } 59 elif post.embed.py_type == "app.bsky.embed.record#view": 60 # Ignore everything thats not a quote-tweet 61 if post.embed.record.py_type == "app.bsky.embed.record#viewRecord": 62 data["embed"] = { 63 "author": post.embed.record.author.did, 64 "text": post.embed.record.value.text, 65 "cid": post.embed.record.cid, 66 "created_at": post.embed.record.value.created_at, 67 } 68 elif post.embed.py_type == "app.bsky.embed.images#view": 69 data["embed"]["images"] = [] 70 for image in post.embed.images: 71 data["embed"]["images"].append( 72 { 73 "alt_text": image.alt, 74 "full_url": image.fullsize, 75 "thumb_url": image.thumb, 76 } 77 ) 78 elif post.embed.py_type == "app.bsky.embed.video#view": 79 data["embed"]["video"] = { 80 "alt": post.embed.alt, 81 "full_url": post.embed.playlist, 82 "thumb_url": post.embed.thumbnail, 83 } 84 85 if top.reply: 86 if top.reply.parent.py_type == "app.bsky.feed.defs#postView": 87 data["reply_parent"] = {} 88 data["reply_parent"]["author"] = top.reply.parent.author.did 89 data["reply_parent"]["text"] = top.reply.parent.record.text 90 data["reply_parent"]["cid"] = top.reply.parent.cid 91 data["reply_parent"]["created_at"] = top.reply.parent.record.created_at 92 if top.reply.root.py_type == "app.bsky.feed.defs#postView": 93 data["reply_parent"]["root_cid"] = top.reply.root.cid 94 95 return data 96 97 98async def get_all_posts( 99 client: AsyncClient, 100 rate_limit: RateLimit, 101 account_did: str, 102 start_dt: datetime, 103 end_dt: datetime, 104) -> Tuple[List[Dict], str]: 105 posts: List[Dict] = [] 106 await rate_limit.acquire() 107 try: 108 data = await client.get_author_feed( 109 actor=account_did, 110 filter="posts_and_author_threads", 111 ) 112 # If user can't be accessed just return an empty list to skip next time 113 except at_exceptions.BadRequestError as e: 114 if e.response.status_code == 400: 115 return [], account_did 116 else: 117 logger.info(f"Error status code: {e.response.status_code}") 118 raise e 119 120 for top in data.feed: 121 dt = datetime.strptime(top.post.indexed_at, "%Y-%m-%dT%H:%M:%S.%fZ") 122 if start_dt <= dt and dt < end_dt: 123 parsed = process_post(top) 124 if parsed is not None: 125 posts.append(parsed) 126 127 hit_start_window = False 128 while data.cursor and not hit_start_window: 129 await rate_limit.acquire() 130 data = await client.get_author_feed( 131 actor=account_did, filter="posts_and_author_threads", cursor=data.cursor 132 ) 133 134 for top in data.feed: 135 dt = datetime.strptime(top.post.indexed_at, "%Y-%m-%dT%H:%M:%S.%fZ") 136 if start_dt <= dt and dt < end_dt: 137 parsed = process_post(top) 138 if parsed is not None: 139 posts.append(parsed) 140 if dt < start_dt: 141 hit_start_window = True 142 143 return posts, account_did 144 145 146async def retrieve_posts( 147 user: str, 148 app_pw: str, 149 graph_file: str, 150 checkpoint_dir: str, 151 start_dt: datetime, 152 end_dt: datetime, 153): 154 # Checkpoint folders contain one file per user 155 completed_accounts = load_checkpoint(checkpoint_dir) 156 accts = get_accounts(graph_file, completed_accounts) 157 logger.info(f"Num of accounts to retrieve posts from: {len(accts)}") 158 159 client = AsyncClient() 160 await client.login(user, app_pw) 161 162 # Get all posts for accounts 163 batch_count = 0 164 fail_count = 0 165 rate_limiter = RateLimit(BSKY_API_LIMIT) 166 for i in range(0, len(accts), BATCH_SIZE): 167 batch = [acct for acct, _ in accts[i : i + BATCH_SIZE]] 168 for result in asyncio.as_completed( 169 [ 170 get_all_posts(client, rate_limiter, did, start_dt, end_dt) 171 for did in batch 172 ] 173 ): 174 try: 175 posts, did = await result 176 # Save posts 177 with gzip.open( 178 os.path.join(checkpoint_dir, did + ".gz"), "wt" 179 ) as out_file: 180 for post in posts: 181 out_file.write(json.dumps(post) + "\n") 182 except at_exceptions.BadRequestError as e: 183 # Bad request is probably a profile that's private or deleted 184 logger.info(f"Bad Request: {e.response.content.error}") 185 continue 186 except Exception as e: 187 logger.error(f"Failed to get posts: {e}", exc_info=1) 188 fail_count += 1 189 if fail_count >= 100: 190 logger.error("Hitting error threshold, exiting...") 191 sys.exit(1) 192 continue 193 194 batch_count += 1 195 if batch_count % 10 == 0: 196 logger.info(f"Completed batch: {batch_count}") 197 198 199def main(): 200 for key in REQUIRED_ENV: 201 if key not in os.environ: 202 raise ValueError(f"Must set '{key}' env var") 203 204 user_name = os.environ["BSKY_USER"] 205 app_pw = os.environ["BSKY_APP_PW"] 206 207 parser = argparse.ArgumentParser( 208 prog="GetPosts", 209 description="Get all posts for accounts in provided follow graph", 210 ) 211 parser.add_argument( 212 "--graph-file", 213 dest="graph_file", 214 required=True, 215 help="File with follow graph", 216 ) 217 parser.add_argument( 218 "--save-dir", 219 dest="save_dir", 220 required=True, 221 help="Where to store crawl data", 222 ) 223 parser.add_argument( 224 "--start", 225 dest="start", 226 required=True, 227 help="Date to start saving posts from (YYYY-MM-DD)", 228 ) 229 parser.add_argument( 230 "--end", 231 dest="end", 232 required=True, 233 help="Date to stop (exclusive) saving posts from (YYYY-MM-DD)", 234 ) 235 args = parser.parse_args() 236 237 if args.save_dir is None and args.ckpt is None: 238 logger.error("Must provide save dir or checkpoint dir") 239 sys.exit(1) 240 241 try: 242 start = datetime.strptime(args.start, "%Y-%m-%d") 243 except: 244 logger.error("Invalid start date") 245 sys.exit(1) 246 247 try: 248 end = datetime.strptime(args.end, "%Y-%m-%d") 249 except: 250 logger.error("Invalid end date") 251 sys.exit(1) 252 253 if end <= start: 254 logger.error( 255 "Start date has to be before date, what're you trying to do man..." 256 ) 257 sys.exit(1) 258 259 asyncio.run( 260 retrieve_posts( 261 user_name, 262 app_pw, 263 graph_file=args.graph_file, 264 checkpoint_dir=args.save_dir, 265 start_dt=start, 266 end_dt=end, 267 ) 268 ) 269 270 271if __name__ == "__main__": 272 main()