Sync reading position from Moon Reader app to Bookhive atproto records
atproto bookhive ereader moonreader
3
fork

Configure Feed

Select the types of activity you want to include in your feed.

Refresh session on 400 ExpiredToken, not just 401

The PDS returns 400 with `{"error": "ExpiredToken"}` when an access
token has aged out mid-session; the previous code only refreshed on
401 (unauthenticated) and forwarded the 400 to the caller, which
surfaced to Moon+ Reader as a sync failure after ~2h of activity.

+108 -2
+20 -2
src/waggle/atproto/client.py
··· 103 103 did=data["did"], 104 104 ) 105 105 106 + @staticmethod 107 + def _is_expired_token(resp: httpx.Response) -> bool: 108 + """True if the PDS is telling us to refresh. 109 + 110 + atproto's convention: a fully-unauthenticated call returns 401, but a 111 + call with a previously-valid token that has since expired returns 400 112 + with `{"error": "ExpiredToken"}`. Both should trigger a refresh. 113 + """ 114 + if resp.status_code == 401: 115 + return True 116 + if resp.status_code == 400: 117 + try: 118 + return resp.json().get("error") in ("ExpiredToken", "InvalidToken") 119 + except ValueError: 120 + return False 121 + return False 122 + 106 123 async def request( 107 124 self, 108 125 method: str, ··· 113 130 content: bytes | None = None, 114 131 headers: dict | None = None, 115 132 ) -> httpx.Response: 116 - """XRPC call with auto-refresh on 401.""" 133 + """XRPC call with auto-refresh on ExpiredToken / 401.""" 117 134 sess = await self._ensure_session() 118 135 url = f"{self.pds_url}/xrpc/{nsid}" 119 136 hdrs = {"Authorization": f"Bearer {sess.access_jwt}"} ··· 123 140 resp = await self._http.request( 124 141 method, url, params=params, json=json, content=content, headers=hdrs 125 142 ) 126 - if resp.status_code == 401: 143 + if self._is_expired_token(resp): 144 + log.info("Access token expired; refreshing and retrying %s", nsid) 127 145 async with self._lock: 128 146 await self._refresh_session() 129 147 assert self._session is not None
+88
tests/test_atproto_client.py
··· 1 + """Token-refresh behavior of ATProtoClient.request().""" 2 + 3 + from __future__ import annotations 4 + 5 + import httpx 6 + import pytest 7 + import respx 8 + 9 + from waggle.atproto.client import ATProtoClient 10 + 11 + 12 + @pytest.fixture 13 + def client(): 14 + return ATProtoClient("pds.example", "tester.example", "app-pw") 15 + 16 + 17 + def _session_body(jwt: str = "access.jwt") -> dict: 18 + return { 19 + "accessJwt": jwt, 20 + "refreshJwt": "refresh.jwt", 21 + "did": "did:plc:tester", 22 + "handle": "tester.example", 23 + } 24 + 25 + 26 + @respx.mock 27 + async def test_retries_on_400_expired_token(client): 28 + """The real incident: PDS returns 400 ExpiredToken, not 401.""" 29 + respx.post("https://pds.example/xrpc/com.atproto.server.createSession").mock( 30 + return_value=httpx.Response(200, json=_session_body("old.jwt")) 31 + ) 32 + respx.post("https://pds.example/xrpc/com.atproto.server.refreshSession").mock( 33 + return_value=httpx.Response(200, json=_session_body("new.jwt")) 34 + ) 35 + 36 + put = respx.post("https://pds.example/xrpc/com.atproto.repo.putRecord").mock( 37 + side_effect=[ 38 + httpx.Response(400, json={"error": "ExpiredToken", "message": "expired"}), 39 + httpx.Response(200, json={"uri": "...", "cid": "..."}), 40 + ] 41 + ) 42 + 43 + resp = await client.request("POST", "com.atproto.repo.putRecord", json={"x": 1}) 44 + assert resp.status_code == 200 45 + assert put.call_count == 2 46 + # Second call carries the refreshed token. 47 + assert put.calls[1].request.headers["Authorization"] == "Bearer new.jwt" 48 + await client.close() 49 + 50 + 51 + @respx.mock 52 + async def test_retries_on_401(client): 53 + """Original contract — plain 401 still triggers refresh.""" 54 + respx.post("https://pds.example/xrpc/com.atproto.server.createSession").mock( 55 + return_value=httpx.Response(200, json=_session_body("old.jwt")) 56 + ) 57 + respx.post("https://pds.example/xrpc/com.atproto.server.refreshSession").mock( 58 + return_value=httpx.Response(200, json=_session_body("new.jwt")) 59 + ) 60 + route = respx.get("https://pds.example/xrpc/com.atproto.repo.listRecords").mock( 61 + side_effect=[ 62 + httpx.Response(401, json={"error": "AuthenticationRequired"}), 63 + httpx.Response(200, json={"records": []}), 64 + ] 65 + ) 66 + resp = await client.request("GET", "com.atproto.repo.listRecords", params={"q": 1}) 67 + assert resp.status_code == 200 68 + assert route.call_count == 2 69 + await client.close() 70 + 71 + 72 + @respx.mock 73 + async def test_does_not_retry_on_unrelated_400(client): 74 + """A non-token 400 (e.g. lexicon validation) must NOT trigger a refresh.""" 75 + respx.post("https://pds.example/xrpc/com.atproto.server.createSession").mock( 76 + return_value=httpx.Response(200, json=_session_body()) 77 + ) 78 + refresh = respx.post("https://pds.example/xrpc/com.atproto.server.refreshSession").mock( 79 + return_value=httpx.Response(200, json=_session_body("new.jwt")) 80 + ) 81 + put = respx.post("https://pds.example/xrpc/com.atproto.repo.putRecord").mock( 82 + return_value=httpx.Response(400, json={"error": "InvalidRequest", "message": "bad"}), 83 + ) 84 + resp = await client.request("POST", "com.atproto.repo.putRecord", json={}) 85 + assert resp.status_code == 400 86 + assert put.call_count == 1 87 + assert not refresh.called 88 + await client.close()