this repo has no description
1import argparse
2import asyncio
3from datetime import datetime
4import gzip
5import json
6import logging
7import os
8import sys
9from typing import Tuple, List, Dict
10
11from atproto import AsyncClient
12from atproto import exceptions as at_exceptions
13from atproto_client.models.app.bsky.feed.defs import FeedViewPost
14
15from utils import get_accounts, load_checkpoint, RateLimit, BSKY_API_LIMIT
16
17logger = logging.getLogger(__name__)
18logger.setLevel(logging.INFO)
19
20# Create formatter
21formatter = logging.Formatter(
22 "%(asctime)s | %(levelname)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
23)
24
25# Console handler
26console_handler = logging.StreamHandler(sys.stdout)
27console_handler.setFormatter(formatter)
28logger.addHandler(console_handler)
29
30
31BATCH_SIZE = 10
32REQUIRED_ENV = ("BSKY_USER", "BSKY_APP_PW")
33
34
35def process_post(top: FeedViewPost):
36 post = top.post
37 data = {
38 "author": post.author.did,
39 "text": post.record.text,
40 "cid": post.cid,
41 "created_at": post.record.created_at,
42 "repost": False,
43 }
44 if (
45 top.reason is not None
46 and top.reason.py_type == "app.bsky.feed.defs#reasonRepost"
47 ):
48 data["repost"] = True
49
50 if post.embed:
51 data["embed"] = {}
52 if post.embed.py_type == "app.bsky.embed.external#view":
53 data["embed"] = {
54 "title": post.embed.external.title,
55 "description": post.embed.external.description,
56 "uri": post.embed.external.uri,
57 "thumb": post.embed.external.thumb,
58 }
59 elif post.embed.py_type == "app.bsky.embed.record#view":
60 # Ignore everything thats not a quote-tweet
61 if post.embed.record.py_type == "app.bsky.embed.record#viewRecord":
62 data["embed"] = {
63 "author": post.embed.record.author.did,
64 "text": post.embed.record.value.text,
65 "cid": post.embed.record.cid,
66 "created_at": post.embed.record.value.created_at,
67 }
68 elif post.embed.py_type == "app.bsky.embed.images#view":
69 data["embed"]["images"] = []
70 for image in post.embed.images:
71 data["embed"]["images"].append(
72 {
73 "alt_text": image.alt,
74 "full_url": image.fullsize,
75 "thumb_url": image.thumb,
76 }
77 )
78 elif post.embed.py_type == "app.bsky.embed.video#view":
79 data["embed"]["video"] = {
80 "alt": post.embed.alt,
81 "full_url": post.embed.playlist,
82 "thumb_url": post.embed.thumbnail,
83 }
84
85 if top.reply:
86 if top.reply.parent.py_type == "app.bsky.feed.defs#postView":
87 data["reply_parent"] = {}
88 data["reply_parent"]["author"] = top.reply.parent.author.did
89 data["reply_parent"]["text"] = top.reply.parent.record.text
90 data["reply_parent"]["cid"] = top.reply.parent.cid
91 data["reply_parent"]["created_at"] = top.reply.parent.record.created_at
92 if top.reply.root.py_type == "app.bsky.feed.defs#postView":
93 data["reply_parent"]["root_cid"] = top.reply.root.cid
94
95 return data
96
97
98async def get_all_posts(
99 client: AsyncClient,
100 rate_limit: RateLimit,
101 account_did: str,
102 start_dt: datetime,
103 end_dt: datetime,
104) -> Tuple[List[Dict], str]:
105 posts: List[Dict] = []
106 await rate_limit.acquire()
107 try:
108 data = await client.get_author_feed(
109 actor=account_did,
110 filter="posts_and_author_threads",
111 )
112 # If user can't be accessed just return an empty list to skip next time
113 except at_exceptions.BadRequestError as e:
114 if e.response.status_code == 400:
115 return [], account_did
116 else:
117 logger.info(f"Error status code: {e.response.status_code}")
118 raise e
119
120 for top in data.feed:
121 dt = datetime.strptime(top.post.indexed_at, "%Y-%m-%dT%H:%M:%S.%fZ")
122 if start_dt <= dt and dt < end_dt:
123 parsed = process_post(top)
124 if parsed is not None:
125 posts.append(parsed)
126
127 hit_start_window = False
128 while data.cursor and not hit_start_window:
129 await rate_limit.acquire()
130 data = await client.get_author_feed(
131 actor=account_did, filter="posts_and_author_threads", cursor=data.cursor
132 )
133
134 for top in data.feed:
135 dt = datetime.strptime(top.post.indexed_at, "%Y-%m-%dT%H:%M:%S.%fZ")
136 if start_dt <= dt and dt < end_dt:
137 parsed = process_post(top)
138 if parsed is not None:
139 posts.append(parsed)
140 if dt < start_dt:
141 hit_start_window = True
142
143 return posts, account_did
144
145
146async def retrieve_posts(
147 user: str,
148 app_pw: str,
149 graph_file: str,
150 checkpoint_dir: str,
151 start_dt: datetime,
152 end_dt: datetime,
153):
154 # Checkpoint folders contain one file per user
155 completed_accounts = load_checkpoint(checkpoint_dir)
156 accts = get_accounts(graph_file, completed_accounts)
157 logger.info(f"Num of accounts to retrieve posts from: {len(accts)}")
158
159 client = AsyncClient()
160 await client.login(user, app_pw)
161
162 # Get all posts for accounts
163 batch_count = 0
164 fail_count = 0
165 rate_limiter = RateLimit(BSKY_API_LIMIT)
166 for i in range(0, len(accts), BATCH_SIZE):
167 batch = [acct for acct, _ in accts[i : i + BATCH_SIZE]]
168 for result in asyncio.as_completed(
169 [
170 get_all_posts(client, rate_limiter, did, start_dt, end_dt)
171 for did in batch
172 ]
173 ):
174 try:
175 posts, did = await result
176 # Save posts
177 with gzip.open(
178 os.path.join(checkpoint_dir, did + ".gz"), "wt"
179 ) as out_file:
180 for post in posts:
181 out_file.write(json.dumps(post) + "\n")
182 except at_exceptions.BadRequestError as e:
183 # Bad request is probably a profile that's private or deleted
184 logger.info(f"Bad Request: {e.response.content.error}")
185 continue
186 except Exception as e:
187 logger.error(f"Failed to get posts: {e}", exc_info=1)
188 fail_count += 1
189 if fail_count >= 100:
190 logger.error("Hitting error threshold, exiting...")
191 sys.exit(1)
192 continue
193
194 batch_count += 1
195 if batch_count % 10 == 0:
196 logger.info(f"Completed batch: {batch_count}")
197
198
199def main():
200 for key in REQUIRED_ENV:
201 if key not in os.environ:
202 raise ValueError(f"Must set '{key}' env var")
203
204 user_name = os.environ["BSKY_USER"]
205 app_pw = os.environ["BSKY_APP_PW"]
206
207 parser = argparse.ArgumentParser(
208 prog="GetPosts",
209 description="Get all posts for accounts in provided follow graph",
210 )
211 parser.add_argument(
212 "--graph-file",
213 dest="graph_file",
214 required=True,
215 help="File with follow graph",
216 )
217 parser.add_argument(
218 "--save-dir",
219 dest="save_dir",
220 required=True,
221 help="Where to store crawl data",
222 )
223 parser.add_argument(
224 "--start",
225 dest="start",
226 required=True,
227 help="Date to start saving posts from (YYYY-MM-DD)",
228 )
229 parser.add_argument(
230 "--end",
231 dest="end",
232 required=True,
233 help="Date to stop (exclusive) saving posts from (YYYY-MM-DD)",
234 )
235 args = parser.parse_args()
236
237 if args.save_dir is None and args.ckpt is None:
238 logger.error("Must provide save dir or checkpoint dir")
239 sys.exit(1)
240
241 try:
242 start = datetime.strptime(args.start, "%Y-%m-%d")
243 except:
244 logger.error("Invalid start date")
245 sys.exit(1)
246
247 try:
248 end = datetime.strptime(args.end, "%Y-%m-%d")
249 except:
250 logger.error("Invalid end date")
251 sys.exit(1)
252
253 if end <= start:
254 logger.error(
255 "Start date has to be before date, what're you trying to do man..."
256 )
257 sys.exit(1)
258
259 asyncio.run(
260 retrieve_posts(
261 user_name,
262 app_pw,
263 graph_file=args.graph_file,
264 checkpoint_dir=args.save_dir,
265 start_dt=start,
266 end_dt=end,
267 )
268 )
269
270
271if __name__ == "__main__":
272 main()