forked from
cameron.stream/void
this repo has no description
1import os
2import logging
3from typing import Optional, Dict, Any
4from atproto_client import Client, Session, SessionEvent, models
5
6# Configure logging
7logging.basicConfig(
8 level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
9)
10logger = logging.getLogger("bluesky_session_handler")
11
12# Load the environment variables
13import dotenv
14dotenv.load_dotenv(override=True)
15
16import yaml
17import json
18
19# Strip fields. A list of fields to remove from a JSON object
20STRIP_FIELDS = [
21 "cid",
22 "rev",
23 "did",
24 "uri",
25 "langs",
26 "threadgate",
27 "py_type",
28 "labels",
29 "facets",
30 "avatar",
31 "viewer",
32 "indexed_at",
33 "tags",
34 "associated",
35 "thread_context",
36 "aspect_ratio",
37 "thumb",
38 "fullsize",
39 "root",
40 "created_at",
41 "verification",
42 "like_count",
43 "quote_count",
44 "reply_count",
45 "repost_count",
46 "embedding_disabled",
47 "thread_muted",
48 "reply_disabled",
49 "pinned",
50 "like",
51 "repost",
52 "blocked_by",
53 "blocking",
54 "blocking_by_list",
55 "followed_by",
56 "following",
57 "known_followers",
58 "muted",
59 "muted_by_list",
60 "root_author_like",
61 "entities",
62 "ref",
63 "mime_type",
64 "size",
65]
66def convert_to_basic_types(obj):
67 """Convert complex Python objects to basic types for JSON/YAML serialization."""
68 if hasattr(obj, '__dict__'):
69 # Convert objects with __dict__ to their dictionary representation
70 return convert_to_basic_types(obj.__dict__)
71 elif isinstance(obj, dict):
72 return {key: convert_to_basic_types(value) for key, value in obj.items()}
73 elif isinstance(obj, list):
74 return [convert_to_basic_types(item) for item in obj]
75 elif isinstance(obj, (str, int, float, bool)) or obj is None:
76 return obj
77 else:
78 # For other types, try to convert to string
79 return str(obj)
80
81
82def strip_fields(obj, strip_field_list):
83 """Recursively strip fields from a JSON object."""
84 if isinstance(obj, dict):
85 keys_flagged_for_removal = []
86
87 # Remove fields from strip list and pydantic metadata
88 for field in list(obj.keys()):
89 if field in strip_field_list or field.startswith("__"):
90 keys_flagged_for_removal.append(field)
91
92 # Remove flagged keys
93 for key in keys_flagged_for_removal:
94 obj.pop(key, None)
95
96 # Recursively process remaining values
97 for key, value in list(obj.items()):
98 obj[key] = strip_fields(value, strip_field_list)
99 # Remove empty/null values after processing
100 if (
101 obj[key] is None
102 or (isinstance(obj[key], dict) and len(obj[key]) == 0)
103 or (isinstance(obj[key], list) and len(obj[key]) == 0)
104 or (isinstance(obj[key], str) and obj[key].strip() == "")
105 ):
106 obj.pop(key, None)
107
108 elif isinstance(obj, list):
109 for i, value in enumerate(obj):
110 obj[i] = strip_fields(value, strip_field_list)
111 # Remove None values from list
112 obj[:] = [item for item in obj if item is not None]
113
114 return obj
115
116
117def thread_to_yaml_string(thread, strip_metadata=True):
118 """
119 Convert thread data to a YAML-formatted string for LLM parsing.
120
121 Args:
122 thread: The thread data from get_post_thread
123 strip_metadata: Whether to strip metadata fields for cleaner output
124
125 Returns:
126 YAML-formatted string representation of the thread
127 """
128 # First convert complex objects to basic types
129 basic_thread = convert_to_basic_types(thread)
130
131 if strip_metadata:
132 # Create a copy and strip unwanted fields
133 cleaned_thread = strip_fields(basic_thread, STRIP_FIELDS)
134 else:
135 cleaned_thread = basic_thread
136
137 return yaml.dump(cleaned_thread, indent=2, allow_unicode=True, default_flow_style=False)
138
139
140
141
142
143def get_session(username: str) -> Optional[str]:
144 try:
145 with open(f"session_{username}.txt", encoding="UTF-8") as f:
146 return f.read()
147 except FileNotFoundError:
148 logger.debug(f"No existing session found for {username}")
149 return None
150
151def save_session(username: str, session_string: str) -> None:
152 with open(f"session_{username}.txt", "w", encoding="UTF-8") as f:
153 f.write(session_string)
154 logger.debug(f"Session saved for {username}")
155
156def on_session_change(username: str, event: SessionEvent, session: Session) -> None:
157 logger.info(f"Session changed: {event} {repr(session)}")
158 if event in (SessionEvent.CREATE, SessionEvent.REFRESH):
159 logger.info(f"Saving changed session for {username}")
160 save_session(username, session.export())
161
162def init_client(username: str, password: str) -> Client:
163 pds_uri = os.getenv("PDS_URI")
164 if pds_uri is None:
165 logger.warning(
166 "No PDS URI provided. Falling back to bsky.social. Note! If you are on a non-Bluesky PDS, this can cause logins to fail. Please provide a PDS URI using the PDS_URI environment variable."
167 )
168 pds_uri = "https://bsky.social"
169
170 # Print the PDS URI
171 logger.info(f"Using PDS URI: {pds_uri}")
172
173 client = Client(pds_uri)
174 client.on_session_change(
175 lambda event, session: on_session_change(username, event, session)
176 )
177
178 session_string = get_session(username)
179 if session_string:
180 logger.info(f"Reusing existing session for {username}")
181 client.login(session_string=session_string)
182 else:
183 logger.info(f"Creating new session for {username}")
184 client.login(username, password)
185
186 return client
187
188
189def default_login() -> Client:
190 username = os.getenv("BSKY_USERNAME")
191 password = os.getenv("BSKY_PASSWORD")
192
193 if username is None:
194 logger.error(
195 "No username provided. Please provide a username using the BSKY_USERNAME environment variable."
196 )
197 exit()
198
199 if password is None:
200 logger.error(
201 "No password provided. Please provide a password using the BSKY_PASSWORD environment variable."
202 )
203 exit()
204
205 return init_client(username, password)
206
207def reply_to_post(client: Client, text: str, reply_to_uri: str, reply_to_cid: str, root_uri: Optional[str] = None, root_cid: Optional[str] = None) -> Dict[str, Any]:
208 """
209 Reply to a post on Bluesky with rich text support.
210
211 Args:
212 client: Authenticated Bluesky client
213 text: The reply text
214 reply_to_uri: The URI of the post being replied to (parent)
215 reply_to_cid: The CID of the post being replied to (parent)
216 root_uri: The URI of the root post (if replying to a reply). If None, uses reply_to_uri
217 root_cid: The CID of the root post (if replying to a reply). If None, uses reply_to_cid
218
219 Returns:
220 The response from sending the post
221 """
222 import re
223
224 # If root is not provided, this is a reply to the root post
225 if root_uri is None:
226 root_uri = reply_to_uri
227 root_cid = reply_to_cid
228
229 # Create references for the reply
230 parent_ref = models.create_strong_ref(models.ComAtprotoRepoStrongRef.Main(uri=reply_to_uri, cid=reply_to_cid))
231 root_ref = models.create_strong_ref(models.ComAtprotoRepoStrongRef.Main(uri=root_uri, cid=root_cid))
232
233 # Parse rich text facets (mentions and URLs)
234 facets = []
235 text_bytes = text.encode("UTF-8")
236
237 # Parse mentions
238 mention_regex = rb"[$|\W](@([a-zA-Z0-9]([a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]([a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?)"
239
240 for m in re.finditer(mention_regex, text_bytes):
241 handle = m.group(1)[1:].decode("UTF-8") # Remove @ prefix
242 try:
243 # Resolve handle to DID using the API
244 resolve_resp = client.app.bsky.actor.get_profile({'actor': handle})
245 if resolve_resp and hasattr(resolve_resp, 'did'):
246 facets.append(
247 models.AppBskyRichtextFacet.Main(
248 index=models.AppBskyRichtextFacet.ByteSlice(
249 byteStart=m.start(1),
250 byteEnd=m.end(1)
251 ),
252 features=[models.AppBskyRichtextFacet.Mention(did=resolve_resp.did)]
253 )
254 )
255 except Exception as e:
256 logger.debug(f"Failed to resolve handle {handle}: {e}")
257 continue
258
259 # Parse URLs
260 url_regex = rb"[$|\W](https?:\/\/(www\.)?[-a-zA-Z0-9@:%._\+~#=]{1,256}\.[a-zA-Z0-9()]{1,6}\b([-a-zA-Z0-9()@:%_\+.~#?&//=]*[-a-zA-Z0-9@%_\+~#//=])?)"
261
262 for m in re.finditer(url_regex, text_bytes):
263 url = m.group(1).decode("UTF-8")
264 facets.append(
265 models.AppBskyRichtextFacet.Main(
266 index=models.AppBskyRichtextFacet.ByteSlice(
267 byteStart=m.start(1),
268 byteEnd=m.end(1)
269 ),
270 features=[models.AppBskyRichtextFacet.Link(uri=url)]
271 )
272 )
273
274 # Send the reply with facets if any were found
275 if facets:
276 response = client.send_post(
277 text=text,
278 reply_to=models.AppBskyFeedPost.ReplyRef(parent=parent_ref, root=root_ref),
279 facets=facets
280 )
281 else:
282 response = client.send_post(
283 text=text,
284 reply_to=models.AppBskyFeedPost.ReplyRef(parent=parent_ref, root=root_ref)
285 )
286
287 logger.info(f"Reply sent successfully: {response.uri}")
288 return response
289
290
291def get_post_thread(client: Client, uri: str) -> Optional[Dict[str, Any]]:
292 """
293 Get the thread containing a post to find root post information.
294
295 Args:
296 client: Authenticated Bluesky client
297 uri: The URI of the post
298
299 Returns:
300 The thread data or None if not found
301 """
302 try:
303 thread = client.app.bsky.feed.get_post_thread({'uri': uri, 'parent_height': 60, 'depth': 10})
304 return thread
305 except Exception as e:
306 logger.error(f"Error fetching post thread: {e}")
307 return None
308
309
310def reply_to_notification(client: Client, notification: Any, reply_text: str) -> Optional[Dict[str, Any]]:
311 """
312 Reply to a notification (mention or reply).
313
314 Args:
315 client: Authenticated Bluesky client
316 notification: The notification object from list_notifications
317 reply_text: The text to reply with
318
319 Returns:
320 The response from sending the reply or None if failed
321 """
322 try:
323 # Get the post URI and CID from the notification (handle both dict and object)
324 if isinstance(notification, dict):
325 post_uri = notification.get('uri')
326 post_cid = notification.get('cid')
327 elif hasattr(notification, 'uri') and hasattr(notification, 'cid'):
328 post_uri = notification.uri
329 post_cid = notification.cid
330 else:
331 post_uri = None
332 post_cid = None
333
334 if not post_uri or not post_cid:
335 logger.error("Notification doesn't have required uri/cid fields")
336 return None
337
338 # Get the thread to find the root post
339 thread_data = get_post_thread(client, post_uri)
340
341 if thread_data and hasattr(thread_data, 'thread'):
342 thread = thread_data.thread
343
344 # Find root post
345 root_uri = post_uri
346 root_cid = post_cid
347
348 # If this has a parent, find the root
349 if hasattr(thread, 'parent') and thread.parent:
350 # Keep going up until we find the root
351 current = thread
352 while hasattr(current, 'parent') and current.parent:
353 current = current.parent
354 if hasattr(current, 'post') and hasattr(current.post, 'uri') and hasattr(current.post, 'cid'):
355 root_uri = current.post.uri
356 root_cid = current.post.cid
357
358 # Reply to the notification
359 return reply_to_post(
360 client=client,
361 text=reply_text,
362 reply_to_uri=post_uri,
363 reply_to_cid=post_cid,
364 root_uri=root_uri,
365 root_cid=root_cid
366 )
367 else:
368 # If we can't get thread data, just reply directly
369 return reply_to_post(
370 client=client,
371 text=reply_text,
372 reply_to_uri=post_uri,
373 reply_to_cid=post_cid
374 )
375
376 except Exception as e:
377 logger.error(f"Error replying to notification: {e}")
378 return None
379
380
381if __name__ == "__main__":
382 client = default_login()
383 # do something with the client
384 logger.info("Client is ready to use!")