A better Rust ATProto crate
1use std::future::Future;
2use std::sync::Arc;
3
4use dashmap::DashMap;
5use jacquard_common::{
6 IntoStatic,
7 session::{SessionStore, SessionStoreError},
8 types::did::Did,
9};
10use smol_str::{SmolStr, ToSmolStr, format_smolstr};
11
12use crate::session::{AuthRequestData, ClientSessionData};
13
14/// Persistent storage backend for OAuth client sessions and in-flight authorization requests.
15///
16/// Implementors are responsible for durably storing two categories of data:
17/// - Active client sessions (access tokens, refresh tokens, nonces) keyed by DID + session ID.
18/// - Pending authorization request state, keyed by the OAuth `state` parameter, which must
19/// survive the round-trip to the authorization server and be cleaned up after use.
20#[cfg_attr(not(target_arch = "wasm32"), trait_variant::make(Send))]
21pub trait ClientAuthStore {
22 /// Retrieve an active session for the given DID and session identifier, if one exists.
23 fn get_session(
24 &self,
25 did: &Did<'_>,
26 session_id: &str,
27 ) -> impl Future<Output = Result<Option<ClientSessionData<'_>>, SessionStoreError>>;
28
29 /// Insert or update a session, replacing any existing entry for the same DID and session ID.
30 fn upsert_session(
31 &self,
32 session: ClientSessionData<'_>,
33 ) -> impl Future<Output = Result<(), SessionStoreError>>;
34
35 /// Delete the session for the given DID and session identifier.
36 fn delete_session(
37 &self,
38 did: &Did<'_>,
39 session_id: &str,
40 ) -> impl Future<Output = Result<(), SessionStoreError>>;
41
42 /// Retrieve the authorization request data associated with the given OAuth `state` value.
43 fn get_auth_req_info(
44 &self,
45 state: &str,
46 ) -> impl Future<Output = Result<Option<AuthRequestData<'_>>, SessionStoreError>>;
47
48 /// Persist authorization request data so it can be retrieved after the OAuth redirect.
49 fn save_auth_req_info(
50 &self,
51 auth_req_info: &AuthRequestData<'_>,
52 ) -> impl Future<Output = Result<(), SessionStoreError>>;
53
54 /// Remove authorization request data after the callback has been handled.
55 fn delete_auth_req_info(
56 &self,
57 state: &str,
58 ) -> impl Future<Output = Result<(), SessionStoreError>>;
59}
60
61/// An in-memory implementation of [`ClientAuthStore`], suitable for testing and single-process
62/// deployments where session persistence across restarts is not required.
63pub struct MemoryAuthStore {
64 sessions: DashMap<SmolStr, ClientSessionData<'static>>,
65 auth_reqs: DashMap<SmolStr, AuthRequestData<'static>>,
66}
67
68impl MemoryAuthStore {
69 /// Create a new, empty in-memory auth store.
70 pub fn new() -> Self {
71 Self {
72 sessions: DashMap::new(),
73 auth_reqs: DashMap::new(),
74 }
75 }
76}
77
78impl ClientAuthStore for MemoryAuthStore {
79 async fn get_session(
80 &self,
81 did: &Did<'_>,
82 session_id: &str,
83 ) -> Result<Option<ClientSessionData<'_>>, SessionStoreError> {
84 let key = format_smolstr!("{}_{}", did, session_id);
85 Ok(self.sessions.get(&key).map(|v| v.clone()))
86 }
87
88 async fn upsert_session(
89 &self,
90 session: ClientSessionData<'_>,
91 ) -> Result<(), SessionStoreError> {
92 let key = format_smolstr!("{}_{}", session.account_did, session.session_id);
93 self.sessions.insert(key, session.into_static());
94 Ok(())
95 }
96
97 async fn delete_session(
98 &self,
99 did: &Did<'_>,
100 session_id: &str,
101 ) -> Result<(), SessionStoreError> {
102 let key = format_smolstr!("{}_{}", did, session_id);
103 self.sessions.remove(&key);
104 Ok(())
105 }
106
107 async fn get_auth_req_info(
108 &self,
109 state: &str,
110 ) -> Result<Option<AuthRequestData<'_>>, SessionStoreError> {
111 Ok(self.auth_reqs.get(state).map(|v| v.clone()))
112 }
113
114 async fn save_auth_req_info(
115 &self,
116 auth_req_info: &AuthRequestData<'_>,
117 ) -> Result<(), SessionStoreError> {
118 self.auth_reqs.insert(
119 auth_req_info.state.clone().to_smolstr(),
120 auth_req_info.clone().into_static(),
121 );
122 Ok(())
123 }
124
125 async fn delete_auth_req_info(&self, state: &str) -> Result<(), SessionStoreError> {
126 self.auth_reqs.remove(state);
127 Ok(())
128 }
129}
130
131impl<T: ClientAuthStore + Send + Sync>
132 SessionStore<(Did<'static>, SmolStr), ClientSessionData<'static>> for Arc<T>
133{
134 /// Get the current session if present.
135 async fn get(&self, key: &(Did<'static>, SmolStr)) -> Option<ClientSessionData<'static>> {
136 let (did, session_id) = key;
137 self.as_ref()
138 .get_session(did, session_id)
139 .await
140 .ok()
141 .flatten()
142 .into_static()
143 }
144 /// Persist the given session.
145 async fn set(
146 &self,
147 _key: (Did<'static>, SmolStr),
148 session: ClientSessionData<'static>,
149 ) -> Result<(), SessionStoreError> {
150 self.as_ref().upsert_session(session).await
151 }
152 /// Delete the given session.
153 async fn del(&self, key: &(Did<'static>, SmolStr)) -> Result<(), SessionStoreError> {
154 let (did, session_id) = key;
155 self.as_ref().delete_session(did, session_id).await
156 }
157}