this repo has no description
1import argparse
2import asyncio
3from datetime import datetime
4import decimal
5import gzip
6import json
7import logging
8import os
9import sys
10import time
11from typing import Tuple, List, Dict
12
13from atproto import AsyncClient
14from atproto import exceptions as at_exceptions
15from atproto_client.models.app.bsky.embed.record import ViewBlocked
16from atproto_client.models.app.bsky.feed.defs import FeedViewPost
17import pandas as pd
18from rich import print
19
20from crawl_follows import RateLimit
21
22logger = logging.getLogger(__name__)
23logger.setLevel(logging.INFO)
24
25# Create formatter
26formatter = logging.Formatter(
27 "%(asctime)s | %(levelname)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
28)
29
30# Console handler
31console_handler = logging.StreamHandler(sys.stdout)
32console_handler.setFormatter(formatter)
33logger.addHandler(console_handler)
34
35
36BATCH_SIZE = 10
37FOLLOWER_THRESHOLD = 150
38REQUIRED_ENV = ("BSKY_USER", "BSKY_APP_PW")
39
40
41def process_post(top: FeedViewPost):
42 post = top.post
43 data = {
44 "author": post.author.did,
45 "text": post.record.text,
46 "cid": post.cid,
47 "created_at": post.record.created_at,
48 "repost": False,
49 }
50 if (
51 top.reason is not None
52 and top.reason.py_type == "app.bsky.feed.defs#reasonRepost"
53 ):
54 data["repost"] = True
55
56 if post.embed:
57 data["embed"] = {}
58 if post.embed.py_type == "app.bsky.embed.external#view":
59 data["embed"] = {
60 "title": post.embed.external.title,
61 "description": post.embed.external.description,
62 "uri": post.embed.external.uri,
63 "thumb": post.embed.external.thumb,
64 }
65 elif post.embed.py_type == "app.bsky.embed.record#view":
66 # Ignore everything thats not a quote-tweet
67 if post.embed.record.py_type == "app.bsky.embed.record#viewRecord":
68 data["embed"] = {
69 "author": post.embed.record.author.did,
70 "text": post.embed.record.value.text,
71 "cid": post.embed.record.cid,
72 "created_at": post.embed.record.value.created_at,
73 }
74 elif post.embed.py_type == "app.bsky.embed.images#view":
75 data["embed"]["images"] = []
76 for image in post.embed.images:
77 data["embed"]["images"].append(
78 {
79 "alt_text": image.alt,
80 "full_url": image.fullsize,
81 "thumb_url": image.thumb,
82 }
83 )
84 elif post.embed.py_type == "app.bsky.embed.video#view":
85 data["embed"]["video"] = {
86 "alt": post.embed.alt,
87 "full_url": post.embed.playlist,
88 "thumb_url": post.embed.thumbnail,
89 }
90
91 if top.reply:
92 if top.reply.parent.py_type == "app.bsky.feed.defs#postView":
93 data["reply_parent"] = {}
94 data["reply_parent"]["author"] = top.reply.parent.author.did
95 data["reply_parent"]["text"] = top.reply.parent.record.text
96 data["reply_parent"]["cid"] = top.reply.parent.cid
97 data["reply_parent"]["created_at"] = top.reply.parent.record.created_at
98 if top.reply.root.py_type == "app.bsky.feed.defs#postView":
99 data["reply_parent"]["root_cid"] = top.reply.root.cid
100
101 return data
102
103
104async def get_all_posts(
105 client: AsyncClient,
106 rate_limit: RateLimit,
107 account_did: str,
108 start_dt: datetime,
109 end_dt: datetime,
110) -> Tuple[List[Dict], str]:
111 posts: List[Dict] = []
112 await rate_limit.acquire()
113 data = await client.get_author_feed(
114 actor=account_did,
115 filter="posts_and_author_threads",
116 )
117
118 for top in data.feed:
119 dt = datetime.strptime(top.post.indexed_at, "%Y-%m-%dT%H:%M:%S.%fZ")
120 if start_dt <= dt and dt < end_dt:
121 parsed = process_post(top)
122 if parsed is not None:
123 posts.append(parsed)
124
125 hit_start_window = False
126 while data.cursor and not hit_start_window:
127 await rate_limit.acquire()
128 data = await client.get_author_feed(
129 actor=account_did, filter="posts_and_author_threads", cursor=data.cursor
130 )
131
132 for top in data.feed:
133 dt = datetime.strptime(top.post.indexed_at, "%Y-%m-%dT%H:%M:%S.%fZ")
134 if start_dt <= dt and dt < end_dt:
135 parsed = process_post(top)
136 if parsed is not None:
137 posts.append(parsed)
138 if dt < start_dt:
139 hit_start_window = True
140
141 return posts, account_did
142
143
144async def retrieve_posts(
145 user: str,
146 app_pw: str,
147 graph_file: str,
148 checkpoint_dir: str,
149 start_dt: datetime,
150 end_dt: datetime,
151):
152
153 # If checkpoint dir doesn't exist, try to create it
154 if not os.path.isdir(checkpoint_dir):
155 logger.info("Checkpoint dir doesn't exist, creating...")
156 try:
157 os.mkdir(checkpoint_dir)
158 except Exception as e:
159 logger.error(f"Failed to created checkpoint dir, {checkpoint_dir}\n{e}")
160 sys.exit(1)
161
162 # Checkpoint folders contain one file per user
163 completed_accounts = set()
164 try:
165 files = os.listdir(checkpoint_dir)
166 for file in files:
167 # Grab entire file name except for .gz extension
168 completed_accounts.add(file[:-3])
169 except Exception as e:
170 logger.error(
171 f"Failed to recover from checkpoint dir, {checkpoint_dir}\n{e}",
172 exc_info=1,
173 )
174 sys.exit(1)
175
176 # Load follow graph parquet file
177 to_explore = dict()
178 try:
179 logger.info("Parsing follower graph file...")
180 follow_df = pd.read_parquet(graph_file)
181 except Exception as e:
182 logger.error(f"Failed to open follow graph file, {graph_file}\n{e}")
183 sys.exit(1)
184
185 for _, row in follow_df.iterrows():
186 for acct in row["follows"]:
187 if acct not in completed_accounts:
188 if acct not in to_explore:
189 to_explore[acct] = 0
190 to_explore[acct] += 1
191
192 accts = [
193 (acct, follows)
194 for acct, follows in to_explore.items()
195 if follows >= FOLLOWER_THRESHOLD
196 ]
197 accts.sort(key=lambda x: -1 * x[1])
198
199 logger.info(f"Num of accounts to retrieve posts from: {len(accts)}")
200
201 client = AsyncClient()
202 await client.login(user, app_pw)
203
204 # Get all posts for accounts
205 batch_count = 0
206 fail_count = 0
207 rate_limiter = RateLimit(BATCH_SIZE)
208 for i in range(0, len(accts), BATCH_SIZE):
209 batch = [acct for acct, _ in accts[i : i + BATCH_SIZE]]
210 for result in asyncio.as_completed(
211 [
212 get_all_posts(client, rate_limiter, did, start_dt, end_dt)
213 for did in batch
214 ]
215 ):
216 try:
217 posts, did = await result
218 # Save posts
219 with gzip.open(
220 os.path.join(checkpoint_dir, did + ".gz"), "wt"
221 ) as out_file:
222 for post in posts:
223 out_file.write(json.dumps(post) + "\n")
224 except at_exceptions.BadRequestError as e:
225 # Bad request is probably a profile that's private or deleted
226 logger.info(f"Bad Request: {e.response.content.error}")
227 continue
228 except Exception as e:
229 logger.error(f"Failed to get posts: {e}", exc_info=1)
230 fail_count += 1
231 if fail_count >= 3:
232 sys.exit(1)
233 continue
234
235 batch_count += 1
236 if batch_count % 10 == 0:
237 logger.info(f"Completed batch: {batch_count}")
238
239
240def main():
241 for key in REQUIRED_ENV:
242 if key not in os.environ:
243 raise ValueError(f"Must set '{key}' env var")
244
245 user_name = os.environ["BSKY_USER"]
246 app_pw = os.environ["BSKY_APP_PW"]
247
248 parser = argparse.ArgumentParser(
249 prog="GetPosts",
250 description="Get all posts for accounts in provided follow graph",
251 )
252 parser.add_argument(
253 "--graph-file",
254 dest="graph_file",
255 required=True,
256 help="File with follow graph",
257 )
258 parser.add_argument(
259 "--save-dir",
260 dest="save_dir",
261 required=True,
262 help="Where to store crawl data",
263 )
264 parser.add_argument(
265 "--start",
266 dest="start",
267 required=True,
268 help="Date to start saving posts from (YYYY-MM-DD)",
269 )
270 parser.add_argument(
271 "--end",
272 dest="end",
273 required=True,
274 help="Date to stop (exclusive) saving posts from (YYYY-MM-DD)",
275 )
276 args = parser.parse_args()
277
278 if args.save_dir is None and args.ckpt is None:
279 logger.error("Must provide save dir or checkpoint dir")
280 sys.exit(1)
281
282 try:
283 start = datetime.strptime(args.start, "%Y-%m-%d")
284 except:
285 logger.error("Invalid start date")
286 sys.exit(1)
287
288 try:
289 end = datetime.strptime(args.end, "%Y-%m-%d")
290 except:
291 logger.error("Invalid end date")
292 sys.exit(1)
293
294 if end <= start:
295 logger.error(
296 "Start date has to be before date, what're you trying to do man..."
297 )
298 sys.exit(1)
299
300 asyncio.run(
301 retrieve_posts(
302 user_name,
303 app_pw,
304 graph_file=args.graph_file,
305 checkpoint_dir=args.save_dir,
306 start_dt=start,
307 end_dt=end,
308 )
309 )
310
311
312if __name__ == "__main__":
313 main()