this repo has no description
4
fork

Configure Feed

Select the types of activity you want to include in your feed.

Finish getting posts logic

+143 -43
+143 -43
get_posts.py
··· 1 1 import argparse 2 2 import asyncio 3 + from datetime import datetime 3 4 import decimal 5 + import gzip 4 6 import json 5 7 import logging 6 8 import os ··· 10 12 11 13 from atproto import AsyncClient 12 14 from atproto import exceptions as at_exceptions 15 + from atproto_client.models.app.bsky.embed.record import ViewBlocked 13 16 from atproto_client.models.app.bsky.feed.defs import FeedViewPost 14 17 import pandas as pd 15 18 from rich import print ··· 31 34 32 35 33 36 BATCH_SIZE = 10 34 - CHECKPOINT_THRESHOLD = 1_000 35 37 FOLLOWER_THRESHOLD = 150 36 38 REQUIRED_ENV = ("BSKY_USER", "BSKY_APP_PW") 37 39 ··· 61 63 "thumb": post.embed.external.thumb, 62 64 } 63 65 elif post.embed.py_type == "app.bsky.embed.record#view": 64 - data["embed"] = { 65 - "author": post.embed.record.author.did, 66 - "text": post.embed.record.value.text, 67 - "cid": post.embed.record.cid, 68 - "created_at": post.embed.record.value.created_at, 69 - } 66 + # Ignore everything thats not a quote-tweet 67 + if post.embed.record.py_type == "app.bsky.embed.record#viewRecord": 68 + data["embed"] = { 69 + "author": post.embed.record.author.did, 70 + "text": post.embed.record.value.text, 71 + "cid": post.embed.record.cid, 72 + "created_at": post.embed.record.value.created_at, 73 + } 70 74 elif post.embed.py_type == "app.bsky.embed.images#view": 71 75 data["embed"]["images"] = [] 72 76 for image in post.embed.images: ··· 85 89 } 86 90 87 91 if top.reply: 88 - data["reply_parent"] = {} 89 - data["reply_parent"]["author"] = top.reply.parent.author.did 90 - data["reply_parent"]["text"] = top.reply.parent.record.text 91 - data["reply_parent"]["cid"] = top.reply.parent.cid 92 - data["reply_parent"]["created_at"] = top.reply.parent.record.created_at 93 - data["reply_parent"]["root_cid"] = top.reply.root.cid 92 + if top.reply.parent.py_type == "app.bsky.feed.defs#postView": 93 + data["reply_parent"] = {} 94 + data["reply_parent"]["author"] = top.reply.parent.author.did 95 + data["reply_parent"]["text"] = top.reply.parent.record.text 96 + data["reply_parent"]["cid"] = top.reply.parent.cid 97 + data["reply_parent"]["created_at"] = top.reply.parent.record.created_at 98 + if top.reply.root.py_type == "app.bsky.feed.defs#postView": 99 + data["reply_parent"]["root_cid"] = top.reply.root.cid 94 100 95 101 return data 96 102 97 103 98 - async def get_all_posts(client: AsyncClient, rate_limit: RateLimit, account_did: str): 104 + async def get_all_posts( 105 + client: AsyncClient, 106 + rate_limit: RateLimit, 107 + account_did: str, 108 + start_dt: datetime, 109 + end_dt: datetime, 110 + ) -> Tuple[List[Dict], str]: 111 + posts: List[Dict] = [] 99 112 await rate_limit.acquire() 100 113 data = await client.get_author_feed( 101 114 actor=account_did, 102 115 filter="posts_and_author_threads", 103 116 ) 104 - for i in range(4): 105 - print(data.feed[i]) 106 - process_post(data.feed[i]) 117 + 118 + for top in data.feed: 119 + dt = datetime.strptime(top.post.indexed_at, "%Y-%m-%dT%H:%M:%S.%fZ") 120 + if start_dt <= dt and dt < end_dt: 121 + parsed = process_post(top) 122 + if parsed is not None: 123 + posts.append(parsed) 124 + 125 + hit_start_window = False 126 + while data.cursor and not hit_start_window: 127 + await rate_limit.acquire() 128 + data = await client.get_author_feed( 129 + actor=account_did, filter="posts_and_author_threads", cursor=data.cursor 130 + ) 131 + 132 + for top in data.feed: 133 + dt = datetime.strptime(top.post.indexed_at, "%Y-%m-%dT%H:%M:%S.%fZ") 134 + if start_dt <= dt and dt < end_dt: 135 + parsed = process_post(top) 136 + if parsed is not None: 137 + posts.append(parsed) 138 + if dt < start_dt: 139 + hit_start_window = True 107 140 141 + return posts, account_did 108 142 109 - async def retrieve_posts(user: str, app_pw: str, graph_file: str, checkpoint_dir: str): 143 + 144 + async def retrieve_posts( 145 + user: str, 146 + app_pw: str, 147 + graph_file: str, 148 + checkpoint_dir: str, 149 + start_dt: datetime, 150 + end_dt: datetime, 151 + ): 110 152 111 153 # If checkpoint dir doesn't exist, try to create it 112 154 if not os.path.isdir(checkpoint_dir): ··· 122 164 try: 123 165 files = os.listdir(checkpoint_dir) 124 166 for file in files: 125 - completed_accounts.add(file) 167 + # Grab entire file name except for .gz extension 168 + completed_accounts.add(file[:-3]) 126 169 except Exception as e: 127 170 logger.error( 128 171 f"Failed to recover from checkpoint dir, {checkpoint_dir}\n{e}", ··· 131 174 sys.exit(1) 132 175 133 176 # Load follow graph parquet file 134 - # to_explore = dict() 135 - # try: 136 - # logger.info("Parsing follower graph file...") 137 - # follow_df = pd.read_parquet(graph_file) 138 - # except Exception as e: 139 - # logger.error(f"Failed to open follow graph file, {graph_file}\n{e}") 140 - # sys.exit(1) 177 + to_explore = dict() 178 + try: 179 + logger.info("Parsing follower graph file...") 180 + follow_df = pd.read_parquet(graph_file) 181 + except Exception as e: 182 + logger.error(f"Failed to open follow graph file, {graph_file}\n{e}") 183 + sys.exit(1) 141 184 142 - # for _, row in follow_df.iterrows(): 143 - # for acct in row["follows"]: 144 - # if acct not in completed_accounts: 145 - # if acct not in to_explore: 146 - # to_explore[acct] = 0 147 - # to_explore[acct] += 1 185 + for _, row in follow_df.iterrows(): 186 + for acct in row["follows"]: 187 + if acct not in completed_accounts: 188 + if acct not in to_explore: 189 + to_explore[acct] = 0 190 + to_explore[acct] += 1 148 191 149 - # accts = [ 150 - # (acct, follows) 151 - # for acct, follows in to_explore.items() 152 - # if follows >= FOLLOWER_THRESHOLD 153 - # ] 154 - # accts.sort(key=lambda x: -1 * x[1]) 155 - accts = [("did:plc:5o6k7jvowuyaquloafzn3cfw", 8604)] 192 + accts = [ 193 + (acct, follows) 194 + for acct, follows in to_explore.items() 195 + if follows >= FOLLOWER_THRESHOLD 196 + ] 197 + accts.sort(key=lambda x: -1 * x[1]) 156 198 157 199 logger.info(f"Num of accounts to retrieve posts from: {len(accts)}") 158 200 ··· 160 202 await client.login(user, app_pw) 161 203 162 204 # Get all posts for accounts 163 - batch_count = 1 205 + batch_count = 0 164 206 fail_count = 0 165 207 rate_limiter = RateLimit(BATCH_SIZE) 166 208 for i in range(0, len(accts), BATCH_SIZE): 167 - batch = [acct for acct, follow_count in accts[i : i + BATCH_SIZE]] 168 - await get_all_posts(client, rate_limiter, batch[0]) 169 - sys.exit(1) 209 + batch = [acct for acct, _ in accts[i : i + BATCH_SIZE]] 210 + for result in asyncio.as_completed( 211 + [ 212 + get_all_posts(client, rate_limiter, did, start_dt, end_dt) 213 + for did in batch 214 + ] 215 + ): 216 + try: 217 + posts, did = await result 218 + # Save posts 219 + with gzip.open( 220 + os.path.join(checkpoint_dir, did + ".gz"), "wt" 221 + ) as out_file: 222 + for post in posts: 223 + out_file.write(json.dumps(post) + "\n") 224 + except at_exceptions.BadRequestError as e: 225 + # Bad request is probably a profile that's private or deleted 226 + logger.info(f"Bad Request: {e.response.content.error}") 227 + continue 228 + except Exception as e: 229 + logger.error(f"Failed to get posts: {e}", exc_info=1) 230 + fail_count += 1 231 + if fail_count >= 3: 232 + sys.exit(1) 233 + continue 234 + 235 + batch_count += 1 236 + if batch_count % 10 == 0: 237 + logger.info(f"Completed batch: {batch_count}") 170 238 171 239 172 240 def main(): ··· 193 261 required=True, 194 262 help="Where to store crawl data", 195 263 ) 264 + parser.add_argument( 265 + "--start", 266 + dest="start", 267 + required=True, 268 + help="Date to start saving posts from (YYYY-MM-DD)", 269 + ) 270 + parser.add_argument( 271 + "--end", 272 + dest="end", 273 + required=True, 274 + help="Date to stop (exclusive) saving posts from (YYYY-MM-DD)", 275 + ) 196 276 args = parser.parse_args() 197 277 198 278 if args.save_dir is None and args.ckpt is None: 199 279 logger.error("Must provide save dir or checkpoint dir") 200 280 sys.exit(1) 201 281 282 + try: 283 + start = datetime.strptime(args.start, "%Y-%m-%d") 284 + except: 285 + logger.error("Invalid start date") 286 + sys.exit(1) 287 + 288 + try: 289 + end = datetime.strptime(args.end, "%Y-%m-%d") 290 + except: 291 + logger.error("Invalid end date") 292 + sys.exit(1) 293 + 294 + if end <= start: 295 + logger.error( 296 + "Start date has to be before date, what're you trying to do man..." 297 + ) 298 + sys.exit(1) 299 + 202 300 asyncio.run( 203 301 retrieve_posts( 204 302 user_name, 205 303 app_pw, 206 304 graph_file=args.graph_file, 207 305 checkpoint_dir=args.save_dir, 306 + start_dt=start, 307 + end_dt=end, 208 308 ) 209 309 ) 210 310