A social pastebin built on atproto.
1"""ATProto OAuth helpers.
2
3Handles DPoP proof generation, PAR requests, token exchange, and
4authenticated PDS requests. Adapted from the Bluesky cookbook example.
5"""
6
7import json
8import time
9import urllib.request
10from typing import Any, Tuple
11
12import requests_hardened
13from authlib.common.security import generate_token
14from authlib.jose import JsonWebKey, jwt
15from authlib.oauth2.rfc7636 import create_s256_code_challenge
16from requests import Response
17
18# SSRF-safe HTTP client
19hardened_http = requests_hardened.Manager(
20 requests_hardened.Config(
21 default_timeout=(2, 10),
22 never_redirect=True,
23 ip_filter_enable=True,
24 ip_filter_allow_loopback_ips=False,
25 user_agent_override="Morsels",
26 )
27)
28
29
30def is_safe_url(url):
31 """Crude SSRF check — only allows HTTPS URLs with public hostnames."""
32 from urllib.parse import urlparse
33
34 parts = urlparse(url)
35 if not (
36 parts.scheme == "https"
37 and parts.hostname is not None
38 and parts.hostname == parts.netloc
39 and parts.username is None
40 and parts.password is None
41 and parts.port is None
42 ):
43 return False
44 segments = parts.hostname.split(".")
45 if not (
46 len(segments) >= 2
47 and segments[-1] not in ["local", "arpa", "internal", "localhost"]
48 ):
49 return False
50 if segments[-1].isdigit():
51 return False
52 return True
53
54
55def is_valid_authserver_meta(obj, url):
56 """Validate authorization server metadata against atproto requirements."""
57 from urllib.parse import urlparse
58
59 fetch_url = urlparse(url)
60 issuer_url = urlparse(obj["issuer"])
61 assert issuer_url.hostname == fetch_url.hostname
62 assert issuer_url.scheme == "https"
63 assert "code" in obj["response_types_supported"]
64 assert "authorization_code" in obj["grant_types_supported"]
65 assert "refresh_token" in obj["grant_types_supported"]
66 assert "S256" in obj["code_challenge_methods_supported"]
67 assert "private_key_jwt" in obj["token_endpoint_auth_methods_supported"]
68 assert "ES256" in obj["token_endpoint_auth_signing_alg_values_supported"]
69 assert "atproto" in obj["scopes_supported"]
70 assert obj["authorization_response_iss_parameter_supported"] is True
71 assert obj["pushed_authorization_request_endpoint"] is not None
72 assert obj["require_pushed_authorization_requests"] is True
73 assert "ES256" in obj["dpop_signing_alg_values_supported"]
74 assert obj["client_id_metadata_document_supported"] is True
75 return True
76
77
78def resolve_pds_authserver(url):
79 """Given a PDS URL, find its authorization server."""
80 assert is_safe_url(url)
81 with hardened_http.get_session() as sess:
82 resp = sess.get(f"{url}/.well-known/oauth-protected-resource")
83 resp.raise_for_status()
84 assert resp.status_code == 200
85 return resp.json()["authorization_servers"][0]
86
87
88def fetch_authserver_meta(url):
89 """Fetch and validate authorization server metadata."""
90 assert is_safe_url(url)
91 with hardened_http.get_session() as sess:
92 resp = sess.get(f"{url}/.well-known/oauth-authorization-server")
93 resp.raise_for_status()
94 meta = resp.json()
95 assert is_valid_authserver_meta(meta, url)
96 return meta
97
98
99def client_assertion_jwt(client_id, authserver_url, client_secret_jwk):
100 """Create a signed JWT asserting our client identity."""
101 return jwt.encode(
102 {"alg": "ES256", "kid": client_secret_jwk["kid"]},
103 {
104 "iss": client_id,
105 "sub": client_id,
106 "aud": authserver_url,
107 "jti": generate_token(),
108 "iat": int(time.time()),
109 "exp": int(time.time()) + 60,
110 },
111 client_secret_jwk,
112 ).decode("utf-8")
113
114
115def authserver_dpop_jwt(method, url, nonce, dpop_private_jwk):
116 """Create a DPoP proof JWT for auth server requests."""
117 dpop_pub_jwk = json.loads(dpop_private_jwk.as_json(is_private=False))
118 body = {
119 "jti": generate_token(),
120 "htm": method,
121 "htu": url,
122 "iat": int(time.time()),
123 "exp": int(time.time()) + 30,
124 }
125 if nonce:
126 body["nonce"] = nonce
127 return jwt.encode(
128 {"typ": "dpop+jwt", "alg": "ES256", "jwk": dpop_pub_jwk},
129 body,
130 dpop_private_jwk,
131 ).decode("utf-8")
132
133
134def _parse_www_authenticate(data):
135 scheme, _, params = data.partition(" ")
136 items = urllib.request.parse_http_list(params)
137 opts = urllib.request.parse_keqv_list(items)
138 return scheme, opts
139
140
141def is_use_dpop_nonce_error_response(resp):
142 """Check if a response is asking us to retry with a new DPoP nonce."""
143 if resp.status_code not in [400, 401]:
144 return False
145 www_authenticate = resp.headers.get("WWW-Authenticate")
146 if www_authenticate:
147 try:
148 scheme, params = _parse_www_authenticate(www_authenticate)
149 if scheme.lower() == "dpop" and params.get("error") == "use_dpop_nonce":
150 return True
151 except Exception:
152 pass
153 try:
154 json_body = resp.json()
155 if isinstance(json_body, dict) and json_body.get("error") == "use_dpop_nonce":
156 return True
157 except Exception:
158 pass
159 return False
160
161
162def auth_server_post(
163 authserver_url,
164 client_id,
165 client_secret_jwk,
166 dpop_private_jwk,
167 dpop_authserver_nonce,
168 post_url,
169 post_data,
170) -> Tuple[str, Response]:
171 """POST to auth server with client assertion and DPoP, handling nonce rotation."""
172 client_assertion = client_assertion_jwt(
173 client_id, authserver_url, client_secret_jwk
174 )
175 post_data |= {
176 "client_id": client_id,
177 "client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
178 "client_assertion": client_assertion,
179 }
180 dpop_proof = authserver_dpop_jwt(
181 "POST", post_url, dpop_authserver_nonce, dpop_private_jwk
182 )
183
184 assert is_safe_url(post_url)
185 with hardened_http.get_session() as sess:
186 resp = sess.post(post_url, data=post_data, headers={"DPoP": dpop_proof})
187
188 if is_use_dpop_nonce_error_response(resp):
189 dpop_authserver_nonce = resp.headers["DPoP-Nonce"]
190 dpop_proof = authserver_dpop_jwt(
191 "POST", post_url, dpop_authserver_nonce, dpop_private_jwk
192 )
193 with hardened_http.get_session() as sess:
194 resp = sess.post(post_url, data=post_data, headers={"DPoP": dpop_proof})
195
196 return dpop_authserver_nonce, resp
197
198
199def send_par_auth_request(
200 authserver_url,
201 authserver_meta,
202 login_hint,
203 client_id,
204 redirect_uri,
205 scope,
206 client_secret_jwk,
207 dpop_private_jwk,
208) -> Tuple[str, str, str, Any]:
209 """Send a Pushed Authorization Request. Returns (pkce_verifier, state, dpop_nonce, response)."""
210 par_url = authserver_meta["pushed_authorization_request_endpoint"]
211 state = generate_token()
212 pkce_verifier = generate_token(48)
213 code_challenge = create_s256_code_challenge(pkce_verifier)
214
215 par_body = {
216 "response_type": "code",
217 "code_challenge": code_challenge,
218 "code_challenge_method": "S256",
219 "state": state,
220 "redirect_uri": redirect_uri,
221 "scope": scope,
222 }
223 if login_hint:
224 par_body["login_hint"] = login_hint
225
226 assert is_safe_url(par_url)
227 dpop_authserver_nonce, resp = auth_server_post(
228 authserver_url=authserver_url,
229 client_id=client_id,
230 client_secret_jwk=client_secret_jwk,
231 dpop_private_jwk=dpop_private_jwk,
232 dpop_authserver_nonce="",
233 post_url=par_url,
234 post_data=par_body,
235 )
236
237 return pkce_verifier, state, dpop_authserver_nonce, resp
238
239
240def initial_token_request(
241 auth_request, code, client_id, redirect_uri, client_secret_jwk
242):
243 """Exchange authorization code for tokens. Returns (token_body, dpop_nonce)."""
244 authserver_url = auth_request["authserver_iss"]
245 authserver_meta = fetch_authserver_meta(authserver_url)
246
247 token_url = authserver_meta["token_endpoint"]
248 dpop_private_jwk = JsonWebKey.import_key(
249 json.loads(auth_request["dpop_private_jwk"])
250 )
251
252 params = {
253 "redirect_uri": redirect_uri,
254 "grant_type": "authorization_code",
255 "code": code,
256 "code_verifier": auth_request["pkce_verifier"],
257 }
258
259 assert is_safe_url(token_url)
260 dpop_authserver_nonce, resp = auth_server_post(
261 authserver_url=authserver_url,
262 client_id=client_id,
263 client_secret_jwk=client_secret_jwk,
264 dpop_private_jwk=dpop_private_jwk,
265 dpop_authserver_nonce=auth_request["dpop_authserver_nonce"],
266 post_url=token_url,
267 post_data=params,
268 )
269
270 resp.raise_for_status()
271 return resp.json(), dpop_authserver_nonce
272
273
274def refresh_token_request(user, client_id, client_secret_jwk):
275 """Refresh an access token. Returns (token_body, dpop_nonce)."""
276 authserver_url = user["authserver_iss"]
277 authserver_meta = fetch_authserver_meta(authserver_url)
278
279 token_url = authserver_meta["token_endpoint"]
280 dpop_private_jwk = JsonWebKey.import_key(json.loads(user["dpop_private_jwk"]))
281
282 params = {
283 "grant_type": "refresh_token",
284 "refresh_token": user["refresh_token"],
285 }
286
287 assert is_safe_url(token_url)
288 dpop_authserver_nonce, resp = auth_server_post(
289 authserver_url=authserver_url,
290 client_id=client_id,
291 client_secret_jwk=client_secret_jwk,
292 dpop_private_jwk=dpop_private_jwk,
293 dpop_authserver_nonce=user["dpop_authserver_nonce"],
294 post_url=token_url,
295 post_data=params,
296 )
297
298 resp.raise_for_status()
299 return resp.json(), dpop_authserver_nonce
300
301
302def revoke_token_request(user, client_id, client_secret_jwk):
303 """Revoke access and refresh tokens."""
304 authserver_url = user["authserver_iss"]
305 authserver_meta = fetch_authserver_meta(authserver_url)
306
307 dpop_private_jwk = JsonWebKey.import_key(json.loads(user["dpop_private_jwk"]))
308 dpop_authserver_nonce = user["dpop_authserver_nonce"]
309
310 revoke_url = authserver_meta.get("revocation_endpoint")
311 if not revoke_url:
312 return
313
314 assert is_safe_url(revoke_url)
315 for token_type in ["access_token", "refresh_token"]:
316 dpop_authserver_nonce, resp = auth_server_post(
317 authserver_url=authserver_url,
318 client_id=client_id,
319 client_secret_jwk=client_secret_jwk,
320 dpop_private_jwk=dpop_private_jwk,
321 dpop_authserver_nonce=dpop_authserver_nonce,
322 post_url=revoke_url,
323 post_data={
324 "token": user[token_type],
325 "token_type_hint": token_type,
326 },
327 )
328 resp.raise_for_status()
329
330
331def pds_authed_req(method, url, user, db, body=None):
332 """Make an authenticated request to a user's PDS with DPoP."""
333 dpop_private_jwk = JsonWebKey.import_key(json.loads(user["dpop_private_jwk"]))
334 dpop_pds_nonce = user["dpop_pds_nonce"] or ""
335 access_token = user["access_token"]
336
337 resp = None
338 for _ in range(2):
339 dpop_pub_jwk = json.loads(dpop_private_jwk.as_json(is_private=False))
340 dpop_body = {
341 "iat": int(time.time()),
342 "exp": int(time.time()) + 10,
343 "jti": generate_token(),
344 "htm": method,
345 "htu": url,
346 "ath": create_s256_code_challenge(access_token),
347 }
348 if dpop_pds_nonce:
349 dpop_body["nonce"] = dpop_pds_nonce
350 dpop_jwt = jwt.encode(
351 {"typ": "dpop+jwt", "alg": "ES256", "jwk": dpop_pub_jwk},
352 dpop_body,
353 dpop_private_jwk,
354 ).decode("utf-8")
355
356 with hardened_http.get_session() as sess:
357 if method == "GET":
358 resp = sess.get(
359 url,
360 headers={"Authorization": f"DPoP {access_token}", "DPoP": dpop_jwt},
361 )
362 else:
363 resp = sess.post(
364 url,
365 headers={"Authorization": f"DPoP {access_token}", "DPoP": dpop_jwt},
366 json=body,
367 )
368
369 if is_use_dpop_nonce_error_response(resp):
370 dpop_pds_nonce = resp.headers["DPoP-Nonce"]
371 cur = db.cursor()
372 cur.execute(
373 "UPDATE oauth_session SET dpop_pds_nonce = ? WHERE did = ?;",
374 [dpop_pds_nonce, user["did"]],
375 )
376 db.commit()
377 cur.close()
378 continue
379 break
380
381 return resp